Skip to content

Commit 7cfd4eb

Browse files
committed
Support fixed-start method.
1 parent d7c91a2 commit 7cfd4eb

File tree

1 file changed

+80
-42
lines changed

1 file changed

+80
-42
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 80 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ def _print(self):
111111

112112
@dataclass
113113
class DecomposeConfig:
114+
method: str
115+
tolerance: int | List[int]
114116
max_subgraph_size: int = -1
115-
incorrect_models: List[str] = field(default_factory=list)
116117
tasks_map: Dict[str, Union[int, str, list, dict]] = field(default_factory=dict)
117118
running_states: Dict[str, Union[int, str, list, dict]] = field(default_factory=dict)
118119

@@ -139,6 +140,11 @@ def load(self, work_dir):
139140
def get_config_path(self, work_dir) -> str:
140141
return os.path.join(work_dir, "decompose_config.json")
141142

143+
def get_incorrect_models(self, pass_id):
144+
pass_key = get_pass_name(pass_id)
145+
assert pass_key in self.running_states
146+
return self.running_states[pass_key]["incorrect_models"]
147+
142148
def update_running_states(self, pass_id, **kwargs):
143149
pass_key = get_pass_name(pass_id)
144150
if self.running_states.get(pass_key, None) is None:
@@ -229,8 +235,13 @@ def run_decomposer_for_multi_models(
229235
)
230236
for model_name, task_info in tasks_map.items():
231237
original_path = task_info["original_path"]
232-
233238
split_positions = sorted(list(task_info["split_positions"]))
239+
240+
method = "fixed-start"
241+
if method == "fixed-start":
242+
assert len(split_positions) >= 3, f"{split_positions=}"
243+
split_positions = [0, split_positions[1]]
244+
234245
rectified_model_path = get_rectfied_model_path(original_path)
235246
assert os.path.exists(
236247
rectified_model_path
@@ -279,18 +290,22 @@ def run_evaluation(
279290
), f"[ERROR] test failed for {samples_dir}, please check the log."
280291

281292

282-
def reconstruct_split_positions_for_subgraph(
283-
split_positions, subgraph_idx, max_subgraph_size
293+
def reconstruct_split_positions_for_subgraphs(
294+
split_positions, subgraph_idxs, max_subgraph_size
284295
):
285-
assert (
286-
subgraph_idx < len(split_positions) - 1
287-
), f"subgraph_idx {subgraph_idx} is out of bounds of split_positions: {split_positions}."
296+
subgraph_idxs = [subgraph_idxs] if isinstance(subgraph_idxs, int) else subgraph_idxs
288297

289-
start_pos, end_pos = split_positions[subgraph_idx : subgraph_idx + 2]
290-
new_split_positions = set(
291-
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
292-
)
293-
return sorted(list(new_split_positions))
298+
new_split_positions = []
299+
for subgraph_idx in subgraph_idxs:
300+
assert (
301+
subgraph_idx < len(split_positions) - 1
302+
), f"subgraph_idx {subgraph_idx} is out of bounds of split_positions: {split_positions}."
303+
304+
start_pos, end_pos = split_positions[subgraph_idx : subgraph_idx + 2]
305+
new_split_positions = new_split_positions + list(
306+
range(start_pos, end_pos + max_subgraph_size, max_subgraph_size)
307+
)
308+
return sorted(list(set(new_split_positions)))
294309

295310

296311
def generate_initial_tasks(args):
@@ -299,9 +314,9 @@ def generate_initial_tasks(args):
299314
initial_failures = get_ranged_incorrect_models(args.tolerance, args.log_file)
300315

301316
tasks_map = {}
302-
max_subgraph_size = args.max_subgraph_size
317+
max_subgraph_size = min(args.max_subgraph_size, kMaxGraphSize // 2)
303318

304-
initial_split_positions = reconstruct_split_positions_for_subgraph(
319+
initial_split_positions = reconstruct_split_positions_for_subgraphs(
305320
[0, kMaxGraphSize], 0, max_subgraph_size
306321
)
307322
for model_path in initial_failures:
@@ -328,42 +343,61 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
328343
return model_name, subgraph_idx
329344

330345

331-
def generate_successor_tasks(base_output_dir, current_pass_id):
346+
def collect_incorrect_subgraph_idxs(args, target_model_names, incorrect_models):
347+
model_name2subgraph_idxs = {}
348+
for subgraph_path in sorted(incorrect_models):
349+
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
350+
print(f"{subgraph_path=}")
351+
print(f"{model_name=}, {subgraph_idx=}")
352+
assert model_name in target_model_names, f"{model_name=}, {subgraph_idx=}"
353+
354+
if model_name not in model_name2subgraph_idxs:
355+
model_name2subgraph_idxs[model_name] = []
356+
model_name2subgraph_idxs[model_name].append(subgraph_idx)
357+
358+
if args.method == "fixed-start":
359+
print(model_name2subgraph_idxs)
360+
for model_name in target_model_names:
361+
if model_name not in model_name2subgraph_idxs:
362+
model_name2subgraph_idxs[model_name] = [1]
363+
else:
364+
assert len(
365+
model_name2subgraph_idxs[model_name]
366+
) == 1 and model_name2subgraph_idxs[model_name] == [0]
367+
return model_name2subgraph_idxs
368+
369+
370+
def generate_successor_tasks(args, base_output_dir, current_pass_id):
332371
"""Generates tasks for Pass > 0 based on previous pass results."""
333372
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
334373
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
335374

336375
prev_config = DecomposeConfig.load(prev_pass_dir)
337376
max_subgraph_size = prev_config.max_subgraph_size // 2
338-
if not prev_config.incorrect_models:
377+
incorrect_models = prev_config.get_incorrect_models(current_pass_id)
378+
if args.method != "fixed-start" and not incorrect_models:
339379
return {}, max_subgraph_size, prev_config.running_states
340380

341381
tasks_map = {}
342382
prev_tasks_map = prev_config.tasks_map
343383

344-
for subgraph_path in sorted(prev_config.incorrect_models):
345-
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
346-
print(f"{subgraph_path=}")
384+
target_model_names = list(prev_tasks_map.keys())
385+
model_name2subgraph_idxs = collect_incorrect_subgraph_idxs(
386+
args, target_model_names, incorrect_models
387+
)
347388

348-
assert model_name in prev_tasks_map
389+
for model_name, subgraph_idxs in model_name2subgraph_idxs.items():
349390
pre_task_for_model = prev_tasks_map[model_name]
350391

351392
prev_split_positions = pre_task_for_model.get("split_positions", [])
352-
split_positions = reconstruct_split_positions_for_subgraph(
353-
prev_split_positions, subgraph_idx, max_subgraph_size
393+
split_positions = reconstruct_split_positions_for_subgraphs(
394+
prev_split_positions, subgraph_idxs, max_subgraph_size
354395
)
355-
if model_name not in tasks_map:
356-
tasks_map[model_name] = {
357-
"original_path": pre_task_for_model["original_path"],
358-
"split_positions": list(sorted(split_positions)),
359-
}
360-
else:
361-
merged_split_positions = (
362-
tasks_map[model_name]["split_positions"] + split_positions
363-
)
364-
tasks_map[model_name]["split_positions"] = list(
365-
sorted(set(merged_split_positions))
366-
)
396+
397+
tasks_map[model_name] = {
398+
"original_path": pre_task_for_model["original_path"],
399+
"split_positions": split_positions,
400+
}
367401
print(f"{tasks_map=}")
368402

369403
return tasks_map, max_subgraph_size, prev_config.running_states
@@ -374,7 +408,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
374408
tasks_map, max_subgraph_size, running_states = generate_initial_tasks(args)
375409
else:
376410
tasks_map, max_subgraph_size, running_states = generate_successor_tasks(
377-
base_output_dir, current_pass_id
411+
args, base_output_dir, current_pass_id
378412
)
379413

380414
print(f"[Init] initial max_subgraph_size: {max_subgraph_size}")
@@ -393,7 +427,6 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
393427
)
394428
sys.exit(0)
395429

396-
sys.exit(0)
397430
return tasks_map, max_subgraph_size, running_states
398431

399432

@@ -402,6 +435,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
402435

403436
failed_decomposition = []
404437
need_decompose = True if len(tasks_map) > 0 else False
438+
method = "fixed-start"
405439

406440
while need_decompose:
407441
decomposed_samples_dir = os.path.join(
@@ -426,6 +460,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
426460
not failed_decomposition
427461
and num_decomposed_samples == len(tasks_map)
428462
and max_subgraph_size > 1
463+
and method != "fixed-start"
429464
):
430465
need_decompose = True
431466
shutil.rmtree(decomposed_samples_dir)
@@ -435,7 +470,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
435470
split_positions = task_info["split_positions"]
436471
if not split_positions or len(split_positions) < 2:
437472
continue
438-
new_split_positions = reconstruct_split_positions_for_subgraph(
473+
new_split_positions = reconstruct_split_positions_for_subgraphs(
439474
split_positions, 0, max_subgraph_size
440475
)
441476
task_info["split_positions"] = new_split_positions
@@ -458,8 +493,7 @@ def count_unique_original_models(incorrect_models):
458493
return len(original_model_paths)
459494

460495

461-
def print_summary_and_suggestion(next_round_models, max_subgraph_size):
462-
"""Print suggestion/result."""
496+
def print_summary_and_suggestion(args, next_round_models, max_subgraph_size):
463497
print("\n" + "=" * 80)
464498
if next_round_models and max_subgraph_size > 1:
465499
print(f">>> [SUGGESTION] Issues remain (Count: {len(next_round_models)}).")
@@ -485,6 +519,8 @@ def main(args):
485519
args, current_pass_id, base_output_dir
486520
)
487521
decompose_config = DecomposeConfig(
522+
method=args.method,
523+
tolerance=args.tolerance,
488524
max_subgraph_size=max_subgraph_size,
489525
tasks_map=tasks_map,
490526
running_states=running_states,
@@ -517,7 +553,6 @@ def main(args):
517553
run_evaluation(args.framework, args.test_config, work_dir, log_path)
518554

519555
# --- Step 4: Analysis ---
520-
next_pass_incorrect_models = set()
521556
if task_controller.task_scheduler["post_analysis"]:
522557
tolerance = (
523558
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
@@ -530,15 +565,17 @@ def main(args):
530565
num_incorrect_models=num_original_models,
531566
incorrect_models=list(next_pass_incorrect_models),
532567
)
568+
533569
print(
534570
f"[Analysis] Found {len(next_pass_incorrect_models)} incorrect subgraphs ({num_original_models} original models)."
535571
)
536572
for idx, model_path in enumerate(next_pass_incorrect_models):
537573
print(f"- [{idx}] {model_path}")
538-
print_summary_and_suggestion(next_pass_incorrect_models, max_subgraph_size)
574+
print_summary_and_suggestion(
575+
args, next_pass_incorrect_models, max_subgraph_size
576+
)
539577

540578
# --- Step 5: Save States ---
541-
decompose_config.incorrect_models = list(next_pass_incorrect_models)
542579
decompose_config.save(work_dir)
543580

544581

@@ -550,6 +587,7 @@ def main(args):
550587
parser.add_argument(
551588
"--test-config", type=str, required=True, help="Base64 encoded test config"
552589
)
590+
parser.add_argument("--method", type=str, required=True)
553591
parser.add_argument(
554592
"--tolerance",
555593
type=int,

0 commit comments

Comments
 (0)