Skip to content

Commit 5750ae7

Browse files
lixinqiJewelRoam
andauthored
1) support resume; 2) fix bugs in torch/backend/range_decomposer_validator_backend.py (#441)
* debug_typical_sequence * support model-path-prefix in splitting positions * fix * fix * Improve efficiency of test/typical_sequence_decomposer_test.sh * 1) support resume; 2) fix bugs in torch/backend/range_decomposer_validator_backend.py --------- Co-authored-by: JewelRoam <[email protected]>
1 parent 201d5b9 commit 5750ae7

File tree

9 files changed

+209
-73
lines changed

9 files changed

+209
-73
lines changed

graph_net/model_path_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import traceback
23
from graph_net.imp_util import load_module
34
import logging
45
import sys
@@ -52,6 +53,9 @@ def handle_model_path_list_in_current_process(handler, args):
5253
except KeyboardInterrupt:
5354
print("KeyboardInterrupt")
5455
return
56+
except Exception:
57+
print("------------[model_path_handler failed]------------", flush=True)
58+
traceback.print_exc()
5559

5660

5761
def handle_model_path_list_in_subprocess(args):
Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
1-
samples/timm/convnextv2_base.fcmae_ft_in1k
2-
samples/timm/hgnet_tiny.paddle_in1k
3-
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k
4-
samples/timm/regnety_080_tv.tv2_in1k
5-
samples/timm/res2net50_14w_8s.in1k
6-
samples/torchaudio/wavlm_base
71
samples/torchgeometric/RECT_L
8-
samples/torchvision/vgg16_bn
92
samples/transformers-auto-model/bge-small-en-v1.5
103
samples/transformers-auto-model/distilbert_distilbert-base-multilingual-cased
114
samples/transformers-auto-model/OFA-Sys_chinese-clip-vit-large-patch14
@@ -17,4 +10,3 @@ samples/transformers-auto-model/opus-mt-en-tw
1710
samples/transformers-auto-model/opus-mt-fi-niu
1811
samples/transformers-auto-model/opus-mt-tc-bible-big-deu_eng_fra_por_spa-bat
1912
samples/transformers-auto-model/opus-mt-tc-bible-big-gmw-deu_eng_fra_por_spa
20-
samples/ultralytics/yolov3-tinyu

graph_net/test/typical_sequence_decomposer_test.sh

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,28 @@ mkdir -p "$DECOMPOSE_PATH"
99
# model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
1010
model_list="$GRAPH_NET_ROOT/graph_net/test/dev_model_list/validation_error_model_list.txt"
1111

