@@ -27,6 +27,28 @@ def get_pass_name(pass_id):
2727 return f"pass_{ pass_id } "
2828
2929
30+ def get_ranged_incorrect_models (tolerance_args : List [int ], log_path : str ) -> set :
31+ if not os .path .exists (log_path ):
32+ return set ()
33+
34+ t_start = tolerance_args [0 ]
35+ models_start = set (get_incorrect_models (t_start , log_path ))
36+
37+ if len (tolerance_args ) == 1 :
38+ return models_start
39+
40+ t_end = tolerance_args [1 ]
41+ models_end = set (get_incorrect_models (t_end , log_path ))
42+
43+ print (f"[Filter] Tolerance Range: { t_start } -> { t_end } " )
44+ print (
45+ f"[Filter] Fail({ t_start } ): { len (models_start )} , Fail({ t_end } ): { len (models_end )} "
46+ )
47+
48+ diff_set = models_start - models_end
49+ return diff_set
50+
51+
3052class TaskController :
3153 def __init__ (self , args ):
3254 self .root_output_dir = os .path .abspath (args .output_dir )
@@ -198,10 +220,10 @@ def run_decomposer_for_multi_models(
198220 )
199221 for model_name , task_info in tasks_map .items ():
200222 original_path = task_info ["original_path" ]
201- split_positions = calculate_split_positions_for_subgraph (
202- task_info ["subgraph_size" ], max_subgraph_size
203- )
204- task_info [ " split_positions" ] = split_positions
223+
224+ split_positions = task_info ["split_positions" ]
225+ if isinstance ( split_positions , set ):
226+ split_positions = sorted ( list ( split_positions ))
205227
206228 rectified_model_path = get_rectfied_model_path (original_path )
207229 assert os .path .exists (
@@ -262,35 +284,39 @@ def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
262284 return subgraph_size
263285
264286
265- def calculate_split_positions_for_subgraph (subgraph_size , max_subgraph_size ):
266- assert isinstance (subgraph_size , (list , tuple )) and len (subgraph_size ) == 2
287+ def calculate_split_positions_for_subgraph (subgraph_range , max_subgraph_size ):
288+ assert isinstance (subgraph_range , (list , tuple )) and len (subgraph_range ) == 2
267289
268290 # subgraph_size: the start and end position in original model.
269- start_pos , end_pos = subgraph_size
291+ start_pos , end_pos = subgraph_range
270292 end_pos = kMaxGraphSize if end_pos == float ("inf" ) else end_pos
271293
272- split_positions = list (range (start_pos , end_pos + 1 , max_subgraph_size ))
273- deduplicated_splits = list (dict . fromkeys (split_positions ))
294+ split_positions = set (range (start_pos , end_pos + 1 , max_subgraph_size ))
295+ deduplicated_splits = list (sorted (split_positions ))
274296 return deduplicated_splits
275297
276298
277299def generate_initial_tasks (args ):
278300 """Generates tasks for Pass 0 based on the initial log file."""
279301 print (f"[Init] Pass 0: Reading from log file: { args .log_file } " )
280- initial_failures = get_incorrect_models (args .tolerance , args .log_file )
281- t1_incorrect_models = get_incorrect_models (1 , args .log_file )
282- initial_failures = initial_failures - t1_incorrect_models
302+ initial_failures = get_ranged_incorrect_models (args .tolerance , args .log_file )
283303
284304 tasks_map = {}
305+ max_subgraph_size = args .max_subgraph_size
306+
285307 for model_path in initial_failures :
286308 model_name = get_model_name_with_subgraph_tag (model_path )
309+
310+ initial_range = [0 , kMaxGraphSize ]
311+ initial_splits = calculate_split_positions_for_subgraph (
312+ initial_range , max_subgraph_size
313+ )
314+
287315 tasks_map [model_name ] = {
288316 "original_path" : model_path ,
289- "subgraph_size" : [0 , kMaxGraphSize ],
290- "split_positions" : set (),
317+ "split_positions" : list (sorted (initial_splits )),
291318 }
292319
293- max_subgraph_size = args .max_subgraph_size
294320 running_states = {
295321 "pass_0" : {
296322 "num_incorrect_models" : len (initial_failures ),
@@ -322,20 +348,25 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
322348 assert model_name in prev_tasks_map
323349 pre_task_for_model = prev_tasks_map [model_name ]
324350
325- # Reconstruct previous subgraph size to locate the failing segment
326351 prev_split_positions = pre_task_for_model .get ("split_positions" , [])
327- subgraph_size = reconstruct_subgraph_size (prev_split_positions )
352+ subgraph_ranges = reconstruct_subgraph_size (prev_split_positions )
353+
328354 assert subgraph_idx < len (
329- subgraph_size
355+ subgraph_ranges
330356 ), f"subgraph_idx { subgraph_idx } is out of bounds for { model_name } (previous split_positions: { prev_split_positions } )"
331357
358+ new_splits = calculate_split_positions_for_subgraph (
359+ subgraph_ranges [subgraph_idx ], max_subgraph_size
360+ )
361+
332362 if model_name not in tasks_map :
333363 tasks_map [model_name ] = {
334364 "original_path" : pre_task_for_model ["original_path" ],
335- "subgraph_size" : subgraph_size [subgraph_idx ],
336- "split_positions" : set (),
365+ "split_positions" : list (sorted (new_splits )),
337366 }
338-
367+ else :
368+ new_splits = set (tasks_map [model_name ]["split_positions" ]) + set (new_splits )
369+ tasks_map [model_name ]["split_positions" ] = list (sorted (new_splits ))
339370 return tasks_map , max_subgraph_size , prev_config .running_states
340371
341372
@@ -399,11 +430,23 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
399430 need_decompose = True
400431 shutil .rmtree (decomposed_samples_dir )
401432 os .makedirs (decomposed_samples_dir , exist_ok = True )
433+ max_subgraph_size = max (1 , max_subgraph_size // 2 )
402434 for model_name , task_info in tasks_map .items ():
403- task_info ["subgraph_size" ][1 ] = (
404- task_info ["subgraph_size" ][0 ] + max_subgraph_size
435+ splits = task_info ["split_positions" ]
436+ if not splits or len (splits ) < 2 :
437+ continue
438+ if isinstance (splits , set ):
439+ splits = sorted (list (splits ))
440+ start_pos = splits [0 ]
441+ first_segment_end = splits [1 ]
442+ new_splits = list (
443+ range (start_pos , first_segment_end + 1 , max_subgraph_size )
405444 )
406- max_subgraph_size = max (1 , max_subgraph_size // 2 )
445+
446+ if new_splits [- 1 ] != first_segment_end :
447+ new_splits .append (first_segment_end )
448+
449+ task_info ["split_positions" ] = sorted (list (set (new_splits )))
407450 else :
408451 need_decompose = False
409452 print ()
@@ -458,6 +501,7 @@ def main(args):
458501 "failed_decomposition_models"
459502 ] = list (failed_decomposition )
460503 else :
504+ print ("\n --- Phase 1: Decomposition (skipped) ---" , flush = True )
461505 config = DecomposeConfig .load (pass_work_dir )
462506 max_subgraph_size = config .max_subgraph_size
463507 tasks_map = config .tasks_map
@@ -466,19 +510,26 @@ def main(args):
466510 # --- Step 3: Evaluation ---
467511 pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
468512 if task_controller .task_scheduler ["run_evaluation" ]:
469- print ("\n --- Phase 2: Evaluation ---" )
513+ print (f "\n --- Phase 2: Evaluation { task_controller . test_module_name } ---" )
470514 run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
471515
472516 # --- Step 4: Analysis ---
473517 next_round_models = set ()
474518 if task_controller .task_scheduler ["post_analysis" ]:
475- print ("\n --- Phase 3: Analysis ---" )
476- next_round_models = sorted (get_incorrect_models (args .tolerance , pass_log_path ))
477- print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
519+ tolerance = (
520+ args .tolerance [0 ] if isinstance (args .tolerance , list ) else args .tolerance
521+ )
522+ print (f"\n --- Phase 3: Analysis (torlance={ tolerance } ) ---" )
523+ next_round_models = sorted (get_incorrect_models (tolerance , pass_log_path ))
478524 running_states [f"pass_{ current_pass_id + 1 } " ] = {
479525 "num_incorrect_models" : len (next_round_models ),
480526 "incorrect_models" : list (next_round_models ),
481527 }
528+
529+ print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
530+ for idx , model_path in enumerate (next_round_models ):
531+ print (f"- [{ idx } ] { model_path } " )
532+
482533 print_summary_and_suggestion (next_round_models , max_subgraph_size )
483534
484535 # --- Step 5: Save States ---
@@ -500,7 +551,11 @@ def main(args):
500551 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
501552 )
502553 parser .add_argument (
503- "--tolerance" , type = int , required = True , help = "Tolerance level range [-10, 5)"
554+ "--tolerance" ,
555+ type = int ,
556+ nargs = "+" ,
557+ required = True ,
558+ help = "Tolerance level range [-10, 5)" ,
504559 )
505560 parser .add_argument ("--max-subgraph-size" , type = int , default = 4096 )
506561 args = parser .parse_args ()
0 commit comments