diff --git a/tests/python/multidevice/test_multidevice.py b/tests/python/multidevice/test_multidevice.py index 99858ca2e41..6cabb89a14d 100644 --- a/tests/python/multidevice/test_multidevice.py +++ b/tests/python/multidevice/test_multidevice.py @@ -29,7 +29,6 @@ def test_sizes_and_ranks(multidevice_test): @pytest.mark.mpi def test_pointwise(multidevice_test): num_devices = multidevice_test.size - mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices)) with FusionDefinition() as fd: inp_tv = fd.define_tensor((-1, -1), contiguity=False, dtype=DataType.Float) @@ -37,6 +36,7 @@ def test_pointwise(multidevice_test): tv2 = fd.ops.add(tv1, tv1) fd.add_output(tv2) + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices)) for tv in [inp_tv, tv1, tv2]: tv.set_device_mesh(mesh) @@ -50,6 +50,53 @@ def test_pointwise(multidevice_test): torch.testing.assert_close(out.cpu(), out_ref) +@pytest.mark.mpi +def test_transpose(multidevice_test): + d = multidevice_test.size + cp_size = 2 + if d != cp_size * cp_size: + pytest.skip(f"{d=} must equal {cp_size=}^2.") + + c = 128 + with FusionDefinition() as fd: + inp_tv = fd.define_tensor( + (-1, c, -1, -1, cp_size), contiguity=True, dtype=DataType.BFloat16 + ) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d).reshape(cp_size, cp_size)) + for tv in [inp_tv, out_tv]: + tv.set_device_mesh(mesh) + + inp_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y) + inp_tv.outer_split(3, cp_size) + inp_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x) + + out_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y) + out_tv.outer_split(3, cp_size) + out_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x) + out_tv.set_allocation_domain( + ( + out_tv.axis(2), + out_tv.axis(0), + out_tv.axis(1), + out_tv.axis(3), + out_tv.axis(4), + out_tv.axis(5), + ), + True, + ) + + b = 3 + s = cp_size * 5 + inp_ref = torch.randn(b, c, s, s, cp_size, dtype=torch.bfloat16) + out_ref = inp_ref + + inp = multidevice_test.shard_tensor(inp_ref, inp_tv) + fd.execute([inp]) + + class QkvFormat(Enum): BHSE = auto() BSHE = auto()