Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotswap allow different alpha scalings and ranks #2177

Merged
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
355e2ab
[WIP] Allow different alphas for hotswapping
BenjaminBossan Oct 24, 2024
135dfda
Only enable hotswap tests in CI for now
BenjaminBossan Oct 24, 2024
9b7d979
More debugging
BenjaminBossan Oct 24, 2024
3a6d4b5
Merge branch 'main' into hotswap-allow-different-alpha-scalings
BenjaminBossan Oct 24, 2024
088a9d9
More debugging
BenjaminBossan Oct 24, 2024
3da33ba
Take alpha from new config
BenjaminBossan Oct 24, 2024
0bba669
Remove some debug code
BenjaminBossan Oct 24, 2024
c73dafc
Merge branch 'main' into hotswap-allow-different-alpha-scalings
BenjaminBossan Nov 28, 2024
552ecac
Compiled hotswapping works w/ different scalings
BenjaminBossan Nov 28, 2024
2dcd6a1
[WIP] Hotswap with different ranks
BenjaminBossan Nov 28, 2024
620d43e
Make hotswapping with different ranks work
BenjaminBossan Nov 29, 2024
efaf2b0
More tests, docs, fixes
BenjaminBossan Dec 2, 2024
25d7a77
Make style, fix
BenjaminBossan Dec 2, 2024
88bd814
Merge branch 'main' into hotswap-allow-different-alpha-scalings
BenjaminBossan Jan 8, 2025
2a29370
Remove obsolete test
BenjaminBossan Jan 8, 2025
9b132ec
Merge branch 'main' into hotswap-allow-different-alpha-scalings
BenjaminBossan Jan 24, 2025
ddf1bad
Disable test skips for CI (only testing)
BenjaminBossan Jan 24, 2025
3251c6c
Clean up, documentation
BenjaminBossan Jan 28, 2025
9966c10
Reviewer feedback: Refactor _pad_lora_weights
BenjaminBossan Jan 30, 2025
01924ad
Reviewer feedback: Add clarifying comments
BenjaminBossan Jan 30, 2025
66b987e
Fix issue with bias term
BenjaminBossan Jan 31, 2025
d32d23f
Clarify docstring
BenjaminBossan Jan 31, 2025
8b7e602
Reviewer feedback:
BenjaminBossan Feb 3, 2025
d0a1fc5
Reviewer feedback: add clarifying comments
BenjaminBossan Feb 4, 2025
7772e2e
Reviewer feedback: address model.compile()
BenjaminBossan Feb 4, 2025
fe19da5
Revert testing
BenjaminBossan Feb 4, 2025
5bc0dbe
Simplify tests for recompilation
BenjaminBossan Feb 5, 2025
a412f7a
Reviewer feedback: Better handling of torch cache
BenjaminBossan Feb 5, 2025
0943519
Change fixture to clean up dynamo cache after test
BenjaminBossan Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 144 additions & 98 deletions src/peft/utils/hotswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,122 +71,166 @@ def _convert_scalings_to_tensor(model):
)


def _pad_lora_weights(model, target_rank):
def _get_padded_linear(lora_module: torch.nn.Module, target_rank: int, is_lora_A: bool) -> torch.nn.Linear:
"""
Pad LoRA weights in a state dict to a target rank while preserving the original behavior.
Get a new Linear layer for LoRA with padded weights according to the target rank.

Args:
state_dict (dict): The state dict containing LoRA weights
target_rank (int): The target rank to pad to
lora_module (nn.Module):
The LoRA sub-module (e.g. module.lora_A[adapter_name]).
target_rank (int):
The desired rank to pad to.
is_lora_A (bool):
True if this is the LoRA A matrix, False if LoRA B.

Returns:
nn.Linear:
A newly created and padded Linear layer. If the rank already fit, the original layer is returned.
"""
weight = lora_module.weight
# For LoRA A, the "rank dimension" is weight.size(0) (out_features).
# For LoRA B, it is weight.size(1) (in_features).
original_rank = weight.size(0) if is_lora_A else weight.size(1)

# If no padding needed
if original_rank == target_rank:
return lora_module

if original_rank > target_rank:
raise ValueError(
f"Trying to pad the adapter to the target rank {target_rank}, but the original rank is larger "
f"({original_rank}). This is not possible."
)

out_features, in_features = weight.shape

