Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional metadata to the columns, dataset and benchmark #65

Merged
merged 6 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,23 @@ Similarly, you can easily access a dataset.
```python
import polaris as po

dataset = po.load_dataset("org_or_user/name")
dataset.get_data(col=..., row=...)
# Load the dataset from the hub
dataset = po.load_dataset("polaris/hello-world-dataset")

# Get information on the dataset size
dataset.size()

# Load a datapoint in memory
dataset.get_data(
row=dataset.rows[0],
col=dataset.columns[0],
)

# Or, similarly:
dataset[dataset.rows[0], dataset.columns[0]]

# Get the first 10 rows in memory
dataset[:10]
```

## Core concepts
Expand Down
109 changes: 101 additions & 8 deletions polaris/benchmark/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import numpy as np
import pandas as pd
from pydantic import (
Field,
FieldValidationInfo,
computed_field,
field_serializer,
field_validator,
model_validator,
)
from sklearn.utils.multiclass import type_of_target

from polaris._artifact import BaseArtifactModel
from polaris.dataset import Dataset, Subset
Expand All @@ -22,7 +25,15 @@
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError
from polaris.utils.misc import listit
from polaris.utils.types import AccessType, DataFormat, HubOwner, PredictionsType, SplitType
from polaris.utils.types import (
AccessType,
DataFormat,
HubOwner,
PredictionsType,
SplitType,
TargetType,
TaskType,
)

ColumnsType = Union[str, list[str]]

Expand All @@ -47,13 +58,26 @@ class BenchmarkSpecification(BaseArtifactModel):
```python
import polaris as po

benchmark = po.load_benchmark("/path/to/benchmark")
# Load the benchmark from the Hub
benchmark = po.load_benchmark("polaris/hello_world_benchmark")

# Get the train and test data-loaders
train, test = benchmark.get_train_test_split()

# Work your magic
predictions = ...
# Use the training data to train your model
# Get the input as an array with 'train.inputs' and 'train.targets'
# Or simply iterate over the train object.
for x, y in train:
...

# Work your magic to accurately predict the test set
predictions = [0.0 for x in test]

benchmark.evaluate(predictions)
# Evaluate your predictions
results = benchmark.evaluate(predictions)

# Submit your results
results.upload_to_hub(owner="dummy-user")
```

Attributes:
Expand All @@ -68,6 +92,7 @@ class BenchmarkSpecification(BaseArtifactModel):
readme: Markdown text that can be used to provide a formatted description of the benchmark.
If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI
as it provides a rich text editor for writing markdown.
target_types: A dictionary that maps target columns to their type. If not specified, this is automatically inferred.
For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class.
"""

Expand All @@ -83,6 +108,9 @@ class BenchmarkSpecification(BaseArtifactModel):

# Additional meta-data
readme: str = ""
target_types: dict[str, Optional[Union[TargetType, str]]] = Field(
default_factory=dict, validate_default=True
)

@field_validator("dataset")
def _validate_dataset(cls, v):
Expand Down Expand Up @@ -175,9 +203,31 @@ def _validate_split(cls, v, info: FieldValidationInfo):
raise InvalidBenchmarkError("The predefined split contains invalid indices")
return v

@field_validator("target_types")
def _validate_target_types(cls, v, info: FieldValidationInfo):
"""Try to automatically infer the target types if not already set"""

dataset = info.data.get("dataset")
target_cols = info.data.get("target_cols")
if dataset is None or target_cols is None:
return v

for target in target_cols:
if target not in v:
target_type = type_of_target(dataset[:, target])
if target_type == "continuous":
v[target] = TargetType.REGRESSION
elif target_type in ["binary", "multiclass"]:
v[target] = TargetType.CLASSIFICATION
else:
v[target] = None
elif not isinstance(v, TargetType):
v[target] = TargetType(v[target])
return v

@model_validator(mode="after")
@classmethod
def _validate_checksum(cls, m: "BenchmarkSpecification"):
def _validate_model(cls, m: "BenchmarkSpecification"):
"""
If a checksum is provided, verify it matches what the checksum should be.
If no checksum is provided, make sure it is set.
Expand Down Expand Up @@ -221,6 +271,11 @@ def _serialize_split(self, v):
"""Convert any tuple to list to make sure it's serializable"""
return listit(v)

