@@ -111,8 +111,9 @@ def _print(self):
111111
112112@dataclass
113113class DecomposeConfig :
114+ method : str
115+ tolerance : int | List [int ]
114116 max_subgraph_size : int = - 1
115- incorrect_models : List [str ] = field (default_factory = list )
116117 tasks_map : Dict [str , Union [int , str , list , dict ]] = field (default_factory = dict )
117118 running_states : Dict [str , Union [int , str , list , dict ]] = field (default_factory = dict )
118119
@@ -139,6 +140,11 @@ def load(self, work_dir):
139140 def get_config_path (self , work_dir ) -> str :
140141 return os .path .join (work_dir , "decompose_config.json" )
141142
143+ def get_incorrect_models (self , pass_id ):
144+ pass_key = get_pass_name (pass_id )
145+ assert pass_key in self .running_states
146+ return self .running_states [pass_key ]["incorrect_models" ]
147+
142148 def update_running_states (self , pass_id , ** kwargs ):
143149 pass_key = get_pass_name (pass_id )
144150 if self .running_states .get (pass_key , None ) is None :
@@ -229,8 +235,13 @@ def run_decomposer_for_multi_models(
229235 )
230236 for model_name , task_info in tasks_map .items ():
231237 original_path = task_info ["original_path" ]
232-
233238 split_positions = sorted (list (task_info ["split_positions" ]))
239+
240+ method = "fixed-start"
241+ if method == "fixed-start" :
242+ assert len (split_positions ) >= 3 , f"{ split_positions = } "
243+ split_positions = [0 , split_positions [1 ]]
244+
234245 rectified_model_path = get_rectfied_model_path (original_path )
235246 assert os .path .exists (
236247 rectified_model_path
@@ -279,18 +290,22 @@ def run_evaluation(
279290 ), f"[ERROR] test failed for { samples_dir } , please check the log."
280291
281292
282- def reconstruct_split_positions_for_subgraph (
283- split_positions , subgraph_idx , max_subgraph_size
293+ def reconstruct_split_positions_for_subgraphs (
294+ split_positions , subgraph_idxs , max_subgraph_size
284295):
285- assert (
286- subgraph_idx < len (split_positions ) - 1
287- ), f"subgraph_idx { subgraph_idx } is out of bounds of split_positions: { split_positions } ."
296+ subgraph_idxs = [subgraph_idxs ] if isinstance (subgraph_idxs , int ) else subgraph_idxs
288297
289- start_pos , end_pos = split_positions [subgraph_idx : subgraph_idx + 2 ]
290- new_split_positions = set (
291- range (start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
292- )
293- return sorted (list (new_split_positions ))
298+ new_split_positions = []
299+ for subgraph_idx in subgraph_idxs :
300+ assert (
301+ subgraph_idx < len (split_positions ) - 1
302+ ), f"subgraph_idx { subgraph_idx } is out of bounds of split_positions: { split_positions } ."
303+
304+ start_pos , end_pos = split_positions [subgraph_idx : subgraph_idx + 2 ]
305+ new_split_positions = new_split_positions + list (
306+ range (start_pos , end_pos + max_subgraph_size , max_subgraph_size )
307+ )
308+ return sorted (list (set (new_split_positions )))
294309
295310
296311def generate_initial_tasks (args ):
@@ -299,9 +314,9 @@ def generate_initial_tasks(args):
299314 initial_failures = get_ranged_incorrect_models (args .tolerance , args .log_file )
300315
301316 tasks_map = {}
302- max_subgraph_size = args .max_subgraph_size
317+ max_subgraph_size = min ( args .max_subgraph_size , kMaxGraphSize // 2 )
303318
304- initial_split_positions = reconstruct_split_positions_for_subgraph (
319+ initial_split_positions = reconstruct_split_positions_for_subgraphs (
305320 [0 , kMaxGraphSize ], 0 , max_subgraph_size
306321 )
307322 for model_path in initial_failures :
@@ -328,42 +343,61 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
328343 return model_name , subgraph_idx
329344
330345
331- def generate_successor_tasks (base_output_dir , current_pass_id ):
346+ def collect_incorrect_subgraph_idxs (args , target_model_names , incorrect_models ):
347+ model_name2subgraph_idxs = {}
348+ for subgraph_path in sorted (incorrect_models ):
349+ model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
350+ print (f"{ subgraph_path = } " )
351+ print (f"{ model_name = } , { subgraph_idx = } " )
352+ assert model_name in target_model_names , f"{ model_name = } , { subgraph_idx = } "
353+
354+ if model_name not in model_name2subgraph_idxs :
355+ model_name2subgraph_idxs [model_name ] = []
356+ model_name2subgraph_idxs [model_name ].append (subgraph_idx )
357+
358+ if args .method == "fixed-start" :
359+ print (model_name2subgraph_idxs )
360+ for model_name in target_model_names :
361+ if model_name not in model_name2subgraph_idxs :
362+ model_name2subgraph_idxs [model_name ] = [1 ]
363+ else :
364+ assert len (
365+ model_name2subgraph_idxs [model_name ]
366+ ) == 1 and model_name2subgraph_idxs [model_name ] == [0 ]
367+ return model_name2subgraph_idxs
368+
369+
370+ def generate_successor_tasks (args , base_output_dir , current_pass_id ):
332371 """Generates tasks for Pass > 0 based on previous pass results."""
333372 prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
334373 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
335374
336375 prev_config = DecomposeConfig .load (prev_pass_dir )
337376 max_subgraph_size = prev_config .max_subgraph_size // 2
338- if not prev_config .incorrect_models :
377+ incorrect_models = prev_config .get_incorrect_models (current_pass_id )
378+ if args .method != "fixed-start" and not incorrect_models :
339379 return {}, max_subgraph_size , prev_config .running_states
340380
341381 tasks_map = {}
342382 prev_tasks_map = prev_config .tasks_map
343383
344- for subgraph_path in sorted (prev_config .incorrect_models ):
345- model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
346- print (f"{ subgraph_path = } " )
384+ target_model_names = list (prev_tasks_map .keys ())
385+ model_name2subgraph_idxs = collect_incorrect_subgraph_idxs (
386+ args , target_model_names , incorrect_models
387+ )
347388
348- assert model_name in prev_tasks_map
389+ for model_name , subgraph_idxs in model_name2subgraph_idxs . items ():
349390 pre_task_for_model = prev_tasks_map [model_name ]
350391
351392 prev_split_positions = pre_task_for_model .get ("split_positions" , [])
352- split_positions = reconstruct_split_positions_for_subgraph (
353- prev_split_positions , subgraph_idx , max_subgraph_size
393+ split_positions = reconstruct_split_positions_for_subgraphs (
394+ prev_split_positions , subgraph_idxs , max_subgraph_size
354395 )
355- if model_name not in tasks_map :
356- tasks_map [model_name ] = {
357- "original_path" : pre_task_for_model ["original_path" ],
358- "split_positions" : list (sorted (split_positions )),
359- }
360- else :
361- merged_split_positions = (
362- tasks_map [model_name ]["split_positions" ] + split_positions
363- )
364- tasks_map [model_name ]["split_positions" ] = list (
365- sorted (set (merged_split_positions ))
366- )
396+
397+ tasks_map [model_name ] = {
398+ "original_path" : pre_task_for_model ["original_path" ],
399+ "split_positions" : split_positions ,
400+ }
367401 print (f"{ tasks_map = } " )
368402
369403 return tasks_map , max_subgraph_size , prev_config .running_states
@@ -374,7 +408,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
374408 tasks_map , max_subgraph_size , running_states = generate_initial_tasks (args )
375409 else :
376410 tasks_map , max_subgraph_size , running_states = generate_successor_tasks (
377- base_output_dir , current_pass_id
411+ args , base_output_dir , current_pass_id
378412 )
379413
380414 print (f"[Init] initial max_subgraph_size: { max_subgraph_size } " )
@@ -393,7 +427,6 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
393427 )
394428 sys .exit (0 )
395429
396- sys .exit (0 )
397430 return tasks_map , max_subgraph_size , running_states
398431
399432
@@ -402,6 +435,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
402435
403436 failed_decomposition = []
404437 need_decompose = True if len (tasks_map ) > 0 else False
438+ method = "fixed-start"
405439
406440 while need_decompose :
407441 decomposed_samples_dir = os .path .join (
@@ -426,6 +460,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
426460 not failed_decomposition
427461 and num_decomposed_samples == len (tasks_map )
428462 and max_subgraph_size > 1
463+ and method != "fixed-start"
429464 ):
430465 need_decompose = True
431466 shutil .rmtree (decomposed_samples_dir )
@@ -435,7 +470,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
435470 split_positions = task_info ["split_positions" ]
436471 if not split_positions or len (split_positions ) < 2 :
437472 continue
438- new_split_positions = reconstruct_split_positions_for_subgraph (
473+ new_split_positions = reconstruct_split_positions_for_subgraphs (
439474 split_positions , 0 , max_subgraph_size
440475 )
441476 task_info ["split_positions" ] = new_split_positions
@@ -458,8 +493,7 @@ def count_unique_original_models(incorrect_models):
458493 return len (original_model_paths )
459494
460495
461- def print_summary_and_suggestion (next_round_models , max_subgraph_size ):
462- """Print suggestion/result."""
496+ def print_summary_and_suggestion (args , next_round_models , max_subgraph_size ):
463497 print ("\n " + "=" * 80 )
464498 if next_round_models and max_subgraph_size > 1 :
465499 print (f">>> [SUGGESTION] Issues remain (Count: { len (next_round_models )} )." )
@@ -485,6 +519,8 @@ def main(args):
485519 args , current_pass_id , base_output_dir
486520 )
487521 decompose_config = DecomposeConfig (
522+ method = args .method ,
523+ tolerance = args .tolerance ,
488524 max_subgraph_size = max_subgraph_size ,
489525 tasks_map = tasks_map ,
490526 running_states = running_states ,
@@ -517,7 +553,6 @@ def main(args):
517553 run_evaluation (args .framework , args .test_config , work_dir , log_path )
518554
519555 # --- Step 4: Analysis ---
520- next_pass_incorrect_models = set ()
521556 if task_controller .task_scheduler ["post_analysis" ]:
522557 tolerance = (
523558 args .tolerance [0 ] if isinstance (args .tolerance , list ) else args .tolerance
@@ -530,15 +565,17 @@ def main(args):
530565 num_incorrect_models = num_original_models ,
531566 incorrect_models = list (next_pass_incorrect_models ),
532567 )
568+
533569 print (
534570 f"[Analysis] Found { len (next_pass_incorrect_models )} incorrect subgraphs ({ num_original_models } original models)."
535571 )
536572 for idx , model_path in enumerate (next_pass_incorrect_models ):
537573 print (f"- [{ idx } ] { model_path } " )
538- print_summary_and_suggestion (next_pass_incorrect_models , max_subgraph_size )
574+ print_summary_and_suggestion (
575+ args , next_pass_incorrect_models , max_subgraph_size
576+ )
539577
540578 # --- Step 5: Save States ---
541- decompose_config .incorrect_models = list (next_pass_incorrect_models )
542579 decompose_config .save (work_dir )
543580
544581
@@ -550,6 +587,7 @@ def main(args):
550587 parser .add_argument (
551588 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
552589 )
590+ parser .add_argument ("--method" , type = str , required = True )
553591 parser .add_argument (
554592 "--tolerance" ,
555593 type = int ,
0 commit comments