Skip to content

Add optional custom log_rewards to all GFlowNet losses#503

Merged
josephdviviano merged 11 commits into
masterfrom
custom-log-rewards
Apr 3, 2026
Merged

Add optional custom log_rewards to all GFlowNet losses#503
josephdviviano merged 11 commits into
masterfrom
custom-log-rewards

Conversation

@josephdviviano
Copy link
Copy Markdown
Collaborator

@josephdviviano josephdviviano commented Apr 3, 2026

Adds a log_rewards: torch.Tensor | None = None keyword-only parameter to all GFlowNet loss() and get_scores() methods (TB, DB, SubTB, FM, RTB, LPV, TrustPCL, RelativeLPV, and variants). When None (default), behavior is unchanged — environment rewards are used as before. When provided, the custom rewards replace environment rewards, enabling intrinsic reward schemes (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025).

Files changed

File Changes
src/gfn/gflownet/base.py TrajectoryBasedGFlowNet.get_scores()log_rewards is keyword-only
src/gfn/gflownet/trajectory_balance.py TBGFlowNet.loss(), RelativeTBBase.get_scores()/_compute_rtb_scores(), RelativeTrajectoryBalanceGFlowNet.loss(), TrustPCLGFlowNet.loss(), LogPartitionVarianceGFlowNet.loss(), RelativeLogPartitionVarianceGFlowNet.loss()log_rewards is keyword-only
src/gfn/gflownet/detailed_balance.py DBGFlowNet.get_scores(), DBGFlowNet.loss()log_rewards is keyword-only; raises ValueError when forward_looking=True and log_rewards is provided (environment rewards cannot be fully replaced in that mode)
src/gfn/gflownet/flow_matching.py FMGFlowNet.loss()log_rewards is keyword-only
src/gfn/gflownet/sub_trajectory_balance.py SubTBGFlowNet.calculate_targets(), SubTBGFlowNet.get_scores(), SubTBGFlowNet.loss()log_rewards is keyword-only
tutorials/examples/train_hypergrid_gafn.py Updated subclass override to match new keyword-only signature
testing/test_custom_log_rewards.py Tests for all loss variants: TB, DB, SubTB, FM, LPV, RTB, TrustPCL (including alpha² scaling), RelativeLPV, and a test that DBGFlowNet raises when forward_looking=True with custom log_rewards

API design

log_rewards is placed after * in every signature, making it keyword-only. This preserves full backward compatibility — all existing positional callers (loss(env, trajs, False) etc.) are unaffected.

Limitation

Custom log_rewards are not supported with DBGFlowNet(forward_looking=True). The forward-looking objective calls env.log_reward() for intermediate-state flow adjustments, so custom rewards cannot fully replace environment rewards. Passing log_rewards in this mode raises ValueError.

josephdviviano and others added 2 commits April 2, 2026 22:38
Add a `log_rewards` parameter to all loss/score methods across TB, DB,
SubTB, FM, RTB, LPV, and their variants. When None (default), behavior
is unchanged. When provided, the custom rewards are used instead of
environment rewards, enabling intrinsic reward schemes.

Based on the proposal in #312 by @Idriss-Malek, re-implemented on
current master to resolve conflicts with the refactored codebase.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
14 tests covering all GFlowNet loss variants (TB, DB, SubTB, FM, LPV,
RTB) verify two properties:

1. Identity: passing log_rewards equal to the environment rewards
   reproduces the default loss exactly.
2. Override: passing different log_rewards produces a different loss,
   proving the custom rewards are actually used.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for optionally overriding environment-provided terminal log-rewards with a user-supplied log_rewards tensor across GFlowNet loss/score computations, enabling intrinsic/custom reward schemes while keeping default behavior unchanged when log_rewards=None.

Changes:

  • Adds log_rewards: torch.Tensor | None = None to multiple loss() / get_scores() APIs and threads it through the score computations.
  • Updates the HyperGrid GAFN tutorial subclass override to match the new loss() signature and pass log_rewards through to super().
  • Adds a new test suite validating identity/override behavior for several loss variants when log_rewards is provided.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
