Skip to content

Fix Adam.set_params re-zeroing fp16 moment buffers#1595

Open
Nas01010101 wants to merge 1 commit into
NVIDIA:mainfrom
Nas01010101:fix/adam-fp16-set-params
Open

Fix Adam.set_params re-zeroing fp16 moment buffers#1595
Nas01010101 wants to merge 1 commit into
NVIDIA:mainfrom
Nas01010101:fix/adam-fp16-set-params

Conversation

@Nas01010101

@Nas01010101 Nas01010101 commented Jun 29, 2026

Copy link
Copy Markdown

Description

Fixes #1593.

warp.optim.Adam.set_params() silently re-zeroed its moment buffers on every call when the parameters are fp16, discarding accumulated optimizer state. Moment buffers are intentionally kept in float32 even for fp16 params, but the reuse check compared the existing buffer dtype against the parameter dtype (self.m[i].dtype != param.dtype), which for fp16 is always float32 != float16 → always true → reallocate to zeros. This affects the documented "set params later via set_params" path; the first call from __init__ allocates regardless. SGD.set_params() is not affected because its buffer dtype matches the param dtype.

Changes

  • warp/_src/optim/adam.py: compare the existing moment buffers against the moment-buffer dtype (dtype) instead of param.dtype when deciding whether to reuse them. float32/vec3 behavior is unchanged (there dtype == param.dtype).
  • warp/tests/test_adam.py: add test_adam_set_params_preserves_fp16_state covering fp32/fp16/vec3.
  • CHANGELOG.md: add a Fixed entry.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • CHANGELOG.md is updated for any user-facing changes under the Unreleased section.

Validation summary

Added test_adam_set_params_preserves_fp16_state, which dirties the moment buffers, calls set_params again with the same parameters, and asserts the buffers are reused (same object) and their contents preserved, for fp32, fp16, and vec3.

  • Verified the new test fails on main without this change (the fp16 case re-zeros the buffers): python -m warp.tests -k set_params_preservesFAILED (failures=2) on CPU and CUDA.
  • Verified the new test passes with this change: same command → OK on CPU and CUDA.
  • No regression: test_adam, test_sgd, test_grad, and test_linear_solvers all pass on CPU and CUDA.

Bug fix

Reproduces without this PR applied (CPU or CUDA):

import warp as wp
import warp.optim
wp.init()

p = wp.zeros(8, dtype=wp.float16, requires_grad=True)
opt = warp.optim.Adam([p], lr=1e-3)
opt.m[0].fill_(1.0)          # pretend we have accumulated moment state
opt.set_params([p])          # same params -> should be a no-op
print((opt.m[0].numpy() == 1.0).all())   # main: False (state wiped); with fp32 params: True

Summary by CodeRabbit

  • Bug Fixes

    • Fixed a mixed-precision issue where optimizer moment buffers could be reset/reallocated on repeated parameter setup calls, losing accumulated state.
    • Improved optimizer state reuse by correctly validating moment-buffer dtype compatibility rather than parameter dtype.
  • Tests

    • Added a unit test to confirm Adam.set_params() preserves existing first/second moment buffers (no reallocation or resets) across fp32, fp16, and vector parameter types.

@copy-pr-bot

copy-pr-bot Bot commented Jun 29, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 29, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: b0157790-ac47-487b-bf13-96bdbdbf45bc

📥 Commits

Reviewing files that changed from the base of the PR and between cde6077 and 73391e0.

📒 Files selected for processing (3)
  • CHANGELOG.md
  • warp/_src/optim/adam.py
  • warp/tests/test_adam.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • warp/tests/test_adam.py
  • warp/_src/optim/adam.py

📝 Walkthrough

Walkthrough

Adam.set_params now compares existing moment buffer dtype against the computed moment dtype variable instead of param.dtype, preventing moment buffers from being reallocated and zeroed on repeated calls for fp16 parameters. A regression test and changelog entry accompany the fix.

Changes

Adam fp16 moment buffer fix

Layer / File(s) Summary
Fix dtype check and regression test
warp/_src/optim/adam.py, warp/tests/test_adam.py, CHANGELOG.md
Moment buffer reuse guard compares self.m[i].dtype against the computed moment dtype instead of param.dtype. The new test test_adam_set_params_preserves_fp16_state checks buffer identity and unchanged contents for wp.float32, wp.float16, and wp.vec3, and the changelog records the fix.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~5 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly matches the main fix to Adam.set_params re-zeroing fp16 moment buffers.
Linked Issues check ✅ Passed The code change, regression test, and changelog update all align with issue #1593's fp16 state-preservation fix.
Out of Scope Changes check ✅ Passed The PR stays focused on the Adam fp16 buffer reuse bug and its regression coverage, with no unrelated changes.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure.

