Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
986793a
1119
roll-away Nov 19, 2025
282a32d
1120
roll-away Nov 20, 2025
0187ae0
1120.2
roll-away Nov 20, 2025
5d46f55
model_path
roll-away Nov 20, 2025
39b4139
remove unnecessary files and pre-committed
roll-away Nov 20, 2025
b775e46
remove unnecessary files and pre-committed
roll-away Nov 21, 2025
44ad76f
1121 remove unnecessary files
roll-away Nov 21, 2025
0fc84c4
modify rev version
roll-away Nov 21, 2025
19dc60b
modify rev version
roll-away Nov 21, 2025
d6eda81
modify rev version
roll-away Nov 21, 2025
956ad33
accuracy issues targeted
roll-away Nov 21, 2025
8c8070b
test script and modify feature
roll-away Nov 21, 2025
ef7d4b6
return set[str]
roll-away Nov 21, 2025
181b293
add logfile for test
roll-away Nov 21, 2025
2aac268
filter can get the number of kernels in naive_graph_decomposer
roll-away Nov 24, 2025
00d5b4b
Merge branch 'PaddlePaddle:develop' into develop
roll-away Nov 24, 2025
75c3e61
post extract process feature
roll-away Nov 25, 2025
fe89add
remove unnecessary code blocks and variables
roll-away Nov 25, 2025
ca860b3
modify the way of counting kernels used
roll-away Nov 25, 2025
c21717f
modify the way of counting kernels used
roll-away Nov 25, 2025
de54e88
modify script, rename files and variables
roll-away Nov 25, 2025
9363023
add failure protection and log output when removing directories
roll-away Nov 26, 2025
adff744
Merge branch 'PaddlePaddle:develop' into develop
roll-away Nov 27, 2025
ca20508
add a script to check fusability of a given model
roll-away Dec 1, 2025
fc0071c
Merge branch 'PaddlePaddle:develop' into develop
roll-away Dec 1, 2025
9a28d45
Merge branch 'develop' of github.com:roll-away/GraphNet into develop
roll-away Dec 1, 2025
513cc38
add a script to check if a given model is fully fusable
roll-away Dec 1, 2025
4847ee3
Merge branch 'PaddlePaddle:develop' into develop
roll-away Dec 1, 2025
6538119
Merge branch 'PaddlePaddle:develop' into develop
roll-away Dec 1, 2025
22a2772
add a script to check if a given model is fully fusable
roll-away Dec 1, 2025
684dba9
a script to check if a given model is fully fusable
roll-away Dec 1, 2025
bfe0848
Merge branch 'PaddlePaddle:develop' into develop
roll-away Dec 1, 2025
f8cc102
add a script to check if a given model is fully fusionable
roll-away Dec 1, 2025
f131cfb
add a script to check if a given model is fully fusionable
roll-away Dec 1, 2025
f7f3d2a
add a script to find fully fusionable subgraph
roll-away Dec 1, 2025
353e7bd
find the biggest fully fusionable subgraph
roll-away Dec 2, 2025
b703458
update new codes
roll-away Dec 8, 2025
0b687cf
get fusible subgraph test
roll-away Dec 8, 2025
e70b44b
get fusible subgraph test
roll-away Dec 8, 2025
7dbb6e9
modify get fully fusible subgraph
roll-away Dec 9, 2025
f71b56b
improve fully_fusible_subgraph_extractor.py efficiency
lixinqi Dec 9, 2025
93fabbf
Merge pull request #1 from lixinqi/lxq_fusibletest
roll-away Dec 9, 2025
6df0cd0
backup code
lixinqi Dec 9, 2025
babdde5
Improve efficiency of test/fully_fusible_subgraph_extractor_test.sh
lixinqi Dec 9, 2025
48467f7
Merge pull request #2 from lixinqi/lxq_fusibletest
roll-away Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions graph_net/test/naive_decomposer_and_post_extract_process_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash
# bash graph_net/test/naive_decomposer_and_post_extract_process_test.sh

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
os.path.dirname(graph_net.__file__))")

# input model path
MODEL_NAME=resnet18
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
decorator_config_json_str=$(cat <<EOF
{
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
"decorator_config": {
"name": "$MODEL_NAME",
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
"custom_extractor_config": {
"output_dir": "/tmp/naive_decompose_workspace",
"split_positions": [8, 16, 32],
"group_head_and_tail": true,
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
"filter_config": {},
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
"post_extract_process_class_name": "PostExtractProcess"
}
}
}
EOF
)
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)

python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG
9 changes: 9 additions & 0 deletions graph_net/test/naive_graph_decomposer_test.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
#!/bin/bash
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")

# 将项目根目录加入Python路径
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment line 2 to line 7

It's not a good way to force users modifying PYTHONPATH.
If scripts failed, It's user's duty to set PYTHONPATH in .bashrc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删 在注释里写了怎么跑这个脚本






GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
Expand Down
21 changes: 21 additions & 0 deletions graph_net/torch/naive_graph_decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def make_config(
output_dir="./tmp/naive_decomposer_dir",
filter_path=None,
filter_config=None,
post_extract_process_path=None,
post_extract_process_class_name=None,
):
for pos in split_positions:
assert isinstance(
Expand All @@ -44,6 +46,8 @@ def make_config(
"output_dir": output_dir,
"filter_path": filter_path,
"filter_config": filter_config if filter_config is not None else {},
"post_extract_process_path": post_extract_process_path,
"post_extract_process_class_name": post_extract_process_class_name,
}

def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
Expand Down Expand Up @@ -71,6 +75,7 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
self.seq_no = seq_no
self.extracted = False
name = f"{parent_graph_extractor.name}_{self.seq_no}"
self.model_name = name
self.builtin_extractor = BuiltinGraphExtractor(
name=name,
dynamic=False,
Expand All @@ -79,11 +84,15 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
workspace_path=self.parent_graph_extractor.config["output_dir"],
)
self.filter = self.make_filter(self.parent_graph_extractor.config)
self.post_extract_process = self.make_post_extract_process(
self.parent_graph_extractor.config
)

def forward(self, *args):
if not self.extracted:
if self.need_extract(self.submodule, args):
self.builtin_extractor(self.submodule, args)
self._post_extract_process()
self.extracted = True
return self.submodule(*args)

Expand All @@ -92,8 +101,20 @@ def need_extract(self, gm, sample_inputs):
return True
return self.filter(gm, sample_inputs)

def _post_extract_process(self):
model_path = os.path.join(
self.parent_graph_extractor.config["output_dir"], self.model_name
)
return self.post_extract_process(model_path)

def make_filter(self, config):
if config["filter_path"] is None:
return None
module = imp_util.load_module(config["filter_path"])
return module.GraphFilter(config["filter_config"])

def make_post_extract_process(self, config):
if config["post_extract_process_path"] is None:
return None
module = imp_util.load_module(config["post_extract_process_path"])
return module.PostExtractProcess(config["post_extract_process_path"])
1 change: 0 additions & 1 deletion graph_net/torch/naive_subgraph_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ def __init__(self, config):
self.config = config

def __call__(self, gm, sample_inputs):
print(f"GraphFilter\n{gm.code}")
return True
79 changes: 79 additions & 0 deletions graph_net/torch/post_extract_process_count_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from graph_net.torch import utils
import importlib.util
import shutil
import torch
from typing import Type
from torch.profiler import profile, record_function, ProfilerActivity


class PostExtractProcess:
def __init__(self, config):
self.config = config

def __call__(self, model_path=None):
if model_path is None:
return False
# model
model_class = load_class_from_file(
f"{model_path}/model.py", class_name="GraphModule"
)
assert model_class is not None
model = model_class()
Copy link
Collaborator

@Xreki Xreki Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L26 - L51这些都不需要,你知道model_path,L58定义了model对象后,就可以直接调用compile_and_count_kernels(model, state_dict)了。
在使用torch.compile之前,也可以执行一下model(**state_dict),看看这个样本有没有问题,能不能成功执行。

print(f"{model_path=}")

inputs_params = utils.load_converted_from_text(f"{model_path}")
params = inputs_params["weight_info"]
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}

model(**state_dict)
compiled_model = torch.compile(model)
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
if compiled_num_of_kernels == 1:
print(model_path, "can be fully integrated")
return True
else:
print(model_path, "can not be fully integrated")
shutil.rmtree(model_path)
return False


def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
spec = importlib.util.spec_from_file_location("unnamed", file_path)
unnamed = importlib.util.module_from_spec(spec)
spec.loader.exec_module(unnamed)
model_class = getattr(unnamed, class_name, None)
return model_class


def count_kernels(model, sample_inputs) -> int:
"""
Count the number of CUDA kernel launches performed during a model's forward pass.
Args:
model(graph models)
sample_inputs(tensors)
Returns:
int: The number of kernels used.
Behavior:
- Runs the model once inside a PyTorch profiler context.
- Identifies the event with key = 'cudaLaunchKernel', which corresponds
to the number of CUDA kernel launches.
"""
model.eval()
# Use PyTorch Profiler

with profile(
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
record_shapes=True,
) as prof:
with record_function("model_inference"):
output = model(**sample_inputs)
events = prof.key_averages()

total_count = 0
for e in events:
if e.key == "cuLaunchKernel" or e.key == "cudaLaunchKernel":
total_count += e.count
return total_count