if is_lora_A:
# LoRA A affects out_features
padded = torch.zeros(target_rank, in_features, device=weight.device, dtype=weight.dtype)
padded[:original_rank, :] = weight
new_layer = torch.nn.Linear(in_features, target_rank, bias=lora_module.bias)
githubnemo marked this conversation as resolved.
Show resolved Hide resolved
else:
# LoRA B affects in_features
padded = torch.zeros(out_features, target_rank, device=weight.device, dtype=weight.dtype)
padded[:, :original_rank] = weight
new_layer = torch.nn.Linear(target_rank, out_features, bias=lora_module.bias)

# Sanity check
if new_layer.weight.shape != padded.shape:
raise ValueError(
"Something went wrong when trying to pad the LoRA Linear weights, the new shape should be "
f"{padded.shape} but {new_layer.weight.shape} was found. Please open an issue on PEFT "
"(https://github.com/huggingface/peft/issues) and report this error."
)

new_layer.weight.data = padded
# Copy bias if present
if lora_module.bias is not None:
new_layer.bias.data = lora_module.bias.data
Comment on lines +136 to +137
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could potentially also leverage the copy_() method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I tried but got:

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(10, 1, 16)), Parameter(FakeTensor(..., size=(13, 16), requires_grad=True)), None), **{}):
Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu

Not sure why that is, since we take care of the device, but I reverted the change. I think in this specific situation, there is no practical benefit to .copy_ anyway, is there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. Thanks for looking into it.


Returns: new_state_dict: A new state dict with padded LoRA weights
return new_layer


def _get_padded_conv2d(lora_module: torch.nn.Module, target_rank: int, is_lora_A: bool) -> torch.nn.Conv2d:
"""
for module in model.modules():
if not isinstance(module, (Conv2d, Linear)):
continue
Get a new Conv2d layer for LoRA with padded weights according to the target rank.

is_conv = isinstance(module, Conv2d)
Args:
lora_module (nn.Module):
The LoRA sub-module (e.g. module.lora_A[adapter_name]).
target_rank (int):
The desired rank to pad to.
is_lora_A (bool):
True if this is the LoRA A matrix, False if LoRA B.

Returns:
nn.Conv2d:
A newly created and padded Conv2d layer. If the rank already fit, the original layer is returned.
"""
weight = lora_module.weight
# For Conv2d: [out_channels, in_channels, kernel_height, kernel_width]
out_channels, in_channels, kh, kw = weight.shape
original_rank = out_channels if is_lora_A else in_channels

# LoRA A
for adapter_name, lora_module in module.lora_A.items():
weight = lora_module.weight
original_rank = weight.size(0)
if original_rank == target_rank:
return lora_module

if original_rank == target_rank:
continue
if original_rank > target_rank:
raise ValueError(
f"Trying to pad the adapter to the target rank {target_rank}, but the original rank is larger "
f"({original_rank}). This is not possible."
)

if original_rank > target_rank:
raise ValueError(
f"Trying to pad the adapter to the target rank {target_rank}, but the original rank is larger "
f"({original_rank}), which is not possible. Please choose a target rank that is greater or equal "
"to the largest rank of the adapter."
)
if is_lora_A:
# LoRA A affects out_channels
padded = torch.zeros(target_rank, in_channels, kh, kw, device=weight.device, dtype=weight.dtype)
padded[:out_channels, :, :, :] = weight
new_layer = torch.nn.Conv2d(
in_channels,
target_rank,
kernel_size=lora_module.kernel_size,
stride=lora_module.stride,
padding=lora_module.padding,
bias=lora_module.bias,
groups=lora_module.groups,
)
else:
# LoRA B affects in_channels
padded = torch.zeros(out_channels, target_rank, kh, kw, device=weight.device, dtype=weight.dtype)
padded[:, :in_channels, :, :] = weight
new_layer = torch.nn.Conv2d(
target_rank,
out_channels,
kernel_size=lora_module.kernel_size,
stride=lora_module.stride,
padding=lora_module.padding,
bias=lora_module.bias,
groups=lora_module.groups,
)
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

if is_conv:
padded = torch.zeros(
target_rank,
weight.size(1),
weight.size(2),
weight.size(3),
device=weight.device,
dtype=weight.dtype,
)
padded[:original_rank, :, :, :] = weight
new_layer = torch.nn.Conv2d(
weight.size(1),
target_rank,
kernel_size=lora_module.kernel_size,
stride=lora_module.stride,
padding=lora_module.padding,
bias=lora_module.bias,
)
else:
padded = torch.zeros(target_rank, weight.size(1), device=weight.device, dtype=weight.dtype)
padded[:original_rank, :] = weight
new_layer = torch.nn.Linear(weight.size(1), target_rank, bias=lora_module.bias)
# Sanity check
if new_layer.weight.shape != padded.shape:
raise ValueError(
"Something went wrong when trying to pad the LoRA weights, the new shape should be "
f"{padded.shape} but {new_layer.weight.shape} was found. Please open an issue on PEFT "
"(https://github.com/huggingface/peft/issues) and report this error."
)

