Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions iris/experimental/iris_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,29 @@ def _translate(self, ptr, from_rank, to_rank):

return translated_ptr

@gluon.jit
def as_remote(self, ptr, rank):
"""
Translate a local pointer to point at the target rank's copy.

Convenience wrapper over ``_translate`` that fills in ``cur_rank``
automatically. Returns a pointer usable directly with ``gl.load``
/ ``gl.store``.

Args:
ptr: Pointer in the current rank's address space
rank: Target rank ID

Returns:
Translated pointer in the target rank's address space

Example::

remote_ptr = ctx.as_remote(buf + offsets, target_rank)
data = gl.load(remote_ptr, mask=mask)
"""
return self._translate(ptr, self.cur_rank, rank)

@gluon.jit
def load(self, pointer, from_rank, mask=None, other=None, cache_modifier=None, volatile=False):
"""
Expand Down Expand Up @@ -1358,6 +1381,54 @@ def is_symmetric(self, tensor: torch.Tensor) -> bool:
"""
return self.heap.is_symmetric(tensor)

def as_remote(self, tensor: torch.Tensor, rank: int) -> torch.Tensor:
"""
Return a zero-copy view of a symmetric tensor pointing to the target rank's copy.

Takes a tensor allocated on the symmetric heap and returns a new tensor with
the same shape, dtype, and strides, but whose ``data_ptr()`` points to the
corresponding location in the target rank's heap. This is useful for hoisting
pointer translation out of loops or passing pre-translated pointers to kernels.

Args:
tensor (torch.Tensor): A tensor on the symmetric heap
rank (int): Target rank whose copy to point at

Returns:
torch.Tensor: A view pointing to the target rank's symmetric heap

Raises:
ValueError: If tensor is not on the symmetric heap or rank is out of range

Example:
>>> import iris.experimental.iris_gluon as iris_gl
>>> ctx = iris_gl.iris(heap_size=2**30)
>>> buf = ctx.zeros(1024, dtype=torch.float32)
>>> remote_buf = ctx.as_remote(buf, target_rank)
>>> # remote_buf.data_ptr() now points to target_rank's copy
"""
if not self.is_symmetric(tensor):
raise ValueError("as_remote requires a tensor on the symmetric heap")
if rank < 0 or rank >= self.num_ranks:
raise ValueError(f"rank {rank} out of range [0, {self.num_ranks})")

local_base = int(self.heap.heap_bases[self.cur_rank].item())
remote_base = int(self.heap.heap_bases[rank].item())
offset = tensor.data_ptr() - local_base
remote_ptr = remote_base + offset

elem_size = tensor.element_size()
if tensor.numel() == 0:
storage_bytes = 0
else:
max_offset = sum((s - 1) * st for s, st in zip(tensor.shape, tensor.stride()))
storage_bytes = (max_offset + 1) * elem_size

from iris.tensor_utils import tensor_from_ptr

flat = tensor_from_ptr(remote_ptr, storage_bytes, dtype=tensor.dtype, device=str(tensor.device))
return torch.as_strided(flat, tensor.shape, tensor.stride())


def iris(heap_size=1 << 30):
"""
Expand Down
47 changes: 47 additions & 0 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,53 @@ def is_symmetric(self, tensor: torch.Tensor) -> bool:
"""
return self.heap.is_symmetric(tensor)

def as_remote(self, tensor: torch.Tensor, rank: int) -> torch.Tensor:
"""
Return a zero-copy view of a symmetric tensor pointing to the target rank's copy.

Takes a tensor allocated on the symmetric heap and returns a new tensor with
the same shape, dtype, and strides, but whose ``data_ptr()`` points to the
corresponding location in the target rank's heap. This is useful for hoisting
pointer translation out of loops or passing pre-translated pointers to kernels.

Args:
tensor (torch.Tensor): A tensor on the symmetric heap
rank (int): Target rank whose copy to point at

Returns:
torch.Tensor: A view pointing to the target rank's symmetric heap

Raises:
ValueError: If tensor is not on the symmetric heap or rank is out of range

Example:
>>> ctx = iris.iris(heap_size=2**30)
>>> buf = ctx.zeros(1024, dtype=torch.float32)
>>> remote_buf = ctx.as_remote(buf, target_rank)
>>> # remote_buf.data_ptr() now points to target_rank's copy
"""
if not self.is_symmetric(tensor):
raise ValueError("as_remote requires a tensor on the symmetric heap")
if rank < 0 or rank >= self.num_ranks:
raise ValueError(f"rank {rank} out of range [0, {self.num_ranks})")

local_base = int(self.heap.heap_bases[self.cur_rank].item())
remote_base = int(self.heap.heap_bases[rank].item())
offset = tensor.data_ptr() - local_base
remote_ptr = remote_base + offset

elem_size = tensor.element_size()
if tensor.numel() == 0:
storage_bytes = 0
else:
max_offset = sum((s - 1) * st for s, st in zip(tensor.shape, tensor.stride()))
storage_bytes = (max_offset + 1) * elem_size

from iris.tensor_utils import tensor_from_ptr

flat = tensor_from_ptr(remote_ptr, storage_bytes, dtype=tensor.dtype, device=str(tensor.device))
return torch.as_strided(flat, tensor.shape, tensor.stride())

def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
"""
Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value.
Expand Down
Loading
Loading