Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion tests/python/multidevice/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def test_sizes_and_ranks(multidevice_test):
@pytest.mark.mpi
def test_pointwise(multidevice_test):
num_devices = multidevice_test.size
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices))

with FusionDefinition() as fd:
inp_tv = fd.define_tensor((-1, -1), contiguity=False, dtype=DataType.Float)
tv1 = fd.ops.relu(inp_tv)
tv2 = fd.ops.add(tv1, tv1)
fd.add_output(tv2)

mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices))
for tv in [inp_tv, tv1, tv2]:
tv.set_device_mesh(mesh)

Expand All @@ -50,6 +50,53 @@ def test_pointwise(multidevice_test):
torch.testing.assert_close(out.cpu(), out_ref)


@pytest.mark.mpi
def test_transpose(multidevice_test):
d = multidevice_test.size
cp_size = 2
if d != cp_size * cp_size:
pytest.skip(f"{d=} must equal {cp_size=}^2.")

c = 128
with FusionDefinition() as fd:
inp_tv = fd.define_tensor(
(-1, c, -1, -1, cp_size), contiguity=True, dtype=DataType.BFloat16
)
out_tv = fd.ops.set(inp_tv)
fd.add_output(out_tv)

mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d).reshape(cp_size, cp_size))
for tv in [inp_tv, out_tv]:
tv.set_device_mesh(mesh)

inp_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y)
inp_tv.outer_split(3, cp_size)
inp_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x)

out_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y)
out_tv.outer_split(3, cp_size)
out_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x)
out_tv.set_allocation_domain(
(
out_tv.axis(2),
out_tv.axis(0),
out_tv.axis(1),
out_tv.axis(3),
out_tv.axis(4),
out_tv.axis(5),
),
True,
)

b = 3
s = cp_size * 5
inp_ref = torch.randn(b, c, s, s, cp_size, dtype=torch.bfloat16)
out_ref = inp_ref

inp = multidevice_test.shard_tensor(inp_ref, inp_tv)
fd.execute([inp])


class QkvFormat(Enum):
BHSE = auto()
BSHE = auto()
Expand Down