Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 15 additions & 46 deletions pina/_src/core/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Trainer utilities built on top of the PyTorch Lightning Trainer class."""

import sys
import warnings
import torch
import lightning
Expand All @@ -13,7 +12,7 @@
check_positive_integer,
)

# set the warning for compile options
# Set custom warning format and filter warnings
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=UserWarning)

Expand All @@ -23,8 +22,8 @@ class Trainer(lightning.pytorch.Trainer):
PINA-specific extension of :class:`lightning.pytorch.Trainer`.

The trainer configures solver execution, dataset splitting, batching,
logging, compilation support, device placement for unknown parameters, and
gradient tracking requirements for physics-informed solvers.
logging, device placement for unknown parameters, and gradient tracking
requirements for physics-informed solvers.
"""

# Available batching modes
Expand All @@ -41,7 +40,6 @@ def __init__(
train_size=1.0,
test_size=0.0,
val_size=0.0,
compile=False,
batching_mode="common_batch_size",
automatic_batching=False,
num_workers=0,
Expand All @@ -64,9 +62,6 @@ def __init__(
Default is ``0.0``.
:param float test_size: The fraction of samples assigned to the test
split. Must belong to the interval ``[0, 1]``. Default is ``0.0``.
:param bool compile: Whether to compile the model before training.
Compilation is disabled on Windows and with Python 3.14 or later.
Default is ``False``.
:param str batching_mode: The strategy used to aggregate batches across
dataloaders. Available options are ``"common_batch_size"`` for
uniform batch sizes across conditions, ``"proportional"`` for batch
Expand All @@ -91,26 +86,33 @@ def __init__(
not a float in the interval ``[0, 1]``.
:raises ValueError: If the sum of ``train_size``, ``val_size``, and
``test_size`` is not equal to 1.
:raises ValueError: If ``compile``, ``automatic_batching``,
``pin_memory``, or ``shuffle`` is not a boolean.
:raises ValueError: If ``automatic_batching``, ``pin_memory``, or
``shuffle`` is not a boolean.
:raises AssertionError: If ``num_workers`` is a negative integer.
:raises ValueError: If ``batch_size``, when provided, is not a positive
integer.
:raises ValueError: If ``batching_mode`` is not one of the available
options.
:raises UserWarning: If compilation is requested on an unsupported
platform or Python version.
:raises UserWarning: If the provided ``batching_mode`` is incompatible
with the ``batch_size``.
:raises RuntimeError: If any domain in the problem has not been
discretised.
"""
# Backward compatibility: compile has been removed
if "compile" in kwargs:
warnings.warn(
"`compile` is deprecated and no longer used. Compilation is "
"now disabled and the argument will be ignored.",
DeprecationWarning,
stacklevel=2,
)
kwargs.pop("compile")

# Check consistency
check_consistency(solver, BaseSolver)
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
check_consistency(compile, bool)
check_consistency(automatic_batching, bool)
check_consistency(pin_memory, bool)
check_consistency(shuffle, bool)
Expand Down Expand Up @@ -147,19 +149,6 @@ def __init__(
# Initialize the parent class with the provided keyword arguments
super().__init__(**kwargs)

# Disable compilation for Windows and Python 3.14+
if sys.platform == "win32" or sys.version_info >= (3, 14) and compile:

# Raise a warning if compilation is requested but not supported
warnings.warn(
"Model compilation is not supported on Windows or with Python "
"3.14+. Compilation has been disabled.",
UserWarning,
)

# Set compile to False if not supported
compile = False

# Raise warning if batch size and batching mode are incompatible
if batch_size is None and batching_mode != "common_batch_size":
warnings.warn(
Expand Down Expand Up @@ -189,7 +178,6 @@ def __init__(

# Initialize the class attributes
self.solver = solver
self.compile = compile
self.batch_size = batch_size

# Move the unknown parameters to the correct device
Expand Down Expand Up @@ -299,22 +287,3 @@ def solver(self, solver):
:param BaseSolver solver: The solver instance to attach.
"""
self._solver = solver

@property
def compile(self):
"""
Return whether model compilation is enabled.

:return: ``True`` if compilation is enabled, otherwise ``False``.
:rtype: bool
"""
return self._compile

@compile.setter
def compile(self, value):
"""
Set the value of compile.

:param bool value: Whether compilation is required or not.
"""
self._compile = value
28 changes: 0 additions & 28 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
@pytest.mark.parametrize("automatic_batching", [True, False])
@pytest.mark.parametrize("pin_memory", [True, False])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize("batch_size", [None, 5])
@pytest.mark.parametrize(
"train_size, test_size, val_size", [(0.8, 0.1, 0.1), (0.7, 0.2, 0.1)]
Expand All @@ -26,7 +25,6 @@ def test_constructor(
train_size,
test_size,
val_size,
compile,
batching_mode,
automatic_batching,
pin_memory,
Expand All @@ -39,7 +37,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=0,
Expand All @@ -55,7 +52,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=0,
Expand All @@ -71,23 +67,6 @@ def test_constructor(
train_size=0.5,
test_size=0.3,
val_size=0.3,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=0,
pin_memory=pin_memory if batch_size else False,
shuffle=shuffle,
)

# Should raise ValueError if compile is not a boolean
with pytest.raises(ValueError):
Trainer(
solver=solver,
batch_size=batch_size,
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile="not_a_boolean",
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=0,
Expand All @@ -103,7 +82,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching="not_a_boolean",
num_workers=0,
Expand All @@ -119,7 +97,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=0,
Expand All @@ -135,7 +112,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=0,
Expand All @@ -151,7 +127,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=-1,
Expand All @@ -167,7 +142,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=0,
Expand All @@ -183,7 +157,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode="invalid_mode",
automatic_batching=automatic_batching,
num_workers=0,
Expand All @@ -206,7 +179,6 @@ def test_constructor(
train_size=train_size,
test_size=test_size,
val_size=val_size,
compile=compile,
batching_mode=batching_mode if batch_size else "common_batch_size",
automatic_batching=automatic_batching,
num_workers=0,
Expand Down
Loading