Skip to content

Commit

Permalink
minor code clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 6, 2024
1 parent 09eaddf commit ed987af
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
6 changes: 5 additions & 1 deletion folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,10 @@ def run(self, fit_threshold: int | False = False) -> float:
s_test = self.dataset.get_sensitive_attribute_data().loc[y_test.index]

# Get LLM risk-estimate predictions for each row in the test set
test_predictions_save_path = self._get_predictions_save_path("test")
self._y_test_scores = self.llm_clf.predict_proba(
data=X_test,
predictions_save_path=self._get_predictions_save_path("test"),
predictions_save_path=test_predictions_save_path,
labels=y_test, # used only to save alongside predictions in disk
)

Expand All @@ -199,6 +200,9 @@ def run(self, fit_threshold: int | False = False) -> float:
model_name=self.llm_clf.model_name,
)

# Save predictions save path
self._results["predictions_path"] = test_predictions_save_path.as_posix()

# Log main results
msg = (
f"\n** Test results **\n"
Expand Down
2 changes: 1 addition & 1 deletion folktexts/cli/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def launch_experiment_job(exp: Experiment):

# Concurrency limits:
# > each job uses this amount of resources out of a pool of 10k
"concurrency_limits": "user.llm_clf:500", # 20 jobs in parallel
"concurrency_limits": "user.folktexts:100", # 100 jobs in parallel

"+MaxRunningPrice": MAX_RUNNING_PRICE,
"+RunningPriceExceededAction": classad.quote("restart"),
Expand Down
16 changes: 6 additions & 10 deletions folktexts/cli/launch_acs_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from pprint import pprint

from folktexts._io import load_json, save_json
from folktexts._utils import get_current_date
from folktexts.llm_utils import get_model_folder_path, get_model_size_B

from .experiments import Experiment, launch_experiment_job

# All ACS prediction tasks
ACS_TASKS = (
"ACSIncome",
# "ACSEmployment", # TODO: get other ACS tasks running
# "ACSEmployment", # TODO: run on other ACS tasks
# "ACSMobility",
# "ACSTravelTime",
# "ACSPublicCoverage",
Expand All @@ -26,14 +25,13 @@
# Useful paths #
################
ROOT_DIR = Path("/fast/groups/sf")
# ROOT_DIR = Path("/fast/acruz")
# ROOT_DIR = Path("~").expanduser().resolve() # on local machine

# ACS data directory
ACS_DATA_DIR = ROOT_DIR / "data"

# Directory to save results in (make sure it exists)
RESULTS_DIR = ROOT_DIR / "folktexts-results" / get_current_date()
RESULTS_DIR = ROOT_DIR / "folktexts-results"
RESULTS_DIR.mkdir(exist_ok=True, parents=False)

# Models save directory
Expand All @@ -49,14 +47,12 @@
BATCH_SIZE = 30
CONTEXT_SIZE = 500
CORRECT_ORDER_BIAS = True
FIT_THRESHOLD = 100

VERBOSE = True

JOB_CPUS = 4
JOB_MEMORY_GB = 60
# JOB_BID = 50
JOB_BID = 505
JOB_BID = 250

# LLMs to evaluate
LLM_MODELS = [
Expand All @@ -75,8 +71,8 @@
# # ** Large models **
"01-ai/Yi-34B",
"01-ai/Yi-34B-Chat",
# "mistralai/Mixtral-8x7B-v0.1",
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"meta-llama/Meta-Llama-3-70B",
"meta-llama/Meta-Llama-3-70B-Instruct",
# "mistralai/Mixtral-8x22B-v0.1",
Expand Down Expand Up @@ -115,7 +111,7 @@ def make_llm_as_clf_experiment(
experiment_kwargs.setdefault("batch_size", math.ceil(BATCH_SIZE / n_shots))
experiment_kwargs.setdefault("context_size", CONTEXT_SIZE * n_shots)
experiment_kwargs.setdefault("data_dir", ACS_DATA_DIR.as_posix())
experiment_kwargs.setdefault("fit_threshold", FIT_THRESHOLD)
# experiment_kwargs.setdefault("fit_threshold", FIT_THRESHOLD)

# Define experiment
exp = Experiment(
Expand Down

0 comments on commit ed987af

Please sign in to comment.