Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
165 commits
Select commit Hold shift + click to select a range
e059dd8
Tomo base plotting fixes
Sep 16, 2025
993ba86
Semi-working NeRF implementation; bug where it's not converging corre…
Sep 17, 2025
66502b2
NeRF kinda working; bugs with convergence
Sep 17, 2025
bcb7d1d
TV loss squeeze axis
Sep 17, 2025
0bb881b
NeRF works, barebones. Start writing into quantem style
Sep 18, 2025
a6a7014
Pre-transferring to quantem code
Sep 22, 2025
e162dde
Baseline working Nerf reconstructions; created a TomoDDP class
Sep 23, 2025
bb40c4b
Tomo optimizers + schedulers working; need to do object_models
Sep 23, 2025
c704ce3
Working DDP with objects and stuff
Sep 23, 2025
1960f35
Soft constraints also work
Sep 23, 2025
b0ccd8c
Multi-step training with different schedulers kind of working. Need t…
Sep 23, 2025
f9a8afa
Implemented ObjectINN with create volume in it (create_volume new ver…
Sep 23, 2025
f3b5b86
Tomo-NeRF working fully; just need to clean up a little bit
Sep 23, 2025
9c0685e
Merge pull request #103 from electronmicroscopy/tomography-hpc-2
cedriclim1 Sep 23, 2025
867bb8e
Initial commit for background subtraction
cophus Oct 7, 2025
0d09a39
Updating plots for background subtraction
cophus Oct 7, 2025
f34b9e1
Adding function overloading
cophus Oct 7, 2025
9f9ef88
More conservative fitting order
cophus Oct 7, 2025
da10551
Switching from @overload to TypeVar
cophus Oct 10, 2025
bc46375
Merge pull request #109 from cophus/tomo-background
cedriclim1 Oct 14, 2025
f1dc677
Tomo changes
Oct 14, 2025
0136682
Merge branch 'dev' into tomography
Oct 14, 2025
a57d885
Enforcing positivity optional in tomography_dataset with clamp flag
cedriclim1 Oct 16, 2025
078f351
Added SIREN, HSIREN, and Finer with allowed complex inputs
Oct 14, 2025
d83deb1
Added FinerActivation to get_activation_function
Oct 14, 2025
4227e58
Fixed PtychoLite to call from cnn.py instead of cnn2d (now removed)
Oct 14, 2025
2e70de3
Siren + HSiren need np.sqrt instead of torch.sqrt for .uniform_
Oct 14, 2025
9bbb54e
net_list fix on Siren models
Oct 14, 2025
0aec975
Softplus missing from self.net_list
Oct 14, 2025
35eae44
Print pred on cross_Correlation align stack
Oct 23, 2025
0a99af5
Adding back in validation set
Nov 12, 2025
3971489
Added pretraininig functionality, in core ML added custom loss functi…
Nov 21, 2025
69efcfd
Updates
Dec 6, 2025
9937bc5
Test
cedriclim1 Dec 15, 2025
700f205
Outlining dataset_models.py, and the top level Tomography class
cedriclim1 Dec 16, 2025
b90499d
Need to think about tomography_ddp a little bit more, also what shoul…
cedriclim1 Dec 17, 2025
d73c498
SIRT Reconstructions working. TomographyConventional looks a little c…
cedriclim1 Dec 19, 2025
bc98f4b
Implemented tomography_opt.py
cedriclim1 Dec 19, 2025
0fa3740
Starting to write the ML methods for Tomography; Need to figure out h…
cedriclim1 Dec 19, 2025
31bc2eb
Starting DDP stuff
cedriclim1 Jan 20, 2026
acb998e
Pulling from dev
cedriclim1 Jan 20, 2026
70ffb31
Implementing pretraining for object_models, along with added function…
cedriclim1 Jan 20, 2026
0ed758a
Object pretraining INR working
cedriclim1 Jan 20, 2026
c90984c
DDPMixin in ML, pretraining working for objects
cedriclim1 Jan 21, 2026
97c80d6
Added cosine annealing to set_scheduler in OptimizerMixin
cedriclim1 Jan 21, 2026
9900fc3
Reworking TomographyINRDatasets; need to figure out what to do for au…
cedriclim1 Jan 22, 2026
4bda57f
Some device switching bugs that need to be addressed.
cedriclim1 Jan 23, 2026
cc9188f
Working reconstruction loop, need to figure out this device stuff and…
cedriclim1 Jan 23, 2026
d08c334
Various u[dates
cedriclim1 Jan 26, 2026
9a37e05
Logger implementation
cedriclim1 Jan 27, 2026
cfee59a
DDP bug where some projection idx's don't get optimized.
Jan 27, 2026
370934a
DDP projection indices fixed; added hard constraints to the forward m…
Jan 28, 2026
b80cd68
NVIDIA Profiling testing added in the reconstruction loop in tomograp…
cedriclim1 Jan 28, 2026
a61cce3
Starting profiling of the reconstruction loop; need to move stuff ove…
cedriclim1 Jan 29, 2026
8f6cb02
Val + train test split implemented - cuBLAS error after adding this n…
cedriclim1 Jan 30, 2026
791a644
Small updates
cedriclim1 Jan 30, 2026
4b7f108
Implemented a working TomographyLite, need to test AutoSerialize, and…
cedriclim1 Jan 31, 2026
49b2b26
Added option to only learn parts of the pose i.e, shifts or tilt axis…
cedriclim1 Feb 1, 2026
da6bbf9
Fix for imaging_utils.py
cedriclim1 Feb 1, 2026
0d6d769
DDP fix validation sampler initialization.
Feb 1, 2026
4005d7a
Save volume DDP fix
cedriclim1 Feb 2, 2026
c3afcbc
Merge branch 'dev' into tomography_refactor
cedriclim1 Feb 2, 2026
e1cd567
Updates
cedriclim1 Feb 2, 2026
e9bf8da
Reinstantiating new dataloaders in the training loop; kinda messy, ne…
cedriclim1 Feb 4, 2026
aea1c1b
Fixed some device issues, one process is still hanging on cuda:0, nee…
cedriclim1 Feb 4, 2026
c4c6e8a
Removed profiling stuff
cedriclim1 Feb 23, 2026
162485f
Merged object_models.py from eds_tomography, this allows for multimod…
cedriclim1 Feb 23, 2026
75721e1
Object creation - channels on the last axis.
cedriclim1 Feb 23, 2026
46dda7c
constraints.py; Fixed type hinting and added more description to the …
cedriclim1 Feb 24, 2026
5c617d3
Cleaned up DDP module; offloaded loading of weights to the object_mod…
cedriclim1 Feb 24, 2026
9465aa7
Added deterministic random winner initialization in inr.py
cedriclim1 Feb 24, 2026
05901bf
inr.py, ignore some type hinting stuff that seems not correct...?
cedriclim1 Feb 24, 2026
953ee94
Removed MSE and L1 implementations in
cedriclim1 Feb 24, 2026
bde606f
Background subtraction in imaging_utils.py
cedriclim1 Feb 24, 2026
c0b61a3
Fixed _token calling .from_data in TomographyDatasetBase
cedriclim1 Feb 24, 2026
d42611b
Reconciled params and get_optimization_parameters type-hinting. Addit…
cedriclim1 Feb 24, 2026
aea34b1
Object models also consistent with dataset models in terms of params …
cedriclim1 Feb 24, 2026
038e94d
More type hinting
cedriclim1 Feb 24, 2026
edab112
Remove params function, use nn.Module.parameters() directly in get_op…
cedriclim1 Feb 24, 2026
888e6b9
Updated TomographyPixDataset docstring
cedriclim1 Feb 24, 2026
c0a8777
Removed DatasetValue from the forward call in TomgoraphyINRDataset, t…
cedriclim1 Feb 24, 2026
6d9cbf6
Class instantiation for ObjectPixelated; from_uniform and from_array
cedriclim1 Feb 24, 2026
ddd473d
Pretraining in ObjectINR, implemented reset()
cedriclim1 Feb 24, 2026
d51aa51
.to(device) called again at the reconstruct call just to make sure de…
cedriclim1 Feb 24, 2026
16aa0f2
For TomographyLite classes, .from_dataset now only allows for tilt_se…
cedriclim1 Feb 24, 2026
945064e
Made tomography_utils.py; moved background_subtraction and other util…
cedriclim1 Feb 24, 2026
c2b7d26
Made save_volume more explicit in top-level Tomography.py
cedriclim1 Feb 24, 2026
2b7848d
Logging now logs all channels if there are any additional ones. In th…
cedriclim1 Feb 24, 2026
384ce2f
Added abstract method decorator
cedriclim1 Feb 24, 2026
446a6da
core/ml/loss_functions; Changed everything to modules check if this w…
cedriclim1 Feb 25, 2026
c65b63b
dataset_models.py bug fixes. Distribute model is not working with cud…
cedriclim1 Feb 26, 2026
f7df0a4
Sparsity loss implemented from origin/f_inr_tomography
cedriclim1 Mar 2, 2026
424bd87
:Merge branch 'tomography_refactor' of https://github.com/electronmic…
cedriclim1 Mar 2, 2026
50e21b0
get_loss_module functionality, allowing for kwargs in the loss functon.
cedriclim1 Mar 2, 2026
1b4aeeb
Removed KE regularization from DINR paper -- will probably move off t…
cedriclim1 Mar 2, 2026
bfc6e06
Fixed all linting issues in object_models.py. @amccray, ask about ign…
cedriclim1 Mar 2, 2026
886ef22
ddp.py linting errors addressed by ignoring. Again @arthurmccray, sin…
cedriclim1 Mar 2, 2026
471bd6b
Cleaned up tomography_base.py; inheritting from ObjectConstraints to …
cedriclim1 Mar 2, 2026
9336e4f
Cleaned up some tomography.py; lots of linter errors still but meh
cedriclim1 Mar 2, 2026
ce28e29
_token fix for TomographyBase and some other stuff
cedriclim1 Mar 2, 2026
9ecd8d7
Added TomographyDatasetConstraints to dataset_models.py. Perhaps ther…
cedriclim1 Mar 3, 2026
a535083
Reworded in top-level tomography: obj_constraints and dset_constraints
cedriclim1 Mar 3, 2026
2e87529
Added pbar to INR reconstruction with verbose option. Deleted print s…
cedriclim1 Mar 3, 2026
bd9dd80
Deleted some more unnecessary prints in the logger.
cedriclim1 Mar 3, 2026
4ebf7c2
Scientific notation for the outputs in the progress bar
cedriclim1 Mar 3, 2026
f2d2142
Added an obj_view method in objectinr that does the transpose for com…
cedriclim1 Mar 3, 2026
ceea355
DDP errors when saving the dataloader, added a rebuild dataloader met…
cedriclim1 Mar 3, 2026
bf34a1e
Updates from comments from PR from quantem-tutorials
cedriclim1 Mar 3, 2026
9fce382
Starting OptimizerParams and SchedulerParams stuff in optimizer_mixin.py
cedriclim1 Mar 4, 2026
6f89ac0
More scheduler support for optimizer_mixin.py; still allows for dicti…
cedriclim1 Mar 4, 2026
18d186f
Optimizer and Scheduler hinting - tested on tomography tutorials.
cedriclim1 Mar 4, 2026
213988c
Fixes
cedriclim1 Mar 4, 2026
137ccac
Merge branch 'opt_sched_typehinting' into tomography_refactor_optsche…
cedriclim1 Mar 5, 2026
737a7bc
Fixes to TomographyOpt.py
cedriclim1 Mar 5, 2026
87867d2
Merge remote-tracking branch 'upstream/dev' into tomography_refactor_…
cedriclim1 Mar 5, 2026
4fb5862
Bugs in tomography_opt fixed; scheduler and optimizer of pose not bei…
cedriclim1 Mar 5, 2026
04a7d65
Bug fixes in OptimizerMixin
cedriclim1 Mar 5, 2026
d3beb50
Scheduler bug fix as well in parse_dict
cedriclim1 Mar 5, 2026
b40f946
Optimizer mixin hotfix fastforward
cedriclim1 Mar 5, 2026
aacd82d
Added new tests as well
cedriclim1 Mar 5, 2026
1fb2ca9
Added DatasetConstraintParams, works with dictionary inputs still
cedriclim1 Mar 5, 2026
2967c9d
Merge branch 'dev' into tomography_refactor
cedriclim1 Mar 5, 2026
f5d6351
OptimizerMixin fix tests
cedriclim1 Mar 5, 2026
c1a7e23
Docstrings for dataset_models and object_models
cedriclim1 Mar 5, 2026
2c91064
Readded dataset models type back in
cedriclim1 Mar 5, 2026
882d9e9
Bug
cedriclim1 Mar 5, 2026
64ee01e
NumPy style docstrings
cedriclim1 Mar 5, 2026
47f31ea
Some typos in the docstrings
cedriclim1 Mar 5, 2026
586f3bc
Added some __str__ for the OptimizerParams; not sure if needed but be…
cedriclim1 Mar 5, 2026
922dc12
Verbose, if false will just start printing
cedriclim1 Mar 10, 2026
8472dfb
Added relaxation term to SIRT reconstruction, stable gradients
cedriclim1 Mar 16, 2026
0d34fc0
Merge branch 'tomography_refactor' of https://github.com/electronmicr…
cedriclim1 Mar 16, 2026
3621568
Relaxation parameters for SIRT reconstructions
cedriclim1 Mar 17, 2026
614ffa1
DDP dist.all_reduce for scheduler stepping
cedriclim1 Mar 19, 2026
8772d89
.to() for object_models fixed; Note should check if world_size > 1 an…
cedriclim1 Mar 20, 2026
196b9e9
@arthurmccray cuBLAS error on the first iteration is due to torch.aut…
cedriclim1 Mar 22, 2026
3e19602
@arthurmccray all Python 3.14 conflcits have been resolved.
cedriclim1 Mar 22, 2026
c7be969
TomographyLite updated to match ObjConstraints and DatasetConstraints
cedriclim1 Mar 23, 2026
a7f9b5d
Fix for TomographyLite when reloading
cedriclim1 Mar 23, 2026
73c282f
Show metrics total loss + lrs from different optimizers after finishe…
cedriclim1 Mar 24, 2026
70ecd70
Added show_metrics to SIRT reconstructions as well
cedriclim1 Mar 24, 2026
c91148f
Added a check to make sure that if a list(tuple) is given it should …
cedriclim1 Mar 24, 2026
fd13a8e
Merge branch 'dev' into tomography_refactor
cedriclim1 Mar 24, 2026
c2f46fe
Setting up DDP for obj_model print statement in TomographyBase now on…
cedriclim1 Mar 24, 2026
3d74da8
Merge branch 'tomography_refactor' of https://github.com/electronmicr…
cedriclim1 Mar 24, 2026
8e1c1ef
uv.lock fix?
cedriclim1 Mar 24, 2026
b922cbb
num_samples_per_ray schedule provided set to only global_rank 0 print
cedriclim1 Mar 24, 2026
c81b2b5
Type-hinting ignore fix in inr.py
cedriclim1 Mar 24, 2026
62616d7
to_numpy() added in tomography_utils.py, and edited background_subtract
cedriclim1 Mar 24, 2026
3567b4e
tomography_utils.py diff_shift_2d updated docstring
cedriclim1 Mar 24, 2026
5923624
plot_losses() function implemented both for conventional and INR algo…
cedriclim1 Mar 24, 2026
0ec02df
Quick fix on object_models DDP instantiation, global_rank is not defi…
cedriclim1 Mar 24, 2026
8afb028
logger_tomography fixed linting
cedriclim1 Mar 24, 2026
4ebb1d9
TomographyDataset quantile linter error fixed
cedriclim1 Mar 24, 2026
435e623
Linter error in .reconstruct
cedriclim1 Mar 24, 2026
5c10509
Fixed all linter issues in top-level tomography.py
cedriclim1 Mar 24, 2026
fc0c675
logger_tomography.py type-hinting fix
cedriclim1 Mar 24, 2026
462e7a0
dataset_models.py type-hinting
cedriclim1 Mar 24, 2026
238bffd
radon.py type-hinting fix
cedriclim1 Mar 24, 2026
1b3b0cd
tomography_opt.py linting errors fixed by Claude
cedriclim1 Mar 24, 2026
ab44488
Removed tomography_old and put forkserver back in instead of spawn.
cedriclim1 Mar 24, 2026
8d85f30
Nevermind back to spawn, DataLoader doesn't like to be in forkserver
cedriclim1 Mar 24, 2026
d6e2aa8
Zero-pad docs fixed
cedriclim1 Mar 24, 2026
9ebc845
Merge branch 'tomography_refactor' of https://github.com/electronmicr…
cedriclim1 Mar 24, 2026
28021af
fixing couple typehints
arthurmccray Mar 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ dev = [
"pre-commit>=4.2.0",
"ruff>=0.11.5",
"tomli>=2.2.1",
]
]
100 changes: 100 additions & 0 deletions src/quantem/core/ml/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Self

import numpy as np
import torch
from numpy.typing import NDArray


@dataclass(slots=False)
class Constraints(ABC):
"""
Any model that inherits from BaseConstraints will contain a Constraints instance that contains soft and hard constraints.
"""

soft_constraint_keys = []
hard_constraint_keys = []

@property
def allowed_keys(self) -> list[str]:
"""
List of all allowed keys.
"""
return self.hard_constraint_keys + self.soft_constraint_keys

def copy(self) -> Self:
"""
Copy the constraints.
"""
return deepcopy(self)

def __str__(self) -> str:
hard = "\n".join(f"{key}: {getattr(self, key)}" for key in self.hard_constraint_keys)
soft = "\n".join(f"{key}: {getattr(self, key)}" for key in self.soft_constraint_keys)

# Fix: Move the replace operations outside the f-string or assign to variables
hard_indented = hard.replace("\n", "\n ")
soft_indented = soft.replace("\n", "\n ")

return (
"Constraints:\n"
" Hard constraints:\n"
f" {hard_indented}\n"
" Soft constraints:\n"
f" {soft_indented}"
)


class BaseConstraints(ABC):
"""
Base class for constraints.
"""

# Default constraints are the dataclasses themselves.
DEFAULT_CONSTRAINTS = Constraints()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._soft_constraint_losses = []
self.constraints = self.DEFAULT_CONSTRAINTS.copy()

@property
def soft_constraint_losses(self) -> NDArray[np.float32]:
return np.array(self._soft_constraint_losses, dtype=np.float32)

@property
def constraints(self) -> Constraints:
"""
Constraints for the model.
"""
return self._constraints

@constraints.setter
def constraints(self, constraints: Constraints | dict[str, Any]):
"""
Setter for constraints class, can be a Constraints instance or a dictionary.
"""
if isinstance(constraints, Constraints):
self._constraints = constraints
elif isinstance(constraints, dict):
for key, value in constraints.items():
setattr(self._constraints, key, value)
else:
raise ValueError(f"Invalid constraints type: {type(constraints)}")

# --- Required methods tha tneeds to implemented in subclasses ---
@abstractmethod
def apply_hard_constraints(self, *args, **kwargs) -> torch.Tensor:
"""
Apply hard constraints to the model.
"""
raise NotImplementedError

@abstractmethod
def apply_soft_constraints(self, *args, **kwargs) -> torch.Tensor:
"""
Apply soft constraints to the model.
"""
raise NotImplementedError
181 changes: 181 additions & 0 deletions src/quantem/core/ml/ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import os

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, DistributedSampler, random_split

from quantem.tomography.dataset_models import DatasetModelType


def worker_init_fn(worker_id):
os.environ["CUDA_VISIBLE_DEVICES"] = ""


class DDPMixin:
"""
Class for setting up all distributed training.

-
"""

def setup_distributed(self, device: str | torch.device | None = None):
"""
Initializes parameters depending if multiple-GPU training, single-GPU training, or CPU training.
"""
if "RANK" in os.environ:
if not dist.is_initialized():
dist.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo", init_method="env://"
)

