Skip to content

Commit c3e78da

Browse files
committed
Merge branch 'develop' into opt_saved_results
2 parents 7d9581f + 410311b commit c3e78da

17 files changed

+1151
-45
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,28 @@ def get_pass_name(pass_id):
2727
return f"pass_{pass_id}"
2828

2929

30+
def get_ranged_incorrect_models(tolerance_args: List[int], log_path: str) -> set:
31+
if not os.path.exists(log_path):
32+
return set()
33+
34+
t_start = tolerance_args[0]
35+
models_start = set(get_incorrect_models(t_start, log_path))
36+
37+
if len(tolerance_args) == 1:
38+
return models_start
39+
40+
t_end = tolerance_args[1]
41+
models_end = set(get_incorrect_models(t_end, log_path))
42+
43+
print(f"[Filter] Tolerance Range: {t_start} -> {t_end}")
44+
print(
45+
f"[Filter] Fail({t_start}): {len(models_start)}, Fail({t_end}): {len(models_end)}"
46+
)
47+
48+
diff_set = models_start - models_end
49+
return diff_set
50+
51+
3052
class TaskController:
3153
def __init__(self, args):
3254
self.root_output_dir = os.path.abspath(args.output_dir)
@@ -198,10 +220,10 @@ def run_decomposer_for_multi_models(
198220
)
199221
for model_name, task_info in tasks_map.items():
200222
original_path = task_info["original_path"]
201-
split_positions = calculate_split_positions_for_subgraph(
202-
task_info["subgraph_size"], max_subgraph_size
203-
)
204-
task_info["split_positions"] = split_positions
223+
224+
split_positions = task_info["split_positions"]
225+
if isinstance(split_positions, set):
226+
split_positions = sorted(list(split_positions))
205227

