Conversation
|
@claude review |
| loss_ctx = self.logprob_cfg.build(data={"shifted_labels": shifted_labels}) | ||
| assert loss_ctx is not None | ||
| ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx}) | ||
| ref_logprobs_list.append(ref_output["loss"]) |
There was a problem hiding this comment.
Claude: Good fix. This now mirrors the pattern used in compute_actor_logprobs (lines 379–383 via _engine.forward_only), which also delegates logprob computation to LogProbContext. This brings two improvements:
- Consistency — ref model logprob computation now follows the same pipeline as actor logprob computation.
- Memory efficiency — When
LogProbConfigusesmode="chunk", the logprobs are computed in vocabulary-sized chunks instead of materializing the full(batch, seq_len, vocab_size)logits tensor. The old code forced full materialization regardless of config.
Nit: the variable name loss_ctx on line 393 is reused as a key in the dict literal on line 395 ({"lm": loss_ctx}). Consider renaming the local to ref_loss_ctx or logprob_ctx for clarity — this would also match the naming intent better since it's a LogProbContext, not a generic loss context.
|
Claude: ## Summary Clean fix that aligns IssuesNit
Process
VerdictAPPROVE — the code change itself is correct and well-motivated. |
|
@claude please approve it |
|
Claude: Approved the PR as requested. The code change correctly aligns |
No description provided.