Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions upath/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ def _fsspec_protocol_equals(p0: str, p1: str) -> bool:
except KeyError:
raise ValueError(f"Protocol not known: {p1!r}")

if o0 == o1:
return True

if isinstance(o0, dict):
o0 = o0.get("class")
elif isinstance(o0, type):
if o0.__module__:
o0 = o0.__module__ + "." + o0.__name__
else:
o0 = o0.__name__
if isinstance(o1, dict):
o1 = o1.get("class")
elif isinstance(o1, type):
if o1.__module__:
o1 = o1.__module__ + "." + o1.__name__
else:
o1 = o1.__name__

return o0 == o1


Expand Down
64 changes: 45 additions & 19 deletions upath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from upath._protocol import compatible_protocol
from upath._protocol import get_upath_protocol
from upath._stat import UPathStatResult
from upath.registry import _get_implementation_protocols
from upath.registry import available_implementations
from upath.registry import get_upath_class
from upath.types import UNSET_DEFAULT
from upath.types import JoinablePathLike
Expand Down Expand Up @@ -404,7 +406,7 @@ def _fs_factory(

_protocol_dispatch: bool | None = None

def __new__(
def __new__( # noqa C901
cls,
*args: JoinablePathLike,
protocol: str | None = None,
Expand Down Expand Up @@ -435,6 +437,27 @@ def __new__(
if "incompatible with" in str(e):
raise _IncompatibleProtocolError(str(e)) from e
raise

# subclasses should default to their own protocol
if protocol is None and cls is not UPath:
impl_protocols = _get_implementation_protocols(cls)
if not pth_protocol and impl_protocols:
pth_protocol = impl_protocols[0]
elif pth_protocol and pth_protocol not in impl_protocols:
msg_protocol = pth_protocol
if not pth_protocol:
msg_protocol = "'' (empty string)"
msg = (
f"{cls.__name__!s}(...) detected protocol {msg_protocol!s}"
f" which is incompatible with {cls.__name__}."
)
if not pth_protocol or pth_protocol not in available_implementations():
msg += (
" Did you forget to register the subclass for this protocol"
" with upath.registry.register_implementation()?"
)
raise _IncompatibleProtocolError(msg)

# determine which UPath subclass to dispatch to
upath_cls: type[UPath] | None
if cls._protocol_dispatch or cls._protocol_dispatch is None:
Expand Down Expand Up @@ -468,26 +491,24 @@ def __new__(
raise RuntimeError("UPath.__new__ expected cls to be subclass of UPath")

else:
msg_protocol = repr(pth_protocol)
msg_protocol = pth_protocol
if not pth_protocol:
msg_protocol += " (empty string)"
msg_protocol = "'' (empty string)"
msg = (
f"{cls.__name__!s}(...) detected protocol {msg_protocol!s} and"
f" returns a {upath_cls.__name__} instance that isn't a direct"
f" subclass of {cls.__name__}. This will raise an exception in"
" future universal_pathlib versions. To prevent the issue, use"
" UPath(...) to create instances of unrelated protocols or you"
f" can instead derive your subclass {cls.__name__!s}(...) from"
f" {upath_cls.__name__} or alternatively override behavior via"
f" registering the {cls.__name__} implementation with protocol"
f" {msg_protocol!s} replacing the default implementation."
)
warnings.warn(
msg,
DeprecationWarning,
stacklevel=2,
f"{cls.__name__!s}(...) detected protocol {msg_protocol!s}"
f" which is incompatible with {cls.__name__}."
)
upath_cls = cls
if (
# find a better way
(not pth_protocol and cls.__name__ not in ["CloudPath", "LocalPath"])
or pth_protocol
and pth_protocol not in available_implementations()
):
msg += (
" Did you forget to register the subclass for this protocol"
" with upath.registry.register_implementation()?"
)
raise _IncompatibleProtocolError(msg)

return object.__new__(upath_cls)

Expand Down Expand Up @@ -520,7 +541,6 @@ def __init__(
Additional storage options for the path.

"""

# todo: avoid duplicating this call from __new__
protocol = get_upath_protocol(
args[0] if args else "",
Expand All @@ -539,6 +559,12 @@ def __init__(
if not compatible_protocol(protocol, *args):
raise ValueError("can't combine incompatible UPath protocols")

# subclasses should default to their own protocol
if not protocol:
impl_protocols = _get_implementation_protocols(type(self))
if impl_protocols:
protocol = impl_protocols[0]

if args:
args0 = args[0]
if isinstance(args0, UPath):
Expand Down
27 changes: 27 additions & 0 deletions upath/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __setitem__(self, item: str, value: type[upath.UPath] | str) -> None:
)
if not item or item in self._m:
get_upath_class.cache_clear() # type: ignore[attr-defined]
_get_implementation_protocols.cache_clear() # type: ignore[attr-defined]
self._m[item] = value

def __delitem__(self, __v: str) -> None:
Expand Down Expand Up @@ -211,6 +212,32 @@ def register_implementation(
_registry[protocol] = cls


@lru_cache # type: ignore[misc]
def _get_implementation_protocols(cls: type[upath.UPath]) -> list[str]:
"""return protocols registered for a given UPath class without triggering imports"""
if not issubclass(cls, upath.UPath):
raise ValueError(f"{cls!r} is not a UPath subclass")
if cls.__module__ == "upath.implementations._experimental":
# experimental fallback implementations have no registry entry
return [cls.__name__[1:-4].lower()]
loaded = (
p
for p, c in _registry._m.maps[0].items() # type: ignore[attr-defined]
if c is cls
)
known = (
p
for p, fqn in _registry.known_implementations.items()
if fqn == f"{cls.__module__}.{cls.__name__}"
)
eps = (
p
for p, ep in _registry._entries.items()
if ep.module == cls.__module__ and ep.attr == cls.__name__
)
return list(dict.fromkeys((*loaded, *known, *eps)))


# --- get_upath_class type overloads ------------------------------------------

if TYPE_CHECKING: # noqa: C901
Expand Down
129 changes: 110 additions & 19 deletions upath/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from upath import UPath
from upath.implementations.cloud import GCSPath
from upath.implementations.cloud import S3Path
from upath.registry import get_upath_class
from upath.registry import register_implementation
from upath.types import ReadablePath
from upath.types import WritablePath

Expand Down Expand Up @@ -112,12 +114,35 @@ def test_subclass(local_testdir):
class MyPath(UPath):
pass

with pytest.warns(
DeprecationWarning, match=r"MyPath\(...\) detected protocol '' .*"
):
path = MyPath(local_testdir)
assert str(path) == pathlib.Path(local_testdir).as_posix()
with pytest.raises(ValueError, match=r".*incompatible with"):
MyPath(local_testdir)


@pytest.fixture(scope="function")
def upath_registry_snapshot():
"""Save and restore the upath registry state around a test."""
from upath.registry import _registry

# Save the current state of the registry's mutable mapping
saved_m = _registry._m.maps[0].copy()
try:
yield
finally:
# Restore the registry state
_registry._m.maps[0].clear()
_registry._m.maps[0].update(saved_m)
get_upath_class.cache_clear()


def test_subclass_registered(upath_registry_snapshot):
class MyPath(UPath):
pass

register_implementation("memory", MyPath, clobber=True)
path = MyPath("memory:///test_path")
assert str(path) == "memory:///test_path"
assert issubclass(MyPath, UPath)
assert isinstance(path, MyPath)
assert isinstance(path, pathlib_abc.ReadablePath)
assert isinstance(path, pathlib_abc.WritablePath)
assert not isinstance(path, pathlib.Path)
Expand Down Expand Up @@ -453,33 +478,99 @@ def test_open_a_local_upath(tmp_path, protocol):
@pytest.mark.parametrize(
"uri,protocol",
[
# s3 compatible protocols
("s3://bucket/folder", "s3"),
("gs://bucket/folder", "gs"),
("s3a://bucket/folder", "s3a"),
("bucket/folder", "s3"),
# gcs compatible
("gs://bucket/folder", "gs"),
("gcs://bucket/folder", "gcs"),
("bucket/folder", "gs"),
# azure compatible
("az://container/blob", "az"),
("abfs://container/blob", "abfs"),
("abfss://container/blob", "abfss"),
("adl://container/blob", "adl"),
# memory
("memory://folder", "memory"),
("/folder", "memory"),
# file/local
("file:/tmp/folder", "file"),
("/tmp/folder", "file"),
("file:/tmp/folder", "local"),
("/tmp/folder", "local"),
("/tmp/folder", ""),
("a/b/c", ""),
# http/https
("http://example.com/path", "http"),
("https://example.com/path", "https"),
# ftp
("ftp://example.com/path", "ftp"),
# sftp/ssh
("sftp://example.com/path", "sftp"),
("ssh://example.com/path", "ssh"),
# smb
("smb://server/share/path", "smb"),
# hdfs
("hdfs://namenode/path", "hdfs"),
# webdav - requires base_url, skip for now
# github
("github://owner:repo@branch/path", "github"),
# data
("data:text/plain;base64,SGVsbG8=", "data"),
# huggingface
("hf://datasets/user/repo/path", "hf"),
],
)
def test_constructor_compatible_protocol_uri(uri, protocol):
p = UPath(uri, protocol=protocol)
assert p.protocol == protocol


@pytest.mark.parametrize(
"uri,protocol",
[
("s3://bucket/folder", "gs"),
("gs://bucket/folder", "s3"),
("memory://folder", "s3"),
("file:/tmp/folder", "s3"),
("s3://bucket/folder", ""),
("memory://folder", ""),
("file:/tmp/folder", ""),
],
)
# Protocol to sample URI mapping
_PROTOCOL_URIS = {
"s3": "s3://bucket/folder",
"gs": "gs://bucket/folder",
"az": "az://container/blob",
"memory": "memory://folder",
"file": "file:/tmp/folder",
"http": "http://example.com/path",
"ftp": "ftp://example.com/path",
"sftp": "sftp://example.com/path",
"smb": "smb://server/share/path",
"hdfs": "hdfs://namenode/path",
}

# Generate incompatible combinations: each protocol with URIs from other protocols
_INCOMPATIBLE_CASES = [
(_PROTOCOL_URIS[uri_protocol], target_protocol)
for target_protocol in _PROTOCOL_URIS
for uri_protocol in _PROTOCOL_URIS
if target_protocol != uri_protocol
]

# Also test explicit empty protocol with protocol-prefixed URIs
_INCOMPATIBLE_CASES.extend([(uri, "") for uri in _PROTOCOL_URIS.values()])


@pytest.mark.parametrize("uri,protocol", _INCOMPATIBLE_CASES)
def test_constructor_incompatible_protocol_uri(uri, protocol):
with pytest.raises(ValueError, match=r".*incompatible with"):
with pytest.raises(TypeError, match=r".*incompatible with"):
UPath(uri, protocol=protocol)


# Test subclass instantiation with incompatible URIs
# Use protocols that have registered implementations we can get via get_upath_class
_SUBCLASS_INCOMPATIBLE_CASES = [
(_PROTOCOL_URIS[uri_protocol], target_protocol)
for target_protocol in _PROTOCOL_URIS
for uri_protocol in _PROTOCOL_URIS
if target_protocol != uri_protocol
]


@pytest.mark.parametrize("uri,protocol", _SUBCLASS_INCOMPATIBLE_CASES)
def test_subclass_constructor_incompatible_protocol_uri(uri, protocol):
cls = get_upath_class(protocol)
with pytest.raises(TypeError, match=r".*incompatible with"):
cls(uri)
31 changes: 31 additions & 0 deletions upath/tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,34 @@ class MyPath(UPath):
a = MyPath(".", protocol="memory")

assert isinstance(a, MyPath)


# Protocol to sample URI mapping for compatibility tests
_PROTOCOL_URIS = {
"s3": "s3://bucket/folder",
"gs": "gs://bucket/folder",
"memory": "memory://folder",
"file": "file:/tmp/folder",
"http": "http://example.com/path",
"": "/tmp/folder",
}

# Generate incompatible combinations
_PROXY_INCOMPATIBLE_CASES = [
(_PROTOCOL_URIS[uri_protocol], target_protocol)
for target_protocol in _PROTOCOL_URIS
for uri_protocol in _PROTOCOL_URIS
if target_protocol != uri_protocol and uri_protocol != ""
]


@pytest.mark.parametrize("uri,protocol", _PROXY_INCOMPATIBLE_CASES)
def test_proxy_subclass_incompatible_protocol_uri(uri, protocol):
"""Test that ProxyUPath subclasses raise TypeError for incompatible protocols."""

class MyProxyPath(ProxyUPath):
pass

# ProxyUPath wraps the underlying path, so it should also raise TypeError
with pytest.raises(TypeError, match=r".*incompatible with"):
MyProxyPath(uri, protocol=protocol)