Skip to content

Commit 3301ecb

Browse files
committed
Record the number of original incorrect models.
1 parent 3d30e86 commit 3301ecb

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,14 @@ def generate_initial_tasks(args):
326326
return tasks_map, max_subgraph_size, running_states
327327

328328

329+
def extract_model_name_and_subgraph_idx(subgraph_path):
330+
# Parse model name and subgraph index
331+
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
332+
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
333+
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
334+
return model_name, subgraph_idx
335+
336+
329337
def generate_refined_tasks(base_output_dir, current_pass_id):
330338
"""Generates tasks for Pass > 0 based on previous pass results."""
331339
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
@@ -340,10 +348,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
340348
prev_tasks_map = prev_config.tasks_map
341349

342350
for subgraph_path in sorted(prev_config.incorrect_models):
343-
# Parse model name and subgraph index
344-
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
345-
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
346-
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
351+
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
347352

348353
assert model_name in prev_tasks_map
349354
pre_task_for_model = prev_tasks_map[model_name]
@@ -525,12 +530,18 @@ def main(args):
525530
)
526531
print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---")
527532
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
533+
original_model_paths = [
534+
model_name
535+
for subgraph_path in next_round_models
536+
for model_name, _ in [extract_model_name_and_subgraph_idx(subgraph_path)]
537+
]
538+
528539
running_states[f"pass_{current_pass_id + 1}"] = {
529-
"num_incorrect_models": len(next_round_models),
540+
"num_incorrect_models": len(set(original_model_paths)),
530541
"incorrect_models": list(next_round_models),
531542
}
532543

533-
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
544+
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.")
534545
for idx, model_path in enumerate(next_round_models):
535546
print(f"- [{idx}] {model_path}")
536547

0 commit comments

Comments
 (0)