Skip to content

Commit ca796fb

Browse files
authored
Merge pull request #940 from alan-turing-institute/fix_loss_history_bug
Make loss_history instance variable (not class variable)
2 parents eaa6b03 + 9d487a2 commit ca796fb

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

autoemulate/emulators/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import random
22
from abc import ABC, abstractmethod
3-
from typing import ClassVar
43

54
import numpy as np
65
import torch
@@ -545,7 +544,6 @@ class PyTorchBackend(nn.Module, Emulator):
545544
batch_size: int = 16
546545
shuffle: bool = True
547546
epochs: int = 10
548-
loss_history: ClassVar[list[float]] = []
549547
verbose: bool = False
550548
loss_fn: nn.Module = nn.MSELoss()
551549
optimizer_cls: type[optim.Optimizer] = optim.Adam
@@ -570,6 +568,8 @@ def _fit(self, x: TensorLike, y: TensorLike):
570568
y: OutputLike or None
571569
Target values (not needed if x is a DataLoader).
572570
"""
571+
self.loss_history: list[float] = []
572+
573573
self.train() # Set model to training mode
574574

575575
# Convert input to DataLoader if not already

0 commit comments

Comments
 (0)