12+
op_names_extractor_config_json_str=$(cat <<EOF
13+
{
14+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/typical_sequence_split_points.py",
15+
"handler_class_name": "OpNamesExtractor",
16+
"handler_config": {
17+
"resume": true,
18+
"model_path_prefix": "$GRAPH_NET_ROOT",
19+
"output_dir": "$DECOMPOSE_PATH"
20+
}
21+
}
22+
EOF
23+
)
24+
OP_NAMES_EXTRACTOR_CONFIG=$(echo $op_names_extractor_config_json_str | base64 -w 0)
25+
26+
python3 -m graph_net.model_path_handler \
27+
--model-path-list $model_list \
28+
--handler-config=$OP_NAMES_EXTRACTOR_CONFIG \
29+
1230
python3 -m graph_net.torch.typical_sequence_split_points \
31+
--enable-resume \
1332
--model-list "$model_list" \
14-
--model-path-prefix "$GRAPH_NET_ROOT" \
33+
--op-names-path-prefix "$DECOMPOSE_PATH" \
1534
--device "cuda" \
1635
--window-size 10 \
1736
--fold-policy default \
@@ -23,6 +42,7 @@ decompose_config_json_str=$(cat <<EOF
2342
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_decomposer.py",
2443
"handler_class_name": "RangeDecomposerExtractor",
2544
"handler_config": {
45+
"resume": true,
2646
"model_path_prefix": "$GRAPH_NET_ROOT",
2747
"output_dir": "$DECOMPOSE_PATH",
2848
"split_results_path": "$DECOMPOSE_PATH/split_results.json",
@@ -37,10 +57,10 @@ DECOMPOSE_CONFIG=$(echo $decompose_config_json_str | base64 -w 0)
3757
python3 -m graph_net.model_path_handler \
3858
--model-path-list $model_list \
3959
--handler-config=$DECOMPOSE_CONFIG \
40-
--use-subprocess
4160

4261
test_compiler_config_json_str=$(cat <<EOF
4362
{
63+
"model_path_prefix": "$GRAPH_NET_ROOT",
4464
"decomposed_root": "$DECOMPOSE_PATH"
4565
}
4666
EOF
@@ -53,7 +73,7 @@ python3 -m graph_net.torch.test_compiler \
5373
--device cuda \
5474
--config $TEST_COMPILER_CONFIG \
5575
--model-path-prefix $GRAPH_NET_ROOT \
56-
> "$DECOMPOSE_PATH/validation.log" 2>&1
76+
2>&1 | tee "$DECOMPOSE_PATH/validation.log"
5777

5878
python3 -m graph_net.plot_ESt \
5979
--benchmark-path "$DECOMPOSE_PATH/validation.log" \

graph_net/torch/backend/range_decomposer_validator_backend.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,41 @@
11
import torch
2+
import inspect
23
import torch.nn as nn
34
import os
45
import importlib.util
56
from typing import List
67

78

89
class ComposedModel(nn.Module):
9-
def __init__(self, subgraph: List[nn.Module]):
10+
def __init__(self, subgraphs: List[nn.Module]):
1011
super().__init__()
11-
self.subgraphs = nn.ModuleList(subgraph)
12+
self.subgraphs = nn.ModuleList(subgraphs)
1213

1314
def forward(self, **kwargs):
1415
output = None
1516
for i, subgraph in enumerate(self.subgraphs):
16-
print(f"{i=} subgraph begin")
1717
if output is None:
18-
output = subgraph(**kwargs)
18+
output = subgraph(**self._convert_inputs(subgraph, kwargs))
1919
else:
2020
output = subgraph(*output)
21-
print(f"{i=} subgraph end")
2221

2322
return output
2423

24+
def _convert_inputs(self, subgraph, input_kwargs):
25+
input_keywords = set(name for name, _ in input_kwargs.items())
26+
sub_graph_arg_names = set(inspect.signature(subgraph.forward).parameters)
27+
assert (
28+
len(sub_graph_arg_names - input_keywords) == 0
29+
), f"{(sub_graph_arg_names - input_keywords)=}"
30+
for remainder in input_keywords - sub_graph_arg_names:
31+
assert remainder.startswith("s")
32+
assert remainder[1:].isdigit()
33+
return {
34+
name: value
35+
for name, value in input_kwargs.items()
36+
if name in sub_graph_arg_names
37+
}
38+
2539

2640
class RangeDecomposerValidatorBackend:
2741
def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
@@ -36,40 +50,56 @@ def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
3650
instance = ModelClass().to(device)
3751
return instance
3852

39-
def _make_config(self, decomposed_root, decomposed_model_name_suffix="_decomposed"):
53+
def _make_config(
54+
self,
55+
model_path_prefix: str,
56+
decomposed_root: str,
57+
decomposed_dentry: str = "_decomposed",
58+
):
4059
return {
60+
"model_path_prefix": model_path_prefix,
4161
"decomposed_root": decomposed_root,
42-
"decomposed_model_name_suffix": decomposed_model_name_suffix,
62+
"decomposed_dentry": decomposed_dentry,
4363
}
4464

65+
def _get_rel_model_path(self, model_path) -> str:
66+
model_path = os.path.realpath(model_path)
67+
model_path_prefix = os.path.realpath(self.config["model_path_prefix"])
68+
assert model_path.startswith(model_path_prefix)
69+
rel_model_path = model_path[len(model_path_prefix) :]
70+
if rel_model_path.startswith("/"):
71+
rel_model_path = rel_model_path[1:]
72+
assert not rel_model_path.startswith("/")
73+
return rel_model_path
74+
75+
def _get_model_name_order(self, name):
76+
lst = name.split("_")
77+
if not (len(lst) > 0):
78+
return -1
79+
if not (lst[-1].isdigit()):
80+
return -1
81+
return int(lst[-1])
82+
4583
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
4684
config = self._make_config(**self.config)
47-
model_file_path = model.__class__.__graph_net_file_path__
48-
model_dir = os.path.dirname(model_file_path)
49-
model_name = os.path.basename(model_dir)
85+
model_path = os.path.dirname(model.__class__.__graph_net_file_path__)
86+
rel_model_path = self._get_rel_model_path(model_path)
5087
decomposed_parent_dir = os.path.join(
51-
config["decomposed_root"], f"{model_name}_decomposed"
88+
config["decomposed_root"], rel_model_path, config["decomposed_dentry"]
5289
)
5390
subgraph_paths = []
54-
for name in sorted(os.listdir(decomposed_parent_dir)):
91+
dentries = os.listdir(decomposed_parent_dir)
92+
for name in sorted(dentries, key=self._get_model_name_order):
5593
full_path = os.path.join(decomposed_parent_dir, name)
56-
if os.path.isdir(full_path) and name[-1].isdigit():
94+
if os.path.isdir(full_path) and self._get_model_name_order(name) >= 0:
5795
subgraph_paths.append(full_path)
5896

59-
print(
60-
f"[RangeDecomposerValidatorBackend] Found subgraphs: {[os.path.basename(p) for p in subgraph_paths]}"
61-
)
62-
6397
device = model.__class__.__graph_net_device__
6498
subgraph_instances = []
6599

66100
for path in subgraph_paths:
67101
instance = self._load_model_instance(path, device)
68102
subgraph_instances.append(instance)
69-
dir_name = os.path.basename(path)
70-
print(
71-
f"[RangeDecomposerValidatorBackend] Loaded and instantiated '{dir_name}'"
72-
)
73103

74104
composed_model = ComposedModel(subgraph_instances)
75105
return composed_model.eval()

graph_net/torch/fx_graph_parse_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def _get_name_pattern2replacement(names_from_signature, names_from_placeholder):
113113

114114

115115
def _rename_placeholder(name, pattern2replacement):
116-
assert name[:2] == "L_" or name[:2] == "l_", f"{name=}"
116+
if not (name[:2] == "L_" or name[:2] == "l_"):
117+
return name
117118
name = name[2:]
118119
if name[0] == "l":
119120
name = "L" + name[1:]

graph_net/torch/graph_decomposer.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23
import torch
34
import json
45
from graph_net.torch.decompose_util import convert_to_submodules_graph
@@ -79,7 +80,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
7980
def get_naive_decomposer_extractor(self, submodule, seq_no):
8081
return NaiveDecomposerExtractorModule(
8182
config=self.config,
82-
parent_graph_name=self.name,
83+
parent_graph_rel_model_path=self.name,
8384
submodule=submodule,
8485
seq_no=seq_no,
8586
)
@@ -145,7 +146,7 @@ def get_naive_decomposer_extractor(self, model_path):
145146
def fn(submodule, seq_no):
146147
return NaiveDecomposerExtractorModule(
147148
config=self.config,
148-
parent_graph_name=os.path.basename(model_path),
149+
parent_graph_rel_model_path=os.path.basename(model_path),
149150
submodule=submodule,
150151
seq_no=seq_no,
151152
)
@@ -165,6 +166,7 @@ def __init__(self, config: dict = None):
165166

166167
def _make_config(
167168
self,
169+
resume: bool = False,
168170
split_results_path=None,
169171
group_head_and_tail=False,
170172
chain_style=False,
@@ -181,6 +183,7 @@ def _make_config(
181183
f"split_results_path should be a valid JSON file path, but got {split_results_path=}"
182184
)
183185
return {
186+
"resume": resume,
184187
"split_results_path": split_results_path,
185188
"group_head_and_tail": group_head_and_tail,
186189
"chain_style": chain_style,
@@ -190,10 +193,25 @@ def _make_config(
190193
"model_path_prefix": model_path_prefix,
191194
}
192195

196+
def _is_model_handled(self, rel_model_path, split_positions):
197+
num_subgraphs = len(split_positions) + 1
198+
decomposed_model_path = Path(self.config["output_dir"]) / rel_model_path
199+
num_decomposed = len(list(decomposed_model_path.rglob("model.py")))
200+
if num_decomposed > 0:
201+
assert (
202+
num_subgraphs <= num_decomposed
203+
), f"{num_subgraphs=} {num_decomposed=} {str(decomposed_model_path)=}"
204+
return num_subgraphs == num_decomposed
205+
193206
def __call__(self, rel_model_path):
194207
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
195208
split_results = load_json(self.config["split_results_path"])
196209
split_positions = split_results[rel_model_path]["split_positions"]
210+
if self.config["resume"] and self._is_model_handled(
211+
rel_model_path, split_positions
212+
):
213+
return
214+
torch.cuda.empty_cache()
197215
config = {
198216
"split_positions": split_positions,
199217
"group_head_and_tail": self.config.get("group_head_and_tail", False),
@@ -203,16 +221,16 @@ def __call__(self, rel_model_path):
203221
gm = parse_sole_graph_module(module, inputs)
204222
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
205223
gm,
206-
submodule_hook=self.get_naive_decomposer_extractor(model_path),
224+
submodule_hook=self.get_naive_decomposer_extractor(rel_model_path),
207225
**config,
208226
)
209227
rewrited_gm(*inputs)
210228

211-
def get_naive_decomposer_extractor(self, model_path):
229+
def get_naive_decomposer_extractor(self, rel_model_path):
212230
def fn(submodule, seq_no):
213231
return NaiveDecomposerExtractorModule(
214232
config=self.config,
215-
parent_graph_name=os.path.basename(model_path),
233+
parent_graph_rel_model_path=rel_model_path,
216234
submodule=submodule,
217235
seq_no=seq_no,
218236
)
@@ -224,7 +242,7 @@ class NaiveDecomposerExtractorModule(torch.nn.Module):
224242
def __init__(
225243
self,
226244
config: dict,
227-
parent_graph_name: str,
245+
parent_graph_rel_model_path: str,
228246
submodule: torch.nn.Module,
229247
seq_no: int,
230248
):
@@ -233,34 +251,28 @@ def __init__(
233251
self.submodule = submodule
234252
self.seq_no = seq_no
235253
self.extracted = False
236-
self.parent_graph_name = parent_graph_name
254+
self.parent_graph_rel_model_path = parent_graph_rel_model_path
255+
parent_graph_model_name = os.path.basename(parent_graph_rel_model_path)
237256
if self.seq_no is None:
238-
self.model_name = parent_graph_name
257+
self.model_name = parent_graph_model_name
239258
else:
240-
submodule_name = f"{parent_graph_name}_{self.seq_no}"
259+
submodule_name = f"{parent_graph_model_name}_{self.seq_no}"
241260
self.model_name = submodule_name
242261
self.builtin_extractor = BuiltinGraphExtractor(
243262
name=submodule_name,
244263
dynamic=False,
245264
mut_graph_codes=[],
246265
placeholder_auto_rename=False,
247266
workspace_path=os.path.join(
248-
self.config["output_dir"], f"{parent_graph_name}_decomposed"
267+
self.config["output_dir"], parent_graph_rel_model_path, "_decomposed"
249268
),
250269
)
251270
self.filter = self.make_filter(self.config)
252271

253272
def _get_model_path(self):
254273
return os.path.join(
255274
self.config["output_dir"],
256-
f"{self.parent_graph_name}_decomposed",
257-
self.model_name,
258-
)
259-
260-
def _get_model_path(self):
261-
return os.path.join(
262-
self.config["output_dir"],
263-
f"{self.parent_graph_name}_decomposed",
275+
f"{self.parent_graph_model_name}/_decomposed",
264276
self.model_name,
265277
)
266278

0 commit comments

Comments
 (0)