Skip to content

Commit 102a334

Browse files
committed
fix style
1 parent 31e8d84 commit 102a334

File tree

3 files changed

+102
-135
lines changed

3 files changed

+102
-135
lines changed

graph_net/test/graph_variable_rename_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMP
2929

3030
test_compiler_config_json_str=$(cat <<EOF
3131
{
32+
"model_path_prefix": "$GRAPH_NET_ROOT",
3233
"renamed_root": "$WORKSPACE"
3334
}
3435
EOF

graph_net/torch/backend/graph_variable_renamer_validator_backend.py

Lines changed: 59 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
from pathlib import Path
3-
from typing import Dict, Any
4-
from graph_net.imp_util import load_module
3+
from typing import Dict
54
from graph_net.tensor_meta import TensorMeta
5+
import os
6+
import importlib.util
67

78

89
class RenamedModelAdapter(torch.nn.Module):
@@ -14,104 +15,75 @@ def __init__(self, renamed_model: torch.nn.Module, mapping: Dict[str, str]):
1415
self.__graph_net_file_path__ = renamed_model.__graph_net_file_path__
1516

1617
def forward(self, **kwargs):
18+
new_kwargs = self._convert_by_name_mapping(kwargs)
19+
return self.model(**new_kwargs)
20+
21+
def _convert_by_name_mapping(self, kwargs):
1722
new_kwargs = {}
1823
for old_name, value in kwargs.items():
1924
if old_name in self.mapping:
2025
new_name = self.mapping[old_name]
2126
new_kwargs[new_name] = value
22-
return self.model(**new_kwargs)
27+
return new_kwargs
2328

2429

2530
class GraphVariableRenamerValidatorBackend:
26-
def __init__(self, config: Dict[str, Any] = None):
27-
self.config = config or {}
28-
29-
def _get_rename_mapping(self, model_dir: Path) -> Dict[str, str]:
31+
def _get_rename_mapping(self, model_dir: Path):
3032
mapping = {}
31-
if not model_dir.exists():
32-
print(f"[ValidatorBackend] Error: Model dir does not exist: {model_dir}")
33-
return mapping
34-
3533
for meta_file in ["input_meta.py", "weight_meta.py"]:
3634
meta_path = model_dir / meta_file
37-
if meta_path.exists():
38-
try:
39-
metas = TensorMeta.unserialize_from_py_file(str(meta_path))
40-
for m in metas:
41-
if m.original_name:
42-
mapping[m.original_name] = m.name
43-
except Exception as e:
44-
print(
45-
f"[ValidatorBackend] Warning: Failed to parse {meta_path}: {e}"
46-
)
35+
if not meta_path.exists():
36+
continue
37+
metas = TensorMeta.unserialize_from_py_file(str(meta_path))
38+
for m in metas:
39+
if m.original_name:
40+
mapping[m.original_name] = m.name
4741
return mapping
4842

49-
def _load_renamed_model(
50-
self, model_dir: Path, device: torch.device
51-
) -> torch.nn.Module:
52-
model_py_path = model_dir / "model.py"
53-
if not model_py_path.exists():
54-
raise FileNotFoundError(f"Renamed model not found at {model_py_path}")
55-
56-
py_module = load_module(str(model_py_path))
57-
58-
if not hasattr(py_module, "GraphModule"):
59-
raise ValueError(f"GraphModule class not found in {model_py_path}")
60-
61-
GraphModule = getattr(py_module, "GraphModule")
62-
GraphModule.__graph_net_file_path__ = str(model_py_path)
63-
64-
model = GraphModule()
65-
model.to(device)
66-
model.eval()
67-
return model
68-
69-
def __call__(self, original_model: torch.nn.Module) -> torch.nn.Module:
70-
renamed_root = self.config.get("renamed_root")
71-
if not renamed_root:
72-
raise ValueError("Config 'renamed_root' is missing!")
73-
74-
default_prefix = str(Path(__file__).resolve().parent.parent.parent.parent)
75-
model_path_prefix = self.config.get("model_path_prefix", default_prefix)
76-
77-
if not hasattr(original_model, "__graph_net_file_path__"):
78-
raise ValueError("Original model missing __graph_net_file_path__")
79-
80-
orig_abs_path = Path(original_model.__class__.__graph_net_file_path__).resolve()
81-
orig_model_dir = orig_abs_path.parent
82-
83-
try:
84-
rel_model_path = orig_model_dir.relative_to(model_path_prefix)
85-
except ValueError:
86-
print(
87-
f"[ValidatorBackend] Warning: Model path {orig_model_dir} is not under prefix {model_path_prefix}. Fallback to leaf name."
88-
)
89-
rel_model_path = orig_model_dir.name
90-
91-
renamed_model_dir = Path(renamed_root) / rel_model_path
92-
93-
print(f"[ValidatorBackend] Original Path: {orig_model_dir}")
94-
print(f"[ValidatorBackend] Relative Path: {rel_model_path}")
95-
print(f"[ValidatorBackend] Loading Renamed: {renamed_model_dir}")
96-
97-
try:
98-
device = next(original_model.parameters()).device
99-
except StopIteration:
100-
device = torch.device("cpu")
101-
102-
renamed_core_model = self._load_renamed_model(renamed_model_dir, device)
103-
mapping = self._get_rename_mapping(renamed_model_dir)
104-
105-
if not mapping:
106-
print(
107-
f"[ValidatorBackend] Warning: Mapping is empty for {rel_model_path}. Check input_meta.py generation."
108-
)
109-
110-
adapter = RenamedModelAdapter(renamed_core_model, mapping)
111-
adapter.to(device)
112-
adapter.eval()
113-
114-
return adapter
43+
def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
44+
class_name = "GraphModule"
45+
model_file = os.path.join(path, "model.py")
46+
47+
spec = importlib.util.spec_from_file_location(class_name, model_file)
48+
module = importlib.util.module_from_spec(spec)
49+
spec.loader.exec_module(module)
50+
51+
ModelClass = getattr(module, class_name)
52+
instance = ModelClass().to(device)
53+
return instance
54+
55+
def _make_config(
56+
self,
57+
model_path_prefix: str,
58+
renamed_root: str,
59+
renamed_dentry: str = "_renamed",
60+
):
61+
return {
62+
"model_path_prefix": model_path_prefix,
63+
"renamed_root": renamed_root,
64+
"renamed_dentry": renamed_dentry,
65+
}
66+
67+
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
68+
config = self._make_config(**self.config)
69+
model_path = os.path.dirname(model.__class__.__graph_net_file_path__)
70+
model_name = os.path.basename(model_path)
71+
renamed_dir_name = f"{model_name}_renamed"
72+
renamed_model_dir = os.path.join(config["renamed_root"], renamed_dir_name)
73+
74+
print(f"[GraphVariableRenamerValidatorBackend] Processing: {model_name}")
75+
print(
76+
f"[GraphVariableRenamerValidatorBackend] Loading from: {renamed_model_dir}"
77+
)
78+
79+
device = model.__class__.__graph_net_device__
80+
renamed_model = self._load_model_instance(renamed_model_dir, device)
81+
mapping = self._get_rename_mapping(Path(renamed_model_dir))
82+
assert (
83+
mapping
84+
), f"Mapping is empty for {renamed_dir_name} at {renamed_model_dir}"
85+
adapter = RenamedModelAdapter(renamed_model, mapping)
86+
return adapter.eval()
11587

11688
def synchronize(self):
11789
if torch.cuda.is_available():

graph_net/torch/graph_variable_renamer.py

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ def __call__(self, rel_model_path):
7979
module, inputs = get_torch_module_and_inputs(src_model_path)
8080
gm = parse_sole_graph_module(module, inputs)
8181
gm = self.rename_graph_variables(gm, inputs, src_model_path)
82+
model_name = os.path.basename(rel_model_path.rstrip(os.sep))
83+
new_rel_path = f"{model_name}_renamed"
8284
dst_model_path = os.path.realpath(
83-
os.path.join(self.config["output_dir"], rel_model_path)
85+
os.path.join(self.config["output_dir"], new_rel_path)
8486
)
8587
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
8688
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
@@ -158,55 +160,47 @@ def _get_input_names_from_signature(self, module):
158160
def rename_graph_variables(
159161
self, gm: torch.fx.GraphModule, sample_inputs, model_path
160162
):
161-
in_cnt = 0
162-
w_cnt = 0
163-
tmp_cnt = 0
164-
163+
counters = {"in": 0, "w": 0, "tmp": 0}
164+
# graph may not have input, only contain weights
165165
arg_iter = iter(sample_inputs) if sample_inputs else iter([])
166-
167166
for node in gm.graph.nodes:
168-
old_name = node.name
169-
170-
if "original_name" not in node.meta:
171-
node.meta["original_name"] = node.name
172-
173-
new_name = old_name
174-
should_update_target = False
175-
176-
if node.op == "placeholder":
177-
real_arg = next(arg_iter)
178-
is_weight = not self.data_input_predicator(model_path, node.name)
179-
if node.type is not None:
180-
if isinstance(node.type, type) and issubclass(
181-
node.type, torch.nn.parameter.Parameter
182-
):
183-
is_weight = True
184-
elif real_arg is not None:
185-
if isinstance(real_arg, torch.nn.Parameter):
186-
is_weight = True
187-
188-
if is_weight:
189-
new_name = f"w_{w_cnt}"
190-
w_cnt += 1
191-
else:
192-
new_name = f"in_{in_cnt}"
193-
in_cnt += 1
194-
195-
should_update_target = True
196-
197-
elif node.op == "get_attr":
198-
node.name = f"w_{w_cnt}"
199-
w_cnt += 1
200-
201-
elif node.op != "output":
202-
node.name = f"tmp_{tmp_cnt}"
203-
tmp_cnt += 1
204-
205-
if new_name != old_name:
206-
node.name = new_name
207-
if should_update_target:
208-
node.target = new_name
209-
167+
self._process_single_node(node, arg_iter, counters, model_path)
210168
gm.graph.lint()
211169
gm.recompile()
212170
return gm
171+
172+
def _process_single_node(self, node, arg_iter, counters, model_path):
173+
if "original_name" not in node.meta:
174+
node.meta["original_name"] = node.name
175+
if node.op == "placeholder":
176+
self._handle_placeholder(node, arg_iter, counters, model_path)
177+
elif node.op == "get_attr":
178+
self._apply_rename(node, "w", counters)
179+
elif node.op != "output":
180+
self._apply_rename(node, "tmp", counters)
181+
182+
def _handle_placeholder(self, node, arg_iter, counters, model_path):
183+
real_arg = next(arg_iter, None)
184+
is_weight = self._is_weight_node(node, real_arg, model_path)
185+
prefix = "w" if is_weight else "in"
186+
self._apply_rename(node, prefix, counters, update_target=True)
187+
188+
def _apply_rename(self, node, prefix, counters, update_target=False):
189+
new_name = f"{prefix}_{counters[prefix]}"
190+
counters[prefix] += 1
191+
if node.name != new_name:
192+
node.name = new_name
193+
if update_target:
194+
node.target = new_name
195+
196+
def _is_weight_node(self, node, real_arg, model_path):
197+
if not self.data_input_predicator(model_path, node.name):
198+
return True
199+
if node.type is not None:
200+
if isinstance(node.type, type) and issubclass(
201+
node.type, torch.nn.parameter.Parameter
202+
):
203+
return True
204+
if real_arg is not None and isinstance(real_arg, torch.nn.Parameter):
205+
return True
206+
return False

0 commit comments

Comments
 (0)