Fix/trace2tree: fix cross-rank GPU attribution and merged-trace hang#577
Fix/trace2tree: fix cross-rank GPU attribution and merged-trace hang#577
Conversation
… gpu_events lookup
|
Thanks for the detailed write-up and the performance investigation. On fixes 1 and 2 (cross-rank attribution and the hang): As noted in the TraceFusion docs, merged traces produced by TraceFusion are intended only for visual analysis in Perfetto, not for automated analysis. The tree perf / trace2tree pipeline operates on single-rank traces. The intended workflow for multi-rank analysis is to generate perf reports per rank individually, then analyze or compare the resulting report sheets. Rather than making the trace2tree internals handle merged-trace correlation collisions, I think the better approach would be to detect merged/multi-rank input early and surface a clear error pointing users to the per-rank workflow. That keeps the single-rank code path simple. On fix 3 (O(subtree) → O(1) lookup in This is a clean win — the propagated Suggestion: Could we split this into two PRs? Land the |
|
Hey @ajassani, thank you for your review! Yes I think your comments make sense, I can turn this into two PRs so that fix 3 is independent of fix 1 / 2. Thank you! |
# $O(subtree \times N_{launchers})$ traversal in `tree_perf.py`
`get_kernel_launchers` computed subtree GPU time by calling
`_compute_subtree_kernel_time_us(event)`, which called
`loop_and_aggregate_kernels` (a full recursive subtree traversal) *for
every launcher*. Since `add_gpu_ops_to_tree` already propagates all GPU
kernel UIDs up to every ancestor via event["gpu_events"], this is now an
$O(1)$ field lookup.
Split PR at request of @ajassani
#577 (comment)
<!--
Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights
reserved.
See LICENSE for license information.
-->
# Pull Request Template
> **Note to AMDers:**
> This is a public repository. Please do **not** upload any confidential
or customer data. Make sure all such data has been anonymized or removed
before making this PR. If you need to attach any private files or links,
please insert a Internal OneDrive Link or a Jira Ticket Link instead.
The problem
When PyTorch profiles a single rank, each GPU kernel launch is assigned a correlation ID unique within that session. When traces from K ranks are merged into one file, each rank's correlation IDs restart from the same numeric range, causing collisions.
add_gpu_ops_to_treeuses these correlation IDs to link CPU runtime events to their GPU kernels, and was vulnerable to these collisions in two ways.First,
_get_graph_gpu_eventslooked up GPU kernels by correlation ID alone, so in a merged trace every graph launch event claimed kernels from all K ranks sharing that ID, producing incorrect attribution.Second, the per-kernel ancestor walk had no mechanism to deduplicate GPU events that appeared as children of multiple runtime parents, causing$O(K^2 \times N_{gpu})$ complexity during propagation.
Together, these were causing issues such as incorrect cross-rank GPU attribution and indefinite hangs when processing merged trace files from PyTorch profiling of multi-node workloads.
This PR
Fixes two bugs in$O(subtree)$ per-launcher traversal in
TraceToTree.add_gpu_ops_to_treethat caused indefinite hangs and incorrect GPU event attribution when processing large merged multi-rank traces, and removes a relatedTreePerfAnalyzer.Three changes
trace_to_tree.py_get_graph_gpu_eventslooked up GPU kernels by correlation ID alone via linking_id_to_gpu_events[corr] (introduced in #522). In merged traces this bucket contains GPU kernels from all K ranks sharing that correlation, so every graph launch was incorrectly claiming foreign-rank kernels as its own, inflating gpu_events, total_subtree_kernel_time, and kernel_details with cross-rank data.tree_perf.pyget_kernel_launcherscomputed subtree GPU time by calling_compute_subtree_kernel_time_us(event), which calledloop_and_aggregate_kernels(a full recursive subtree traversal) for every launcher. Sinceadd_gpu_ops_to_treealready propagates all GPU kernel UIDs up to every ancestor via event["gpu_events"], this is now antrace_to_tree.pyIn merged K-rank traces, correlation ID collisions caused GPU kernels from all K ranks to become linked as children of every runtime event sharing that correlation. The ancestor walk then ran for all of those cross-rank kernels, producing$O(K \times N_{gpu} \times depth)$ individual
list.append()calls, causing an indefinite hang.Replaced the per-kernel ancestor walk with a single BFS topological sort seeded from$O(K^2 \times N_{gpu})$ to $O(N)$ . GC is disabled during propagation to eliminate cyclic collector overhead.
cpu_root_nodes, followed by a reverse-orderlist.extend()propagation pass. A visited set ensures each event is processed exactly once regardless of how many parents claim it, collapsing traversal fromTesting
Result
Large multi-node merged PyTorch trace file that previously ran indefinitely now runs in 235.9 seconds.
Pull Request Template