Skip to content

Commit

Permalink
ENH Allow disabling input dtype casting for LoRA (#2353)
Browse files Browse the repository at this point in the history
Provides the disable_input_dtype_casting to prevent the input dtype to
be cast during the forward call of a PEFT layer.

Normally, the dtype of the weight and input need to match, which is why
the dtype is cast. However, in certain circumustances, this is handled
by forward hooks, e.g. when using layerwise casting in diffusers. In
that case, PEFT casting the dtype interferes with the layerwise casting,
which is why the option to disable it is given.

Right now, this only supports LoRA. LoKr and LoHa don't cast the input
dtype anyway. Therefore, the PEFT methods most relevant for diffusers
are covered.
  • Loading branch information
BenjaminBossan authored Feb 4, 2025
1 parent 2825774 commit db9dd3f
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 30 deletions.
7 changes: 6 additions & 1 deletion docs/source/package_reference/helpers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@ A collection of helper functions for PEFT.
## Temporarily Rescaling Adapter Scale in LoraLayer Modules

[[autodoc]] helpers.rescale_adapter_scale
- all
- all

## Context manager to disable input dtype casting in the `forward` method of LoRA layers

[[autodoc]] helpers.disable_input_dtype_casting
- all
43 changes: 42 additions & 1 deletion src/peft/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from functools import update_wrapper
from types import MethodType

from torch import nn

from .peft_model import PeftConfig, PeftModel
from .tuners.lora.layer import LoraLayer
from .tuners.lora import LoraLayer


def update_forward_signature(model: PeftModel) -> None:
Expand Down Expand Up @@ -209,3 +211,42 @@ def rescale_adapter_scale(model, multiplier):
# restore original scaling values after exiting the context
for module, scaling in original_scaling.items():
module.scaling = scaling


@contextmanager
def disable_input_dtype_casting(model: nn.Module, active: bool = True):
"""
Context manager disables input dtype casting to the dtype of the weight.
Currently specifically works for LoRA.
Parameters:
model (nn.Module):
The model containing PEFT modules whose input dtype casting is to be adjusted.
active (bool):
Whether the context manager is active (default) or inactive.
"""
# Additional info: Normally, the dtype of the weight and input need to match, which is why the dtype is cast.
# However, in certain circumustances, this is handled by forward hooks, e.g. when using layerwise casting in
# diffusers. In that case, PEFT casting the dtype interferes with the layerwise casting, which is why the option to
# disable it is given.
if not active:
yield
return

original_values = {}
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
continue
original_values[name] = module.cast_input_dtype_enabled
module.cast_input_dtype_enabled = False

try:
yield
finally:
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
continue
if name in original_values:
module.cast_input_dtype_enabled = original_values[name]
4 changes: 1 addition & 3 deletions src/peft/tuners/adalora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.dtype)

output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
if requires_conversion:
Expand Down
3 changes: 1 addition & 2 deletions src/peft/tuners/adalora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
x = self._cast_input_dtype(x, torch.float32)

output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
# TODO: here, the dtype conversion is applied on the *whole expression*,
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5

x = x.to(lora_A.dtype)
x = self._cast_input_dtype(x, lora_A.dtype)
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum

