Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with finetuning with Corda #2317

Open
2 of 4 tasks
sirluk opened this issue Jan 9, 2025 · 11 comments
Open
2 of 4 tasks

Issue with finetuning with Corda #2317

sirluk opened this issue Jan 9, 2025 · 11 comments

Comments

@sirluk
Copy link
Contributor

sirluk commented Jan 9, 2025

System Info

peft master branch (commit 8d3039b)

Who can help?

@BenjaminBossan @5eqn

Hi, I would like to try out Corda for my finetuning usecase but looking at the loss curves something seems to be going wrong so I just wanted to verify I implemented Corda correctly.

This is the relevant code snippet from my script. I have a tokenized dataset which I wrap with a dataloader with a batch size = 1 to pass to the preprocess_corda function. Once preprocess_corda is done computing I can just instantiate the peft model as usual with the required config, correct?

Would greatly appreciate some feedback.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

# imports
import torch
from functools import partial
from datasets import load_dataset, interleave_datasets, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig
from peft.tuners.lora.corda import preprocess_corda
from peft.tuners.lora.config import CordaConfig


# functions
def _tokenize_fn(prompts, completions, tokenizer):
    prompt_tokens = tokenizer(prompts, add_special_tokens=False)["input_ids"]
    input_tokens = tokenizer([x+y for x, y in zip(prompts, completions)], add_special_tokens=False)["input_ids"]
    input_tokens = [[tokenizer.bos_token_id]+x+[tokenizer.eos_token_id] for x in input_tokens]
    prompt_length = [len(x)+1 for x in prompt_tokens] # +1 for the bos token
    input_length = [len(x) for x in input_tokens]
    return {"input_ids": input_tokens, "prompt_length": prompt_length, "input_length": input_length}

class _TokenizerPromptSource:

    def __init__(self, tokenizer_path, space_after_prompt=True):
        
        # import promptsource
        from promptsource_custom.templates import DatasetTemplates
        self.dataset_templates = DatasetTemplates
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.space_after_prompt = space_after_prompt

    def __call__(self, examples):
        examples = [dict(zip(examples.keys(), e)) for e in zip(*examples.values())]
        prompts, completions = zip(*[self.prompt.apply(e) for e in examples])
        if self.space_after_prompt:
            prompts = [p + " " for p in prompts]
        return _tokenize_fn(prompts, completions, self.tokenizer)
    
class TokenizerWinogrande(_TokenizerPromptSource):

    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.prompt = self.dataset_templates("winogrande", "winogrande_xl")["multiple_choice_simple"]

class TokenizerHellaswag(_TokenizerPromptSource):

    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.prompt = self.dataset_templates("hellaswag")["multiple_choice_simple"]

class TokenizerArcChallenge(_TokenizerPromptSource):

    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.prompt = self.dataset_templates("ai2_arc", "ARC-Challenge")["multiple_choice_simple"]

class TokenizerArcEasy(_TokenizerPromptSource):

    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.prompt = self.dataset_templates("ai2_arc", "ARC-Easy")["multiple_choice_simple"]

class TokenizerPIQA(_TokenizerPromptSource):

    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.prompt = self.dataset_templates("piqa")["multiple_choice_simple"]

class TokenizerSIQA(_TokenizerPromptSource):

    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.prompt = self.dataset_templates("social_i_qa")["multiple_choice_simple"]

class TokenizerOpenBookQA(_TokenizerPromptSource):

    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.prompt = self.dataset_templates("openbookqa", "main")["multiple_choice_simple"]

class TokenizerBoolQ(_TokenizerPromptSource):

    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.prompt = self.dataset_templates("super_glue", "boolq")["multiple_choice_simple"]

class DataCollator:
    def __init__(self, eos_token_id, max_length = None):
        self.eos_token_id = eos_token_id
        self.max_length = max_length

    def __call__(self, batch):
        batch = {k: [item[k] for item in batch] for k in batch[0]}
        input_lengths = torch.stack(batch["input_length"])
        prompt_lengths = torch.stack(batch["prompt_length"])
        input_ids = torch.nn.utils.rnn.pad_sequence(batch["input_ids"], batch_first=True, padding_value=self.eos_token_id)
        col_indices = torch.arange(input_ids.size(1)).unsqueeze(0)
        attention_mask = col_indices < input_lengths.unsqueeze(1)
        label_mask = torch.logical_or(col_indices < prompt_lengths.unsqueeze(1), ~attention_mask)
        labels = input_ids.masked_fill(label_mask, -100)
        if self.max_length is not None:
            input_ids = input_ids[:, :self.max_length]
            attention_mask = attention_mask[:, :self.max_length]
            labels = labels[:, :self.max_length]
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