self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
self.local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(self.local_rank)
device = torch.device("cuda", self.local_rank)
else:
self.world_size = 1
self.global_rank = 0
self.local_rank = 0

if torch.cuda.is_available():
device = torch.device("cuda:0" if device is None else device)
torch.cuda.set_device(device.index)
else:
device = torch.device("cpu")

if device.type == "cuda":
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

self.device = device

def setup_dataloader(
self,
dataset: Dataset | DatasetModelType,
batch_size: int,
num_workers: int = 0,
val_fraction: float = 0.0,
):
pin_mem = self.device.type == "cuda"
persist = num_workers > 0

if val_fraction > 0.0:
train_dataset, val_dataset = random_split(dataset, [1 - val_fraction, val_fraction]) # type: ignore[reportArgumentType] --> dataset inherits from torch Dataset so this is fine.
else:
train_dataset = dataset
val_dataset = None

if self.world_size > 1:
shuffle = True
train_sampler = DistributedSampler(
train_dataset, # type: ignore[reportArgumentType] --> Torch datasets do not have a len method, but still works.
num_replicas=self.world_size,
rank=self.global_rank,
shuffle=shuffle,
)

if val_dataset:
val_sampler = DistributedSampler(
val_dataset,
num_replicas=self.world_size,
rank=self.global_rank,
shuffle=False,
)
else:
val_sampler = None
shuffle = False

