Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion graph_net/test/graph_variable_rename_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
os.path.dirname(graph_net.__file__))")
WORKSPACE=/tmp/graph_variable_rename_workspace

# input model path
MODEL_NAME=resnet18
Expand All @@ -16,7 +17,7 @@ config_json_str=$(cat <<EOF
"data_input_predicator_class_name": "NaiveDataInputPredicator",
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
"output_dir": "/tmp/graph_variable_rename_workspace"
"output_dir": "$WORKSPACE"
}
}
EOF
Expand All @@ -25,3 +26,23 @@ CONFIG=$(echo $config_json_str | base64 -w 0)

python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
# python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG

test_compiler_config_json_str=$(cat <<EOF
{
"model_path_prefix": "$GRAPH_NET_ROOT",
"renamed_root": "$WORKSPACE"
}
EOF
)
TEST_COMPILER_CONFIG=$(echo $test_compiler_config_json_str | base64 -w 0)

python3 -m graph_net.torch.test_compiler \
--model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES \
--compiler graph_variable_renamer_validator \
--device cuda \
--config $TEST_COMPILER_CONFIG \
> "$WORKSPACE/validation.log" 2>&1

python3 -m graph_net.plot_ESt \
--benchmark-path "$WORKSPACE/validation.log" \
--output-dir "$WORKSPACE"
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from pathlib import Path
from typing import Dict
from graph_net.tensor_meta import TensorMeta
import os
import importlib.util


class RenamedModelAdapter(torch.nn.Module):
def __init__(self, renamed_model: torch.nn.Module, mapping: Dict[str, str]):
super().__init__()
self.model = renamed_model
self.mapping = mapping
if hasattr(renamed_model, "__graph_net_file_path__"):
self.__graph_net_file_path__ = renamed_model.__graph_net_file_path__

def forward(self, **kwargs):
new_kwargs = self._convert_by_name_mapping(kwargs)
return self.model(**new_kwargs)

def _convert_by_name_mapping(self, kwargs):
new_kwargs = {}
for old_name, value in kwargs.items():
if old_name in self.mapping:
new_name = self.mapping[old_name]
new_kwargs[new_name] = value
return new_kwargs


class GraphVariableRenamerValidatorBackend:
def _get_rename_mapping(self, model_dir: Path):
mapping = {}
for meta_file in ["input_meta.py", "weight_meta.py"]:
meta_path = model_dir / meta_file
if not meta_path.exists():
continue
metas = TensorMeta.unserialize_from_py_file(str(meta_path))
for m in metas:
if m.original_name:
mapping[m.original_name] = m.name
return mapping

def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
class_name = "GraphModule"
model_file = os.path.join(path, "model.py")

