11from typing import Optional
22
33import torch
4- from torch .special import gammainc , gammaincc , gammaln
4+ from torch .special import gammainc
55
6- from .potential import Potential
7-
8-
9- def gamma (x : torch .Tensor ) -> torch .Tensor :
10- """
11- (Complete) Gamma function.
6+ from torchpme .lib import gamma , gammaincc_over_powerlaw
127
13- pytorch has not implemented the commonly used (complete) Gamma function. We define
14- it in a custom way to make autograd work as in
15- https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122
16- """
17- return torch .exp (gammaln (x ))
8+ from .potential import Potential
189
1910
2011class InversePowerLawPotential (Potential ):
@@ -46,16 +37,16 @@ class InversePowerLawPotential(Potential):
4637
4738 def __init__ (
4839 self ,
49- exponent : float ,
40+ exponent : int ,
5041 smearing : Optional [float ] = None ,
5142 exclusion_radius : Optional [float ] = None ,
5243 dtype : Optional [torch .dtype ] = None ,
5344 device : Optional [torch .device ] = None ,
5445 ):
5546 super ().__init__ (smearing , exclusion_radius , dtype , device )
5647
57- if exponent <= 0 or exponent > 3 :
58- raise ValueError ( f"` exponent` p= { exponent } has to satisfy 0 < p <= 3" )
48+ # function call to check the validity of the exponent
49+ gammaincc_over_powerlaw ( exponent , torch . tensor ( 1.0 , dtype = dtype , device = device ) )
5950 self .register_buffer (
6051 "exponent" , torch .tensor (exponent , dtype = self .dtype , device = self .device )
6152 )
@@ -130,9 +121,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
130121 # for consistency reasons.
131122 masked = torch .where (x == 0 , 1.0 , x ) # avoid NaNs in backwards, see Coulomb
132123 return torch .where (
133- k_sq == 0 ,
134- 0.0 ,
135- prefac * gammaincc (peff , masked ) / masked ** peff * gamma (peff ),
124+ k_sq == 0 , 0.0 , prefac * gammaincc_over_powerlaw (exponent , masked )
136125 )
137126
138127 def self_contribution (self ) -> torch .Tensor :
@@ -145,7 +134,11 @@ def self_contribution(self) -> torch.Tensor:
145134 return 1 / gamma (phalf + 1 ) / (2 * self .smearing ** 2 ) ** phalf
146135
147136 def background_correction (self ) -> torch .Tensor :
148- # "charge neutrality" correction for 1/r^p potential
137+ # "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3
138+ # and is not needed for p > 3 , so we set it to zero (see in
139+ # https://doi.org/10.48550/arXiv.2412.03281 SI section)
140+ if self .exponent >= 3 :
141+ return torch .tensor (0.0 , dtype = self .dtype , device = self .device )
149142 if self .smearing is None :
150143 raise ValueError (
151144 "Cannot compute background correction without specifying `smearing`."
0 commit comments