Fix Adam.set_params re-zeroing fp16 moment buffers#1595
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (3)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthrough
ChangesAdam fp16 moment buffer fix
Estimated code review effort🎯 2 (Simple) | ⏱️ ~5 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Warning There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure. 🔧 OpenGrep (1.23.0)warp/_src/optim/adam.py┌──────────────┐ �[32m✔�[39m �[1mOpengrep OSS�[0m [00.14][ERROR]: unable to find a config; path warp/tests/test_adam.py┌──────────────┐ �[32m✔�[39m �[1mOpengrep OSS�[0m [00.12][ERROR]: unable to find a config; path 🔧 markdownlint-cli2 (0.22.1)CHANGELOG.mdmarkdownlint-cli2 wrapper config was not available before execution Comment |
Greptile SummaryThis PR fixes a bug in
Confidence Score: 5/5Safe to merge — the change is a two-line targeted fix to a documented dtype-mismatch bug, with no side-effects on fp32 or vec3 optimizers. The fix is minimal and provably correct: the moment-buffer dtype is always resolved before the reuse check, so comparing against it rather than the parameter dtype is strictly more accurate. fp32 and vec3 codepaths are unaffected because their moment dtype equals their parameter dtype. The new regression test covers all three supported dtypes using both object-identity and value assertions, and the author confirmed it fails on main and passes with the fix applied. No files require special attention. Important Files Changed
Reviews (2): Last reviewed commit: "Fix Adam.set_params re-zeroing fp16 mome..." | Re-trigger Greptile |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@warp/_src/optim/adam.py`:
- Around line 134-137: The Adam moment-buffer reuse check in set_params() only
validates shape and dtype, so buffers can be incorrectly reused when params move
to a different device. Update the compatibility guard for self.m[i] and
self.v[i] to also compare param.device against the existing buffer device, and
recreate the buffers in Adam when the device differs so state never mixes
devices.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: bb662c80-8836-48e4-a5de-5b8901b757de
📒 Files selected for processing (3)
CHANGELOG.mdwarp/_src/optim/adam.pywarp/tests/test_adam.py
| if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != dtype: | ||
| self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device) | ||
| if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype: | ||
| if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != dtype: | ||
| self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Include device in the moment-buffer reuse guard.
If set_params() is called with same-shaped params on a different device, these branches will reuse self.m[i]/self.v[i] from the old device because only shape and dtype are checked. That leaves Adam holding mixed-device state and can fail on the next step. Make param.device part of the compatibility check.
Suggested fix
- if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != dtype:
+ if (
+ self.m[i] is None
+ or self.m[i].shape != param.shape
+ or self.m[i].dtype != dtype
+ or self.m[i].device != param.device
+ ):
self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)
- if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != dtype:
+ if (
+ self.v[i] is None
+ or self.v[i].shape != param.shape
+ or self.v[i].dtype != dtype
+ or self.v[i].device != param.device
+ ):
self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != dtype: | |
| self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device) | |
| if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype: | |
| if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != dtype: | |
| self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device) | |
| if ( | |
| self.m[i] is None | |
| or self.m[i].shape != param.shape | |
| or self.m[i].dtype != dtype | |
| or self.m[i].device != param.device | |
| ): | |
| self.m[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device) | |
| if ( | |
| self.v[i] is None | |
| or self.v[i].shape != param.shape | |
| or self.v[i].dtype != dtype | |
| or self.v[i].device != param.device | |
| ): | |
| self.v[i] = wp.zeros(shape=param.shape, dtype=dtype, device=param.device) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@warp/_src/optim/adam.py` around lines 134 - 137, The Adam moment-buffer reuse
check in set_params() only validates shape and dtype, so buffers can be
incorrectly reused when params move to a different device. Update the
compatibility guard for self.m[i] and self.v[i] to also compare param.device
against the existing buffer device, and recreate the buffers in Adam when the
device differs so state never mixes devices.
There was a problem hiding this comment.
Could be out of scope for this PR. It targets the fp16 moment-state reset in #1593. The device-compatibility gap is pre-existing and not specific to the fp16 fix: the same shape/dtype-only guard exists in SGD.set_params, so reusing buffers across a device change is an issue for both optimizers. Could be better to change device handling for Adam and SGD in a separate follow-up rather than widen this change. Will let maintainers review that!
There was a problem hiding this comment.
@Nas01010101, that makes sense — thanks for the clarification.
Agreed: this is a real pre-existing gap, but it’s orthogonal to the fp16 state-preservation fix in #1595. Handling Adam.set_params() and SGD.set_params() together in a separate follow-up sounds like the cleaner approach, so I won’t push on this PR further. If you’d like, I can help draft or open a follow-up issue for that shared device-compatibility work.
(ᵔᴥᵔ)
✏️ Learnings added
Learnt from: Nas01010101
Repo: NVIDIA/warp PR: 1595
File: warp/_src/optim/adam.py:134-137
Timestamp: 2026-06-29T03:33:58.805Z
Learning: In NVIDIA/warp PR `#1595`, the device-compatibility issue in `warp/_src/optim/adam.py` `Adam.set_params()` was confirmed as a valid pre-existing concern but intentionally kept out of scope because the PR only fixes fp16 moment-state reset. The same shape/dtype-only reuse guard also exists in `SGD.set_params()`, so future review comments should prefer a consistent follow-up covering both optimizers rather than requesting an `Adam`-only fix in that PR.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
🧠 Learnings used
Learnt from: shi-eric
Repo: NVIDIA/warp PR: 1433
File: warp/_src/utils.py:1571-1577
Timestamp: 2026-05-07T19:57:16.716Z
Learning: When reviewing code in NVIDIA/warp that uses `wp.config.log_level` or `ScopedLogLevel`, treat log levels as open-ended integer thresholds following Python `logging` semantics (e.g., DEBUG=10, INFO=20, WARNING=30, ERROR=40). Arbitrary intermediate integer values (e.g., `LOG_INFO + 5`) are valid and should not be flagged as invalid. Also, do not recommend changing `ScopedLogLevel.__init__` to accept only the four named constants, since it must remain consistent with how `wp.config.log_level` is used elsewhere in the `warp/_src` codebase.
Adam moment buffers are always fp32, even for fp16 params. The realloc guard in set_params compared the existing buffer dtype against the param dtype, so for fp16 params (moment fp32 != param fp16) it was always true and the moment buffers were silently re-allocated to zeros on every set_params call, discarding optimizer state. Compare against the moment buffer dtype instead. fp32/vec3 behaviour is unchanged. Adds a regression test covering fp32, fp16, and vec3. Fixes NVIDIA#1593. Signed-off-by: Anas <anaselghoudane@gmail.com>
cde6077 to
73391e0
Compare
Description
Fixes #1593.
warp.optim.Adam.set_params()silently re-zeroed its moment buffers on every call when the parameters arefp16, discarding accumulated optimizer state. Moment buffers are intentionally kept infloat32even forfp16params, but the reuse check compared the existing buffer dtype against the parameter dtype (self.m[i].dtype != param.dtype), which forfp16is alwaysfloat32 != float16→ always true → reallocate to zeros. This affects the documented "set params later viaset_params" path; the first call from__init__allocates regardless.SGD.set_params()is not affected because its buffer dtype matches the param dtype.Changes
warp/_src/optim/adam.py: compare the existing moment buffers against the moment-buffer dtype (dtype) instead ofparam.dtypewhen deciding whether to reuse them.float32/vec3behavior is unchanged (theredtype == param.dtype).warp/tests/test_adam.py: addtest_adam_set_params_preserves_fp16_statecoveringfp32/fp16/vec3.CHANGELOG.md: add aFixedentry.Checklist
Unreleasedsection.Validation summary
Added
test_adam_set_params_preserves_fp16_state, which dirties the moment buffers, callsset_paramsagain with the same parameters, and asserts the buffers are reused (same object) and their contents preserved, forfp32,fp16, andvec3.mainwithout this change (thefp16case re-zeros the buffers):python -m warp.tests -k set_params_preserves→FAILED (failures=2)on CPU and CUDA.OKon CPU and CUDA.test_adam,test_sgd,test_grad, andtest_linear_solversall pass on CPU and CUDA.Bug fix
Reproduces without this PR applied (CPU or CUDA):
Summary by CodeRabbit
Bug Fixes
Tests
Adam.set_params()preserves existing first/second moment buffers (no reallocation or resets) across fp32, fp16, and vector parameter types.