@@ -55,6 +55,12 @@ class OpStat:
5555 op_dtypes : dict [str , int ] = field (default_factory = dict )
5656 count : int = 0
5757
58+ def update (self , other ):
59+ if isinstance (other , OpStat ) and self .op_name == other .op_name :
60+ self .count += other .count
61+ for name , count in other .op_dtypes .items ():
62+ self .op_dtypes [name ] = self .op_dtypes .get (name , 0 ) + count
63+
5864
5965def resolve_native_multi_head_attention (* args , ** kwargs ):
6066 query , key , value = args [0 ], args [1 ], args [2 ]
@@ -132,19 +138,23 @@ def resolve_with_real_tensor(op_func, device, meta_args, meta_kwargs):
132138 return None
133139
134140
135- def collect_op_stats_manual (model , input_dict , device ):
136- try :
137- # FX symbolic trace
138- traced = torch .fx .symbolic_trace (model )
139- # print(traced.graph)
140- except Exception :
141- print ("Failed to FX symbolic_trace" )
142- return False , None
141+ torch ._dynamo .config .capture_scalar_outputs = True
142+ torch ._dynamo .config .capture_dynamic_output_shape_ops = True
143+ torch ._dynamo .config .capture_sparse_compute = True
144+ torch ._dynamo .config .raise_on_ctx_manager_usage = False
145+ torch ._dynamo .config .allow_rnn = True
143146
144- # Use meta tensors as input to avoid actually running the model
145- meta_input_dict = convert_real_to_meta (input_dict )
146147
147- def get_output_dtype (out ):
148+ class GraphMetaExecutor :
149+ def __init__ (self , device ):
150+ self .device = device
151+ self .op_stats = {}
152+ self .is_complete = True
153+ self .num_ops = 0
154+ self .num_ops_misses_dtypes = 0
155+ self .subgraph_counter = 0
156+
157+ def get_output_dtype (self , out ):
148158 if isinstance (out , torch .Tensor ):
149159 return out .dtype
150160 if (
@@ -156,102 +166,160 @@ def get_output_dtype(out):
156166 else :
157167 return None
158168
159- is_complete = True
160- op_stats = {}
161- node_outputs = {}
162- for node in traced .graph .nodes :
169+ def get_op_name_and_func (self , node , node_outputs ):
163170 op_name = None
164- dtype = None
165- if node .op == "placeholder" :
166- node_outputs [node .name ] = meta_input_dict [node .target ]
167- op_name = node .op
168- elif node .op in ["call_function" , "call_module" , "call_method" ]:
169- node_args = torch .fx .map_arg (
170- node .args ,
171- lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
172- )
173- node_kwargs = torch .fx .map_arg (
174- node .kwargs ,
175- lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
176- )
177-
178- try :
179- if node .op == "call_module" :
180- # classname of module
181- submod = traced .get_submodule (node .target )
182- op_name = submod .__class__ .__name__
183- op_func = submod
184- elif node .op == "call_function" :
185- op_name = node .target .__name__
186- op_func = node .target
187- elif node .op == "call_method" :
188- op_name = node .target
189- self_obj = (
190- node_outputs [node .args [0 ].name ]
191- if isinstance (node .args [0 ], torch .fx .Node )
192- else node .args [0 ]
193- )
194- op_func = getattr (self_obj , node .target )
195- node_args = node_args [1 :]
196-
197- # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
198- if op_name == "_native_multi_head_attention" :
199- out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
200- elif op_name == "to" :
201- # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
202- out = resolve_tensor_to (
203- node_outputs [node .args [0 ].name ], * node_args , ** node_kwargs
204- )
205- elif op_name == "item" :
206- out = resolve_tensor_item (node_outputs [node .args [0 ].name ])
207- else :
208- out = op_func (* node_args , ** node_kwargs )
209- node_outputs [node .name ] = out
210- dtype = get_output_dtype (out )
211- except Exception :
212- out = resolve_with_real_tensor (op_func , device , node_args , node_kwargs )
213- node_outputs [node .name ] = out
214- if out is not None :
215- dtype = get_output_dtype (out )
216- else :
217- print (
218- f"dtype inference failed: node.op={ node .op } , op_name={ op_name } "
219- )
220- is_complete = False
221- elif node .op == "get_attr" :
222- op_name = node .op
223- out = resolve_get_attr (traced , node )
224- node_outputs [node .name ] = out
225- dtype = get_output_dtype (out )
226- elif node .op == "output" :
227- op_name = node .op
228- node_args = torch .fx .map_arg (
229- node .args ,
230- lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
231- )
232- node_outputs [node .name ] = node_args [0 ] if len (node_args ) == 1 else node_args
233- dtype = get_output_dtype (node_args [0 ])
234- else :
235- assert False , f"node.op: { node .op } "
236-
171+ op_func = None
172+ try :
173+ if node .op == "call_module" :
174+ # classname of module
175+ submod = traced .get_submodule (node .target )
176+ op_name = submod .__class__ .__name__
177+ op_func = submod
178+ elif node .op == "call_function" :
179+ op_name = node .target .__name__
180+ op_func = node .target
181+ elif node .op == "call_method" :
182+ op_name = node .target
183+ self_obj = (
184+ node_outputs [node .args [0 ].name ]
185+ if isinstance (node .args [0 ], torch .fx .Node )
186+ else node .args [0 ]
187+ )
188+ op_func = getattr (self_obj , node .target )
189+ elif node .op in ["get_attr" , "placeholder" , "output" ]:
190+ op_name = node .op
191+ except Exception :
192+ pass
193+ return op_name , op_func
194+
195+ def update_op_stats (self , op_stats , op_name , op_dtype ):
237196 if op_name is not None :
238- dtype_str = str (dtype ).replace ("torch." , "" )
197+ dtype_str = str (op_dtype ).replace ("torch." , "" )
239198 if op_stats .get (op_name , None ) is None :
240199 op_stats [op_name ] = OpStat (op_name , {dtype_str : 1 }, 1 )
241200 else :
242201 op_stats [op_name ].op_dtypes [dtype_str ] = (
243202 op_stats [op_name ].op_dtypes .get (dtype_str , 0 ) + 1
244203 )
245- op_stats [op_name ].count = op_stats [op_name ].count + 1
246- return is_complete , op_stats
204+ op_stats [op_name ].count += 1
205+
206+ def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
207+ # Use meta tensors as input to avoid actually running the model
208+ meta_sample_inputs = convert_real_to_meta (sample_inputs )
209+
210+ op_stats = {}
211+ num_ops_misses_dtypes = 0
212+
213+ input_idx = 0
214+ node_outputs = {}
215+ for node in gm .graph .nodes :
216+ out = None
217+ op_dtype = None
218+ op_name , op_func = self .get_op_name_and_func (node , node_outputs )
219+ if node .op == "placeholder" :
220+ out = meta_sample_inputs [input_idx ]
221+ input_idx += 1
222+ elif node .op in ["call_function" , "call_module" , "call_method" ]:
223+ try :
224+ node_args = torch .fx .map_arg (
225+ node .args ,
226+ lambda n : node_outputs [n .name ]
227+ if isinstance (n , torch .fx .Node )
228+ else n ,
229+ )
230+ node_kwargs = torch .fx .map_arg (
231+ node .kwargs ,
232+ lambda n : node_outputs [n .name ]
233+ if isinstance (n , torch .fx .Node )
234+ else n ,
235+ )
236+ if node .op == "call_method" :
237+ node_args = node_args [1 :]
238+
239+ if op_name == "_native_multi_head_attention" :
240+ out = resolve_native_multi_head_attention (
241+ * node_args , ** node_kwargs
242+ )
243+ elif op_name == "to" :
244+ out = resolve_tensor_to (
245+ node_outputs [node .args [0 ].name ], * node_args , ** node_kwargs
246+ )
247+ elif op_name == "item" :
248+ out = resolve_tensor_item (node_outputs [node .args [0 ].name ])
249+ else :
250+ assert op_func is not None , f"op_func of { node } is None."
251+ out = op_func (* node_args , ** node_kwargs )
252+ except Exception :
253+ out = resolve_with_real_tensor (
254+ op_func , self .device , node_args , node_kwargs
255+ )
256+ if out is None :
257+ if num_ops_misses_dtypes == 0 :
258+ print (
259+ f"dtype inference failed: node.op={ node .op } , op_name={ op_name } "
260+ )
261+ num_ops_misses_dtypes += 1
262+ elif node .op == "get_attr" :
263+ out = resolve_get_attr (traced , node )
264+ elif node .op == "output" :
265+ pass
266+ else :
267+ assert False , f"node.op: { node .op } "
268+
269+ if out is not None :
270+ node_outputs [node .name ] = out
271+ op_dtype = self .get_output_dtype (out )
272+
273+ if node .op not in ["placeholder" , "output" ]:
274+ self .update_op_stats (op_stats , op_name , op_dtype )
275+
276+ if num_ops_misses_dtypes > 0 :
277+ self .is_complete = False
278+ self .num_ops_misses_dtypes += num_ops_misses_dtypes
279+ num_ops = 0
280+ for name , stat in op_stats .items ():
281+ num_ops += stat .count
282+ if name in self .op_stats .keys ():
283+ self .op_stats [name ].update (stat )
284+ else :
285+ self .op_stats [name ] = stat
286+ self .num_ops += num_ops
287+ self .subgraph_counter += 1
288+ return gm .forward
289+
290+ def summary (self ):
291+ print (
292+ f"Totally { self .subgraph_counter } subgraphs, { self .num_ops } operators, and { self .num_ops_misses_dtypes } operators failed to inference dtypes."
293+ )
247294
248295
249- def collect_op_stats_with_make_fx (model , input_dict , arg_types ):
296+ def collect_op_stats_with_compile (model , sample_inputs , device ):
297+ assert isinstance (model , torch .nn .Module ), f"{ type (model )= } "
298+ meta_executor = GraphMetaExecutor (device )
299+ compiled_model = torch .compile (model , backend = meta_executor )
300+ compiled_model (* sample_inputs )
301+ meta_executor .summary ()
302+ return "compile" , meta_executor .is_complete , meta_executor .op_stats
303+
304+
305+ def collect_op_stats_manual (model , sample_inputs , device ):
306+ try :
307+ # FX symbolic trace
308+ traced = torch .fx .symbolic_trace (model )
309+ # print(traced.graph)
310+ except Exception :
311+ print ("Failed to FX symbolic_trace" )
312+ return False , None
313+
314+ meta_executor = GraphMetaExecutor (device )
315+ meta_executor (traced , sample_inputs )
316+ meta_executor .summary ()
317+ return meta_executor .is_complete , meta_executor .op_stats
318+
319+
320+ def collect_op_stats_with_make_fx (model , sample_inputs ):
250321 # Use meta tensors as input to avoid actually running the model
251- meta_input_list = []
252- for arg_name in arg_types .keys ():
253- x = input_dict [arg_name ]
254- meta_input_list .append (convert_real_to_meta (x ))
322+ meta_input_list = convert_real_to_meta (sample_inputs )
255323
256324 try :
257325 # Generate FX Graph, and automatically fill in meta information
@@ -325,14 +393,19 @@ def collect_model_stats(model_path, device, log_prompt):
325393 model = model_class ()
326394 arg_types = get_argument_types (model_class , "forward" )
327395 input_dict = get_input_dict (model_path , device )
396+ ordered_input_list = [input_dict [arg_name ] for arg_name in arg_types .keys ()]
328397
329398 num_ops = 0
330399 num_outputs = 0
331400 ops_count_dict = {}
332401 op_dtypes = {}
333- method , is_complete , op_stats = collect_op_stats (
334- model , input_dict , arg_types , device
402+ method , is_complete , op_stats = collect_op_stats_with_compile (
403+ model , ordered_input_list , device
335404 )
405+
406+ # method, is_complete, op_stats = collect_op_stats(
407+ # model, input_dict, arg_types, device
408+ # )
336409 if op_stats is not None :
337410 for op_name , stat in sorted (op_stats .items ()):
338411 if op_name == "placeholder" :
@@ -474,5 +547,4 @@ def main(args):
474547 help = "Log prompt for stats log filtering." ,
475548 )
476549 args = parser .parse_args ()
477- print (f"[CollectStats Arguments] { args } " )
478550 main (args = args )
0 commit comments