Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pina/_src/solver/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ def _compute_condition_loss(self, condition, data, batch_idx):
data = dict(data)
data["input"] = data["input"].clone()

# Prepare condition data, e.g. by enabling gradient for regularizations
data = self._prepare_condition_data(data=data)

# Compute and store the residual tensor for the condition
self.residual_tensor = condition.evaluate(data, self)

Expand All @@ -296,11 +299,55 @@ def _compute_condition_loss(self, condition, data, batch_idx):
# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)

# Optional regularization hook, e.g gradient-enhanced or residual-based
condition_tensor_loss = self._regularize_condition_loss(
condition_tensor_loss=condition_tensor_loss,
condition_name=condition_name,
data=data,
batch_idx=batch_idx,
)

# Compute the scalar loss from the tensor loss and return it
condition_scalar_loss = self._apply_reduction(condition_tensor_loss)

return condition_scalar_loss

def _prepare_condition_data(self, data):
"""
Prepare the condition data for loss computation. This method can be
overridden by mixins to implement specific data preparation steps, such
as enabling gradient tracking for inputs in gradient-enhanced solvers.

:param dict data: The original condition data.
:return: The prepared condition data.
:rtype: dict
"""
return data

def _regularize_condition_loss(
self,
condition_tensor_loss,
condition_name,
data,
batch_idx,
):
"""
Regularize the condition loss if needed. This method can be overridden
by mixins to implement specific regularization strategies, such as
adding a gradient penalty in gradient-enhanced solvers or applying
residual-based attention.

:param condition_tensor_loss: The original tensor loss for the
condition.
:type condition_tensor_loss: torch.Tensor | LabelTensor
:param str condition_name: The name of the condition.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The regularized tensor loss for the condition.
:rtype: torch.Tensor | LabelTensor
"""
return condition_tensor_loss

def _loss_from_residual(self, condition_name=None):
"""
Compute the tensor loss from the residual tensor.
Expand Down
11 changes: 11 additions & 0 deletions pina/_src/solver/causal_physics_informed_single_model_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def _compute_condition_loss(self, condition, data, batch_idx):
data = dict(data)
data["input"] = data["input"].clone()

# Prepare condition data, e.g. by enabling gradient for regularizations
data = self._prepare_condition_data(data=data)

# Extract the temporal domain
time_domain = self.problem.temporal_domain

Expand Down Expand Up @@ -251,6 +254,14 @@ def _compute_condition_loss(self, condition, data, batch_idx):
# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)

# Optional regularization hook
condition_tensor_loss = self._regularize_condition_loss(
condition_tensor_loss=condition_tensor_loss,
condition_name=condition_name,
data=data,
batch_idx=batch_idx,
)

# Append the loss for the current time step to the list
time_loss.append(condition_tensor_loss)

Expand Down
11 changes: 11 additions & 0 deletions pina/_src/solver/competitive_physics_informed_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def _compute_condition_loss(self, condition, data, batch_idx):
data = dict(data)
data["input"] = data["input"].clone()

# Prepare condition data, e.g. by enabling gradient for regularizations
data = self._prepare_condition_data(data=data)

# Compute and store the residual tensor for the condition
self.residual_tensor = condition.evaluate(data, self)

Expand All @@ -229,6 +232,14 @@ def _compute_condition_loss(self, condition, data, batch_idx):
# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)

# Optional regularization hook, e.g gradient-enhanced or residual-based
condition_tensor_loss = self._regularize_condition_loss(
condition_tensor_loss=condition_tensor_loss,
condition_name=condition_name,
data=data,
batch_idx=batch_idx,
)

# Compute the scalar loss from the tensor loss and return it
condition_scalar_loss = self._apply_reduction(condition_tensor_loss)

Expand Down
58 changes: 55 additions & 3 deletions pina/_src/solver/mixin/ensemble_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module for the ensemble mixin class."""

import torch
from pina._src.solver.base_solver import BaseSolver
from pina._src.solver.mixin.multi_model_mixin import MultiModelMixin


Expand All @@ -16,14 +17,65 @@ class EnsembleMixin(MultiModelMixin):

