Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage
CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage
CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage
CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage
}
Expand Down Expand Up @@ -93,6 +92,24 @@
# For linear: [input_activation, weight]
[qconfig_A16.input_activation, qconfig_A16.weight],
),
(
"conv1d_A16",
lambda self: self._build_conv1d_graph(),
CadenceWith16BitConvActivationsQuantizer(),
torch.ops.aten.conv1d.default,
qconfig_A16.output_activation,
# For conv1d: [input_activation, weight]
[qconfig_A16.input_activation, qconfig_A16.weight],
),
(
"conv2d_A16",
lambda self: self._build_conv2d_graph(),
CadenceWith16BitConvActivationsQuantizer(),
torch.ops.aten.conv2d.default,
qconfig_A16.output_activation,
# For conv2d: [input_activation, weight]
[qconfig_A16.input_activation, qconfig_A16.weight],
),
]

# Derive the set of tested quantizer classes from the test cases.
Expand Down Expand Up @@ -149,6 +166,54 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
return gm, linear_nodes[0]

def _build_conv1d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a conv1d operation (no bias)."""
builder = GraphBuilder()
# Input shape: (batch, in_channels, length)
x = builder.placeholder("x", torch.randn(1, 3, 10))
# Weight shape: (out_channels, in_channels, kernel_size)
weight = builder.placeholder("weight", torch.randn(6, 3, 3))
conv1d = builder.call_operator(
op=torch.ops.aten.conv1d.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]}
),
)
builder.output([conv1d])
gm = builder.get_graph_module()

conv1d_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.conv1d.default,
)
self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node")
return gm, conv1d_nodes[0]

def _build_conv2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a conv2d operation (no bias)."""
builder = GraphBuilder()
# Input shape: (batch, in_channels, height, width)
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
# Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3))
conv2d = builder.call_operator(
op=torch.ops.aten.conv2d.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]}
),
)
builder.output([conv2d])
gm = builder.get_graph_module()

conv2d_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.conv2d.default,
)
self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node")
return gm, conv2d_nodes[0]

@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
def test_quantizer_annotation(
self,
Expand Down
Loading