diff --git a/graph_net/config/get_fusible_subgraph_sample_list.txt b/graph_net/config/get_fusible_subgraph_sample_list.txt new file mode 100644 index 000000000..1a06bed01 --- /dev/null +++ b/graph_net/config/get_fusible_subgraph_sample_list.txt @@ -0,0 +1,100 @@ +samples/timm/crossvit_small_240.in1k +samples/timm/poolformerv2_s12.sail_in1k +samples/timm/regnety_080.pycls_in1k +samples/timm/dla46x_c.in1k +samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k +samples/timm/efficientnetv2_rw_s.ra2_in1k +samples/timm/vit_base_patch16_rope_ape_224.naver_in1k +samples/timm/fastvit_t8.apple_dist_in1k +samples/timm/test_byobnet.r160_in1k +samples/timm/mambaout_base.in1k +samples/timm/davit_small +samples/timm/resnet61q.ra2_in1k +samples/timm/coat_tiny +samples/timm/regnetx_004.pycls_in1k +samples/timm/convnextv2_large.fcmae +samples/timm/regnety_640.seer +samples/timm/repvit_m1_1.dist_300e_in1k +samples/timm/tinynet_d.in1k +samples/timm/resnetrs270.tf_in1k +samples/timm/cait_m48_448 +samples/timm/legacy_seresnet50.in1k +samples/timm/tinynet_a.in1k +samples/timm/convnext_small.fb_in1k +samples/timm/vit_huge_patch14_clip_quickgelu_224.dfn5b +samples/timm/dpn131.mx_in1k +samples/timm/convnextv2_large.fcmae_ft_in1k +samples/timm/convnextv2_small +samples/timm/repvit_m1.dist_in1k +samples/timm/cs3darknet_s +samples/timm/resnet50d.a1_in1k +samples/timm/dm_nfnet_f6 +samples/timm/coatnet_1_rw_224 +samples/timm/lcnet_050.ra2_in1k +samples/timm/efficientnet_em.ra2_in1k +samples/timm/dpn48b +samples/timm/semnasnet_075.rmsp_in1k +samples/timm/skresnet34.ra_in1k +samples/timm/crossvit_15_dagger_240.in1k +samples/timm/mnasnet_100.rmsp_in1k +samples/timm/mobilenetv3_rw.rmsp_in1k +samples/timm/xception65p.ra3_in1k +samples/timm/coatnet_0_rw_224 +samples/timm/eca_nfnet_l3 +samples/timm/deit3_base_patch16_224.fb_in1k +samples/timm/mambaout_base_short_rw.sw_e500_in1k +samples/timm/mobilenetv4_conv_small.e1200_r224_in1k +samples/timm/xception71.tf_in1k +samples/timm/dla60.in1k +samples/timm/repghostnet_130.in1k +samples/timm/mambaout_base_plus_rw.sw_e150_in12k +samples/timm/poolformerv2_s36.sail_in1k +samples/timm/deit3_huge_patch14_224.fb_in1k +samples/timm/vit_base_patch32_clip_224.datacompxl +samples/timm/poolformer_m48.sail_in1k +samples/timm/regnety_006.pycls_in1k +samples/timm/starnet_s4.in1k +samples/timm/poolformer_m36.sail_in1k +samples/timm/vit_huge_patch14_gap_224.in1k_ijepa +samples/timm/efficientnet_b3.ra2_in1k +samples/timm/mobilenetv3_large_150d.ra4_e3600_r256_in1k +samples/timm/hgnetv2_b0.ssld_stage1_in22k_in1k +samples/timm/convnextv2_huge.fcmae +samples/timm/davit_huge +samples/timm/regnetx_004_tv.tv2_in1k +samples/timm/dla34.in1k +samples/timm/convnext_xlarge.fb_in22k +samples/timm/resmlp_12_224.fb_dino +samples/timm/fasternet_t1.in1k +samples/timm/resnetblur50.bt_in1k +samples/timm/res2net50d.in1k +samples/timm/vit_base_patch32_224.augreg_in1k +samples/timm/mambaout_base_wide_rw.sw_e500_in1k +samples/timm/vgg19_bn.tv_in1k +samples/timm/vit_small_patch16_rope_ape_224.naver_in1k +samples/timm/hardcorenas_b.miil_green_in1k +samples/timm/vgg16.tv_in1k +samples/timm/xception41p.ra3_in1k +samples/timm/efficientnet_lite0.ra_in1k +samples/timm/regnetv_064.ra3_in1k +samples/timm/regnety_320.pycls_in1k +samples/timm/convnext_pico.d1_in1k +samples/timm/repvit_m1_0.dist_300e_in1k +samples/timm/resnet50c.gluon_in1k +samples/timm/mobileone_s4.apple_in1k +samples/timm/ghostnet_100.in1k +samples/timm/deit_base_distilled_patch16_384 +samples/timm/dpn68b.mx_in1k +samples/timm/dla60_res2next +samples/timm/resnet101d.gluon_in1k +samples/timm/eva02_large_patch14_clip_224.merged2b +samples/timm/fasternet_m.in1k +samples/timm/mobilenetv2_110d.ra_in1k +samples/timm/regnetx_064.pycls_in1k +samples/timm/cspresnet50.ra_in1k +samples/timm/resmlp_24_224.fb_dino +samples/timm/mobileone_s3.apple_in1k +samples/timm/mobileone_s2.apple_in1k +samples/timm/res2net101d +samples/timm/hardcorenas_f.miil_green_in1k +samples/timm/hrnet_w18_ssld.paddle_in1k diff --git a/graph_net/config/small_sample_list_for_get_fusible_subgraph.txt b/graph_net/config/small_sample_list_for_get_fusible_subgraph.txt new file mode 100644 index 000000000..3ea9a1a9f --- /dev/null +++ b/graph_net/config/small_sample_list_for_get_fusible_subgraph.txt @@ -0,0 +1,10 @@ +#samples/timm/crossvit_small_240.in1k +#samples/timm/poolformerv2_s12.sail_in1k +#samples/timm/regnety_080.pycls_in1k +#samples/timm/dla46x_c.in1k +#samples/timm/mobilenetv1_100.ra4_e3600_r224_in1k +samples/timm/efficientnetv2_rw_s.ra2_in1k +samples/timm/vit_base_patch16_rope_ape_224.naver_in1k +#samples/timm/fastvit_t8.apple_dist_in1k +#samples/timm/test_byobnet.r160_in1k +#samples/timm/mambaout_base.in1k diff --git a/graph_net/test/fully_fusible_subgraph_extractor_test.sh b/graph_net/test/fully_fusible_subgraph_extractor_test.sh new file mode 100755 index 000000000..81992956b --- /dev/null +++ b/graph_net/test/fully_fusible_subgraph_extractor_test.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( +os.path.dirname(graph_net.__file__))") + +# input model path +MODEL_NAME=resnet18 +MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME +# INPUT_MODEL_LIST=$GRAPH_NET_ROOT/config/get_fusible_subgraph_sample_list.txt +INPUT_MODEL_LIST=$GRAPH_NET_ROOT/config/small_sample_list_for_get_fusible_subgraph.txt + +OUTPUT_DIR="/tmp/find_fully_fusible_output" +config_json_str=$(cat < Type[torch.nn.Module]: + spec = importlib.util.spec_from_file_location("unnamed", file_path) + unnamed = importlib.util.module_from_spec(spec) + spec.loader.exec_module(unnamed) + model_class = getattr(unnamed, class_name, None) + return model_class + + +def count_kernels(model, sample_inputs) -> int: + """ + Count the number of CUDA kernel launches performed during a model's forward pass. + + Args: + model(graph models) + sample_inputs(tensors) + + Returns: + int: The number of kernels used. + + Behavior: + - Runs the model once inside a PyTorch profiler context. + - Identifies the event with key = 'cudaLaunchKernel', which corresponds + to the number of CUDA kernel launches. + """ + model.eval() + # Use PyTorch Profiler + + with profile( + activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], + record_shapes=True, + ) as prof: + with record_function("model_inference"): + if isinstance(sample_inputs, dict): + ret_tensors = model(**sample_inputs) + elif isinstance(sample_inputs, (list, tuple)): + ret_tensors = model(*sample_inputs) + else: + raise NotImplementedError(f"{type(sample_inputs)=}") + + events = prof.key_averages() + + total_count = 0 + for e in events: + if e.key == "cuLaunchKernel" or e.key == "cudaLaunchKernel": + total_count += e.count + return ret_tensors, total_count diff --git a/graph_net/torch/decompose_util.py b/graph_net/torch/decompose_util.py index b8617b899..5f447e271 100755 --- a/graph_net/torch/decompose_util.py +++ b/graph_net/torch/decompose_util.py @@ -181,7 +181,7 @@ def fold_range_to_submodule( end_node_idx: int, submodule_hook=None, submodule_name="extracted_submodule", - group_head_and_tail=True, + group_head_and_tail=False, ): return convert_to_submodules_graph( gm, @@ -249,7 +249,9 @@ def get_args_node(arg): yield arg.stop yield arg.step else: - assert isinstance(arg, (int, bool, float, str, type(None))), f"{type(arg)=}" + assert isinstance( + arg, (int, bool, float, str, type(...), type(None)) + ), f"{type(arg)=}" def get_args_node_and_self_node(node): for arg in node.args: diff --git a/graph_net/torch/fully_fusible_graph_predicator.py b/graph_net/torch/fully_fusible_graph_predicator.py new file mode 100644 index 000000000..bbce85e04 --- /dev/null +++ b/graph_net/torch/fully_fusible_graph_predicator.py @@ -0,0 +1,102 @@ +import torch +import traceback +import logging +from graph_net.imp_util import load_module +from graph_net.torch.decompose_util import fold_range_to_submodule +from graph_net.torch.graph_decomposer import NaiveDecomposerExtractor +from graph_net.torch.graph_fusibility_status import ( + GraphFusibilityStatus, + GraphFusibility, +) +from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs +from graph_net.torch.fx_graph_cache_util import ( + parse_immutable_model_path_into_sole_graph_module, +) + +logger = logging.getLogger(__name__) + + +class FullyFusibleGraphPredicator: + def __init__(self, config=None): + if config is None: + config = {} + self.config = config + handler_config = self.config["handler_config"] + self.decomposer_extractor = NaiveDecomposerExtractor(handler_config) + + def __call__(self, model_path): + try: + self.decomposer_extractor(model_path) + except GraphFusibilityStatus as status: + if status.graph_fusibility == GraphFusibility.kFullyFusible: + return True + elif status.graph_fusibility == GraphFusibility.kNotFullyFusible: + return False + else: + raise NotImplementedError(f"{status.graph_fusibility=}") + except Exception: + print("\n--- Custom Error Handler ---") + traceback.print_exc() + print("--------------------------\n") + return False + + +class FullyFusibleSubGraphPredicator: + def __init__(self, config): + if config is None: + config = {} + self.config = self._make_config(**config) + self.nn_module_fully_fusible_decorator = ( + self._make_nn_module_fully_fusible_decorator(config) + ) + model_path = self.config["model_path"] + module, inputs = get_torch_module_and_inputs(model_path) + self.traced_module = parse_immutable_model_path_into_sole_graph_module( + model_path + ) + self.inputs = inputs + + def _make_nn_module_fully_fusible_decorator(self, config): + py_module = load_module(self.config["nn_module_fully_fusible_decorator_path"]) + decorator_cls = getattr( + py_module, self.config["nn_module_fully_fusible_decorator_class_name"] + ) + return decorator_cls(self.config["nn_module_fully_fusible_decorator_config"]) + + def _make_config( + self, + model_path, + nn_module_fully_fusible_decorator_path, + nn_module_fully_fusible_decorator_class_name, + nn_module_fully_fusible_decorator_config=None, + ): + if nn_module_fully_fusible_decorator_config is None: + nn_module_fully_fusible_decorator_config = {} + return { + "model_path": model_path, + "nn_module_fully_fusible_decorator_path": nn_module_fully_fusible_decorator_path, + "nn_module_fully_fusible_decorator_class_name": nn_module_fully_fusible_decorator_class_name, + "nn_module_fully_fusible_decorator_config": nn_module_fully_fusible_decorator_config, + } + + def __call__(self, start_node_idx, end_node_idx): + try: + rewrited_gm: torch.fx.GraphModule = fold_range_to_submodule( + self.traced_module, + start_node_idx=start_node_idx, + end_node_idx=end_node_idx, + submodule_hook=self.nn_module_fully_fusible_decorator, + ) + rewrited_gm(*self.inputs) + except GraphFusibilityStatus as status: + if status.graph_fusibility == GraphFusibility.kFullyFusible: + return True + elif status.graph_fusibility == GraphFusibility.kNotFullyFusible: + return False + else: + raise NotImplementedError(f"{status.graph_fusibility=}") + except Exception: + print("\n--- Custom Error Handler ---") + traceback.print_exc() + print("--------------------------\n") + return False diff --git a/graph_net/torch/fully_fusible_subgraph_extractor.py b/graph_net/torch/fully_fusible_subgraph_extractor.py index 13384650a..7e5ee4990 100644 --- a/graph_net/torch/fully_fusible_subgraph_extractor.py +++ b/graph_net/torch/fully_fusible_subgraph_extractor.py @@ -1,51 +1,45 @@ import os import torch -import graph_net +from pathlib import Path import tempfile import shutil -from graph_net.torch import constraint_util +from graph_net.torch.graph_decomposer import NaiveDecomposerExtractor +from graph_net.torch.fully_fusible_graph_predicator import ( + FullyFusibleSubGraphPredicator, +) +import logging +logger = logging.getLogger(__name__) -class GraphExtractor: - def __init__( - self, - config: dict, - name, - dynamic, - mut_graph_codes=None, - placeholder_auto_rename=False, - ): - self.subgraph_counter = 0 - self.name = name - self.dynamic = dynamic - self.mut_graph_codes = mut_graph_codes - self.placeholder_auto_rename = placeholder_auto_rename - self.config = self.make_config(**config) - def make_config( +class FullyFusibleSubgraphExtractor: + def __init__(self, config: dict = None): + if config is None: + config = {} + self.config = self._make_config(**config) + + def _make_config( self, + nn_module_fully_fusible_decorator_path, + nn_module_fully_fusible_decorator_class_name, + nn_module_fully_fusible_decorator_config=None, output_dir=None, - split_positions=(), - group_head_and_tail=False, - chain_style=False, + resume: bool = True, max_step=8, min_step=2, max_nodes=32, - model_path=None, + model_path_prefix="", ): - for pos in split_positions: - assert isinstance( - pos, int - ), f"split_positions should be list of int, {split_positions=}" return { "output_dir": output_dir, - "split_positions": split_positions, - "group_head_and_tail": group_head_and_tail, - "chain_style": chain_style, + "resume": resume, + "nn_module_fully_fusible_decorator_path": nn_module_fully_fusible_decorator_path, + "nn_module_fully_fusible_decorator_class_name": nn_module_fully_fusible_decorator_class_name, + "nn_module_fully_fusible_decorator_config": nn_module_fully_fusible_decorator_config, "max_step": max_step, "min_step": min_step, "max_nodes": max_nodes, - "model_path": model_path, + "model_path_prefix": model_path_prefix, } def _get_sub_ranges(self): @@ -66,56 +60,77 @@ def _get_sub_ranges(self): ), f"Invalid range generated: start={start_pos}, end={end_pos}, max={self.config['max_nodes']}" yield start_pos, end_pos - def _handle_success(self, temp_dir: str, start_pos: int, end_pos: int) -> str: - target_name = f"{self.name}_start{start_pos}_end{end_pos}" + def _copy_from_tmp_dir_to_output_dir( + self, temp_dir: str, rel_model_path: str + ) -> str: + subdirs = list(Path(temp_dir).iterdir()) + assert len(subdirs) == 1 + temp_dir = str(subdirs[0]) target_path = os.path.join( self.config["output_dir"], - target_name, + rel_model_path, ) os.makedirs(target_path, exist_ok=True) - shutil.move(temp_dir, target_path) + shutil.copytree(temp_dir, target_path, dirs_exist_ok=True) return target_path def _build_decompose_config( self, temp_dir: str, start_pos: int, end_pos: int ) -> dict: - self.config["split_positions"] = [start_pos, end_pos] - graph_net_root = os.path.dirname(graph_net.__file__) + model_path_prefix = self.config["model_path_prefix"] + decomposer_config = { + "model_path_prefix": model_path_prefix, + "output_dir": temp_dir, + "split_positions": [start_pos, end_pos], + "group_head_and_tail": False, + } + return decomposer_config - check_fusible_config = { - "decorator_path": f"{graph_net_root}/torch/extractor.py", - "decorator_config": { - "name": f"{self.name}", - "custom_extractor_path": f"{graph_net_root}/torch/graph_decomposer.py", - "custom_extractor_config": { - "output_dir": temp_dir, - "split_positions": self.config["split_positions"], - "group_head_and_tail": False, - "filter_path": f"{graph_net_root}/torch/naive_subgraph_filter.py", - "filter_config": {}, - "post_extract_process_path": f"{graph_net_root}/torch/post_extract_process_count_kernels.py", - "post_extract_process_class_name": "GraphFullyFusible", - }, - }, + def _get_fully_fusible_subgraph_predicator(self, model_path): + config = { + "model_path": model_path, + "nn_module_fully_fusible_decorator_path": self.config[ + "nn_module_fully_fusible_decorator_path" + ], + "nn_module_fully_fusible_decorator_class_name": self.config[ + "nn_module_fully_fusible_decorator_class_name" + ], + "nn_module_fully_fusible_decorator_config": self.config[ + "nn_module_fully_fusible_decorator_config" + ], } - return check_fusible_config + return FullyFusibleSubGraphPredicator(config) - def __call__(self, gm: torch.fx.GraphModule, sample_inputs): + def _is_model_path_handled(self, rel_model_path): + model_path = Path(self.config["output_dir"]) / rel_model_path + return model_path.exists() and len(list(model_path.iterdir())) > 0 + + def __call__(self, rel_model_path): + if self.config["resume"] and self._is_model_path_handled(rel_model_path): + return + torch.cuda.empty_cache() + model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) + fully_fusible_subgraph_predicator = self._get_fully_fusible_subgraph_predicator( + model_path + ) for start_pos, end_pos in self._get_sub_ranges(): + logger.warning("fully_fusible_subgraph_predicator-begin") + success = fully_fusible_subgraph_predicator(start_pos, end_pos) + logger.warning("fully_fusible_subgraph_predicator-end") + if not success: + continue with tempfile.TemporaryDirectory( prefix="_find_fusible_subgraph_" ) as temp_dir: - check_fusible_config = self._build_decompose_config( + decomposer_config = self._build_decompose_config( temp_dir, start_pos, end_pos ) - print("current split_positions:", self.config["split_positions"]) - success = constraint_util.RunModelPredicator(check_fusible_config)( - self.config["model_path"] + naive_graph_decomposer = NaiveDecomposerExtractor(decomposer_config) + logger.warning("naive_graph_decomposer-begin") + naive_graph_decomposer(rel_model_path) + logger.warning("naive_graph_decomposer-end") + fully_fusible_destination_path = self._copy_from_tmp_dir_to_output_dir( + temp_dir, rel_model_path ) - if success: - target_path = self._handle_success(temp_dir, start_pos, end_pos) - print( - f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: {target_path}" - ) - break - return gm.forward + print(f"{fully_fusible_destination_path=}") + return diff --git a/graph_net/torch/fx_graph_parse_util.py b/graph_net/torch/fx_graph_parse_util.py index 57a12755c..e8e394f4f 100644 --- a/graph_net/torch/fx_graph_parse_util.py +++ b/graph_net/torch/fx_graph_parse_util.py @@ -229,7 +229,7 @@ def zip_filter_names_str(): from pathlib import Path Path("/tmp/a.py").write_text(traced_module.code) - assert len(zip_filter_names) == 0, f"{zip_filter_names_str()=}" + # assert len(zip_filter_names) == 0, f"{zip_filter_names_str()=}" return traced_module diff --git a/graph_net/torch/graph_decomposer.py b/graph_net/torch/graph_decomposer.py index 2723f59bf..5d1eaec37 100644 --- a/graph_net/torch/graph_decomposer.py +++ b/graph_net/torch/graph_decomposer.py @@ -5,7 +5,13 @@ from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor import graph_net.imp_util as imp_util from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs +from graph_net.torch.fx_graph_cache_util import ( + parse_immutable_model_path_into_sole_graph_module, +) from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module +import logging + +logger = logging.getLogger(__name__) def load_json(file_path): @@ -99,10 +105,10 @@ def __init__(self, config: dict = None): def _make_config( self, + output_dir, split_positions=(), group_head_and_tail=False, chain_style=False, - output_dir="./tmp/naive_decomposer_dir", filter_path=None, filter_config=None, post_extract_process_path=None, @@ -138,13 +144,17 @@ def __call__(self, rel_model_path): if k in {"split_positions", "group_head_and_tail", "chain_style"} } module, inputs = get_torch_module_and_inputs(model_path) - gm = parse_sole_graph_module(module, inputs) - rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph( - gm, - submodule_hook=self.get_naive_decomposer_extractor(model_path), - **config, - ) - rewrited_gm(*inputs) + gm = parse_immutable_model_path_into_sole_graph_module(model_path) + try: + logger.warning("convert_to_submodules_graph-call-begin") + rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph( + gm, + submodule_hook=self.get_naive_decomposer_extractor(model_path), + **config, + ) + rewrited_gm(*inputs) + finally: + logger.warning("convert_to_submodules_graph-call-end") def get_naive_decomposer_extractor(self, model_path): def fn(submodule, seq_no): @@ -248,6 +258,7 @@ def __init__( self.submodule = submodule self.seq_no = seq_no self.extracted = False + self.parent_graph_name = parent_graph_name if self.seq_no is None: self.model_name = parent_graph_name else: @@ -265,12 +276,21 @@ def __init__( self.filter = self.make_filter(self.config) self.post_extract_process = self.make_post_extract_process(self.config) + def _get_model_path(self): + return os.path.join( + self.config["output_dir"], + f"{self.parent_graph_name}_decomposed", + self.model_name, + ) + def forward(self, *args): + logger.warning("naive decomposer forwarding") if not self.extracted: if self.need_extract(self.submodule, args): self.builtin_extractor(self.submodule, args) + self._post_extract_process() self.extracted = True - self._post_extract_process() + logger.warning("naive decomposer end") return self.submodule(*args) def need_extract(self, gm, sample_inputs): @@ -279,7 +299,7 @@ def need_extract(self, gm, sample_inputs): return self.filter(gm, sample_inputs) def _post_extract_process(self): - model_path = os.path.join(self.config["output_dir"], self.model_name) + model_path = self._get_model_path() return self.post_extract_process(model_path) def make_filter(self, config): diff --git a/graph_net/torch/graph_fusibility_status.py b/graph_net/torch/graph_fusibility_status.py new file mode 100644 index 000000000..035de033e --- /dev/null +++ b/graph_net/torch/graph_fusibility_status.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class GraphFusibility(Enum): + kFullyFusible = "fully_fusible" + kNotFullyFusible = "not_fully_fusible" + + +class GraphFusibilityStatus(Exception): + def __init__(self, graph_fusibility: GraphFusibility): + message = f"{graph_fusibility=}" + super().__init__(message) + self.graph_fusibility = graph_fusibility diff --git a/graph_net/torch/post_extract_process_count_kernels.py b/graph_net/torch/post_extract_process_count_kernels.py deleted file mode 100644 index f4cb6ab75..000000000 --- a/graph_net/torch/post_extract_process_count_kernels.py +++ /dev/null @@ -1,85 +0,0 @@ -from graph_net.torch import utils -import importlib.util -import torch -import sys -from typing import Type -from torch.profiler import profile, record_function, ProfilerActivity - - -class GraphFullyFusible: - def __init__(self, config): - self.config = config - - def __call__(self, model_path=None): - torch._dynamo.reset() - if model_path is None: - sys.exit(1) - # model - model_class = load_class_from_file( - f"{model_path}/model.py", class_name="GraphModule" - ) - assert model_class is not None - model = model_class() - # print(f"{model_path=}") - - inputs_params = utils.load_converted_from_text(f"{model_path}") - params = inputs_params["weight_info"] - state_dict = {k: utils.replay_tensor(v) for k, v in params.items()} - - # try to run the model - try: - model(**state_dict) - except Exception: - sys.exit(1) - # try to compile the model - try: - compiled_model = torch.compile(model) - except Exception: - sys.exit(1) - compiled_num_of_kernels = count_kernels(compiled_model, state_dict) - if compiled_num_of_kernels == 1: - sys.exit(0) - else: - sys.exit(1) - - -def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: - spec = importlib.util.spec_from_file_location("unnamed", file_path) - unnamed = importlib.util.module_from_spec(spec) - spec.loader.exec_module(unnamed) - model_class = getattr(unnamed, class_name, None) - return model_class - - -def count_kernels(model, sample_inputs) -> int: - """ - Count the number of CUDA kernel launches performed during a model's forward pass. - - Args: - model(graph models) - sample_inputs(tensors) - - Returns: - int: The number of kernels used. - - Behavior: - - Runs the model once inside a PyTorch profiler context. - - Identifies the event with key = 'cudaLaunchKernel', which corresponds - to the number of CUDA kernel launches. - """ - model.eval() - # Use PyTorch Profiler - - with profile( - activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], - record_shapes=True, - ) as prof: - with record_function("model_inference"): - _ = model(**sample_inputs) - events = prof.key_averages() - - total_count = 0 - for e in events: - if e.key == "cuLaunchKernel" or e.key == "cudaLaunchKernel": - total_count += e.count - return total_count diff --git a/model_100_list.txt b/model_100_list.txt new file mode 100644 index 000000000..e69de29bb