Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@ class TorchExperiment(BaseExperiment):

Parameters
----------
datamodule : L.LightningDataModule
A PyTorch Lightning DataModule that handles data loading and preparation.
data_module : type
A PyTorch Lightning DataModule class (not an instance) that
handles data loading and preparation. It will be instantiated
with hyperparameters during optimization.
lightning_module : type
A PyTorch Lightning Module class (not an instance) that will be instantiated
with hyperparameters during optimization.
trainer_kwargs : dict, optional (default=None)
A dictionary of keyword arguments to pass to the PyTorch Lightning Trainer.
dm_kwargs : dict, optional (default=None)
A dictionary of keyword arguments to pass to the Data Module upon instantiation.
objective_metric : str, optional (default='val_loss')
The metric used to evaluate the model's performance. This should correspond
to a metric logged in the LightningModule during validation.
Expand Down Expand Up @@ -93,14 +97,12 @@ class TorchExperiment(BaseExperiment):
... def val_dataloader(self):
... return DataLoader(self.val, batch_size=self.batch_size)
>>>
>>> datamodule = RandomDataModule(batch_size=16)
>>> datamodule.setup()
>>>
>>> # Create Experiment
>>> experiment = TorchExperiment(
... datamodule=datamodule,
... data_module=RandomDataModule,
... lightning_module=SimpleLightningModule,
... trainer_kwargs={'max_epochs': 3},
... dm_kwargs={'batch_size': 16},
... objective_metric="val_loss"
... )
>>>
Expand All @@ -118,14 +120,16 @@ class TorchExperiment(BaseExperiment):

def __init__(
self,
datamodule,
data_module,
lightning_module,
trainer_kwargs=None,
dm_kwargs=None,
objective_metric: str = "val_loss",
):
self.datamodule = datamodule
self.data_module = data_module
self.lightning_module = lightning_module
self.trainer_kwargs = trainer_kwargs or {}
self.dm_kwargs = dm_kwargs or {}
self.objective_metric = objective_metric

super().__init__()
Expand Down Expand Up @@ -174,7 +178,8 @@ def _evaluate(self, params):
try:
model = self.lightning_module(**params)
trainer = L.Trainer(**self._trainer_kwargs)
trainer.fit(model, self.datamodule)
data = self.data_module(**self.dm_kwargs)
trainer.fit(model, data)

val_result = trainer.callback_metrics.get(self.objective_metric)
metadata = {}
Expand Down Expand Up @@ -265,17 +270,16 @@ def train_dataloader(self):
def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size)

datamodule = RandomDataModule(batch_size=16)

params = {
"datamodule": datamodule,
"data_module": RandomDataModule,
"lightning_module": SimpleLightningModule,
"trainer_kwargs": {
"max_epochs": 1,
"enable_progress_bar": False,
"enable_model_summary": False,
"logger": False,
},
"dm_kwargs": {"batch_size": 16},
"objective_metric": "val_loss",
}

Expand Down Expand Up @@ -339,17 +343,16 @@ def train_dataloader(self):
def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size)

datamodule2 = RegressionDataModule(batch_size=16, num_samples=150)

params2 = {
"datamodule": datamodule2,
"data_module": RegressionDataModule,
"lightning_module": RegressionModule,
"trainer_kwargs": {
"max_epochs": 1,
"enable_progress_bar": False,
"enable_model_summary": False,
"logger": False,
},
"dm_kwargs": {"batch_size": 8, "num_samples": 200},
"objective_metric": "val_loss",
}

Expand All @@ -370,4 +373,5 @@ def _get_score_params(cls):
"""
score_params1 = {"input_dim": 10, "hidden_dim": 20, "lr": 0.001}
score_params2 = {"num_layers": 3, "hidden_size": 64, "dropout": 0.2}

return [score_params1, score_params2]
Loading