WIP: adding a jax/flax formatter#9902
Conversation
|
The latest updates on your projects. Learn more about Vercel for GitHub.
|
The Flax formatter previously surfaced non-trainable state (BatchNorm running stats, dropout PRNG keys) as a "+N state" note. PyTorch's formatter ignores buffers, so the two were inconsistent. Drop the state note and count only nnx.Param, matching the PyTorch behavior. Updates tests and the smoke-test notebook accordingly. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds a rich HTML tree formatter for Flax NNX (flax.nnx.Module), analogous to the existing PyTorch formatter, and factors shared NN-tree presentation (CSS/legend/helpers) into a common module.
Changes:
- Introduce
FlaxFormatter+flax_formatters.format()for rendering NNX modules as a collapsible tree. - Extract shared NN-tree presentation utilities into
marimo/_output/formatters/_nn_tree.pyand refactor PyTorch formatter to use it. - Add optional test coverage + smoke test notebook/script for the new formatter and declare
flaxas an optional test dependency.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/_output/formatters/test_flax_formatters.py | Adds unit tests for Flax/NXX HTML tree rendering and formatter registration. |
| pyproject.toml | Adds flax to the optional test dependency set (Python 3.12 marker). |
| marimo/_smoke_tests/formatters/flax_formatters.py | Adds a smoke-test marimo app demonstrating Flax formatter output on several model shapes. |
| marimo/_output/formatters/pytorch_formatters.py | Refactors PyTorch formatter to reuse shared NN-tree presentation utilities. |
| marimo/_output/formatters/formatters.py | Registers the Flax formatter factory for lazy activation on import flax. |
| marimo/_output/formatters/flax_formatters.py | Implements Flax NNX module tree extraction + HTML rendering and formatter registration. |
| marimo/_output/formatters/_nn_tree.py | New shared CSS/legend/helpers for NN-tree formatters. |
| marimo/_dependencies/dependencies.py | Adds DependencyManager.flax. |
Comments suppressed due to low confidence (1)
marimo/_output/formatters/flax_formatters.py:330
FlaxFormatter.register()importsnnxunconditionally. Because registration runs automatically onimport flaxvia the import hook, a missing/unsupportedflax.nnx(or partial install) would raise and can break importingflaxentirely. Guard thennximport and avoid the self-import; you can reference the localformat()directly.
There was a problem hiding this comment.
1 issue found across 3 files (changes from recent commits).
Reply with feedback, questions, or to request a fix.
Re-trigger cubic
There was a problem hiding this comment.
2 issues found and verified against the latest diff
Reply with feedback, questions, or to request a fix.
Re-trigger cubic
The tree-walking and node HTML construction was duplicated nearly line-for-line between the two formatters; only CSS/helpers had been shared. Introduce framework-agnostic TreeNode/LeafBody types and render_node/render_model in _nn_tree.py, so each formatter now only extracts data from its own modules. Also address PR review feedback: dim params-less Flax leaves (data-frozen) so the shared legend is accurate, raise the jax floor to >=0.4.27 (required by flax 0.10), and fix a stale smoke-test comment about the attention category. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
1 issue found across 6 files (changes from recent commits).
Tip: Review your code locally with the cubic CLI to iterate faster.
Re-trigger cubic
_fmt_integer only handled K/M, so large (LLM-scale) models showed unwieldy raw counts. Add a billions branch and extend the test. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
📝 Summary
We've got a PyTorch formatter, which looks like this:
This PR adds a formatter for Jax too, via Flax.
It's very much like PyTorch, but uses Jax primitives under the hood.
In an attempt to keep things DRY I figured it couldn't hurt to create a
_nn_tree.pycommon helper file that both the Flax/PyTorch formatter could use. It certainly falls in the "not must-have but nice-to-have"-category, but one that might help me show this to some Jax people.