Skip to content

Commit

Permalink
wip wip wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszkolodziejczyk committed Feb 26, 2025
1 parent 5e19df0 commit a21c8e1
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 26 deletions.
185 changes: 185 additions & 0 deletions coh.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CATEGORIES PER SEQUENCE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"df_tgt = pd.read_csv(\"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/baseball/batting.csv.gz\")\n",
"df_tgt.head(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.qa._coherence import pull_data_for_coherence\n",
"\n",
"df_tgt = pull_data_for_coherence(df_tgt=df_tgt, tgt_context_key=\"players_id\")\n",
"df_tgt.head(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.qa._coherence import calculate_categories_per_sequence\n",
"\n",
"categories_per_sequence_df = calculate_categories_per_sequence(df=df_tgt, context_key=\"players_id\")\n",
"categories_per_sequence_df.head(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.qa._accuracy import calculate_numeric_uni_kdes\n",
"\n",
"trn_num_kdes = calculate_numeric_uni_kdes(categories_per_sequence_df)\n",
"trn_num_kdes[\"team\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.qa._accuracy import bin_data\n",
"\n",
"\n",
"cat_share_per_sequence_binned = bin_data(categories_per_sequence_df, bins=10)[0]\n",
"cat_share_per_sequence_binned.head(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.qa._accuracy import calculate_categorical_uni_counts\n",
"\n",
"\n",
"trn_bin_col_cnts = calculate_categorical_uni_counts(df=cat_share_per_sequence_binned, hash_rare_values=False)\n",
"trn_bin_col_cnts[\"team\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.qa._accuracy import plot_univariate\n",
"\n",
"for col in categories_per_sequence_df.columns:\n",
" if col != \"players_id\": # Skip the context key\n",
" display(\n",
" plot_univariate(\n",
" col_name=col,\n",
" trn_num_kde=trn_num_kdes.get(col),\n",
" syn_num_kde=trn_num_kdes.get(col),\n",
" trn_cat_col_cnts=None,\n",
" syn_cat_col_cnts=None,\n",
" trn_bin_col_cnts=trn_bin_col_cnts[col],\n",
" syn_bin_col_cnts=trn_bin_col_cnts[col],\n",
" accuracy=0.5,\n",
" )\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SEQUENCES PER CATEGORY"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.qa._coherence import calculate_sequences_per_category\n",
"\n",
"sequences_per_category_dict, sequences_per_category_binned_dict = calculate_sequences_per_category(\n",
" df=df_tgt, context_key=\"players_id\"\n",
")\n",
"display(sequences_per_category_dict[\"team\"].head(2))\n",
"display(sequences_per_category_binned_dict[\"team\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.qa._accuracy import plot_univariate\n",
"\n",
"for col in categories_per_sequence_df.columns:\n",
" if col != \"players_id\": # Skip the context key\n",
" display(\n",
" plot_univariate(\n",
" col_name=col,\n",
" trn_num_kde=None,\n",
" syn_num_kde=None,\n",
" trn_cat_col_cnts=sequences_per_category_dict[col],\n",
" syn_cat_col_cnts=sequences_per_category_dict[col],\n",
" trn_bin_col_cnts=sequences_per_category_binned_dict[col],\n",
" syn_bin_col_cnts=sequences_per_category_binned_dict[col],\n",
" accuracy=0.5,\n",
" trn_col_total=df_tgt[\"players_id\"].nunique(),\n",
" syn_col_total=df_tgt[\"players_id\"].nunique(),\n",
" )\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
39 changes: 30 additions & 9 deletions mostlyai/qa/_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def plot_univariate(
trn_bin_col_cnts: pd.Series,
syn_bin_col_cnts: pd.Series,
accuracy: float | None,
trn_col_total: int | None = None,
syn_col_total: int | None = None,
) -> go.Figure:
# either numerical/datetime KDEs or categorical counts must be provided

