Skip to content

Commit b26630f

Browse files
committed
Merge branch 'develop' into opt_saved_results
2 parents 7d9581f + 410311b commit b26630f

17 files changed

+1154
-48
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 84 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,28 @@ def get_pass_name(pass_id):
2727
return f"pass_{pass_id}"
2828

2929

30+
def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set:
31+
if not os.path.exists(log_path):
32+
return set()
33+
34+
t_start = tolerance_args[0]
35+
models_start = set(get_incorrect_models(t_start, log_path))
36+
37+
if len(tolerance_args) == 1:
38+
return models_start
39+
40+
t_end = tolerance_args[1]
41+
models_end = set(get_incorrect_models(t_end, log_path))
42+
43+
print(f"[Filter] Tolerance Range: {t_start} -> {t_end}")
44+
print(
45+
f"[Filter] Fail({t_start}): {len(models_start)}, Fail({t_end}): {len(models_end)}"
46+
)
47+
48+
diff_set = models_start - models_end
49+
return diff_set
50+
51+
3052
class TaskController:
3153
def __init__(self, args):
3254
self.root_output_dir = os.path.abspath(args.output_dir)
@@ -198,10 +220,10 @@ def run_decomposer_for_multi_models(
198220
)
199221
for model_name, task_info in tasks_map.items():
200222
original_path = task_info["original_path"]
201-
split_positions = calculate_split_positions_for_subgraph(
202-
task_info["subgraph_size"], max_subgraph_size
203-
)
204-
task_info["split_positions"] = split_positions
223+
224+
split_positions = task_info["split_positions"]
225+
if isinstance(split_positions, set):
226+
split_positions = sorted(list(split_positions))
205227

