Skip to content

Commit 2afe318

Browse files
authored
Merge branch 'alan-turing-institute:main' into plot_loss
2 parents fe247b0 + ca796fb commit 2afe318

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

autoemulate/core/compare.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
self,
5656
x: InputLike,
5757
y: InputLike,
58+
test_data: tuple[InputLike, InputLike] | None = None,
5859
models: list[type[Emulator] | str] | None = None,
5960
x_transforms_list: list[list[Transform | dict]] | None = None,
6061
y_transforms_list: list[list[Transform | dict]] | None = None,
@@ -81,6 +82,9 @@ def __init__(
8182
Input features.
8283
y: InputLike or None
8384
Target values (not needed if x is a Dataset).
85+
test_data: tuple[InputLike, InputLike] | None
86+
Optional test data as a tuple (x_test, y_test). If None, a random split
87+
from the provided data is used. Defaults to None.
8488
models: list[type[Emulator]] | None
8589
List of emulator classes to compare. If None, all available emulators
8690
are used.
@@ -164,7 +168,17 @@ def __init__(
164168
self.models = updated_models
165169
if random_seed is not None:
166170
set_random_seed(seed=random_seed)
167-
self.train_val, self.test = self._random_split(self._convert_to_dataset(x, y))
171+
172+
if test_data is None:
173+
self.train_val, self.test = self._random_split(
174+
self._convert_to_dataset(x, y)
175+
)
176+
else:
177+
self.train_val = self._convert_to_dataset(x, y)
178+
test_x, test_y = self._move_tensors_to_device(
179+
*self._convert_to_tensors(*test_data)
180+
)
181+
self.test = self._convert_to_dataset(test_x, test_y)
168182

169183
# Run the compare method with the provided models
170184
if not self.models:

autoemulate/emulators/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import random
22
from abc import ABC, abstractmethod
3-
from typing import ClassVar
43

54
import numpy as np
65
import torch
@@ -545,7 +544,6 @@ class PyTorchBackend(nn.Module, Emulator):
545544
batch_size: int = 16
546545
shuffle: bool = True
547546
epochs: int = 10
548-
loss_history: ClassVar[list[float]] = []
549547
verbose: bool = False
550548
loss_fn: nn.Module = nn.MSELoss()
551549
optimizer_cls: type[optim.Optimizer] = optim.Adam
@@ -570,6 +568,8 @@ def _fit(self, x: TensorLike, y: TensorLike):
570568
y: OutputLike or None
571569
Target values (not needed if x is a DataLoader).
572570
"""
571+
self.loss_history: list[float] = []
572+
573573
self.train() # Set model to training mode
574574

575575
# Convert input to DataLoader if not already

tests/core/test_compare.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from autoemulate.emulators import DEFAULT_EMULATORS
1111
from autoemulate.emulators.base import Emulator
1212
from torch.distributions import Transform
13+
from torch.utils.data import TensorDataset
1314

1415

1516
@pytest.mark.parametrize("device", SUPPORTED_DEVICES)
@@ -462,3 +463,29 @@ def __call__(
462463
metric_names = [m.name for m in result.test_metrics]
463464
assert "custom_r2" in metric_names
464465
assert "rmse" in metric_names
466+
467+
468+
def test_ae_with_fixed_test_data(sample_data_for_ae_compare):
469+
"""Test AutoEmulate with a fixed test dataset."""
470+
x, y = sample_data_for_ae_compare
471+
models: list[str | type[Emulator]] = ["mlp", "RandomForest"]
472+
473+
# Create fixed test set
474+
test_size = 25
475+
x_test, y_test = x[:test_size], y[:test_size]
476+
x_train, y_train = x[test_size:], y[test_size:]
477+
478+
ae = AutoEmulate(
479+
x_train,
480+
y_train,
481+
models=models,
482+
test_data=(x_test, y_test),
483+
n_iter=2,
484+
n_splits=2,
485+
model_params={}, # Skip tuning for speed
486+
)
487+
488+
assert isinstance(ae.test, TensorDataset)
489+
assert ae.test.tensors == (x_test, y_test)
490+
assert isinstance(ae.train_val, TensorDataset)
491+
assert ae.train_val.tensors == (x_train, y_train)

0 commit comments

Comments
 (0)