206228
rectified_model_path = get_rectfied_model_path(original_path)
207229
assert os.path.exists(
@@ -262,35 +284,39 @@ def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
262284
return subgraph_size
263285

264286

265-
def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
266-
assert isinstance(subgraph_size, (list, tuple)) and len(subgraph_size) == 2
287+
def calculate_split_positions_for_subgraph(subgraph_range, max_subgraph_size):
288+
assert isinstance(subgraph_range, (list, tuple)) and len(subgraph_range) == 2
267289

268290
# subgraph_size: the start and end position in original model.
269-
start_pos, end_pos = subgraph_size
291+
start_pos, end_pos = subgraph_range
270292
end_pos = kMaxGraphSize if end_pos == float("inf") else end_pos
271293

272-
split_positions = list(range(start_pos, end_pos + 1, max_subgraph_size))
273-
deduplicated_splits = list(dict.fromkeys(split_positions))
294+
split_positions = set(range(start_pos, end_pos + 1, max_subgraph_size))
295+
deduplicated_splits = list(sorted(split_positions))
274296
return deduplicated_splits
275297

276298

277299
def generate_initial_tasks(args):
278300
"""Generates tasks for Pass 0 based on the initial log file."""
279301
print(f"[Init] Pass 0: Reading from log file: {args.log_file}")
280-
initial_failures = get_incorrect_models(args.tolerance, args.log_file)
281-
t1_incorrect_models = get_incorrect_models(1, args.log_file)
282-
initial_failures = initial_failures - t1_incorrect_models
302+
initial_failures = get_ranged_incorrect_models(args.tolerance, args.log_file)
283303

284304
tasks_map = {}
305+
max_subgraph_size = args.max_subgraph_size
306+
285307
for model_path in initial_failures:
286308
model_name = get_model_name_with_subgraph_tag(model_path)
309+
310+
initial_range = [0, kMaxGraphSize]
311+
initial_splits = calculate_split_positions_for_subgraph(
312+
initial_range, max_subgraph_size
313+
)
314+
287315
tasks_map[model_name] = {
288316
"original_path": model_path,
289-
"subgraph_size": [0, kMaxGraphSize],
290-
"split_positions": set(),
317+
"split_positions": list(sorted(initial_splits)),
291318
}
292319

293-
max_subgraph_size = args.max_subgraph_size
294320
running_states = {
295321
"pass_0": {
296322
"num_incorrect_models": len(initial_failures),
@@ -322,20 +348,25 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
322348
assert model_name in prev_tasks_map
323349
pre_task_for_model = prev_tasks_map[model_name]
324350

325-
# Reconstruct previous subgraph size to locate the failing segment
326351
prev_split_positions = pre_task_for_model.get("split_positions", [])
327-
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
352+
subgraph_ranges = reconstruct_subgraph_size(prev_split_positions)
353+
328354
assert subgraph_idx < len(
329-
subgraph_size
355+
subgraph_ranges
330356
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
331357

358+
new_splits = calculate_split_positions_for_subgraph(
359+
subgraph_ranges[subgraph_idx], max_subgraph_size
360+
)
361+
332362
if model_name not in tasks_map:
333363
tasks_map[model_name] = {
334364
"original_path": pre_task_for_model["original_path"],
335-
"subgraph_size": subgraph_size[subgraph_idx],
336-
"split_positions": set(),
365+
"split_positions": list(sorted(new_splits)),
337366
}
338-
367+
else:
368+
new_splits = set(tasks_map[model_name]["split_positions"]) + set(new_splits)
369+
tasks_map[model_name]["split_positions"] = list(sorted(new_splits))
339370
return tasks_map, max_subgraph_size, prev_config.running_states
340371

341372

@@ -399,11 +430,23 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
399430
need_decompose = True
400431
shutil.rmtree(decomposed_samples_dir)
401432
os.makedirs(decomposed_samples_dir, exist_ok=True)
433+
max_subgraph_size = max(1, max_subgraph_size // 2)
402434
for model_name, task_info in tasks_map.items():
403-
task_info["subgraph_size"][1] = (
404-
task_info["subgraph_size"][0] + max_subgraph_size
435+
splits = task_info["split_positions"]
436+
if not splits or len(splits) < 2:
437+
continue
438+
if isinstance(splits, set):
439+
splits = sorted(list(splits))
440+
start_pos = splits[0]
441+
first_segment_end = splits[1]
442+
new_splits = list(
443+
range(start_pos, first_segment_end + 1, max_subgraph_size)
405444
)
406-
max_subgraph_size = max(1, max_subgraph_size // 2)
445+
446+
if new_splits[-1] != first_segment_end:
447+
new_splits.append(first_segment_end)
448+
449+
task_info["split_positions"] = sorted(list(set(new_splits)))
407450
else:
408451
need_decompose = False
409452
print()
@@ -473,12 +516,20 @@ def main(args):
473516
next_round_models = set()
474517
if task_controller.task_scheduler["post_analysis"]:
475518
print("\n--- Phase 3: Analysis ---")
476-
next_round_models = sorted(get_incorrect_models(args.tolerance, pass_log_path))
519+
tolerance = (
520+
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
521+
)
522+
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
477523
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
478524
running_states[f"pass_{current_pass_id + 1}"] = {
479525
"num_incorrect_models": len(next_round_models),
480526
"incorrect_models": list(next_round_models),
481527
}
528+
529+
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
530+
for idx, model_path in enumerate(next_round_models):
531+
print(f"- [{idx}] {model_path}")
532+
482533
print_summary_and_suggestion(next_round_models, max_subgraph_size)
483534

484535
# --- Step 5: Save States ---
@@ -500,7 +551,11 @@ def main(args):
500551
"--test-config", type=str, required=True, help="Base64 encoded test config"
501552
)
502553
parser.add_argument(
503-
"--tolerance", type=int, required=True, help="Tolerance level range [-10, 5)"
554+
"--tolerance",
555+
type=int,
556+
nargs="+",
557+
required=True,
558+
help="Tolerance level range [-10, 5)",
504559
)
505560
parser.add_argument("--max-subgraph-size", type=int, default=4096)
506561
args = parser.parse_args()

graph_net/test/dimension_generalization_test.sh

100644100755
File mode changed.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": false,
6+
"model_name": "error_model"
7+
}

graph_net/test/error_model/input_meta.py

Whitespace-only changes.

graph_net/test/error_model/input_tensor_constraints.py

Whitespace-only changes.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import torch
2+
3+
from torch import device
4+
5+
6+
class GraphModule(torch.nn.Module):
7+
def forward(
8+
self,
9+
add_22,
10+
extended_attention_mask_2,
11+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_,
12+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_,
13+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_,
14+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_,
15+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_,
16+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_,
17+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_,
18+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_,
19+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_,
20+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_,
21+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_,
22+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_,
23+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_,
24+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_,
25+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_,
26+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_,
27+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_,
28+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_,
29+
):
30+
hidden_states_66 = torch.nn.functional.layer_norm(
31+
add_22,
32+
(32,),
33+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_,
34+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_,
35+
1e-12,
36+
)
37+
add_22 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_3_modules_output_modules_layer_norm_parameters_bias_ = (None)
38+
linear_44 = torch.nn.functional.linear(
39+
hidden_states_66,
40+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_,
41+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_,
42+
)
43+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_query_parameters_bias_ = (None)
44+
view_16 = linear_44.view(2, -1, 4, 8)
45+
linear_44 = None
46+
query_layer_4 = view_16.transpose(1, 2)
47+
view_16 = None
48+
linear_45 = torch.nn.functional.linear(
49+
hidden_states_66,
50+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_,
51+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_,
52+
)
53+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_key_parameters_bias_ = (None)
54+
view_17 = linear_45.view(2, -1, 4, 8)
55+
linear_45 = None
56+
key_layer_4 = view_17.transpose(1, 2)
57+
view_17 = None
58+
linear_46 = torch.nn.functional.linear(
59+
hidden_states_66,
60+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_,
61+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_,
62+
)
63+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_self_modules_value_parameters_bias_ = (None)
64+
view_18 = linear_46.view(2, -1, 4, 8)
65+
linear_46 = None
66+
value_layer_4 = view_18.transpose(1, 2)
67+
view_18 = None
68+
transpose_25 = key_layer_4.transpose(-1, -2)
69+
key_layer_4 = None
70+
attention_scores_22 = torch.matmul(query_layer_4, transpose_25)
71+
query_layer_4 = transpose_25 = None
72+
attention_scores_23 = attention_scores_22 / 2.8284271247461903
73+
attention_scores_22 = None
74+
eps = torch.tensor(1e-8, device=attention_scores_23.device)
75+
nan_val = eps / (eps - eps)
76+
attention_scores_23 = attention_scores_23 + nan_val
77+
nan_val = None
78+
to_8 = extended_attention_mask_2.to(device(type="cuda", index=0))
79+
extended_attention_mask_2 = None
80+
attention_scores_24 = attention_scores_23 + to_8
81+
attention_scores_23 = to_8 = None
82+
_log_api_usage_once_4 = torch._C._log_api_usage_once("python.nn_module")
83+
_log_api_usage_once_4 = None
84+
attention_probs_14 = torch.nn.functional.softmax(
85+
attention_scores_24, -1, _stacklevel=5
86+
)
87+
attention_scores_24 = None
88+
attention_probs_dropped_4 = torch.nn.functional.dropout(
89+
attention_probs_14, 0.0, False, False
90+
)
91+
attention_probs_14 = None
92+
context_layer_22 = torch.matmul(attention_probs_dropped_4, value_layer_4)
93+
attention_probs_dropped_4 = value_layer_4 = None
94+
permute_14 = context_layer_22.permute(0, 2, 1, 3)
95+
context_layer_22 = None
96+
context_layer_23 = permute_14.contiguous()
97+
permute_14 = None
98+
context_layer_24 = context_layer_23.view(2, 14, 32)
99+
context_layer_23 = None
100+
hidden_states_67 = torch.nn.functional.linear(
101+
context_layer_24,
102+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_,
103+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_,
104+
)
105+
context_layer_24 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_dense_parameters_bias_ = (None)
106+
hidden_states_68 = torch.nn.functional.dropout(
107+
hidden_states_67, 0.0, False, False
108+
)
109+
hidden_states_67 = None
110+
add_24 = hidden_states_68 + hidden_states_66
111+
hidden_states_68 = hidden_states_66 = None
112+
hidden_states_69 = torch.nn.functional.layer_norm(
113+
add_24,
114+
(32,),
115+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_,
116+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_,
117+
1e-12,
118+
)
119+
add_24 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_attention_modules_output_modules_layer_norm_parameters_bias_ = (None)
120+
hidden_states_70 = torch.nn.functional.linear(
121+
hidden_states_69,
122+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_,
123+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_,
124+
)
125+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_intermediate_modules_dense_parameters_bias_ = (None)
126+
hidden_states_71 = torch.nn.functional.gelu(hidden_states_70)
127+
hidden_states_70 = None
128+
hidden_states_72 = torch.nn.functional.linear(
129+
hidden_states_71,
130+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_,
131+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_,
132+
)
133+
hidden_states_71 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_dense_parameters_bias_ = (None)
134+
hidden_states_73 = torch.nn.functional.dropout(
135+
hidden_states_72, 0.0, False, False
136+
)
137+
hidden_states_72 = None
138+
nan_val = torch.tensor(0.0, device=hidden_states_73.device) / torch.tensor(
139+
0.0, device=hidden_states_73.device
140+
)
141+
hidden_states_73 = hidden_states_73 + nan_val
142+
nan_val = None
143+
add_25 = hidden_states_73 + hidden_states_69
144+
hidden_states_73 = hidden_states_69 = None
145+
hidden_states_74 = torch.nn.functional.layer_norm(
146+
add_25,
147+
(32,),
148+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_,
149+
l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_,
150+
1e-12,
151+
)
152+
add_25 = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_weight_ = l_l_self_modules_text_model_modules_encoder_modules_layer_modules_4_modules_output_modules_layer_norm_parameters_bias_ = (None)
153+
return (hidden_states_74,)

0 commit comments

Comments
 (0)