Pass train_dtype into _load_model_memory_efficient#106
Conversation
The OSFT memory-efficient loading path was extracting the training dtype indirectly from base_kwargs["torch_dtype"] instead of accepting it as an explicit parameter. This made it fragile and inconsistent with the SFT distributed loading path which takes train_dtype directly. Thread train_dtype through: setup_model() → setup_osft_model_distributed() → from_pretrained() → _load_model_memory_efficient(), with backward- compatible fallback to torch_dtype from base_kwargs when not provided. Closes #34 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: multica-agent <github@multica.ai>
|
Warning Review limit reached
More reviews will be available in 4 minutes. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Summary
train_dtypeparameter to_load_model_memory_efficient()instead of extracting it indirectly frombase_kwargs["torch_dtype"]train_dtypethrough the full OSFT distributed loading path:setup_model()→setup_osft_model_distributed()→from_pretrained()→_load_model_memory_efficient()torch_dtypefrombase_kwargswhentrain_dtypeis not providedCloses #34
Test plan
train_dtypeoverridesbase_kwargs["torch_dtype"], one verifying backward-compatible fallbackTestLazyInitTokenizerAlignmentpassCo-Authored-By: Claude Opus 4.6 (1M context) noreply@anthropic.com