# constants
CORDA = False
SEED = 0
BATCH_SIZE = 4
NUM_EPOCHS = 1
LEARNING_RATE = 5e-4
GRADIENT_ACCUMULATION_STEPS = 8
MODEL_NAME = "meta-llama/Llama-2-7b-hf"
MODEL_MAX_LENGTH = 1024
QA_DATASETS = [
    "Rowan/hellaswag",
    "allenai/winogrande",
    "allenai/ai2_arc_challenge",
    "allenai/ai2_arc_easy",
    "ybisk/piqa",
    "allenai/social_i_qa",
    "allenai/openbookqa",
    "boolq"
]
LOAD_DATASET_KWARGS = {
    "Rowan/hellaswag": {"path": "Rowan/hellaswag"},
    "allenai/winogrande": {"path": "allenai/winogrande", "name": "winogrande_xl"},
    "allenai/ai2_arc_challenge": {"path": "allenai/ai2_arc", "name": "ARC-Challenge"},
    "allenai/ai2_arc_easy": {"path": "allenai/ai2_arc", "name": "ARC-Easy"},
    "ybisk/piqa": {"path": "ybisk/piqa"},
    "allenai/social_i_qa": {"path": "allenai/social_i_qa"},
    "allenai/openbookqa": {"path": "allenai/openbookqa", "name": "main"},
    "boolq": {"path": "aps/super_glue", "name": "boolq"}
}
TOKENIZE_MAP = {
    "Rowan/hellaswag": TokenizerHellaswag,
    "allenai/winogrande": TokenizerWinogrande,
    "allenai/ai2_arc_challenge": TokenizerArcChallenge,
    "allenai/ai2_arc_easy": TokenizerArcEasy,
    "ybisk/piqa": TokenizerPIQA,
    "allenai/social_i_qa": TokenizerSIQA,
    "allenai/openbookqa": TokenizerOpenBookQA,
    "boolq": TokenizerBoolQ
}

# load model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.cuda()

# load dataset
datasets = []
for dataset_name in QA_DATASETS:
    tokenizer_cls = TOKENIZE_MAP[dataset_name]
    tokenizer_wrapper = tokenizer_cls(tokenizer_path=MODEL_NAME)
    load_dataset_kwargs = LOAD_DATASET_KWARGS[dataset_name]
    if load_dataset_kwargs["path"] is not None:
        load_dataset_kwargs["path"] = load_dataset_kwargs["path"]
    datasets.append(load_dataset(**load_dataset_kwargs, trust_remote_code=True))
    datasets[-1] = datasets[-1].map(tokenizer_wrapper, batched=True, remove_columns=datasets[-1]["train"].column_names)
    datasets[-1].set_format(type="torch")
    datasets[-1] = datasets[-1].shuffle(seed=SEED)
all_splits = set([n for ds in datasets for n in ds.keys()])
datasets = DatasetDict({split: interleave_datasets([ds[split] for ds in datasets if split in ds]) for split in all_splits})
data_collator = DataCollator(tokenizer_wrapper.tokenizer.eos_token_id, MODEL_MAX_LENGTH)

# get peft config
target_modules = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]


if CORDA:
    corda_config = CordaConfig(corda_method="ipm")
    lora_config = LoraConfig(
        init_lora_weights="corda",
        target_modules=target_modules,
        lora_alpha=1,
        lora_dropout=0,
        r=16,
        corda_config=corda_config
    )
    sampled_dataset = datasets["train"].select(list(range(256)))
    corda_data_loader = torch.utils.data.DataLoader(
        sampled_dataset,
        batch_size=1,
        collate_fn=data_collator,
        shuffle=True
    )
    def run_model(model, corda_data_loader):
        for batch in corda_data_loader:
            input_ids = batch["input_ids"]
            input_ids = input_ids.to(model.device)
            with torch.no_grad():
                model(input_ids)
    run_model = partial(run_model, model=model, corda_data_loader=corda_data_loader)
    preprocess_corda(model, lora_config, run_model=run_model)
else:
    lora_config = LoraConfig(
        init_lora_weights=True,
        target_modules=target_modules,
        lora_alpha=1,
        lora_dropout=0,
        r=16
    )

model = get_peft_model(model, lora_config)

training_args = TrainingArguments(
    output_dir="output",
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    seed=SEED,
    learning_rate=LEARNING_RATE,
    remove_unused_columns=False,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    report_to=[]
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=datasets["train"],
    eval_dataset=datasets["validation"] if "validation" in datasets else None,
    data_collator=data_collator
)
trainer.train()

Expected behavior

I tried to follow the corda example in the documentation and thought it should work like this

@BenjaminBossan
Copy link
Member

Thanks for trying out CorDA and reporting your problem. To diagnose further, could you please provide a bit more info:

  1. What exactly is wrong with the loss curves? Do you have a plot of the loss with vs without CorDA?
  2. What model and dataset are you using? Could you perhaps share a full example?