Expand Down Expand Up @@ -485,8 +487,16 @@ def plot_univariate(
fig.layout.xaxis2.update(type="category")
else:
fig.layout.yaxis.update(tickformat=".0%")
trn_line1, syn_line1 = plot_univariate_distribution_categorical(trn_cat_col_cnts, syn_cat_col_cnts)
trn_line2, syn_line2 = plot_univariate_binned(trn_bin_col_cnts, syn_bin_col_cnts, sort_by_frequency=True)
trn_line1, syn_line1 = plot_univariate_distribution_categorical(
trn_cat_col_cnts, syn_cat_col_cnts, trn_col_total, syn_col_total
)
trn_line2, syn_line2 = plot_univariate_binned(
trn_bin_col_cnts,
syn_bin_col_cnts,
sort_by_frequency=True,
trn_col_total=trn_col_total,
syn_col_total=syn_col_total,
)
# prevent Plotly from trying to convert strings to dates
fig.layout.xaxis.update(type="category")
fig.layout.xaxis2.update(type="category")
Expand All @@ -505,6 +515,8 @@ def plot_univariate(
def prepare_categorical_plot_data_distribution(
trn_col_cnts: pd.Series,
syn_col_cnts: pd.Series,
trn_col_total: int | None = None,
syn_col_total: int | None = None,
) -> pd.DataFrame:
trn_col_cnts_idx = trn_col_cnts.index.to_series().astype("string").fillna(NA_BIN).replace("", EMPTY_BIN)
syn_col_cnts_idx = syn_col_cnts.index.to_series().astype("string").fillna(NA_BIN).replace("", EMPTY_BIN)
Expand All @@ -517,8 +529,8 @@ def prepare_categorical_plot_data_distribution(
df["synthetic_cnt"] = df["synthetic_cnt"].fillna(0.0)
df["avg_cnt"] = (df["target_cnt"] + df["synthetic_cnt"]) / 2
df = df[df["avg_cnt"] > 0]
df["target_pct"] = df["target_cnt"] / df["target_cnt"].sum()
df["synthetic_pct"] = df["synthetic_cnt"] / df["synthetic_cnt"].sum()
df["target_pct"] = df["target_cnt"] / (trn_col_total or df["target_cnt"].sum())
df["synthetic_pct"] = df["synthetic_cnt"] / (syn_col_total or df["synthetic_cnt"].sum())
df = df.rename(columns={"index": "category"})
if df["category"].dtype.name == "category":
df["category_code"] = df["category"].cat.codes
Expand All @@ -532,6 +544,8 @@ def prepare_categorical_plot_data_binned(
trn_bin_col_cnts: pd.Series,
syn_bin_col_cnts: pd.Series,
sort_by_frequency: bool,
trn_col_total: int | None = None,
syn_col_total: int | None = None,
) -> pd.DataFrame:
t = trn_bin_col_cnts.to_frame("target_cnt").reset_index(names="category")
s = syn_bin_col_cnts.to_frame("synthetic_cnt").reset_index(names="category")
Expand All @@ -540,8 +554,8 @@ def prepare_categorical_plot_data_binned(
df["synthetic_cnt"] = df["synthetic_cnt"].fillna(0.0)
df["avg_cnt"] = (df["target_cnt"] + df["synthetic_cnt"]) / 2
df = df[df["avg_cnt"] > 0]
df["target_pct"] = df["target_cnt"] / df["target_cnt"].sum()
df["synthetic_pct"] = df["synthetic_cnt"] / df["synthetic_cnt"].sum()
df["target_pct"] = df["target_cnt"] / (trn_col_total or df["target_cnt"].sum())
df["synthetic_pct"] = df["synthetic_cnt"] / (syn_col_total or df["synthetic_cnt"].sum())
if df["category"].dtype.name == "category":
df["category_code"] = df["category"].cat.codes
else:
Expand All @@ -554,10 +568,13 @@ def prepare_categorical_plot_data_binned(


def plot_univariate_distribution_categorical(
trn_cat_col_cnts: pd.Series, syn_cat_col_cnts: pd.Series
trn_cat_col_cnts: pd.Series,
syn_cat_col_cnts: pd.Series,
trn_col_total: int | None = None,
syn_col_total: int | None = None,
) -> tuple[go.Scatter, go.Scatter]:
# prepare data
df = prepare_categorical_plot_data_distribution(trn_cat_col_cnts, syn_cat_col_cnts)
df = prepare_categorical_plot_data_distribution(trn_cat_col_cnts, syn_cat_col_cnts, trn_col_total, syn_col_total)
df = df.sort_values("avg_cnt", ascending=False)
# trim labels
df["category"] = trim_labels(df["category"], max_length=10)
Expand Down Expand Up @@ -587,9 +604,13 @@ def plot_univariate_binned(
trn_bin_col_cnts: pd.Series,
syn_bin_col_cnts: pd.Series,
sort_by_frequency: bool = False,
trn_col_total: int | None = None,
syn_col_total: int | None = None,
) -> tuple[go.Scatter, go.Scatter]:
# prepare data
df = prepare_categorical_plot_data_binned(trn_bin_col_cnts, syn_bin_col_cnts, sort_by_frequency)
df = prepare_categorical_plot_data_binned(
trn_bin_col_cnts, syn_bin_col_cnts, sort_by_frequency, trn_col_total, syn_col_total
)
# prepare plots
trn_line = go.Scatter(
mode="lines+markers",
Expand Down
88 changes: 88 additions & 0 deletions mostlyai/qa/_coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from mostlyai.qa._common import CHARTS_COLORS, CHARTS_FONTS
from mostlyai.qa._filesystem import TemporaryWorkspace
from mostlyai.qa._sampling import harmonize_dtype


def plot_store_coherences(
Expand Down Expand Up @@ -215,3 +216,90 @@ def calculate_coh_univariates(
)
accuracies["accuracy"], accuracies["accuracy_max"] = zip(*results)
return accuracies


def pull_data_for_coherence(
*,
df_tgt: pd.DataFrame,
tgt_context_key: str,
max_sequence_length: int = 100,
) -> pd.DataFrame:
"""
Prepare sequential dataset for coherence metrics.
"""
# randomly sample at most max_sequence_length rows per sequence
df_tgt = df_tgt.sample(frac=1).reset_index(drop=True)
df_tgt = df_tgt[df_tgt.groupby(tgt_context_key).cumcount() < max_sequence_length].reset_index(drop=True)

# harmonize dtypes
# apply harmonize_dtype to all columns except tgt_context_key
df_tgt = df_tgt.apply(lambda col: harmonize_dtype(col) if col.name != tgt_context_key else col)

# TODO: discretize columns
for col in df_tgt.columns:
if col == tgt_context_key:
continue
df_tgt[col] = pd.Categorical(df_tgt[col], ordered=True)

# Example output (pd.DataFrame):
# | players_id | year | team | league | G | AB | R | H | HR | RBI | SB | CS | BB | SO |
# |------------|--------|------|--------|------|-------|------|------|------|------|------|------|------|------|
# | borowha01 | 1943.0 | NYA | AL | 29.0 | 74.0 | 2.0 | 15.0 | 0.0 | 7.0 | 0.0 | 0.0 | 5.0 | 17.0 |
# | wallaja02 | 1946.0 | PHA | AL | 63.0 | 194.0 | 16.0 | 38.0 | 5.0 | 11.0 | 1.0 | 0.0 | 14.0 | 47.0 |
# players_id dtype: original, other columns dtype: category
return df_tgt


def calculate_categories_per_sequence(df: pd.DataFrame, context_key: str) -> pd.DataFrame:
"""
Calculate the number of categories per sequence for all columns except the context key.
"""
# Example output (pd.DataFrame):
# | players_id | year | team | league | G | AB | R | H | HR | RBI | SB | CS | BB | SO |
# |------------|------|------|--------|----|----|----|----|----|-----|----|----|----|----|
# | aardsda01 | 9 | 8 | 2 | 9 | 3 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 2 |
# | aaronha01 | 23 | 3 | 2 | 18 | 21 | 20 | 23 | 17 | 20 | 15 | 10 | 22 | 19 |
# players_id dtype: original, other columns dtype: int64
return df.groupby(context_key).nunique().reset_index()


def calculate_sequences_per_category(
df: pd.DataFrame, context_key: str
) -> tuple[dict[str, pd.Series], dict[str, pd.Series]]:
"""
Calculate the number of sequences per category for all columns except the context key.
"""
# replace all null values with '(n/a)'
df = df.copy()
for col in df.columns:
if col == context_key:
continue
# Add '(n/a)' category if needed and replace nulls
if df[col].isna().any():
df[col] = df[col].cat.add_categories("(n/a)")
df.loc[df[col].isna(), col] = "(n/a)"

# Example output for "team" (pd.Series):
# team
# ALT 18
# ANA 164
# Name: players_id, dtype: int64
sequences_per_category_dict = {
col: df.groupby(col)[context_key].nunique().rename_axis(None) for col in df.columns if col != context_key
}

# convert df to have top 9 categories w.r.t. frequency of belonging to sequences + '(other)' for all other categories
df = df.copy()
for col in df.columns:
if col == context_key:
continue
top_categories = sequences_per_category_dict[col].nlargest(9).index.tolist()
not_in_top_categories_mask = ~df[col].isin(top_categories)
if not_in_top_categories_mask.any():
df[col] = df[col].cat.add_categories("(other)")
df.loc[not_in_top_categories_mask, col] = "(other)"
df[col] = df[col].cat.remove_unused_categories()
sequences_per_category_binned_dict = {
col: df.groupby(col)[context_key].nunique().rename_axis(None) for col in df.columns if col != context_key
}
return sequences_per_category_dict, sequences_per_category_binned_dict
Loading

0 comments on commit a21c8e1

Please sign in to comment.