Skip to content

Commit e6c4d97

Browse files
committed
Validate the correctness of graph variable rename
1 parent 5750ae7 commit e6c4d97

File tree

4 files changed

+174
-4
lines changed

4 files changed

+174
-4
lines changed

graph_net/test/graph_variable_rename_test.sh

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
5+
WORKSPACE=/tmp/graph_variable_rename_workspace
56

67
# input model path
78
MODEL_NAME=resnet18
@@ -16,7 +17,7 @@ config_json_str=$(cat <<EOF
1617
"data_input_predicator_class_name": "NaiveDataInputPredicator",
1718
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1819
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
19-
"output_dir": "/tmp/graph_variable_rename_workspace"
20+
"output_dir": "$WORKSPACE"
2021
}
2122
}
2223
EOF
@@ -25,3 +26,14 @@ CONFIG=$(echo $config_json_str | base64 -w 0)
2526

2627
python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
2728
# python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG
29+
30+
31+
python3 -m graph_net.torch.test_compiler \
32+
--model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES \
33+
--compiler graph_variable_renamer \
34+
--device cuda \
35+
> "$WORKSPACE/validation.log" 2>&1
36+
37+
python3 -m graph_net.plot_ESt \
38+
--benchmark-path "$WORKSPACE/validation.log" \
39+
--output-dir "$WORKSPACE"
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import torch
2+
import graph_net
3+
from pathlib import Path
4+
from typing import Any, Dict
5+
from graph_net.imp_util import load_module
6+
from graph_net.torch.graph_variable_renamer import GraphVariableRenamer
7+
from graph_net.tensor_meta import TensorMeta
8+
9+
10+
class RenamedModelAdapter(torch.nn.Module):
11+
def __init__(self, renamed_model, mapping):
12+
super().__init__()
13+
self.model = renamed_model
14+
self.mapping = mapping
15+
16+
def forward(self, **kwargs):
17+
new_kwargs = {}
18+
for old_name, value in kwargs.items():
19+
if old_name in self.mapping:
20+
new_name = self.mapping[old_name]
21+
new_kwargs[new_name] = value
22+
23+
return self.model(**new_kwargs)
24+
25+
26+
class GraphVariableRenameBackend:
27+
def __init__(self, config: Dict[str, Any] = None):
28+
if config is None:
29+
config = {}
30+
self.config = config
31+
self.workspace_path = Path(
32+
self.config.get("workspace_path", "./tmp/graph_variable_rename_workspace")
33+
)
34+
self.workspace_path.mkdir(parents=True, exist_ok=True)
35+
36+
def _get_default_paths(self):
37+
lib_root = Path(graph_net.__file__).parent
38+
default_util_path = str(lib_root / "torch/constraint_util.py")
39+
return default_util_path
40+
41+
def _get_rename_mapping(self, dst_model_dir: Path) -> Dict[str, str]:
42+
mapping = {}
43+
44+
input_meta_path = dst_model_dir / "input_meta.py"
45+
if input_meta_path.exists():
46+
metas = TensorMeta.unserialize_from_py_file(str(input_meta_path))
47+
for m in metas:
48+
if m.original_name:
49+
mapping[m.original_name] = m.name
50+
51+
weight_meta_path = dst_model_dir / "weight_meta.py"
52+
if weight_meta_path.exists():
53+
metas = TensorMeta.unserialize_from_py_file(str(weight_meta_path))
54+
for m in metas:
55+
if m.original_name:
56+
mapping[m.original_name] = m.name
57+
58+
return mapping
59+
60+
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
61+
print("\n[GraphVariableRenameBackend] Starting rename process...")
62+
63+
if not hasattr(model.__class__, "__graph_net_file_path__"):
64+
raise ValueError(
65+
"Input model must be a GraphNet model with __graph_net_file_path__ attribute."
66+
)
67+
68+
src_file_path = Path(model.__class__.__graph_net_file_path__).resolve()
69+
src_model_dir = src_file_path.parent
70+
model_rel_path = src_model_dir.name
71+
model_path_prefix = str(src_model_dir.parent)
72+
default_util_path = self._get_default_paths()
73+
data_input_predicator_filepath = self.config.get(
74+
"data_input_predicator_filepath", default_util_path
75+
)
76+
data_input_predicator_class_name = self.config.get(
77+
"data_input_predicator_class_name", "NaiveDataInputPredicator"
78+
)
79+
data_input_predicator_config = self.config.get(
80+
"data_input_predicator_config", {}
81+
)
82+
model_runnable_predicator_filepath = self.config.get(
83+
"model_runnable_predicator_filepath", default_util_path
84+
)
85+
model_runnable_predicator_class_name = self.config.get(
86+
"model_runnable_predicator_class_name", "ModelRunnablePredicator"
87+
)
88+
model_runnable_predicator_config = self.config.get(
89+
"model_runnable_predicator_config", {}
90+
)
91+
92+
output_dir = str(self.workspace_path)
93+
94+
renamer_config = {
95+
"output_dir": output_dir,
96+
"model_path_prefix": model_path_prefix,
97+
"data_input_predicator_filepath": data_input_predicator_filepath,
98+
"data_input_predicator_class_name": data_input_predicator_class_name,
99+
"data_input_predicator_config": data_input_predicator_config,
100+
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
101+
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
102+
"model_runnable_predicator_config": model_runnable_predicator_config,
103+
}
104+
105+
print(f"[Backend Info] Model Source Dir: {src_model_dir}")
106+
print(f"[Backend Info] Calculated Prefix: {model_path_prefix}")
107+
108+
try:
109+
renamer = GraphVariableRenamer(renamer_config)
110+
renamer(model_rel_path)
111+
except Exception as e:
112+
print(f"[Error] GraphVariableRenamer execution failed: {e}")
113+
raise e
114+
115+
dst_model_dir = self.workspace_path / model_rel_path
116+
print(f"[Success] Renamed model saved to {dst_model_dir}")
117+
118+
renamed_core_model = self._load_model(dst_model_dir)
119+
name_mapping = self._get_rename_mapping(dst_model_dir)
120+
121+
adapter_model = RenamedModelAdapter(renamed_core_model, name_mapping)
122+
adapter_model.eval()
123+
124+
return adapter_model
125+
126+
def _load_model(self, model_dir: Path) -> torch.nn.Module:
127+
model_py_path = model_dir / "model.py"
128+
if not model_py_path.exists():
129+
raise FileNotFoundError(f"Renamed model not found at {model_py_path}")
130+
131+
py_module = load_module(str(model_py_path))
132+
133+
if hasattr(py_module, "GraphModule"):
134+
GraphModule = getattr(py_module, "GraphModule")
135+
GraphModule.__graph_net_file_path__ = str(model_py_path)
136+
model = GraphModule()
137+
return model
138+
else:
139+
raise ValueError(f"GraphModule class not found in {model_py_path}")
140+
141+
def synchronize(self):
142+
if torch.cuda.is_available():
143+
torch.cuda.synchronize()