🔧 OpenGrep (1.23.0)
warp/_src/optim/adam.py

┌──────────────┐
│ Opengrep CLI │
└──────────────┘

�[32m✔�[39m �[1mOpengrep OSS�[0m
�[32m✔�[39m Basic security coverage for first-party code vulnerabilities.

[00.14][ERROR]: unable to find a config; path .coderabbit-opengrep-fallback.yml does not exist

warp/tests/test_adam.py

┌──────────────┐
│ Opengrep CLI │
└──────────────┘

�[32m✔�[39m �[1mOpengrep OSS�[0m
�[32m✔�[39m Basic security coverage for first-party code vulnerabilities.

[00.12][ERROR]: unable to find a config; path .coderabbit-opengrep-fallback.yml does not exist

🔧 markdownlint-cli2 (0.22.1)
CHANGELOG.md

markdownlint-cli2 wrapper config was not available before execution


Comment @coderabbitai help to get the list of available commands.

@greptile-apps

greptile-apps Bot commented Jun 29, 2026

Copy link
Copy Markdown

Greptile Summary

This PR fixes a bug in Adam.set_params() where moment buffers were reallocated (and zeroed) on every call for fp16 parameters, silently discarding accumulated optimizer state. The root cause was comparing the existing buffer's dtype against the parameter dtype rather than the moment-buffer dtype; since moments are always kept in fp32 even for fp16 params, the check float32 != float16 was always True, triggering a reallocation.

  • warp/_src/optim/adam.py: Two-line fix — both the first-moment (m) and second-moment (v) reuse checks now compare against dtype (the resolved moment-buffer dtype) instead of param.dtype.
  • warp/tests/test_adam.py: New regression test test_adam_set_params_preserves_fp16_state verifies buffer identity and content preservation for fp32, fp16, and vec3 after a redundant set_params call.

Confidence Score: 5/5

Safe to merge — the change is a two-line targeted fix to a documented dtype-mismatch bug, with no side-effects on fp32 or vec3 optimizers.

The fix is minimal and provably correct: the moment-buffer dtype is always resolved before the reuse check, so comparing against it rather than the parameter dtype is strictly more accurate. fp32 and vec3 codepaths are unaffected because their moment dtype equals their parameter dtype. The new regression test covers all three supported dtypes using both object-identity and value assertions, and the author confirmed it fails on main and passes with the fix applied.

No files require special attention.

Important Files Changed

Filename Overview
warp/_src/optim/adam.py Two-line targeted fix: dtype comparison in both buffer-reuse guards changed from param.dtype to dtype (the resolved moment-buffer dtype), correctly handling fp16 params whose moments are stored as fp32.
warp/tests/test_adam.py New regression test covers all three supported dtypes, checks both buffer identity (assertIs) and content preservation after a no-op set_params call.
CHANGELOG.md Bug-fix entry added under Unreleased/Fixed with a link to GH-1593; correctly placed and formatted.

Reviews (2): Last reviewed commit: "Fix Adam.set_params re-zeroing fp16 mome..." | Re-trigger Greptile

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/_src/optim/adam.py`:
- Around line 134-137: The Adam moment-buffer reuse check in set_params() only
validates shape and dtype, so buffers can be incorrectly reused when params move
to a different device. Update the compatibility guard for self.m[i] and
self.v[i] to also compare param.device against the existing buffer device, and
recreate the buffers in Adam when the device differs so state never mixes
devices.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: bb662c80-8836-48e4-a5de-5b8901b757de

📥 Commits

Reviewing files that changed from the base of the PR and between ced4300 and cde6077.

📒 Files selected for processing (3)
  • CHANGELOG.md
  • warp/_src/optim/adam.py
  • warp/tests/test_adam.py

Comment thread warp/_src/optim/adam.py
Comment on lines +134 to 137
if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != dtype:
self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype:
if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != dtype:
self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)

@coderabbitai coderabbitai Bot Jun 29, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Include device in the moment-buffer reuse guard.

If set_params() is called with same-shaped params on a different device, these branches will reuse self.m[i]/self.v[i] from the old device because only shape and dtype are checked. That leaves Adam holding mixed-device state and can fail on the next step. Make param.device part of the compatibility check.

Suggested fix
-                if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != dtype:
+                if (
+                    self.m[i] is None
+                    or self.m[i].shape != param.shape
+                    or self.m[i].dtype != dtype
+                    or self.m[i].device != param.device
+                ):
                     self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
-                if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != dtype:
+                if (
+                    self.v[i] is None
+                    or self.v[i].shape != param.shape
+                    or self.v[i].dtype != dtype
+                    or self.v[i].device != param.device
+                ):
                     self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != dtype:
self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype:
if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != dtype:
self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
if (
self.m[i] is None
or self.m[i].shape != param.shape
or self.m[i].dtype != dtype
or self.m[i].device != param.device
):
self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
if (
self.v[i] is None
or self.v[i].shape != param.shape
or self.v[i].dtype != dtype
or self.v[i].device != param.device
):
self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@warp/_src/optim/adam.py` around lines 134 - 137, The Adam moment-buffer reuse
check in set_params() only validates shape and dtype, so buffers can be
incorrectly reused when params move to a different device. Update the
compatibility guard for self.m[i] and self.v[i] to also compare param.device
against the existing buffer device, and recreate the buffers in Adam when the
device differs so state never mixes devices.

@Nas01010101 Nas01010101 Jun 29, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be out of scope for this PR. It targets the fp16 moment-state reset in #1593. The device-compatibility gap is pre-existing and not specific to the fp16 fix: the same shape/dtype-only guard exists in SGD.set_params, so reusing buffers across a device change is an issue for both optimizers. Could be better to change device handling for Adam and SGD in a separate follow-up rather than widen this change. Will let maintainers review that!

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nas01010101, that makes sense — thanks for the clarification.

Agreed: this is a real pre-existing gap, but it’s orthogonal to the fp16 state-preservation fix in #1595. Handling Adam.set_params() and SGD.set_params() together in a separate follow-up sounds like the cleaner approach, so I won’t push on this PR further. If you’d like, I can help draft or open a follow-up issue for that shared device-compatibility work.

(ᵔᴥᵔ)


✏️ Learnings added
Learnt from: Nas01010101
Repo: NVIDIA/warp PR: 1595
File: warp/_src/optim/adam.py:134-137
Timestamp: 2026-06-29T03:33:58.805Z
Learning: In NVIDIA/warp PR `#1595`, the device-compatibility issue in `warp/_src/optim/adam.py` `Adam.set_params()` was confirmed as a valid pre-existing concern but intentionally kept out of scope because the PR only fixes fp16 moment-state reset. The same shape/dtype-only reuse guard also exists in `SGD.set_params()`, so future review comments should prefer a consistent follow-up covering both optimizers rather than requesting an `Adam`-only fix in that PR.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: shi-eric
Repo: NVIDIA/warp PR: 1433
File: warp/_src/utils.py:1571-1577
Timestamp: 2026-05-07T19:57:16.716Z
Learning: When reviewing code in NVIDIA/warp that uses `wp.config.log_level` or `ScopedLogLevel`, treat log levels as open-ended integer thresholds following Python `logging` semantics (e.g., DEBUG=10, INFO=20, WARNING=30, ERROR=40). Arbitrary intermediate integer values (e.g., `LOG_INFO + 5`) are valid and should not be flagged as invalid. Also, do not recommend changing `ScopedLogLevel.__init__` to accept only the four named constants, since it must remain consistent with how `wp.config.log_level` is used elsewhere in the `warp/_src` codebase.

Adam moment buffers are always fp32, even for fp16 params. The realloc
guard in set_params compared the existing buffer dtype against the param
dtype, so for fp16 params (moment fp32 != param fp16) it was always true
and the moment buffers were silently re-allocated to zeros on every
set_params call, discarding optimizer state. Compare against the moment
buffer dtype instead. fp32/vec3 behaviour is unchanged. Adds a regression
test covering fp32, fp16, and vec3.

Fixes NVIDIA#1593.

Signed-off-by: Anas <anaselghoudane@gmail.com>
@Nas01010101 Nas01010101 force-pushed the fix/adam-fp16-set-params branch from cde6077 to 73391e0 Compare June 29, 2026 03:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Adam.set_params re-zeros fp16 optimizer moment state on every call

1 participant