def forward(self, x):
"""
The forward pass implementation that evaluates all models and returns
the average of their outputs.
Forward pass for ensemble solvers. If an active model index is set, only
that model is evaluated. Otherwise, all models are evaluated and their
outputs are stacked together.

:param x: The input data.
:type x: torch.Tensor | LabelTensor | Data | Graph
:return: The output of all models stacked together.
:rtype: torch.Tensor | LabelTensor | Data | Graph
"""
# Retrieve the index of the active model if set
active_idx = getattr(self, "_active_model_idx", None)
Comment thread
GiovanniCanali marked this conversation as resolved.

# If an active model index is set, evaluate only that model
if active_idx is not None:
return self.models[active_idx](x)

# Otherwise, evaluate all models and stack outputs
return torch.stack(
[self.models[idx](x) for idx in range(self.num_models)]
).mean(dim=0)
)

def _compute_condition_loss(self, condition, data, batch_idx):
"""
Compute the scalar loss for a given condition and its data.

:param BaseCondition condition: The condition for which to compute the
loss.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The scalar loss for the condition.
:rtype: torch.Tensor
"""
# Initialize model losses for the current condition
model_losses = []

# Restore the active model index if it was set, else set it to None
previous_active_model_idx = getattr(self, "_active_model_idx", None)

# Try - finally to ensure active model index is always restored
try:

# Iterate over all ensemble models to compute individual losses
for model_idx in range(self.num_models):

# Set the active model index for the current iteration
self._active_model_idx = model_idx

# Compute the scalar loss for the current model and condition
condition_scalar_loss = BaseSolver._compute_condition_loss(
self, condition, data, batch_idx
)

# Store the computed loss for the current model
model_losses.append(condition_scalar_loss)

# Ensure that the active model index is always restored
finally:

# Restore the previous active model index after computation
self._active_model_idx = previous_active_model_idx

