Skip to content

Commit e058173

Browse files
committed
add GraphVariableRenamerValidatorBackend
1 parent 964fc7f commit e058173

File tree

4 files changed

+130
-148
lines changed

4 files changed

+130
-148
lines changed

graph_net/test/graph_variable_rename_test.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,19 @@ CONFIG=$(echo $config_json_str | base64 -w 0)
2727
python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
2828
# python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG
2929

30+
test_compiler_config_json_str=$(cat <<EOF
31+
{
32+
"renamed_root": "$WORKSPACE"
33+
}
34+
EOF
35+
)
36+
TEST_COMPILER_CONFIG=$(echo $test_compiler_config_json_str | base64 -w 0)
3037

3138
python3 -m graph_net.torch.test_compiler \
3239
--model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES \
33-
--compiler graph_variable_renamer \
40+
--compiler graph_variable_renamer_validator \
3441
--device cuda \
42+
--config $TEST_COMPILER_CONFIG \
3543
> "$WORKSPACE/validation.log" 2>&1
3644

3745
python3 -m graph_net.plot_ESt \

graph_net/torch/backend/graph_variable_rename_backend.py

Lines changed: 0 additions & 143 deletions
This file was deleted.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import torch
2+
from pathlib import Path
3+
from typing import Dict, Any
4+
from graph_net.imp_util import load_module
5+
from graph_net.tensor_meta import TensorMeta
6+
7+
8+
class RenamedModelAdapter(torch.nn.Module):
9+
def __init__(self, renamed_model: torch.nn.Module, mapping: Dict[str, str]):
10+
super().__init__()
11+
self.model = renamed_model
12+
self.mapping = mapping
13+
if hasattr(renamed_model, "__graph_net_file_path__"):
14+
self.__graph_net_file_path__ = renamed_model.__graph_net_file_path__
15+
16+
def forward(self, **kwargs):
17+
new_kwargs = {}
18+
for old_name, value in kwargs.items():
19+
if old_name in self.mapping:
20+
new_name = self.mapping[old_name]
21+
new_kwargs[new_name] = value
22+
return self.model(**new_kwargs)
23+
24+
25+
class GraphVariableRenamerValidatorBackend:
26+
def __init__(self, config: Dict[str, Any] = None):
27+
self.config = config or {}
28+
29+
def _get_rename_mapping(self, model_dir: Path) -> Dict[str, str]:
30+
mapping = {}
31+
if not model_dir.exists():
32+
print(f"[ValidatorBackend] Error: Model dir does not exist: {model_dir}")
33+
return mapping
34+
35+
for meta_file in ["input_meta.py", "weight_meta.py"]:
36+
meta_path = model_dir / meta_file
37+
if meta_path.exists():
38+
try:
39+
metas = TensorMeta.unserialize_from_py_file(str(meta_path))
40+
for m in metas:
41+
if m.original_name:
42+
mapping[m.original_name] = m.name
43+
except Exception as e:
44+
print(
45+
f"[ValidatorBackend] Warning: Failed to parse {meta_path}: {e}"
46+
)
47+
return mapping
48+
49+
def _load_renamed_model(
50+
self, model_dir: Path, device: torch.device
51+
) -> torch.nn.Module:
52+
model_py_path = model_dir / "model.py"
53+
if not model_py_path.exists():
54+
raise FileNotFoundError(f"Renamed model not found at {model_py_path}")
55+
56+
py_module = load_module(str(model_py_path))
57+
58+
if not hasattr(py_module, "GraphModule"):
59+
raise ValueError(f"GraphModule class not found in {model_py_path}")
60+
61+
GraphModule = getattr(py_module, "GraphModule")
62+
GraphModule.__graph_net_file_path__ = str(model_py_path)
63+
64+
model = GraphModule()
65+
model.to(device)
66+
model.eval()
67+
return model
68+
69+
def __call__(self, original_model: torch.nn.Module) -> torch.nn.Module:
70+
renamed_root = self.config.get("renamed_root")
71+
if not renamed_root:
72+
raise ValueError("Config 'renamed_root' is missing!")
73+
74+
default_prefix = str(Path(__file__).resolve().parent.parent.parent.parent)
75+
model_path_prefix = self.config.get("model_path_prefix", default_prefix)
76+
77+
if not hasattr(original_model, "__graph_net_file_path__"):
78+
raise ValueError("Original model missing __graph_net_file_path__")
79+
80+
orig_abs_path = Path(original_model.__class__.__graph_net_file_path__).resolve()
81+
orig_model_dir = orig_abs_path.parent
82+
83+
try:
84+
rel_model_path = orig_model_dir.relative_to(model_path_prefix)
85+
except ValueError:
86+
print(
87+
f"[ValidatorBackend] Warning: Model path {orig_model_dir} is not under prefix {model_path_prefix}. Fallback to leaf name."
88+
)
89+
rel_model_path = orig_model_dir.name
90+
91+
renamed_model_dir = Path(renamed_root) / rel_model_path
92+
93+
print(f"[ValidatorBackend] Original Path: {orig_model_dir}")
94+
print(f"[ValidatorBackend] Relative Path: {rel_model_path}")
95+
print(f"[ValidatorBackend] Loading Renamed: {renamed_model_dir}")
96+
97+
try:
98+
device = next(original_model.parameters()).device
99+
except StopIteration:
100+
device = torch.device("cpu")
101+
102+
renamed_core_model = self._load_renamed_model(renamed_model_dir, device)
103+
mapping = self._get_rename_mapping(renamed_model_dir)
104+
105+
if not mapping:
106+
print(
107+
f"[ValidatorBackend] Warning: Mapping is empty for {rel_model_path}. Check input_meta.py generation."
108+
)
109+
110+
adapter = RenamedModelAdapter(renamed_core_model, mapping)
111+
adapter.to(device)
112+
adapter.eval()
113+
114+
return adapter
115+
116+
def synchronize(self):
117+
if torch.cuda.is_available():
118+
torch.cuda.synchronize()

graph_net/torch/test_compiler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
from graph_net.torch.backend.range_decomposer_validator_backend import (
2525
RangeDecomposerValidatorBackend,
2626
)
27-
28-
from graph_net.torch.backend.graph_variable_rename_backend import (
29-
GraphVariableRenameBackend,
27+
from graph_net.torch.backend.graph_variable_renamer_validator_backend import (
28+
GraphVariableRenamerValidatorBackend,
3029
)
3130
from graph_net import test_compiler_util
3231
from graph_net import path_utils
@@ -41,7 +40,7 @@
4140
"nope": NopeBackend(),
4241
"unstable_to_stable": UnstableToStableBackend(),
4342
"range_decomposer_validator": RangeDecomposerValidatorBackend(),
44-
"graph_variable_renamer": GraphVariableRenameBackend(),
43+
"graph_variable_renamer_validator": GraphVariableRenamerValidatorBackend(),
4544
}
4645

4746

0 commit comments

Comments
 (0)