Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ wheels/

# Data files
examples/**/*data/
examples/**/*results/
examples/**/*.csv
examples/**/*.npy

Expand All @@ -43,16 +44,6 @@ outputs/
# mkdocs site
site/

# Training examples
examples/training/single_table/data/**
examples/training/single_table/results/**
examples/training/multi_table/data/**
examples/training/multi_table/results/**
examples/synthesizing/single_table/data/**
examples/synthesizing/single_table/results/**
examples/synthesizing/multi_table/data/**
examples/synthesizing/multi_table/results/**

# Training Logs
*.err
*.out
105 changes: 105 additions & 0 deletions examples/gan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# CTGAN Single-Table Example

This example will go over training a single-table [CTGAN](https://arxiv.org/pdf/1907.00503)
model using the [CTGAN](https://github.com/sdv-dev/CTGAN/) library and then synthesizing
some data afterwards.


## Downloading data

First, we need the data. Download it from this
[Google Drive link](https://drive.google.com/file/d/1J5qDuMHHg4dm9c3ISmb41tcTHSu1SVUC/view?usp=drive_link),
extract the files and place them in a `/data` folder in within this folder
(`examples/gan`).

> [!NOTE]
> If you wish to change the data folder, you can do so by editing the `base_data_dir` attribute
> of the (`config.yaml`)[config.yaml] file.

Here is a description of the files that have been extracted:
- `trans.csv`: The training data. It consists of information about bank transactions and it
contains 20,000 data points.
- `trans_domain.json`: Metadata about the columns in `trans.csv`, such as data types and sizes.
- `dataset_meta.json`: Metadata about the relationship between the tables. Since this is a
single-table example, it will only contain information about the `trans` table.
- `meta_info.json`: Metadata about the dataset, namely which columns are numerical and
which ones are categorical, the target column and the task type (e.g. `regression`).


## Kicking off training

To kick off training, simply run the command below from the project's root folder:

```bash
python -m examples.gan.train
```


## Training results

The result files will be saved inside a `/results` folder within this folder
(`examples/gan`).

> [!NOTE]
> If you wish to change the save folder, you can do so by editing the `results_dir` attribute
> of the (`config.yaml`)[config.yaml] file.

In the `/results` folder, there will be a file called `trained_ctgan_model.pkl`,
which is a pickle file containing the trained model. You can load it using CTGAN's
`load` function:

```python
import pickle
from ctgan import CTGAN

results_file = Path("examples/gan/results/trained_ctgan_model.pkl")

ctgan = CTGAN.load()
```

## Synthesizing data

To synthesize some data with the trained model, run:

```bash
python -m examples.gan.synthesize
```

If there is already a trained model in the `/results` folder, it will use that model.
Otherwise it will train one from scratch. At the end of the script, it will save the
synthesized data to `/results/trans_synthetic.csv`.


## Evaluating the quality of the synthetic data

### Alpha Precision

To run a round of evaluation with [Alpha Precision](https://arxiv.org/abs/2301.07573)
metrics on a set of synthetic data, run the `evaluate.py` script:

```bash
python -m midst_toolkit.evaluation.quality.scripts.midst_alpha_precision_eval \
--synthetic_data_path examples/gan/results/trans_synthetic.csv \
--real_data examples/gan/data/trans.csv \
--meta_info_path examples/gan/data/meta_info.json \
--save_directory examples/gan/results/
```

It will save the evaluation results under the `/results/model.txt` file.

### Additional Metrics

The calculation of assitional metrics are set up in the `evaluate.py` file. They are the
Kolmogorov-Smirnov (KS) test, Total Variation Distance (TVD), Correlation Matrix Difference
and Mutual Information Difference.

To compute those metrics, you can run the command below. The name of the table should be
defined in the `dataset_meta.json` file, and the file for synthetic data should be under
`/data/{table_name}.csv` for the real data and `/results/{table_name}_synthetic.csv`
for the synthetic data.

```bash
python -m examples.gan.evaluate
```

The results will be saved in the `/results/evaluation.json` file.
11 changes: 11 additions & 0 deletions examples/gan/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Training example configuration
# Base data directory (can be overridden from command line)
base_data_dir: examples/gan/data
results_dir: examples/gan/results

training:
epochs: 300
verbose: True

synthesizing:
sample_size: 20000
79 changes: 79 additions & 0 deletions examples/gan/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
from logging import INFO
from pathlib import Path

import hydra
import pandas as pd
from omegaconf import DictConfig

from examples.gan.utils import get_table_name
from midst_toolkit.common.logger import log
from midst_toolkit.evaluation.quality.correlation_matrix_difference import CorrelationMatrixDifference
from midst_toolkit.evaluation.quality.kolmogorov_smirnov_total_variation import KolmogorovSmirnovAndTotalVariation
from midst_toolkit.evaluation.quality.mutual_information_difference import MutualInformationDifference


@hydra.main(config_path=".", config_name="config", version_base=None)
def main(config: DictConfig) -> None:
"""
Run the evaluation pipeline for the Kolmogorov-Smirnov and Total Variation Distance metrics.

It will load the config and then data from the `config.base_data_dir` folder for the table
name (from the `dataset_meta.json` file) and the real data under `{table_name}.csv`, and
the synthetic data from the `config.results_dir` folder under `{table_name}_synthetic.csv`,
and then compute the Kolmogorov-Smirnov and Total Variation Distance metrics.

It will also need the meta_info.json file for the information about categorical and numerical
columns.

The results will be saved in the `config.results_dir` folder under `ks_tvd_evaluation.json`.

Args:
config: Configuration as an OmegaConf DictConfig object.
"""
log(INFO, "Loading data...")

table_name = get_table_name(config.base_data_dir)

real_data = pd.read_csv(Path(config.base_data_dir) / f"{table_name}.csv")
synthetic_data = pd.read_csv(Path(config.results_dir) / f"{table_name}_synthetic.csv")

with open(Path(config.base_data_dir) / "meta_info.json", "r") as f:
meta_info = json.load(f)

numerical_columns = [real_data.columns[i] for i in meta_info["num_col_idx"]]
categorical_columns = [real_data.columns[i] for i in meta_info["cat_col_idx"]]

results = {}

# KS and TVD
ks_tvd_metric = KolmogorovSmirnovAndTotalVariation(categorical_columns, numerical_columns, do_preprocess=True)
ks_tvd_score = ks_tvd_metric.compute(real_data, synthetic_data)

log(INFO, f"Kolmogorov-Smirnov and Total Variation Distance score: {ks_tvd_score}")
results["ks_tvd"] = ks_tvd_score

# Correlation Matrix Difference
cmd_metric = CorrelationMatrixDifference(categorical_columns, numerical_columns, do_preprocess=True)
cmd_result = cmd_metric.compute(real_data, synthetic_data)

log(INFO, f"Correlation Matrix Difference score: {cmd_result}")
results["correlation_matrix_difference"] = cmd_result

# Mutual Information Difference
mid_metric = MutualInformationDifference(categorical_columns, numerical_columns, do_preprocess=True)
mid_result = mid_metric.compute(real_data, synthetic_data)
mid_result["score"] = mid_result["mutual_inf_diff"] / mid_result["mi_mat_dims"]

log(INFO, f"Mutual Information Difference score: {mid_result}")
results["mutual_information_difference"] = mid_result

log(INFO, "Saving results...")
with open(Path(config.results_dir) / "evaluation.json", "w") as f:
json.dump(results, f, indent=4)

log(INFO, "Done!")


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions examples/gan/synthesize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from logging import INFO
from pathlib import Path

import hydra
from omegaconf import DictConfig
from sdv.single_table import CTGANSynthesizer # type: ignore[import-untyped]

from examples.gan.train import main as train_main
from examples.gan.utils import get_table_name
from midst_toolkit.common.logger import log


@hydra.main(config_path=".", config_name="config", version_base=None)
def main(config: DictConfig) -> None:
"""
Run the synthesizing pipeline for a single-table CTGAN model.

It will load the config and then data from the `config.base_data_dir` folder,
load the trained model (or train one if it doesn't exist) and save the results
in the `config.results_dir` folder.

Args:
config: Configuration as an OmegaConf DictConfig object.
"""
results_file = Path(config.results_dir) / "trained_ctgan_model.pkl"

if not results_file.exists():
log(INFO, f"Trained model not found at {results_file}. Training a new model from scratch.")
train_main(config)

log(INFO, f"Loading model from {results_file}...")
ctgan = CTGANSynthesizer.load(results_file)

log(INFO, f"Synthesizing data of size {config.synthesizing.sample_size}...")
synthetic_data = ctgan.sample(num_rows=config.synthesizing.sample_size)

table_name = get_table_name(config.base_data_dir)
synthetic_data_file = Path(config.results_dir) / f"{table_name}_synthetic.csv"

log(INFO, f"Saving synthetic data to {synthetic_data_file}...")
synthetic_data.to_csv(synthetic_data_file, index=False)

log(INFO, "Done!")


if __name__ == "__main__":
main()
55 changes: 55 additions & 0 deletions examples/gan/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
from logging import INFO
from pathlib import Path

import hydra
import pandas as pd
from omegaconf import DictConfig
from sdv.single_table import CTGANSynthesizer # type: ignore[import-untyped]

from examples.gan.utils import get_metadata, get_table_name
from midst_toolkit.common.logger import log


@hydra.main(config_path=".", config_name="config", version_base=None)
def main(config: DictConfig) -> None:
"""
Run the training pipeline for a single-table CTGAN model.

It will load the config and then data from the `config.base_data_dir` folder,
train the model and save the results in the `config.results_dir` folder.

Args:
config: Configuration as an OmegaConf DictConfig object.
"""
log(INFO, "Loading data...")

table_name = get_table_name(config.base_data_dir)

with open(Path(config.base_data_dir) / f"{table_name}_domain.json", "r") as f:
domain_info = json.load(f)

real_data = pd.read_csv(Path(config.base_data_dir) / f"{table_name}.csv")

metadata, real_data_without_ids = get_metadata(real_data, domain_info)

log(INFO, "Fitting CTGAN...")

ctgan = CTGANSynthesizer(
metadata=metadata,
epochs=config.training.epochs,
verbose=config.training.verbose,
)
ctgan.fit(real_data_without_ids)

log(INFO, "Done!")
log(INFO, "Saving model...")
results_file = Path(config.results_dir) / "trained_ctgan_model.pkl"
results_file.parent.mkdir(parents=True, exist_ok=True)

ctgan.save(results_file)
log(INFO, f"Model saved to {results_file}")


if __name__ == "__main__":
main()
69 changes: 69 additions & 0 deletions examples/gan/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
from pathlib import Path
from typing import Any

import pandas as pd
from sdv.metadata import SingleTableMetadata # type: ignore[import-untyped]


def get_table_name(base_data_dir: Path) -> str:
"""
Get the name of the table from the dataset metadata.

Args:
base_data_dir: The base directory containing the dataset metadata.

Returns:
The name of the table.
"""
with open(Path(base_data_dir) / "dataset_meta.json", "r") as f:
dataset_meta = json.load(f)

assert len(dataset_meta["tables"]) == 1, (
"Only one table is supported for single-table training. "
f"Got {len(dataset_meta['tables'])} tables: {dataset_meta['tables'].keys()}"
)

return list(dataset_meta["tables"].keys())[0]


def get_metadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe since this function is specific to single-table (even the return data type is specific for single table), we should include single-table somewhere in the name. For example, get_single_table_metadata(...).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe a better name would be get_single_table_svd_metadata() to highlight the return object type.

data: pd.DataFrame,
domain_dictionary: dict[str, Any] | None = None,
) -> tuple[SingleTableMetadata, pd.DataFrame]:
"""
Get the metadata for a single-table dataset.

Args:
data: The dataframe containing the data.
domain_dictionary: The domain dictionary containing metadata about the data columns.

Returns:
A tuple containing the metadata and the dataframe without the id columns.
"""
metadata = SingleTableMetadata()
data_without_ids = data.drop(columns=[column_name for column_name in data.columns if "_id" in column_name])
metadata.detect_from_dataframe(data_without_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment about detect_from_dataframe() would be helpful.


if domain_dictionary is not None:
for column_name in data_without_ids.columns:
if domain_dictionary[column_name]["type"] == "discrete":
if domain_dictionary[column_name]["size"] < 1000:
metadata.update_column(
column_name=column_name,
sdtype="categorical",
)
else:
metadata.update_column(
column_name=column_name,
sdtype="numerical",
)
else:
metadata.update_column(
column_name=column_name,
sdtype="numerical",
)

metadata.remove_primary_key()

return metadata, data_without_ids
Loading