-
Notifications
You must be signed in to change notification settings - Fork 44
check if a graph can be fully fused into a single cuda kernel #381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 21 commits
986793a
282a32d
0187ae0
5d46f55
39b4139
b775e46
44ad76f
0fc84c4
19dc60b
d6eda81
956ad33
8c8070b
ef7d4b6
181b293
2aac268
00d5b4b
75c3e61
fe89add
ca860b3
c21717f
de54e88
9363023
adff744
ca20508
fc0071c
9a28d45
513cc38
4847ee3
6538119
22a2772
684dba9
bfe0848
f8cc102
f131cfb
f7f3d2a
353e7bd
b703458
0b687cf
e70b44b
7dbb6e9
f71b56b
93fabbf
6df0cd0
babdde5
48467f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| #!/bin/bash | ||
| # bash graph_net/test/naive_decomposer_and_post_extract_process_test.sh | ||
|
|
||
| GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( | ||
| os.path.dirname(graph_net.__file__))") | ||
|
|
||
| # input model path | ||
| MODEL_NAME=resnet18 | ||
| MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME | ||
| decorator_config_json_str=$(cat <<EOF | ||
| { | ||
| "decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py", | ||
| "decorator_config": { | ||
| "name": "$MODEL_NAME", | ||
| "custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py", | ||
| "custom_extractor_config": { | ||
| "output_dir": "/tmp/naive_decompose_workspace", | ||
| "split_positions": [8, 16, 32], | ||
| "group_head_and_tail": true, | ||
| "filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py", | ||
| "filter_config": {}, | ||
| "post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py", | ||
| "post_extract_process_class_name": "PostExtractProcess" | ||
| } | ||
| } | ||
| } | ||
| EOF | ||
| ) | ||
| DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0) | ||
|
|
||
| python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,13 @@ | ||
| #!/bin/bash | ||
| SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) | ||
| GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR") | ||
| PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR") | ||
|
|
||
| # 将项目根目录加入Python路径 | ||
| export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" | ||
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( | ||
|
|
||
lixinqi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| from graph_net.torch import utils | ||
| import importlib.util | ||
| import shutil | ||
| import torch | ||
| from typing import Type | ||
| from torch.profiler import profile, record_function, ProfilerActivity | ||
|
|
||
|
|
||
| class PostExtractProcess: | ||
roll-away marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self, config): | ||
| self.config = config | ||
|
|
||
| def __call__(self, model_path=None): | ||
roll-away marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if model_path is None: | ||
| return False | ||
| # model | ||
| model_class = load_class_from_file( | ||
| f"{model_path}/model.py", class_name="GraphModule" | ||
| ) | ||
| assert model_class is not None | ||
| model = model_class() | ||
|
||
| print(f"{model_path=}") | ||
|
|
||
| inputs_params = utils.load_converted_from_text(f"{model_path}") | ||
| params = inputs_params["weight_info"] | ||
| state_dict = {k: utils.replay_tensor(v) for k, v in params.items()} | ||
|
|
||
| model(**state_dict) | ||
| compiled_model = torch.compile(model) | ||
roll-away marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| compiled_num_of_kernels = count_kernels(compiled_model, state_dict) | ||
| if compiled_num_of_kernels == 1: | ||
| print(model_path, "can be fully integrated") | ||
| return True | ||
| else: | ||
| print(model_path, "can not be fully integrated") | ||
| shutil.rmtree(model_path) | ||
roll-away marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return False | ||
|
|
||
|
|
||
| def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: | ||
| spec = importlib.util.spec_from_file_location("unnamed", file_path) | ||
| unnamed = importlib.util.module_from_spec(spec) | ||
| spec.loader.exec_module(unnamed) | ||
| model_class = getattr(unnamed, class_name, None) | ||
| return model_class | ||
|
|
||
|
|
||
| def count_kernels(model, sample_inputs) -> int: | ||
| """ | ||
| Count the number of CUDA kernel launches performed during a model's forward pass. | ||
| Args: | ||
| model(graph models) | ||
| sample_inputs(tensors) | ||
| Returns: | ||
| int: The number of kernels used. | ||
| Behavior: | ||
| - Runs the model once inside a PyTorch profiler context. | ||
| - Identifies the event with key = 'cudaLaunchKernel', which corresponds | ||
| to the number of CUDA kernel launches. | ||
| """ | ||
| model.eval() | ||
| # Use PyTorch Profiler | ||
|
|
||
| with profile( | ||
| activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], | ||
| record_shapes=True, | ||
| ) as prof: | ||
| with record_function("model_inference"): | ||
| output = model(**sample_inputs) | ||
| events = prof.key_averages() | ||
|
|
||
| total_count = 0 | ||
| for e in events: | ||
| if e.key == "cuLaunchKernel" or e.key == "cudaLaunchKernel": | ||
| total_count += e.count | ||
| return total_count | ||
Uh oh!
There was an error while loading. Please reload this page.