diff --git a/src/quantem/core/datastructures/vector.py b/src/quantem/core/datastructures/vector.py index 9bc513a4..bc528bca 100644 --- a/src/quantem/core/datastructures/vector.py +++ b/src/quantem/core/datastructures/vector.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import ( Any, List, @@ -95,6 +96,14 @@ class Vector(AutoSerialize): # Set field data from flattened array v['field2'].set_flattened(new_values) # Must match total length + Cell Access and Indexing: + ------------------------- + # Access a specific cell + cell_data = v[0] # Returns a CellView object for the first cell with all fields + + # Access a specific field in a cell + cell_field_data = v[0]['field0'] # Returns a 1-D array of the field0 value for the first row + Advanced Operations: ------------------- # Complex field calculations @@ -161,7 +170,7 @@ def __init__( @classmethod def from_shape( cls, - shape: Tuple[int, ...], + shape: ArrayLike, num_fields: Optional[int] = None, fields: Optional[List[str]] = None, units: Optional[List[str]] = None, @@ -172,25 +181,39 @@ def from_shape( Parameters ---------- - shape : Tuple[int, ...] - The shape of the vector (dimensions) - num_fields : Optional[int] - Number of fields in the vector - name : Optional[str] - Name of the vector - fields : Optional[List[str]] - List of field names - units : Optional[List[str]] - List of units for each field + shape + The fixed indexed dimensions of the ragged vector. + Accepts any array-like input that can be converted to a tuple of integers + Including single integers for 1D vectors and empty shapes for 0D vectors. + num_fields + Number of fields in the vector (ignored if `fields` is provided). + fields + List of field names (mutually exclusive with `num_fields`). + units + Unit strings per field. If None, defaults are used. + name + Optional name. Returns ------- Vector - A new Vector instance + A new Vector instance. """ - validated_shape = validate_shape(shape) + # --- Normalize 'shape' to a tuple[int, ...] to satisfy validate_shape --- + if isinstance(shape, (int, np.integer)): + shape_tuple: Tuple[int, ...] = (int(shape),) + elif isinstance(shape, tuple): + shape_tuple = tuple(int(s) for s in shape) + elif isinstance(shape, Sequence): + shape_tuple = tuple(int(s) for s in shape) + else: + raise TypeError(f"Unsupported type for shape: {type(shape)}") + + # validate_shape expects a tuple and applies your project-specific checks + validated_shape = validate_shape(shape_tuple) ndim = len(validated_shape) + # --- Fields / num_fields handling (unchanged) --- if fields is not None: validated_fields = validate_fields(fields) validated_num_fields = len(validated_fields) @@ -446,16 +469,18 @@ def __getitem__( np.asarray(i) if isinstance(i, (list, np.ndarray)) else i for i in normalized ) - # Check if we should return a numpy array (all indices are integers) - return_np = all(isinstance(i, (int, np.integer)) for i in idx_converted[: len(self.shape)]) + # Check if we should return a single-cell view (all indices are integers) + return_cell = all( + isinstance(i, (int, np.integer)) for i in idx_converted[: len(self.shape)] + ) if len(idx_converted) < len(self.shape): - return_np = False + return_cell = False - if return_np: - view = self._data - for i in idx_converted: - view = view[i] - return cast(NDArray[Any], view) + if return_cell: + # Return a CellView so atoms[0]['x'] works; + # still behaves like ndarray via __array__ when used numerically. + indices_tuple = tuple(int(i) for i in idx_converted[: len(self.shape)]) + return _CellView(self, indices_tuple) # Handle fancy indexing and slicing def get_indices(dim_idx: Any, dim_size: int) -> np.ndarray: @@ -1024,3 +1049,34 @@ def __getitem__( def __array__(self) -> np.ndarray: """Convert to numpy array when needed.""" return self.flatten() + + +class _CellView: + """ + View over a single Vector cell (fixed indices over the indexed dims). + Supports item access by field name, e.g., v[0]['x'] -> 1D array for that cell. + Behaves like a numpy array via __array__ for backward compatibility. + """ + + def __init__(self, vector: "Vector", indices: Tuple[int, ...]) -> None: + self.vector = vector + self.indices = indices # tuple of ints, one per indexed dimension + + @property + def array(self) -> NDArray: + ref = self.vector._data + for i in self.indices: + ref = ref[i] + return ref # shape: (rows, num_fields) + + def __array__(self) -> np.ndarray: + # Allows numpy to transparently consume this as an ndarray + return self.array + + def __getitem__(self, field_name: str) -> NDArray: + if not isinstance(field_name, str): + raise TypeError("Use a field name string, e.g. cell['x']") + if field_name not in self.vector._fields: + raise KeyError(f"Field '{field_name}' not found.") + j = self.vector._fields.index(field_name) + return self.array[:, j] diff --git a/tests/datastructures/test_vector.py b/tests/datastructures/test_vector.py index e2085b76..0322331b 100644 --- a/tests/datastructures/test_vector.py +++ b/tests/datastructures/test_vector.py @@ -431,3 +431,39 @@ def test_set_data_methods(self): with pytest.raises(ValueError): v.set_data(np.array([[1.0]]), 0, 0) # Wrong number of fields + + def test_cell_view(self): + """Test the _CellView class for cell-level access.""" + + # Create test data + data = [ + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + np.array([[7.0, 8.0, 9.0]]), + np.array([[10.0, 11.0, 12.0], [13.0, 14.0, 15.0], [16.0, 17.0, 18.0]]), + ] + + # Create a Vector from the data + v = Vector.from_data( + data=data, + fields=["field0", "field1", "field2"], + name="test_vector", + units=["unit0", "unit1", "unit2"], + ) + + # Test cell view access + cell_data = v[0] + cell_data_array = cell_data.__array__() + np.testing.assert_array_equal(cell_data_array, data[0]) + assert isinstance(cell_data_array, np.ndarray) + + # Test field access through cell view + cell0_field0 = cell_data["field0"] + cell0_field0_direct = v[0]["field0"] + cell0_field0_data = np.array([1.0, 4.0]) + np.testing.assert_array_equal(cell0_field0, cell0_field0_direct) + np.testing.assert_array_equal(cell0_field0, cell0_field0_data) + + with pytest.raises(KeyError): + cell_data["nonexistent_field"] + with pytest.raises(IndexError): + v[3]["field0"]