From 046015b03b6bc02ac2648a324cafdb14cafe0a40 Mon Sep 17 00:00:00 2001 From: Ahmed Shuaibi Date: Mon, 24 Nov 2025 14:40:13 -0800 Subject: [PATCH] enable detailed memory breakdown for fixed number of iterations Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2167 Removed the environment variable FBGEMM_TBE_MEM_BREAKDOWN and instead log the detailed memory breakdown for a fixed number (10) of iterations. This change simplifies configuration and allows for rule-based control of the feature. Reviewed By: spcyppt Differential Revision: D87287099 --- ...it_table_batched_embeddings_ops_training.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 597446a36b..ade8262fe5 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1430,6 +1430,10 @@ def __init__( # noqa C901 self.step = 0 self.last_reported_step = 0 self.last_reported_uvm_stats: list[float] = [] + # Track number of times detailed memory breakdown has been reported + self.detailed_mem_breakdown_report_count = 0 + # Set max number of reports for detailed memory breakdown + self.max_detailed_mem_breakdown_reports = 10 # Check whether to use TBE v2 is_experimental = False @@ -1885,14 +1889,14 @@ def _report_tbe_mem_usage(self) -> None: tbe_id=self.uuid, ) - # Check if detailed memory breakdown is enabled via environment variable - # Set FBGEMM_TBE_MEM_BREAKDOWN=1 to enable expensive detailed breakdown - enable_detailed_breakdown = ( - int(os.environ.get("FBGEMM_TBE_MEM_BREAKDOWN", "0")) == 1 - ) - - if not enable_detailed_breakdown: + # Only report detailed breakdown for the first max_detailed_mem_breakdown_reports reportable + # steps since static sparse memory (weights, optimizer states, cache) is constant + if ( + self.detailed_mem_breakdown_report_count + >= self.max_detailed_mem_breakdown_reports + ): return + self.detailed_mem_breakdown_report_count += 1 # Tensor groups for sparse memory categorization weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]