Skip to content
32 changes: 32 additions & 0 deletions news/optimize-messages.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
**Added:**

* 'SNMFOptimizer.objective_log' attr: dictionary list to track the optimization
process, recording the step, iteration, objective, and timestamp at each update.
Uses the 'step', 'iteration', 'objective' and 'timestamp' keys.
* 'SNMFOptimizer(verbose : Optional[bool])' option and SNMFOptimizer.verbose
attribute to allow users to toggle diagnostic console output.

**Changed:**

* Modified all print messages for improved readability and tied them to the new
verbose flag.
* Refactored convergence checks and step-size calculations to pull objective
values directly from objective_log instead of relying on a separate history
array.

**Deprecated:**

* <news item>

**Removed:**

* Removed the 'SNMFOptimizer._objective_history' list, which was made redundant
by the comprehensive 'SNMFOptimizer.objective_log' tracking system.

**Fixed:**

* <news item>

**Security:**

* <news item>
155 changes: 98 additions & 57 deletions src/diffpy/stretched_nmf/snmf_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import cvxpy as cp
import numpy as np
from scipy.optimize import minimize
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
eta=0,
random_state=None,
show_plots=False,
verbose=True,
):
"""Initialize an instance of sNMF with estimator
hyperparameters.
Expand Down Expand Up @@ -126,6 +129,7 @@ def __init__(
self.eta = eta
self.random_state = random_state
self.show_plots = show_plots
self.verbose = verbose

self._rng = np.random.default_rng(self.random_state)
self._plotter = SNMFPlotter() if self.show_plots else None
Expand Down Expand Up @@ -214,6 +218,7 @@ def _initialize_factors(
[1, -2, 1],
offsets=[0, 1, 2],
shape=(self.n_signals_ - 2, self.n_signals_),
dtype=float,
)

def fit(
Expand Down Expand Up @@ -303,7 +308,14 @@ def fit(
self.stretch_.copy(),
]
self.objective_difference_ = None
self._objective_history = [self.objective_function_]
self.objective_log = [
{
"step": "start",
"iteration": 0,
"objective": self.objective_function_,
"timestamp": time.time(),
}
]

# Set up tracking variables for _update_components()
self._prev_components = None
Expand All @@ -322,13 +334,15 @@ def fit(
sparsity_term = self.eta * np.sum(
np.sqrt(self.components_)
) # Square root penalty
objective_without_penalty = (
base_obj = (
self.objective_function_ - regularization_term - sparsity_term
)
print(
f"Start, Objective function: {self.objective_function_:.5e}"
f", Obj - reg/sparse: {objective_without_penalty:.5e}"
)
if self.verbose:
print(
f"\n--- Start ---"
f"\nTotal Objective : {self.objective_function_:.5e}"
f"\nBase Obj (No Reg) : {base_obj:.5e}"
)

# Main optimization loop
for outiter in range(self.max_iter):
Expand All @@ -347,35 +361,23 @@ def fit(
sparsity_term = self.eta * np.sum(
np.sqrt(self.components_)
) # Square root penalty
objective_without_penalty = (
base_obj = (
self.objective_function_ - regularization_term - sparsity_term
)
print(
f"Obj fun: {self.objective_function_:.5e}, "
f"Obj - reg/sparse: {objective_without_penalty:.5e}, "
f"Iter: {self._outer_iter}"
)
obj_diff = (
self.objective_function - regularization_term - sparsity_term
)
print(
f"Obj fun: {self.objective_function:.5e}, "
f", Obj - reg/sparse: {obj_diff:.5e}"
f"Iter: {self.outiter}"
)

convergence_threshold = self.objective_function_ * self.tol
# Convergence check: Stop if diffun is small
# and at least min_iter iterations have passed
print(
"Checking if ",
self.objective_difference_,
" < ",
self.objective_function_ * self.tol,
)
if self.verbose:
print(
f"\n--- Iteration {self._outer_iter} ---"
f"\nTotal Objective : {self.objective_function_:.5e}"
f"\nBase Obj (No Reg) : {base_obj:.5e}"
"\nConvergence Check : Δ "
f"({self.objective_difference_:.2e})"
f" < Threshold ({convergence_threshold:.2e})\n"
)
if (
self.objective_difference_ is not None
and self.objective_difference_
< self.objective_function_ * self.tol
self.objective_difference_ < convergence_threshold
and outiter >= self.min_iter
):
self.converged_ = True
Expand All @@ -387,6 +389,8 @@ def fit(
return self

def _normalize_results(self):
if self.verbose:
print("\nNormalizing results after convergence...")
# Select our best results for normalization
self.components_ = self.best_matrices_[0]
self.weights_ = self.best_matrices_[1]
Expand Down Expand Up @@ -420,13 +424,17 @@ def _normalize_results(self):
self._update_components()
self.residuals_ = self._get_residual_matrix()
self.objective_function_ = self._get_objective_function()
print(
f"Objective function after normalize_components: "
f"{self.objective_function_:.5e}"
self.objective_log.append(
{
"step": "c_norm",
"iteration": outiter,
"objective": self.objective_function_,
"timestamp": time.time(),
}
)
self._objective_history.append(self.objective_function_)
self.objective_difference_ = (
self._objective_history[-2] - self._objective_history[-1]
self.objective_log[-2]["objective"]
- self.objective_log[-1]["objective"]
)
if self._plotter is not None:
self._plotter.update(
Expand All @@ -435,27 +443,41 @@ def _normalize_results(self):
stretch=self.stretch_,
update_tag="normalize components",
)
convergence_threshold = self.objective_function_ * self.tol
if self.verbose:
print(
f"\n--- Iteration {outiter} after normalization---"
f"\nTotal Objective : {self.objective_function_:.5e}"
"\nConvergence Check : Δ "
f"({self.objective_difference_:.2e})"
f" < Threshold ({convergence_threshold:.2e})\n"
)
if (
self.objective_difference_
< self.objective_function_ * self.tol
self.objective_difference_ < convergence_threshold
and outiter >= 7
):
break

def _outer_loop(self):
if self.verbose:
print("Updating components and weights...")
for inner_iter in range(4):
self._inner_iter = inner_iter
self._prev_grad_components = self._grad_components.copy()
self._update_components()
self.residuals_ = self._get_residual_matrix()
self.objective_function_ = self._get_objective_function()
print(
f"Objective function after _update_components: "
f"{self.objective_function_:.5e}"
self.objective_log.append(
{
"step": "c",
"iteration": self._outer_iter,
"objective": self.objective_function_,
"timestamp": time.time(),
}
)
self._objective_history.append(self.objective_function_)
self.objective_difference_ = (
self._objective_history[-2] - self._objective_history[-1]
self.objective_log[-2]["objective"]
- self.objective_log[-1]["objective"]
)
if self.objective_function_ < self.best_objective_:
self.best_objective_ = self.objective_function_
Expand All @@ -475,13 +497,18 @@ def _outer_loop(self):
self._update_weights()
self.residuals_ = self._get_residual_matrix()
self.objective_function_ = self._get_objective_function()
print(
f"Objective function after _update_weights: "
f"{self.objective_function_:.5e}"
self.objective_log.append(
{
"step": "w",
"iteration": self._outer_iter,
"objective": self.objective_function_,
"timestamp": time.time(),
}
)
self._objective_history.append(self.objective_function_)

self.objective_difference_ = (
self._objective_history[-2] - self._objective_history[-1]
self.objective_log[-2]["objective"]
- self.objective_log[-1]["objective"]
)
if self.objective_function_ < self.best_objective_:
self.best_objective_ = self.objective_function_
Expand All @@ -499,10 +526,11 @@ def _outer_loop(self):
)

self.objective_difference_ = (
self._objective_history[-2] - self._objective_history[-1]
self.objective_log[-2]["objective"]
- self.objective_log[-1]["objective"]
)
if (
self._objective_history[-3] - self.objective_function_
self.objective_log[-3]["objective"] - self.objective_function_
< self.objective_difference_ * 1e-3
):
break
Expand All @@ -512,13 +540,17 @@ def _outer_loop(self):
self._update_stretch()
self.residuals_ = self._get_residual_matrix()
self.objective_function_ = self._get_objective_function()
print(
f"Objective function after _update_stretch: "
f"{self.objective_function_:.5e}"
self.objective_log.append(
{
"step": "s",
"iteration": self._outer_iter,
"objective": self.objective_function_,
"timestamp": time.time(),
}
)
self._objective_history.append(self.objective_function_)
self.objective_difference_ = (
self._objective_history[-2] - self._objective_history[-1]
self.objective_log[-2]["objective"]
- self.objective_log[-1]["objective"]
)
if self.objective_function_ < self.best_objective_:
self.best_objective_ = self.objective_function_
Expand Down Expand Up @@ -804,7 +836,12 @@ def _solve_quadratic_program(self, t, m):

# Solve using a QP solver
prob = cp.Problem(objective, constraints)
prob.solve(solver=cp.OSQP, verbose=False)
prob.solve(
solver=cp.OSQP,
verbose=False,
polish=False, # TODO keep? removes polish message
# solver_verbose=False
)

# Get the solution
return np.maximum(
Expand All @@ -814,6 +851,7 @@ def _solve_quadratic_program(self, t, m):
def _update_components(self):
"""Updates `components` using gradient-based optimization with
adaptive step size."""

# Compute stretched components using the interpolation function
stretched_components, _, _ = (
self._compute_stretched_components()
Expand Down Expand Up @@ -878,8 +916,8 @@ def _update_components(self):
)
self.components_ = mask * self.components_

objective_improvement = self._objective_history[
-1
objective_improvement = self.objective_log[-1][
"objective"
] - self._get_objective_function(
residuals=self._get_residual_matrix()
)
Expand Down Expand Up @@ -962,6 +1000,9 @@ def _update_stretch(self):
"""Updates stretching matrix using constrained optimization
(equivalent to fmincon in MATLAB)."""

if self.verbose:
print("Updating stretch factors...")

# Flatten stretch for compatibility with the optimizer
# (since SciPy expects 1D input)
stretch_flat_initial = self.stretch_.flatten()
Expand Down
Loading