if new_layer.weight.shape != padded.shape:
raise ValueError(
"Something went wrong when trying to pad the LoRA weights, the new shape should be "
f"{padded.shape} but {new_layer.weight.shape} was found. Please open an issue on PEFT "
"(https://github.com/huggingface/peft/issues) and report this error."
)
new_layer.weight.data = padded
# Copy bias if present
if lora_module.bias is not None:
new_layer.bias.data = lora_module.bias.data

new_layer.weight.data = padded
if lora_module.bias:
new_layer.bias.data = lora_module.bias.data
module.lora_A[adapter_name] = new_layer
return new_layer

# LoRA B
for adapter_name, lora_module in module.lora_B.items():
weight = lora_module.weight
original_rank = weight.size(1)

if original_rank == target_rank:
continue
def _pad_lora_weights(model: torch.nn.Module, target_rank: int) -> None:
"""
Pad LoRA weights in a model to a target rank while preserving the original behavior.

if original_rank > target_rank:
# TODO: is this necessary or can we just continue???
raise ValueError(
f"Trying to pad the adapter to the target rank {target_rank}, but the original rank is larger "
f"({original_rank}), which is not possible. Please choose a target rank that is greater or equal "
"to the largest rank of the adapter."
)
Args:
model (nn.Module): The model containing LoRA modules (with lora_A and lora_B).
target_rank (int): The target rank to pad to.
"""

if is_conv:
padded = torch.zeros(
weight.size(0),
target_rank,
weight.size(2),
weight.size(3),
device=weight.device,
dtype=weight.dtype,
)
padded[:, :original_rank, :, :] = weight
new_layer = torch.nn.Conv2d(
target_rank,
weight.size(0),
kernel_size=lora_module.kernel_size,
stride=lora_module.stride,
padding=lora_module.padding,
bias=lora_module.bias,
)
new_layer.weight.data = padded
else:
padded = torch.zeros(weight.size(0), target_rank, device=weight.device, dtype=weight.dtype)
padded[:, :original_rank] = weight
new_layer = torch.nn.Linear(target_rank, weight.size(0), bias=lora_module.bias)
for module in model.modules():
# Decide which pad function to call based on module type
if isinstance(module, Linear):
pad_fn = _get_padded_linear
elif isinstance(module, Conv2d):
pad_fn = _get_padded_conv2d
else:
# Skip any other module types
continue

if new_layer.weight.shape != padded.shape:
raise ValueError(
"Something went wrong when trying to pad the LoRA weights, the new shape should be "
f"{padded.shape} but {new_layer.weight.shape} was found. Please open an issue on PEFT "
"(https://github.com/huggingface/peft/issues) and report this error."
)
# Pad LoRA A
for adapter_name, lora_A_module in module.lora_A.items():
new_layer = pad_fn(lora_A_module, target_rank=target_rank, is_lora_A=True)
module.lora_A[adapter_name] = new_layer

new_layer.weight.data = padded
if lora_module.bias:
new_layer.bias.data = lora_module.bias.data
# Pad LoRA B
for adapter_name, lora_B_module in module.lora_B.items():
new_layer = pad_fn(lora_B_module, target_rank=target_rank, is_lora_A=False)
module.lora_B[adapter_name] = new_layer


Expand Down Expand Up @@ -236,6 +280,7 @@ def prepare_model_for_compiled_hotswap(
return

if not isinstance(config, dict):
# config can be either a PeftConfig, or a dict of PeftConfigs like PeftModel.peft_config
config = {"dummy": config}

for lora_config in config.values():
Expand Down Expand Up @@ -311,6 +356,7 @@ def hotswap_adapter_from_state_dict(

# actual swapping
for key, new_val in state_dict.items():
# get LoRA parent module name by removing the 'lora_*.<adapter-name>.weight' part
module_name = ".".join(key.split(".")[:-3])
githubnemo marked this conversation as resolved.
Show resolved Hide resolved
module = model.get_submodule(module_name)

Expand Down