Skip to content

456 add support for traces generated by jax 08 suppress warning#572

Open
devalshahamd wants to merge 35 commits intomainfrom
456-add-support-for-traces-generated-by-jax-08_suppress_warning
Open

456 add support for traces generated by jax 08 suppress warning#572
devalshahamd wants to merge 35 commits intomainfrom
456-add-support-for-traces-generated-by-jax-08_suppress_warning

Conversation

@devalshahamd
Copy link
Copy Markdown
Contributor

This pull request introduces improvements to error and warning handling in several modules, focusing on making warning messages less noisy and more informative. The changes include limiting the number of warning logs for missing or unresolved items, adding summary warnings when limits are exceeded, and improving robustness when encountering unknown data types.

Enhanced warning handling and logging:

  • Added counters and limits to suppress excessive warnings about missing hlo_op entries in trace_to_tree.py. Now, only the first 10 missing entries trigger warnings, and a summary warning is logged if more are encountered. [1] [2]
  • Updated the _resolve_operand_references method in util.py to accept a max_warnings parameter, limit warning logs to the first 10 unresolved operand references, and log a summary if more are found. [1] [2] [3]

Robustness improvements:

  • In tree_perf.py, added support for the pred dtype and improved handling of unknown dtypes by logging a warning and skipping them instead of raising an error. [1] [2]

