Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 77 additions & 21 deletions src/quantem/core/datastructures/vector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import (
Any,
List,
Expand Down Expand Up @@ -95,6 +96,14 @@ class Vector(AutoSerialize):
# Set field data from flattened array
v['field2'].set_flattened(new_values) # Must match total length

Cell Access and Indexing:
-------------------------
# Access a specific cell
cell_data = v[0] # Returns a CellView object for the first cell with all fields

# Access a specific field in a cell
cell_field_data = v[0]['field0'] # Returns a 1-D array of the field0 value for the first row

Advanced Operations:
-------------------
# Complex field calculations
Expand Down Expand Up @@ -161,7 +170,7 @@ def __init__(
@classmethod
def from_shape(
cls,
shape: Tuple[int, ...],
shape: ArrayLike,
num_fields: Optional[int] = None,
fields: Optional[List[str]] = None,
units: Optional[List[str]] = None,
Expand All @@ -172,25 +181,39 @@ def from_shape(

Parameters
----------
shape : Tuple[int, ...]
The shape of the vector (dimensions)
num_fields : Optional[int]
Number of fields in the vector
name : Optional[str]
Name of the vector
fields : Optional[List[str]]
List of field names
units : Optional[List[str]]
List of units for each field
shape
The fixed indexed dimensions of the ragged vector.
Accepts any array-like input that can be converted to a tuple of integers
Including single integers for 1D vectors and empty shapes for 0D vectors.
num_fields
Number of fields in the vector (ignored if `fields` is provided).
fields
List of field names (mutually exclusive with `num_fields`).
units
Unit strings per field. If None, defaults are used.
name
Optional name.

Returns
-------
Vector
A new Vector instance
A new Vector instance.
"""
validated_shape = validate_shape(shape)
# --- Normalize 'shape' to a tuple[int, ...] to satisfy validate_shape ---
if isinstance(shape, (int, np.integer)):
shape_tuple: Tuple[int, ...] = (int(shape),)
elif isinstance(shape, tuple):
shape_tuple = tuple(int(s) for s in shape)
elif isinstance(shape, Sequence):
shape_tuple = tuple(int(s) for s in shape)
else:
raise TypeError(f"Unsupported type for shape: {type(shape)}")

# validate_shape expects a tuple and applies your project-specific checks
validated_shape = validate_shape(shape_tuple)
ndim = len(validated_shape)

# --- Fields / num_fields handling (unchanged) ---
if fields is not None:
validated_fields = validate_fields(fields)
validated_num_fields = len(validated_fields)
Expand Down Expand Up @@ -446,16 +469,18 @@ def __getitem__(
np.asarray(i) if isinstance(i, (list, np.ndarray)) else i for i in normalized
)

# Check if we should return a numpy array (all indices are integers)
return_np = all(isinstance(i, (int, np.integer)) for i in idx_converted[: len(self.shape)])
# Check if we should return a single-cell view (all indices are integers)
return_cell = all(
isinstance(i, (int, np.integer)) for i in idx_converted[: len(self.shape)]
)
if len(idx_converted) < len(self.shape):
return_np = False
return_cell = False

if return_np:
view = self._data
for i in idx_converted:
view = view[i]
return cast(NDArray[Any], view)
if return_cell:
# Return a CellView so atoms[0]['x'] works;
# still behaves like ndarray via __array__ when used numerically.
indices_tuple = tuple(int(i) for i in idx_converted[: len(self.shape)])
return _CellView(self, indices_tuple)

# Handle fancy indexing and slicing
def get_indices(dim_idx: Any, dim_size: int) -> np.ndarray:
Expand Down Expand Up @@ -1024,3 +1049,34 @@ def __getitem__(
def __array__(self) -> np.ndarray:
"""Convert to numpy array when needed."""
return self.flatten()


class _CellView:
"""
View over a single Vector cell (fixed indices over the indexed dims).
Supports item access by field name, e.g., v[0]['x'] -> 1D array for that cell.
Behaves like a numpy array via __array__ for backward compatibility.
"""

def __init__(self, vector: "Vector", indices: Tuple[int, ...]) -> None:
self.vector = vector
self.indices = indices # tuple of ints, one per indexed dimension

@property
def array(self) -> NDArray:
ref = self.vector._data
for i in self.indices:
ref = ref[i]
return ref # shape: (rows, num_fields)

def __array__(self) -> np.ndarray:
# Allows numpy to transparently consume this as an ndarray
return self.array

def __getitem__(self, field_name: str) -> NDArray:
if not isinstance(field_name, str):
raise TypeError("Use a field name string, e.g. cell['x']")
if field_name not in self.vector._fields:
raise KeyError(f"Field '{field_name}' not found.")
j = self.vector._fields.index(field_name)
return self.array[:, j]
36 changes: 36 additions & 0 deletions tests/datastructures/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,39 @@ def test_set_data_methods(self):

with pytest.raises(ValueError):
v.set_data(np.array([[1.0]]), 0, 0) # Wrong number of fields

def test_cell_view(self):
"""Test the _CellView class for cell-level access."""

# Create test data
data = [
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
np.array([[7.0, 8.0, 9.0]]),
np.array([[10.0, 11.0, 12.0], [13.0, 14.0, 15.0], [16.0, 17.0, 18.0]]),
]

# Create a Vector from the data
v = Vector.from_data(
data=data,
fields=["field0", "field1", "field2"],
name="test_vector",
units=["unit0", "unit1", "unit2"],
)

# Test cell view access
cell_data = v[0]
cell_data_array = cell_data.__array__()
np.testing.assert_array_equal(cell_data_array, data[0])
assert isinstance(cell_data_array, np.ndarray)

# Test field access through cell view
cell0_field0 = cell_data["field0"]
cell0_field0_direct = v[0]["field0"]
cell0_field0_data = np.array([1.0, 4.0])
np.testing.assert_array_equal(cell0_field0, cell0_field0_direct)
np.testing.assert_array_equal(cell0_field0, cell0_field0_data)

with pytest.raises(KeyError):
cell_data["nonexistent_field"]
with pytest.raises(IndexError):
v[3]["field0"]