return torch.stack(model_losses).mean()
Comment thread
GiovanniCanali marked this conversation as resolved.
64 changes: 36 additions & 28 deletions pina/_src/solver/mixin/gradient_enhanced_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,41 +81,52 @@ def _init_gradient_enhanced_components(
self.regularization_weight = regularization_weight
self.regularized_conditions = regularized_conditions

def _compute_condition_loss(self, condition, data, batch_idx):
def _prepare_condition_data(self, data):
"""
Compute the scalar loss for a given condition and its data.
Prepare the condition data for loss computation. This method can be
overridden by mixins to implement specific data preparation steps, such
as enabling gradient tracking for inputs in gradient-enhanced solvers.

:param BaseCondition condition: The condition for which to compute the
loss.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The scalar loss for the condition.
:rtype: torch.Tensor
:param dict data: The original condition data.
:return: The prepared condition data.
:rtype: dict
"""
# Clone the input tensor if it exists to avoid in-place modifications
if "input" in data and hasattr(data["input"], "clone"):
data = dict(data)
data["input"] = data["input"].clone()

# If data does not require grad, force requires_grad to True
if "input" in data and not data["input"].requires_grad:
data["input"].requires_grad_(True)

# Compute and store the residual tensor for the condition
self.residual_tensor = condition.evaluate(data, self)
self.residual_tensor.labels = [
f"res_{i}" for i in range(self.residual_tensor.shape[1])
]

# Retrieve condition name for more complex weighting schemes
condition_name = condition.name if hasattr(condition, "name") else None

# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)
return data

def _regularize_condition_loss(
self,
condition_tensor_loss,
condition_name,
data,
batch_idx,
):
"""
Regularize the condition loss if needed. This method can be overridden
by mixins to implement specific regularization strategies, such as
adding a gradient penalty in gradient-enhanced solvers or applying
residual-based attention.

:param condition_tensor_loss: The original tensor loss for the
condition.
:type condition_tensor_loss: torch.Tensor | LabelTensor
:param str condition_name: The name of the condition.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The regularized tensor loss for the condition.
:rtype: torch.Tensor | LabelTensor
"""
# Regularize the loss with the gradient penalty if needed
if condition_name in self.regularized_conditions:

# Apply labels to the residual tensor for gradient computation
self.residual_tensor.labels = [
f"res_{i}" for i in range(self.residual_tensor.shape[1])
]

# Compute the gradient of the residual with respect to spatial input
residual_gradient = grad(
output_=self.residual_tensor,
Expand All @@ -134,7 +145,4 @@ def _compute_condition_loss(self, condition, data, batch_idx):
# Add the gradient penalty to the original condition tensor loss
condition_tensor_loss = condition_tensor_loss + penalty

# Compute the scalar loss from the tensor loss and return it
condition_scalar_loss = self._apply_reduction(condition_tensor_loss)

return condition_scalar_loss
return condition_tensor_loss
44 changes: 19 additions & 25 deletions pina/_src/solver/mixin/residual_based_attention_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,28 @@ def _init_residual_attention_components(
self.register_buffer(f"weight_{cond}", torch.zeros((n_pts, 1)))
self.weight_buffers[cond] = f"weight_{cond}"

def _compute_condition_loss(self, condition, data, batch_idx):
def _regularize_condition_loss(
self,
condition_tensor_loss,
condition_name,
data,
batch_idx,
):
"""
Compute the scalar loss for a given condition and its data.

:param BaseCondition condition: The condition for which to compute the
loss.
Regularize the condition loss if needed. This method can be overridden
by mixins to implement specific regularization strategies, such as
adding a gradient penalty in gradient-enhanced solvers or applying
residual-based attention.

:param condition_tensor_loss: The original tensor loss for the
condition.
:type condition_tensor_loss: torch.Tensor | LabelTensor
:param str condition_name: The name of the condition.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The scalar loss for the condition.
:rtype: torch.Tensor
:return: The regularized tensor loss for the condition.
:rtype: torch.Tensor | LabelTensor
"""
# Clone the input tensor if it exists to avoid in-place modifications
if "input" in data and hasattr(data["input"], "clone"):
data = dict(data)
data["input"] = data["input"].clone()

# Compute and store the residual tensor for the condition
self.residual_tensor = condition.evaluate(data, self)

# Retrieve condition name for more complex weighting schemes
condition_name = condition.name

# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)

# Apply residual-based attention mechanism if needed
if condition_name in self.regularized_conditions:

Expand Down Expand Up @@ -150,7 +147,4 @@ def _compute_condition_loss(self, condition, data, batch_idx):
# Weight the condition tensor loss with attention weights
condition_tensor_loss = condition_tensor_loss * weights[idx]

# Compute the scalar loss from the tensor loss and return it
condition_scalar_loss = self._apply_reduction(condition_tensor_loss)

return condition_scalar_loss
return condition_tensor_loss
11 changes: 11 additions & 0 deletions pina/_src/solver/self_adaptive_physics_informed_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def _compute_condition_loss(self, condition, data, batch_idx):
data = dict(data)
data["input"] = data["input"].clone()

# Prepare condition data, e.g. by enabling gradient for regularizations
data = self._prepare_condition_data(data=data)

# Compute and store the residual tensor for the condition
self.residual_tensor = condition.evaluate(data, self)

Expand All @@ -253,6 +256,14 @@ def _compute_condition_loss(self, condition, data, batch_idx):
# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)

# Optional regularization hook, e.g gradient-enhanced or residual-based
condition_tensor_loss = self._regularize_condition_loss(
condition_tensor_loss=condition_tensor_loss,
condition_name=condition_name,
data=data,
batch_idx=batch_idx,
)

# Get the correct indices to retrieve the weights for the current batch
len_residuals = self.residual_tensor.shape[0]

Expand Down
7 changes: 6 additions & 1 deletion tests/test_solver/test_autoregressive_ensemble_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,12 @@ def test_train_load_restore(clean_tmp_dir, use_lt):
)

# Assert that the predictions from the loaded solver match original ones
assert new_solver.forward(test_pts).shape == (n_traj, t_steps, n_feats)
assert new_solver.forward(test_pts).shape == (
n_models,
n_traj,
t_steps,
n_feats,
)
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
torch.testing.assert_close(
new_solver.forward(test_pts), solver.forward(test_pts)
Expand Down
Loading
Loading