Skip to content

Log chosen/rejected entropy#1159

Open
jacklanchantin wants to merge 8 commits intoonline_trainingfrom
jacklanchantin/log_metrics
Open

Log chosen/rejected entropy#1159
jacklanchantin wants to merge 8 commits intoonline_trainingfrom
jacklanchantin/log_metrics

Conversation

@jacklanchantin
Copy link
Contributor

@jacklanchantin jacklanchantin commented May 1, 2025

What does this PR do? Please describe:

  • Adds logging entropy for chosen and rejected sequences separately in online DPO training.
  • Few other small changes

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 1, 2025
@jacklanchantin jacklanchantin changed the base branch from main to online_training May 1, 2025 21:25
return outputs

def reward_from_model(self, prompt_list, batch_size=64):
def reward_from_model(self, prompt_list, batch_size=16):
Copy link
Contributor Author

@jacklanchantin jacklanchantin May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was getting some vllm cuda OOM with batch_size=64 because we have custom vllm model

Comment on lines +81 to +82
register("chosen_logit_entropy", "Chosen Logit Entropy", 51, format_as_float)
register("rejected_logit_entropy","Rejected Logit Entropy", 51, format_as_float)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are the only two added. rest are formatted

per_seq_loss = (
(per_token_loss * target_mask).sum(dim=-1)
).mean(dim=1)
per_seq_loss = ((per_token_loss * target_mask).sum(dim=-1)).mean(dim=1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting

) # [Batch x Rollouts, 1]

# entropy for all N rollouts
logit_entropy = self.get_all_rollouts_entropy(rollouts)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we want this. previously logit_entropy was computed for the chosen sequences. here i'm computing it for all rollouts

@jacklanchantin jacklanchantin requested a review from uralik May 2, 2025 00:08
@jacklanchantin jacklanchantin marked this pull request as ready for review May 2, 2025 00:08
@jacklanchantin jacklanchantin requested a review from cbalioglu as a code owner May 2, 2025 00:08
@jacklanchantin jacklanchantin changed the title Jacklanchantin/log metrics Log chosne/rejected entropy May 2, 2025
@jacklanchantin jacklanchantin changed the title Log chosne/rejected entropy Log chosen/rejected entropy May 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants