Skip to content

Commit 6054aab

Browse files
committed
Use 'epochs' as determiner for model history presence
1 parent 2afe318 commit 6054aab

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

autoemulate/core/plotting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -642,13 +642,14 @@ def plot_loss(
642642
"""
643643

644644
try:
645-
history = model.loss_history
646-
except:
647-
raise AttributeError("Emulator does not have a Loss history")
645+
has_epochs = model.epochs
646+
except Exception:
647+
raise AttributeError
648648

649649
if figsize is None:
650650
figsize = (6, 6)
651651

652+
history = model.loss_history
652653
fig, ax = plt.subplots(figsize=figsize)
653654
ax.plot(range(1, len(history) + 1), history)
654655
ax.set_xlabel("Epochs")

tests/core/test_plotting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_plot_loss(model_class, should_raise, title):
8989

9090
if should_raise:
9191
with pytest.raises(AttributeError):
92-
plotting.plot_loss(model=model, title=title)
92+
fig, ax = plotting.plot_loss(model=model, title=title)
9393
return
9494

9595
fig, ax = plotting.plot_loss(model=model, title=title)

0 commit comments

Comments
 (0)