Skip to content

Commit

Permalink
update functions for loading arc-agi dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Nov 26, 2024
1 parent 0c3b521 commit a3cda5c
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 32 deletions.
7 changes: 6 additions & 1 deletion fusion_bench/dataset/arc_agi/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .arc_agi import load_tokenized_arc_agi_dataset
from .arc_agi import (
load_tokenized_arc_agi_dataset_for_ttt,
load_tokenized_arc_agi_dataset,
process_task,
process_task_for_ttt,
)
170 changes: 150 additions & 20 deletions fusion_bench/dataset/arc_agi/arc_agi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing_extensions import TYPE_CHECKING

from .arc import Example, Task
from .preprocess import get_augmenters, process_task
from .preprocess import get_augmenters, process_task_for_ttt, process_task

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -65,7 +65,7 @@ def _join_list(lists: List[List[Any]]) -> List[Any]:
return ans


def _to_task(
def _to_tasks(
train_data: List[Dict[str, Any]],
test_data: List[Dict[str, Any]],
name: str,
Expand All @@ -87,7 +87,7 @@ def _to_task(
return tasks


def _tokenizer_tasks(
def tokenizer_tasks_for_ttt(
tasks: List[Task],
tokenizer: "PreTrainedTokenizer",
use_data_augmentation: bool = True,
Expand All @@ -106,7 +106,7 @@ def _tokenizer_tasks(

formatter = _get_formatter("new")
processor = functools.partial(
process_task,
process_task_for_ttt,
augmenters=augmenters_to_apply,
formatter=formatter,
tokenizer=tokenizer,
Expand All @@ -133,7 +133,34 @@ def _tokenizer_tasks(
return dataset


def load_tokenized_arc_agi_dataset(
def tokenizer_tasks(
tasks: List[Task],
tokenizer: "PreTrainedTokenizer",
):
formatter = _get_formatter("new")
processor = functools.partial(
process_task, formatter=formatter, tokenizer=tokenizer
)

# with Pool(multiprocessing.cpu_count()) as p:
# data = p.map(processor, tasks)
data = _join_list(
[
processor(task)
for task in tqdm(
tasks,
desc="Processing tasks",
dynamic_ncols=True,
leave=False,
disable=not rank_zero_only.rank == 0,
)
]
)
dataset = Dataset.from_list(data)
return dataset


def load_tokenized_arc_agi_dataset_for_ttt(
tokenizer: Optional["PreTrainedTokenizer"],
path: str = "dataartist/arc-agi",
split: Optional[str] = None,
Expand All @@ -144,10 +171,7 @@ def load_tokenized_arc_agi_dataset(
max_num_tasks: Optional[int] = None,
):
# regularize split
if split.lower() == "train":
split = "training"
if split.lower() == "test":
split = "evaluation"
split = split.lower() if split is not None else split

# load cached dataset if available
if cache_path is not None and fusion_bench.utils.path.path_is_dir_and_not_empty(
Expand All @@ -164,27 +188,30 @@ def load_tokenized_arc_agi_dataset(
), "Cached dataset not found. Need tokenizer to process the raw data."

# load raw dataset
datasets = load_dataset(path, split=split)
datasets = load_dataset(path)
datasets = DatasetDict(
{"train": datasets["training"], "test": datasets["evaluation"]}
)
if split is None:
converted_datasets = {
converted_datasets: Dict[str, List[Task]] = {
"train": _join_list(
[
_to_task(
_to_tasks(
task["train"],
task["test"],
task["id"],
)
for task in datasets["training"]
for task in datasets["train"]
]
),
"test": _join_list(
[
_to_task(
_to_tasks(
task["train"],
task["test"],
task["id"],
)
for task in datasets["evaluation"]
for task in datasets["test"]
]
),
}
Expand All @@ -195,7 +222,7 @@ def load_tokenized_arc_agi_dataset(
for split in converted_datasets
}
converted_datasets = {
split: _tokenizer_tasks(
split: tokenizer_tasks_for_ttt(
converted_datasets[split],
tokenizer,
use_data_augmentation,
Expand All @@ -210,25 +237,128 @@ def load_tokenized_arc_agi_dataset(
)
}
converted_datasets = DatasetDict(converted_datasets)
else:
else: # split is not None
converted_datasets = _join_list(
[
_to_task(
_to_tasks(
task["train"],
task["test"],
task["id"],
)
for task in datasets
for task in datasets[split]
]
)
if max_num_tasks is not None:
# limit the number of tasks, useful for debugging
converted_datasets = converted_datasets[:max_num_tasks]
converted_datasets = _tokenizer_tasks(
converted_datasets = tokenizer_tasks_for_ttt(
converted_datasets, tokenizer, use_data_augmentation, permute_n, seed
)

if cache_path is not None and rank_zero_only.rank == 0:
os.makedirs(cache_path, exist_ok=True)
converted_datasets.save_to_disk(cache_path)
return converted_datasets


def load_tokenized_arc_agi_dataset(
tokenizer: Optional["PreTrainedTokenizer"],
path: str = "dataartist/arc-agi",
split: Optional[str] = None,
cache_path: Optional[str] = None,
max_num_tasks: Optional[int] = None,
):
"""
Loads and tokenizes the ARC-AGI dataset.
Args:
tokenizer (Optional[PreTrainedTokenizer]): The tokenizer to use for tokenizing the dataset.
path (str, optional): The path to the dataset. Defaults to "dataartist/arc-agi".
split (Optional[str], optional): The dataset split to load (e.g., "train", "test"). Defaults to None.
cache_path (Optional[str], optional): The path to cache the processed dataset. Defaults to None.
max_num_tasks (Optional[int], optional): The maximum number of tasks to load. Useful for debugging. Defaults to None.
Returns:
DatasetDict or Dataset: The tokenized dataset, either as a DatasetDict if split is None, or as a Dataset if a specific split is specified.
"""
# regularize split
split = split.lower() if split is not None else split

# load cached dataset if available
if cache_path is not None and fusion_bench.utils.path.path_is_dir_and_not_empty(
cache_path
):
datasets = load_from_disk(cache_path)
if split is None and split in datasets.column_names:
return datasets[split]
else:
return datasets
else:
assert (
tokenizer is not None
), "Cached dataset not found. Need tokenizer to process the raw data."

# load raw dataset
datasets = load_dataset(path)
datasets = DatasetDict(
{"train": datasets["training"], "test": datasets["evaluation"]}
)
if split is None:
converted_datasets: Dict[str, List[Task]] = {
"train": _join_list(
[
_to_tasks(
task["train"],
task["test"],
task["id"],
)
for task in datasets["train"]
]
),
"test": _join_list(
[
_to_tasks(
task["train"],
task["test"],
task["id"],
)
for task in datasets["test"]
]
),
}
if max_num_tasks is not None:
# limit the number of tasks, useful for debugging
converted_datasets = {
split: converted_datasets[split][:max_num_tasks]
for split in converted_datasets
}
converted_datasets = {
split: tokenizer_tasks(converted_datasets[split], tokenizer)
for split in tqdm(
converted_datasets,
desc="Processing splits",
dynamic_ncols=True,
disable=not rank_zero_only.rank == 0,
)
}
converted_datasets = DatasetDict(converted_datasets)
else: # split is not None
converted_datasets = _join_list(
[
_to_tasks(
task["train"],
task["test"],
task["id"],
)
for task in datasets[split]
]
)
if max_num_tasks is not None:
# limit the number of tasks, useful for debugging
converted_datasets = converted_datasets[:max_num_tasks]
converted_datasets = tokenizer_tasks(converted_datasets, tokenizer)

if cache_path is not None and rank_zero_only.rank == 0:
os.makedirs(cache_path, exist_ok=True)
converted_datasets.save_to_disk(cache_path)
return converted_datasets
58 changes: 50 additions & 8 deletions fusion_bench/dataset/arc_agi/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,41 @@ def get_augmenters(
def format_and_filter(
formatter: MessageRepresenter,
tokenizer: "PreTrainedTokenizer",
task,
task: Task,
):
"""
Formats and filters a task for model input.
Args:
formatter (MessageRepresenter): The formatter to encode the task.
tokenizer (PreTrainedTokenizer): The tokenizer to tokenize the conversation.
task: The task to be formatted and filtered.
Returns:
Dict[str, Any]: A dictionary containing the formatted data with keys:
- "input_ids": The tokenized input IDs.
- "attention_mask": The attention mask for the input IDs.
- "labels": The labels for the input IDs.
- "task_id": The task ID.
- "num_prompt_tokens": The number of prompt tokens.
- "num_output_tokens": The number of output tokens.
"""
task_id = task.name
task = formatter.encode(task)
conversation = task[0] + [task[1]]
assert conversation[-1]["role"] == "assistant", "Last message should be assistant"
prompt_tokens = tokenizer.apply_chat_template(
conversation[:-1], tokenize=True, add_generation_prompt=True
)
output_tokens = tokenizer.encode(conversation[-1]["content"] + tokenizer.eos_token)
generation_tokens = tokenizer.apply_chat_template(conversation, tokenize=True)
output_tokens = generation_tokens[len(prompt_tokens) :]
data = {
"input_ids": prompt_tokens + output_tokens,
"attention_mask": [1] * len(prompt_tokens) + [1] * len(output_tokens),
"labels": [-100] * len(prompt_tokens) + output_tokens,
"task_id": task_id,
"num_prompt_tokens": len(prompt_tokens),
"num_output_tokens": len(output_tokens),
}
return data

Expand All @@ -136,6 +158,19 @@ def get_test_time_train_data(
permute_n: int = 1,
seed: int = 0,
) -> List[Task]:
"""
Generates augmented training data for test-time training.
Args:
original_task (Task): The original task containing training examples.
augmenters (List[Augmenter]): A list of augmenters to apply to the tasks.
n (int, optional): The number of examples to leave out for testing. Defaults to 1.
permute_n (int, optional): The number of times to permute the augmented tasks. Defaults to 1.
seed (int, optional): The random seed for reproducibility. Defaults to 0.
Returns:
List[Task]: A list of augmented tasks.
"""
rng = np.random.RandomState(seed)
train_examples = original_task.train_examples.copy()
initial_tasks = []
Expand All @@ -150,7 +185,7 @@ def get_test_time_train_data(
for comb in combs:
initial_tasks.append(
Task(
name="",
name=original_task.name,
train_examples=[examples[j] for j in comb],
test_example=examples[i],
)
Expand Down Expand Up @@ -183,7 +218,6 @@ def get_test_time_train_data(
color_and_permute_augmented_tasks.append(new_task)

augmented_tasks = color_and_permute_augmented_tasks + augmented_tasks

augmented_tasks = list(set(augmented_tasks))

return augmented_tasks
Expand All @@ -193,13 +227,12 @@ def get_formatted_data(
task: Task,
augmenters: List[Augmenter],
formatter: MessageRepresenter,
tokenizer,
tokenizer: "PreTrainedTokenizer",
leave_n: int = 1,
permute_n: int = 1,
seed: int = 0,
max_tokens: int = 8192,
):

train_data = get_test_time_train_data(
task, augmenters, n=leave_n, permute_n=permute_n, seed=seed
)
Expand All @@ -213,11 +246,11 @@ def get_formatted_data(
return formatted_data


def process_task(
def process_task_for_ttt(
task: Task,
augmenters: List[Augmenter],
formatter: MessageRepresenter,
tokenizer,
tokenizer: "PreTrainedTokenizer",
permute_n: int = 1,
Nmax: int = 250,
seed: int = 0,
Expand Down Expand Up @@ -254,3 +287,12 @@ def process_task(
train = train[:Nmax]

return train


def process_task(
task: Task,
formatter: MessageRepresenter,
tokenizer: "PreTrainedTokenizer",
):
formatted = format_and_filter(formatter, tokenizer, task)
return [formatted]
Loading

0 comments on commit a3cda5c

Please sign in to comment.