Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,10 @@ def __init__( # noqa C901
self.step = 0
self.last_reported_step = 0
self.last_reported_uvm_stats: list[float] = []
# Track number of times detailed memory breakdown has been reported
self.detailed_mem_breakdown_report_count = 0
# Set max number of reports for detailed memory breakdown
self.max_detailed_mem_breakdown_reports = 10

# Check whether to use TBE v2
is_experimental = False
Expand Down Expand Up @@ -1885,14 +1889,14 @@ def _report_tbe_mem_usage(self) -> None:
tbe_id=self.uuid,
)

# Check if detailed memory breakdown is enabled via environment variable
# Set FBGEMM_TBE_MEM_BREAKDOWN=1 to enable expensive detailed breakdown
enable_detailed_breakdown = (
int(os.environ.get("FBGEMM_TBE_MEM_BREAKDOWN", "0")) == 1
)

if not enable_detailed_breakdown:
# Only report detailed breakdown for the first max_detailed_mem_breakdown_reports reportable
# steps since static sparse memory (weights, optimizer states, cache) is constant
if (
self.detailed_mem_breakdown_report_count
>= self.max_detailed_mem_breakdown_reports
):
return
self.detailed_mem_breakdown_report_count += 1

# Tensor groups for sparse memory categorization
weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]
Expand Down
Loading