107107 'TVAESynthesizer' ,
108108]
109109SDV_MULTI_TABLE_SYNTHESIZERS = ['HMASynthesizer' ]
110-
110+ MODALITY_IDX = 10
111111SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS
112112
113113
@@ -220,32 +220,46 @@ def _get_metainfo_increment(top_folder, s3_client=None):
220220 return max (increments ) + 1 if increments else 0
221221
222222
223- def _setup_output_destination_aws (output_destination , synthesizers , datasets , s3_client ):
223+ def _setup_output_destination_aws (
224+ output_destination ,
225+ synthesizers ,
226+ datasets ,
227+ modality ,
228+ s3_client ,
229+ ):
224230 paths = defaultdict (dict )
225231 s3_path = output_destination [len (S3_PREFIX ) :].rstrip ('/' )
226232 parts = s3_path .split ('/' )
227233 bucket_name = parts [0 ]
228234 prefix_parts = parts [1 :]
229235 paths ['bucket_name' ] = bucket_name
230236 today = datetime .today ().strftime ('%m_%d_%Y' )
231- top_folder = '/' .join (prefix_parts + [f'SDGym_results_{ today } ' ])
237+
238+ modality_prefix = '/' .join (prefix_parts + [modality ])
239+ top_folder = f'{ modality_prefix } /SDGym_results_{ today } '
232240 increment = _get_metainfo_increment (f's3://{ bucket_name } /{ top_folder } ' , s3_client )
233241 suffix = f'({ increment } )' if increment >= 1 else ''
234242 s3_client .put_object (Bucket = bucket_name , Key = top_folder + '/' )
243+ synthetic_data_extension = 'zip' if modality == 'multi_table' else 'csv'
235244 for dataset in datasets :
236245 dataset_folder = f'{ top_folder } /{ dataset } _{ today } '
237246 s3_client .put_object (Bucket = bucket_name , Key = dataset_folder + '/' )
238- paths [ dataset ][ 'meta' ] = f's3:// { bucket_name } / { dataset_folder } /meta.yaml'
247+
239248 for synth_name in synthesizers :
240249 final_synth_name = f'{ synth_name } { suffix } '
241250 synth_folder = f'{ dataset_folder } /{ final_synth_name } '
242251 s3_client .put_object (Bucket = bucket_name , Key = synth_folder + '/' )
243252 paths [dataset ][final_synth_name ] = {
244- 'synthesizer' : f's3://{ bucket_name } /{ synth_folder } /{ final_synth_name } .pkl' ,
245- 'synthetic_data' : f's3://{ bucket_name } /{ synth_folder } /{ final_synth_name } _synthetic_data.csv' ,
246- 'benchmark_result' : f's3://{ bucket_name } /{ synth_folder } /{ final_synth_name } _benchmark_result.csv' ,
247- 'results' : f's3://{ bucket_name } /{ top_folder } /results{ suffix } .csv' ,
248- 'metainfo' : f's3://{ bucket_name } /{ top_folder } /metainfo{ suffix } .yaml' ,
253+ 'synthesizer' : (f's3://{ bucket_name } /{ synth_folder } /{ final_synth_name } .pkl' ),
254+ 'synthetic_data' : (
255+ f's3://{ bucket_name } /{ synth_folder } /'
256+ f'{ final_synth_name } _synthetic_data.{ synthetic_data_extension } '
257+ ),
258+ 'benchmark_result' : (
259+ f's3://{ bucket_name } /{ synth_folder } /{ final_synth_name } _benchmark_result.csv'
260+ ),
261+ 'metainfo' : (f's3://{ bucket_name } /{ top_folder } /metainfo{ suffix } .yaml' ),
262+ 'results' : (f's3://{ bucket_name } /{ top_folder } /results{ suffix } .csv' ),
249263 }
250264
251265 s3_client .put_object (
@@ -279,7 +293,9 @@ def _setup_output_destination(
279293 The s3 client that can be used to read / write to s3. Defaults to ``None``.
280294 """
281295 if s3_client :
282- return _setup_output_destination_aws (output_destination , synthesizers , datasets , s3_client )
296+ return _setup_output_destination_aws (
297+ output_destination , synthesizers , datasets , modality , s3_client
298+ )
283299
284300 if output_destination is None :
285301 return {}
@@ -1571,7 +1587,7 @@ def _get_s3_script_content(
15711587 return f"""
15721588import boto3
15731589import cloudpickle
1574- from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file
1590+ from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file, MODALITY_IDX
15751591from io import StringIO
15761592from sdgym.result_writer import S3ResultsWriter
15771593
@@ -1583,8 +1599,9 @@ def _get_s3_script_content(
15831599)
15841600response = s3_client.get_object(Bucket='{ bucket_name } ', Key='{ job_args_key } ')
15851601job_args_list = cloudpickle.loads(response['Body'].read())
1602+ modality = job_args_list[0][MODALITY_IDX]
15861603result_writer = S3ResultsWriter(s3_client=s3_client)
1587- _write_metainfo_file({ synthesizers } , job_args_list, 'single_table' , result_writer)
1604+ _write_metainfo_file({ synthesizers } , job_args_list, modality , result_writer)
15881605scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
15891606metainfo_filename = job_args_list[0][-1]['metainfo']
15901607_update_metainfo_file(metainfo_filename, result_writer)
@@ -1619,7 +1636,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content):
16191636
16201637 echo "======== Install Dependencies in venv ============"
16211638 pip install --upgrade pip
1622- pip install sdgym[all]
1639+ pip install " sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@feature_branch/mutli_table_benchmark"
16231640 pip install s3fs
16241641
16251642 echo "======== Write Script ==========="
@@ -1644,13 +1661,14 @@ def _run_on_aws(
16441661 aws_secret_access_key ,
16451662):
16461663 bucket_name , job_args_key = _store_job_args_in_s3 (output_destination , job_args_list , s3_client )
1664+ synthesizer_names = [{'name' : synthesizer ['name' ]} for synthesizer in synthesizers ]
16471665 script_content = _get_s3_script_content (
16481666 aws_access_key_id ,
16491667 aws_secret_access_key ,
16501668 S3_REGION ,
16511669 bucket_name ,
16521670 job_args_key ,
1653- synthesizers ,
1671+ synthesizer_names ,
16541672 )
16551673
16561674 # Create a session and EC2 client using the provided S3 client's credentials
@@ -1917,3 +1935,102 @@ def benchmark_multi_table(
19171935 _update_metainfo_file (metainfo_filename , result_writer )
19181936
19191937 return scores
1938+
1939+
1940+ def benchmark_multi_table_aws (
1941+ output_destination ,
1942+ aws_access_key_id = None ,
1943+ aws_secret_access_key = None ,
1944+ synthesizers = DEFAULT_MULTI_TABLE_SYNTHESIZERS ,
1945+ sdv_datasets = DEFAULT_MULTI_TABLE_DATASETS ,
1946+ additional_datasets_folder = None ,
1947+ limit_dataset_size = False ,
1948+ compute_quality_score = True ,
1949+ compute_diagnostic_score = True ,
1950+ timeout = None ,
1951+ ):
1952+ """Run the SDGym benchmark on multi-table datasets.
1953+
1954+ Args:
1955+ output_destination (str):
1956+ An S3 bucket or filepath. The results output folder will be written here.
1957+ Should be structured as:
1958+ s3://{s3_bucket_name}/{path_to_file} or s3://{s3_bucket_name}.
1959+ aws_access_key_id (str): The AWS access key id. Optional
1960+ aws_secret_access_key (str): The AWS secret access key. Optional
1961+ synthesizers (list[string]):
1962+ The synthesizer(s) to evaluate. Defaults to
1963+ ``[HMASynthesizer, MultiTableUniformSynthesizer]``. The available options
1964+ are:
1965+ - ``HMASynthesizer``
1966+ - ``MultiTableUniformSynthesizer``
1967+ sdv_datasets (list[str] or ``None``):
1968+ Names of the SDV demo datasets to use for the benchmark. Defaults to
1969+ ``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
1970+ covtype]``. Use ``None`` to disable using any sdv datasets.
1971+ additional_datasets_folder (str or ``None``):
1972+ The path to an S3 bucket. Datasets found in this folder are
1973+ run in addition to the SDV datasets. If ``None``, no additional datasets are used.
1974+ limit_dataset_size (bool):
1975+ Use this flag to limit the size of the datasets for faster evaluation. If ``True``,
1976+ limit the size of every table to 1,000 rows (randomly sampled) and the first 10
1977+ columns.
1978+ compute_quality_score (bool):
1979+ Whether or not to evaluate an overall quality score. Defaults to ``True``.
1980+ compute_diagnostic_score (bool):
1981+ Whether or not to evaluate an overall diagnostic score. Defaults to ``True``.
1982+ timeout (int or ``None``):
1983+ The maximum number of seconds to wait for synthetic data creation. If ``None``, no
1984+ timeout is enforced.
1985+
1986+ Returns:
1987+ pandas.DataFrame:
1988+ A table containing one row per synthesizer + dataset.
1989+ """
1990+ s3_client = _validate_output_destination (
1991+ output_destination ,
1992+ aws_keys = {
1993+ 'aws_access_key_id' : aws_access_key_id ,
1994+ 'aws_secret_access_key' : aws_secret_access_key ,
1995+ },
1996+ )
1997+ if not synthesizers :
1998+ synthesizers = []
1999+
2000+ _ensure_uniform_included (synthesizers , modality = 'multi_table' )
2001+ synthesizers = _import_and_validate_synthesizers (
2002+ synthesizers = synthesizers ,
2003+ custom_synthesizers = None ,
2004+ modality = 'multi_table' ,
2005+ )
2006+ job_args_list = _generate_job_args_list (
2007+ limit_dataset_size = limit_dataset_size ,
2008+ sdv_datasets = sdv_datasets ,
2009+ additional_datasets_folder = additional_datasets_folder ,
2010+ sdmetrics = None ,
2011+ timeout = timeout ,
2012+ output_destination = output_destination ,
2013+ compute_quality_score = compute_quality_score ,
2014+ compute_diagnostic_score = compute_diagnostic_score ,
2015+ compute_privacy_score = None ,
2016+ synthesizers = synthesizers ,
2017+ detailed_results_folder = None ,
2018+ s3_client = s3_client ,
2019+ modality = 'multi_table' ,
2020+ )
2021+ if not job_args_list :
2022+ return _get_empty_dataframe (
2023+ compute_diagnostic_score = compute_diagnostic_score ,
2024+ compute_quality_score = compute_quality_score ,
2025+ compute_privacy_score = None ,
2026+ sdmetrics = None ,
2027+ )
2028+
2029+ _run_on_aws (
2030+ output_destination = output_destination ,
2031+ synthesizers = synthesizers ,
2032+ s3_client = s3_client ,
2033+ job_args_list = job_args_list ,
2034+ aws_access_key_id = aws_access_key_id ,
2035+ aws_secret_access_key = aws_secret_access_key ,
2036+ )
0 commit comments