spec = importlib.util.spec_from_file_location(class_name, model_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

ModelClass = getattr(module, class_name)
instance = ModelClass().to(device)
return instance

def _make_config(
self,
model_path_prefix: str,
renamed_root: str,
renamed_dentry: str = "_renamed",
):
return {
"model_path_prefix": model_path_prefix,
"renamed_root": renamed_root,
"renamed_dentry": renamed_dentry,
}

def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
config = self._make_config(**self.config)
model_path = os.path.dirname(model.__class__.__graph_net_file_path__)
model_name = os.path.basename(model_path)
renamed_dir_name = f"{model_name}_renamed"
renamed_model_dir = os.path.join(config["renamed_root"], renamed_dir_name)

print(f"[GraphVariableRenamerValidatorBackend] Processing: {model_name}")
print(
f"[GraphVariableRenamerValidatorBackend] Loading from: {renamed_model_dir}"
)

device = model.__class__.__graph_net_device__
renamed_model = self._load_model_instance(renamed_model_dir, device)
mapping = self._get_rename_mapping(Path(renamed_model_dir))
assert (
mapping
), f"Mapping is empty for {renamed_dir_name} at {renamed_model_dir}"
adapter = RenamedModelAdapter(renamed_model, mapping)
return adapter.eval()

def synchronize(self):
if torch.cuda.is_available():
torch.cuda.synchronize()
82 changes: 43 additions & 39 deletions graph_net/torch/graph_variable_renamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def __call__(self, rel_model_path):
module, inputs = get_torch_module_and_inputs(src_model_path)
gm = parse_sole_graph_module(module, inputs)
gm = self.rename_graph_variables(gm, inputs, src_model_path)
model_name = os.path.basename(rel_model_path.rstrip(os.sep))
new_rel_path = f"{model_name}_renamed"
dst_model_path = os.path.realpath(
os.path.join(self.config["output_dir"], rel_model_path)
os.path.join(self.config["output_dir"], new_rel_path)
)
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
Expand Down Expand Up @@ -158,45 +160,47 @@ def _get_input_names_from_signature(self, module):
def rename_graph_variables(
self, gm: torch.fx.GraphModule, sample_inputs, model_path
):
in_cnt = 0
w_cnt = 0
tmp_cnt = 0

arg_iter = iter(sample_inputs)
counters = {"in": 0, "w": 0, "tmp": 0}
# graph may not have input, only contain weights
arg_iter = iter(sample_inputs) if sample_inputs else iter([])
for node in gm.graph.nodes:
if "original_name" not in node.meta:
node.meta["original_name"] = node.name

if node.op == "placeholder":
real_arg = next(arg_iter)
is_weight = not self.data_input_predicator(model_path, node.name)
if node.type is not None:
if isinstance(node.type, type) and issubclass(
node.type, torch.nn.parameter.Parameter
):
is_weight = True
elif real_arg is not None:
if isinstance(real_arg, torch.nn.Parameter):
is_weight = True

if is_weight:
new_name = f"w_{w_cnt}"
w_cnt += 1
else:
new_name = f"in_{in_cnt}"
in_cnt += 1

node.name = new_name
node.target = new_name

elif node.op == "get_attr":
node.name = f"w_{w_cnt}"
w_cnt += 1

elif node.op != "output":
node.name = f"tmp_{tmp_cnt}"
tmp_cnt += 1

self._process_single_node(node, arg_iter, counters, model_path)
gm.graph.lint()
gm.recompile()
return gm

def _process_single_node(self, node, arg_iter, counters, model_path):
if "original_name" not in node.meta:
node.meta["original_name"] = node.name
if node.op == "placeholder":
self._handle_placeholder(node, arg_iter, counters, model_path)
elif node.op == "get_attr":
self._apply_rename(node, "w", counters)
elif node.op != "output":
self._apply_rename(node, "tmp", counters)
else:
# Do nothing
pass

def _handle_placeholder(self, node, arg_iter, counters, model_path):
real_arg = next(arg_iter, None)
is_weight = self._is_weight_node(node, real_arg, model_path)
prefix = "w" if is_weight else "in"
self._apply_rename(node, prefix, counters, update_target=True)

def _apply_rename(self, node, prefix, counters, update_target=False):
new_name = f"{prefix}_{counters[prefix]}"
counters[prefix] += 1
node.name = new_name
if update_target:
node.target = new_name

def _is_weight_node(self, node, real_arg, model_path):
is_not_data_input = not self.data_input_predicator(model_path, node.name)
is_parameter_type = (
node.type is not None
and isinstance(node.type, type)
and issubclass(node.type, torch.nn.parameter.Parameter)
)
is_parameter_value = isinstance(real_arg, torch.nn.Parameter)
return is_not_data_input or is_parameter_type or is_parameter_value
4 changes: 4 additions & 0 deletions graph_net/torch/test_compiler.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from graph_net.torch.backend.range_decomposer_validator_backend import (
RangeDecomposerValidatorBackend,
)
from graph_net.torch.backend.graph_variable_renamer_validator_backend import (
GraphVariableRenamerValidatorBackend,
)
from graph_net import test_compiler_util
from graph_net import path_utils

Expand All @@ -38,6 +41,7 @@
"nope": NopeBackend(),
"unstable_to_stable": UnstableToStableBackend(),
"range_decomposer_validator": RangeDecomposerValidatorBackend(),
"graph_variable_renamer_validator": GraphVariableRenamerValidatorBackend(),
}


Expand Down