Skip to content

Commit e4dde82

Browse files
committed
Add benchmark_multi_table_aws (#507)
1 parent 9714363 commit e4dde82

File tree

5 files changed

+321
-18
lines changed

5 files changed

+321
-18
lines changed

sdgym/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
benchmark_multi_table,
1717
benchmark_single_table,
1818
benchmark_single_table_aws,
19+
benchmark_multi_table_aws,
1920
)
2021
from sdgym.cli.collect import collect_results
2122
from sdgym.cli.summary import make_summary_spreadsheet
@@ -36,6 +37,7 @@
3637
'DatasetExplorer',
3738
'ResultsExplorer',
3839
'benchmark_multi_table',
40+
'benchmark_multi_table_aws',
3941
'benchmark_single_table',
4042
'benchmark_single_table_aws',
4143
'collect_results',

sdgym/benchmark.py

Lines changed: 131 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
'TVAESynthesizer',
108108
]
109109
SDV_MULTI_TABLE_SYNTHESIZERS = ['HMASynthesizer']
110-
110+
MODALITY_IDX = 10
111111
SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS
112112

113113

@@ -220,32 +220,46 @@ def _get_metainfo_increment(top_folder, s3_client=None):
220220
return max(increments) + 1 if increments else 0
221221

222222

223-
def _setup_output_destination_aws(output_destination, synthesizers, datasets, s3_client):
223+
def _setup_output_destination_aws(
224+
output_destination,
225+
synthesizers,
226+
datasets,
227+
modality,
228+
s3_client,
229+
):
224230
paths = defaultdict(dict)
225231
s3_path = output_destination[len(S3_PREFIX) :].rstrip('/')
226232
parts = s3_path.split('/')
227233
bucket_name = parts[0]
228234
prefix_parts = parts[1:]
229235
paths['bucket_name'] = bucket_name
230236
today = datetime.today().strftime('%m_%d_%Y')
231-
top_folder = '/'.join(prefix_parts + [f'SDGym_results_{today}'])
237+
238+
modality_prefix = '/'.join(prefix_parts + [modality])
239+
top_folder = f'{modality_prefix}/SDGym_results_{today}'
232240
increment = _get_metainfo_increment(f's3://{bucket_name}/{top_folder}', s3_client)
233241
suffix = f'({increment})' if increment >= 1 else ''
234242
s3_client.put_object(Bucket=bucket_name, Key=top_folder + '/')
243+
synthetic_data_extension = 'zip' if modality == 'multi_table' else 'csv'
235244
for dataset in datasets:
236245
dataset_folder = f'{top_folder}/{dataset}_{today}'
237246
s3_client.put_object(Bucket=bucket_name, Key=dataset_folder + '/')
238-
paths[dataset]['meta'] = f's3://{bucket_name}/{dataset_folder}/meta.yaml'
247+
239248
for synth_name in synthesizers:
240249
final_synth_name = f'{synth_name}{suffix}'
241250
synth_folder = f'{dataset_folder}/{final_synth_name}'
242251
s3_client.put_object(Bucket=bucket_name, Key=synth_folder + '/')
243252
paths[dataset][final_synth_name] = {
244-
'synthesizer': f's3://{bucket_name}/{synth_folder}/{final_synth_name}.pkl',
245-
'synthetic_data': f's3://{bucket_name}/{synth_folder}/{final_synth_name}_synthetic_data.csv',
246-
'benchmark_result': f's3://{bucket_name}/{synth_folder}/{final_synth_name}_benchmark_result.csv',
247-
'results': f's3://{bucket_name}/{top_folder}/results{suffix}.csv',
248-
'metainfo': f's3://{bucket_name}/{top_folder}/metainfo{suffix}.yaml',
253+
'synthesizer': (f's3://{bucket_name}/{synth_folder}/{final_synth_name}.pkl'),
254+
'synthetic_data': (
255+
f's3://{bucket_name}/{synth_folder}/'
256+
f'{final_synth_name}_synthetic_data.{synthetic_data_extension}'
257+
),
258+
'benchmark_result': (
259+
f's3://{bucket_name}/{synth_folder}/{final_synth_name}_benchmark_result.csv'
260+
),
261+
'metainfo': (f's3://{bucket_name}/{top_folder}/metainfo{suffix}.yaml'),
262+
'results': (f's3://{bucket_name}/{top_folder}/results{suffix}.csv'),
249263
}
250264

