diff --git a/README.md b/README.md index 9996c72..f313bd7 100644 --- a/README.md +++ b/README.md @@ -54,33 +54,37 @@ With libcachesim installed, you can start cache simulation for some eviction alg ```python import libcachesim as lcs -# Step 1: Get one trace from S3 bucket -URI = "cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" -dl = lcs.DataLoader() -dl.load(URI) - -# Step 2: Open trace and process efficiently +# Step 1: Open a trace hosted on S3 (find more via https://github.com/cacheMon/cache_dataset) +URI = "s3://cache-datasets/cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" reader = lcs.TraceReader( - trace = dl.get_cache_path(URI), + trace = URI, trace_type = lcs.TraceType.ORACLE_GENERAL_TRACE, reader_init_params = lcs.ReaderInitParam(ignore_obj_size=False) ) -# Step 3: Initialize cache -cache = lcs.S3FIFO(cache_size=1024*1024) - -# Step 4: Process entire trace efficiently (C++ backend) -obj_miss_ratio, byte_miss_ratio = cache.process_trace(reader) -print(f"Object miss ratio: {obj_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") +# Step 2: Initialize cache +cache = lcs.S3FIFO( + cache_size=1024*1024, + # Cache specific parameters + small_size_ratio=0.2, + ghost_size_ratio=0.8, + move_to_main_threshold=2, +) -# Step 4.1: Process with limited number of requests -cache = lcs.S3FIFO(cache_size=1024*1024) -obj_miss_ratio, byte_miss_ratio = cache.process_trace( - reader, - start_req=0, - max_req=1000 +# Step 3: Process entire trace efficiently (C++ backend) +req_miss_ratio, byte_miss_ratio = cache.process_trace(reader) +print(f"Request miss ratio: {req_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") + +# Step 3.1: Further process the first 1000 requests again +cache = lcs.S3FIFO( + cache_size=1024 * 1024, + # Cache specific parameters + small_size_ratio=0.2, + ghost_size_ratio=0.8, + move_to_main_threshold=2, ) -print(f"Object miss ratio: {obj_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") +req_miss_ratio, byte_miss_ratio = cache.process_trace(reader, start_req=0, max_req=1000) +print(f"Request miss ratio: {req_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") ``` ## Plugin System diff --git a/docs/src/en/getting_started/quickstart.md b/docs/src/en/getting_started/quickstart.md index 70bfe22..4246046 100644 --- a/docs/src/en/getting_started/quickstart.md +++ b/docs/src/en/getting_started/quickstart.md @@ -56,33 +56,37 @@ With libcachesim installed, you can start cache simulation for some eviction alg ```python import libcachesim as lcs - # Step 1: Get one trace from S3 bucket - URI = "cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" - dl = lcs.DataLoader() - dl.load(URI) - - # Step 2: Open trace and process efficiently + # Step 1: Open a trace hosted on S3 (find more via https://github.com/cacheMon/cache_dataset) + URI = "s3://cache-datasets/cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" reader = lcs.TraceReader( - trace = dl.get_cache_path(URI), + trace = URI, trace_type = lcs.TraceType.ORACLE_GENERAL_TRACE, reader_init_params = lcs.ReaderInitParam(ignore_obj_size=False) ) - # Step 3: Initialize cache - cache = lcs.S3FIFO(cache_size=1024*1024) - - # Step 4: Process entire trace efficiently (C++ backend) - obj_miss_ratio, byte_miss_ratio = cache.process_trace(reader) - print(f"Object miss ratio: {obj_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") + # Step 2: Initialize cache + cache = lcs.S3FIFO( + cache_size=1024*1024, + # Cache specific parameters + small_size_ratio=0.2, + ghost_size_ratio=0.8, + move_to_main_threshold=2, + ) - # Step 4.1: Process with limited number of requests - cache = lcs.S3FIFO(cache_size=1024*1024) - obj_miss_ratio, byte_miss_ratio = cache.process_trace( - reader, - start_req=0, - max_req=1000 + # Step 3: Process entire trace efficiently (C++ backend) + req_miss_ratio, byte_miss_ratio = cache.process_trace(reader) + print(f"Request miss ratio: {req_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") + + # Step 3.1: Further process the first 1000 requests again + cache = lcs.S3FIFO( + cache_size=1024 * 1024, + # Cache specific parameters + small_size_ratio=0.2, + ghost_size_ratio=0.8, + move_to_main_threshold=2, ) - print(f"Object miss ratio: {obj_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") + req_miss_ratio, byte_miss_ratio = cache.process_trace(reader, start_req=0, max_req=1000) + print(f"Request miss ratio: {req_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") ``` The above example demonstrates the basic workflow of using `libcachesim` for cache simulation: diff --git a/examples/basic_usage.py b/examples/basic_usage.py index 2a4bd60..d1bf0b1 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -1,25 +1,26 @@ import libcachesim as lcs -# Step 1: Get one trace from S3 bucket -URI = "cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" -dl = lcs.DataLoader() -dl.load(URI) - -# Step 2: Open trace and process efficiently +# Step 1: Open a trace hosted on S3 (find more via https://github.com/cacheMon/cache_dataset) +URI = "s3://cache-datasets/cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" reader = lcs.TraceReader( - trace=dl.get_cache_path(URI), + trace=URI, trace_type=lcs.TraceType.ORACLE_GENERAL_TRACE, reader_init_params=lcs.ReaderInitParam(ignore_obj_size=False), ) -# Step 3: Initialize cache -cache = lcs.S3FIFO(cache_size=1024 * 1024) +# Step 2: Initialize cache +cache = lcs.S3FIFO( + cache_size=1024 * 1024, + # Cache specific parameters + small_size_ratio=0.2, + ghost_size_ratio=0.8, + move_to_main_threshold=2, +) -# Step 4: Process entire trace efficiently (C++ backend) -obj_miss_ratio, byte_miss_ratio = cache.process_trace(reader) -print(f"Object miss ratio: {obj_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") +# Step 3: Process entire trace efficiently (C++ backend) +req_miss_ratio, byte_miss_ratio = cache.process_trace(reader) +print(f"Request miss ratio: {req_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") -# Step 4.1: Process with limited number of requests -cache = lcs.S3FIFO(cache_size=1024 * 1024) -obj_miss_ratio, byte_miss_ratio = cache.process_trace(reader, start_req=0, max_req=1000) -print(f"Object miss ratio: {obj_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") +# Step 3.1: Further process the first 1000 requests again +req_miss_ratio, byte_miss_ratio = cache.process_trace(reader, start_req=0, max_req=1000) +print(f"Request miss ratio: {req_miss_ratio:.4f}, Byte miss ratio: {byte_miss_ratio:.4f}") diff --git a/libcachesim/__init__.py b/libcachesim/__init__.py index c9fc1e7..86baf3a 100644 --- a/libcachesim/__init__.py +++ b/libcachesim/__init__.py @@ -59,7 +59,6 @@ from .trace_analyzer import TraceAnalyzer from .synthetic_reader import SyntheticReader, create_zipf_requests, create_uniform_requests from .util import Util -from .data_loader import DataLoader __all__ = [ # Core classes @@ -118,8 +117,6 @@ "create_uniform_requests", # Utilities "Util", - # Data loader - "DataLoader", # Metadata "__doc__", "__version__", diff --git a/libcachesim/_s3_cache.py b/libcachesim/_s3_cache.py new file mode 100644 index 0000000..36874ff --- /dev/null +++ b/libcachesim/_s3_cache.py @@ -0,0 +1,366 @@ +"""S3 Bucket data loader with local caching (HuggingFace-style).""" + +from __future__ import annotations + +import logging +import os +import re +import shutil +from pathlib import Path +from typing import Optional, Union +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +class _DataLoader: + """Internal S3 data loader with local caching.""" + + DEFAULT_BUCKET = "cache-datasets" + DEFAULT_CACHE_DIR = Path(os.environ.get("LCS_HUB_CACHE", Path.home() / ".cache/libcachesim/hub")) + + # Characters that are problematic on various filesystems + INVALID_CHARS = set('<>:"|?*\x00') + # Reserved names on Windows + RESERVED_NAMES = { + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", + } + + def __init__( + self, bucket_name: str = DEFAULT_BUCKET, cache_dir: Optional[Union[str, Path]] = None, use_auth: bool = False + ): + self.bucket_name = self._validate_bucket_name(bucket_name) + self.cache_dir = Path(cache_dir) if cache_dir else self.DEFAULT_CACHE_DIR + self.use_auth = use_auth + self._s3_client = None + self._ensure_cache_dir() + + def _validate_bucket_name(self, bucket_name: str) -> str: + """Validate S3 bucket name according to AWS rules.""" + if not bucket_name: + raise ValueError("Bucket name cannot be empty") + + if len(bucket_name) < 3 or len(bucket_name) > 63: + raise ValueError("Bucket name must be between 3 and 63 characters") + + if not re.match(r"^[a-z0-9.-]+$", bucket_name): + raise ValueError("Bucket name can only contain lowercase letters, numbers, periods, and hyphens") + + if bucket_name.startswith(".") or bucket_name.endswith("."): + raise ValueError("Bucket name cannot start or end with a period") + + if bucket_name.startswith("-") or bucket_name.endswith("-"): + raise ValueError("Bucket name cannot start or end with a hyphen") + + if ".." in bucket_name: + raise ValueError("Bucket name cannot contain consecutive periods") + + return bucket_name + + def _validate_and_sanitize_key(self, key: str) -> str: + """Validate and sanitize S3 key for safe local filesystem usage.""" + if not key: + raise ValueError("S3 key cannot be empty") + + if len(key) > 1024: # S3 limit is 1024 bytes + raise ValueError("S3 key is too long (max 1024 characters)") + + # Check for path traversal attempts + if ".." in key: + raise ValueError("S3 key cannot contain '..' (path traversal not allowed)") + + if key.startswith("/"): + raise ValueError("S3 key cannot start with '/'") + + # Split key into parts and validate each part + parts = key.split("/") + sanitized_parts = [] + + for part in parts: + if not part: # Empty part (double slash) + continue + + # Check for reserved names (case insensitive) + if part.upper() in self.RESERVED_NAMES: + raise ValueError(f"S3 key contains reserved name: {part}") + + # Check for invalid characters + if any(c in self.INVALID_CHARS for c in part): + raise ValueError(f"S3 key contains invalid characters in part: {part}") + + # Check if part is too long for filesystem + if len(part) > 255: # Most filesystems have 255 char limit per component + raise ValueError(f"S3 key component too long: {part}") + + sanitized_parts.append(part) + + if not sanitized_parts: + raise ValueError("S3 key resulted in empty path after sanitization") + + return "/".join(sanitized_parts) + + def _ensure_cache_dir(self) -> None: + (self.cache_dir / self.bucket_name).mkdir(parents=True, exist_ok=True) + + def _get_available_disk_space(self, path: Path) -> int: + """Get available disk space in bytes.""" + try: + stat = os.statvfs(path) + return stat.f_bavail * stat.f_frsize + except (OSError, AttributeError): + # Fallback for Windows or other systems + try: + import shutil + + return shutil.disk_usage(path).free + except Exception: + logger.warning("Could not determine available disk space") + return float("inf") # Assume unlimited space if we can't check + + @property + def s3_client(self): + if self._s3_client is None: + try: + import boto3 + from botocore.config import Config + from botocore import UNSIGNED + + self._s3_client = boto3.client( + "s3", config=None if self.use_auth else Config(signature_version=UNSIGNED) + ) + except ImportError: + raise ImportError("Install boto3: pip install boto3") + return self._s3_client + + def _cache_path(self, key: str) -> Path: + """Create cache path that mirrors S3 structure after validation.""" + sanitized_key = self._validate_and_sanitize_key(key) + cache_path = self.cache_dir / self.bucket_name / sanitized_key + + # Double-check that the resolved path is still within cache directory + try: + cache_path.resolve().relative_to(self.cache_dir.resolve()) + except ValueError: + raise ValueError(f"S3 key resolves outside cache directory: {key}") + + return cache_path + + def _get_object_size(self, key: str) -> int: + """Get the size of an S3 object without downloading it.""" + try: + response = self.s3_client.head_object(Bucket=self.bucket_name, Key=key) + return response["ContentLength"] + except Exception as e: + logger.warning(f"Could not determine object size for s3://{self.bucket_name}/{key}: {e}") + return 0 + + def _download(self, key: str, dest: Path) -> None: + temp = dest.with_suffix(dest.suffix + ".tmp") + temp.parent.mkdir(parents=True, exist_ok=True) + + try: + # Check available disk space before downloading + object_size = self._get_object_size(key) + if object_size > 0: + available_space = self._get_available_disk_space(temp.parent) + if object_size > available_space: + raise RuntimeError( + f"Insufficient disk space. Need {object_size / (1024**3):.2f} GB, " + f"but only {available_space / (1024**3):.2f} GB available" + ) + + logger.info(f"Downloading s3://{self.bucket_name}/{key}") + obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) + with open(temp, "wb") as f: + f.write(obj["Body"].read()) + shutil.move(str(temp), str(dest)) + logger.info(f"Saved to: {dest}") + except Exception as e: + if temp.exists(): + temp.unlink() + raise RuntimeError(f"Download failed for s3://{self.bucket_name}/{key}: {e}") + + def load(self, key: str, force: bool = False, mode: str = "rb") -> Union[bytes, str]: + path = self._cache_path(key) + if not path.exists() or force: + self._download(key, path) + with open(path, mode) as f: + return f.read() + + def get_cached_path(self, key: str) -> str: + """Get the local cached file path, downloading if necessary.""" + path = self._cache_path(key) + if not path.exists(): + self._download(key, path) + return str(path) + + def is_cached(self, key: str) -> bool: + try: + return self._cache_path(key).exists() + except ValueError: + return False + + def clear_cache(self, key: Optional[str] = None) -> None: + if key: + try: + path = self._cache_path(key) + if path.exists(): + path.unlink() + logger.info(f"Cleared: {path}") + except ValueError as e: + logger.warning(f"Cannot clear cache for invalid key {key}: {e}") + else: + shutil.rmtree(self.cache_dir, ignore_errors=True) + logger.info(f"Cleared entire cache: {self.cache_dir}") + + def list_cached_files(self) -> list[str]: + if not self.cache_dir.exists(): + return [] + return [str(p) for p in self.cache_dir.rglob("*") if p.is_file() and not p.name.endswith(".tmp")] + + def get_cache_size(self) -> int: + return sum(p.stat().st_size for p in self.cache_dir.rglob("*") if p.is_file()) + + def list_s3_objects(self, prefix: str = "", delimiter: str = "/") -> dict: + """ + List S3 objects and pseudo-folders under a prefix. + + Args: + prefix: The S3 prefix to list under (like folder path) + delimiter: Use "/" to simulate folder structure + + Returns: + A dict with two keys: + - "folders": list of sub-prefixes (folders) + - "files": list of object keys (files) + """ + paginator = self.s3_client.get_paginator("list_objects_v2") + result = {"folders": [], "files": []} + + for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, Delimiter=delimiter): + # CommonPrefixes are like subdirectories + result["folders"].extend(cp["Prefix"] for cp in page.get("CommonPrefixes", [])) + result["files"].extend(obj["Key"] for obj in page.get("Contents", [])) + + return result + + +# Global data loader instance +_data_loader = _DataLoader() + + +def set_cache_dir(cache_dir: Union[str, Path]) -> None: + """ + Set the global cache directory for S3 downloads. + + Args: + cache_dir: Path to the cache directory + + Example: + >>> import libcachesim as lcs + >>> lcs.set_cache_dir("/tmp/my_cache") + """ + global _data_loader + _data_loader = _DataLoader(cache_dir=cache_dir) + + +def get_cache_dir() -> Path: + """ + Get the current cache directory. + + Returns: + Path to the current cache directory + + Example: + >>> import libcachesim as lcs + >>> print(lcs.get_cache_dir()) + /home/user/.cache/libcachesim/hub + """ + return _data_loader.cache_dir + + +def clear_cache(s3_path: Optional[str] = None) -> None: + """ + Clear cached files. + + Args: + s3_path: Specific S3 path to clear, or None to clear all cache + + Example: + >>> import libcachesim as lcs + >>> # Clear specific file + >>> lcs.clear_cache("s3://cache-datasets/trace1.lcs.zst") + >>> # Clear all cache + >>> lcs.clear_cache() + """ + if s3_path and s3_path.startswith("s3://"): + parsed = urlparse(s3_path) + bucket = parsed.netloc + key = parsed.path.lstrip("/") + if bucket == _data_loader.bucket_name: + _data_loader.clear_cache(key) + else: + logger.warning(f"Cannot clear cache for different bucket: {bucket}") + else: + _data_loader.clear_cache(s3_path) + + +def get_cache_size() -> int: + """ + Get total size of cached files in bytes. + + Returns: + Total cache size in bytes + + Example: + >>> import libcachesim as lcs + >>> size_mb = lcs.get_cache_size() / (1024**2) + >>> print(f"Cache size: {size_mb:.2f} MB") + """ + return _data_loader.get_cache_size() + + +def list_cached_files() -> list[str]: + """ + List all cached files. + + Returns: + List of cached file paths + + Example: + >>> import libcachesim as lcs + >>> files = lcs.list_cached_files() + >>> for file in files: + ... print(file) + """ + return _data_loader.list_cached_files() + + +def get_data_loader(bucket_name: str = None) -> _DataLoader: + """Get data loader instance for a specific bucket or the global one.""" + global _data_loader + if bucket_name is None or bucket_name == _data_loader.bucket_name: + return _data_loader + else: + return _DataLoader(bucket_name=bucket_name, cache_dir=_data_loader.cache_dir.parent) diff --git a/libcachesim/data_loader.py b/libcachesim/data_loader.py deleted file mode 100644 index f889364..0000000 --- a/libcachesim/data_loader.py +++ /dev/null @@ -1,118 +0,0 @@ -"""S3 Bucket data loader with local caching (HuggingFace-style).""" - -from __future__ import annotations - -import hashlib -import logging -import shutil -from pathlib import Path -from typing import Optional, Union -from urllib.parse import quote - -logger = logging.getLogger(__name__) - - -class DataLoader: - DEFAULT_BUCKET = "cache-datasets" - DEFAULT_CACHE_DIR = Path.home() / ".cache/libcachesim_hub" - - def __init__( - self, bucket_name: str = DEFAULT_BUCKET, cache_dir: Optional[Union[str, Path]] = None, use_auth: bool = False - ): - self.bucket_name = bucket_name - self.cache_dir = Path(cache_dir) if cache_dir else self.DEFAULT_CACHE_DIR - self.use_auth = use_auth - self._s3_client = None - self._ensure_cache_dir() - - def _ensure_cache_dir(self) -> None: - (self.cache_dir / self.bucket_name).mkdir(parents=True, exist_ok=True) - - @property - def s3_client(self): - if self._s3_client is None: - try: - import boto3 - from botocore.config import Config - from botocore import UNSIGNED - - self._s3_client = boto3.client( - "s3", config=None if self.use_auth else Config(signature_version=UNSIGNED) - ) - except ImportError: - raise ImportError("Install boto3: pip install boto3") - return self._s3_client - - def _cache_path(self, key: str) -> Path: - safe_name = hashlib.sha256(key.encode()).hexdigest()[:16] + "_" + quote(key, safe="") - return self.cache_dir / self.bucket_name / safe_name - - def _download(self, key: str, dest: Path) -> None: - temp = dest.with_suffix(dest.suffix + ".tmp") - temp.parent.mkdir(parents=True, exist_ok=True) - - try: - logger.info(f"Downloading s3://{self.bucket_name}/{key}") - obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) - with open(temp, "wb") as f: - f.write(obj["Body"].read()) - shutil.move(str(temp), str(dest)) - logger.info(f"Saved to: {dest}") - except Exception as e: - if temp.exists(): - temp.unlink() - raise RuntimeError(f"Download failed for s3://{self.bucket_name}/{key}: {e}") - - def load(self, key: str, force: bool = False, mode: str = "rb") -> Union[bytes, str]: - path = self._cache_path(key) - if not path.exists() or force: - self._download(key, path) - with open(path, mode) as f: - return f.read() - - def is_cached(self, key: str) -> bool: - return self._cache_path(key).exists() - - def get_cache_path(self, key: str) -> Path: - return self._cache_path(key).as_posix() - - def clear_cache(self, key: Optional[str] = None) -> None: - if key: - path = self._cache_path(key) - if path.exists(): - path.unlink() - logger.info(f"Cleared: {path}") - else: - shutil.rmtree(self.cache_dir, ignore_errors=True) - logger.info(f"Cleared entire cache: {self.cache_dir}") - - def list_cached_files(self) -> list[str]: - if not self.cache_dir.exists(): - return [] - return [str(p) for p in self.cache_dir.rglob("*") if p.is_file() and not p.name.endswith(".tmp")] - - def get_cache_size(self) -> int: - return sum(p.stat().st_size for p in self.cache_dir.rglob("*") if p.is_file()) - - def list_s3_objects(self, prefix: str = "", delimiter: str = "/") -> dict: - """ - List S3 objects and pseudo-folders under a prefix. - - Args: - prefix: The S3 prefix to list under (like folder path) - delimiter: Use "/" to simulate folder structure - - Returns: - A dict with two keys: - - "folders": list of sub-prefixes (folders) - - "files": list of object keys (files) - """ - paginator = self.s3_client.get_paginator("list_objects_v2") - result = {"folders": [], "files": []} - - for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, Delimiter=delimiter): - # CommonPrefixes are like subdirectories - result["folders"].extend(cp["Prefix"] for cp in page.get("CommonPrefixes", [])) - result["files"].extend(obj["Key"] for obj in page.get("Contents", [])) - - return result diff --git a/libcachesim/trace_reader.py b/libcachesim/trace_reader.py index d282a68..e593dbb 100644 --- a/libcachesim/trace_reader.py +++ b/libcachesim/trace_reader.py @@ -1,12 +1,15 @@ -"""Wrapper of Reader""" +"""Wrapper of Reader with S3 support.""" import logging from typing import overload, Union, Optional from collections.abc import Iterator +from urllib.parse import urlparse from .protocols import ReaderProtocol - from .libcachesim_python import TraceType, SamplerType, Request, ReaderInitParam, Reader, Sampler, ReadDirection +from ._s3_cache import get_data_loader + +logger = logging.getLogger(__name__) class TraceReader(ReaderProtocol): @@ -28,6 +31,10 @@ def __init__( self._reader = trace return + # Handle S3 URIs + if isinstance(trace, str) and trace.startswith("s3://"): + trace = self._resolve_s3_path(trace) + if reader_init_params is None: reader_init_params = ReaderInitParam() @@ -36,6 +43,74 @@ def __init__( self._reader = Reader(trace, trace_type, reader_init_params) + def _validate_s3_uri(self, s3_uri: str) -> tuple[str, str]: + """ + Validate and parse S3 URI. + + Args: + s3_uri: S3 URI like "s3://bucket/key" + + Returns: + Tuple of (bucket, key) + + Raises: + ValueError: If URI is invalid + """ + parsed = urlparse(s3_uri) + + if parsed.scheme != "s3": + raise ValueError(f"Invalid S3 URI scheme. Expected 's3', got '{parsed.scheme}': {s3_uri}") + + if not parsed.netloc: + raise ValueError(f"Missing bucket name in S3 URI: {s3_uri}") + + bucket = parsed.netloc + key = parsed.path.lstrip("/") + + if not key: + raise ValueError(f"Missing key (object path) in S3 URI: {s3_uri}") + + # Check for path traversal in the key part only + if ".." in key: + raise ValueError(f"S3 key contains path traversal patterns: {key}") + + # Check for double slashes in the key part (after s3://) + if "//" in key: + raise ValueError(f"S3 key contains double slashes: {key}") + + # Check for backslashes (not valid in URLs) + if "\\" in s3_uri: + raise ValueError(f"S3 URI contains backslashes: {s3_uri}") + + return bucket, key + + def _resolve_s3_path(self, s3_path: str) -> str: + """ + Resolve S3 path to local cached file path. + + Args: + s3_path: S3 URI like "s3://bucket/key" + + Returns: + Local file path + """ + try: + bucket, key = self._validate_s3_uri(s3_path) + except ValueError as e: + raise ValueError(f"Invalid S3 URI: {e}") + + # Get data loader for this bucket + try: + loader = get_data_loader(bucket) + except ValueError as e: + raise ValueError(f"Invalid bucket name '{bucket}': {e}") + + logger.info(f"Resolving S3 path: {s3_path}") + try: + return loader.get_cached_path(key) + except ValueError as e: + raise ValueError(f"Invalid S3 key '{key}': {e}") + @property def n_read_req(self) -> int: return self._reader.n_read_req diff --git a/src/libCacheSim b/src/libCacheSim index 4a2627d..e8a7194 160000 --- a/src/libCacheSim +++ b/src/libCacheSim @@ -1 +1 @@ -Subproject commit 4a2627d4221cd49c0182cc6aa1ddd82f3006cf83 +Subproject commit e8a7194d861857d4c4481b53cf24a1bfd33df2ad diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 8712a8c..6f72fcf 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -1,4 +1,4 @@ -from libcachesim import TraceAnalyzer, TraceReader, DataLoader, AnalysisOption, AnalysisParam +from libcachesim import TraceAnalyzer, TraceReader, AnalysisOption, AnalysisParam import os import pytest @@ -9,13 +9,8 @@ def test_analyzer_common(): """ # Add debugging and error handling - loader = DataLoader() - loader.load("cache_dataset_oracleGeneral/2020_tencentBlock/1K/tencentBlock_1621.oracleGeneral.zst") - file_path = loader.get_cache_path( - "cache_dataset_oracleGeneral/2020_tencentBlock/1K/tencentBlock_1621.oracleGeneral.zst" - ) - - reader = TraceReader(file_path) + URI = "s3://cache-datasets/cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" + reader = TraceReader(trace=URI) # For this specific small dataset (only 4 objects), configure analysis options more conservatively # to avoid bounds issues with the analysis modules diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py deleted file mode 100644 index 5aba6f5..0000000 --- a/tests/test_data_loader.py +++ /dev/null @@ -1,8 +0,0 @@ -from libcachesim import DataLoader - - -def test_data_loader_common(): - loader = DataLoader() - loader.load("cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst") - path = loader.get_cache_path("cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst") - filles = loader.list_s3_objects("cache_dataset_oracleGeneral/2007_msr/") diff --git a/tests/test_reader.py b/tests/test_reader.py index 688217a..a49466b 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -7,8 +7,8 @@ import pytest import tempfile import os -from libcachesim import TraceReader, SyntheticReader, DataLoader -from libcachesim.libcachesim_python import TraceType, SamplerType, Request, ReqOp, ReaderInitParam, Sampler +from libcachesim import TraceReader, SyntheticReader +from libcachesim.libcachesim_python import TraceType, SamplerType, Request, ReaderInitParam, Sampler class TestSyntheticReader: @@ -331,6 +331,14 @@ def test_invalid_sampling_ratio(self): finally: os.unlink(temp_file) + def test_trace_reader_s3(self): + """Test trace reader with S3""" + URI = "s3://cache-datasets/cache_dataset_oracleGeneral/2007_msr/msr_hm_0.oracleGeneral.zst" + reader = TraceReader(trace=URI) + for req in reader: + assert req.valid == True + break + class TestReaderCompatibility: """Test compatibility between different reader types"""