return result
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
12 changes: 4 additions & 8 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -243,9 +241,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
Expand Down Expand Up @@ -470,7 +466,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -514,7 +510,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down
8 changes: 2 additions & 6 deletions src/peft/tuners/lora/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def _mixed_batch_forward(
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
Expand Down Expand Up @@ -218,9 +216,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
19 changes: 17 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
self._caches: dict[str, Any] = {}
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled: bool = True
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -492,6 +494,19 @@ def _mixed_batch_forward(

return result

def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""
Whether to cast the dtype of the input to the forward method.
Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting
layer.cast_input_dtype=False, this can be disabled if necessary.
Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager.
"""
if (not self.cast_input_dtype_enabled) or (x.dtype == dtype):
return x
return x.to(dtype=dtype)


# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
Expand Down Expand Up @@ -703,7 +718,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down Expand Up @@ -1268,7 +1283,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/tp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
Expand Down
103 changes: 102 additions & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import pytest
import torch
from diffusers import StableDiffusionPipeline
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import LoraConfig, get_peft_model
from peft.helpers import check_if_peft_model, rescale_adapter_scale
from peft.helpers import check_if_peft_model, disable_input_dtype_casting, rescale_adapter_scale
from peft.tuners.lora.layer import LoraLayer
from peft.utils import infer_device


class TestCheckIsPeftModel:
Expand Down Expand Up @@ -369,3 +371,102 @@ def test_merging_adapter(self, tokenizer):
logits_merged_scaling = model(**inputs).logits

assert torch.allclose(logits_merged_scaling, logits_unmerged_scaling, atol=1e-4, rtol=1e-4)


class TestDisableInputDtypeCasting:
"""Test the context manager `disable_input_dtype_casting` that temporarily disables input dtype casting
in the model.
The test works as follows:
We create a simple MLP and convert it to a PeftModel. The model dtype is set to float16. Then a pre-foward hook is
added that casts the model parameters to float32. Moreover, a post-forward hook is added that casts the weights
back to float16. The input dtype is float32.
Without the disable_input_dtype_casting context, what would happen is that PEFT detects that the input dtype is
float32 but the weight dtype is float16, so it casts the input to float16. Then the pre-forward hook casts the
weight to float32, which results in a RuntimeError.
With the disable_input_dtype_casting context, the input dtype is left as float32 and there is no error. We also add
a hook to record the dtype of the result from the LoraLayer to ensure that it is indeed float32.
"""

device = infer_device()
dtype_record = []

@torch.no_grad()
def cast_params_to_fp32_pre_hook(self, module, input):
for param in module.parameters(recurse=False):
param.data = param.data.float()
return input

@torch.no_grad()
def cast_params_to_fp16_hook(self, module, input, output):
for param in module.parameters(recurse=False):
param.data = param.data.half()
return output

def record_dtype_hook(self, module, input, output):
self.dtype_record.append(output[0].dtype)

@pytest.fixture
def inputs(self):
return torch.randn(4, 10, device=self.device, dtype=torch.float32)

@pytest.fixture
def base_model(self):
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.lin1 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = self.lin0(X)
X = self.lin1(X)
X = self.sm(X)
return X

return MLP()

@pytest.fixture
def model(self, base_model):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(base_model, config).to(device=self.device, dtype=torch.float16)
# Register hooks on the submodule that holds parameters
for module in model.modules():
if sum(p.numel() for p in module.parameters()) > 0:
module.register_forward_pre_hook(self.cast_params_to_fp32_pre_hook)
module.register_forward_hook(self.cast_params_to_fp16_hook)
if isinstance(module, LoraLayer):
module.register_forward_hook(self.record_dtype_hook)
return model

def test_disable_input_dtype_casting_active(self, model, inputs):
self.dtype_record.clear()
with disable_input_dtype_casting(model, active=True):
model(inputs)
assert self.dtype_record == [torch.float32]

def test_no_disable_input_dtype_casting(self, model, inputs):
msg = r"expected m.*1 and m.*2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
model(inputs)

def test_disable_input_dtype_casting_inactive(self, model, inputs):
msg = r"expected m.*1 and m.*2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
with disable_input_dtype_casting(model, active=False):
model(inputs)

def test_disable_input_dtype_casting_inactive_after_existing_context(self, model, inputs):
# this is to ensure that when the context is left, we return to the previous behavior
with disable_input_dtype_casting(model, active=True):
model(inputs)

# after the context exited, we're back to the error
msg = r"expected m.*1 and m.*2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
model(inputs)

0 comments on commit db9dd3f

Please sign in to comment.