Skip to content

Commit ef92564

Browse files
committed
Record the number of original incorrect models.
1 parent 3d30e86 commit ef92564

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ def get_pass_name(pass_id):
2828

2929

3030
def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set:
31-
if not os.path.exists(log_path):
32-
return set()
31+
assert os.path.exists(log_path)
3332

3433
t_start = tolerance_args[0]
3534
models_start = set(get_incorrect_models(t_start, log_path))
@@ -40,13 +39,10 @@ def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set
4039
t_end = tolerance_args[1]
4140
models_end = set(get_incorrect_models(t_end, log_path))
4241

43-
print(f"[Filter] Tolerance Range: {t_start} -> {t_end}")
4442
print(
45-
f"[Filter] Fail({t_start}): {len(models_start)}, Fail({t_end}): {len(models_end)}"
43+
f"[Init] number of incorrect models: {len(models_start)} (tolerance={t_start}) - {len(models_end)} (tolerance={t_end})"
4644
)
47-
48-
diff_set = models_start - models_end
49-
return diff_set
45+
return models_start - models_end
5046

5147

5248
class TaskController:
@@ -326,6 +322,14 @@ def generate_initial_tasks(args):
326322
return tasks_map, max_subgraph_size, running_states
327323

328324

325+
def extract_model_name_and_subgraph_idx(subgraph_path):
326+
# Parse model name and subgraph index
327+
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
328+
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
329+
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
330+
return model_name, subgraph_idx
331+
332+
329333
def generate_refined_tasks(base_output_dir, current_pass_id):
330334
"""Generates tasks for Pass > 0 based on previous pass results."""
331335
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
@@ -340,10 +344,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
340344
prev_tasks_map = prev_config.tasks_map
341345

342346
for subgraph_path in sorted(prev_config.incorrect_models):
343-
# Parse model name and subgraph index
344-
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
345-
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
346-
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
347+
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
347348

348349
assert model_name in prev_tasks_map
349350
pre_task_for_model = prev_tasks_map[model_name]
@@ -525,12 +526,18 @@ def main(args):
525526
)
526527
print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---")
527528
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
529+
original_model_paths = [
530+
model_name
531+
for subgraph_path in next_round_models
532+
for model_name, _ in [extract_model_name_and_subgraph_idx(subgraph_path)]
533+
]
534+
528535
running_states[f"pass_{current_pass_id + 1}"] = {
529-
"num_incorrect_models": len(next_round_models),
536+
"num_incorrect_models": len(set(original_model_paths)),
530537
"incorrect_models": list(next_round_models),
531538
}
532539

533-
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
540+
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.")
534541
for idx, model_path in enumerate(next_round_models):
535542
print(f"- [{idx}] {model_path}")
536543

0 commit comments

Comments
 (0)