# direct installation
pip install git+https://github.com/poseidonchan/w1ot.git
# Or, clone the repo and install it
# git clone https://github.com/poseidonchan/w1ot.git
# cd w1ot
# pip install -e .For general usage:
from w1ot import w1ot, w2ot
from w1ot.data import make_2d_data, plot_2d_data
# create the toy data for training
source, target = make_2d_data(dataset='circles', n_samples=2**17, noise=0.01)
# initialize the model
model = w1ot(source, target, 0.1, device, path='./saved_models/w1ot/circles')
# fit the Kantorovich potential
model.fit_potential_function(num_iters=10000,resume_from_checkpoint=True)
# visualize the Kantorovich potential
model.plot_2dpotential()
# fit the step size
model.fit_distance_function(num_iters=10000, resume_from_checkpoint=True)
# create the testing data
source, target = make_2d_data(dataset='circles', n_samples=2000, noise=0.01)
# apply the learned transport map
transported = model.transport(source)
# visualize the result without markers
plot_2d_data(source, target, transported, False, 0.5)
# Alternative: w2ot
# model = w2ot(source, target, 0.1, device, path='./saved_models/w2ot/circles')
# model.fit_potential_function(num_iters=10000, resume_from_checkpoint=True)
# transported = model.transport(source)For single-cell data in h5ad format:
from w1ot.experiment import PerturbModel
# data requirment: the model will automatically doing the normalization and log1p transformation if the max value exceeds 50.
# If you have already preprocessing it, you should refer to the general usage above to directly use w1ot model, which offers more flexibility.
# Initialization
model = PerturbModel(model_name="w1ot", # also support "w2ot", "scgen"
source_adata=source_train_adata, # data must be splitted at first
target_adata=target_train_adata, # data must be splitted at first
perturbation_attribute=perturbation_attribute, # it is only used in the evaluation process
latent_dim=8,
embedding=True, # whether to use the embedding model (vae), if used, then the OT is doing on the latent space.
output_dir=model_output_dir,
hidden_layers=[32, 32], # hidden layer size for the embedding model.
num_iters=10000, # training iteration for the embedding model
device="cuda")
# Training
model.train()
# Inference
transported_adata = model.predict(source_test_adata) # the transported_adata will have the same meta data with source_test_adata
# Evaluation
metrics = model.evaluate(source_adata,
target_adata,
top_k=50 # this is using the topk DEGs for evaluation. It will automatically calculate the DEGs using scanpy.tl.rank_genes_group
)
# if using embedding model, it will evaluate the performance on both embedding space and cell space. Otherwise, embedding_* is np.nan
embedding_r2, embedding_l2, embedding_mmd, cell_r2, cell_l2, cell_mmd = metricsTo reproduce the experiments efficiently, we suggest you install the Ray (2.37.0) and configure your own Ray clusters. After that you can run the experiments codes in the Experiments folder.
# for example:
python ./Experiments/4i.pyIf you do not have access to enough computation resources, the reproducing procedure could be very slow (since the w2ot model consumes a lot of time).
@misc{chen2024fastscalablewasserstein1neural,
title={Fast and scalable Wasserstein-1 neural optimal transport solver for single-cell perturbation prediction},
author={Yanshuo Chen and Zhengmian Hu and Wei Chen and Heng Huang},
year={2024},
eprint={2411.00614},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2411.00614},
}If you have any questions, feel free to email [email protected]
