Skip to content

WIP: adding a jax/flax formatter#9902

Open
koaning wants to merge 7 commits into
mainfrom
koaning/flax-nnx-formatter
Open

WIP: adding a jax/flax formatter#9902
koaning wants to merge 7 commits into
mainfrom
koaning/flax-nnx-formatter

Conversation

@koaning

@koaning koaning commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

📝 Summary

We've got a PyTorch formatter, which looks like this:

CleanShot 2026-06-16 at 14 13 25@2x

This PR adds a formatter for Jax too, via Flax.

CleanShot 2026-06-16 at 14 38 20@2x

It's very much like PyTorch, but uses Jax primitives under the hood.

In an attempt to keep things DRY I figured it couldn't hurt to create a _nn_tree.py common helper file that both the Flax/PyTorch formatter could use. It certainly falls in the "not must-have but nice-to-have"-category, but one that might help me show this to some Jax people.

Review in cubic

Copilot AI review requested due to automatic review settings June 16, 2026 12:39
@vercel

vercel Bot commented Jun 16, 2026

Copy link
Copy Markdown

The latest updates on your projects. Learn more about Vercel for GitHub.

Project Deployment Actions Updated (UTC)
marimo-docs Ready Ready Preview, Comment Jun 25, 2026 8:21am

Request Review

@koaning koaning added the enhancement New feature or request label Jun 16, 2026
The Flax formatter previously surfaced non-trainable state (BatchNorm
running stats, dropout PRNG keys) as a "+N state" note. PyTorch's
formatter ignores buffers, so the two were inconsistent. Drop the state
note and count only nnx.Param, matching the PyTorch behavior. Updates
tests and the smoke-test notebook accordingly.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a rich HTML tree formatter for Flax NNX (flax.nnx.Module), analogous to the existing PyTorch formatter, and factors shared NN-tree presentation (CSS/legend/helpers) into a common module.

Changes:

  • Introduce FlaxFormatter + flax_formatters.format() for rendering NNX modules as a collapsible tree.
  • Extract shared NN-tree presentation utilities into marimo/_output/formatters/_nn_tree.py and refactor PyTorch formatter to use it.
  • Add optional test coverage + smoke test notebook/script for the new formatter and declare flax as an optional test dependency.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/_output/formatters/test_flax_formatters.py Adds unit tests for Flax/NXX HTML tree rendering and formatter registration.
pyproject.toml Adds flax to the optional test dependency set (Python 3.12 marker).
marimo/_smoke_tests/formatters/flax_formatters.py Adds a smoke-test marimo app demonstrating Flax formatter output on several model shapes.
marimo/_output/formatters/pytorch_formatters.py Refactors PyTorch formatter to reuse shared NN-tree presentation utilities.
marimo/_output/formatters/formatters.py Registers the Flax formatter factory for lazy activation on import flax.
marimo/_output/formatters/flax_formatters.py Implements Flax NNX module tree extraction + HTML rendering and formatter registration.
marimo/_output/formatters/_nn_tree.py New shared CSS/legend/helpers for NN-tree formatters.
marimo/_dependencies/dependencies.py Adds DependencyManager.flax.
Comments suppressed due to low confidence (1)

marimo/_output/formatters/flax_formatters.py:330

  • FlaxFormatter.register() imports nnx unconditionally. Because registration runs automatically on import flax via the import hook, a missing/unsupported flax.nnx (or partial install) would raise and can break importing flax entirely. Guard the nnx import and avoid the self-import; you can reference the local format() directly.

Comment thread marimo/_smoke_tests/formatters/flax_formatters.py Outdated
Comment thread marimo/_output/formatters/flax_formatters.py Outdated

@cubic-dev-ai cubic-dev-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 issue found across 3 files (changes from recent commits).

Reply with feedback, questions, or to request a fix.

Re-trigger cubic

Comment thread marimo/_output/formatters/flax_formatters.py Outdated

@cubic-dev-ai cubic-dev-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 issues found and verified against the latest diff

Reply with feedback, questions, or to request a fix.

Re-trigger cubic

Comment thread pyproject.toml
Comment thread marimo/_smoke_tests/formatters/flax_formatters.py Outdated
@mscolnick mscolnick requested a review from manzt June 16, 2026 14:02
@koaning koaning changed the title Adding a jax/flax formatter WIP: adding a jax/flax formatter Jun 16, 2026
@koaning koaning removed the request for review from manzt June 16, 2026 14:36
The tree-walking and node HTML construction was duplicated nearly
line-for-line between the two formatters; only CSS/helpers had been
shared. Introduce framework-agnostic TreeNode/LeafBody types and
render_node/render_model in _nn_tree.py, so each formatter now only
extracts data from its own modules. Also address PR review feedback:
dim params-less Flax leaves (data-frozen) so the shared legend is
accurate, raise the jax floor to >=0.4.27 (required by flax 0.10), and
fix a stale smoke-test comment about the attention category.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@cubic-dev-ai cubic-dev-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 issue found across 6 files (changes from recent commits).

Tip: Review your code locally with the cubic CLI to iterate faster.

Re-trigger cubic

Comment thread marimo/_output/formatters/_nn_tree.py Outdated
_fmt_integer only handled K/M, so large (LLM-scale) models showed
unwieldy raw counts. Add a billions branch and extend the test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants