Skip to content

Commit 6538119

Browse files
authored
Merge branch 'PaddlePaddle:develop' into develop
2 parents 4847ee3 + 4fddaee commit 6538119

10 files changed

+837
-50
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,11 @@ repos:
1111
- id: ruff-check
1212
args: [--fix, --exit-non-zero-on-fix, --no-cache]
1313

14-
- repo: https://github.com/PFCCLab/typos-pre-commit-mirror.git
15-
rev: v1.39.2
16-
hooks:
17-
- id: typos
18-
args: [--force-exclude]
19-
2014
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
2115
rev: v1.5.1
2216
hooks:
2317
- id: remove-crlf
2418
- id: remove-tabs
2519
name: Tabs remver (Python)
2620
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
27-
args: [--whitespaces-count, '4']
21+
args: [--whitespaces-count, '4']

graph_net/config/empty_cstr_torch_samples_list.txt

Lines changed: 503 additions & 0 deletions
Large diffs are not rendered by default.

graph_net/constraint_util.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, config=None):
2525
self.model_runnable_predicator = self._make_model_runnable_predicator(
2626
self.config
2727
)
28-
self.num_successful_handled_models = 0
28+
self.num_handled_models = 0
2929

3030
def _make_data_input_predicator(self, config):
3131
module = load_module(config["data_input_predicator_filepath"])
@@ -51,7 +51,7 @@ def _make_config(
5151
model_path_prefix="",
5252
resume=False,
5353
last_model_log_file=None,
54-
limits_successfully_handled_models=None,
54+
limits_handled_models=None,
5555
):
5656
if data_input_predicator_config is None:
5757
data_input_predicator_config = {}
@@ -72,7 +72,7 @@ def _make_config(
7272
"dimension_generalizer_class_name": dimension_generalizer_class_name,
7373
"dimension_generalizer_config": dimension_generalizer_config,
7474
"last_model_log_file": last_model_log_file,
75-
"limits_successfully_handled_models": limits_successfully_handled_models,
75+
"limits_handled_models": limits_handled_models,
7676
}
7777

7878
def __call__(self, model_path):
@@ -125,16 +125,15 @@ def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
125125
)
126126
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
127127
self._save_dim_gen_pass_names(dim_gen_pass_names, model_path)
128-
if len(dyn_dim_cstr.symbols) > 0:
129-
self.num_successful_handled_models += 1
130-
limits = self.config["limits_successfully_handled_models"]
131-
if limits is not None:
132-
if self.num_successful_handled_models > limits:
133-
print(
134-
"`num_successful_handled_models` exceeds config `limits_successfully_handled_models`",
135-
file=sys.stderr,
136-
)
137-
sys.exit(0)
128+
self.num_handled_models += 1
129+
limits = self.config["limits_handled_models"]
130+
if limits is not None:
131+
if self.num_handled_models >= limits:
132+
print(
133+
"`num_handled_models` exceeds config `limits_handled_models`",
134+
file=sys.stderr,
135+
)
136+
sys.exit(0)
138137

139138
def get_dimension_generalizer(self):
140139
if hasattr(self, "_dim_generalizer"):
@@ -159,6 +158,7 @@ def get_model(self, model_path):
159158
def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
160159
logging.warning("enter _try_dimension_generalization")
161160
if self.config["dimension_generalizer_filepath"] is None:
161+
self._save_model_to_log_file(model_path)
162162
yield model_path, ()
163163
return
164164
model = self.get_model(model_path)
@@ -168,6 +168,7 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
168168
need_rewrite = dim_gen_pass.need_rewrite(inputs)
169169
logging.warning("after need_rewrite")
170170
if not need_rewrite:
171+
self._save_model_to_log_file(model_path)
171172
yield model_path, ()
172173
return
173174

@@ -177,11 +178,14 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
177178
with tempfile.TemporaryDirectory() as tmp_dir:
178179
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
179180
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
180-
if self.config["last_model_log_file"] is not None:
181-
log_file = Path(self.config["last_model_log_file"])
182-
shutil.copy(Path(tmp_dir) / "model.py", log_file)
181+
self._save_model_to_log_file(tmp_dir)
183182
yield tmp_dir, dim_gen_pass.get_pass_names()
184183

184+
def _save_model_to_log_file(self, model_path):
185+
if self.config["last_model_log_file"] is not None:
186+
log_file = Path(self.config["last_model_log_file"])
187+
shutil.copy(Path(model_path) / "model.py", log_file)
188+
185189
def _save_dim_gen_pass_names(self, dim_gen_pass_names, model_path):
186190
from graph_net.graph_net_json_file_util import kDimensionGeneralizationPasses
187191

@@ -324,7 +328,7 @@ def append_dim_gen_pass_names(dim_gen_pass_names):
324328
)
325329

326330
for i, picked_dim in enumerate(unique_dims):
327-
logging.warning(f"{i=} {picked_dim=}")
331+
logging.warning(f"{i=} {picked_dim=} {dim2axes[picked_dim]=}")
328332
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
329333

330334
def filter_fn(input_name, input_idx, axis, dim):

graph_net/tools/batch_init_input_tensor_constraints.sh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,27 @@ config_json_str=$(cat <<EOF
1919
"model_runnable_predicator_class_name": "$model_runnable_predicator",
2020
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
2121
"dimension_generalizer_class_name": "StaticToDynamic",
22+
"dimension_generalizer_config": {
23+
"pass_names": [
24+
"batch_call_method_view_pass",
25+
"tuple_arg_call_method_view_pass",
26+
"naive_call_method_reshape_pass",
27+
"naive_call_method_expand_pass",
28+
"non_batch_call_method_expand_pass",
29+
"non_batch_call_function_arange_pass",
30+
"non_batch_call_function_getitem_slice_pass",
31+
"non_batch_call_function_full_pass",
32+
"non_batch_call_function_full_plus_one_pass",
33+
"non_batch_call_function_zeros_pass",
34+
"non_batch_call_function_arange_plus_one_pass"
35+
]
36+
},
37+
"limits_handled_models": 1,
2238
"last_model_log_file": "/tmp/a.py"
2339
}
2440
}
2541
EOF
2642
)
2743
CONFIG=$(echo $config_json_str | base64 -w 0)
2844

29-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG
45+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/empty_cstr_torch_samples_list.txt --handler-config=$CONFIG

graph_net/torch/constraint_util.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@ def __init__(self, config):
88
self.config = config
99

1010
def __call__(self, model_path, input_var_name: str) -> bool:
11-
return not ("_self_" in input_var_name)
11+
return not (
12+
"_self_" in input_var_name or "_instance_modules_" in input_var_name
13+
)
1214

1315

1416
class ModelRunnablePredicator:
1517
def __init__(self, config):
1618
if config is None:
1719
config = {}
1820

19-
graph_net_root = os.path.dirname(graph_net.__file__)
2021
decorator_config = {"use_dummy_inputs": True}
2122
self.predicator = RunModelPredicator(decorator_config)
2223

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch.fx as fx
2+
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
3+
import os
4+
5+
6+
class ConcretePass(DimensionGeneralizationPass):
7+
def __init__(self, *args, **kwargs):
8+
super().__init__(*args, **kwargs)
9+
10+
def get_pass_name(cls) -> bool:
11+
return os.path.basename(__file__)[:-3]
12+
13+
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
14+
if 0 not in self.axes:
15+
return False
16+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
17+
18+
def _node_need_rewrite(self, node) -> bool:
19+
if not (node.op == "call_method"):
20+
return False
21+
if not (node.target == "view"):
22+
return False
23+
if not (len(node.args) >= 2):
24+
return False
25+
if not (isinstance(node.args[1], int)):
26+
return False
27+
if not (self.dim == node.args[1]):
28+
return False
29+
return True
30+
31+
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
32+
"""
33+
Fx Pass: Replaces hardcoded constants in 'view' ops that match an input tensor dimension
34+
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size (axis 0).
35+
"""
36+
# Create a new graph to hold the rewritten nodes
37+
new_graph = fx.Graph()
38+
39+
# Create a map to link nodes from the old graph to nodes in the new graph
40+
val_map = {}
41+
42+
def get_new_tuple_args(input_tensor_node, view_args):
43+
# --- Dependency on ShapeProp Results ---
44+
# input_shape is the static shape (e.g., batch_size, C, H, W)
45+
input_meta = input_tensor_node.meta.get("tensor_meta")
46+
if input_meta is None:
47+
raise RuntimeError(
48+
f"Node {input_tensor_node.name} lacks tensor_meta. Did ShapeProp run?"
49+
)
50+
51+
input_shape = input_meta.shape
52+
53+
# Find the new list of view arguments
54+
new_view_args = []
55+
for axis_idx, target_dim in enumerate(view_args):
56+
if not isinstance(target_dim, int) or target_dim < 1:
57+
new_view_args.append(
58+
val_map[target_dim] if target_dim in val_map else target_dim
59+
)
60+
continue
61+
62+
if axis_idx == 0 and target_dim == input_shape[axis_idx]:
63+
new_input_node = val_map[input_tensor_node]
64+
size_node = new_graph.call_method(
65+
"size", args=(new_input_node, axis_idx)
66+
)
67+
best_match = size_node
68+
else:
69+
best_match = target_dim
70+
new_view_args.append(best_match)
71+
return tuple(new_view_args)
72+
73+
for node in traced_module.graph.nodes:
74+
if self._node_need_rewrite(node):
75+
# Get the input tensor node
76+
input_tensor_node = node.args[0]
77+
# Get the target shape arguments for view (e.g., 1, -1, 6, 64)
78+
view_args = node.args[1:]
79+
print(f"{view_args=}")
80+
new_view_args = get_new_tuple_args(input_tensor_node, view_args)
81+
82+
# --- Rebuild the view node ---
83+
# 1. Map the input tensor node to the new graph node
84+
new_input_node = val_map[input_tensor_node]
85+
86+
# 2. Insert the new view node into the new graph
87+
# with new_graph.inserting_after(new_input_node):
88+
new_node = new_graph.call_method(
89+
"view", args=(new_input_node, *new_view_args)
90+
)
91+
92+
# 3. Map the old node to the new node
93+
val_map[node] = new_node
94+
95+
else:
96+
# Copy other nodes to the new graph
97+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
98+
val_map[node] = new_node
99+
100+
# Replace the old graph with the new graph and return
101+
traced_module.graph = new_graph
102+
traced_module.recompile()
103+
return traced_module

graph_net/torch/dim_gen_passes/naive_call_method_view_pass.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
import torch.fx as fx
32
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
43
import os
@@ -14,10 +13,17 @@ def get_pass_name(cls) -> bool:
1413
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
1514
if 0 not in self.axes:
1615
return False
17-
for node in traced_module.graph.nodes:
18-
if node.op == "call_method" and node.target == "view":
19-
return True
20-
return False
16+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
17+
18+
def _node_need_rewrite(self, node) -> bool:
19+
if not (node.op == "call_method"):
20+
return False
21+
if not (node.target == "view"):
22+
return False
23+
print(f"{self.dim=} {node.args[1:]=}")
24+
if self.dim not in node.args[1:]:
25+
return False
26+
return True
2127

2228
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
2329
"""

0 commit comments

Comments
 (0)