251265
s3_client.put_object(
@@ -279,7 +293,9 @@ def _setup_output_destination(
279293
The s3 client that can be used to read / write to s3. Defaults to ``None``.
280294
"""
281295
if s3_client:
282-
return _setup_output_destination_aws(output_destination, synthesizers, datasets, s3_client)
296+
return _setup_output_destination_aws(
297+
output_destination, synthesizers, datasets, modality, s3_client
298+
)
283299

284300
if output_destination is None:
285301
return {}
@@ -1571,7 +1587,7 @@ def _get_s3_script_content(
15711587
return f"""
15721588
import boto3
15731589
import cloudpickle
1574-
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file
1590+
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file, MODALITY_IDX
15751591
from io import StringIO
15761592
from sdgym.result_writer import S3ResultsWriter
15771593
@@ -1583,8 +1599,9 @@ def _get_s3_script_content(
15831599
)
15841600
response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}')
15851601
job_args_list = cloudpickle.loads(response['Body'].read())
1602+
modality = job_args_list[0][MODALITY_IDX]
15861603
result_writer = S3ResultsWriter(s3_client=s3_client)
1587-
_write_metainfo_file({synthesizers}, job_args_list, 'single_table', result_writer)
1604+
_write_metainfo_file({synthesizers}, job_args_list, modality, result_writer)
15881605
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
15891606
metainfo_filename = job_args_list[0][-1]['metainfo']
15901607
_update_metainfo_file(metainfo_filename, result_writer)
@@ -1619,7 +1636,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content):
16191636
16201637
echo "======== Install Dependencies in venv ============"
16211638
pip install --upgrade pip
1622-
pip install sdgym[all]
1639+
pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@feature_branch/mutli_table_benchmark"
16231640
pip install s3fs
16241641
16251642
echo "======== Write Script ==========="
@@ -1644,13 +1661,14 @@ def _run_on_aws(
16441661
aws_secret_access_key,
16451662
):
16461663
bucket_name, job_args_key = _store_job_args_in_s3(output_destination, job_args_list, s3_client)
1664+
synthesizer_names = [{'name': synthesizer['name']} for synthesizer in synthesizers]
16471665
script_content = _get_s3_script_content(
16481666
aws_access_key_id,
16491667
aws_secret_access_key,
16501668
S3_REGION,
16511669
bucket_name,
16521670
job_args_key,
1653-
synthesizers,
1671+
synthesizer_names,
16541672
)
16551673

