@@ -28,8 +28,7 @@ def get_pass_name(pass_id):
2828
2929
3030def get_ranged_incorrect_models (tolerance_args : List [int ], log_path : str ) -> set :
31- if not os .path .exists (log_path ):
32- return set ()
31+ assert os .path .exists (log_path )
3332
3433 t_start = tolerance_args [0 ]
3534 models_start = set (get_incorrect_models (t_start , log_path ))
@@ -40,13 +39,10 @@ def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set
4039 t_end = tolerance_args [1 ]
4140 models_end = set (get_incorrect_models (t_end , log_path ))
4241
43- print (f"[Filter] Tolerance Range: { t_start } -> { t_end } " )
4442 print (
45- f"[Filter] Fail( { t_start } ) : { len (models_start )} , Fail( { t_end } ): { len (models_end )} "
43+ f"[Init] number of incorrect models : { len (models_start )} (tolerance= { t_start } ) - { len (models_end )} (tolerance= { t_end } ) "
4644 )
47-
48- diff_set = models_start - models_end
49- return diff_set
45+ return models_start - models_end
5046
5147
5248class TaskController :
@@ -326,6 +322,14 @@ def generate_initial_tasks(args):
326322 return tasks_map , max_subgraph_size , running_states
327323
328324
325+ def extract_model_name_and_subgraph_idx (subgraph_path ):
326+ # Parse model name and subgraph index
327+ model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
328+ model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
329+ subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
330+ return model_name , subgraph_idx
331+
332+
329333def generate_refined_tasks (base_output_dir , current_pass_id ):
330334 """Generates tasks for Pass > 0 based on previous pass results."""
331335 prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
@@ -340,10 +344,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
340344 prev_tasks_map = prev_config .tasks_map
341345
342346 for subgraph_path in sorted (prev_config .incorrect_models ):
343- # Parse model name and subgraph index
344- model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
345- model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
346- subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
347+ model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
347348
348349 assert model_name in prev_tasks_map
349350 pre_task_for_model = prev_tasks_map [model_name ]
@@ -382,11 +383,11 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
382383 base_output_dir , current_pass_id
383384 )
384385
385- print (f"[INFO ] initial max_subgraph_size: { max_subgraph_size } " )
386- print (f"[INFO ] number of incorrect models: { len (tasks_map )} " )
387- for model_name , task_info in tasks_map .items ():
386+ print (f"[Init ] initial max_subgraph_size: { max_subgraph_size } " )
387+ print (f"[Init ] number of incorrect models: { len (tasks_map )} " )
388+ for idx , ( model_name , task_info ) in enumerate ( tasks_map .items () ):
388389 original_path = task_info ["original_path" ]
389- print (f"- { original_path } " )
390+ print (f"- [ { idx } ] { original_path } " )
390391
391392 if not tasks_map :
392393 print ("[FINISHED] No models need processing." )
@@ -525,12 +526,24 @@ def main(args):
525526 )
526527 print (f"\n --- Phase 3: Analysis (torlance={ tolerance } ) ---" )
527528 next_round_models = sorted (get_incorrect_models (tolerance , pass_log_path ))
529+ original_model_paths = set (
530+ [
531+ model_name
532+ for subgraph_path in next_round_models
533+ for model_name , _ in [
534+ extract_model_name_and_subgraph_idx (subgraph_path )
535+ ]
536+ ]
537+ )
538+
528539 running_states [f"pass_{ current_pass_id + 1 } " ] = {
529- "num_incorrect_models" : len (next_round_models ),
540+ "num_incorrect_models" : len (original_model_paths ),
530541 "incorrect_models" : list (next_round_models ),
531542 }
532543
533- print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
544+ print (
545+ f"[Analysis] Found { len (next_round_models )} incorrect subgraphs ({ len (original_model_paths )} original models)."
546+ )
534547 for idx , model_path in enumerate (next_round_models ):
535548 print (f"- [{ idx } ] { model_path } " )
536549
0 commit comments