Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P1] Transitioning from peft to pyreft for Classification Approach #92

Open
SaBay89 opened this issue May 28, 2024 · 2 comments
Open

[P1] Transitioning from peft to pyreft for Classification Approach #92

SaBay89 opened this issue May 28, 2024 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@SaBay89
Copy link

SaBay89 commented May 28, 2024

I'm encountering problems when training a classification model. While the peft library works without errors and initiates training, using pyreft instead results in the following error:

**ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided ['labels']

Features of the input for training: ['input_ids', 'attention_mask', 'labels']

  • Loading the Model
    from transformers import AutoModelForSequenceClassification model = AutoModelForSequenceClassification.from_pretrained( "google/gemma-2b-it", num_labels = len(classes), id2label = id2class, label2id = class2id, #quantization_config=bnb_config, device_map={"":0}, problem_type = "multi_label_classification", )

  • peft wrapping
    from peft import LoraConfig, get_peft_model modules = ['k_proj', 'o_proj', 'up_proj', 'q_proj', 'down_proj', 'gate_proj', 'v_proj'] lora_config = LoraConfig( r=64, lora_alpha=32, target_modules=modules, lora_dropout=0.05, bias="none", task_type="SEQ_CLS") model = get_peft_model(model, lora_config) model.print_trainable_parameters()

--> trainable params: 78,794,752 || all params: 2,585,315,328 || trainable%: 3.0478_

  • Start Training
    trainer = Trainer( model, args=training_args, train_dataset=tokenized_dataset_train, tokenizer=tokenizer, data_collator=data_collator, ) trainer.train()__

This works without any errors and the training starts.

The same with pyreft

  • with pyreft wrapping (after loading the model)
    reft_config = pyreft.ReftConfig(representations={ "component": "block_output", "component": "model.layers[0].output", "low_rank_dimension": 4, "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size, low_rank_dimension=4)}) model = pyreft.get_reft_model(model, reft_config) model.set_device("cuda") model.print_trainable_parameters()

--> trainable intervention params: 16,388 || trainable model params: 0
--> model params: 2,506,520,576 || trainable%: 0.0006538147006218711

The error arises also when I use the pyreft.ReftTrainerForSequenceClassification class from the pyreft library instead of the standard Trainer class. While the peft library works seamlessly with the Trainer class, it seems that pyreft requires a different approach for data preparation.

I apologize for any misrepresentation that may arise from directly comparing peft to pyreft. My intention is not to diminish the capabilities of pyreft but rather to highlight the specific challenges I've encountered while transitioning from a familiar tool to a new one.

@frankaging frankaging changed the title Transitioning from peft to pyreft for Classification Approach [P1] Transitioning from peft to pyreft for Classification Approach May 28, 2024
@frankaging frankaging self-assigned this May 28, 2024
@frankaging frankaging added the question Further information is requested label May 28, 2024
@frankaging
Copy link
Collaborator

@SaBay89 hi, thanks for raising the question.

please take a look at our GLUE experiment setup for sequence classification tasks:
https://github.com/stanfordnlp/pyreft/blob/main/examples/loreft/train.py#L363

what are the fields for your training dataset? tokenized_dataset_train? does it contain all the needed fields?

@xinlanz
Copy link

xinlanz commented Jun 11, 2024

May I ask if you have solved this problem, I had the same problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants