Skip to content

Fix refer model with RL#1660

Open
hhaAndroid wants to merge 1 commit intoInternLM:mainfrom
hhaAndroid:fix_ref_model_rl
Open

Fix refer model with RL#1660
hhaAndroid wants to merge 1 commit intoInternLM:mainfrom
hhaAndroid:fix_ref_model_rl

Conversation

@hhaAndroid
Copy link
Copy Markdown
Collaborator

No description provided.

@hhaAndroid
Copy link
Copy Markdown
Collaborator Author

@claude review

@hhaAndroid hhaAndroid changed the title fix refer model with RL Fix refer model with RL Apr 8, 2026
Comment on lines +393 to +396
loss_ctx = self.logprob_cfg.build(data={"shifted_labels": shifted_labels})
assert loss_ctx is not None
ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})
ref_logprobs_list.append(ref_output["loss"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Good fix. This now mirrors the pattern used in compute_actor_logprobs (lines 379–383 via _engine.forward_only), which also delegates logprob computation to LogProbContext. This brings two improvements:

  1. Consistency — ref model logprob computation now follows the same pipeline as actor logprob computation.
  2. Memory efficiency — When LogProbConfig uses mode="chunk", the logprobs are computed in vocabulary-sized chunks instead of materializing the full (batch, seq_len, vocab_size) logits tensor. The old code forced full materialization regardless of config.

Nit: the variable name loss_ctx on line 393 is reused as a key in the dict literal on line 395 ({"lm": loss_ctx}). Consider renaming the local to ref_loss_ctx or logprob_ctx for clarity — this would also match the naming intent better since it's a LogProbContext, not a generic loss context.

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 8, 2026

Claude: ## Summary

Clean fix that aligns compute_ref_logprobs with the existing compute_actor_logprobs pattern by delegating logprob computation to LogProbContext instead of manually calling gather_logprobs on full logits. This is semantically equivalent, but gains chunked-mode memory efficiency and reduces code duplication.

Issues

Nit

  • [worker.py:393] Minor naming: local loss_ctx could be ref_loss_ctx or logprob_ctx for clarity vs. the dict key.

Process

  • PR body is empty — per CLAUDE.md, bug fix PRs should include Root Cause, Fix, and Test Plan sections. Even for a straightforward change like this, documenting why the old code was wrong (e.g. OOM on large vocabs? inconsistency with actor path?) helps future readers.
  • No regression test — CLAUDE.md requires bug fix PRs to include a test that reproduces the original bug.

Verdict

APPROVE — the code change itself is correct and well-motivated.

@hhaAndroid
Copy link
Copy Markdown
Collaborator Author

@claude please approve it

Copy link
Copy Markdown
Contributor

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code change is correct — aligns compute_ref_logprobs with the existing compute_actor_logprobs pattern by delegating logprob computation to LogProbContext.

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 8, 2026

Claude: Approved the PR as requested. The code change correctly aligns compute_ref_logprobs with the existing compute_actor_logprobs pattern.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant