diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index da81d372..22e50b01 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,14 +1,20 @@ name: PyNeon CI -on: [push, pull_request] +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] jobs: - ruff-format: + format: runs-on: ubuntu-latest permissions: contents: write steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 # to be able to push changes - name: Set up Python uses: actions/setup-python@v4 @@ -26,6 +32,7 @@ jobs: run: ruff format . - name: Commit changes if any + if: github.event_name == 'push' # only push changes on push events run: | git config --local user.email "github-actions[bot]@users.noreply.github.com" git config --local user.name "github-actions[bot]" @@ -33,8 +40,42 @@ jobs: git commit -m "Format code with isort and ruff" || echo "No changes to commit" git push + tests: + runs-on: ${{ matrix.os }} + env: + PYTHONFAULTHANDLER: "1" + needs: format # waits until formatting is done + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install test dependencies + run: pip install .[dev] + + - name: Run tests + if: matrix.os != 'windows-latest' + run: pytest tests -p no:cacheprovider -p no:faulthandler -p no:unraisableexception + + - name: Run tests (Windows, disable MSMF) + if: matrix.os == 'windows-latest' + env: + OPENCV_VIDEOIO_PRIORITY_MSMF: "0" + run: pytest tests -p no:cacheprovider -p no:faulthandler -p no:unraisableexception + build-docs: runs-on: ubuntu-latest + needs: format # waits until formatting is done + if: github.ref == 'refs/heads/main' && github.event_name == 'push' steps: - uses: actions/checkout@v4 @@ -44,7 +85,7 @@ jobs: python-version: "3.13" - name: Install Pandoc - run: sudo apt-get install pandoc + run: sudo apt-get install -y pandoc - name: Install docs dependencies run: pip install .[doc] @@ -65,7 +106,6 @@ jobs: - name: Deploy (GitHub Pages) uses: peaceiris/actions-gh-pages@v3 - if: github.ref == 'refs/heads/main' with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: build/html diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index 83997075..00000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: Run tests - -on: [push, pull_request] - -jobs: - test: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.10", "3.11", "3.12", "3.13"] - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Install test dependencies - run: pip install .[dev] - - - name: Run tests - run: pytest tests \ No newline at end of file diff --git a/.gitignore b/.gitignore index b5808b9d..e5ea2cf5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Test data data/ tests/outputs/ +source/tutorials/export/ # Ruff .ruff_cache/ diff --git a/README.md b/README.md index d6d0cda4..3dc61078 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ ![GitHub License](https://img.shields.io/github/license/ncc-brain/PyNeon?style=plastic) ![Website](https://img.shields.io/website?url=https%3A%2F%2Fncc-brain.github.io%2FPyNeon%2F&up_message=online&style=plastic&label=Documentation) +[![PyNeon CI](https://github.com/ncc-brain/PyNeon/actions/workflows/main.yml/badge.svg)](https://github.com/ncc-brain/PyNeon/actions/workflows/main.yml) # PyNeon @@ -14,14 +15,18 @@ PyNeon supports both **native** (data stored in the companion device) and [**Pup Documentation for PyNeon is available at which includes detailed references for classes and functions, as well as step-by-step tutorials presented as Jupyter notebooks. +We also created a few sample datasets containing short Neon recordings for testing and tutorial purposes. These datasets can be found on [OSF](https://doi.org/10.17605/OSF.IO/3N85H). We also provide a utility function `get_sample_data()` to download these sample datasets directly from PyNeon. + ## Key Features -- [(Tutorial)](https://ncc-brain.github.io/PyNeon/tutorials/read_recording.html) Easy API for reading in datasets, recordings, or individual modalities of data. +- Easy API for reading in datasets, recordings, or individual modalities of data. + - [Tutorial](https://ncc-brain.github.io/PyNeon/tutorials/read_recording_cloud.html) for reading data in Pupil Cloud format + - [Tutorial](https://ncc-brain.github.io/PyNeon/tutorials/read_recording_native.html) for reading data in native format - [(Tutorial)](https://ncc-brain.github.io/PyNeon/tutorials/interpolate_and_concat.html) Various preprocessing functions, including data cropping, interpolation, concatenation, etc. - [(Tutorial)](https://ncc-brain.github.io/PyNeon/tutorials/pupil_size_and_epoching.html) Flexible epoching of data for trial-based analysis. - [(Tutorial)](https://ncc-brain.github.io/PyNeon/tutorials/video.html) Methods for working with scene video, including scanpath estimation and AprilTags-based mapping. -- [(Tutorial)](https://ncc-brain.github.io/PyNeon/tutorials/export_to_bids.html) Exportation to [Motion-BIDS](https://www.nature.com/articles/s41597-024-03559-8) (and forthcoming Eye-Tracking-BIDS) format for interoperability across the cognitive neuroscience community. +- [(Tutorial)](https://ncc-brain.github.io/PyNeon/tutorials/export_to_bids.html) Exportation to [Motion-BIDS](https://doi.org/10.1038/s41597-024-03559-8) and [Eye-Tracking-BIDS](https://doi.org/10.64898/2026.02.03.703514) formats for interoperability across the cognitive neuroscience community. ## Installation diff --git a/pyneon/__init__.py b/pyneon/__init__.py index ae5173bf..44d625b0 100644 --- a/pyneon/__init__.py +++ b/pyneon/__init__.py @@ -1,17 +1,14 @@ # ruff: noqa: E402 __version__ = "0.0.1" -from typeguard import install_import_hook - -install_import_hook("pyneon") - from .dataset import Dataset -from .epochs import Epochs, construct_times_df, events_to_times_df +from .epochs import Epochs, construct_epochs_info, events_to_epochs_info from .events import Events from .recording import Recording from .stream import Stream from .utils import * -from .video import Video +from .video import Video, find_homographies +from .vis import plot_marker_layout __all__ = [ "Dataset", @@ -20,6 +17,8 @@ "Events", "Epochs", "Video", - "construct_times_df", - "events_to_times_df", + "plot_marker_layout", + "find_homographies", + "construct_epochs_info", + "events_to_epochs_info", ] diff --git a/pyneon/dataset.py b/pyneon/dataset.py index 93b7811b..c45d22d3 100644 --- a/pyneon/dataset.py +++ b/pyneon/dataset.py @@ -8,9 +8,10 @@ class Dataset: """ - Holder for multiple recordings. It reads from a directory containing a multiple - recordings downloaded from Pupil Cloud with the **Timeseries CSV** or - **Timeseries CSV and Scene Video** option. For example, a dataset with 2 recordings + Container for multiple recordings. Reads from a directory containing multiple + recordings. + + For example, a dataset with 2 recordings downloaded from Pupil Cloud would have the following folder structure: .. code-block:: text @@ -28,48 +29,76 @@ class Dataset: ├── enrichment_info.txt └── sections.csv - Individual recordings will be read into :class:`pyneon.Recording` objects based on - ``sections.csv``. They are accessible through the ``recordings`` attribute. + Or a dataset with multiple native recordings: + + .. code-block:: text + + dataset_dir/ + ├── recording_dir_1/ + │ ├── info.json + │ ├── blinks ps1.raw + | ├── blinks ps1.time + | ├── blinks.dtype + | └── ... + └── recording_dir_2/ + ├── info.json + ├── blinks ps1.raw + ├── blinks ps1.time + ├── blinks.dtype + └── ... + + Individual recordings will be read into :class:`Recording` instances + (based on ``sections.csv``, if available) and accessible through the + ``recordings`` attribute. Parameters ---------- dataset_dir : str or pathlib.Path Path to the directory containing the dataset. - custom : bool, optional - Whether to expect a custom dataset structure. If ``False``, the dataset - is expected to follow the standard Pupil Cloud dataset structure with a - ``sections.csv`` file. If True, every directory in ``dataset_dir`` is - considered a recording directory, and the ``sections`` attribute is - constructed from the ``info`` of recordings found. - Defaults to ``False``. Attributes ---------- dataset_dir : pathlib.Path Path to the directory containing the dataset. recordings : list of Recording - List of :class:`pyneon.Recording` objects for each recording in the dataset. + List of :class:`Recording` instances for each recording in the dataset. sections : pandas.DataFrame DataFrame containing the sections of the dataset. + Examples + -------- + >>> from pyneon import Dataset + >>> dataset = Dataset("path/to/dataset") + >>> print(dataset) + + Dataset | 2 recordings + + >>> rec = dataset.recordings[0] + >>> print(rec) + + Data format: cloud + Recording ID: 56fcec49-d660-4d67-b5ed-ba8a083a448a + Wearer ID: 028e4c69-f333-4751-af8c-84a09af079f5 + Wearer name: Pilot + Recording start time: 2025-12-18 17:13:49.460000 + Recording duration: 8235000000 ns (8.235 s) """ - def __init__(self, dataset_dir: str | Path, custom: bool = False): + def __init__(self, dataset_dir: str | Path): dataset_dir = Path(dataset_dir) if not dataset_dir.is_dir(): raise FileNotFoundError(f"Directory not found: {dataset_dir}") - self.dataset_dir = dataset_dir - self.recordings = list() + self.dataset_dir: Path = dataset_dir + self.recordings: list[Recording] = list() - if not custom: - sections_path = dataset_dir.joinpath("sections.csv") - if not sections_path.is_file(): - raise FileNotFoundError(f"sections.csv not found in {dataset_dir}") - self.sections = pd.read_csv(sections_path) + sections_path = dataset_dir / "sections.csv" + if sections_path.is_file(): + self.sections = pd.read_csv(sections_path) recording_ids = self.sections["recording id"] + # Assert if recording IDs are correct for rec_id in recording_ids: rec_id_start = rec_id.split("-")[0] rec_dir = [ @@ -104,44 +133,44 @@ def __init__(self, dataset_dir: str | Path, custom: bool = False): RuntimeWarning, ) - # Rebuild a `sections` DataFrame from the Recording objects + # Rebuild a `sections` DataFrame from the Recording instances sections = [] - for i, rec in enumerate(self.recordings): + for rec in self.recordings: sections.append( { - "section id": i, + "section id": None, "recording id": rec.recording_id, - "recording name": rec.recording_id, - "wearer id": rec.info["wearer_id"], - "wearer name": rec.info["wearer_name"], + "recording name": None, + "wearer id": rec.info.get("wearer_id", None), + "wearer name": rec.info.get("wearer_name", None), "section start time [ns]": rec.start_time, - "section end time [ns]": rec.start_time + rec.info["duration"], + "section end time [ns]": rec.start_time + + rec.info.get("duration", 0), } ) self.sections = pd.DataFrame(sections) def __repr__(self): + """Return a string representation of the Dataset. + + Returns + ------- + str + Summary showing the number of recordings. + """ return f"Dataset | {len(self.recordings)} recordings" def __len__(self): + """Return the number of recordings in the dataset. + + Returns + ------- + int + Number of recordings. + """ return len(self.recordings) def __getitem__(self, index: int) -> Recording: """Get a Recording by index.""" return self.recordings[index] - - def load_enrichment(self, enrichment_dir: str | Path): - """ - Load enrichment information from an enrichment directory. The directory must - contain an enrichment_info.txt file. Enrichment data will be parsed for each - recording ID and added to Recording object in the dataset. - - The method is currently being developed and is not yet implemented. - - Parameters - ---------- - enrichment_dir : str or pathlib.Path - Path to the directory containing the enrichment information. - """ - raise NotImplementedError("Enrichment loading is not yet implemented.") diff --git a/pyneon/epochs.py b/pyneon/epochs.py index e47dec49..88774d31 100644 --- a/pyneon/epochs.py +++ b/pyneon/epochs.py @@ -1,6 +1,7 @@ -import warnings +from functools import cached_property from numbers import Number from typing import Literal, Optional +from warnings import warn import matplotlib.pyplot as plt import numpy as np @@ -8,49 +9,28 @@ from .events import Events from .stream import Stream +from .utils.doc_decorators import fill_doc +from .utils.variables import circular_columns from .vis import plot_epochs -def _check_overlap(times_df: pd.DataFrame) -> bool: - """ - Emits warnings if any adjacent epochs overlap in time. - """ - times_df = times_df.sort_values("t_ref") - overlap = False - overlap_epochs = [] - for i in range(1, times_df.shape[0]): - # Check if the current epoch overlaps with the previous epoch - if ( - times_df["t_ref"].iloc[i] - times_df["t_before"].iloc[i] - < times_df["t_ref"].iloc[i - 1] + times_df["t_after"].iloc[i - 1] - ): - overlap_epochs.append((i - 1, i)) - overlap = True - if overlap: - warnings.warn( - f"The following epochs overlap in time:\n{overlap_epochs}", RuntimeWarning - ) - return overlap - - +@fill_doc class Epochs: """ - Class to create and manage epochs in the data streams. + Class to create and analyze epochs in the data streams. Parameters ---------- source : Stream or Events - Data to create epochs from. - times_df : pandas.DataFrame, shape (n_epochs, 4), optional - DataFrame containing epoch information with the following columns: - - ``t_ref``: Reference time of the epoch, in nanoseconds.\n - ``t_before``: Time before the reference time to start the epoch, in nanoseconds.\n - ``t_after``: Time after the reference time to end the epoch, in nanoseconds.\n - ``description``: Description or label associated with the epoch. + Data to create epochs from. Can be either a :class:`Stream` or + a :class:`Events` instance. + %(epochs_info)s Must not have empty values. + See :func:`events_to_epochs_info` or :func:`construct_epochs_info` + for helper functions to create this DataFrame. + Notes ----- An epoch spans the temporal range of ``t_ref - t_before`` to ``t_ref + t_after`` as shown below: @@ -63,98 +43,235 @@ class Epochs: Attributes ---------- - epochs : pandas.DataFrame - DataFrame containing epoch information with the following columns: - - ``t_ref`` (int64): Reference time of the epoch, in nanoseconds.\n - ``t_before`` (int64): Time before the reference time to start the epoch, in nanoseconds.\n - ``t_after`` (int64): Time after the reference time to end the epoch, in nanoseconds.\n - ``description`` (str): Description or label associated with the epoch.\n - ``data`` (object): DataFrame containing the data for each epoch. - data : pandas.DataFrame - Annotated data with epoch information. In addition to the original data columns, - the following columns are added: - - ``epoch index`` (Int32): ID of the epoch the data belongs to.\n - ``epoch time`` (Int64): Time relative to the epoch reference time, in nanoseconds.\n - ``epoch description`` (str): Description or label associated with the epoch. - - If epochs overlap, data annotations are always overwritten by the latest epoch. + %(epochs_info)s + + ======= ================================ + Column Description + ======= ================================ + t_start Start time of the epoch (``t_ref - t_before``). + t_end End time of the epoch (``t_ref + t_after``). + ======= ================================ + + source : Stream or Events + The source data used to create epochs. """ - def __init__(self, source: Stream | Events, times_df: pd.DataFrame): - if times_df.isnull().values.any(): - raise ValueError("times_df should not have any empty values") + def __init__(self, source: Stream | Events, epochs_info: pd.DataFrame): + if epochs_info.empty or epochs_info.isnull().values.any(): + raise ValueError("epochs_info must not be empty or contain NaN values.") + + epochs_info = epochs_info.sort_values("t_ref").reset_index(drop=True) + epochs_info.index.name = "epoch index" + epochs_info["t_start"] = epochs_info["t_ref"] - epochs_info["t_before"] + epochs_info["t_end"] = epochs_info["t_ref"] + epochs_info["t_after"] - # Sort by t_ref - assert times_df.shape[0] > 0, "times_df must have at least one row" - times_df = times_df.sort_values("t_ref").reset_index(drop=True) # Set columns to appropriate data types (check if columns are present along the way) - times_df = times_df.astype( + epochs_info = epochs_info.astype( { "t_ref": "int64", "t_before": "int64", "t_after": "int64", + "t_start": "int64", + "t_end": "int64", "description": "str", } ) + self.epochs_info: pd.DataFrame = epochs_info + self.source: Stream | Events = source.copy() + self._check_overlap() - if isinstance(source, Stream): - self.source_class = Stream - self.is_uniformly_sampled = source.is_uniformly_sampled - self.sf = source.sampling_freq_effective - elif isinstance(source, Events): - self.source_class = Events - self.is_uniformly_sampled = None - self.sf = None + def __len__(self): + return self.epochs_info.shape[0] + + def _check_overlap(self) -> list[tuple[int, int] | None]: + overlap_epochs = [] + for i in range(1, self.epochs_info.shape[0]): + # Check if the current epoch overlaps with the previous epoch + if ( + self.epochs_info["t_ref"].iloc[i] - self.epochs_info["t_before"].iloc[i] + < self.epochs_info["t_ref"].iloc[i - 1] + + self.epochs_info["t_after"].iloc[i - 1] + ): + overlap_epochs.append((i - 1, i)) + if overlap_epochs: + warn( + f"The following epochs overlap in time:\n{overlap_epochs}", + RuntimeWarning, + ) + return overlap_epochs - # Create epochs - self.epochs, self.data = _create_epochs(source, times_df) + @cached_property + def epochs_dict(self) -> dict[int, Stream | Events | None]: + """ + Dictionary of epochs indexed by epoch index. Each epoch contains + data cropped from the source between ``t_start`` and ``t_end``. + If no data is found for an epoch, its value is ``None``. - def __len__(self): - return self.epochs.shape[0] + Returns + ------- + dict of int to Stream or Events or None + Dictionary mapping epoch indices to their corresponding data. + """ + epochs = {} + empty_epochs = [] + for epoch_index in self.epochs_info.index: + t_ref = self.epochs_info.at[epoch_index, "t_ref"] + t_start = self.epochs_info.at[epoch_index, "t_start"] + t_end = self.epochs_info.at[epoch_index, "t_end"] + try: + epoch = self.source.crop(t_start, t_end, by="timestamp", inplace=False) + ts = epoch.ts if isinstance(epoch, Stream) else epoch.start_ts + epoch.data["epoch time [ns]"] = ts - t_ref + epochs[int(epoch_index)] = epoch + except ValueError: + empty_epochs.append(int(epoch_index)) + epochs[int(epoch_index)] = None + if empty_epochs: + warn(f"No data found for epoch(s): {empty_epochs}.", RuntimeWarning) + return epochs + + @property + def empty_epochs(self) -> list[int]: + """Indices of epochs that contain no data. + + Returns + ------- + list of int + List of epoch indices for which no data was found. + """ + return [ + int(epoch_index) + for epoch_index, epoch in self.epochs_dict.items() + if epoch is None + ] + + def annotate_source(self) -> pd.DataFrame: + """ + Create index-wise annotations of epoch indices for the source data. + + Returns + ------- + pandas.DataFrame + DataFrame with index matching the source data indices and a column + "epoch_indices" containing lists of epoch indices that include + each data point. + """ + source = self.source + epochs_info = self.epochs_info + + # Timestamps from the source + ts = source.ts if isinstance(source, Stream) else source.start_ts + source_index = source.data.index + annot = {i: [] for i in source_index} # Initialize empty lists for each index + + # Iterate over each event time to create epochs + empty_epochs = [] + for i, row in epochs_info.iterrows(): + t_ref_i, t_before_i, t_after_i = row[ + ["t_ref", "t_before", "t_after"] + ].to_list() + + start_time = t_ref_i - t_before_i + end_time = t_ref_i + t_after_i + mask = np.logical_and(ts >= start_time, ts <= end_time) + + if not mask.any(): + empty_epochs.append(int(i)) + + # Annotate the data with the epoch index + for idx in source_index[mask]: + annot[idx].append(int(i)) + + if empty_epochs: + warn(f"No data found for epoch(s): {empty_epochs}.", RuntimeWarning) + + annot_df = pd.DataFrame.from_dict( + annot, orient="index", columns=["epoch index"] + ) + return annot_df @property def t_ref(self) -> np.ndarray: - """The reference time for each epoch in UTC nanoseconds.""" - return self.epochs["t_ref"].to_numpy() + """Reference time for each epoch in Unix nanoseconds. + + Returns + ------- + numpy.ndarray + Array of reference timestamps in nanoseconds. + """ + return self.epochs_info["t_ref"].to_numpy() @property def t_before(self) -> np.ndarray: - """The time before the reference time for each epoch in nanoseconds.""" - return self.epochs["t_before"].to_numpy() + """Time before the reference time for each epoch in nanoseconds. + + Returns + ------- + numpy.ndarray + Array of time durations before reference in nanoseconds. + """ + return self.epochs_info["t_before"].to_numpy() @property def t_after(self) -> np.ndarray: - """The time after the reference time for each epoch in nanoseconds.""" - return self.epochs["t_after"].to_numpy() + """Time after the reference time for each epoch in nanoseconds. + + Returns + ------- + numpy.ndarray + Array of time durations after reference in nanoseconds. + """ + return self.epochs_info["t_after"].to_numpy() @property def description(self) -> np.ndarray: - """The description or label for each epoch.""" - return self.epochs["description"].to_numpy() + """Description or label for each epoch. + + Returns + ------- + numpy.ndarray + Array of description strings. + """ + return self.epochs_info["description"].to_numpy() @property def columns(self) -> pd.Index: - return self.data.columns[:-3] + if self.data.empty: + return pd.Index([]) + return self.data.columns.drop("epoch time [ns]", errors="ignore") @property def dtypes(self) -> pd.Series: - """The data types of the epoched data.""" - return self.data.dtypes[:-3] + """Data types of the epoched data.""" + if self.data.empty: + return pd.Series(dtype=object) + return self.data.drop(columns=["epoch time [ns]"], errors="ignore").dtypes @property def is_equal_length(self) -> bool: - """Whether all epochs have the same length.""" + """Whether all epochs have the same length. + + Returns + ------- + bool + True if all epochs have identical t_before and t_after durations. + """ return np.allclose(self.t_before, self.t_before[0]) and np.allclose( self.t_after, self.t_after[0] ) @property def has_overlap(self) -> bool: - """Whether any adjacent epochs overlap.""" - return _check_overlap(self.epochs) + """Whether any adjacent epochs overlap in time. + + Returns + ------- + bool + True if any adjacent epochs have overlapping time intervals. + """ + return self._check_overlap() != [] + @fill_doc def plot( self, column_name: Optional[str] = None, @@ -168,23 +285,17 @@ def plot( Parameters ---------- column_name : str - Name of the column to plot for :class:`pyneon.Epochs` constructed - from a :class:`pyneon.Stream`. If :class:`pyneon.Epochs` was constructed - from a :class:`pyneon.Events`, this parameter is ignored. Defaults to None. + Name of the column to plot for :class:`Epochs` constructed + from a :class:`Stream`. If :class:`Epochs` was constructed + from a :class:`Events`, this parameter is ignored. Defaults to None. cmap_name : str Colormap to use for different epochs. Defaults to 'cool'. - ax : matplotlib.axes.Axes or None - Axis to plot the data on. If ``None``, a new figure is created. - Defaults to ``None``. - show : bool - Show the figure if ``True``. Defaults to True. + %(ax_param)s + %(show_param)s Returns ------- - fig : matplotlib.figure.Figure - Figure object containing the plot. - ax : matplotlib.axes.Axes - Axis object containing the plot. + %(fig_ax_returns)s """ fig_ax = plot_epochs( self, @@ -195,269 +306,242 @@ def plot( ) return fig_ax + @fill_doc def to_numpy( self, column_names: str | list[str] = "all", + sampling_rate: Optional[Number] = None, + float_kind: str | int = "linear", + other_kind: str | int = "nearest", ) -> tuple[np.ndarray, dict]: """ - Converts epochs into a 3D array with dimensions (n_epochs, n_channels, n_times). + Converts epochs into a 3D arrays with dimensions (n_epochs, n_channels, n_times). Acts similarly as :meth:`mne.Epochs.get_data`. - Requires the epoch to be created from a uniformly-sampled :class:`pyneon.Stream`. + + Requires the epoch to be created from a :class:`Stream`. Parameters ---------- column_names : str or list of str, optional - Column names to include in the NumPy array. If 'all', all columns are included. - Only columns that can be converted to int or float can be included. - Default is 'all'. + Column names to include in the NumPy array. If "all", all columns are included. + Only numerical columns can be included. + Default to "all". + sampling_rate : numbers.Number, optional + Desired sampling rate in Hz for the output NumPy array. + If None, the nominal sampling rate of the source Stream is used. + Defaults to None. + %(interp_kind_params)s Returns ------- - numpy_epochs : numpy.ndarray + numpy.ndarray NumPy array of shape (n_epochs, n_channels, n_times). - info : dict A dictionary containing: - "column_ids": List of provided column names.\n - "t_rel": The common time grid, in nanoseconds.\n - "nan_flag": Boolean indicating whether NaN values were found in the data. - - Notes - ----- - - The time grid (``t_rel``) is in nanoseconds. - - If `NaN` values are present after interpolation, they are noted in ``nan_flag``. + ============ ================================ + epoch_times The common time grid, in nanoseconds. + column_names List of provided column names. + nan_flag Boolean indicating whether NaN values were found in the data. + ============ ================================ """ - if self.source_class != Stream or not self.is_uniformly_sampled: + if not isinstance(self.source, Stream): + raise TypeError("The source must be a Stream to convert to NumPy array.") + if not self.is_equal_length: raise ValueError( - "The source must be a uniformly-sampled Stream to convert to NumPy array." + "Epochs must have equal length (t_before and t_after) to convert to NumPy array." ) - if not self.is_equal_length: - raise ValueError("Epochs must have equal length to convert to NumPy array.") - - t_before = self.t_before[0] - t_after = self.t_after[0] - - times = np.linspace( - -t_before, t_after, int((t_before + t_after) * self.sf * 1e-9) + 1 + sf = ( + self.source.sampling_freq_nominal + if sampling_rate is None + else sampling_rate ) - n_times = len(times) + # Check if column names (str or list) are all in the source columns if column_names == "all": - columns = self.columns.to_list() - else: - columns = [column_names] if isinstance(column_names, str) else column_names - for col in columns: - if col not in self.columns: - raise ValueError(f"Column '{col}' doesn't exist in the data.") - - n_columns = len(columns) - - # Initialize the NumPy array - # MNE convention: (n_epochs, n_channels, n_times) - epochs_np = np.full((len(self), n_columns, n_times - 2), np.nan) + column_names = self.source.columns.to_list() + if isinstance(column_names, str): + column_names = [column_names] + for col in column_names: + if col not in self.source.columns: + raise ValueError(f"Column '{col}' not found in source Stream.") + + epoch_times = np.arange( + -self.epochs_info["t_before"].iloc[0], + self.epochs_info["t_after"].iloc[0], + step=int(1e9 / sf), + dtype="int64", + ) # Interpolate each epoch onto the common time grid - for i, epoch in self.epochs.iterrows(): - epoch_data = epoch["data"].copy() - epoch_time = epoch_data["epoch time"].to_numpy() - for j, col in enumerate(columns): - y = epoch_data[col].to_numpy() - interp_values = np.interp( - times, epoch_time, y, left=np.nan, right=np.nan - ) - interp_values = interp_values[1:-1] # Exclude the first and last values - epochs_np[i, j, :] = interp_values + epochs_np = np.full((len(self), len(column_names), len(epoch_times)), np.nan) + for i, row in self.epochs_info.iterrows(): + t_ref = row["t_ref"] + new_ts = epoch_times + t_ref + epoch_data = self.source.interpolate( + new_ts, + float_kind=float_kind, + other_kind=other_kind, + max_gap_ms=None, + inplace=False, + ).data[column_names] + epochs_np[i, :, :] = epoch_data.to_numpy().T - # check if there are any NaN values in the data - nan_flag = np.isnan(epochs_np).any() - if nan_flag: - warnings.warn("NaN values were found in the data.", RuntimeWarning) - - # Return an object holding the column ids, times, and data info = { - "column_ids": columns, - "epoch_times": times[1:-1] * 1e-9, # Convert to seconds - "nan_flag": nan_flag, + "epoch_times": epoch_times, + "column_names": column_names, + "nan_flag": np.isnan(epochs_np).any(), } return epochs_np, info - def baseline_correction( + def apply_baseline( self, baseline: tuple[Number | None, Number | None] = (None, 0), - method: str = "mean", + method: Literal["mean", "regression"] = "mean", + exclude_cols: list[str] = [], inplace: bool = True, - ) -> Optional[pd.DataFrame]: + ) -> dict | None: """ - Perform baseline correction on epochs. + Apply baseline correction to epochs. Only applied to columns of float type. + + The baseline data is extracted and used to compute the correction. + When ``method="mean"``, the mean of the baseline window is subtracted from the entire epoch. + When ``method="regression"``, a linear trend is fitted to the baseline window and subtracted. + + For columns containing circular data (e.g., "yaw [deg]"), the correction is applied on unwrapped + data (rad) and the result is wrapped back to degrees after correction. Parameters ---------- - baseline : (t_min, t_max), iterable of float | None - Start and end of the baseline window **in seconds**, relative to - the event trigger (t_ref = 0). ``None`` means "from the first / - up to the last sample". Default: (None, 0.0) -> the pre-trigger - part of each epoch. - method : "mean" or "linear", optional - * "mean" - Subtract the scalar mean of the baseline window. - * "linear" - Fit a first-order (:math:`y = at + b`) model *within* the - baseline window and remove the fitted trend from the entire - epoch (a very small, fast version of MNE's regression detrending). - - Defaults to "mean". - inplace : bool - If True, overwrite epochs data. - Otherwise return a **new, corrected** pandas.DataFrame - and leave the object unchanged. - Defaults to True. + baseline : tuple[Number or None, Number or None], optional + Time window (relative to reference) for baseline computation in seconds. + Defaults to (None, 0), which uses all data before the reference time. + method : {"mean", "regression"}, optional + Baseline correction method. Defaults to "mean". + exclude_cols : list of str, optional + Columns to exclude from baseline correction. Defaults to []. + inplace : bool, optional + If ``True``, replace :attr:`epochs_dict`. Otherwise returns a new instance of dict. + Defaults to ``True``. Returns ------- - - pandas.DataFrame - The baseline-corrected data (same shape & dtypes as original data). - + dict or None + A new dict with modified epoch data if ``inplace=False``, otherwise ``None``. + See :attr:`epochs_dict` for details. """ - if self.source_class != Stream: - raise ValueError( - "Baseline correction is only supported for epochs created from a Stream." - ) - def _fit_and_subtract(epoch_df: pd.DataFrame, chan_cols: list[str]) -> None: - """In-place mean or linear detrend on *one* epoch DF.""" - # mask rows within the baseline window (epoch time is int64 ns) - t_rel_sec = epoch_df["epoch time"].to_numpy() * 1e-9 + def _get_baseline_mask( + t_rel_sec: np.ndarray, + t_min: Number | None, + t_max: Number | None, + ) -> np.ndarray: + """Create boolean mask for baseline window.""" if t_min is None: - mask = t_rel_sec <= t_max + return t_rel_sec <= t_max elif t_max is None: - mask = t_rel_sec >= t_min + return t_rel_sec >= t_min else: - mask = (t_rel_sec >= t_min) & (t_rel_sec <= t_max) - - if not mask.any(): - warnings.warn( - "Baseline window is empty for at least one epoch.", - RuntimeWarning, - ) - return # nothing to correct - + return (t_rel_sec >= t_min) & (t_rel_sec <= t_max) + + def _apply_baseline_correction( + epoch_df: pd.DataFrame, + cols_to_correct: list[str], + circ_cols: list[str], + baseline_mask: np.ndarray, + epoch_time_s: np.ndarray, + method: str, + ) -> None: + """Consolidated baseline correction for linear and circular float columns.""" if method == "mean": - baseline_mean = epoch_df.loc[mask, chan_cols].mean() - epoch_df.loc[:, chan_cols] = epoch_df[chan_cols] - baseline_mean - elif method == "linear": - t_base = t_rel_sec[mask] - for col in chan_cols: - y = epoch_df.loc[mask, col].to_numpy() - - # Check for NaNs, length, and constant input + baseline_means = epoch_df.loc[baseline_mask, cols_to_correct].mean() + epoch_df.loc[:, cols_to_correct] -= baseline_means + elif method == "regression": + t_base = epoch_time_s[baseline_mask] + for col in cols_to_correct: + y = epoch_df.loc[baseline_mask, col].to_numpy() + if ( len(t_base) < 2 or np.any(np.isnan(t_base)) or np.any(np.isnan(y)) ): - warnings.warn( - f"Skipping linear baseline correction for '{col}' due to insufficient or invalid data.", - RuntimeWarning, - ) continue - if np.all(t_base == t_base[0]): - warnings.warn( - f"Skipping linear baseline correction for '{col}' due to constant timestamps.", - RuntimeWarning, - ) - continue - - # Now it's safe to fit - a, b = np.polyfit(t_base, y, 1) - epoch_df.loc[:, col] = epoch_df[col] - (a * t_rel_sec + b) - else: - raise ValueError("method must be 'mean' or 'linear'") - # ------------------------------------------------------------------ - # 1. Parse parameters - # ------------------------------------------------------------------ + # Fit trend on baseline and subtract from the whole trial + coeffs = np.polyfit(t_base, y, 1) + trend = np.polyval(coeffs, epoch_time_s) + epoch_df.loc[:, col] -= trend + + # Wrap circular columns back to range after correction + for col in circ_cols: + vals = epoch_df[col].to_numpy() + vals_rad = vals * (2 * np.pi / 360) + valid = ~np.isnan(vals_rad) + vals_unwrapped_rad = np.full_like(vals_rad, np.nan) + vals_unwrapped_rad[valid] = np.unwrap(vals_rad[valid]) + vals_deg_unwrapped = vals_unwrapped_rad * (360 / (2 * np.pi)) + vals_deg_wrapped = ((vals_deg_unwrapped + 180) % 360) - 180 + epoch_df.loc[:, col] = vals_deg_wrapped + + if not isinstance(self.source, Stream): + raise TypeError("Baseline correction requires the source to be a Stream.") + + # Parse parameters t_min, t_max = baseline if t_min is not None and t_max is not None and (t_max < t_min): raise ValueError("baseline[1] must be >= baseline[0]") - chan_cols = self.columns.to_list() - - # Work on a copy unless the caller wants in-place modification + # Determine target streams/data if inplace: - epochs_copy = self.epochs - data_copy = self.data + epochs_to_process = self.epochs_dict else: - epochs_copy = self.epochs.copy(deep=True) - data_copy = self.data.copy(deep=True) - - for idx, row in epochs_copy.iterrows(): - epoch_df: pd.DataFrame = row["data"] - _fit_and_subtract(epoch_df, chan_cols) - # write back (only needed when we are working on a *copy*) - if not inplace: - epochs_copy.at[idx, "data"] = epoch_df - # update the global data DF as well - mask = data_copy["epoch index"] == idx - data_copy.loc[mask, chan_cols] = epoch_df[chan_cols].to_numpy() - - if not inplace: - return data_copy - - -def _create_epochs( - source: Stream | Events, times_df: pd.DataFrame -) -> tuple[pd.DataFrame, pd.DataFrame]: - """ - Create epochs DataFrame and annotate the data with epoch information. - """ - _check_overlap(times_df) - - data = source.data.copy() - data["epoch index"] = pd.Series(dtype="Int32") - data["epoch time"] = pd.Series(dtype="Int64") - data["epoch description"] = pd.Series(dtype="str") - - # check for source type - if isinstance(source, Stream): - ts = source.ts - elif isinstance(source, Events): - ts = source.start_ts - else: - raise ValueError("Source must be a Stream or Events.") - - epochs = times_df.copy().reset_index(drop=True) - epochs["data"] = pd.Series(dtype="object") - - # Iterate over each event time to create epochs - for i, row in times_df.iterrows(): - t_ref_i, t_before_i, t_after_i, description_i = row[ - ["t_ref", "t_before", "t_after", "description"] - ].to_list() - - start_time = t_ref_i - t_before_i - end_time = t_ref_i + t_after_i - mask = np.logical_and(ts >= start_time, ts <= end_time) - - if not mask.any(): - warnings.warn(f"No data found for epoch {i}.", RuntimeWarning) - epochs.at[i, "epoch data"] = pd.DataFrame() - continue - - data.loc[mask, "epoch index"] = i - data.loc[mask, "epoch description"] = str(description_i) - data.loc[mask, "epoch time"] = ( - data.loc[mask].index.to_numpy() - t_ref_i - ).astype("int64") - - local_data = data.loc[mask].copy() - local_data.drop(columns=["epoch index", "epoch description"], inplace=True) - epochs.at[i, "data"] = local_data - - return epochs, data + epochs_to_process = self.epochs_dict.copy() + + # Process each epoch + for idx, epoch in epochs_to_process.items(): + if epoch is None: + continue + + epoch_df = epoch.data.copy() + # Only apply to float columns and respect excludes + cols_to_correct = [ + c + for c in epoch_df.select_dtypes(include=[float]).columns + if c not in exclude_cols + ] + + if not cols_to_correct: + continue + + # Get baseline mask + epoch_time_s = epoch_df["epoch time [ns]"].to_numpy() * 1e-9 + baseline_mask = _get_baseline_mask(epoch_time_s, t_min, t_max) + + if not baseline_mask.any(): + warn(f"Baseline window is empty for epoch {idx}.", RuntimeWarning) + continue + + # Identify which target float columns are circular + epoch_circ_cols = [c for c in cols_to_correct if c in circular_columns] + + # Step 2: Apply baseline correction (Linear or Mean) to all float columns + _apply_baseline_correction( + epoch_df, + cols_to_correct, + epoch_circ_cols, + baseline_mask, + epoch_time_s, + method, + ) + # Assign corrected data back to the epoch + epoch.data = epoch_df + return None if inplace else epochs_to_process -def events_to_times_df( +@fill_doc +def events_to_epochs_info( events: "Events", t_before: Number, t_after: Number, @@ -465,35 +549,66 @@ def events_to_times_df( event_name: str | list[str] = "all", ) -> pd.DataFrame: """ - Construct a ``times_df`` DataFrame suitable for creating epochs from event data. - For "simple" ``events`` (blinks, fixations, saccades), all events are used. - For more complex ``events`` (e.g., from "events.csv", or concatenated events), - the user can specify which events to include by a ``name`` column. + Construct a ``epochs_info`` DataFrame suitable for creating epochs around event onsets. + + For simple event classes ("blinks", "fixations", "saccades"), all events + in the input are used automatically. For more complex or combined event collections + (e.g., loaded from ``events.csv``), you can either include all events + (`event_name="all"`) or filter by specific names using ``event_name``. Parameters ---------- events : Events Events instance containing the event times. t_before : numbers.Number - Time before the event start time to start the epoch. Units specified by ``t_unit``. + Time before each event start to begin the epoch. + Interpreted according to ``t_unit``. t_after : numbers.Number - Time after the event start time to end the epoch. Units specified by ``t_unit``. + Time after each event start to end the epoch. + Interpreted according to ``t_unit``. t_unit : str, optional Unit of time for ``t_before`` and ``t_after``. - Can be "s", "ms", "us", or "ns". Default is "s". + Can be "s", "ms", "us", or "ns". Defaults to "s". event_name : str or list of str, optional - Only used if ``events`` includes more than one event type. - If "all", all events are used. Otherwise, the ``name`` column is used to filter events - whose names are in the list. Default to "all". + Only used if ``events.type`` is not one of "blinks", "fixations", or "saccades". + Otherwise, ``events.data`` must have a ``name`` column indicating event labels. + If `"all"`, all events from ``events.data`` are included, + and their ``name`` values become the epoch descriptions. + If a string or list is provided, only matching events are included. + Defaults to "all". Returns ------- - pandas.DataFrame - DataFrame with columns: ``t_ref``, ``t_before``, ``t_after``, ``description`` (all in ns). + %(epochs_info)s + + Examples + -------- + Create ``epochs_info`` from blink events: + + >>> epochs_info = events_to_epochs_info(blinks, t_before=1, t_after=1) + >>> print(epochs_info.head()) + t_ref t_before t_after description + 0 1766068460987724691 1000000000 1000000000 blink + 1 1766068462919464691 1000000000 1000000000 blink + 2 1766068463785334691 1000000000 1000000000 blink + 3 1766068464836328691 1000000000 1000000000 blink + 4 1766068465932322691 1000000000 1000000000 blink + + Create ``epochs_info`` from "flash onset" events: + + >>> epochs_info = events_to_epochs_info( + events, t_before=0.5, t_after=3, event_name="flash onset") + >>> print(epochs_info.head()) + t_ref t_before t_after description + 0 1766068461745390000 500000000 3000000000 flash onset + 1 1766068465647497000 500000000 3000000000 flash onset + 2 1766068469642822000 500000000 3000000000 flash onset + 3 1766068473635128000 500000000 3000000000 flash onset + 4 1766068477629326000 500000000 3000000000 flash onset """ + t_ref = events.start_ts if events.type in ["blinks", "fixations", "saccades"]: description = events.type[:-1] # Remove the 's' at the end - t_ref = events.start_ts else: if "name" not in events.data.columns: raise ValueError( @@ -502,30 +617,25 @@ def events_to_times_df( names = events.data["name"].astype(str) if event_name == "all": - t_ref = events.data.index.to_numpy() description = names.to_numpy() else: - if isinstance(event_name, str): - event_name = [event_name] - mask = names.isin(event_name) - if not mask.any(): - raise ValueError(f"No events found matching names: {event_name}") - filtered_data = events.data[mask] - t_ref = filtered_data.index.to_numpy() - description = filtered_data["name"].to_numpy() + matching_events = events.filter_by_name(event_name) + t_ref = matching_events.start_ts + description = matching_events.data["name"].to_numpy() - times_df = construct_times_df( + epochs_info = construct_epochs_info( t_ref, t_before, t_after, description, - "ns", - t_unit, + t_ref_unit="ns", + t_other_unit=t_unit, ) - return times_df + return epochs_info -def construct_times_df( +@fill_doc +def construct_epochs_info( t_ref: np.ndarray, t_before: np.ndarray | Number, t_after: np.ndarray | Number, @@ -535,9 +645,9 @@ def construct_times_df( global_t_ref: int = 0, ) -> pd.DataFrame: """ - Handles the construction of the ``times_df`` DataFrame for creating epochs. It populates - single values for `t_before`, `t_after`, and `description` to match the length of `t_ref`. - and converts all times to UTC timestamps in nanoseconds. + Construct the ``epochs_info`` DataFrame for creating epochs. It populates + single values for ``t_before``, ``t_after``, and ``description`` to match the length of ``t_ref`` + and converts all times to Unix timestamps in nanoseconds. Parameters ---------- @@ -564,18 +674,15 @@ def construct_times_df( Global reference time (in nanoseconds) to be added to `t_ref`. Unit is nanosecond. Defaults to 0. This is useful when the reference times are relative to a global start time - (for instance :attr:`pyneon.Stream.first_ts`). + (for instance :attr:`Stream.first_ts`). Returns ------- - pandas.DataFrame - DataFrame with columns: ``t_ref``, ``t_before``, ``t_after``, ``description`` (all in ns). + %(epochs_info)s """ - if n_epoch := len(t_ref) == 0: + if (n_epoch := len(t_ref)) == 0: raise ValueError("t_ref must not be empty") - else: - n_epoch = len(t_ref) time_factors = {"s": 1e9, "ms": 1e6, "us": 1e3, "ns": 1} @@ -586,7 +693,9 @@ def construct_times_df( if isinstance(x, np.ndarray): # Ensure it's the same length as t_ref if len(x) != n_epoch: - raise ValueError(f"{name} must have the same length as t_ref") + raise ValueError( + f"{name} must have the same length as t_ref ({n_epoch}), got {len(x)}" + ) elif isinstance(x, (Number, str)): x = np.repeat(x, n_epoch) else: @@ -594,7 +703,7 @@ def construct_times_df( # Construct the event times DataFrame # Do rounding as they should be timestamps already - times_df = pd.DataFrame( + epochs_info = pd.DataFrame( { "t_ref": t_ref * time_factors[t_ref_unit] + global_t_ref, "t_before": t_before * time_factors[t_other_unit], @@ -602,7 +711,7 @@ def construct_times_df( "description": description, } ) - times_df = times_df.astype( + epochs_info = epochs_info.astype( { "t_ref": "int64", "t_before": "int64", @@ -610,4 +719,4 @@ def construct_times_df( "description": "str", } ) - return times_df + return epochs_info diff --git a/pyneon/events.py b/pyneon/events.py index cdb0d5ea..0b232f77 100644 --- a/pyneon/events.py +++ b/pyneon/events.py @@ -8,6 +8,7 @@ import pandas as pd from .tabular import BaseTabular +from .utils import _apply_homography, _validate_df_columns from .utils.doc_decorators import fill_doc from .utils.variables import native_to_cloud_column_map @@ -111,7 +112,7 @@ def _load_native_events_data( return data, files -def _infer_events_type_and_id(data: pd.DataFrame) -> tuple[str, Optional[str]]: +def _infer_events_type(data: pd.DataFrame) -> str: """ Infer event type based on presence of specific columns. If multiple or no matches found, return "custom". @@ -121,6 +122,7 @@ def _infer_events_type_and_id(data: pd.DataFrame) -> tuple[str, Optional[str]]: "saccade id": "saccades", "fixation id": "fixations", "name": "events", + "event id": "events", } reverse_map = {v: k for k, v in col_map.items()} @@ -128,10 +130,15 @@ def _infer_events_type_and_id(data: pd.DataFrame) -> tuple[str, Optional[str]]: types = {col_map[c] for c in data.columns if c in col_map} if len(types) != 1: # None or more than one match → custom event type - return "custom", None + data.index = pd.RangeIndex(start=0, stop=len(data), name="event id") + return "custom" type = types.pop() - id_name = None if type == "events" else reverse_map[type] - return type, id_name + if type == "events": + data.index.name = "event id" + else: + data.set_index(reverse_map[type], inplace=True) + data.index = data.index.astype(np.int64) + return type class Events(BaseTabular): @@ -166,11 +173,8 @@ class Events(BaseTabular): Path to the source file(s). ``None`` if initialized from a DataFrame. data : pandas.DataFrame Event data with standardized column names. - type : {"blinks", "fixations", "saccades", "events", "custom"} + type : str Inferred event type based on data columns. - id_name : str or None - Column name holding event IDs (e.g., ``blink id``, ``fixation id``, - ``saccade id``). ``None`` for ``events`` and ``custom`` types. Examples -------- @@ -191,12 +195,16 @@ class Events(BaseTabular): >>> saccades = Events(df) """ + file: Optional[Path] + data: pd.DataFrame + type: str + def __init__(self, source: pd.DataFrame | Path | str, type: Optional[str] = None): if isinstance(source, str): source = Path(source) if isinstance(source, Path): if not source.is_file(): - raise FileNotFoundError(f"File does not exist: {source}") + raise FileNotFoundError(f"{source} does not exist") if source.suffix == ".csv": self.file = source data = pd.read_csv(source) @@ -206,12 +214,18 @@ def __init__(self, source: pd.DataFrame | Path | str, type: Optional[str] = None data = source.copy(deep=True) self.file = None super().__init__(data) - self.type, self.id_name = _infer_events_type_and_id(self.data) + self.type = _infer_events_type(self.data) def __getitem__(self, index) -> pd.Series: """Get an event series by index.""" return self.data.iloc[index] + def __repr__(self) -> str: + return f"""Events type: {self.type} +Number of samples: {len(self)} +Columns: {list(self.data.columns)} +""" + @property def start_ts(self) -> np.ndarray: """ @@ -223,9 +237,9 @@ def start_ts(self) -> np.ndarray: If no ``start timestamp [ns]`` or ``timestamp [ns]`` column is found in the instance. """ if self.type == "events": - return self.data["timestamp [ns]"].to_numpy() + return self.data["timestamp [ns]"].to_numpy(np.int64) if "start timestamp [ns]" in self.data.columns: - return self.data["start timestamp [ns]"].to_numpy() + return self.data["start timestamp [ns]"].to_numpy(np.int64) else: raise ValueError("No `start timestamp [ns]` column found in the instance.") @@ -240,7 +254,7 @@ def end_ts(self) -> np.ndarray: If no ``end timestamp [ns]`` column is found in the instance. """ if "end timestamp [ns]" in self.data.columns: - return self.data["end timestamp [ns]"].to_numpy() + return self.data["end timestamp [ns]"].to_numpy(np.int64) else: raise ValueError("No `end timestamp [ns]` column found in the instance.") @@ -262,51 +276,84 @@ def durations(self) -> np.ndarray: @property def id(self) -> np.ndarray: """ - Event ID. + Event IDs. + """ + return self.data.index.to_numpy(np.int32) - Raises - ------ - ValueError - If no ID column (e.g., `` id``) is found in the instance. + @property + def id_name(self) -> Optional[str]: """ - if self.id_name in self.data.columns and self.id_name is not None: - return self.data[self.id_name].to_numpy() - else: - raise ValueError("No ID column (e.g., ` id`) found in the instance.") + Name of the event ID column based on event type. + + Returns + ------- + str or None + The ID column name (e.g., ``"blink id"``, ``"fixation id"``, ``"saccade id"``, + ``"event id"``) for known event types, or ``None`` for custom event types. + """ + id_map = { + "blinks": "blink id", + "fixations": "fixation id", + "saccades": "saccade id", + "events": "event id", + } + return id_map.get(self.type, None) @fill_doc def crop( self, tmin: Optional[Number] = None, tmax: Optional[Number] = None, - by: Literal["timestamp", "row"] = "timestamp", + by: Literal["timestamp", "sample"] = "timestamp", inplace: bool = False, ) -> Optional["Events"]: """ - Crop data to a specific time range based on timestamps or row numbers. + Extract a subset of events within a specified temporal range. + + The ``by`` parameter determines how ``tmin`` and ``tmax`` are interpreted: + - ``"timestamp"``: Absolute Unix timestamps in nanoseconds (based on event start times) + - ``"sample"``: Zero-based event indices + + Both bounds are inclusive. If either bound is omitted, it defaults to the + events' natural boundary (earliest or latest event). Parameters ---------- - tmin : number, optional - Start timestamp/row to crop the data to. If ``None``, - the minimum timestamp/row in the data is used. Defaults to ``None``. - tmax : number, optional - End timestamp/row to crop the data to. If ``None``, - the maximum timestamp/row in the data is used. Defaults to ``None``. - by : "timestamp" or "row", optional - Whether tmin and tmax are UTC timestamps in nanoseconds - or row numbers of the stream data. - Defaults to "timestamp". - - %(inplace)s + tmin : numbers.Number, optional + Lower bound of the range to extract (inclusive). If ``None``, + starts from the first event. Defaults to ``None``. + tmax : numbers.Number, optional + Upper bound of the range to extract (inclusive). If ``None``, + extends to the last event. Defaults to ``None``. + by : {"timestamp", "sample"}, optional + Unit used to interpret ``tmin`` and ``tmax``. Defaults to ``"timestamp"``. + + %(inplace_param)s Returns ------- - Events or None - Cropped events if ``inplace=False``, otherwise ``None``. + %(events_or_none_returns)s + + Raises + ------ + ValueError + If both ``tmin`` and ``tmax`` are ``None``, or if no events + fall within the specified range. + + Examples + -------- + Crop fixations to the first 5 seconds: + + >>> fixations_5s = fixations.crop(tmin=rec.gaze.first_ts, + ... tmax=rec.gaze.first_ts + 5e9, + ... by="timestamp") + + Extract the first 100 blinks: + + >>> first_100 = blinks.crop(tmin=0, tmax=99, by="sample") """ if tmin is None and tmax is None: - raise ValueError("At least one of tmin or tmax must be provided") + raise ValueError("At least one of `tmin` or `tmax` must be provided") if by == "timestamp": t = self.start_ts else: @@ -324,20 +371,32 @@ def crop( @fill_doc def restrict(self, other: "Stream", inplace: bool = False) -> Optional["Events"]: """ - Temporally crop the events to the range of timestamps a stream. - Equivalent to ``crop(other.first_ts, other.last_ts)``. + Align events to match a stream's temporal range. + + This method filters events to include only those whose start times fall + between the first and last timestamps of the reference stream. It is + equivalent to calling + ``crop(tmin=other.first_ts, tmax=other.last_ts, by="timestamp")``. + + Useful for limiting event analysis to periods when a particular data stream + is available. Parameters ---------- other : Stream - Stream to restrict to. + Reference stream whose temporal boundaries define the cropping range. - %(inplace)s + %(inplace_param)s Returns ------- - Events or None - Restricted events if ``inplace=False``, otherwise ``None``. + %(events_or_none_returns)s + + Examples + -------- + Analyze only blinks that occurred during recorded gaze data: + + >>> blinks_with_gaze = blinks.restrict(gaze) """ return self.crop(other.first_ts, other.last_ts, by="timestamp", inplace=inplace) @@ -346,7 +405,7 @@ def filter_by_duration( self, dur_min: Optional[Number] = None, dur_max: Optional[Number] = None, - reset_id: bool = True, + reset_id: bool = False, inplace: bool = False, ) -> Optional["Events"]: """ @@ -355,22 +414,23 @@ def filter_by_duration( Parameters ---------- dur_min : number, optional - Minimum duration (in milliseconds) of events to keep. + Minimum duration (in milliseconds) of events to keep (inclusive). If ``None``, no minimum duration filter is applied. Defaults to ``None``. dur_max : number, optional - Maximum duration (in milliseconds) of events to keep. + Maximum duration (in milliseconds) of events to keep (inclusive). If ``None``, no maximum duration filter is applied. Defaults to ``None``. reset_id : bool, optional - Whether to reset event IDs after filtering. Also resets the DataFrame index. - Defaults to ``True``. + Whether to reset event IDs after filtering. + Defaults to ``False``. - %(inplace)s + %(inplace_param)s Returns ------- - Events or None - Filtered events if ``inplace=False``, otherwise ``None``. + %(events_or_none_returns)s """ + if "duration [ms]" not in self.data.columns: + raise ValueError("No `duration [ms]` column found in the instance.") if dur_min is None and dur_max is None: raise ValueError("At least one of dur_min or dur_max must be provided") dur_min = dur_min if dur_min is not None else self.durations.min() @@ -379,14 +439,153 @@ def filter_by_duration( if not mask.any(): raise ValueError("No data found in the specified duration range") print(f"Filtering out {len(self) - mask.sum()} out of {len(self)} events.") + + inst = self if inplace else self.copy() + inst.data = self.data[mask].copy() + if reset_id: + # Reset without losing original index name + inst.data.index = pd.RangeIndex( + start=0, stop=len(inst.data), name=inst.data.index.name + ) + return None if inplace else inst + + @fill_doc + def filter_by_name( + self, + names: str | list[str], + col_name: str = "name", + reset_id: bool = False, + inplace: bool = False, + ) -> Optional["Events"]: + """ + Filter events by matching values in a specified column. + Designed primarily for filtering :attr:`Recording.events` by their names. + + This method selects only the events whose value in ``col_name`` matches + one or more of the provided ``names``. If no events match, a + ``ValueError`` is raised. + + Parameters + ---------- + names : str or list of str + Event name or list of event names to keep. Matching is exact + and case-sensitive. + col_name : str, optional + Name of the column in ``self.data`` to use for filtering. + Must exist in the ``Events`` instance's DataFrame. + Defaults to ``"name"``. + reset_id: bool = False, optional + Whether to reset event IDs after filtering. + Defaults to ``False``. + %(inplace_param)s + + Returns + ------- + %(events_or_none_returns)s + """ + if col_name not in self.data.columns: + raise KeyError(f"No `{col_name}` column found in the instance.") + + names = [names] if isinstance(names, str) else names + mask = self.data[col_name].isin(names) + if not mask.any(): + raise ValueError( + f"No data found matching the specified event names {names}" + ) + inst = self if inplace else self.copy() inst.data = self.data[mask].copy() if reset_id: - if self.id_name is not None: - inst.data[self.id_name] = np.arange(len(inst.data)) + 1 - inst.data.reset_index(drop=True, inplace=True) - else: - raise KeyError( - "Cannot reset event IDs as no event ID column is known for this instance." - ) + inst.data.index = pd.RangeIndex( + start=0, stop=len(inst.data), name=inst.data.index.name + ) + return None if inplace else inst + + @fill_doc + def apply_homographies( + self, + homographies: "Stream", + max_gap_ms: Number = 500, + overwrite: bool = False, + inplace: bool = False, + ) -> Optional["Events"]: + """ + Compute fixation locations in surface coordinates using provided homographies + based on fixation pixel coordinates and append them to the events data. + + Since homographies are estimated per video frame and might not be available + for every frame, they need to be resampled/interpolated to the timestamps of the + fixation data before application. + + The events data must contain the required fixation columns: + ``fixation x [px]`` and ``fixation y [px]``. + The output events will contain two new columns: + ``fixation x [surface coord]`` and ``fixation y [surface coord]``. + + Parameters + ---------- + %(homographies)s + Returned by :func:`pyneon.find_homographies`. + %(max_gap_ms_param)s + overwrite : bool, optional + Only applicable if surface fixation columns already exist. + If ``True``, overwrite existing columns. If ``False``, raise an error. + Defaults to ``False``. + %(inplace_param)s + + Returns + ------- + %(events_or_none_returns)s + """ + inst = self if inplace else self.copy() + data = inst.data + + if not overwrite and ( + "fixation x [surface coord]" in data.columns + or "fixation y [surface coord]" in data.columns + ): + raise ValueError( + "Events already contain fixation surface data. " + "Use overwrite=True to overwrite existing columns." + ) + + required_cols = ["fixation x [px]", "fixation y [px]"] + _validate_df_columns(data, required_cols, df_name="Events data") + + data["fixation x [surface coord]"] = np.nan + data["fixation y [surface coord]"] = np.nan + + event_ts = inst.start_ts + homographies_data = homographies.interpolate( + event_ts, float_kind="linear", max_gap_ms=max_gap_ms + ).data + + h_cols = [f"homography ({i},{j})" for i in range(3) for j in range(3)] + _validate_df_columns(homographies_data, h_cols, df_name="homographies") + + homographies_data = homographies_data.dropna() + + x_col = data.columns.get_loc("fixation x [surface coord]") + y_col = data.columns.get_loc("fixation y [surface coord]") + + for event_idx, ts in enumerate(event_ts): + if ts not in homographies_data.index: + continue + + h_row = homographies_data.loc[ts] + if isinstance(h_row, pd.DataFrame): + h_row = h_row.iloc[0] + + fix_vals = data.iloc[event_idx][required_cols].values + if pd.isna(fix_vals).any(): + continue + + fix_points = np.asarray(fix_vals, dtype=np.float64).reshape(1, -1) + H_flat = h_row[h_cols].values + H = H_flat.reshape(3, 3) + fix_trans = _apply_homography(fix_points, H) + + data.iat[event_idx, x_col] = fix_trans[:, 0] + data.iat[event_idx, y_col] = fix_trans[:, 1] + return None if inplace else inst diff --git a/pyneon/export/__init__.py b/pyneon/export/__init__.py index eeb01b8b..4c91e8af 100644 --- a/pyneon/export/__init__.py +++ b/pyneon/export/__init__.py @@ -1,4 +1,4 @@ -from .export_bids import export_eye_bids, export_motion_bids +from .export_bids import export_eye_tracking_bids, export_motion_bids from .export_cloud import export_cloud_format -__all__ = ["export_motion_bids", "export_eye_bids", "export_cloud_format"] +__all__ = ["export_motion_bids", "export_eye_tracking_bids", "export_cloud_format"] diff --git a/pyneon/export/_bids_parameters.py b/pyneon/export/_bids_parameters.py index 0f6df32f..c900bc56 100644 --- a/pyneon/export/_bids_parameters.py +++ b/pyneon/export/_bids_parameters.py @@ -1,3 +1,5 @@ +from ..utils.variables import nominal_sampling_rates + MOTION_META_DEFAULT = { "TaskName": "", "TaskDescription": "", @@ -9,15 +11,39 @@ "InstitutionName": "", "InstitutionAddress": "", "InstitutionalDepartmentName": "", - "SamplingFrequency": "", - "ACCELChannelCount": 3, - "GYROChannelCount": 3, + "SamplingFrequency": nominal_sampling_rates["imu"], + "ACCELChannelCount": 0, + "ANGACCELChannelCount": 0, + "GYROChannelCount": 0, + "JNTANGChannelCount": 0, + "LATENCYChannelCount": 0, + "MAGNChannelCount": 0, + "MISCChannelCount": 0, "MissingValues": "n/a", - "MotionChannelCount": 13, - "ORNTChannelCount": 7, + "MotionChannelCount": 0, + "ORNTChannelCount": 0, + "POSChannelCount": 0, + "SamplingFrequencyEffective": nominal_sampling_rates["imu"], "SubjectArtefactDescription": "", "TrackedPointsCount": 0, - "TrackingSystemName": "IMU included in Neon", + "TrackingSystemName": "Neon IMU", + "VELChannelCount": 0, +} + +MOTION_CHANNEL_MAP = { + "gyro x": {"component": "x", "type": "GYRO", "units": "deg/s"}, + "gyro y": {"component": "y", "type": "GYRO", "units": "deg/s"}, + "gyro z": {"component": "z", "type": "GYRO", "units": "deg/s"}, + "acceleration x": {"component": "x", "type": "ACCEL", "units": "g"}, + "acceleration y": {"component": "y", "type": "ACCEL", "units": "g"}, + "acceleration z": {"component": "z", "type": "ACCEL", "units": "g"}, + "roll": {"component": "x", "type": "ORNT", "units": "deg"}, + "pitch": {"component": "y", "type": "ORNT", "units": "deg"}, + "yaw": {"component": "z", "type": "ORNT", "units": "deg"}, + "quaternion w": {"component": "quat_w", "type": "ORNT", "units": "arbitrary"}, + "quaternion x": {"component": "quat_x", "type": "ORNT", "units": "arbitrary"}, + "quaternion y": {"component": "quat_y", "type": "ORNT", "units": "arbitrary"}, + "quaternion z": {"component": "quat_z", "type": "ORNT", "units": "arbitrary"}, } EYE_META_DEFAULT = { @@ -27,10 +53,8 @@ "timestamp", "x_coordinate", "y_coordinate", - "azimuth", - "elevation", - "pupil_size_left", - "pupil_size_right", + "left_pupil_diameter", + "right_pupil_diameter", ], "DeviceSerialNumber": "", "Manufacturer": "Pupil Labs", @@ -39,23 +63,61 @@ "PhysioType": "eyetrack", "EnvironmentCoorinates": "top-left", "RecordedEye": "cyclopean", - "SampleCoordinateUnits": "pixels", "SampleCoordinateSystem": "gaze-in-world", "EyeTrackingMethod": "real-time neural network", - "azimuth": { - "Description": "Azimuth of the gaze ray in relation to the scene camera", - "Units": "degrees", + "timestamp": { + "Description": "UTC timestamp in nanoseconds of the sample", + "Units": "ns", + }, + "x_coordinate": { + # Description adapted from https://docs.pupil-labs.com/neon/data-collection/data-format/#gaze-csv + "Description": "X-coordinate of the mapped gaze point in world camera pixel coordinates.", + "Units": "pixel", }, - "elevation": { - "Description": "Elevation of the gaze ray in relation to the scene camera", - "Units": "degrees", + "y_coordinate": { + "Description": "Y-coordinate of the mapped gaze point in world camera pixel coordinates.", + "Units": "pixel", }, - "pupil_size_left": { + "left_pupil_diameter": { + # Description adapted from https://docs.pupil-labs.com/neon/data-collection/data-format/#_3d-eye-states-csv "Description": "Physical diameter of the pupil of the left eye", "Units": "mm", }, - "pupil_size_right": { + "right_pupil_diameter": { "Description": "Physical diameter of the pupil of the right eye", "Units": "mm", }, } + +EYE_EVENTS_META_DEFAULT = { + "Columns": [ + "onset", + "duration", + "trial_type", + "message", + ], + "Description": "Eye events and messages logged by Neon", + "OnsetSource": "timestamp", + "onset": { + "Description": "UTC timestamp in nanoseconds of the start of the event", + "Units": "ns", + }, + "duration": { + "Description": "Event duration", + "Units": "s", + }, + "trial_type": { + "Description": "Type of trial event", + "Levels": { + "fixation": { + "Description": "Fixation event", + }, + "saccade": { + "Description": "Saccade event", + }, + "blink": { + "Description": "Blink event", + }, + }, + }, +} diff --git a/pyneon/export/export_bids.py b/pyneon/export/export_bids.py index f7a10df9..31d5f183 100644 --- a/pyneon/export/export_bids.py +++ b/pyneon/export/export_bids.py @@ -1,17 +1,39 @@ import datetime +import gzip import json import re from pathlib import Path from typing import TYPE_CHECKING, Optional +from warnings import warn import pandas as pd -from ._bids_parameters import MOTION_META_DEFAULT +from ._bids_parameters import ( + EYE_EVENTS_META_DEFAULT, + EYE_META_DEFAULT, + MOTION_CHANNEL_MAP, + MOTION_META_DEFAULT, +) if TYPE_CHECKING: from ..recording import Recording +def _infer_prefix_from_dir(rec, output_dir): + # Infer sub and ses names from motion_dir + sub_name = f"sub-{rec.info['wearer_name']}" + ses_name = None + parent_dir = output_dir.parent + if parent_dir.name.startswith("sub-"): + sub_name = parent_dir.name + ses_name = None + elif parent_dir.name.startswith("ses-"): + ses_name = parent_dir.name + if parent_dir.parent.name.startswith("sub-"): + sub_name = parent_dir.parent.name + return sub_name, ses_name + + def export_motion_bids( rec: "Recording", motion_dir: str | Path, @@ -27,7 +49,7 @@ def export_motion_bids( Parameters ---------- rec : Recording - Recording object containing the IMU data. + Recording instance containing the IMU data. motion_dir : str or pathlib.Path Output directory to save the Motion-BIDS formatted data. prefix : str, optional @@ -49,45 +71,70 @@ def export_motion_bids( ---------- .. [1] Jeung, S., Cockx, H., Appelhoff, S., Berg, T., Gramann, K., Grothkopp, S., ... & Welzel, J. (2024). Motion-BIDS: an extension to the brain imaging data structure to organize motion data for reproducible research. *Scientific Data*, 11(1), 716. """ + imu = rec.imu.interpolate(max_gap_ms=None) motion_dir = Path(motion_dir) if not motion_dir.is_dir(): - raise FileNotFoundError(f"Directory not found: {motion_dir}") + motion_dir.mkdir(parents=True) if motion_dir.name != "motion": raise RuntimeWarning( f"Directory name {motion_dir.name} is not 'motion' as specified by Motion-BIDS" ) + + # Infer sub and ses names from motion_dir + sub_name, ses_name = _infer_prefix_from_dir(rec, motion_dir) + sub_ses_name = f"{sub_name}_{ses_name}" if ses_name else sub_name + + # If prefix is not provided, construct it using the inferred sub and ses names if prefix is None: - prefix = f"sub-{rec.info['wearer_name']}_task-XXX_tracksys-NeonIMU" + if ses_name is None: + prefix = f"{sub_name}_task-TaskName_tracksys-NeonIMU" + else: + prefix = f"{sub_name}_{ses_name}_task-TaskName_tracksys-NeonIMU" + + # Check if required fields are in the prefix + for field in ["sub-", "task-", "tracksys-"]: + if field not in prefix: + raise ValueError(f"Prefix must contain '{field}