Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions dreadnode/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from dreadnode.logging_ import print_info, print_success, print_warning
from dreadnode.storage.datasets.manager import DatasetManager
from dreadnode.storage.datasets.manifest import DatasetManifest, create_manifest
from dreadnode.storage.datasets.metadata import DatasetMetadata, VersionInfo
from dreadnode.storage.datasets.metadata import DatasetMetadata
from dreadnode.util import valid_version


class Dataset:
Expand Down Expand Up @@ -244,7 +245,7 @@ def _ensure_version_bump(
return True

# 5. Data Changed: Bump Version
dataset.metadata.update_version(latest_version or dataset.metadata.version)
dataset.metadata.version = dataset.metadata.bump_patch_version(latest_version)
print_warning(f"[+] Changes detected. Auto-bumping version to {dataset.metadata.version}")
return True

Expand Down Expand Up @@ -317,7 +318,8 @@ def save_dataset_to_disk(
latest_version = dataset_manager.get_version_from_path(latest_path)
if not latest_version:
raise ValueError(f"Could not determine latest version from path: {latest_path}")
if dataset_manager.compare_versions(dataset.metadata.version, latest_version) <= 0:

if dataset.metadata.version <= latest_version:
if overwrite:
print_warning(
"[!] Overwrite enabled: Proceeding to overwrite existing dataset version."
Expand All @@ -333,9 +335,9 @@ def save_dataset_to_disk(
_persist_dataset(
dataset=dataset,
path_str=dataset_manager.get_cache_save_uri(
ref=dataset.metadata.ref,
version=dataset.metadata.version,
with_version=True,
ref=dataset.metadata.versioned_ref,
# version=dataset.metadata.version,
# with_version=True,
),
create_dir=True,
dataset_manager=dataset_manager,
Expand Down Expand Up @@ -368,11 +370,17 @@ def push_dataset(
)
latest_local_path = dataset_manager.get_latest_cache_save_uri(dataset.metadata.ref)
if latest_remote_path and latest_local_path:
compared = dataset_manager.compare_versions_from_paths(
remote_path=latest_remote_path,
local_path=latest_local_path,
)
latest_path = latest_remote_path if compared == 1 else latest_local_path
remote_latest = dataset_manager.get_version_from_path(latest_remote_path)
local_latest = dataset_manager.get_version_from_path(latest_local_path)
if remote_latest and local_latest:
if remote_latest > local_latest:
latest_path = latest_remote_path
else:
latest_path = latest_local_path
elif remote_latest:
latest_path = latest_remote_path
elif local_latest:
latest_path = latest_local_path
else:
latest_path = latest_remote_path or latest_local_path

Expand Down Expand Up @@ -432,8 +440,8 @@ def load_dataset(
fsm: The DatasetManager instance.
kwargs: Additional arguments to pass to pyarrow.dataset.load_dataset.
"""
if version is not None:
VersionInfo.from_string(version)
if version and not valid_version(version):
raise ValueError(f"Invalid version string: {version}")

protocol = get_protocol(uri)

Expand Down
8 changes: 4 additions & 4 deletions dreadnode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,10 +1374,10 @@ def save_dataset_to_disk(

if version is not None:
try:
parse_version(version)
parsed = parse_version(version)
except InvalidVersion as e:
raise ValueError(f"Invalid version string: {version}") from e
ds.metadata.version = version
ds.metadata.version = parsed
ds.metadata.auto_version = False
elif strategy is not None:
ds.metadata.auto_version = True
Expand Down Expand Up @@ -1425,10 +1425,10 @@ def push_dataset(

if version is not None:
try:
parse_version(version)
parsed = parse_version(version)
except InvalidVersion as e:
raise ValueError(f"Invalid version string: {version}") from e
ds.metadata.version = version
ds.metadata.version = parsed
ds.metadata.auto_version = False
elif strategy is not None:
ds.metadata.auto_version = True
Expand Down
74 changes: 12 additions & 62 deletions dreadnode/storage/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
)
from dreadnode.logging_ import console as logging_console
from dreadnode.logging_ import print_info
from dreadnode.storage.datasets.metadata import DatasetMetadata, VersionInfo
from dreadnode.util import resolve_endpoint
from dreadnode.storage.datasets.metadata import DatasetMetadata, DatasetVersion
from dreadnode.util import parse_version, resolve_endpoint, valid_version


class DatasetManager:
Expand Down Expand Up @@ -90,7 +90,7 @@ def _needs_refresh(self) -> bool:
return (expiry - now).total_seconds() < FS_CREDENTIAL_REFRESH_BUFFER

@staticmethod
def get_version_from_path(path: str) -> str | None:
def get_version_from_path(path: str) -> DatasetVersion | None:
"""
Extracts version string from a given path.
Assumes version is the last part of the path.
Expand All @@ -100,13 +100,10 @@ def get_version_from_path(path: str) -> str | None:
resolved_path = Path(path).resolve()
version_candidate = resolved_path.name

try:
_ = VersionInfo.from_string(version_candidate)

except ValueError:
if not valid_version(version_candidate):
return None

return version_candidate
return parse_version(version_candidate)

@classmethod
def configure(
Expand Down Expand Up @@ -137,49 +134,6 @@ def check_cache(self, uri: str, version: str | None = None) -> bool:

return target_path.exists()

def compare_versions(self, specified_version: str, local_version: str) -> int:
"""
Compares two version strings.
Returns:
1 if specified_version > local_version
-1 if specified_version < local_version
0 if equal
"""

specified_ver = VersionInfo.from_string(specified_version)
local_ver = VersionInfo.from_string(local_version)

if specified_ver > local_ver:
return 1
if local_ver > specified_ver:
return -1
return 0

def compare_versions_from_paths(self, remote_path: str, local_path: str) -> int | None:
"""
Compares versions extracted from remote and local paths.
Returns:
1 if remote version > local version
-1 if remote version < local version
0 if equal
None if versions cannot be determined
"""

remote_version_str = self.get_version_from_path(remote_path)
local_version_str = self.get_version_from_path(local_path)

if not remote_version_str or not local_version_str:
return None

remote_version = VersionInfo.from_string(remote_version_str)
local_version = VersionInfo.from_string(local_version_str)

if remote_version > local_version:
return 1
if local_version > remote_version:
return -1
return 0

def ensure_dir(self, fs: FileSystem, path: str) -> None:
"""
Creates directory if local. Skips if S3.
Expand All @@ -202,17 +156,15 @@ def delete_remote_dataset_record(self, dataset_id_or_key: UUID | str) -> None:
def get_cache_save_uri(
self,
ref: str,
*,
version: str | None = None,
with_version: bool = True,
# *,
# version: DatasetVersion | None = None,
# with_version: bool = True,
) -> str:
"""
Constructs the full local cache path.
Example: /home/user/.dreadnode/datasets/main/my-dataset/1.0.0
"""
dataset_uri = Path(f"{self.cache_root}/datasets/{ref}")
if with_version and version:
dataset_uri = dataset_uri / version

dataset_uri = dataset_uri.resolve()

Expand Down Expand Up @@ -333,7 +285,7 @@ def get_remote_save_uri(self, metadata: DatasetMetadata) -> tuple[UUID, str]:
upload_request = CreateDatasetRequest(
org_key=metadata.organization,
key=metadata.name,
version=metadata.version,
version=metadata.version.public,
tags=metadata.tags,
)

Expand All @@ -342,7 +294,7 @@ def get_remote_save_uri(self, metadata: DatasetMetadata) -> tuple[UUID, str]:
user_data_access_response = create_response.user_data_access_response
else:
update_response = self._api.update_dataset_version(
existing_dataset.id, metadata.version
existing_dataset.id, str(metadata.version)
)
dataset_id = update_response.dataset.id
user_data_access_response = update_response.credentials
Expand Down Expand Up @@ -430,8 +382,6 @@ def resolve_latest_local_version(self, uri: str) -> str | None:
return None
latest: str = sorted(versions, reverse=True)[0]
# ensure it's a valid version
try:
_ = VersionInfo.from_string(latest)
except ValueError as e:
raise ValueError(f"No valid versions found in {uri}") from e
if not valid_version(latest):
return None
return latest
9 changes: 6 additions & 3 deletions dreadnode/storage/datasets/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from dreadnode.constants import MANIFEST_FILE
from dreadnode.logging_ import console as logging_console
from dreadnode.storage.datasets.metadata import DatasetVersion


class FileEntry(BaseModel):
Expand Down Expand Up @@ -146,8 +147,8 @@ def compute_file_hash(

def create_manifest(
path: str,
version: str,
parent_version: str | None = None,
version: DatasetVersion,
parent_version: DatasetVersion | None = None,
previous_manifest: DatasetManifest | None = None,
exclude_patterns: list[str] | None = None,
algorithm: str = "sha256",
Expand All @@ -156,7 +157,9 @@ def create_manifest(
if fs is None:
raise ValueError("FileSystem must be provided")

manifest = DatasetManifest(version=version, parent_version=parent_version)
manifest = DatasetManifest(
version=version.public, parent_version=parent_version.public if parent_version else None
)
exclude_patterns = exclude_patterns or []

selector = FileSelector(path, recursive=True)
Expand Down
89 changes: 23 additions & 66 deletions dreadnode/storage/datasets/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,38 @@
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from typing import Annotated, Any

import coolname
from packaging.version import Version
from pyarrow.fs import FileSystem
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import (
BaseModel,
BeforeValidator,
Field,
PlainSerializer,
WithJsonSchema,
field_validator,
model_validator,
)

from dreadnode.common_types import VersionStrategy
from dreadnode.util import parse_version


class VersionInfo(BaseModel):
major: int
minor: int
patch: int

@field_validator("major", "minor", "patch", mode="before")
@classmethod
def validate_non_negative(cls, v: Any) -> int:
iv = int(v)
if iv < 0:
raise ValueError("Version numbers must be non-negative integers")
return iv

def to_string(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"

def increment_major(self) -> None:
self.major += 1
self.minor = 0
self.patch = 0

def increment_minor(self) -> None:
self.minor += 1
self.patch = 0

def increment_patch(self) -> None:
self.patch += 1

@staticmethod
def from_string(version_str: str) -> "VersionInfo":
version = Version(version_str)
parts = [version.major, version.minor, version.micro]
if len(parts) != 3:
raise ValueError("Version string must be in the format 'major.minor.patch'")
return VersionInfo(
major=int(parts[0]),
minor=int(parts[1]),
patch=int(parts[2]),
)

def __gt__(self, other: "VersionInfo") -> bool:
if self.major != other.major:
return self.major > other.major
if self.minor != other.minor:
return self.minor > other.minor
return self.patch > other.patch
DatasetVersion = Annotated[
Version,
BeforeValidator(parse_version),
PlainSerializer(lambda v: str(v), return_type=str),
WithJsonSchema({"type": "string"}),
]


class DatasetMetadata(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
organization: str | None = None
name: str = Field(default=coolname.generate_slug(2))
uri: str | None = None
version: str = Field(default=VersionInfo(major=0, minor=1, patch=0).to_string())
version: DatasetVersion = Field(default=Version("0.1.0"))
license: str = Field(default="This dataset is not licensed.")
tags: list[str] = Field(default_factory=list)
readme: str = Field(default="# Dataset README\n\n")
Expand Down Expand Up @@ -131,7 +99,11 @@ def load(cls, path: str, fs: FileSystem) -> "DatasetMetadata":
with fs.open_input_stream(path) as f:
data = json.load(f)

return cls.model_validate(data)
return cls(**data)

@staticmethod
def bump_patch_version(version: DatasetVersion) -> DatasetVersion:
return Version(f"{version.major}.{version.minor}.{version.micro + 1}")

@property
def ref(self) -> str:
Expand All @@ -149,21 +121,6 @@ def save(self, path: str, fs: FileSystem) -> None:
with fs.open_output_stream(path) as f:
f.write(json_bytes)

def set_version(self, version: VersionInfo) -> None:
self.version = version.to_string()

def update_version(self, version: str) -> None:
version_info = VersionInfo.from_string(version)
if self.auto_version_strategy == "major":
version_info.increment_major()
elif self.auto_version_strategy == "minor":
version_info.increment_minor()
elif self.auto_version_strategy == "patch":
version_info.increment_patch()
else:
raise ValueError("part must be 'major', 'minor', or 'patch'")
self.version = version_info.to_string()

def set_license(self, license_content: str | Path) -> None:
"""
Accepts raw string content OR a local Path object to read from.
Expand Down
Loading
Loading