206228
rectified_model_path = get_rectfied_model_path(original_path)
207229
assert os.path.exists(
@@ -262,35 +284,39 @@ def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
262284
return subgraph_size
263285

264286

265-
def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
266-
assert isinstance(subgraph_size, (list, tuple)) and len(subgraph_size) == 2
287+
def calculate_split_positions_for_subgraph(subgraph_range, max_subgraph_size):
288+
assert isinstance(subgraph_range, (list, tuple)) and len(subgraph_range) == 2
267289

268290
# subgraph_size: the start and end position in original model.
269-
start_pos, end_pos = subgraph_size
291+
start_pos, end_pos = subgraph_range
270292
end_pos = kMaxGraphSize if end_pos == float("inf") else end_pos
271293

272-
split_positions = list(range(start_pos, end_pos + 1, max_subgraph_size))
273-
deduplicated_splits = list(dict.fromkeys(split_positions))
294+
split_positions = set(range(start_pos, end_pos + 1, max_subgraph_size))
295+
deduplicated_splits = list(sorted(split_positions))
274296
return deduplicated_splits
275297

276298

277299
def generate_initial_tasks(args):
278300
"""Generates tasks for Pass 0 based on the initial log file."""
279301
print(f"[Init] Pass 0: Reading from log file: {args.log_file}")
280-
initial_failures = get_incorrect_models(args.tolerance, args.log_file)
281-
t1_incorrect_models = get_incorrect_models(1, args.log_file)
282-
initial_failures = initial_failures - t1_incorrect_models
302+
initial_failures = get_ranged_incorrect_models(args.tolerance, args.log_file)
283303

284304
tasks_map = {}
305+
max_subgraph_size = args.max_subgraph_size
306+
285307
for model_path in initial_failures:
286308
model_name = get_model_name_with_subgraph_tag(model_path)
309+
310+
initial_range = [0, kMaxGraphSize]
311+
initial_splits = calculate_split_positions_for_subgraph(
312+
initial_range, max_subgraph_size
313+
)
314+
287315
tasks_map[model_name] = {
288316
"original_path": model_path,
289-
"subgraph_size": [0, kMaxGraphSize],
290-
"split_positions": set(),
317+
"split_positions": list(sorted(initial_splits)),
291318
}
292319

293-
max_subgraph_size = args.max_subgraph_size
294320
running_states = {
295321
"pass_0": {
296322
"num_incorrect_models": len(initial_failures),
@@ -322,20 +348,25 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
322348
assert model_name in prev_tasks_map
323349
pre_task_for_model = prev_tasks_map[model_name]
324350

325-
# Reconstruct previous subgraph size to locate the failing segment
326351
prev_split_positions = pre_task_for_model.get("split_positions", [])
327-
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
352+
subgraph_ranges = reconstruct_subgraph_size(prev_split_positions)
353+
328354
assert subgraph_idx < len(
329-
subgraph_size
355+
subgraph_ranges
330356
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
331357

358+
new_splits = calculate_split_positions_for_subgraph(
359+
subgraph_ranges[subgraph_idx], max_subgraph_size
360+
)
361+
332362
if model_name not in tasks_map:
333363
tasks_map[model_name] = {
334364
"original_path": pre_task_for_model["original_path"],
335-
"subgraph_size": subgraph_size[subgraph_idx],
336-
"split_positions": set(),
365+
"split_positions": list(sorted(new_splits)),
337366
}
338-
367+
else:
368+
new_splits = set(tasks_map[model_name]["split_positions"]) + set(new_splits)
369+
tasks_map[model_name]["split_positions"] = list(sorted(new_splits))
339370
return tasks_map, max_subgraph_size, prev_config.running_states
340371

341372

@@ -399,11 +430,23 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
399430
need_decompose = True
400431
shutil.rmtree(decomposed_samples_dir)
401432
os.makedirs(decomposed_samples_dir, exist_ok=True)
433+
max_subgraph_size = max(1, max_subgraph_size // 2)
402434
for model_name, task_info in tasks_map.items():
403-
task_info["subgraph_size"][1] = (
404-
task_info["subgraph_size"][0] + max_subgraph_size
435+
splits = task_info["split_positions"]
436+
if not splits or len(splits) < 2:
437+
continue
438+
if isinstance(splits, set):
439+
splits = sorted(list(splits))
440+
start_pos = splits[0]
441+
first_segment_end = splits[1]
442+
new_splits = list(
443+
range(start_pos, first_segment_end + 1, max_subgraph_size)
405444
)
406-
max_subgraph_size = max(1, max_subgraph_size // 2)
445+
446+
if new_splits[-1] != first_segment_end:
447+
new_splits.append(first_segment_end)
448+
449+
task_info["split_positions"] = sorted(list(set(new_splits)))
407450
else:
408451
need_decompose = False
409452
print()
@@ -458,6 +501,7 @@ def main(args):
458501
"failed_decomposition_models"
459502
] = list(failed_decomposition)
460503
else:
504+
print("\n--- Phase 1: Decomposition (skipped) ---", flush=True)
461505
config = DecomposeConfig.load(pass_work_dir)
462506
max_subgraph_size = config.max_subgraph_size
463507
tasks_map = config.tasks_map
@@ -466,19 +510,26 @@ def main(args):
466510
# --- Step 3: Evaluation ---
467511
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
468512
if task_controller.task_scheduler["run_evaluation"]:
469-
print("\n--- Phase 2: Evaluation ---")
513+
print(f"\n--- Phase 2: Evaluation {task_controller.test_module_name} ---")
470514
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
471515

472516
# --- Step 4: Analysis ---
473517
next_round_models = set()
474518
if task_controller.task_scheduler["post_analysis"]:
475-
print("\n--- Phase 3: Analysis ---")
476-
next_round_models = sorted(get_incorrect_models(args.tolerance, pass_log_path))
477-
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
519+
tolerance = (
520+
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
521+
)
522+
print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---")
523+
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
478524
running_states[f"pass_{current_pass_id + 1}"] = {
479525
"num_incorrect_models": len(next_round_models),
480526
"incorrect_models": list(next_round_models),
481527
}
528+
529+
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
530+
for idx, model_path in enumerate(next_round_models):
531+
print(f"- [{idx}] {model_path}")
532+
482533
print_summary_and_suggestion(next_round_models, max_subgraph_size)
483534

484535
# --- Step 5: Save States ---
@@ -500,7 +551,11 @@ def main(args):
500551
"--test-config", type=str, required=True, help="Base64 encoded test config"
501552
)
502553
parser.add_argument(
503-
"--tolerance", type=int, required=True, help="Tolerance level range [-10, 5)"
554+
"--tolerance",
555+
type=int,
556+
nargs="+",
557+
required=True,
558+
help="Tolerance level range [-10, 5)",
504559
)
505560
parser.add_argument("--max-subgraph-size", type=int, default=4096)
506561
args = parser.parse_args()

graph_net/test/dimension_generalization_test.sh

100644100755
File mode changed.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": false,
6+
"model_name": "error_model"
7+
}

graph_net/test/error_model/input_meta.py

Whitespace-only changes.

graph_net/test/error_model/input_tensor_constraints.py

Whitespace-only changes.

0 commit comments

Comments
 (0)