Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Lora implementation for nn.Conv1d #2333

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

CCLDArjun
Copy link

Resolves #2241

My comment shows that the shapes match in Enformer model: #2241 (comment)

Unsure how to test further it other than to run it in some training loop

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a Conv1d implementation for LoRA. In general, this looks good, I have a few small comments, please check. Please also run make style to satisfy the linter.

Before merging, however, let's ensure that the code works correctly by adding some tests. We already have a "test factory" for the different LoRA layer types, so this is a matter of adding an entry for Conv1d. To do this, look at this code:

class ModelMha(nn.Module):
def __init__(self):
super().__init__()
self.mha = nn.MultiheadAttention(10, 2)
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float()
X, _ = self.mha(X, X, X)
X = self.lin0(X)
X = self.sm(X)
return X
class MockTransformerWrapper:
"""Mock class to behave like a transformers model.
This is needed because the tests initialize the model by calling transformers_class.from_pretrained.
"""
@classmethod
def from_pretrained(cls, model_id, torch_dtype=None):
# set the seed so that from_pretrained always returns the same model
torch.manual_seed(0)
if torch_dtype is None:
torch_dtype = torch.float32
if model_id == "MLP":
return MLP().to(torch_dtype)
if model_id == "EmbConv1D":
return ModelEmbConv1D().to(torch_dtype)
if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)
if model_id == "Conv3d":
return ModelConv3D().to(torch_dtype)
if model_id == "MLP_LayerNorm":
return MLP_LayerNorm().to(torch_dtype)
if model_id == "MLP2":
return MLP2().to(torch_dtype)
if model_id == "Conv2d2":
return ModelConv2D2().to(torch_dtype)
if model_id == "MHA":
return ModelMha().to(torch_dtype)
raise ValueError(f"model_id {model_id} not implemented")

What we need is to add a model similar to ModelMha but using Conv1d instead. The shape of the input should be 10. The from_pretrained method should get an update to dispatch to said model.

After this, it's only a matter of adding a row to the test cases, following this format:

("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}),
("Conv3d 1 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"]}),
("Conv3d 2 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d", "lin0"]}),
("Conv3d 1 LoRA with DoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"], "use_dora": True}),
("Conv3d 2 LoRA with DoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d", "lin0"], "use_dora": True}),

I hope this makes sense. LMK if you have questions.

src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Show resolved Hide resolved
@CCLDArjun
Copy link
Author

make style was editing 22 different files so I ran ruff on my changed files directly I think it should be good

@BenjaminBossan
Copy link
Member

make style was editing 22 different files so I ran ruff on my changed files directly I think it should be good

I think the most likely explanation is that you were using a different ruff version from what is used on CI. This would explain why CI still fails. Could you please ensure that the same version is used: ruff-0.6.9?

@CCLDArjun
Copy link
Author

@BenjaminBossan Yep, 0.6.9 works much better make style is happy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Request for adding the lora implementation for Conv1d rather than transormers.utils.Conv1d
2 participants