Skip to content

Commit

Permalink
Updates notebooks with GPT-4o results, minor bug fixes (#17)
Browse files Browse the repository at this point in the history
* fix minor bug where web API responds with percentage instead of probability

* bumped version

* improve chatGPT system prompt for numeric answers

* added GPT4o plots

* added GPT4o table and results on ACSIncome

* updated score distribution plots with GPT4o results

* minor plots update

* one row score distribution plots

* added score bias plots for GPT4o

* calibration curves with multiple-choice prompting

* install now requires packaging>=22.0
  • Loading branch information
AndreFCruz authored Jan 17, 2025
1 parent b75f4a6 commit ec4911f
Show file tree
Hide file tree
Showing 57 changed files with 48 additions and 32 deletions.
2 changes: 1 addition & 1 deletion folktexts/classifier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def compute_risk_estimates_for_dataframe(
)

# Store risk estimates for current question
batch_risk_scores[:, q_idx] = risk_estimates_batch
batch_risk_scores[:, q_idx] = np.clip(risk_estimates_batch, 0, 1)

risk_scores[start_idx: end_idx] = batch_risk_scores.mean(axis=1)

Expand Down
18 changes: 15 additions & 3 deletions folktexts/classifier/web_api_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def _query_webapi_batch(
# NOTE: Models often generate "0." instead of directly outputting the fractional part
# > Therefore: for multi-token answers, extra forward passes may be required
else:
num_forward_passes = question.num_forward_passes + 2 # +2 tokens for "0."
# Add extra tokens for textual prefix, e.g., "The probability is: ..."
num_forward_passes = question.num_forward_passes + 2

api_call_params = dict(
temperature=1,
Expand All @@ -175,8 +176,12 @@ def _query_webapi_batch(

# Get system prompt depending on Q&A type
if isinstance(question, DirectNumericQA):
# system_prompt = "Please respond with number."
system_prompt = "Please respond with number representing the estimated probability."
system_prompt = "Your response must start with a number representing the estimated probability."
# system_prompt = (
# "You are a highly specialized assistant that always responds with a single number. "
# "For every input, you must analyze the request and respond with only the relevant single number, "
# "without any additional text, explanation, or symbols."
# )
elif isinstance(question, MultipleChoiceQA):
system_prompt = "Please respond with a single letter."
else:
Expand Down Expand Up @@ -278,6 +283,13 @@ def _decode_risk_estimate_from_api_response(
# Using full text answer as it more tightly relates to the ChatGPT web answer
risk_estimate = risk_estimate_full_text

if risk_estimate > 1:
logging.info(
f"Got risk estimate > 1: {risk_estimate}. Using "
f"output as a percentage: {risk_estimate / 100.0} instead."
)
risk_estimate = risk_estimate / 100.0

except Exception:
logging.error(
f"Failed to extract numeric response from message='{response_message}';\n"
Expand Down
2 changes: 2 additions & 0 deletions notebooks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def prettify_model_name(name: str) -> str:
"gemma-2-27b-it": "Gemma 2 27B (it)",
"openai/gpt-4o-mini": "GPT 4o mini (it)",
"penai/gpt-4o-mini": "GPT 4o mini (it)",
"openai/gpt-4o": "GPT 4o (it)",
"penai/gpt-4o": "GPT 4o (it)",
}

if name in dct:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=64", "wheel"]
requires = ["setuptools>=64", "wheel", "packaging>=22.0"]
build-backend = "setuptools.build_meta"

[project]
Expand Down Expand Up @@ -32,7 +32,7 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]

version = "0.0.26"
version = "0.0.27"
requires-python = ">=3.8"
dynamic = [
"readme",
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified results/imgs/calibration-curves-base-and-instr.numeric-prompt.pdf
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified results/imgs/score-distribution.multiple-choice-prompt.pdf
Binary file not shown.
Binary file not shown.
Binary file modified results/imgs/score-distribution.numeric-prompt.pdf
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified results/imgs/score_bias.Asian_v_Black_score_bias.numeric-prompt.pdf
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified results/imgs/under_over_score.ACSIncome.multiple-choice-prompt.pdf
Binary file not shown.
Binary file modified results/imgs/under_over_score.ACSIncome.numeric-prompt.pdf
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
& ece (mult. choice) & brier score loss (mult. choice) & roc auc (mult. choice) & accuracy (mult. choice) & ece & brier score loss & roc auc & accuracy \\
Model & & & & & & & & \\
\midrule
GPT 4o mini (it) & 0.24 & 0.24 & \cellcolor{cyan!17.6} 0.85 & 0.74 & \cellcolor{cyan!25.0} 0.05 & \cellcolor{cyan!25.0} 0.16 & \cellcolor{cyan!20.6} 0.83 & \cellcolor{cyan!24.4} 0.78 \\
Mixtral 8x22B (it) & 0.21 & \cellcolor{cyan!3.6} 0.22 & \cellcolor{cyan!11.2} 0.85 & \cellcolor{cyan!11.1} 0.76 & 0.11 & \cellcolor{cyan!15.5} 0.17 & \cellcolor{cyan!25.0} 0.84 & \cellcolor{cyan!18.9} 0.77 \\
Mixtral 8x22B & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!21.6} 0.19 & \cellcolor{cyan!16.5} 0.85 & 0.68 & 0.13 & \cellcolor{cyan!3.6} 0.18 & \cellcolor{cyan!9.6} 0.82 & \cellcolor{cyan!2.4} 0.74 \\
Llama 3 70B (it) & 0.27 & 0.27 & \cellcolor{cyan!25.0} 0.86 & 0.69 & 0.25 & 0.23 & \cellcolor{cyan!22.1} 0.84 & 0.67 \\
Llama 3 70B & 0.20 & \cellcolor{cyan!14.9} 0.20 & \cellcolor{cyan!20.8} 0.86 & 0.70 & 0.27 & 0.24 & \cellcolor{cyan!6.7} 0.82 & 0.54 \\
Mixtral 8x7B (it) & \cellcolor{cyan!16.8} 0.16 & \cellcolor{cyan!25.0} 0.18 & \cellcolor{cyan!19.7} 0.86 & \cellcolor{cyan!25.0} 0.78 & 0.10 & \cellcolor{cyan!15.5} 0.17 & \cellcolor{cyan!22.8} 0.84 & \cellcolor{cyan!12.8} 0.76 \\
Mixtral 8x7B & \cellcolor{cyan!10.6} 0.17 & \cellcolor{cyan!11.5} 0.21 & 0.83 & 0.65 & \cellcolor{cyan!4.0} 0.07 & \cellcolor{cyan!13.1} 0.17 & \cellcolor{cyan!3.7} 0.81 & \cellcolor{cyan!25.0} 0.78 \\
Yi 34B (it) & \cellcolor{cyan!1.3} 0.19 & \cellcolor{cyan!18.8} 0.19 & \cellcolor{cyan!19.7} 0.86 & 0.72 & 0.22 & 0.21 & 0.80 & 0.48 \\
Yi 34B & 0.25 & \cellcolor{cyan!2.5} 0.22 & \cellcolor{cyan!13.3} 0.85 & 0.62 & 0.15 & 0.19 & \cellcolor{cyan!17.7} 0.83 & 0.61 \\
Llama 3 8B (it) & 0.32 & 0.30 & \cellcolor{cyan!13.3} 0.85 & 0.62 & 0.23 & 0.23 & \cellcolor{cyan!0.1} 0.81 & 0.67 \\
Llama 3 8B & 0.25 & 0.26 & 0.81 & \cellcolor{orange!20.8} 0.38 & 0.14 & 0.24 & 0.63 & \cellcolor{orange!3.7} 0.40 \\
Mistral 7B (it) & 0.21 & \cellcolor{cyan!4.7} 0.22 & 0.83 & \cellcolor{cyan!16.5} 0.77 & 0.16 & 0.19 & \cellcolor{cyan!14.7} 0.83 & 0.70 \\
Gemma 7B (it) & \cellcolor{orange!14.7} 0.61 & \cellcolor{orange!4.7} 0.59 & \cellcolor{cyan!3.8} 0.84 & \cellcolor{orange!25.0} 0.37 & 0.33 & 0.30 & 0.78 & 0.42 \\
GPT 4o (it) & \cellcolor{cyan!3.9} 0.18 & \cellcolor{cyan!20.5} 0.19 & \cellcolor{cyan!25.0} 0.87 & \cellcolor{cyan!25.0} 0.80 & 0.08 & \cellcolor{cyan!25.0} 0.16 & \cellcolor{cyan!25.0} 0.85 & \cellcolor{cyan!25.0} 0.78 \\
GPT 4o mini (it) & 0.24 & 0.24 & \cellcolor{cyan!7.7} 0.85 & 0.74 & \cellcolor{cyan!25.0} 0.05 & \cellcolor{cyan!22.6} 0.16 & \cellcolor{cyan!10.9} 0.83 & \cellcolor{cyan!22.6} 0.78 \\
Mixtral 8x22B (it) & 0.21 & \cellcolor{cyan!3.6} 0.22 & \cellcolor{cyan!1.6} 0.85 & \cellcolor{cyan!3.9} 0.76 & 0.11 & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!15.1} 0.84 & \cellcolor{cyan!17.1} 0.77 \\
Mixtral 8x22B & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!21.6} 0.19 & \cellcolor{cyan!6.7} 0.85 & 0.68 & 0.13 & \cellcolor{cyan!1.4} 0.18 & \cellcolor{cyan!0.4} 0.82 & \cellcolor{cyan!0.8} 0.74 \\
Llama 3 70B (it) & 0.27 & 0.27 & \cellcolor{cyan!14.8} 0.86 & 0.69 & 0.25 & 0.23 & \cellcolor{cyan!12.3} 0.84 & 0.67 \\
Llama 3 70B & 0.20 & \cellcolor{cyan!14.9} 0.20 & \cellcolor{cyan!10.8} 0.86 & 0.70 & 0.27 & 0.24 & 0.82 & 0.54 \\
Mixtral 8x7B (it) & \cellcolor{cyan!16.8} 0.16 & \cellcolor{cyan!25.0} 0.18 & \cellcolor{cyan!9.8} 0.86 & \cellcolor{cyan!17.4} 0.78 & 0.10 & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!13.0} 0.84 & \cellcolor{cyan!11.1} 0.76 \\
Mixtral 8x7B & \cellcolor{cyan!10.6} 0.17 & \cellcolor{cyan!11.5} 0.21 & 0.83 & 0.65 & \cellcolor{cyan!4.0} 0.07 & \cellcolor{cyan!10.8} 0.17 & 0.81 & \cellcolor{cyan!23.2} 0.78 \\
Yi 34B (it) & \cellcolor{cyan!1.3} 0.19 & \cellcolor{cyan!18.8} 0.19 & \cellcolor{cyan!9.8} 0.86 & 0.72 & 0.22 & 0.21 & 0.80 & 0.48 \\
Yi 34B & 0.25 & \cellcolor{cyan!2.5} 0.22 & \cellcolor{cyan!3.7} 0.85 & 0.62 & 0.15 & 0.19 & \cellcolor{cyan!8.1} 0.83 & 0.61 \\
Llama 3 8B (it) & 0.32 & 0.30 & \cellcolor{cyan!3.7} 0.85 & 0.62 & 0.23 & 0.23 & 0.81 & 0.67 \\
Llama 3 8B & 0.25 & 0.26 & 0.81 & \cellcolor{orange!20.9} 0.38 & 0.14 & 0.24 & 0.63 & \cellcolor{orange!3.8} 0.40 \\
Mistral 7B (it) & 0.21 & \cellcolor{cyan!4.7} 0.22 & 0.83 & \cellcolor{cyan!9.2} 0.77 & 0.16 & 0.19 & \cellcolor{cyan!5.3} 0.83 & 0.70 \\
Gemma 7B (it) & \cellcolor{orange!14.7} 0.61 & \cellcolor{orange!4.7} 0.59 & 0.84 & \cellcolor{orange!25.0} 0.37 & 0.33 & 0.30 & 0.78 & 0.42 \\
Mistral 7B & 0.20 & 0.23 & 0.80 & 0.73 & \cellcolor{orange!15.7} 0.36 & 0.32 & 0.75 & 0.49 \\
Gemma 7B & 0.24 & 0.27 & 0.76 & \cellcolor{orange!25.0} 0.37 & 0.15 & 0.20 & 0.80 & 0.73 \\
Gemma 2B (it) & \cellcolor{orange!25.0} 0.63 & \cellcolor{orange!25.0} 0.63 & 0.73 & \cellcolor{orange!25.0} 0.37 & 0.28 & 0.31 & \cellcolor{orange!25.0} 0.50 & \cellcolor{orange!25.0} 0.37 \\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
& ece (num) & brier score loss (num) & roc auc (num) & accuracy (num) & ece & brier score loss & roc auc & accuracy \\
Model & & & & & & & & \\
\midrule
GPT 4o mini (it) & \cellcolor{cyan!25.0} 0.05 & \cellcolor{cyan!25.0} 0.16 & \cellcolor{cyan!20.6} 0.83 & \cellcolor{cyan!24.4} 0.78 & 0.24 & 0.24 & \cellcolor{cyan!17.6} 0.85 & 0.74 \\
Mixtral 8x22B (it) & 0.11 & \cellcolor{cyan!15.5} 0.17 & \cellcolor{cyan!25.0} 0.84 & \cellcolor{cyan!18.9} 0.77 & 0.21 & \cellcolor{cyan!3.6} 0.22 & \cellcolor{cyan!11.2} 0.85 & \cellcolor{cyan!11.1} 0.76 \\
Mixtral 8x22B & 0.13 & \cellcolor{cyan!3.6} 0.18 & \cellcolor{cyan!9.6} 0.82 & \cellcolor{cyan!2.4} 0.74 & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!21.6} 0.19 & \cellcolor{cyan!16.5} 0.85 & 0.68 \\
Llama 3 70B (it) & 0.25 & 0.23 & \cellcolor{cyan!22.1} 0.84 & 0.67 & 0.27 & 0.27 & \cellcolor{cyan!25.0} 0.86 & 0.69 \\
Llama 3 70B & 0.27 & 0.24 & \cellcolor{cyan!6.7} 0.82 & 0.54 & 0.20 & \cellcolor{cyan!14.9} 0.20 & \cellcolor{cyan!20.8} 0.86 & 0.70 \\
Mixtral 8x7B (it) & 0.10 & \cellcolor{cyan!15.5} 0.17 & \cellcolor{cyan!22.8} 0.84 & \cellcolor{cyan!12.8} 0.76 & \cellcolor{cyan!16.8} 0.16 & \cellcolor{cyan!25.0} 0.18 & \cellcolor{cyan!19.7} 0.86 & \cellcolor{cyan!25.0} 0.78 \\
Mixtral 8x7B & \cellcolor{cyan!4.0} 0.07 & \cellcolor{cyan!13.1} 0.17 & \cellcolor{cyan!3.7} 0.81 & \cellcolor{cyan!25.0} 0.78 & \cellcolor{cyan!10.6} 0.17 & \cellcolor{cyan!11.5} 0.21 & 0.83 & 0.65 \\
Yi 34B (it) & 0.22 & 0.21 & 0.80 & 0.48 & \cellcolor{cyan!1.3} 0.19 & \cellcolor{cyan!18.8} 0.19 & \cellcolor{cyan!19.7} 0.86 & 0.72 \\
Yi 34B & 0.15 & 0.19 & \cellcolor{cyan!17.7} 0.83 & 0.61 & 0.25 & \cellcolor{cyan!2.5} 0.22 & \cellcolor{cyan!13.3} 0.85 & 0.62 \\
Llama 3 8B (it) & 0.23 & 0.23 & \cellcolor{cyan!0.1} 0.81 & 0.67 & 0.32 & 0.30 & \cellcolor{cyan!13.3} 0.85 & 0.62 \\
Llama 3 8B & 0.14 & 0.24 & 0.63 & \cellcolor{orange!3.7} 0.40 & 0.25 & 0.26 & 0.81 & \cellcolor{orange!20.8} 0.38 \\
Mistral 7B (it) & 0.16 & 0.19 & \cellcolor{cyan!14.7} 0.83 & 0.70 & 0.21 & \cellcolor{cyan!4.7} 0.22 & 0.83 & \cellcolor{cyan!16.5} 0.77 \\
Gemma 7B (it) & 0.33 & 0.30 & 0.78 & 0.42 & \cellcolor{orange!14.7} 0.61 & \cellcolor{orange!4.7} 0.59 & \cellcolor{cyan!3.8} 0.84 & \cellcolor{orange!25.0} 0.37 \\
GPT 4o (it) & 0.08 & \cellcolor{cyan!25.0} 0.16 & \cellcolor{cyan!25.0} 0.85 & \cellcolor{cyan!25.0} 0.78 & \cellcolor{cyan!3.9} 0.18 & \cellcolor{cyan!20.5} 0.19 & \cellcolor{cyan!25.0} 0.87 & \cellcolor{cyan!25.0} 0.80 \\
GPT 4o mini (it) & \cellcolor{cyan!25.0} 0.05 & \cellcolor{cyan!22.6} 0.16 & \cellcolor{cyan!10.9} 0.83 & \cellcolor{cyan!22.6} 0.78 & 0.24 & 0.24 & \cellcolor{cyan!7.7} 0.85 & 0.74 \\
Mixtral 8x22B (it) & 0.11 & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!15.1} 0.84 & \cellcolor{cyan!17.1} 0.77 & 0.21 & \cellcolor{cyan!3.6} 0.22 & \cellcolor{cyan!1.6} 0.85 & \cellcolor{cyan!3.9} 0.76 \\
Mixtral 8x22B & 0.13 & \cellcolor{cyan!1.4} 0.18 & \cellcolor{cyan!0.4} 0.82 & \cellcolor{cyan!0.8} 0.74 & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!21.6} 0.19 & \cellcolor{cyan!6.7} 0.85 & 0.68 \\
Llama 3 70B (it) & 0.25 & 0.23 & \cellcolor{cyan!12.3} 0.84 & 0.67 & 0.27 & 0.27 & \cellcolor{cyan!14.8} 0.86 & 0.69 \\
Llama 3 70B & 0.27 & 0.24 & 0.82 & 0.54 & 0.20 & \cellcolor{cyan!14.9} 0.20 & \cellcolor{cyan!10.8} 0.86 & 0.70 \\
Mixtral 8x7B (it) & 0.10 & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!13.0} 0.84 & \cellcolor{cyan!11.1} 0.76 & \cellcolor{cyan!16.8} 0.16 & \cellcolor{cyan!25.0} 0.18 & \cellcolor{cyan!9.8} 0.86 & \cellcolor{cyan!17.4} 0.78 \\
Mixtral 8x7B & \cellcolor{cyan!4.0} 0.07 & \cellcolor{cyan!10.8} 0.17 & 0.81 & \cellcolor{cyan!23.2} 0.78 & \cellcolor{cyan!10.6} 0.17 & \cellcolor{cyan!11.5} 0.21 & 0.83 & 0.65 \\
Yi 34B (it) & 0.22 & 0.21 & 0.80 & 0.48 & \cellcolor{cyan!1.3} 0.19 & \cellcolor{cyan!18.8} 0.19 & \cellcolor{cyan!9.8} 0.86 & 0.72 \\
Yi 34B & 0.15 & 0.19 & \cellcolor{cyan!8.1} 0.83 & 0.61 & 0.25 & \cellcolor{cyan!2.5} 0.22 & \cellcolor{cyan!3.7} 0.85 & 0.62 \\
Llama 3 8B (it) & 0.23 & 0.23 & 0.81 & 0.67 & 0.32 & 0.30 & \cellcolor{cyan!3.7} 0.85 & 0.62 \\
Llama 3 8B & 0.14 & 0.24 & 0.63 & \cellcolor{orange!3.8} 0.40 & 0.25 & 0.26 & 0.81 & \cellcolor{orange!20.9} 0.38 \\
Mistral 7B (it) & 0.16 & 0.19 & \cellcolor{cyan!5.3} 0.83 & 0.70 & 0.21 & \cellcolor{cyan!4.7} 0.22 & 0.83 & \cellcolor{cyan!9.2} 0.77 \\
Gemma 7B (it) & 0.33 & 0.30 & 0.78 & 0.42 & \cellcolor{orange!14.7} 0.61 & \cellcolor{orange!4.7} 0.59 & 0.84 & \cellcolor{orange!25.0} 0.37 \\
Mistral 7B & \cellcolor{orange!15.7} 0.36 & 0.32 & 0.75 & 0.49 & 0.20 & 0.23 & 0.80 & 0.73 \\
Gemma 7B & 0.15 & 0.20 & 0.80 & 0.73 & 0.24 & 0.27 & 0.76 & \cellcolor{orange!25.0} 0.37 \\
Gemma 2B (it) & 0.28 & 0.31 & \cellcolor{orange!25.0} 0.50 & \cellcolor{orange!25.0} 0.37 & \cellcolor{orange!25.0} 0.63 & \cellcolor{orange!25.0} 0.63 & 0.73 & \cellcolor{orange!25.0} 0.37 \\
Expand Down

0 comments on commit ec4911f

Please sign in to comment.