src/gfn/gflownet/base.py Adds optional log_rewards override support to trajectory score computation API.
src/gfn/gflownet/trajectory_balance.py Threads log_rewards through TB/RTB/LPV-family losses and score paths (incl. variants).
src/gfn/gflownet/detailed_balance.py Adds optional log_rewards override to DB scoring and loss computation.
src/gfn/gflownet/flow_matching.py Adds optional log_rewards override for terminating reward-matching term in FM loss.
src/gfn/gflownet/sub_trajectory_balance.py Adds optional log_rewards override through SubTB targets/scores/loss (incl. TB-weighting fast path).
tutorials/examples/train_hypergrid_gafn.py Updates tutorial subclass loss() override to forward log_rewards.
testing/test_custom_log_rewards.py New tests for identity/override behavior of custom log_rewards across several losses.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/gfn/gflownet/base.py
Comment on lines 420 to 424
def get_scores(
self,
trajectories: Trajectories,
log_rewards: torch.Tensor | None = None,
recalculate_all_logprobs: bool = True,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding log_rewards before recalculate_all_logprobs changes the positional calling convention for get_scores(). Existing code that previously did get_scores(trajs, False, env) would now pass False as log_rewards, breaking backward compatibility. Consider keeping existing positional args intact and making log_rewards keyword-only (e.g., place it after env / behind a *).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 11caee9. Made log_rewards keyword-only (using *) in all public loss() and get_scores() methods across base.py, trajectory_balance.py, detailed_balance.py, flow_matching.py, sub_trajectory_balance.py, and the tutorial subclass. The existing positional parameters (env, trajectories/transitions, recalculate_all_logprobs, reduction) remain in their original positions, so callers using positional args continue to work unchanged.

Comment thread src/gfn/gflownet/trajectory_balance.py Outdated
Comment on lines +90 to +94
@@ -91,6 +91,7 @@ def loss(
self,
env: Env,
trajectories: Trajectories,
log_rewards: torch.Tensor | None = None,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding log_rewards before recalculate_all_logprobs breaks backward compatibility for callers that pass recalculate_all_logprobs positionally (e.g., loss(env, trajs, False)). To preserve the public API, keep the existing positional args and add log_rewards as keyword-only (or append it after reduction).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 11caee9. log_rewards is now keyword-only (after *) in TBGFlowNet.loss(), preserving the positional order (env, trajectories, recalculate_all_logprobs, reduction).

Comment on lines +591 to +595
rtb_loss = super().loss(
env,
trajectories,
log_rewards=log_rewards,
recalculate_all_logprobs=recalculate_all_logprobs,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Custom log_rewards are now forwarded through the TrustPCL wrapper, but there’s no test asserting identity/override behavior for TrustPCLGFlowNet.loss() specifically. Adding a small test (similar to RTB) would guard against regressions in the alpha^2 scaling wrapper.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in commit 001ecd3. TestTrustPCLCustomLogRewards covers identity, override, and an alpha^2 scaling test that verifies TrustPCL.loss(log_rewards=r) == alpha^2 * RTB.loss(log_rewards=r) (with matching beta=1/alpha) after custom rewards pass through the wrapper.

Comment on lines +717 to +721
scores = self._compute_rtb_scores(
env,
trajectories,
log_rewards=log_rewards,
recalculate_all_logprobs=recalculate_all_logprobs,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RelativeLogPartitionVarianceGFlowNet.loss() now accepts custom log_rewards, but the custom-reward behavior isn’t covered by tests (new tests cover LPV and RTB, but not this variant). Consider adding an identity/override test for this loss too.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in commit 001ecd3. TestRelativeLPVCustomLogRewards covers identity and override tests for RelativeLogPartitionVarianceGFlowNet.loss().

Comment thread src/gfn/gflownet/detailed_balance.py Outdated
Comment on lines +208 to +212
@@ -209,6 +209,7 @@
self,
env: Env,
transitions: Transitions,
log_rewards: torch.Tensor | None = None,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding log_rewards before recalculate_all_logprobs changes the positional calling convention for get_scores(env, transitions, ...). Existing code that called get_scores(env, transitions, False) will now pass False as log_rewards. To keep backward compatibility, make log_rewards keyword-only or move it after the existing optional args.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 11caee9. log_rewards is now keyword-only in DBGFlowNet.get_scores(), preserving (env, transitions, recalculate_all_logprobs) positional order.

Comment thread src/gfn/gflownet/sub_trajectory_balance.py Outdated
Comment thread src/gfn/gflownet/flow_matching.py Outdated
Comment on lines +278 to +282
@@ -279,6 +279,7 @@ def loss(
self,
env: DiscreteEnv,
states_container: StatesContainer[DiscreteStates],
log_rewards: torch.Tensor | None = None,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding log_rewards before recalculate_all_logprobs changes the positional calling convention for loss(). Existing code that calls loss(env, samples, False) (to set recalculate_all_logprobs=False) would now pass False as log_rewards. Consider making log_rewards keyword-only or placing it after the existing optional args.

Copilot uses AI. Check for mistakes.
Comment on lines 242 to 244
log_rewards: torch.Tensor | None = None,
recalculate_all_logprobs: bool = True,
reduction: str = "mean",
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example override updates the signature to include log_rewards, but it also propagates the same positional-argument backward-compatibility break (a third positional arg would now be interpreted as log_rewards). If the library API is updated to make log_rewards keyword-only / appended, mirror that here too.

Suggested change
log_rewards: torch.Tensor | None = None,
recalculate_all_logprobs: bool = True,
reduction: str = "mean",
recalculate_all_logprobs: bool = True,
reduction: str = "mean",
*,
log_rewards: torch.Tensor | None = None,

Copilot uses AI. Check for mistakes.
Comment thread testing/test_custom_log_rewards.py Outdated
Comment on lines 339 to 345
def loss(
self,
env: Env,
transitions: Transitions,
log_rewards: torch.Tensor | None = None,
recalculate_all_logprobs: bool = True,
reduction: str = "mean",
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same backward-compatibility concern for loss(): inserting log_rewards before recalculate_all_logprobs means a call like loss(env, transitions, False) would now pass False as log_rewards. Consider making log_rewards keyword-only (or appending it after the existing optional args) to avoid breaking positional callers.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 11caee9. log_rewards is now keyword-only in DBGFlowNet.loss(), preserving (env, transitions, recalculate_all_logprobs, reduction) positional order.

Copilot AI and others added 2 commits April 3, 2026 04:25
Fix CI: accept `log_rewards` kwarg in `TBGAFN.get_scores`
josephdviviano and others added 2 commits April 3, 2026 01:54
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…hods for backward compatibility

Agent-Logs-Url: https://github.com/GFNOrg/torchgfn/sessions/59f54f6c-8302-486a-ae28-3b37747a3b55

Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@josephdviviano
Copy link
Copy Markdown
Collaborator Author

@copilot can you please review the changes?

…rds; add TrustPCL and RelativeLPV tests

Agent-Logs-Url: https://github.com/GFNOrg/torchgfn/sessions/88122379-1c44-4cfc-b15a-8b72a521b080

Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 3, 2026

@copilot can you please review the changes?

The changes have been reviewed and verified:

Backward compatibility (11caee9): log_rewards is keyword-only (using *) in all public loss() and get_scores() methods — positional callers are unaffected.

forward_looking guard (001ecd3): DBGFlowNet.get_scores() now raises ValueError immediately when both forward_looking=True and custom log_rewards are provided, preventing silently inconsistent results.

Test coverage (001ecd3): Added TestTrustPCLCustomLogRewards (identity + override + alpha² scaling), TestRelativeLPVCustomLogRewards (identity + override), and TestDBCustomLogRewards.test_forward_looking_raises_with_custom_log_rewards. All 20 tests in test_custom_log_rewards.py pass; 1103 related tests pass; CodeQL reports 0 alerts.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 3, 2026

Codecov Report

❌ Patch coverage is 65.51724% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.46%. Comparing base (e6eddda) to head (5aaef2e).
⚠️ Report is 12 commits behind head on master.

Files with missing lines Patch % Lines
src/gfn/gflownet/sub_trajectory_balance.py 45.45% 4 Missing and 2 partials ⚠️
src/gfn/gflownet/trajectory_balance.py 66.66% 2 Missing and 1 partial ⚠️
src/gfn/gflownet/detailed_balance.py 80.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #503      +/-   ##
==========================================
- Coverage   72.49%   72.46%   -0.03%     
==========================================
  Files          55       55              
  Lines        8522     8539      +17     
  Branches     1090     1099       +9     
==========================================
+ Hits         6178     6188      +10     
- Misses       1957     1961       +4     
- Partials      387      390       +3     
Files with missing lines Coverage Δ
src/gfn/gflownet/base.py 78.51% <100.00%> (+0.36%) ⬆️
src/gfn/gflownet/flow_matching.py 85.57% <100.00%> (+0.14%) ⬆️
src/gfn/gflownet/detailed_balance.py 62.96% <80.00%> (+0.30%) ⬆️
src/gfn/gflownet/trajectory_balance.py 68.25% <66.66%> (-0.86%) ⬇️
src/gfn/gflownet/sub_trajectory_balance.py 77.00% <45.45%> (-1.24%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

josephdviviano and others added 2 commits April 3, 2026 03:31
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@josephdviviano josephdviviano merged commit 5c4ea82 into master Apr 3, 2026
6 checks passed
@josephdviviano josephdviviano deleted the custom-log-rewards branch April 5, 2026 04:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants