diff --git a/warp/__init__.pyi b/warp/__init__.pyi index 8a85b58c73..c021b75b99 100644 --- a/warp/__init__.pyi +++ b/warp/__init__.pyi @@ -1724,13 +1724,13 @@ class transformd: @over def zeros( shape: int | tuple[int, ...] | list[int] | None = None, - dtype: type = float, + dtype: type[DType] | None = None, device: DeviceLike = None, requires_grad: _builtins.bool = False, pinned: _builtins.bool = False, retain_grad: _builtins.bool = False, **kwargs, -) -> array: +) -> array[DType]: """Return a zero-initialized array.""" ... diff --git a/warp/_src/context.py b/warp/_src/context.py index 55acf88079..97f398d792 100644 --- a/warp/_src/context.py +++ b/warp/_src/context.py @@ -66,7 +66,7 @@ def ParamSpec(name): from warp._src.codegen import WarpCodegenTypeError, synchronized from warp._src.logger import LOG_DEBUG, LOG_WARNING, get_logger, log_debug, log_error, log_info, log_warning from warp._src.texture import Texture1D, Texture2D, Texture3D, texture1d_t, texture2d_t, texture3d_t -from warp._src.types import LAUNCH_MAX_DIMS, Array, LaunchBounds, launch_bounds_t, type_repr +from warp._src.types import LAUNCH_MAX_DIMS, Array, DType, LaunchBounds, launch_bounds_t, type_repr _wp_module_name_ = "warp.context" @@ -7647,13 +7647,13 @@ def unmap(self): def zeros( shape: int | tuple[int, ...] | list[int] | None = None, - dtype: type = float, + dtype: type[DType] | None = None, device: DeviceLike = None, requires_grad: bool = False, pinned: bool = False, retain_grad: bool = False, **kwargs, -) -> warp.array: +) -> warp.array[DType]: """Return a zero-initialized array. Args: @@ -7667,6 +7667,8 @@ def zeros( Returns: A warp.array object representing the allocation """ + if dtype is None: + dtype = float arr = empty( shape=shape, @@ -7712,13 +7714,13 @@ def zeros_like( def ones( shape: int | tuple[int, ...] | list[int] | None = None, - dtype: type = float, + dtype: type[DType] | None = None, device: DeviceLike = None, requires_grad: bool = False, pinned: bool = False, retain_grad: bool = False, **kwargs, -) -> warp.array: +) -> warp.array[DType]: """Return a one-initialized array. Args: @@ -7732,6 +7734,8 @@ def ones( Returns: A warp.array object representing the allocation """ + if dtype is None: + dtype = float return full( shape=shape, @@ -7771,13 +7775,13 @@ def ones_like( def full( shape: int | tuple[int, ...] | list[int] | None = None, value: Any = 0, - dtype: type | None = None, + dtype: type[DType] | None = None, device: DeviceLike = None, requires_grad: bool = False, pinned: bool = False, retain_grad: bool = False, **kwargs, -) -> warp.array: +) -> warp.array[DType]: """Return an array with all elements initialized to the given value. Args: @@ -7902,13 +7906,13 @@ def clone( def empty( shape: int | tuple[int, ...] | list[int] | None = None, - dtype=float, + dtype: type[DType] | None = None, device: DeviceLike = None, requires_grad: bool = False, pinned: bool = False, retain_grad: bool = False, **kwargs, -) -> warp.array: +) -> warp.array[DType]: """Return an uninitialized array. Args: @@ -7923,6 +7927,9 @@ def empty( A warp.array object representing the allocation """ + if dtype is None: + dtype = float + # backwards compatibility for case where users called wp.empty(n=length, ...) if "n" in kwargs: shape = (kwargs["n"],) @@ -7998,12 +8005,12 @@ def empty_like( def from_numpy( arr: np.ndarray, - dtype: type | None = None, + dtype: type[DType] | None = None, shape: Sequence[int] | None = None, device: DeviceLike | None = None, requires_grad: bool = False, retain_grad: bool = False, -) -> warp.array: +) -> warp.array[DType]: """Return a Warp array created from a NumPy array. Args: diff --git a/warp/_src/types.py b/warp/_src/types.py index 2d4db5e76a..833089c01d 100644 --- a/warp/_src/types.py +++ b/warp/_src/types.py @@ -3072,7 +3072,7 @@ def __new__(cls, *args, **kwargs): def __init__( self, data: list | tuple | npt.NDArray | None = None, - dtype: Any = Any, + dtype: type[DType] = Any, shape: int | tuple[int, ...] | list[int] | None = None, strides: tuple[int, ...] | None = None, ptr: int | None = None,