Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion autoemulate/core/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from matplotlib.figure import Figure

from autoemulate.core.types import DistributionLike, GaussianLike, NumpyLike, TensorLike
from autoemulate.emulators.base import Emulator
from autoemulate.emulators.base import Emulator, PyTorchBackend


def display_figure(fig: Figure):
Expand Down Expand Up @@ -601,3 +601,58 @@ def plot_calibration_from_distributions(
ax.legend()

return fig, ax


def plot_loss(
model: PyTorchBackend,
title: str | None = None,
figsize: tuple[int, int] | None = None,
):
"""
Plot the per-epoch training loss for a model using the PyTorch backend.

This function visualizes the training loss curve stored in the model's
``loss_history`` attribute. The model must also provide an ``epochs``
attribute. If either attribute is missing, an ``AttributeError`` is raised.

Parameters
----------
model : PyTorchBackend
A model instance using the PyTorch backend. It must provide both
``loss_history`` and ``epochs`` attributes.
title : str, optional
Title for the plot. If ``None``, no title is added.
figsize : tuple of int, optional
Size of the figure as ``(width, height)`` in inches. Defaults to
``(6, 6)`` if not provided.

Returns
-------
fig : matplotlib.figure.Figure
The created matplotlib Figure object.
ax : matplotlib.axes.Axes
The Axes on which the loss curve is plotted.

Raises
------
AttributeError
If the model does not provide ``loss_history`` or ``epochs``.
"""
if not hasattr(model, "loss_history"):
msg = "Emulator does not have a Loss history"
raise AttributeError(msg)

history = model.loss_history

if figsize is None:
figsize = (6, 6)

fig, ax = plt.subplots(figsize=figsize)
ax.plot(range(1, len(history) + 1), history)
ax.set_xlabel("Epochs")
ax.set_ylabel("Train Loss")

if title:
ax.set_title(title)

return fig, ax
38 changes: 38 additions & 0 deletions tests/core/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
import pytest
from autoemulate.core import plotting
from autoemulate.emulators.polynomials import PolynomialRegression
from autoemulate.emulators.random_forest import RandomForest


def test_display_figure_jupyter(monkeypatch):
Expand Down Expand Up @@ -64,3 +66,39 @@ def test_plot_xy():
def test_calculate_subplot_layout(n_plots, n_cols, expected):
result = plotting.calculate_subplot_layout(n_plots, n_cols)
assert result == expected


@pytest.mark.parametrize(
("model_class", "should_raise", "title"),
[
(PolynomialRegression, False, "Training Curve"),
(RandomForest, True, "My Loss Plot"),
(PolynomialRegression, False, None),
],
)
def test_plot_loss(model_class, should_raise, title):
np.random.seed(42)
x = np.random.rand(20, 2)
y = (x[:, 0] + 2 * x[:, 1] > 1).astype(int)

model = model_class(x, y)
model.fit(x, y)

if should_raise:
with pytest.raises(AttributeError):
fig, ax = plotting.plot_loss(model=model, title=title)
return

fig, ax = plotting.plot_loss(model=model, title=title)

if title is not None:
assert ax.get_title() == title

assert ax.get_xlabel() == "Epochs"
assert ax.get_ylabel() == "Train Loss"

epochs = np.arange(1, len(model.loss_history) + 1)
line_x, line_y = ax.get_lines()[0].get_data()

assert np.allclose(line_x, epochs)
assert np.allclose(line_y, model.loss_history)