Skip to content

Commit 04edb22

Browse files
authored
Merge pull request #138 from lab-cosmo/splinegpu
SplinePotential device compatibility
2 parents 9742132 + 24cd5de commit 04edb22

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

docs/src/references/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
2727
Fixed
2828
#####
2929

30+
* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class
3031
* Fix inconsistent ``cutoff`` in neighbor list example
3132
* All calculators now check if the cell is zero if the potential is range-separated
3233

src/torchpme/potentials/spline.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def __init__(
7474
if len(y_grid) != len(r_grid):
7575
raise ValueError("Length of radial grid and value array mismatch.")
7676

77+
r_grid = r_grid.to(dtype=dtype, device=device)
78+
y_grid = y_grid.to(dtype=dtype, device=device)
79+
7780
if reciprocal:
7881
if torch.min(r_grid) <= 0.0:
7982
raise ValueError(
@@ -89,6 +92,8 @@ def __init__(
8992
k_grid = torch.pi * 2 * torch.reciprocal(r_grid).flip(dims=[0])
9093
else:
9194
k_grid = r_grid.clone()
95+
else:
96+
k_grid = k_grid.to(dtype=dtype, device=device)
9297

9398
if yhat_grid is None:
9499
# computes automatically!
@@ -98,6 +103,8 @@ def __init__(
98103
y_grid,
99104
compute_second_derivatives(r_grid, y_grid),
100105
)
106+
else:
107+
yhat_grid = yhat_grid.to(dtype=dtype, device=device)
101108

102109
# the function is defined for k**2, so we define the grid accordingly
103110
if reciprocal:
@@ -108,12 +115,14 @@ def __init__(
108115
self._krn_spline = CubicSpline(k_grid**2, yhat_grid)
109116

110117
if y_at_zero is None:
111-
self._y_at_zero = self._spline(torch.tensor([0.0]))
118+
self._y_at_zero = self._spline(torch.zeros(1, dtype=dtype, device=device))
112119
else:
113120
self._y_at_zero = y_at_zero
114121

115122
if yhat_at_zero is None:
116-
self._yhat_at_zero = self._krn_spline(torch.tensor([0.0]))
123+
self._yhat_at_zero = self._krn_spline(
124+
torch.zeros(1, dtype=dtype, device=device)
125+
)
117126
else:
118127
self._yhat_at_zero = yhat_at_zero
119128

@@ -140,7 +149,7 @@ def self_contribution(self) -> torch.Tensor:
140149
return self._y_at_zero
141150

142151
def background_correction(self) -> torch.Tensor:
143-
return torch.tensor([0.0])
152+
return torch.zeros(1)
144153

145154
from_dist.__doc__ = Potential.from_dist.__doc__
146155
lr_from_dist.__doc__ = Potential.lr_from_dist.__doc__

src/torchpme/utils/splines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def compute_second_derivatives(
198198
d2y = _solve_tridiagonal(a, b, c, d)
199199

200200
# Converts back to the original dtype
201-
return d2y.to(x_points.dtype)
201+
return d2y.to(dtype=x_points.dtype, device=x_points.device)
202202

203203

204204
def compute_spline_ft(

tests/test_potentials.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,3 +573,35 @@ def test_combined_potential_learnable_weights():
573573
loss.backward()
574574
optimizer.step()
575575
assert torch.allclose(combined.weights, weights - 0.1)
576+
577+
578+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
579+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
580+
@pytest.mark.parametrize(
581+
"potential_class", [CoulombPotential, InversePowerLawPotential, SplinePotential]
582+
)
583+
def test_potential_device_dtype(potential_class, device, dtype):
584+
if device == "cuda" and not torch.cuda.is_available():
585+
pytest.skip("CUDA is not available")
586+
587+
smearing = 1.0
588+
exponent = 1.0
589+
590+
if potential_class is InversePowerLawPotential:
591+
potential = potential_class(
592+
exponent=exponent, smearing=smearing, dtype=dtype, device=device
593+
)
594+
elif potential_class is SplinePotential:
595+
x_grid = torch.linspace(0, 20, 100, device=device, dtype=dtype)
596+
y_grid = torch.exp(-(x_grid**2) * 0.5)
597+
potential = potential_class(
598+
r_grid=x_grid, y_grid=y_grid, reciprocal=False, dtype=dtype, device=device
599+
)
600+
else:
601+
potential = potential_class(smearing=smearing, dtype=dtype, device=device)
602+
603+
dists = torch.linspace(0.1, 10.0, 100, device=device, dtype=dtype)
604+
potential_lr = potential.lr_from_dist(dists)
605+
606+
assert potential_lr.device.type == device
607+
assert potential_lr.dtype == dtype

0 commit comments

Comments
 (0)