else:
train_sampler = None
val_sampler = None
shuffle = True

train_dataloader = DataLoader(
train_dataset, # type: ignore[reportArgumentType] --> Torch datasets do not have a len method, but still works.
batch_size=batch_size,
num_workers=num_workers,
sampler=train_sampler,
shuffle=shuffle,
pin_memory=pin_mem,
drop_last=True,
persistent_workers=persist,
multiprocessing_context="spawn",
worker_init_fn=worker_init_fn,
)

if val_dataset:
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size * 4,
num_workers=num_workers,
sampler=val_sampler,
shuffle=False,
pin_memory=pin_mem,
drop_last=False,
persistent_workers=persist,
multiprocessing_context="spawn",
worker_init_fn=worker_init_fn,
)
val_dataloader = val_dataloader
else:
val_dataloader = None

if self.global_rank == 0:
print("Dataloader setup complete:")
print(f" Total train samples: {len(train_dataset)}") # pyright: ignore[reportArgumentType] --> Torch datasets do not have a len method, but still works.
print(f" Local batch size: {batch_size}")
print(f" Global batch size: {batch_size * self.world_size}")
print(f" Train batches per GPU per epoch: {len(train_dataloader)}")

if val_dataset:
print(f" Total val samples: {len(val_dataset)}")
print(f" Val batches per GPU per epoch: {len(val_dataloader)}") # pyright: ignore[reportArgumentType] --> Torch datasets do not have a len method, but still works.