@sirluk
Copy link
Contributor Author

sirluk commented Jan 9, 2025

@BenjaminBossan sorry for beeing a bit unprecise. I figured out now that it seems to be just one of my seeds that causes problems for Corda
image

I updated my initial comment with a full training script. For the training script the attached folder promptsource_custom.zip is also needed. It contains the finetuning templates.

Since the performance seems to overall be a bit lacking for my dataset I was also wondering if the authors have some suggestions which corda hyperparamters I could change, maybe increase the number of samples for preprocess_corda. The average performance when evaluated on these benchmarks is a bit worse than vanilla LoRA 0.79 vs 0.89:

  • hellaswag
  • winogrande
  • arc_challenge
  • arc_easy
  • piqa
  • siqa
  • openbookqa
  • boolq

@5eqn
Copy link
Contributor

5eqn commented Jan 10, 2025

I've noticed that the length of per question in your sampled_dataset is much shorter than what we experimented with. For example, a question in allenai/ai2_arc has about 60 tokens, but in our official example we used a sequence length of 2048, This example uses cross attention as well, but I'm able to demonstrate superiority of CorDA without cross attention with sequence length being about 600 and sample_count = 256.

If length per question is 10x shorter than our experiment, and sample count remain the same, the covariance matrix might be low-ranked. As we apply damp for both KPM and IPM, the damp might be too large, causing the size of LoRA matrix $A$ and $B$ to be too unbalanced.

You can try increasing sample count to about 2560 based on average length per question, or using cross attention (like our official example) if needed.

@sirluk
Copy link
Contributor Author

sirluk commented Jan 10, 2025

@5eqn thanks alot for your answer! I will try increasing the sample count. One additional question regarding that, In my script am I defining the run_model function correctly with a dataloader that processes the dataset in batches of 1?

@5eqn
Copy link
Contributor

5eqn commented Jan 10, 2025

One additional question regarding that, In my script am I defining the run_model function correctly with a dataloader that processes the dataset in batches of 1?

Yes, it's correct.

@iboing
Copy link
Contributor

iboing commented Jan 10, 2025

Hi @sirluk , it seems that only one seed (the grey curve) causes the problem. It occurs due to numerical error because the process involves svd and matrix inverse. One suggestion to detect such issue is to compare the initial performances (e.g. on wiki or some benchmark before finetuning) of the pre-trained model and the PEFT model initialized with corda. They should be as close as possible.
Regarding the preformance, you could try increasing the sample count as suggested by @5eqn . Besidies, you may try adjusting the composition of the sampling dataset. It is possible that sampling from one particular sub-dataset or from some sub-datasets is better than uniformly sampling from all sub-datasets.

@sirluk
Copy link
Contributor Author

sirluk commented Jan 11, 2025

@iboing thanks for you additional suggestions! Increasing the sample count unfortunately didnt work as I was running into OOM errors (also when increasing from 256 to 512). But I will try assembling the dataset differently

@sirluk
Copy link
Contributor Author

sirluk commented Jan 12, 2025

@5eqn @iboing
Something else I noticed during corda finetuning is that you seem to store some matrices in the lora layers for preprocess_corda these matrices then do not get removed even though I think we dont need them for finetuning? and lead to memory errors for me.
For example each lora module has the attributes covariance_matrix and eigens. I had a batch size for LoRA which fit on an A100 but for Corda I needed 100GB of memory for the same batch size.

@5eqn
Copy link
Contributor

5eqn commented Jan 13, 2025

Hi @sirluk , thanks for pointing this out! I've fixed this in our development branch by removing unnecessary mean + std fields and pruning temporary fields when no longer needed. I have run all CorDA tests locally. You can try using the development branch, or fixing this by manually removing redundant fields after get_peft_model:

for name, module in peft_model.base_model.named_modules():
    if hasattr(module, "sample_count"):
        del module.sample_count
    if hasattr(module, "covariance_matrix"):
        del module.covariance_matrix
    if hasattr(module, "mean"):
        del module.mean
    if hasattr(module, "std"):
        del module.std
    if hasattr(module, "corda_method"):
        del module.corda_method
    if hasattr(module, "rank"):
        del module.rank
    if hasattr(module, "eigens"):
        del module.eigens

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Jan 13, 2025

Thanks for this fruitful discussion. Regarding this optimization, could a PR be created for PEFT too? It would probably also help addressing the memory issue I had with CorDA that I reported earlier.

I think it would also be great if the example docs could be updated to contain the insights discussed above regarding the size of the sample dataset.

@5eqn
Copy link
Contributor

5eqn commented Jan 13, 2025

Thanks for this fruitful discussion. Regarding this optimization, could a PR be created for PEFT too?

Sure, I'll open the PR after appending documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants