@@ -326,6 +326,14 @@ def generate_initial_tasks(args):
326326 return tasks_map , max_subgraph_size , running_states
327327
328328
329+ def extract_model_name_and_subgraph_idx (subgraph_path ):
330+ # Parse model name and subgraph index
331+ model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
332+ model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
333+ subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
334+ return model_name , subgraph_idx
335+
336+
329337def generate_refined_tasks (base_output_dir , current_pass_id ):
330338 """Generates tasks for Pass > 0 based on previous pass results."""
331339 prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
@@ -340,10 +348,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
340348 prev_tasks_map = prev_config .tasks_map
341349
342350 for subgraph_path in sorted (prev_config .incorrect_models ):
343- # Parse model name and subgraph index
344- model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
345- model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
346- subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
351+ model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
347352
348353 assert model_name in prev_tasks_map
349354 pre_task_for_model = prev_tasks_map [model_name ]
@@ -525,12 +530,18 @@ def main(args):
525530 )
526531 print (f"\n --- Phase 3: Analysis (torlance={ tolerance } ) ---" )
527532 next_round_models = sorted (get_incorrect_models (tolerance , pass_log_path ))
533+ original_model_paths = [
534+ model_name
535+ for subgraph_path in next_round_models
536+ for model_name , _ in [extract_model_name_and_subgraph_idx (subgraph_path )]
537+ ]
538+
528539 running_states [f"pass_{ current_pass_id + 1 } " ] = {
529- "num_incorrect_models" : len (next_round_models ),
540+ "num_incorrect_models" : len (set ( original_model_paths ) ),
530541 "incorrect_models" : list (next_round_models ),
531542 }
532543
533- print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
544+ print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs." )
534545 for idx , model_path in enumerate (next_round_models ):
535546 print (f"- [{ idx } ] { model_path } " )
536547
0 commit comments