diff --git a/tsl/metrics/torch/metric_base.py b/tsl/metrics/torch/metric_base.py index f187f03..fe25de4 100644 --- a/tsl/metrics/torch/metric_base.py +++ b/tsl/metrics/torch/metric_base.py @@ -1,12 +1,15 @@ import inspect from copy import deepcopy from functools import partial -from typing import Any +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import torch from torchmetrics import Metric from torchmetrics.utilities.checks import _check_same_shape +from tsl.typing import Slicer +from tsl.utils.python_utils import parse_slicing_string + def convert_to_masked_metric(metric_fn, **kwargs): """ @@ -36,56 +39,115 @@ class MaskedMetric(Metric): In particular a `MaskedMetric` accounts for missing values in the input sequences by accepting a boolean mask as additional input. + Multiple metric functions can be specified, + in which case they will be averaged. + Weights can be assigned to perform a + weighted average of the different metrics. Args: - metric_fn: Base function to compute the metric point wise. - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + metric_fn (Sequence[callable], callable): + Base function to compute the metric + point-wise, multiple functions can be passed as a sequence. + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. - at (int, optional): Whether to compute the metric only w.r.t. a certain - time step. + (default: :obj:`False`) + metric_fn_kwargs (Sequence[dict], dict, optional): + Keyword arguments needed by :obj:`metric_fn`. + Use a sequence of keyword arguments if different :obj:`metric_fn` + require different arguments. + (default: :obj:`None`) + metric_fn_kwargs (Sequence[float], float, optional): + Weight assigned to each :obj:`metric_fn`. + Use a sequence if different :obj:`metric_fn` + require different weights. + (default: :obj:`None`) + at (str, Sequence[Tuple[Slicer, ...] | str], tuple[Slicer, ...], + Slicer, optional): + Numpy style slicing to define specific parts + of the output to compute the metrics on. + Either one for all metric or a sequence for each metric. + Slicing can either be a proper slicing tuple + or a string representation containing just + the part you would put inside square brackets + to index an array/tensor. + (default: :obj:`None`) + full_state_update (bool, optional): Set this to overwrite the + :obj:`full_state_update` value of the + :obj:`torchmetrics.Metric` base class. + (default: :obj:`None`) """ is_differentiable: bool = None higher_is_better: bool = None full_state_update: bool = None - def __init__(self, - metric_fn, - mask_nans=False, - mask_inf=False, - metric_fn_kwargs=None, - at=None, - full_state_update: bool = None, - **kwargs: Any): - # set 'full_state_update' before Metric instantiation - if full_state_update is not None: - self.__dict__['full_state_update'] = full_state_update - super(MaskedMetric, self).__init__(**kwargs) - + def __init__( + self, + metric_fn: Union[Sequence[Callable], Callable], + metric_fn_kwargs: Optional[Union[Sequence[Dict[str, Any]], + Dict[str, Any]]] = None, + mask_nans: bool = False, + mask_inf: bool = False, + at: Union[str, Sequence[Union[Tuple[Slicer, ...], str]], + tuple[Slicer, ...], Slicer] = ..., + weights: Optional[Sequence[float]] = None, + full_state_update: Optional[bool] = None, + **kwargs: Any, + ): + super().__init__( + metric_fn=None, + mask_nans=mask_nans, + mask_inf=mask_inf, + metric_fn_kwargs=None, + at=None, + full_state_update=full_state_update, + **kwargs, + ) + assert ( + len({ + len(e) + for e in (metric_fn, metric_fn_kwargs, at, weights) + if isinstance(e, Sequence) + }) == 1 + ), "All sequences used as masked metric arguments " \ + "must have the same length." if metric_fn_kwargs is None: - metric_fn_kwargs = dict() - - self.metric_fn = partial(metric_fn, **metric_fn_kwargs) - + metric_fn_kwargs = {} + if isinstance(metric_fn, Sequence) and isinstance( + metric_fn_kwargs, Sequence): + self.metric_fn = tuple( + partial(fn, **fn_kwargs) + for fn, fn_kwargs in zip(metric_fn, metric_fn_kwargs)) + elif isinstance(metric_fn, Sequence): + self.metric_fn = tuple( + partial(fn, **metric_fn_kwargs) for fn in metric_fn) + else: + self.metric_fn = (partial(metric_fn, **metric_fn_kwargs), ) + if isinstance(at, str) or not isinstance(at, Sequence): + at = (at, ) + at = list( + parse_slicing_string(e) if isinstance(e, str) else e for e in at) + self.at = at * len(self.metric_fn) if len(at) == 1 else at + if weights is None: + self.weights = (1.0, ) * len(self.metric_fn) + else: + self.weights = weights self.mask_nans = mask_nans self.mask_inf = mask_inf - if at is None: - self.at = slice(None) - else: - self.at = slice(at, at + 1) - self.add_state('value', - dist_reduce_fx='sum', - default=torch.tensor(0., dtype=torch.float)) - self.add_state('numel', - dist_reduce_fx='sum', - default=torch.tensor(0., dtype=torch.float)) - - def _check_mask(self, mask, val): + self.add_state("value", + dist_reduce_fx="sum", + default=torch.tensor(0.0, dtype=torch.float)) + self.add_state("numel", + dist_reduce_fx="sum", + default=torch.tensor(0.0, dtype=torch.float)) + + def _check_mask(self, mask, val, at=...): if mask is None: mask = torch.ones_like(val, dtype=torch.bool) else: - mask = mask.bool() + mask = mask[at].bool() _check_same_shape(mask, val) if self.mask_nans: mask = mask & ~torch.isnan(val) @@ -93,32 +155,21 @@ def _check_mask(self, mask, val): mask = mask & ~torch.isinf(val) return mask - def _compute_masked(self, y_hat, y, mask): - _check_same_shape(y_hat, y) - val = self.metric_fn(y_hat, y) - mask = self._check_mask(mask, val) - val = torch.where(mask, val, torch.zeros_like(val)) - return val.sum(), mask.sum() - - def _compute_std(self, y_hat, y): - _check_same_shape(y_hat, y) - val = self.metric_fn(y_hat, y) - return val.sum(), val.numel() - def is_masked(self, mask): return self.mask_inf or self.mask_nans or (mask is not None) def update(self, y_hat, y, mask=None): - y_hat = y_hat[:, self.at] - y = y[:, self.at] - if mask is not None: - mask = mask[:, self.at] - if self.is_masked(mask): - val, numel = self._compute_masked(y_hat, y, mask) - else: - val, numel = self._compute_std(y_hat, y) - self.value += val - self.numel += numel + _check_same_shape(y_hat, y) + for i in range(len(self.metric_fn)): + val = self.metric_fn[i](y_hat[self.at[i]], y[self.at[i]]) + if self.is_masked(mask): + mask = self._check_mask(mask, val, self.at[i]) + val[~mask] = 0 + numel = mask.sum() + else: + numel = val.numel() + self.value += val.sum() * self.weights[i] + self.numel += numel def compute(self): if self.numel > 0: diff --git a/tsl/metrics/torch/metrics.py b/tsl/metrics/torch/metrics.py index 5018db2..4716a8a 100644 --- a/tsl/metrics/torch/metrics.py +++ b/tsl/metrics/torch/metrics.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional import torch from torch.nn import functional as F @@ -14,11 +14,18 @@ class MaskedMAE(MaskedMetric): """Mean Absolute Error Metric. Args: - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. + (default: :obj:`False`) at (int, optional): Whether to compute the metric only w.r.t. a certain - time step. + time step. + (default: :obj:`None`) + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True @@ -26,15 +33,17 @@ class MaskedMAE(MaskedMetric): full_state_update: bool = False def __init__(self, - mask_nans=False, - mask_inf=False, - at=None, + mask_nans: bool = False, + mask_inf: bool = False, + at: Optional[int] = None, + dim: int = 1, **kwargs: Any): super(MaskedMAE, self).__init__(metric_fn=F.l1_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, + dim=dim, **kwargs) @@ -42,22 +51,33 @@ class MaskedMAPE(MaskedMetric): """Mean Absolute Percentage Error Metric. Args: - mask_nans (bool, optional): Whether to automatically mask nan values. + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False - def __init__(self, mask_nans=False, at=None, **kwargs: Any): + def __init__(self, + mask_nans: bool = False, + at: Optional[int] = None, + dim: int = 1, + **kwargs: Any): super(MaskedMAPE, self).__init__(metric_fn=mape, mask_nans=mask_nans, mask_inf=True, metric_fn_kwargs={'reduction': 'none'}, at=at, + dim=dim, **kwargs) @@ -65,11 +85,18 @@ class MaskedMSE(MaskedMetric): """Mean Squared Error Metric. Args: - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. + (default: :obj:`False`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True @@ -77,15 +104,17 @@ class MaskedMSE(MaskedMetric): full_state_update: bool = False def __init__(self, - mask_nans=False, - mask_inf=False, - at=None, + mask_nans: bool = False, + mask_inf: bool = False, + at: Optional[int] = None, + dim: int = 1, **kwargs: Any): super(MaskedMSE, self).__init__(metric_fn=F.mse_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, + dim=dim, **kwargs) @@ -93,11 +122,18 @@ class MaskedMRE(MaskedMetric): """Mean Relative Error Metric. Args: - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. + (default: :obj:`False`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True @@ -105,15 +141,17 @@ class MaskedMRE(MaskedMetric): full_state_update: bool = False def __init__(self, - mask_nans=False, - mask_inf=False, - at=None, + mask_nans: bool = False, + mask_inf: bool = False, + at: Optional[int] = None, + dim: int = 1, **kwargs: Any): super(MaskedMRE, self).__init__(metric_fn=F.l1_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, + dim=dim, **kwargs) self.add_state('tot', dist_reduce_fx='sum', @@ -138,10 +176,11 @@ def compute(self): return self.value def update(self, y_hat, y, mask=None): - y_hat = y_hat[:, self.at] - y = y[:, self.at] - if mask is not None: - mask = mask[:, self.at] + if self.at is not None: + y_hat = y_hat.select(self.dim, self.at) + y = y.select(self.dim, self.at) + if mask is not None: + mask = mask.select(self.dim, self.at) if self.is_masked(mask): val, numel, tot = self._compute_masked(y_hat, y, mask) else: diff --git a/tsl/metrics/torch/pinball_loss.py b/tsl/metrics/torch/pinball_loss.py index b530ff3..45bd45d 100644 --- a/tsl/metrics/torch/pinball_loss.py +++ b/tsl/metrics/torch/pinball_loss.py @@ -1,3 +1,5 @@ +from typing import Any, Callable, Optional + from tsl.metrics.torch import pinball_loss from tsl.metrics.torch.metric_base import MaskedMetric @@ -7,16 +9,24 @@ class MaskedPinballLoss(MaskedMetric): Args: q (float): Target quantile. - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. - compute_on_step (bool, optional): Whether to compute the metric + (default: :obj:`False`) + compute_on_step (bool): Whether to compute the metric right-away or if accumulate the results. This should be :obj:`True` when using the metric to compute a loss function, :obj:`False` if the metric is used for logging the aggregate error across different mini-batches. + (default: :obj:`True`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True @@ -24,14 +34,15 @@ class MaskedPinballLoss(MaskedMetric): full_state_update: bool = False def __init__(self, - q, - mask_nans=False, - mask_inf=False, - compute_on_step=True, - dist_sync_on_step=False, - process_group=None, - dist_sync_fn=None, - at=None): + q: float, + mask_nans: bool = False, + mask_inf: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Any = None, + dist_sync_fn: Callable = None, + at: Optional[int] = None, + dim: int = 1): super(MaskedPinballLoss, self).__init__(metric_fn=pinball_loss, mask_nans=mask_nans, @@ -41,4 +52,5 @@ def __init__(self, process_group=process_group, dist_sync_fn=dist_sync_fn, metric_fn_kwargs={'q': q}, - at=at) + at=at, + dim=dim) diff --git a/tsl/typing.py b/tsl/typing.py index 2a41a12..1e49f65 100644 --- a/tsl/typing.py +++ b/tsl/typing.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union from numpy import ndarray from pandas import DataFrame, DatetimeIndex, PeriodIndex, TimedeltaIndex @@ -33,3 +33,5 @@ "linear"]] ModelReturnOptions = Type[Union[Tensor, Dict, List, Tuple]] + +Slicer = TypeVar("Slicer", slice, type(Ellipsis), int) diff --git a/tsl/utils/python_utils.py b/tsl/utils/python_utils.py index c8382cd..b1224c7 100644 --- a/tsl/utils/python_utils.py +++ b/tsl/utils/python_utils.py @@ -1,9 +1,12 @@ import inspect import os +import re from argparse import ArgumentParser from typing import (Any, Callable, List, Mapping, Optional, Sequence, Set, Type, Union) +import numpy as np + def ensure_list(value: Any) -> List: # if isinstance(value, Sequence) and not isinstance(value, str): @@ -129,3 +132,40 @@ def filter_kwargs(target: Union[Callable, Type], kwargs: Mapping): for k, v in kwargs.items() if k in signature['signature'] } return kwargs + + +def parse_slicing_element(e: str) -> type(Ellipsis) | slice | list[Any] | int: + """ + Parses single slicing elements. + + Args: + e: string representing the slicing element. + + Returns: + The parsed element. + """ + if e == "...": + return Ellipsis + elif ":" in e: + return slice(*(int(i) if not i == "" else None for i in e.split(":"))) + elif e.startswith("[") and e.endswith("]"): + return list(int(i) for i in e[1:-1].split(",")) + else: + return int(e) + + +def parse_slicing_string(s: str) -> tuple[int | slice | type(Ellipsis)]: + """ + Parses slicing elements obtained by splitting a string at each comma + considering elements inside square brackets as individual elements. + + Args: + s: string to parse. + + Returns: + A tuple containing the parsed elements. + """ + return np.index_exp[( + parse_slicing_element(e) + for e in re.split(r'\s*,\s*(?![^\[\]]*])', s.replace(" ", "")) + if not e == "")]