Skip to content

Commit e5b1ce8

Browse files
authored
Add Explicit dtype and device Support for Calculators and Ensure Compatibility with Potentials (#143)
* Refactor parameter handling in calculators and potentials for improved dtype and device management * Updated docstrings and changelog, added an assertion to check for an instance of the potential, and resolved the TorchScript Potential/Calculator incompatibility. * Update changelog and add test for potential and calculator compatibility
1 parent 04edb22 commit e5b1ce8

File tree

16 files changed

+140
-79
lines changed

16 files changed

+140
-79
lines changed

docs/extensions/versions_list.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run(self):
4343
:margin: 0 0 0 0\n"""
4444

4545
for group_i, (version_short, group) in enumerate(grouped_versions.items()):
46-
4746
if group_i < 3:
4847
generated_content += f"""
4948
.. grid-item::

docs/src/references/changelog.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,16 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
2424
`Unreleased <https://github.com/lab-cosmo/torch-pme/>`_
2525
-------------------------------------------------------
2626

27+
Added
28+
#####
29+
30+
* Added ``dtype`` and ``device`` for ``Calculator`` classses
31+
2732
Fixed
2833
#####
2934

35+
* Ensured consistency of ``dtype`` and ``device`` in the ``Potential`` and
36+
``Calculator`` classses
3037
* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class
3138
* Fix inconsistent ``cutoff`` in neighbor list example
3239
* All calculators now check if the cell is zero if the potential is range-separated

examples/5-autograd-demo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,16 @@ def forward(self, positions, cell, charges):
463463

464464
print(
465465
f"""
466-
Delta-Value: {value-jit_value}
466+
Delta-Value: {value - jit_value}
467467
468468
Delta-Position gradients:
469-
{positions.grad.T-jit_positions.grad.T}
469+
{positions.grad.T - jit_positions.grad.T}
470470
471471
Delta-Cell gradients:
472-
{cell.grad-jit_cell.grad}
472+
{cell.grad - jit_cell.grad}
473473
474474
Delta-Charges gradients:
475-
{charges.grad.T-jit_charges.grad.T}
475+
{charges.grad.T - jit_charges.grad.T}
476476
"""
477477
)
478478

src/torchpme/calculators/calculator.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,37 @@ class Calculator(torch.nn.Module):
2626
will come from a full (True) or half (False, default) neighbor list.
2727
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
2828
common values.
29+
:param dtype: type used for the internal buffers and parameters
30+
:param device: device used for the internal buffers and parameters
2931
"""
3032

3133
def __init__(
3234
self,
3335
potential: Potential,
3436
full_neighbor_list: bool = False,
3537
prefactor: float = 1.0,
38+
dtype: Optional[torch.dtype] = None,
39+
device: Optional[torch.device] = None,
3640
):
3741
super().__init__()
38-
# TorchScript requires to initialize all attributes in __init__
39-
self._device = torch.device("cpu")
40-
self._dtype = torch.float32
4142

43+
assert isinstance(potential, Potential), (
44+
f"Potential must be an instance of Potential, got {type(potential)}"
45+
)
46+
47+
self.device = "cpu" if device is None else device
48+
self.dtype = torch.get_default_dtype() if dtype is None else dtype
4249
self.potential = potential
4350

51+
assert self.dtype == self.potential.dtype, (
52+
f"Potential and Calculator must have the same dtype, got {self.dtype} and "
53+
f"{self.potential.dtype}"
54+
)
55+
assert self.device == self.potential.device, (
56+
f"Potential and Calculator must have the same device, got {self.device} and "
57+
f"{self.potential.device}"
58+
)
59+
4460
self.full_neighbor_list = full_neighbor_list
4561

4662
self.prefactor = prefactor

src/torchpme/calculators/ewald.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24

35
from ..lib import generate_kvectors_for_ewald
@@ -53,6 +55,8 @@ class EwaldCalculator(Calculator):
5355
:obj:`False`, a "half" neighbor list is expected.
5456
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
5557
common values.
58+
:param dtype: type used for the internal buffers and parameters
59+
:param device: device used for the internal buffers and parameters
5660
"""
5761

5862
def __init__(
@@ -61,11 +65,15 @@ def __init__(
6165
lr_wavelength: float,
6266
full_neighbor_list: bool = False,
6367
prefactor: float = 1.0,
68+
dtype: Optional[torch.dtype] = None,
69+
device: Optional[torch.device] = None,
6470
):
6571
super().__init__(
6672
potential=potential,
6773
full_neighbor_list=full_neighbor_list,
6874
prefactor=prefactor,
75+
dtype=dtype,
76+
device=device,
6977
)
7078
if potential.smearing is None:
7179
raise ValueError(

src/torchpme/calculators/p3m.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24

35
from ..lib.kspace_filter import P3MKSpaceFilter
@@ -40,6 +42,8 @@ class P3MCalculator(PMECalculator):
4042
set to :py:obj:`False`, a "half" neighbor list is expected.
4143
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
4244
common values.
45+
:param dtype: type used for the internal buffers and parameters
46+
:param device: device used for the internal buffers and parameters
4347
4448
For an **example** on the usage for any calculator refer to :ref:`userdoc-how-to`.
4549
"""
@@ -51,6 +55,8 @@ def __init__(
5155
interpolation_nodes: int = 4,
5256
full_neighbor_list: bool = False,
5357
prefactor: float = 1.0,
58+
dtype: Optional[torch.dtype] = None,
59+
device: Optional[torch.device] = None,
5460
):
5561
self.mesh_spacing: float = mesh_spacing
5662

@@ -62,6 +68,8 @@ def __init__(
6268
potential=potential,
6369
full_neighbor_list=full_neighbor_list,
6470
prefactor=prefactor,
71+
dtype=dtype,
72+
device=device,
6573
)
6674

6775
self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(

src/torchpme/calculators/pme.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24
from torch import profiler
35

@@ -45,6 +47,8 @@ class PMECalculator(Calculator):
4547
set to :obj:`False`, a "half" neighbor list is expected.
4648
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
4749
common values.
50+
:param dtype: type used for the internal buffers and parameters
51+
:param device: device used for the internal buffers and parameters
4852
"""
4953

5054
def __init__(
@@ -54,11 +58,15 @@ def __init__(
5458
interpolation_nodes: int = 4,
5559
full_neighbor_list: bool = False,
5660
prefactor: float = 1.0,
61+
dtype: Optional[torch.dtype] = None,
62+
device: Optional[torch.device] = None,
5763
):
5864
super().__init__(
5965
potential=potential,
6066
full_neighbor_list=full_neighbor_list,
6167
prefactor=prefactor,
68+
dtype=dtype,
69+
device=device,
6270
)
6371

6472
if potential.smearing is None:
@@ -69,8 +77,8 @@ def __init__(
6977
self.mesh_spacing: float = mesh_spacing
7078

7179
self.kspace_filter: KSpaceFilter = KSpaceFilter(
72-
cell=torch.eye(3),
73-
ns_mesh=torch.ones(3, dtype=int),
80+
cell=torch.eye(3, dtype=self.dtype, device=self.device),
81+
ns_mesh=torch.ones(3, dtype=int, device=self.device),
7482
kernel=self.potential,
7583
fft_norm="backward",
7684
ifft_norm="forward",
@@ -79,8 +87,8 @@ def __init__(
7987
self.interpolation_nodes: int = interpolation_nodes
8088

8189
self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
82-
cell=torch.eye(3),
83-
ns_mesh=torch.ones(3, dtype=int),
90+
cell=torch.eye(3, dtype=self.dtype, device=self.device),
91+
ns_mesh=torch.ones(3, dtype=int, device=self.device),
8492
interpolation_nodes=self.interpolation_nodes,
8593
method="Lagrange", # convention for classic PME
8694
)

src/torchpme/lib/mesh_interpolator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,7 @@ def mesh_to_points(self, mesh_vals: torch.Tensor) -> torch.Tensor:
432432
"""
433433
if mesh_vals.dim() != 4:
434434
raise ValueError(
435-
f"`mesh_vals` of dimension {mesh_vals.dim()} has to be of "
436-
"dimension 4"
435+
f"`mesh_vals` of dimension {mesh_vals.dim()} has to be of dimension 4"
437436
)
438437

439438
return (

src/torchpme/potentials/combined.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def __init__(
4747
dtype=dtype,
4848
device=device,
4949
)
50-
if dtype is None:
51-
dtype = torch.get_default_dtype()
52-
if device is None:
53-
device = torch.device("cpu")
50+
5451
smearings = [pot.smearing for pot in potentials]
5552
if not all(smearings) and any(smearings):
5653
raise ValueError(
@@ -76,7 +73,9 @@ def __init__(
7673
"The number of initial weights must match the number of potentials being combined"
7774
)
7875
else:
79-
initial_weights = torch.ones(len(potentials), dtype=dtype, device=device)
76+
initial_weights = torch.ones(
77+
len(potentials), dtype=self.dtype, device=self.device
78+
)
8079
# for torchscript
8180
self.potentials = torch.nn.ModuleList(potentials)
8281
if learnable_weights:

src/torchpme/potentials/coulomb.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,17 @@ def __init__(
3838
device: Optional[torch.device] = None,
3939
):
4040
super().__init__(smearing, exclusion_radius, dtype, device)
41-
if dtype is None:
42-
dtype = torch.get_default_dtype()
43-
if device is None:
44-
device = torch.device("cpu")
4541

4642
# constants used in the forwward
4743
self.register_buffer(
4844
"_rsqrt2",
49-
torch.rsqrt(torch.tensor(2.0, dtype=dtype, device=device)),
45+
torch.rsqrt(torch.tensor(2.0, dtype=self.dtype, device=self.device)),
5046
)
5147
self.register_buffer(
5248
"_sqrt_2_on_pi",
53-
torch.sqrt(torch.tensor(2.0 / torch.pi, dtype=dtype, device=device)),
49+
torch.sqrt(
50+
torch.tensor(2.0 / torch.pi, dtype=self.dtype, device=self.device)
51+
),
5452
)
5553

5654
def from_dist(self, dist: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)