Add optional custom log_rewards to all GFlowNet losses#503
Conversation
Add a `log_rewards` parameter to all loss/score methods across TB, DB, SubTB, FM, RTB, LPV, and their variants. When None (default), behavior is unchanged. When provided, the custom rewards are used instead of environment rewards, enabling intrinsic reward schemes. Based on the proposal in #312 by @Idriss-Malek, re-implemented on current master to resolve conflicts with the refactored codebase. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
14 tests covering all GFlowNet loss variants (TB, DB, SubTB, FM, LPV, RTB) verify two properties: 1. Identity: passing log_rewards equal to the environment rewards reproduces the default loss exactly. 2. Override: passing different log_rewards produces a different loss, proving the custom rewards are actually used. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds support for optionally overriding environment-provided terminal log-rewards with a user-supplied log_rewards tensor across GFlowNet loss/score computations, enabling intrinsic/custom reward schemes while keeping default behavior unchanged when log_rewards=None.
Changes:
- Adds
log_rewards: torch.Tensor | None = Noneto multipleloss()/get_scores()APIs and threads it through the score computations. - Updates the HyperGrid GAFN tutorial subclass override to match the new
loss()signature and passlog_rewardsthrough tosuper(). - Adds a new test suite validating identity/override behavior for several loss variants when
log_rewardsis provided.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
src/gfn/gflownet/base.py |
Adds optional log_rewards override support to trajectory score computation API. |
src/gfn/gflownet/trajectory_balance.py |
Threads log_rewards through TB/RTB/LPV-family losses and score paths (incl. variants). |
src/gfn/gflownet/detailed_balance.py |
Adds optional log_rewards override to DB scoring and loss computation. |
src/gfn/gflownet/flow_matching.py |
Adds optional log_rewards override for terminating reward-matching term in FM loss. |
src/gfn/gflownet/sub_trajectory_balance.py |
Adds optional log_rewards override through SubTB targets/scores/loss (incl. TB-weighting fast path). |
tutorials/examples/train_hypergrid_gafn.py |
Updates tutorial subclass loss() override to forward log_rewards. |
testing/test_custom_log_rewards.py |
New tests for identity/override behavior of custom log_rewards across several losses. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_scores( | ||
| self, | ||
| trajectories: Trajectories, | ||
| log_rewards: torch.Tensor | None = None, | ||
| recalculate_all_logprobs: bool = True, |
There was a problem hiding this comment.
Adding log_rewards before recalculate_all_logprobs changes the positional calling convention for get_scores(). Existing code that previously did get_scores(trajs, False, env) would now pass False as log_rewards, breaking backward compatibility. Consider keeping existing positional args intact and making log_rewards keyword-only (e.g., place it after env / behind a *).
There was a problem hiding this comment.
@copilot apply changes based on this feedback
There was a problem hiding this comment.
Fixed in commit 11caee9. Made log_rewards keyword-only (using *) in all public loss() and get_scores() methods across base.py, trajectory_balance.py, detailed_balance.py, flow_matching.py, sub_trajectory_balance.py, and the tutorial subclass. The existing positional parameters (env, trajectories/transitions, recalculate_all_logprobs, reduction) remain in their original positions, so callers using positional args continue to work unchanged.
| @@ -91,6 +91,7 @@ def loss( | |||
| self, | |||
| env: Env, | |||
| trajectories: Trajectories, | |||
| log_rewards: torch.Tensor | None = None, | |||
There was a problem hiding this comment.
Adding log_rewards before recalculate_all_logprobs breaks backward compatibility for callers that pass recalculate_all_logprobs positionally (e.g., loss(env, trajs, False)). To preserve the public API, keep the existing positional args and add log_rewards as keyword-only (or append it after reduction).
There was a problem hiding this comment.
@copilot apply changes based on this feedback
There was a problem hiding this comment.
Fixed in commit 11caee9. log_rewards is now keyword-only (after *) in TBGFlowNet.loss(), preserving the positional order (env, trajectories, recalculate_all_logprobs, reduction).
| rtb_loss = super().loss( | ||
| env, | ||
| trajectories, | ||
| log_rewards=log_rewards, | ||
| recalculate_all_logprobs=recalculate_all_logprobs, |
There was a problem hiding this comment.
Custom log_rewards are now forwarded through the TrustPCL wrapper, but there’s no test asserting identity/override behavior for TrustPCLGFlowNet.loss() specifically. Adding a small test (similar to RTB) would guard against regressions in the alpha^2 scaling wrapper.
There was a problem hiding this comment.
@copilot apply changes based on this feedback
There was a problem hiding this comment.
Added in commit 001ecd3. TestTrustPCLCustomLogRewards covers identity, override, and an alpha^2 scaling test that verifies TrustPCL.loss(log_rewards=r) == alpha^2 * RTB.loss(log_rewards=r) (with matching beta=1/alpha) after custom rewards pass through the wrapper.
| scores = self._compute_rtb_scores( | ||
| env, | ||
| trajectories, | ||
| log_rewards=log_rewards, | ||
| recalculate_all_logprobs=recalculate_all_logprobs, |
There was a problem hiding this comment.
RelativeLogPartitionVarianceGFlowNet.loss() now accepts custom log_rewards, but the custom-reward behavior isn’t covered by tests (new tests cover LPV and RTB, but not this variant). Consider adding an identity/override test for this loss too.
There was a problem hiding this comment.
@copilot apply changes based on this feedback
There was a problem hiding this comment.
Added in commit 001ecd3. TestRelativeLPVCustomLogRewards covers identity and override tests for RelativeLogPartitionVarianceGFlowNet.loss().
| @@ -209,6 +209,7 @@ | |||
| self, | |||
| env: Env, | |||
| transitions: Transitions, | |||
| log_rewards: torch.Tensor | None = None, | |||
There was a problem hiding this comment.
Adding log_rewards before recalculate_all_logprobs changes the positional calling convention for get_scores(env, transitions, ...). Existing code that called get_scores(env, transitions, False) will now pass False as log_rewards. To keep backward compatibility, make log_rewards keyword-only or move it after the existing optional args.
There was a problem hiding this comment.
@copilot apply changes based on this feedback
There was a problem hiding this comment.
Fixed in commit 11caee9. log_rewards is now keyword-only in DBGFlowNet.get_scores(), preserving (env, transitions, recalculate_all_logprobs) positional order.
| @@ -279,6 +279,7 @@ def loss( | |||
| self, | |||
| env: DiscreteEnv, | |||
| states_container: StatesContainer[DiscreteStates], | |||
| log_rewards: torch.Tensor | None = None, | |||
There was a problem hiding this comment.
Adding log_rewards before recalculate_all_logprobs changes the positional calling convention for loss(). Existing code that calls loss(env, samples, False) (to set recalculate_all_logprobs=False) would now pass False as log_rewards. Consider making log_rewards keyword-only or placing it after the existing optional args.
| log_rewards: torch.Tensor | None = None, | ||
| recalculate_all_logprobs: bool = True, | ||
| reduction: str = "mean", |
There was a problem hiding this comment.
This example override updates the signature to include log_rewards, but it also propagates the same positional-argument backward-compatibility break (a third positional arg would now be interpreted as log_rewards). If the library API is updated to make log_rewards keyword-only / appended, mirror that here too.
| log_rewards: torch.Tensor | None = None, | |
| recalculate_all_logprobs: bool = True, | |
| reduction: str = "mean", | |
| recalculate_all_logprobs: bool = True, | |
| reduction: str = "mean", | |
| *, | |
| log_rewards: torch.Tensor | None = None, |
| def loss( | ||
| self, | ||
| env: Env, | ||
| transitions: Transitions, | ||
| log_rewards: torch.Tensor | None = None, | ||
| recalculate_all_logprobs: bool = True, | ||
| reduction: str = "mean", |
There was a problem hiding this comment.
Same backward-compatibility concern for loss(): inserting log_rewards before recalculate_all_logprobs means a call like loss(env, transitions, False) would now pass False as log_rewards. Consider making log_rewards keyword-only (or appending it after the existing optional args) to avoid breaking positional callers.
There was a problem hiding this comment.
@copilot apply changes based on this feedback
There was a problem hiding this comment.
Fixed in commit 11caee9. log_rewards is now keyword-only in DBGFlowNet.loss(), preserving (env, transitions, recalculate_all_logprobs, reduction) positional order.
Agent-Logs-Url: https://github.com/GFNOrg/torchgfn/sessions/1919552f-66e6-465f-9209-ebd7b2f8629a Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
Fix CI: accept `log_rewards` kwarg in `TBGAFN.get_scores`
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…hods for backward compatibility Agent-Logs-Url: https://github.com/GFNOrg/torchgfn/sessions/59f54f6c-8302-486a-ae28-3b37747a3b55 Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@copilot can you please review the changes? |
…rds; add TrustPCL and RelativeLPV tests Agent-Logs-Url: https://github.com/GFNOrg/torchgfn/sessions/88122379-1c44-4cfc-b15a-8b72a521b080 Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
The changes have been reviewed and verified: Backward compatibility (
Test coverage ( |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #503 +/- ##
==========================================
- Coverage 72.49% 72.46% -0.03%
==========================================
Files 55 55
Lines 8522 8539 +17
Branches 1090 1099 +9
==========================================
+ Hits 6178 6188 +10
- Misses 1957 1961 +4
- Partials 387 390 +3
🚀 New features to boost your workflow:
|
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds a
log_rewards: torch.Tensor | None = Nonekeyword-only parameter to all GFlowNetloss()andget_scores()methods (TB, DB, SubTB, FM, RTB, LPV, TrustPCL, RelativeLPV, and variants). WhenNone(default), behavior is unchanged — environment rewards are used as before. When provided, the custom rewards replace environment rewards, enabling intrinsic reward schemes (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025).Files changed
src/gfn/gflownet/base.pyTrajectoryBasedGFlowNet.get_scores()—log_rewardsis keyword-onlysrc/gfn/gflownet/trajectory_balance.pyTBGFlowNet.loss(),RelativeTBBase.get_scores()/_compute_rtb_scores(),RelativeTrajectoryBalanceGFlowNet.loss(),TrustPCLGFlowNet.loss(),LogPartitionVarianceGFlowNet.loss(),RelativeLogPartitionVarianceGFlowNet.loss()—log_rewardsis keyword-onlysrc/gfn/gflownet/detailed_balance.pyDBGFlowNet.get_scores(),DBGFlowNet.loss()—log_rewardsis keyword-only; raisesValueErrorwhenforward_looking=Trueandlog_rewardsis provided (environment rewards cannot be fully replaced in that mode)src/gfn/gflownet/flow_matching.pyFMGFlowNet.loss()—log_rewardsis keyword-onlysrc/gfn/gflownet/sub_trajectory_balance.pySubTBGFlowNet.calculate_targets(),SubTBGFlowNet.get_scores(),SubTBGFlowNet.loss()—log_rewardsis keyword-onlytutorials/examples/train_hypergrid_gafn.pytesting/test_custom_log_rewards.pyDBGFlowNetraises whenforward_looking=Truewith customlog_rewardsAPI design
log_rewardsis placed after*in every signature, making it keyword-only. This preserves full backward compatibility — all existing positional callers (loss(env, trajs, False)etc.) are unaffected.Limitation
Custom
log_rewardsare not supported withDBGFlowNet(forward_looking=True). The forward-looking objective callsenv.log_reward()for intermediate-state flow adjustments, so custom rewards cannot fully replace environment rewards. Passinglog_rewardsin this mode raisesValueError.