graph_net/torch/graph_variable_renamer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,17 @@ def rename_graph_variables(
162162
w_cnt = 0
163163
tmp_cnt = 0
164164

165-
arg_iter = iter(sample_inputs)
165+
arg_iter = iter(sample_inputs) if sample_inputs else iter([])
166+
166167
for node in gm.graph.nodes:
168+
old_name = node.name
169+
167170
if "original_name" not in node.meta:
168171
node.meta["original_name"] = node.name
169172

173+
new_name = old_name
174+
should_update_target = False
175+
170176
if node.op == "placeholder":
171177
real_arg = next(arg_iter)
172178
is_weight = not self.data_input_predicator(model_path, node.name)
@@ -186,8 +192,7 @@ def rename_graph_variables(
186192
new_name = f"in_{in_cnt}"
187193
in_cnt += 1
188194

189-
node.name = new_name
190-
node.target = new_name
195+
should_update_target = True
191196

192197
elif node.op == "get_attr":
193198
node.name = f"w_{w_cnt}"
@@ -197,6 +202,11 @@ def rename_graph_variables(
197202
node.name = f"tmp_{tmp_cnt}"
198203
tmp_cnt += 1
199204

205+
if new_name != old_name:
206+
node.name = new_name
207+
if should_update_target:
208+
node.target = new_name
209+
200210
gm.graph.lint()
201211
gm.recompile()
202212
return gm

graph_net/torch/test_compiler.py

100644100755
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from graph_net.torch.backend.range_decomposer_validator_backend import (
2626
RangeDecomposerValidatorBackend,
2727
)
28+
29+
from graph_net.torch.backend.graph_variable_rename_backend import (
30+
GraphVariableRenameBackend,
31+
)
2832
from graph_net import test_compiler_util
2933
from graph_net import path_utils
3034

@@ -38,6 +42,7 @@
3842
"nope": NopeBackend(),
3943
"unstable_to_stable": UnstableToStableBackend(),
4044
"range_decomposer_validator": RangeDecomposerValidatorBackend(),
45+
"graph_variable_renamer": GraphVariableRenameBackend(),
4146
}
4247

4348

0 commit comments

Comments
 (0)