@field_serializer("target_types")
def _serialize_target_types(self, v):
"""Convert from enum to string to make sure it's serializable"""
return {k: v.value for k, v in self.target_types.items()}

@staticmethod
def _compute_checksum(dataset, target_cols, input_cols, split, metrics):
"""
Expand Down Expand Up @@ -254,6 +309,46 @@ def _compute_checksum(dataset, target_cols, input_cols, split, metrics):
checksum = hash_fn.hexdigest()
return checksum

@computed_field
@property
def n_train_datapoints(self) -> int:
"""The size of the train set."""
return len(self.split[0])

@computed_field
@property
def n_test_sets(self) -> int:
"""The number of test sets"""
return len(self.split[1]) if isinstance(self.split[1], dict) else 1

@computed_field
@property
def n_test_datapoints(self) -> dict[str, int]:
"""The size of (each of) the test set(s)."""
if self.n_test_sets == 1:
return {"test": len(self.split[1])}
else:
return {k: len(v) for k, v in self.split[1].items()}

@computed_field
@property
def n_classes(self) -> dict[str, int]:
"""The number of classes for each of the target columns."""
n_classes = {}
for target in self.target_cols:
target_type = self.target_types[target]
if target_type is None or target_type == TargetType.REGRESSION:
continue
n_classes[target] = self.dataset.loc[:, target].nunique()
return n_classes

@computed_field
@property
def task_type(self) -> TaskType:
"""The high-level task type of the benchmark."""
v = TaskType.MULTI_TASK if len(self.target_cols) > 1 else TaskType.SINGLE_TASK
return v.value

def get_train_test_split(
self, input_format: DataFormat = "dict", target_format: DataFormat = "dict"
) -> tuple[Subset, Union["Subset", dict[str, Subset]]]:
Expand Down Expand Up @@ -417,8 +512,6 @@ def _repr_dict_(self) -> dict:
repr_dict.pop("dataset")
repr_dict.pop("split")
repr_dict["dataset_name"] = self.dataset.name
repr_dict["n_input_cols"] = len(self.input_cols)
repr_dict["n_target_cols"] = len(self.target_cols)
return repr_dict

def _repr_html_(self):
Expand Down
15 changes: 14 additions & 1 deletion polaris/benchmark/_definitions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import field_validator
from pydantic import computed_field, field_validator

from polaris.benchmark._base import BenchmarkSpecification
from polaris.utils.types import TaskType


class SingleTaskBenchmarkSpecification(BenchmarkSpecification):
Expand All @@ -16,6 +17,12 @@ def validate_target_cols(cls, v):
raise ValueError("A single-task benchmark should specify a single target column")
return v

@computed_field
@property
def task_type(self) -> TaskType:
"""The high-level task type of the benchmark."""
return TaskType.SINGLE_TASK.value


class MultiTaskBenchmarkSpecification(BenchmarkSpecification):
"""Subclass for any multi-task benchmark specification
Expand All @@ -29,3 +36,9 @@ def validate_target_cols(cls, v):
if not len(v) > 1:
raise ValueError("A multi-task benchmark should specify at least two target columns")
return v

@computed_field
@property
def task_type(self) -> TaskType:
"""The high-level task type of the benchmark."""
return TaskType.MULTI_TASK.value
17 changes: 17 additions & 0 deletions polaris/dataset/_column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import enum
from typing import Dict, Optional, Union

import numpy as np
from numpy.typing import DTypeLike
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic.alias_generators import to_camel

Expand Down Expand Up @@ -34,6 +36,7 @@ class ColumnAnnotation(BaseModel):
modality: Union[str, Modality] = Modality.UNKNOWN
description: Optional[str] = None
user_attributes: Dict[str, str] = Field(default_factory=dict)
dtype: Optional[Union[np.dtype, str]] = None

model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel, populate_by_name=True)

Expand All @@ -44,7 +47,21 @@ def _validate_modality(cls, v):
v = Modality[v.upper()]
return v

@field_validator("dtype")
def _validate_dtype(cls, v):
"""Tries to convert a string to the Enum"""
if isinstance(v, str):
v = np.dtype(v)
return v

