@@ -231,6 +231,12 @@ def run_decomposer_for_multi_models(
231231 original_path = task_info ["original_path" ]
232232
233233 split_positions = sorted (list (task_info ["split_positions" ]))
234+
235+ method = "fixed-start"
236+ if method == "fixed-start" :
237+ assert len (split_positions ) >= 3 , f"{ split_positions = } "
238+ split_positions = [0 , split_positions [1 ]]
239+
234240 rectified_model_path = get_rectfied_model_path (original_path )
235241 assert os .path .exists (
236242 rectified_model_path
@@ -279,18 +285,22 @@ def run_evaluation(
279285 ), f"[ERROR] test failed for { samples_dir } , please check the log."
280286
281287
282- def reconstruct_split_positions_for_subgraph (
283- split_positions , subgraph_idx , max_subgraph_size
288+ def reconstruct_split_positions_for_subgraphs (
289+ split_positions , subgraph_idxs , max_subgraph_size
284290):
285- assert (
286- subgraph_idx < len (split_positions ) - 1
287- ), f"subgraph_idx { subgraph_idx } is out of bounds of split_positions: { split_positions } ."
291+ subgraph_idxs = [subgraph_idxs ] if isinstance (subgraph_idxs , int ) else subgraph_idxs
288292
289- start_pos , end_pos = split_positions [subgraph_idx : subgraph_idx + 2 ]
290- new_split_positions = set (
291- range (start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
292- )
293- return sorted (list (new_split_positions ))
293+ new_split_positions = []
294+ for subgraph_idx in subgraph_idxs :
295+ assert (
296+ subgraph_idx < len (split_positions ) - 1
297+ ), f"subgraph_idx { subgraph_idx } is out of bounds of split_positions: { split_positions } ."
298+
299+ start_pos , end_pos = split_positions [subgraph_idx : subgraph_idx + 2 ]
300+ new_split_positions = new_split_positions + list (
301+ range (start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
302+ )
303+ return sorted (list (set (new_split_positions )))
294304
295305
296306def generate_initial_tasks (args ):
@@ -299,9 +309,9 @@ def generate_initial_tasks(args):
299309 initial_failures = get_ranged_incorrect_models (args .tolerance , args .log_file )
300310
301311 tasks_map = {}
302- max_subgraph_size = args .max_subgraph_size
312+ max_subgraph_size = min ( args .max_subgraph_size , kMaxGraphSize // 2 )
303313
304- initial_split_positions = reconstruct_split_positions_for_subgraph (
314+ initial_split_positions = reconstruct_split_positions_for_subgraphs (
305315 [0 , kMaxGraphSize ], 0 , max_subgraph_size
306316 )
307317 for model_path in initial_failures :
@@ -328,7 +338,29 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
328338 return model_name , subgraph_idx
329339
330340
331- def generate_successor_tasks (base_output_dir , current_pass_id ):
341+ def collect_incorrect_subgraph_idxs (args , model_names , incorrect_models ):
342+ model_name2subgraph_idxs = {}
343+ for subgraph_path in sorted (incorrect_models ):
344+ model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
345+ print (f"{ subgraph_path = } " )
346+
347+ if model_name not in model_name2subgraph_idxs :
348+ model_name2subgraph_idxs [model_name ] = []
349+ model_name2subgraph_idxs [model_name ].append (subgraph_idx )
350+
351+ if args .method == "fixed-start" :
352+ for model_name in model_names :
353+ if model_name not in model_name2subgraph_idxs :
354+ model_name2subgraph_idxs [model_name ] = [1 ]
355+ else :
356+ assert (
357+ len (model_name2subgraph_idxs [model_name ]) == 1
358+ and model_name2subgraph_idxs [model_name ] == 0
359+ )
360+ return model_name2subgraph_idxs
361+
362+
363+ def generate_successor_tasks (args , base_output_dir , current_pass_id ):
332364 """Generates tasks for Pass > 0 based on previous pass results."""
333365 prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
334366 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
@@ -341,29 +373,23 @@ def generate_successor_tasks(base_output_dir, current_pass_id):
341373 tasks_map = {}
342374 prev_tasks_map = prev_config .tasks_map
343375
344- for subgraph_path in sorted ( prev_config . incorrect_models ):
345- model_name , subgraph_idx = extract_model_name_and_subgraph_idx ( subgraph_path )
346- print ( f" { subgraph_path = } " )
376+ model_name2subgraph_idxs = collect_incorrect_subgraph_idxs (
377+ args , list ( prev_tasks_map . keys ()), prev_config . incorrect_models
378+ )
347379
380+ for model_name , subgraph_idxs in model_name2subgraph_idxs .items ():
348381 assert model_name in prev_tasks_map
349382 pre_task_for_model = prev_tasks_map [model_name ]
350383
351384 prev_split_positions = pre_task_for_model .get ("split_positions" , [])
352- split_positions = reconstruct_split_positions_for_subgraph (
353- prev_split_positions , subgraph_idx , max_subgraph_size
385+ split_positions = reconstruct_split_positions_for_subgraphs (
386+ prev_split_positions , subgraph_idxs , max_subgraph_size
354387 )
355- if model_name not in tasks_map :
356- tasks_map [model_name ] = {
357- "original_path" : pre_task_for_model ["original_path" ],
358- "split_positions" : list (sorted (split_positions )),
359- }
360- else :
361- merged_split_positions = (
362- tasks_map [model_name ]["split_positions" ] + split_positions
363- )
364- tasks_map [model_name ]["split_positions" ] = list (
365- sorted (set (merged_split_positions ))
366- )
388+
389+ tasks_map [model_name ] = {
390+ "original_path" : pre_task_for_model ["original_path" ],
391+ "split_positions" : split_positions ,
392+ }
367393 print (f"{ tasks_map = } " )
368394
369395 return tasks_map , max_subgraph_size , prev_config .running_states
@@ -374,7 +400,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
374400 tasks_map , max_subgraph_size , running_states = generate_initial_tasks (args )
375401 else :
376402 tasks_map , max_subgraph_size , running_states = generate_successor_tasks (
377- base_output_dir , current_pass_id
403+ args , base_output_dir , current_pass_id
378404 )
379405
380406 print (f"[Init] initial max_subgraph_size: { max_subgraph_size } " )
@@ -393,7 +419,6 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
393419 )
394420 sys .exit (0 )
395421
396- sys .exit (0 )
397422 return tasks_map , max_subgraph_size , running_states
398423
399424
@@ -402,6 +427,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
402427
403428 failed_decomposition = []
404429 need_decompose = True if len (tasks_map ) > 0 else False
430+ method = "fixed-start"
405431
406432 while need_decompose :
407433 decomposed_samples_dir = os .path .join (
@@ -426,6 +452,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
426452 not failed_decomposition
427453 and num_decomposed_samples == len (tasks_map )
428454 and max_subgraph_size > 1
455+ and method != "fixed-start"
429456 ):
430457 need_decompose = True
431458 shutil .rmtree (decomposed_samples_dir )
@@ -435,7 +462,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
435462 split_positions = task_info ["split_positions" ]
436463 if not split_positions or len (split_positions ) < 2 :
437464 continue
438- new_split_positions = reconstruct_split_positions_for_subgraph (
465+ new_split_positions = reconstruct_split_positions_for_subgraphs (
439466 split_positions , 0 , max_subgraph_size
440467 )
441468 task_info ["split_positions" ] = new_split_positions
@@ -550,6 +577,7 @@ def main(args):
550577 parser .add_argument (
551578 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
552579 )
580+ parser .add_argument ("--method" , type = str , required = True )
553581 parser .add_argument (
554582 "--tolerance" ,
555583 type = int ,
0 commit comments