Skip to content

Commit

Permalink
feat: Update evaluation notebooks to three annotators
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan committed Sep 8, 2022
1 parent f131808 commit e54b3fd
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 70 deletions.
138 changes: 69 additions & 69 deletions notebooks/evaluate_agreement.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"from tqdm.auto import tqdm\n",
"from sklearn.metrics import recall_score, precision_score, f1_score\n",
"import nltk\n",
"from statsmodels.stats import inter_rater as irr\n",
"from src.hatespeech.attack import load_attack\n",
"nltk.download('punkt')\n",
"pd.set_option('max_colwidth', None)\n",
Expand All @@ -36,7 +37,6 @@
"cell_type": "markdown",
"id": "a0cc514a-da4e-4d4b-b169-510f148aefbb",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
Expand All @@ -52,8 +52,9 @@
"source": [
"processed_dir = Path(\"data\") / \"processed\"\n",
"final_dir = Path(\"data\") / \"final\"\n",
"dan_path = processed_dir / \"dr_offensive_annotated_dan.csv\"\n",
"anders_path = processed_dir / \"dr_offensive_annotated_anders.csv\""
"dan_path = processed_dir / \"annotated-off-dan.csv\"\n",
"anders_path = processed_dir / \"annotated-off-anders.csv\"\n",
"oliver_path = processed_dir / \"annotated-off-oliver.csv\""
]
},
{
Expand All @@ -65,7 +66,11 @@
"source": [
"dan_df = pd.read_csv(dan_path, sep=\"\\t\")\n",
"anders_df = pd.read_csv(anders_path, sep=\"\\t\")\n",
"dan_df.head()"
"oliver_df = pd.read_csv(oliver_path, sep=\"\\t\")\n",
"oliver_df.label = oliver_df.label.map(\n",
" lambda lbl: \"Could be offensive, depending on context\" if lbl == \"Not sure\" else lbl\n",
")\n",
"oliver_df.head()"
]
},
{
Expand All @@ -75,14 +80,32 @@
"metadata": {},
"outputs": [],
"source": [
"cohen_kappa_score(dan_df.label, anders_df.label)"
"labels = np.stack(\n",
" [\n",
" dan_df.label.astype('category').cat.codes, \n",
" anders_df.label.astype('category').cat.codes,\n",
" oliver_df.label.astype('category').cat.codes,\n",
" ],\n",
" axis=-1\n",
")\n",
"labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12a53ca2-2498-4783-a384-c8196432eeaf",
"metadata": {},
"outputs": [],
"source": [
"agg, _ = irr.aggregate_raters(labels)\n",
"irr.fleiss_kappa(agg)"
]
},
{
"cell_type": "markdown",
"id": "2097f4a8-d0f2-471b-a096-1d04220f5add",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
Expand All @@ -96,8 +119,8 @@
"metadata": {},
"outputs": [],
"source": [
"label_df = pd.concat([dan_df[[\"label\"]], anders_df[[\"label\"]]], axis=1)\n",
"label_df.columns = [\"Dan\", \"Anders\"]\n",
"label_df = pd.concat([dan_df[[\"label\"]], anders_df[[\"label\"]], oliver_df[[\"label\"]]], axis=1)\n",
"label_df.columns = [\"Dan\", \"Anders\", \"Oliver\"]\n",
"label_df.head()"
]
},
Expand All @@ -124,132 +147,109 @@
"metadata": {},
"outputs": [],
"source": [
"label_df.groupby([\"Dan\", \"Anders\"]).size().unstack(fill_value=0)"
"comparisons = [\n",
" [\"Dan\", \"Anders\"],\n",
" [\"Dan\", \"Oliver\"],\n",
" [\"Anders\", \"Oliver\"],\n",
" [\"Dan\", \"Anders\", \"Oliver\"],\n",
"]\n",
"for comparison in comparisons:\n",
" display(label_df.groupby(comparison).size().unstack(fill_value=0))"
]
},
{
"cell_type": "markdown",
"id": "758ffaae-b206-4f7c-b515-bb27c861ac42",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
"## Extract dataframe with agreed labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b43609d8-33bb-4bf9-b021-e2e2d30bf9bc",
"metadata": {},
"outputs": [],
"source": [
"indices_with_agreement = label_df.query('Dan == Anders and Dan != \"Missing context\"').index.tolist()\n",
"agreement_df = dan_df.loc[indices_with_agreement][[\"text\", \"label\"]].reset_index().rename(columns=dict(index=\"idx\"))\n",
"agreement_df.head()"
"## Extract dataframe with majority labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab7879f9-70c4-4e83-a0c0-c9855cf9fe03",
"id": "fd1e9d8c-e425-4b90-9b09-d84a45ba0842",
"metadata": {},
"outputs": [],
"source": [
"agreement_df.label.value_counts()"
"majority_vote_df = dan_df.copy()[[\"text\", \"label\"]]\n",
"majority_vote_df.label = label_df.mode(axis=1).dropna(axis=1)\n",
"majority_vote_df = majority_vote_df.reset_index().rename(columns=dict(index=\"idx\"))\n",
"majority_vote_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d576b178-81c4-42fb-b214-1e8e7bcae932",
"id": "546d5c74-c094-40cd-8537-8d7674f1c8cd",
"metadata": {},
"outputs": [],
"source": [
"val_df_pos = agreement_df.query(\"label == 'Offensive'\").sample(frac=0.5, random_state=4242)\n",
"val_df_neg = agreement_df.query(\"label == 'Not offensive'\").sample(frac=0.5, random_state=4242)\n",
"val_df = pd.concat((val_df_pos, val_df_neg), axis=0).sample(frac=1.).reset_index(drop=True)\n",
"val_df.label.value_counts()"
"agreed_df = pd.concat(\n",
" [label_df[col] == majority_vote_df.label for col in label_df.columns], axis=1\n",
")\n",
"agreed_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd95f826-32d9-4d02-bf9f-ea2c9085cc0e",
"id": "71bbb096-a2e1-41a6-af6b-bbf08fb54e8d",
"metadata": {},
"outputs": [],
"source": [
"test_df = agreement_df[~agreement_df.idx.isin(val_df.idx)]\n",
"test_df.label.value_counts()"
"num_agreed = agreed_df.sum(axis=1)\n",
"majority_vote_df[\"num_agreed\"] = num_agreed\n",
"majority_vote_df = majority_vote_df.query(\"label != 'Missing context'\")\n",
"majority_vote_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50af4f2c-3211-4ffc-be27-ae8ea58d6fee",
"id": "8ffd2ace-23a1-45df-aa70-bc9bcd481c8a",
"metadata": {},
"outputs": [],
"source": [
"agreement_df.to_parquet(processed_dir / \"dr_offensive_annotated_agreement.parquet\")\n",
"val_df.to_parquet(final_dir / \"dr_offensive_val.parquet\")\n",
"test_df.to_parquet(final_dir / \"dr_offensive_test.parquet\")"
"majority_vote_df.num_agreed.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efa2bf0e-e7e9-4af4-8525-c301f5abdbf7",
"id": "589084c0-a961-4ae3-9e95-a870d6abdbbc",
"metadata": {},
"outputs": [],
"source": [
"model_id = 'models/xlmr-base2'\n",
"\n",
"# Load tokenizer and model\n",
"if model_id == 'attack':\n",
" tok, model = load_attack()\n",
"else:\n",
" tok = AutoTokenizer.from_pretrained(model_id)\n",
" model = AutoModelForSequenceClassification.from_pretrained(model_id)\n",
"\n",
"# Get logits\n",
"logits = torch.stack(\n",
" [get_logits(doc, tok, model) for doc in tqdm(val_df.text, leave=False)]\n",
")"
"val_df_pos = majority_vote_df.query(\"label == 'Offensive'\").sample(frac=0.5, random_state=4242)\n",
"val_df_neg = majority_vote_df.query(\"label == 'Not offensive'\").sample(frac=0.5, random_state=4242)\n",
"val_df = pd.concat((val_df_pos, val_df_neg), axis=0).sample(frac=1.).reset_index(drop=True)\n",
"val_df.label.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba99d812-85dc-4608-befb-a66ebdad64f9",
"id": "e45b1a1e-8018-464b-8ac5-1fee09da8082",
"metadata": {},
"outputs": [],
"source": [
"# Add the logits and equivalent probabilities to the validation dataframe\n",
"val_df[\"model_logits\"] = logits.tolist()\n",
"val_df[\"model_probs\"] = torch.sigmoid(logits).tolist()\n",
"val_df.head()"
"test_df = majority_vote_df[~majority_vote_df.idx.isin(val_df.idx)]\n",
"test_df.label.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35df8d18-159b-41e9-963d-63bd645c6e55",
"id": "6e3c38b7-7ede-4d80-9b39-b5777a817a9c",
"metadata": {},
"outputs": [],
"source": [
"# Get the sample indices on which the model was wrong\n",
"wrong_idxs = (\n",
" torch.nonzero((logits > 0) != torch.tensor(val_labels)).squeeze(1).tolist()\n",
")\n",
"\n",
"# Get the samples on which the model was wrong\n",
"wrong_df = val_df.loc[wrong_idxs]\n",
"\n",
"# Sort the dataframe by absolute value of logits\n",
"wrong_df = wrong_df.sort_values(by='model_logits', key=lambda x: abs(x), ascending=False)\n",
"\n",
"wrong_df"
"majority_vote_df.to_parquet(processed_dir / \"annotated-off.parquet\")\n",
"val_df.to_parquet(final_dir / \"val-off.parquet\")\n",
"test_df.to_parquet(final_dir / \"test-off.parquet\")"
]
}
],
Expand Down
1 change: 0 additions & 1 deletion notebooks/evaluate_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@
"cell_type": "markdown",
"id": "30a58c05-2323-4615-a2f2-7dedbc656b31",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
Expand Down

0 comments on commit e54b3fd

Please sign in to comment.