diff --git a/dreadnode/dataset.py b/dreadnode/dataset.py index 05ea2ef..cf3cb66 100644 --- a/dreadnode/dataset.py +++ b/dreadnode/dataset.py @@ -12,7 +12,8 @@ from dreadnode.logging_ import print_info, print_success, print_warning from dreadnode.storage.datasets.manager import DatasetManager from dreadnode.storage.datasets.manifest import DatasetManifest, create_manifest -from dreadnode.storage.datasets.metadata import DatasetMetadata, VersionInfo +from dreadnode.storage.datasets.metadata import DatasetMetadata +from dreadnode.util import valid_version class Dataset: @@ -244,7 +245,7 @@ def _ensure_version_bump( return True # 5. Data Changed: Bump Version - dataset.metadata.update_version(latest_version or dataset.metadata.version) + dataset.metadata.version = dataset.metadata.bump_patch_version(latest_version) print_warning(f"[+] Changes detected. Auto-bumping version to {dataset.metadata.version}") return True @@ -317,7 +318,8 @@ def save_dataset_to_disk( latest_version = dataset_manager.get_version_from_path(latest_path) if not latest_version: raise ValueError(f"Could not determine latest version from path: {latest_path}") - if dataset_manager.compare_versions(dataset.metadata.version, latest_version) <= 0: + + if dataset.metadata.version <= latest_version: if overwrite: print_warning( "[!] Overwrite enabled: Proceeding to overwrite existing dataset version." @@ -333,9 +335,9 @@ def save_dataset_to_disk( _persist_dataset( dataset=dataset, path_str=dataset_manager.get_cache_save_uri( - ref=dataset.metadata.ref, - version=dataset.metadata.version, - with_version=True, + ref=dataset.metadata.versioned_ref, + # version=dataset.metadata.version, + # with_version=True, ), create_dir=True, dataset_manager=dataset_manager, @@ -368,11 +370,17 @@ def push_dataset( ) latest_local_path = dataset_manager.get_latest_cache_save_uri(dataset.metadata.ref) if latest_remote_path and latest_local_path: - compared = dataset_manager.compare_versions_from_paths( - remote_path=latest_remote_path, - local_path=latest_local_path, - ) - latest_path = latest_remote_path if compared == 1 else latest_local_path + remote_latest = dataset_manager.get_version_from_path(latest_remote_path) + local_latest = dataset_manager.get_version_from_path(latest_local_path) + if remote_latest and local_latest: + if remote_latest > local_latest: + latest_path = latest_remote_path + else: + latest_path = latest_local_path + elif remote_latest: + latest_path = latest_remote_path + elif local_latest: + latest_path = latest_local_path else: latest_path = latest_remote_path or latest_local_path @@ -432,8 +440,8 @@ def load_dataset( fsm: The DatasetManager instance. kwargs: Additional arguments to pass to pyarrow.dataset.load_dataset. """ - if version is not None: - VersionInfo.from_string(version) + if version and not valid_version(version): + raise ValueError(f"Invalid version string: {version}") protocol = get_protocol(uri) diff --git a/dreadnode/main.py b/dreadnode/main.py index d8965db..6bda568 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -1374,10 +1374,10 @@ def save_dataset_to_disk( if version is not None: try: - parse_version(version) + parsed = parse_version(version) except InvalidVersion as e: raise ValueError(f"Invalid version string: {version}") from e - ds.metadata.version = version + ds.metadata.version = parsed ds.metadata.auto_version = False elif strategy is not None: ds.metadata.auto_version = True @@ -1425,10 +1425,10 @@ def push_dataset( if version is not None: try: - parse_version(version) + parsed = parse_version(version) except InvalidVersion as e: raise ValueError(f"Invalid version string: {version}") from e - ds.metadata.version = version + ds.metadata.version = parsed ds.metadata.auto_version = False elif strategy is not None: ds.metadata.auto_version = True diff --git a/dreadnode/storage/datasets/manager.py b/dreadnode/storage/datasets/manager.py index 647191f..97fd782 100644 --- a/dreadnode/storage/datasets/manager.py +++ b/dreadnode/storage/datasets/manager.py @@ -22,8 +22,8 @@ ) from dreadnode.logging_ import console as logging_console from dreadnode.logging_ import print_info -from dreadnode.storage.datasets.metadata import DatasetMetadata, VersionInfo -from dreadnode.util import resolve_endpoint +from dreadnode.storage.datasets.metadata import DatasetMetadata, DatasetVersion +from dreadnode.util import parse_version, resolve_endpoint, valid_version class DatasetManager: @@ -90,7 +90,7 @@ def _needs_refresh(self) -> bool: return (expiry - now).total_seconds() < FS_CREDENTIAL_REFRESH_BUFFER @staticmethod - def get_version_from_path(path: str) -> str | None: + def get_version_from_path(path: str) -> DatasetVersion | None: """ Extracts version string from a given path. Assumes version is the last part of the path. @@ -100,13 +100,10 @@ def get_version_from_path(path: str) -> str | None: resolved_path = Path(path).resolve() version_candidate = resolved_path.name - try: - _ = VersionInfo.from_string(version_candidate) - - except ValueError: + if not valid_version(version_candidate): return None - return version_candidate + return parse_version(version_candidate) @classmethod def configure( @@ -137,49 +134,6 @@ def check_cache(self, uri: str, version: str | None = None) -> bool: return target_path.exists() - def compare_versions(self, specified_version: str, local_version: str) -> int: - """ - Compares two version strings. - Returns: - 1 if specified_version > local_version - -1 if specified_version < local_version - 0 if equal - """ - - specified_ver = VersionInfo.from_string(specified_version) - local_ver = VersionInfo.from_string(local_version) - - if specified_ver > local_ver: - return 1 - if local_ver > specified_ver: - return -1 - return 0 - - def compare_versions_from_paths(self, remote_path: str, local_path: str) -> int | None: - """ - Compares versions extracted from remote and local paths. - Returns: - 1 if remote version > local version - -1 if remote version < local version - 0 if equal - None if versions cannot be determined - """ - - remote_version_str = self.get_version_from_path(remote_path) - local_version_str = self.get_version_from_path(local_path) - - if not remote_version_str or not local_version_str: - return None - - remote_version = VersionInfo.from_string(remote_version_str) - local_version = VersionInfo.from_string(local_version_str) - - if remote_version > local_version: - return 1 - if local_version > remote_version: - return -1 - return 0 - def ensure_dir(self, fs: FileSystem, path: str) -> None: """ Creates directory if local. Skips if S3. @@ -202,17 +156,15 @@ def delete_remote_dataset_record(self, dataset_id_or_key: UUID | str) -> None: def get_cache_save_uri( self, ref: str, - *, - version: str | None = None, - with_version: bool = True, + # *, + # version: DatasetVersion | None = None, + # with_version: bool = True, ) -> str: """ Constructs the full local cache path. Example: /home/user/.dreadnode/datasets/main/my-dataset/1.0.0 """ dataset_uri = Path(f"{self.cache_root}/datasets/{ref}") - if with_version and version: - dataset_uri = dataset_uri / version dataset_uri = dataset_uri.resolve() @@ -333,7 +285,7 @@ def get_remote_save_uri(self, metadata: DatasetMetadata) -> tuple[UUID, str]: upload_request = CreateDatasetRequest( org_key=metadata.organization, key=metadata.name, - version=metadata.version, + version=metadata.version.public, tags=metadata.tags, ) @@ -342,7 +294,7 @@ def get_remote_save_uri(self, metadata: DatasetMetadata) -> tuple[UUID, str]: user_data_access_response = create_response.user_data_access_response else: update_response = self._api.update_dataset_version( - existing_dataset.id, metadata.version + existing_dataset.id, str(metadata.version) ) dataset_id = update_response.dataset.id user_data_access_response = update_response.credentials @@ -430,8 +382,6 @@ def resolve_latest_local_version(self, uri: str) -> str | None: return None latest: str = sorted(versions, reverse=True)[0] # ensure it's a valid version - try: - _ = VersionInfo.from_string(latest) - except ValueError as e: - raise ValueError(f"No valid versions found in {uri}") from e + if not valid_version(latest): + return None return latest diff --git a/dreadnode/storage/datasets/manifest.py b/dreadnode/storage/datasets/manifest.py index ea78883..c92e622 100644 --- a/dreadnode/storage/datasets/manifest.py +++ b/dreadnode/storage/datasets/manifest.py @@ -12,6 +12,7 @@ from dreadnode.constants import MANIFEST_FILE from dreadnode.logging_ import console as logging_console +from dreadnode.storage.datasets.metadata import DatasetVersion class FileEntry(BaseModel): @@ -146,8 +147,8 @@ def compute_file_hash( def create_manifest( path: str, - version: str, - parent_version: str | None = None, + version: DatasetVersion, + parent_version: DatasetVersion | None = None, previous_manifest: DatasetManifest | None = None, exclude_patterns: list[str] | None = None, algorithm: str = "sha256", @@ -156,7 +157,9 @@ def create_manifest( if fs is None: raise ValueError("FileSystem must be provided") - manifest = DatasetManifest(version=version, parent_version=parent_version) + manifest = DatasetManifest( + version=version.public, parent_version=parent_version.public if parent_version else None + ) exclude_patterns = exclude_patterns or [] selector = FileSelector(path, recursive=True) diff --git a/dreadnode/storage/datasets/metadata.py b/dreadnode/storage/datasets/metadata.py index c463e43..3ccc3c4 100644 --- a/dreadnode/storage/datasets/metadata.py +++ b/dreadnode/storage/datasets/metadata.py @@ -2,62 +2,30 @@ import uuid from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import Annotated, Any import coolname from packaging.version import Version from pyarrow.fs import FileSystem -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import ( + BaseModel, + BeforeValidator, + Field, + PlainSerializer, + WithJsonSchema, + field_validator, + model_validator, +) from dreadnode.common_types import VersionStrategy +from dreadnode.util import parse_version - -class VersionInfo(BaseModel): - major: int - minor: int - patch: int - - @field_validator("major", "minor", "patch", mode="before") - @classmethod - def validate_non_negative(cls, v: Any) -> int: - iv = int(v) - if iv < 0: - raise ValueError("Version numbers must be non-negative integers") - return iv - - def to_string(self) -> str: - return f"{self.major}.{self.minor}.{self.patch}" - - def increment_major(self) -> None: - self.major += 1 - self.minor = 0 - self.patch = 0 - - def increment_minor(self) -> None: - self.minor += 1 - self.patch = 0 - - def increment_patch(self) -> None: - self.patch += 1 - - @staticmethod - def from_string(version_str: str) -> "VersionInfo": - version = Version(version_str) - parts = [version.major, version.minor, version.micro] - if len(parts) != 3: - raise ValueError("Version string must be in the format 'major.minor.patch'") - return VersionInfo( - major=int(parts[0]), - minor=int(parts[1]), - patch=int(parts[2]), - ) - - def __gt__(self, other: "VersionInfo") -> bool: - if self.major != other.major: - return self.major > other.major - if self.minor != other.minor: - return self.minor > other.minor - return self.patch > other.patch +DatasetVersion = Annotated[ + Version, + BeforeValidator(parse_version), + PlainSerializer(lambda v: str(v), return_type=str), + WithJsonSchema({"type": "string"}), +] class DatasetMetadata(BaseModel): @@ -65,7 +33,7 @@ class DatasetMetadata(BaseModel): organization: str | None = None name: str = Field(default=coolname.generate_slug(2)) uri: str | None = None - version: str = Field(default=VersionInfo(major=0, minor=1, patch=0).to_string()) + version: DatasetVersion = Field(default=Version("0.1.0")) license: str = Field(default="This dataset is not licensed.") tags: list[str] = Field(default_factory=list) readme: str = Field(default="# Dataset README\n\n") @@ -131,7 +99,11 @@ def load(cls, path: str, fs: FileSystem) -> "DatasetMetadata": with fs.open_input_stream(path) as f: data = json.load(f) - return cls.model_validate(data) + return cls(**data) + + @staticmethod + def bump_patch_version(version: DatasetVersion) -> DatasetVersion: + return Version(f"{version.major}.{version.minor}.{version.micro + 1}") @property def ref(self) -> str: @@ -149,21 +121,6 @@ def save(self, path: str, fs: FileSystem) -> None: with fs.open_output_stream(path) as f: f.write(json_bytes) - def set_version(self, version: VersionInfo) -> None: - self.version = version.to_string() - - def update_version(self, version: str) -> None: - version_info = VersionInfo.from_string(version) - if self.auto_version_strategy == "major": - version_info.increment_major() - elif self.auto_version_strategy == "minor": - version_info.increment_minor() - elif self.auto_version_strategy == "patch": - version_info.increment_patch() - else: - raise ValueError("part must be 'major', 'minor', or 'patch'") - self.version = version_info.to_string() - def set_license(self, license_content: str | Path) -> None: """ Accepts raw string content OR a local Path object to read from. diff --git a/dreadnode/util.py b/dreadnode/util.py index fe63ccf..ec38e2f 100644 --- a/dreadnode/util.py +++ b/dreadnode/util.py @@ -179,10 +179,12 @@ def valid_key(key: str) -> bool: return bool(re.fullmatch(r"[a-z0-9-]+", key)) -def valid_version(version: str) -> bool: +def valid_version(version: str | Version) -> bool: """ Check if the version is valid (semantic versioning format). """ + if isinstance(version, Version): + return True try: parsed = parse(version) return isinstance(parsed, Version) @@ -957,3 +959,24 @@ def bump_version(version_str: str, strategy: VersionStrategy) -> str: # Reconstruct the version string return f"{major}.{minor}.{micro}" + + +def parse_version(version: str | Version) -> Version: + """ + Parse a version string into a Version object. + + Args: + version_str: The version string (e.g., "1.0.1", "2.3") + + Returns: + The parsed Version object. + """ + if isinstance(version, Version): + return version + + try: + parsed = parse(version) + except Exception as e: + raise TypeError(f"Invalid version: {version}") from e + + return parsed