return train_dataloader, train_sampler, val_dataloader, val_sampler

def distribute_model(
self,
model: nn.Module,
) -> nn.Module | nn.parallel.DistributedDataParallel:
"""
Wraps the model with DistributedDataParallel if mulitple GPUs are available.

Returns the model.
"""
model = model.to(self.device)

if self.world_size > 1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.local_rank],
output_device=self.local_rank,
find_unused_parameters=False,
broadcast_buffers=True,
bucket_cap_mb=100,
gradient_as_bucket_view=True,
)

if self.global_rank == 0:
print("Model wrapped with DDP and compiled")

if self.world_size > 1:
if self.global_rank == 0:
print("Model built, distributed, and compiled successfully")

else:
print("Model built, compiled successfully")

return model

@property
def device(self) -> torch.device:
return self._device

@device.setter
def device(self, device: torch.device | str):
if isinstance(device, str):
device = torch.device(device)
self._device = device
20 changes: 19 additions & 1 deletion src/quantem/core/ml/inr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
hsiren: bool = False,
dtype: torch.dtype = torch.float32,
final_activation: str | Callable = "identity",
winner_initialization: bool | int = False,
) -> None:
"""Initialize Siren.

Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(
self.alpha = alpha
self.hsiren = hsiren
self.dtype = dtype

self.winner_initialization = winner_initialization
self.final_activation = final_activation

self._build()
Expand Down Expand Up @@ -109,6 +110,21 @@ def _build(self) -> None:
net_list.append(self._final_activation)
self.net = nn.Sequential(*net_list)

if self.winner_initialization:
if type(self.winner_initialization) is int:
rng = torch.Generator()
rng.manual_seed(self.winner_initialization)
else:
rng = torch.Generator()
rng.manual_seed(42)
with torch.no_grad():
self.net[0].linear.weight += ( # type: ignore[reportAttributeAccessIssue]
torch.randn_like(self.net[0].linear.weight) * 5 / self.first_omega_0 # type:ignore
)
self.net[1].linear.weight += ( # type: ignore[reportAttributeAccessIssue]
torch.randn_like(self.net[1].linear.weight) * 0.1 / self.hidden_omega_0 # type:ignore
)

def forward(self, coords: torch.Tensor) -> torch.Tensor:
output = self.net(coords)
return output
Expand Down Expand Up @@ -201,6 +217,7 @@ def __init__(
alpha: float = 1.0,
dtype: torch.dtype = torch.float32,
final_activation: str | Callable = "identity",
winner_initialization: bool | int = False,
) -> None:
"""Initialize HSiren.

Expand Down Expand Up @@ -236,4 +253,5 @@ def __init__(
hsiren=True,
dtype=dtype,
final_activation=final_activation,
winner_initialization=winner_initialization,
)
Loading