16561674
# Create a session and EC2 client using the provided S3 client's credentials
@@ -1917,3 +1935,102 @@ def benchmark_multi_table(
19171935
_update_metainfo_file(metainfo_filename, result_writer)
19181936

19191937
return scores
1938+
1939+
1940+
def benchmark_multi_table_aws(
1941+
output_destination,
1942+
aws_access_key_id=None,
1943+
aws_secret_access_key=None,
1944+
synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS,
1945+
sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS,
1946+
additional_datasets_folder=None,
1947+
limit_dataset_size=False,
1948+
compute_quality_score=True,
1949+
compute_diagnostic_score=True,
1950+
timeout=None,
1951+
):
1952+
"""Run the SDGym benchmark on multi-table datasets.
1953+
1954+
Args:
1955+
output_destination (str):
1956+
An S3 bucket or filepath. The results output folder will be written here.
1957+
Should be structured as:
1958+
s3://{s3_bucket_name}/{path_to_file} or s3://{s3_bucket_name}.
1959+
aws_access_key_id (str): The AWS access key id. Optional
1960+
aws_secret_access_key (str): The AWS secret access key. Optional
1961+
synthesizers (list[string]):
1962+
The synthesizer(s) to evaluate. Defaults to
1963+
``[HMASynthesizer, MultiTableUniformSynthesizer]``. The available options
1964+
are:
1965+
- ``HMASynthesizer``
1966+
- ``MultiTableUniformSynthesizer``
1967+
sdv_datasets (list[str] or ``None``):
1968+
Names of the SDV demo datasets to use for the benchmark. Defaults to
1969+
``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
1970+
covtype]``. Use ``None`` to disable using any sdv datasets.
1971+
additional_datasets_folder (str or ``None``):
1972+
The path to an S3 bucket. Datasets found in this folder are
1973+
run in addition to the SDV datasets. If ``None``, no additional datasets are used.
1974+
limit_dataset_size (bool):
1975+
Use this flag to limit the size of the datasets for faster evaluation. If ``True``,
1976+
limit the size of every table to 1,000 rows (randomly sampled) and the first 10
1977+
columns.
1978+
compute_quality_score (bool):
1979+
Whether or not to evaluate an overall quality score. Defaults to ``True``.
1980+
compute_diagnostic_score (bool):
1981+
Whether or not to evaluate an overall diagnostic score. Defaults to ``True``.
1982+
timeout (int or ``None``):
1983+
The maximum number of seconds to wait for synthetic data creation. If ``None``, no
1984+
timeout is enforced.
1985+
1986+
Returns:
1987+
pandas.DataFrame:
1988+
A table containing one row per synthesizer + dataset.
1989+
"""
1990+
s3_client = _validate_output_destination(
1991+
output_destination,
1992+
aws_keys={
1993+
'aws_access_key_id': aws_access_key_id,
1994+
'aws_secret_access_key': aws_secret_access_key,
1995+
},
1996+
)
1997+
if not synthesizers:
1998+
synthesizers = []
1999+
2000+
_ensure_uniform_included(synthesizers, modality='multi_table')
2001+
synthesizers = _import_and_validate_synthesizers(
2002+
synthesizers=synthesizers,
2003+
custom_synthesizers=None,
2004+
modality='multi_table',
2005+
)
2006+
job_args_list = _generate_job_args_list(
2007+
limit_dataset_size=limit_dataset_size,
2008+
sdv_datasets=sdv_datasets,
2009+
additional_datasets_folder=additional_datasets_folder,
2010+
sdmetrics=None,
2011+
timeout=timeout,
2012+
output_destination=output_destination,
2013+
compute_quality_score=compute_quality_score,
2014+
compute_diagnostic_score=compute_diagnostic_score,
2015+
compute_privacy_score=None,
2016+
synthesizers=synthesizers,
2017+
detailed_results_folder=None,
2018+
s3_client=s3_client,
2019+
modality='multi_table',
2020+
)
2021+
if not job_args_list:
2022+
return _get_empty_dataframe(
2023+
compute_diagnostic_score=compute_diagnostic_score,
2024+
compute_quality_score=compute_quality_score,
2025+
compute_privacy_score=None,
2026+
sdmetrics=None,
2027+
)
2028+
2029+
_run_on_aws(
2030+
output_destination=output_destination,
2031+
synthesizers=synthesizers,
2032+
s3_client=s3_client,
2033+
job_args_list=job_args_list,
2034+
aws_access_key_id=aws_access_key_id,
2035+
aws_secret_access_key=aws_secret_access_key,
2036+
)

sdgym/result_writer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,16 @@ def write_yaml(self, data, file_path, append=False):
154154
run_data.update(data)
155155
new_content = yaml.dump(run_data)
156156
self.s3_client.put_object(Body=new_content.encode(), Bucket=bucket, Key=key)
157+
158+
def write_zipped_dataframes(self, data, file_path, index=False):
159+
"""Write a dictionary of DataFrames to a ZIP file in S3."""
160+
bucket, key = parse_s3_path(file_path)
161+
zip_buffer = io.BytesIO()
162+
with zipfile.ZipFile(zip_buffer, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
163+
for table_name, table in data.items():
164+
csv_buf = io.StringIO()
165+
table.to_csv(csv_buf, index=index)
166+
zf.writestr(f'{table_name}.csv', csv_buf.getvalue())
167+
168+
zip_buffer.seek(0)
169+
self.s3_client.upload_fileobj(zip_buffer, bucket, key)

0 commit comments

Comments
 (0)