@field_serializer("modality")
def _serialize_modality(self, v: Modality):
"""Return the modality as a string, keeping it serializable"""
return v.name

@field_serializer("dtype")
def _serialize_dtype(self, v: Optional[DTypeLike]):
"""Return the modality as a string, keeping it serializable"""
if v is not None:
v = v.name
return v
58 changes: 56 additions & 2 deletions polaris/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from loguru import logger
from pydantic import (
Field,
computed_field,
field_validator,
model_validator,
)
Expand Down Expand Up @@ -55,6 +56,8 @@ class Dataset(BaseArtifactModel):
annotations: Each column _can be_ annotated with a [`ColumnAnnotation`][polaris.dataset.ColumnAnnotation] object.
Importantly, this is used to annotate whether a column is a pointer column.
source: The data source, e.g. a DOI, Github repo or URI.
license: The dataset license
curation_reference: A reference to the curation process, e.g. a DOI, Github repo or URI.
For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class.

Raises:
Expand All @@ -72,6 +75,7 @@ class Dataset(BaseArtifactModel):
annotations: Dict[str, ColumnAnnotation] = Field(default_factory=dict)
source: Optional[HttpUrlString] = None
license: Optional[License] = None
curation_reference: Optional[HttpUrlString] = None

# Config
cache_dir: Optional[str] = None # Where to cache the data to if cache() is called.
Expand Down Expand Up @@ -106,6 +110,7 @@ def _validate_model(cls, m: "Dataset"):
for c in m.table.columns:
if c not in m.annotations:
m.annotations[c] = ColumnAnnotation()
m.annotations[c].dtype = m.table[c].dtype

# Verify the checksum
# NOTE (cwognum): Is it still reasonable to always verify this as the dataset size grows?
Expand Down Expand Up @@ -152,6 +157,28 @@ def _compute_checksum(table):
checksum = hash_fn.hexdigest()
return checksum

@computed_field
@property
def n_rows(self) -> int:
"""The number of rows in the dataset."""
return len(self.rows)

@computed_field
@property
def n_columns(self) -> int:
"""The number of columns in the dataset."""
return len(self.columns)

@property
def rows(self) -> list:
"""Return all row indices for the dataset"""
return self.table.index.tolist()

@property
def columns(self) -> list:
"""Return all columns for the dataset"""
return self.table.columns.tolist()

def get_data(self, row: Union[str, int], col: str) -> np.ndarray:
"""Since the dataset might contain pointers to external files, data retrieval is more complicated
than just indexing the `table` attribute. This method provides an end-point for seamlessly
Expand Down Expand Up @@ -453,7 +480,7 @@ def _get_cache_path(self, column: str, value: str) -> Optional[str]:
return self._path_to_hash[column][value]

def size(self):
return len(self), len(self.table.columns)
return self.rows, self.n_columns

def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]:
"""
Expand Down Expand Up @@ -499,6 +526,33 @@ def fn(path):
table[c] = table[c].apply(fn)
return table

def __getitem__(self, item):
"""Allows for indexing the dataset directly"""
ret = self.table.loc[item]
if isinstance(ret, pd.Series):
# Load the data from the pointer columns

if len(ret) == self.n_columns:
# Returning a row
ret = ret.to_dict()
for k in ret.keys():
ret[k] = self.get_data(item, k)

if len(ret) == self.n_rows:
# Returning a column
if self.annotations[ret.name].is_pointer:
ret = [self.get_data(item, ret.name) for item in ret.index]
return np.array(ret)

# Returning a dataframe
if isinstance(ret, pd.DataFrame):
for c in ret.columns:
if self.annotations[c].is_pointer:
ret[c] = [self.get_data(item, c) for item in ret.index]
return ret

return ret

def _repr_dict_(self) -> dict:
"""Utility function for pretty-printing to the command line and jupyter notebooks"""
repr_dict = self.model_dump()
Expand All @@ -510,7 +564,7 @@ def _repr_html_(self):
return dict2html(self._repr_dict_())

def __len__(self):
return len(self.table)
return self.n_rows

def __repr__(self):
return json.dumps(self._repr_dict_(), indent=2)
Expand Down
Loading