gphuang and others added 30 commits February 20, 2026 12:32
…sing (#456)

Replace hardcoded `tensorboard_plugin_profile` imports with `xprof` (falling
back to the old package for backward compatibility). This enables loading
traces generated by JAX 0.8.0+, which produce HLO instruction IDs > INT_MAX
that crash tensorboard-plugin-profile <= 2.19.0.

Also fix GPU PID detection: xprof remaps device PIDs (e.g. 1 -> 1001),
breaking the previous `pid <= 100` heuristic. Now uses process metadata
(`/device:GPU` in process name) which is robust to any PID scheme.

Fix pre-existing bug where JaxTreePerfAnalyzer.from_file did not extract
and pass metadata_events to build_tree, causing a TypeError on main.

Co-authored-by: Cursor <cursoragent@cursor.com>
Three issues discovered when using xprof as the trace processing backend:

1. get_dict() gated custom_call_target and backend_config extraction on
   metadata= being present in the HLO text. xprof's graph_viewer omits
   metadata= (even with show_metadata=True), so custom_call_target was
   silently dropped. Fix: extract these fields independently.

2. xprof's graph_viewer emits operands as bare references (%bitcast.39.0)
   without inline type annotations (bf16[...] %bitcast.39.0). Add a
   post-processing pass (_resolve_operand_references) that looks up each
   reference in the hlo_ops dict and substitutes its output type.

3. xprof's xspace_to_tool_names() no longer extracts .hlo_proto.pb files
   as a side effect. process_protobuf_file now returns {} gracefully
   instead of crashing with IndexError when files are missing.

Also fix test_compare_perf_report.py to only compare columns present in
both reference and generated DataFrames (avoids KeyError from new columns
added by recent PRs).

Relax test_tree_event_cats assertion to tolerate minor categorization
differences between xprof and tensorboard-plugin-profile.

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add xprof==2.20.7 to install_requires for JAX trace processing
- Raise clear RuntimeError when xprof returns None (e.g. permission issues)

Co-authored-by: Cursor <cursoragent@cursor.com>
- jax_conv_minimal -> jax_conv_minimal_legacy (JAX ~0.6, full perf model)
- jax08_conv -> jax_conv_minimal_08 (JAX 0.8, comparable minimal conv)
- Update test_jax_conv_analysis.py path
- Document naming in tests/traces/README.md
- Add *.SSTABLE to .gitignore (xprof cache)
- Remove legacy SSTABLE from repo

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
- Update add_gpu_ops_to_tree and _categorize_gpu_kernel_ops docstrings to
  reflect process_name-based GPU detection (not pid <= 100)
- Clarify setup.py xprof comment: preferred library, required for JAX 0.8+
- Include converter library name in DataLoader RuntimeError when conversion fails
- Log warning when falling back to tensorboard-plugin-profile
- Log warning when HLO operand reference cannot be resolved
- Add explicit missing-columns assert in test_compare_perf_report

Co-authored-by: Cursor <cursoragent@cursor.com>
CI was using unpinned black (latest from pip), which can produce different
formatting than local. Pin to match local version so lint passes.

Co-authored-by: Cursor <cursoragent@cursor.com>
Reference reports may include these columns (from main) while generated
reports on this branch may not. Add to cols_ignore to avoid assertion
failure on UnaryElementwise and other sheets.

Co-authored-by: Cursor <cursoragent@cursor.com>
Resolve conflicts:
- tree_perf.py: keep defensive metadata_events None check from 456
- test_compare_perf_report.py: keep 456's missing_cols assert and cols_ignore

Co-authored-by: Cursor <cursoragent@cursor.com>
protobuf 4.25.8 is incompatible with grpcio-status (requires >=6.31.1).
Older protobuf causes RET_CHECK failure on JAX 0.8 traces with HLO ids > INT_MAX.

- Add protobuf>=6.31.1,<7.0.0 to setup.py install_requires
- Upgrade protobuf in regression-tests JAX step before running tests
- Remove tensorboard-plugin-profile from JAX step (use xprof from main install)

Co-authored-by: Cursor <cursoragent@cursor.com>
- gpu_event_analyser: use pid in gpu_pids instead of pid < 100 (xprof remaps PIDs)
- gpu_event_analyser: add _is_gpu_event helper, fix get_breakdown_df_multigpu
- trace_to_tree: add _is_gpu_event helper, guard name/args against None
- test_jax_conv_analysis: clarify test_tree_event_cats assertion
- test_jax_perf_report: add fixture to cleanup tmpdirs after tests

Co-authored-by: Cursor <cursoragent@cursor.com>
Resolve conflicts:
- regression-tests.yml: adopt main's reorganized workflow (#506), add protobuf>=6.31.1 for JAX tests, use xprof from setup.py
- test_compare_perf_report.py: accept deletion (replaced by test_perf_report_regression + test_compare_perf_reports)
Resolve conflicts:
- regression-tests.yml: adopt main's reorganized workflow (#506), add protobuf>=6.31.1 for JAX tests, use xprof from setup.py
- test_compare_perf_report.py: accept deletion (replaced by test_perf_report_regression + test_compare_perf_reports)

Made-with: Cursor
Fixes #533

The performance model is currently bare-bones, based on O(n) operations
being strictly necessary for these reductions.

It could be more tuned but that would depend significantly on the
implementation and the theoretical limit is something like n/2 + log(n)
at a minimum which is not likely significantly more accurate

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
xprof 2.20.7 does not generate .hlo_proto.pb sidecar files (the
ConvertMultiXSpaceToHloProto side effect was removed in 2.20.2),
causing missing hlo_op metadata on JAX 0.8 traces.

xprof 2.20.1 is the last version that retains:
- .hlo_proto.pb sidecar generation via xspace_to_tool_names()
- hlo_op and correlation_id on GPU kernel events
- JAX 0.8 trace loading (with benign INT_MAX warnings)

Tested on both jax_conv_minimal_08 (JAX 0.8) and
jax_conv_minimal_legacy (JAX 0.6) traces.
These are generated artifacts produced by xprof as a side effect of
xspace_to_tool_names(). Committing them masks sidecar generation
failures — xprof 2.20.7 silently produces no sidecars, but tests
passed because the pre-committed files were found on disk.

With xprof 2.20.1, sidecars are regenerated at runtime for both
JAX 0.6 and 0.8 traces, so checking them in is unnecessary.
Add string-typed groupby columns as tie-breakers to short_kernels_summary
sort so rows with the same duration sum have reproducible ordering.
Regenerate all reference xlsx files. Fix tmpdir leak in jax perf report
test, clarify protobuf pin comment, and update README sidecar description.
@devalshahamd devalshahamd requested a review from gphuang March 28, 2026 00:43
Base automatically changed from 456-add-support-for-traces-generated-by-jax-08 to main March 29, 2026 05:24
Copy link
Copy Markdown
Contributor

@gphuang gphuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the branch carries 34 commits from PR #501 that are already merged, causing the conflicts.

The changes are sensible! Approved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants