Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions onnxruntime/contrib_ops/cpu/maxpool_with_mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ class MaxpoolWithMask : public OpKernel, public PoolBase {
"Mask and input spatial dimensions mismatch at dimension ", i,
": mask=", m_shape[i], " input=", x_shape[i]);
}
// x_shape.NumDimensions() >= 3 is guaranteed above, so this subtraction cannot underflow.
const size_t input_spatial_rank = x_shape.NumDimensions() - 2;
// The pooling kernel rank drives the 1D/2D/3D dispatch below, which reads x_shape[2..4] and
// output_dims[2..4]. Require it to match the input spatial rank so those reads stay in bounds.
ORT_RETURN_IF_NOT(pool_attrs_.kernel_shape.size() == input_spatial_rank,
"Pooling kernel rank must equal input spatial rank. Got kernel rank: ",
pool_attrs_.kernel_shape.size(), " input spatial rank: ", input_spatial_rank);
// Only 1D/2D/3D pooling is implemented by the dispatch below; a larger rank would match no case.
ORT_RETURN_IF_NOT(input_spatial_rank >= 1 && input_spatial_rank <= 3,
"Only 1D, 2D, and 3D pooling are supported. Got input spatial rank: ", input_spatial_rank);

TensorShapeVector pads = pool_attrs_.pads;
TensorShapeVector kernel_shape = pool_attrs_.kernel_shape;
Expand Down
65 changes: 65 additions & 0 deletions onnxruntime/test/contrib_ops/maxpool_mask_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,70 @@ TEST(ContribOpTest, MaxPoolWithMask_MaskEmptyBatchDim) {
"Mask N and C dimensions must be greater than 0");
}

TEST(ContribOpTest, MaxPoolWithMask_KernelRankMismatch) {
OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain);

// AddShapeToTensorData(false) omits input shape from the graph so ONNX shape inference is bypassed
// (convPoolShapeInference returns early when hasInputShape is false). This lets the model pass
// Graph::Resolve() and reach Compute() where the kernel-rank guard fires.
test.AddShapeToTensorData(false);

test.AddAttribute("auto_pad", "");
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
// 2D kernel_shape, but X has only one spatial dimension (rank 3).
test.AddAttribute("kernel_shape", std::vector<int64_t>{8, 8});

// Input X has shape {1, 1, 8} (rank 3 => one spatial dim)
std::vector<int64_t> x_dims = {1, 1, 8};
std::vector<float> x_vals(8, 1.0f);

// Mask M matches X shape so the earlier spatial/rank guards pass and we reach the kernel-rank guard.
std::vector<int64_t> m_dims = {1, 1, 8};
std::vector<int32_t> m_vals(8, 1);

// Placeholder output shape and values (not validated since we expect failure)
std::vector<int64_t> expected_dims = {1, 1, 1, 1};
std::vector<float> expected_vals = {1.0f};

test.AddInput<float>("X", x_dims, x_vals);
test.AddInput<int32_t>("M", m_dims, m_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(BaseTester::ExpectResult::kExpectFailure,
"Pooling kernel rank must equal input spatial rank");
}

TEST(ContribOpTest, MaxPoolWithMask_KernelRankTooLarge) {
OpTester test("MaxpoolWithMask", 1, onnxruntime::kMSDomain);

// Bypass ONNX shape inference (see MaxPoolWithMask_KernelRankMismatch) so the model reaches Compute().
test.AddShapeToTensorData(false);

test.AddAttribute("auto_pad", "");
test.AddAttribute("strides", std::vector<int64_t>{1, 1, 1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0});
// 4D kernel_shape with a matching rank-6 X (4 spatial dims): passes the kernel-rank equality guard
// but exceeds the supported 1D/2D/3D pooling ranks.
test.AddAttribute("kernel_shape", std::vector<int64_t>{2, 2, 2, 2});

// Input X has shape {1, 1, 2, 2, 2, 2} (rank 6 => four spatial dims)
std::vector<int64_t> x_dims = {1, 1, 2, 2, 2, 2};
std::vector<float> x_vals(16, 1.0f);

// Mask M matches X so the earlier guards and the kernel-rank equality guard all pass.
std::vector<int64_t> m_dims = {1, 1, 2, 2, 2, 2};
std::vector<int32_t> m_vals(16, 1);

// Placeholder output shape and values (not validated since we expect failure)
std::vector<int64_t> expected_dims = {1, 1, 1, 1};
std::vector<float> expected_vals = {1.0f};

test.AddInput<float>("X", x_dims, x_vals);
test.AddInput<int32_t>("M", m_dims, m_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run(BaseTester::ExpectResult::kExpectFailure,
"Only 1D, 2D, and 3D pooling are supported");
}

} // namespace test
} // namespace onnxruntime
Loading