Skip to content

Commit 864e7b3

Browse files
committed
Record the number of original incorrect models.
1 parent 3d30e86 commit 864e7b3

File tree

1 file changed

+30
-17
lines changed

1 file changed

+30
-17
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ def get_pass_name(pass_id):
2828

2929

3030
def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set:
31-
if not os.path.exists(log_path):
32-
return set()
31+
assert os.path.exists(log_path)
3332

3433
t_start = tolerance_args[0]
3534
models_start = set(get_incorrect_models(t_start, log_path))
@@ -40,13 +39,10 @@ def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set
4039
t_end = tolerance_args[1]
4140
models_end = set(get_incorrect_models(t_end, log_path))
4241

43-
print(f"[Filter] Tolerance Range: {t_start} -> {t_end}")
4442
print(
45-
f"[Filter] Fail({t_start}): {len(models_start)}, Fail({t_end}): {len(models_end)}"
43+
f"[Init] number of incorrect models: {len(models_start)} (tolerance={t_start}) - {len(models_end)} (tolerance={t_end})"
4644
)
47-
48-
diff_set = models_start - models_end
49-
return diff_set
45+
return models_start - models_end
5046

5147

5248
class TaskController:
@@ -326,6 +322,14 @@ def generate_initial_tasks(args):
326322
return tasks_map, max_subgraph_size, running_states
327323

328324

325+
def extract_model_name_and_subgraph_idx(subgraph_path):
326+
# Parse model name and subgraph index
327+
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
328+
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
329+
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
330+
return model_name, subgraph_idx
331+
332+
329333
def generate_refined_tasks(base_output_dir, current_pass_id):
330334
"""Generates tasks for Pass > 0 based on previous pass results."""
331335
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
@@ -340,10 +344,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
340344
prev_tasks_map = prev_config.tasks_map
341345

342346
for subgraph_path in sorted(prev_config.incorrect_models):
343-
# Parse model name and subgraph index
344-
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
345-
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
346-
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
347+
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
347348

348349
assert model_name in prev_tasks_map
349350
pre_task_for_model = prev_tasks_map[model_name]
@@ -382,11 +383,11 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
382383
base_output_dir, current_pass_id
383384
)
384385

385-
print(f"[INFO] initial max_subgraph_size: {max_subgraph_size}")
386-
print(f"[INFO] number of incorrect models: {len(tasks_map)}")
387-
for model_name, task_info in tasks_map.items():
386+
print(f"[Init] initial max_subgraph_size: {max_subgraph_size}")
387+
print(f"[Init] number of incorrect models: {len(tasks_map)}")
388+
for idx, (model_name, task_info) in enumerate(tasks_map.items()):
388389
original_path = task_info["original_path"]
389-
print(f"- {original_path}")
390+
print(f"- [{idx}] {original_path}")
390391

391392
if not tasks_map:
392393
print("[FINISHED] No models need processing.")
@@ -525,12 +526,24 @@ def main(args):
525526
)
526527
print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---")
527528
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
529+
original_model_paths = set(
530+
[
531+
model_name
532+
for subgraph_path in next_round_models
533+
for model_name, _ in [
534+
extract_model_name_and_subgraph_idx(subgraph_path)
535+
]
536+
]
537+
)
538+
528539
running_states[f"pass_{current_pass_id + 1}"] = {
529-
"num_incorrect_models": len(next_round_models),
540+
"num_incorrect_models": len(original_model_paths),
530541
"incorrect_models": list(next_round_models),
531542
}
532543

533-
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
544+
print(
545+
f"[Analysis] Found {len(next_round_models)} incorrect subgraphs ({len(original_model_paths)} original models)."
546+
)
534547
for idx, model_path in enumerate(next_round_models):
535548
print(f"- [{idx}] {model_path}")
536549

0 commit comments

Comments
 (0)