11import torch
22from pathlib import Path
3- from typing import Dict , Any
4- from graph_net .imp_util import load_module
3+ from typing import Dict
54from graph_net .tensor_meta import TensorMeta
5+ import os
6+ import importlib .util
67
78
89class RenamedModelAdapter (torch .nn .Module ):
@@ -14,104 +15,75 @@ def __init__(self, renamed_model: torch.nn.Module, mapping: Dict[str, str]):
1415 self .__graph_net_file_path__ = renamed_model .__graph_net_file_path__
1516
1617 def forward (self , ** kwargs ):
18+ new_kwargs = self ._convert_by_name_mapping (kwargs )
19+ return self .model (** new_kwargs )
20+
21+ def _convert_by_name_mapping (self , kwargs ):
1722 new_kwargs = {}
1823 for old_name , value in kwargs .items ():
1924 if old_name in self .mapping :
2025 new_name = self .mapping [old_name ]
2126 new_kwargs [new_name ] = value
22- return self . model ( ** new_kwargs )
27+ return new_kwargs
2328
2429
2530class GraphVariableRenamerValidatorBackend :
26- def __init__ (self , config : Dict [str , Any ] = None ):
27- self .config = config or {}
28-
29- def _get_rename_mapping (self , model_dir : Path ) -> Dict [str , str ]:
31+ def _get_rename_mapping (self , model_dir : Path ):
3032 mapping = {}
31- if not model_dir .exists ():
32- print (f"[ValidatorBackend] Error: Model dir does not exist: { model_dir } " )
33- return mapping
34-
3533 for meta_file in ["input_meta.py" , "weight_meta.py" ]:
3634 meta_path = model_dir / meta_file
37- if meta_path .exists ():
38- try :
39- metas = TensorMeta .unserialize_from_py_file (str (meta_path ))
40- for m in metas :
41- if m .original_name :
42- mapping [m .original_name ] = m .name
43- except Exception as e :
44- print (
45- f"[ValidatorBackend] Warning: Failed to parse { meta_path } : { e } "
46- )
35+ if not meta_path .exists ():
36+ continue
37+ metas = TensorMeta .unserialize_from_py_file (str (meta_path ))
38+ for m in metas :
39+ if m .original_name :
40+ mapping [m .original_name ] = m .name
4741 return mapping
4842
49- def _load_renamed_model (
50- self , model_dir : Path , device : torch .device
51- ) -> torch .nn .Module :
52- model_py_path = model_dir / "model.py"
53- if not model_py_path .exists ():
54- raise FileNotFoundError (f"Renamed model not found at { model_py_path } " )
55-
56- py_module = load_module (str (model_py_path ))
57-
58- if not hasattr (py_module , "GraphModule" ):
59- raise ValueError (f"GraphModule class not found in { model_py_path } " )
60-
61- GraphModule = getattr (py_module , "GraphModule" )
62- GraphModule .__graph_net_file_path__ = str (model_py_path )
63-
64- model = GraphModule ()
65- model .to (device )
66- model .eval ()
67- return model
68-
69- def __call__ (self , original_model : torch .nn .Module ) -> torch .nn .Module :
70- renamed_root = self .config .get ("renamed_root" )
71- if not renamed_root :
72- raise ValueError ("Config 'renamed_root' is missing!" )
73-
74- default_prefix = str (Path (__file__ ).resolve ().parent .parent .parent .parent )
75- model_path_prefix = self .config .get ("model_path_prefix" , default_prefix )
76-
77- if not hasattr (original_model , "__graph_net_file_path__" ):
78- raise ValueError ("Original model missing __graph_net_file_path__" )
79-
80- orig_abs_path = Path (original_model .__class__ .__graph_net_file_path__ ).resolve ()
81- orig_model_dir = orig_abs_path .parent
82-
83- try :
84- rel_model_path = orig_model_dir .relative_to (model_path_prefix )
85- except ValueError :
86- print (
87- f"[ValidatorBackend] Warning: Model path { orig_model_dir } is not under prefix { model_path_prefix } . Fallback to leaf name."
88- )
89- rel_model_path = orig_model_dir .name
90-
91- renamed_model_dir = Path (renamed_root ) / rel_model_path
92-
93- print (f"[ValidatorBackend] Original Path: { orig_model_dir } " )
94- print (f"[ValidatorBackend] Relative Path: { rel_model_path } " )
95- print (f"[ValidatorBackend] Loading Renamed: { renamed_model_dir } " )
96-
97- try :
98- device = next (original_model .parameters ()).device
99- except StopIteration :
100- device = torch .device ("cpu" )
101-
102- renamed_core_model = self ._load_renamed_model (renamed_model_dir , device )
103- mapping = self ._get_rename_mapping (renamed_model_dir )
104-
105- if not mapping :
106- print (
107- f"[ValidatorBackend] Warning: Mapping is empty for { rel_model_path } . Check input_meta.py generation."
108- )
109-
110- adapter = RenamedModelAdapter (renamed_core_model , mapping )
111- adapter .to (device )
112- adapter .eval ()
113-
114- return adapter
43+ def _load_model_instance (self , path : str , device : str ) -> torch .nn .Module :
44+ class_name = "GraphModule"
45+ model_file = os .path .join (path , "model.py" )
46+
47+ spec = importlib .util .spec_from_file_location (class_name , model_file )
48+ module = importlib .util .module_from_spec (spec )
49+ spec .loader .exec_module (module )
50+
51+ ModelClass = getattr (module , class_name )
52+ instance = ModelClass ().to (device )
53+ return instance
54+
55+ def _make_config (
56+ self ,
57+ model_path_prefix : str ,
58+ renamed_root : str ,
59+ renamed_dentry : str = "_renamed" ,
60+ ):
61+ return {
62+ "model_path_prefix" : model_path_prefix ,
63+ "renamed_root" : renamed_root ,
64+ "renamed_dentry" : renamed_dentry ,
65+ }
66+
67+ def __call__ (self , model : torch .nn .Module ) -> torch .nn .Module :
68+ config = self ._make_config (** self .config )
69+ model_path = os .path .dirname (model .__class__ .__graph_net_file_path__ )
70+ model_name = os .path .basename (model_path )
71+ renamed_dir_name = f"{ model_name } _renamed"
72+ renamed_model_dir = os .path .join (config ["renamed_root" ], renamed_dir_name )
73+
74+ print (f"[GraphVariableRenamerValidatorBackend] Processing: { model_name } " )
75+ print (
76+ f"[GraphVariableRenamerValidatorBackend] Loading from: { renamed_model_dir } "
77+ )
78+
79+ device = model .__class__ .__graph_net_device__
80+ renamed_model = self ._load_model_instance (renamed_model_dir , device )
81+ mapping = self ._get_rename_mapping (Path (renamed_model_dir ))
82+ assert (
83+ mapping
84+ ), f"Mapping is empty for { renamed_dir_name } at { renamed_model_dir } "
85+ adapter = RenamedModelAdapter (renamed_model , mapping )
86+ return adapter .eval ()
11587
11688 def synchronize (self ):
11789 if torch .cuda .is_available ():
0 commit comments