diff --git a/graph_net/tools/typical_sequence_decompose.sh b/graph_net/tools/typical_sequence_decompose.sh index 5c9150cec..db7f826cc 100755 --- a/graph_net/tools/typical_sequence_decompose.sh +++ b/graph_net/tools/typical_sequence_decompose.sh @@ -68,6 +68,7 @@ python3 -m graph_net.model_path_handler \ "handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_variable_renamer.py", "handler_class_name": "GraphVariableRenamer", "handler_config": { + "resume": true, "model_path_prefix": "$DECOMPOSE_WORKSPACE", "data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py", "data_input_predicator_class_name": "NaiveDataInputPredicator", diff --git a/graph_net/torch/graph_decomposer.py b/graph_net/torch/graph_decomposer.py index ed01cb5b6..dba668801 100644 --- a/graph_net/torch/graph_decomposer.py +++ b/graph_net/torch/graph_decomposer.py @@ -3,6 +3,7 @@ from pathlib import Path import torch import json +import sys from graph_net.torch.decompose_util import convert_to_submodules_graph from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor import graph_net.imp_util as imp_util @@ -209,6 +210,12 @@ def __call__(self, rel_model_path): ) model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) split_results = load_json(self.config["split_results_path"]) + if ( + split_results[rel_model_path]["split_positions"] is None + or len(split_results[rel_model_path]["split_positions"]) == 0 + ): + sys.stderr.write(f"Error: {rel_model_path} has no split positions.\n") + return split_positions = split_results[rel_model_path]["split_positions"] if self.config["resume"] and self._is_model_handled( rel_model_path, split_positions diff --git a/graph_net/torch/graph_variable_renamer.py b/graph_net/torch/graph_variable_renamer.py index e8e58712b..ae11f0242 100755 --- a/graph_net/torch/graph_variable_renamer.py +++ b/graph_net/torch/graph_variable_renamer.py @@ -2,6 +2,7 @@ import torch import shutil import inspect +import tempfile from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module from graph_net.tensor_meta import TensorMeta @@ -37,8 +38,9 @@ def _make_model_runnable_predicator(self, config): def _make_config( self, - data_input_predicator_filepath, - model_runnable_predicator_filepath, + resume: bool = False, + data_input_predicator_filepath=None, + model_runnable_predicator_filepath=None, output_dir="./tmp/graph_variable_renamer_dir", filter_path=None, filter_config=None, @@ -59,6 +61,7 @@ def _make_config( if model_runnable_predicator_config is None: model_runnable_predicator_config = {} return { + "resume": resume, "output_dir": output_dir, "filter_path": filter_path, "filter_config": filter_config if filter_config is not None else {}, @@ -82,12 +85,20 @@ def __call__(self, rel_model_path): dst_model_path = os.path.realpath( os.path.join(self.config["output_dir"], rel_model_path) ) + if self.config["resume"] and os.path.exists( + os.path.join(dst_model_path, "model.py") + ): + return Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True) - shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True) - self._update_model_py_file(gm, dst_model_path) - self._update_weight_meta_py_file(src_model_path, dst_model_path) - self._update_input_meta_py_file(src_model_path, dst_model_path) - self._try_run(dst_model_path) + with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir: + temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path)) + shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True) + self._update_model_py_file(gm, temp_model_path) + self._update_weight_meta_py_file(src_model_path, temp_model_path) + self._update_input_meta_py_file(src_model_path, temp_model_path) + print("Try to run renamed model...") + self._try_run(temp_model_path) + shutil.copytree(temp_model_path, dst_model_path) def _try_run(self, model_path): assert self.model_runnable_predicator( diff --git a/graph_net/torch/typical_sequence_split_points.py b/graph_net/torch/typical_sequence_split_points.py index dc0f669d4..aeece76d0 100644 --- a/graph_net/torch/typical_sequence_split_points.py +++ b/graph_net/torch/typical_sequence_split_points.py @@ -116,7 +116,9 @@ def _resolve_token_to_ops( return [f"Unknown({tid})"] def _load_op_names_from_file(self, txt_path: Path) -> List[str]: - assert txt_path.exists(), f"{str(txt_path)=}" + if not txt_path.exists(): + print(f"File not found: {txt_path}") + return [] return txt_path.read_text().split("\n") def _calculate_token_lengths(