Skip to content

Commit 00b070d

Browse files
committed
Support fixed-start method.
1 parent d7c91a2 commit 00b070d

File tree

1 file changed

+61
-33
lines changed

1 file changed

+61
-33
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ def run_decomposer_for_multi_models(
231231
original_path = task_info["original_path"]
232232

233233
split_positions = sorted(list(task_info["split_positions"]))
234+
235+
method = "fixed-start"
236+
if method == "fixed-start":
237+
assert len(split_positions) >= 3, f"{split_positions=}"
238+
split_positions = [0, split_positions[1]]
239+
234240
rectified_model_path = get_rectfied_model_path(original_path)
235241
assert os.path.exists(
236242
rectified_model_path
@@ -279,18 +285,22 @@ def run_evaluation(
279285
), f"[ERROR] test failed for {samples_dir}, please check the log."
280286

281287

282-
def reconstruct_split_positions_for_subgraph(
283-
split_positions, subgraph_idx, max_subgraph_size
288+
def reconstruct_split_positions_for_subgraphs(
289+
split_positions, subgraph_idxs, max_subgraph_size
284290
):
285-
assert (
286-
subgraph_idx < len(split_positions) - 1
287-
), f"subgraph_idx {subgraph_idx} is out of bounds of split_positions: {split_positions}."
291+
subgraph_idxs = [subgraph_idxs] if isinstance(subgraph_idxs, int) else subgraph_idxs
288292

289-
start_pos, end_pos = split_positions[subgraph_idx : subgraph_idx + 2]
290-
new_split_positions = set(
291-
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
292-
)
293-
return sorted(list(new_split_positions))
293+
new_split_positions = []
294+
for subgraph_idx in subgraph_idxs:
295+
assert (
296+
subgraph_idx < len(split_positions) - 1
297+
), f"subgraph_idx {subgraph_idx} is out of bounds of split_positions: {split_positions}."
298+
299+
start_pos, end_pos = split_positions[subgraph_idx : subgraph_idx + 2]
300+
new_split_positions = new_split_positions + list(
301+
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
302+
)
303+
return sorted(list(set(new_split_positions)))
294304

295305

296306
def generate_initial_tasks(args):
@@ -299,9 +309,9 @@ def generate_initial_tasks(args):
299309
initial_failures = get_ranged_incorrect_models(args.tolerance, args.log_file)
300310

301311
tasks_map = {}
302-
max_subgraph_size = args.max_subgraph_size
312+
max_subgraph_size = min(args.max_subgraph_size, kMaxGraphSize // 2)
303313

304-
initial_split_positions = reconstruct_split_positions_for_subgraph(
314+
initial_split_positions = reconstruct_split_positions_for_subgraphs(
305315
[0, kMaxGraphSize], 0, max_subgraph_size
306316
)
307317
for model_path in initial_failures:
@@ -328,7 +338,29 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
328338
return model_name, subgraph_idx
329339

330340

331-
def generate_successor_tasks(base_output_dir, current_pass_id):
341+
def collect_incorrect_subgraph_idxs(args, model_names, incorrect_models):
342+
model_name2subgraph_idxs = {}
343+
for subgraph_path in sorted(incorrect_models):
344+
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
345+
print(f"{subgraph_path=}")
346+
347+
if model_name not in model_name2subgraph_idxs:
348+
model_name2subgraph_idxs[model_name] = []
349+
model_name2subgraph_idxs[model_name].append(subgraph_idx)
350+
351+
if args.method == "fixed-start":
352+
for model_name in model_names:
353+
if model_name not in model_name2subgraph_idxs:
354+
model_name2subgraph_idxs[model_name] = [1]
355+
else:
356+
assert (
357+
len(model_name2subgraph_idxs[model_name]) == 1
358+
and model_name2subgraph_idxs[model_name] == 0
359+
)
360+
return model_name2subgraph_idxs
361+
362+
363+
def generate_successor_tasks(args, base_output_dir, current_pass_id):
332364
"""Generates tasks for Pass > 0 based on previous pass results."""
333365
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
334366
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
@@ -341,29 +373,23 @@ def generate_successor_tasks(base_output_dir, current_pass_id):
341373
tasks_map = {}
342374
prev_tasks_map = prev_config.tasks_map
343375

344-
for subgraph_path in sorted(prev_config.incorrect_models):
345-
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
346-
print(f"{subgraph_path=}")
376+
model_name2subgraph_idxs = collect_incorrect_subgraph_idxs(
377+
args, list(prev_tasks_map.keys()), prev_config.incorrect_models
378+
)
347379

380+
for model_name, subgraph_idxs in model_name2subgraph_idxs.items():
348381
assert model_name in prev_tasks_map
349382
pre_task_for_model = prev_tasks_map[model_name]
350383

351384
prev_split_positions = pre_task_for_model.get("split_positions", [])
352-
split_positions = reconstruct_split_positions_for_subgraph(
353-
prev_split_positions, subgraph_idx, max_subgraph_size
385+
split_positions = reconstruct_split_positions_for_subgraphs(
386+
prev_split_positions, subgraph_idxs, max_subgraph_size
354387
)
355-
if model_name not in tasks_map:
356-
tasks_map[model_name] = {
357-
"original_path": pre_task_for_model["original_path"],
358-
"split_positions": list(sorted(split_positions)),
359-
}
360-
else:
361-
merged_split_positions = (
362-
tasks_map[model_name]["split_positions"] + split_positions
363-
)
364-
tasks_map[model_name]["split_positions"] = list(
365-
sorted(set(merged_split_positions))
366-
)
388+
389+
tasks_map[model_name] = {
390+
"original_path": pre_task_for_model["original_path"],
391+
"split_positions": split_positions,
392+
}
367393
print(f"{tasks_map=}")
368394

369395
return tasks_map, max_subgraph_size, prev_config.running_states
@@ -374,7 +400,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
374400
tasks_map, max_subgraph_size, running_states = generate_initial_tasks(args)
375401
else:
376402
tasks_map, max_subgraph_size, running_states = generate_successor_tasks(
377-
base_output_dir, current_pass_id
403+
args, base_output_dir, current_pass_id
378404
)
379405

380406
print(f"[Init] initial max_subgraph_size: {max_subgraph_size}")
@@ -393,7 +419,6 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
393419
)
394420
sys.exit(0)
395421

396-
sys.exit(0)
397422
return tasks_map, max_subgraph_size, running_states
398423

399424

@@ -402,6 +427,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
402427

403428
failed_decomposition = []
404429
need_decompose = True if len(tasks_map) > 0 else False
430+
method = "fixed-start"
405431

406432
while need_decompose:
407433
decomposed_samples_dir = os.path.join(
@@ -426,6 +452,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
426452
not failed_decomposition
427453
and num_decomposed_samples == len(tasks_map)
428454
and max_subgraph_size > 1
455+
and method != "fixed-start"
429456
):
430457
need_decompose = True
431458
shutil.rmtree(decomposed_samples_dir)
@@ -435,7 +462,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
435462
split_positions = task_info["split_positions"]
436463
if not split_positions or len(split_positions) < 2:
437464
continue
438-
new_split_positions = reconstruct_split_positions_for_subgraph(
465+
new_split_positions = reconstruct_split_positions_for_subgraphs(
439466
split_positions, 0, max_subgraph_size
440467
)
441468
task_info["split_positions"] = new_split_positions
@@ -550,6 +577,7 @@ def main(args):
550577
parser.add_argument(
551578
"--test-config", type=str, required=True, help="Base64 encoded test config"
552579
)
580+
parser.add_argument("--method", type=str, required=True)
553581
parser.add_argument(
554582
"--tolerance",
555583
type=int,

0 commit comments

Comments
 (0)