From bc6a091a9f971e59cfd2d3d29b4391c2a1a2fbe3 Mon Sep 17 00:00:00 2001 From: ishtihoss Date: Thu, 25 Jun 2026 20:23:52 -0700 Subject: [PATCH] Fix Upsample align_corners singleton output --- python/mlx/nn/layers/upsample.py | 2 +- python/tests/test_nn.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index e6bd282af1..f4499e50d6 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -12,7 +12,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims): M = int(scale * N) if align_corners: - indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1)) + indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / max(M - 1, 1)) else: step = 1 / scale start = ((M - 1) * step - N + 1) / 2 diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 98256df813..a50e3e1c81 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1471,6 +1471,37 @@ def test_upsample(self): "Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)", ) + def test_upsample_align_corners_one_dim(self): + x = mx.arange(1, 5).reshape((1, 2, 2, 1)) + + up = nn.Upsample(scale_factor=0.5, mode="linear", align_corners=True) + out = up(x) + self.assertEqual(out.shape, (1, 1, 1, 1)) + self.assertTrue(np.allclose(out, x[:, :1, :1, :])) + + up = nn.Upsample(scale_factor=(0.5, 2), mode="linear", align_corners=True) + out = up(x) + expected = mx.array([[[[1.0], [4.0 / 3.0], [5.0 / 3.0], [2.0]]]]) + self.assertEqual(out.shape, (1, 1, 4, 1)) + self.assertTrue(np.allclose(out, expected)) + + up = nn.Upsample(scale_factor=(2, 0.5), mode="linear", align_corners=True) + out = up(x) + expected = mx.array([[[[1.0]], [[5.0 / 3.0]], [[7.0 / 3.0]], [[3.0]]]]) + self.assertEqual(out.shape, (1, 4, 1, 1)) + self.assertTrue(np.allclose(out, expected)) + + up = nn.Upsample(scale_factor=0.5, mode="cubic", align_corners=True) + out = up(x) + self.assertEqual(out.shape, (1, 1, 1, 1)) + self.assertTrue(np.allclose(out, x[:, :1, :1, :])) + + x_1d = mx.arange(0, 4).reshape((1, 4, 1)).astype(mx.float32) + up = nn.Upsample(scale_factor=0.25, mode="linear", align_corners=True) + out = up(x_1d) + self.assertEqual(out.shape, (1, 1, 1)) + self.assertTrue(np.allclose(out, x_1d[:, :1, :])) + def test_pooling(self): # Test 1d pooling x = mx.array(