diff --git a/graph_net/test/graph_variable_rename_test.sh b/graph_net/test/graph_variable_rename_test.sh index bac616870..bc79565f8 100755 --- a/graph_net/test/graph_variable_rename_test.sh +++ b/graph_net/test/graph_variable_rename_test.sh @@ -2,6 +2,7 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( os.path.dirname(graph_net.__file__))") +WORKSPACE=/tmp/graph_variable_rename_workspace # input model path MODEL_NAME=resnet18 @@ -16,7 +17,7 @@ config_json_str=$(cat < "$WORKSPACE/validation.log" 2>&1 + +python3 -m graph_net.plot_ESt \ + --benchmark-path "$WORKSPACE/validation.log" \ + --output-dir "$WORKSPACE" diff --git a/graph_net/torch/backend/graph_variable_renamer_validator_backend.py b/graph_net/torch/backend/graph_variable_renamer_validator_backend.py new file mode 100755 index 000000000..cca0f6e2b --- /dev/null +++ b/graph_net/torch/backend/graph_variable_renamer_validator_backend.py @@ -0,0 +1,90 @@ +import torch +from pathlib import Path +from typing import Dict +from graph_net.tensor_meta import TensorMeta +import os +import importlib.util + + +class RenamedModelAdapter(torch.nn.Module): + def __init__(self, renamed_model: torch.nn.Module, mapping: Dict[str, str]): + super().__init__() + self.model = renamed_model + self.mapping = mapping + if hasattr(renamed_model, "__graph_net_file_path__"): + self.__graph_net_file_path__ = renamed_model.__graph_net_file_path__ + + def forward(self, **kwargs): + new_kwargs = self._convert_by_name_mapping(kwargs) + return self.model(**new_kwargs) + + def _convert_by_name_mapping(self, kwargs): + new_kwargs = {} + for old_name, value in kwargs.items(): + if old_name in self.mapping: + new_name = self.mapping[old_name] + new_kwargs[new_name] = value + return new_kwargs + + +class GraphVariableRenamerValidatorBackend: + def _get_rename_mapping(self, model_dir: Path): + mapping = {} + for meta_file in ["input_meta.py", "weight_meta.py"]: + meta_path = model_dir / meta_file + if not meta_path.exists(): + continue + metas = TensorMeta.unserialize_from_py_file(str(meta_path)) + for m in metas: + if m.original_name: + mapping[m.original_name] = m.name + return mapping + + def _load_model_instance(self, path: str, device: str) -> torch.nn.Module: + class_name = "GraphModule" + model_file = os.path.join(path, "model.py") + + spec = importlib.util.spec_from_file_location(class_name, model_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + ModelClass = getattr(module, class_name) + instance = ModelClass().to(device) + return instance + + def _make_config( + self, + model_path_prefix: str, + renamed_root: str, + renamed_dentry: str = "_renamed", + ): + return { + "model_path_prefix": model_path_prefix, + "renamed_root": renamed_root, + "renamed_dentry": renamed_dentry, + } + + def __call__(self, model: torch.nn.Module) -> torch.nn.Module: + config = self._make_config(**self.config) + model_path = os.path.dirname(model.__class__.__graph_net_file_path__) + model_name = os.path.basename(model_path) + renamed_dir_name = f"{model_name}_renamed" + renamed_model_dir = os.path.join(config["renamed_root"], renamed_dir_name) + + print(f"[GraphVariableRenamerValidatorBackend] Processing: {model_name}") + print( + f"[GraphVariableRenamerValidatorBackend] Loading from: {renamed_model_dir}" + ) + + device = model.__class__.__graph_net_device__ + renamed_model = self._load_model_instance(renamed_model_dir, device) + mapping = self._get_rename_mapping(Path(renamed_model_dir)) + assert ( + mapping + ), f"Mapping is empty for {renamed_dir_name} at {renamed_model_dir}" + adapter = RenamedModelAdapter(renamed_model, mapping) + return adapter.eval() + + def synchronize(self): + if torch.cuda.is_available(): + torch.cuda.synchronize() diff --git a/graph_net/torch/graph_variable_renamer.py b/graph_net/torch/graph_variable_renamer.py index 8d9b962c7..88bd20976 100755 --- a/graph_net/torch/graph_variable_renamer.py +++ b/graph_net/torch/graph_variable_renamer.py @@ -79,8 +79,10 @@ def __call__(self, rel_model_path): module, inputs = get_torch_module_and_inputs(src_model_path) gm = parse_sole_graph_module(module, inputs) gm = self.rename_graph_variables(gm, inputs, src_model_path) + model_name = os.path.basename(rel_model_path.rstrip(os.sep)) + new_rel_path = f"{model_name}_renamed" dst_model_path = os.path.realpath( - os.path.join(self.config["output_dir"], rel_model_path) + os.path.join(self.config["output_dir"], new_rel_path) ) Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True) shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True) @@ -158,45 +160,47 @@ def _get_input_names_from_signature(self, module): def rename_graph_variables( self, gm: torch.fx.GraphModule, sample_inputs, model_path ): - in_cnt = 0 - w_cnt = 0 - tmp_cnt = 0 - - arg_iter = iter(sample_inputs) + counters = {"in": 0, "w": 0, "tmp": 0} + # graph may not have input, only contain weights + arg_iter = iter(sample_inputs) if sample_inputs else iter([]) for node in gm.graph.nodes: - if "original_name" not in node.meta: - node.meta["original_name"] = node.name - - if node.op == "placeholder": - real_arg = next(arg_iter) - is_weight = not self.data_input_predicator(model_path, node.name) - if node.type is not None: - if isinstance(node.type, type) and issubclass( - node.type, torch.nn.parameter.Parameter - ): - is_weight = True - elif real_arg is not None: - if isinstance(real_arg, torch.nn.Parameter): - is_weight = True - - if is_weight: - new_name = f"w_{w_cnt}" - w_cnt += 1 - else: - new_name = f"in_{in_cnt}" - in_cnt += 1 - - node.name = new_name - node.target = new_name - - elif node.op == "get_attr": - node.name = f"w_{w_cnt}" - w_cnt += 1 - - elif node.op != "output": - node.name = f"tmp_{tmp_cnt}" - tmp_cnt += 1 - + self._process_single_node(node, arg_iter, counters, model_path) gm.graph.lint() gm.recompile() return gm + + def _process_single_node(self, node, arg_iter, counters, model_path): + if "original_name" not in node.meta: + node.meta["original_name"] = node.name + if node.op == "placeholder": + self._handle_placeholder(node, arg_iter, counters, model_path) + elif node.op == "get_attr": + self._apply_rename(node, "w", counters) + elif node.op != "output": + self._apply_rename(node, "tmp", counters) + else: + # Do nothing + pass + + def _handle_placeholder(self, node, arg_iter, counters, model_path): + real_arg = next(arg_iter, None) + is_weight = self._is_weight_node(node, real_arg, model_path) + prefix = "w" if is_weight else "in" + self._apply_rename(node, prefix, counters, update_target=True) + + def _apply_rename(self, node, prefix, counters, update_target=False): + new_name = f"{prefix}_{counters[prefix]}" + counters[prefix] += 1 + node.name = new_name + if update_target: + node.target = new_name + + def _is_weight_node(self, node, real_arg, model_path): + is_not_data_input = not self.data_input_predicator(model_path, node.name) + is_parameter_type = ( + node.type is not None + and isinstance(node.type, type) + and issubclass(node.type, torch.nn.parameter.Parameter) + ) + is_parameter_value = isinstance(real_arg, torch.nn.Parameter) + return is_not_data_input or is_parameter_type or is_parameter_value diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py old mode 100644 new mode 100755 index 601c00f06..781aab970 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -25,6 +25,9 @@ from graph_net.torch.backend.range_decomposer_validator_backend import ( RangeDecomposerValidatorBackend, ) +from graph_net.torch.backend.graph_variable_renamer_validator_backend import ( + GraphVariableRenamerValidatorBackend, +) from graph_net import test_compiler_util from graph_net import path_utils @@ -38,6 +41,7 @@ "nope": NopeBackend(), "unstable_to_stable": UnstableToStableBackend(), "range_decomposer_validator": RangeDecomposerValidatorBackend(), + "graph_variable_renamer_validator": GraphVariableRenamerValidatorBackend(), }