Skip to content

Commit

Permalink
add packedloss doc
Browse files Browse the repository at this point in the history
  • Loading branch information
gorold committed Jun 28, 2024
1 parent e5f30c0 commit c869da0
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/uni2ts/loss/packed/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@


class PackedLoss(abc.ABC):
"""
Abstract base class for loss functions supporting packed inputs.
Subclasses should implement the _loss_func method which computes the loss function per token.
"""

def __call__(
self,
pred: Any,
Expand All @@ -34,6 +39,15 @@ def __call__(
sample_id: Optional[Int[torch.Tensor, "*batch seq_len"]] = None,
variate_id: Optional[Int[torch.Tensor, "*batch seq_len"]] = None,
) -> Float[torch.Tensor, ""]:
"""
:param pred: predictions
:param target: target labels
:param prediction_mask: 1 for predictions, 0 for non-predictions
:param observed_mask: 1 for observed values, 0 for non-observed values
:param sample_id: integer array representing the sample id
:param variate_id: integer array representing the variate id
:return: loss
"""
if observed_mask is None:
observed_mask = torch.ones_like(target, dtype=torch.bool)
if sample_id is None:
Expand Down Expand Up @@ -96,6 +110,8 @@ def __repr__(self) -> str:


class PackedPointLoss(PackedLoss):
"""Abstract base class for loss functions on point forecasts."""

@abc.abstractmethod
def _loss_func(
self,
Expand All @@ -109,6 +125,8 @@ def _loss_func(


class PackedDistributionLoss(PackedLoss):
"""Abstract base class for loss functions on probabilistic (distribution) forecasts."""

@abc.abstractmethod
def _loss_func(
self,
Expand Down

0 comments on commit c869da0

Please sign in to comment.