Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotswap allow different alpha scalings and ranks #2177

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Oct 24, 2024

See huggingface/diffusers#9453 for context.

Hotswapping of LoRA adapters is already implemented, but when alpha scalings or ranks differ, this triggers recompilation of the model is compiled, which is inefficient. Users can now call prepare_model_for_compiled_hotswap to prevent recompilation in many cases (see the doc update for caveats).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member Author

not stale

@BenjaminBossan BenjaminBossan changed the title Hotswap allow different alpha scalings Hotswap allow different alpha scalings and ranks Dec 11, 2024
@BenjaminBossan
Copy link
Member Author

@sayakpaul I had some time to brush up this PR. How do we proceed? As you can see, I have added diffusers tests, but since huggingface/diffusers#9453 is not merged, it won't work yet (i.e. there are recompiles). That PR is awaiting this PR, so there is a mutual dependency.

Also, for now Conv2d is not fully implemented, but to confirm that it works, we could test with only Linear, right?

@sayakpaul
Copy link
Member

Also, for now Conv2d is not fully implemented, but to confirm that it works, we could test with only Linear, right?

That should work for now. I would add a note about this somewhere.

@sayakpaul I had some time to brush up this PR. How do we proceed? As you can see, I have added diffusers tests, but since huggingface/diffusers#9453 is not merged, it won't work yet (i.e. there are recompiles).

To make progress and help unblock this PR, how about we merge this PR with sufficient testing (which I believe we already have) and mark the diffusers test with an xfail?

@BenjaminBossan BenjaminBossan marked this pull request as ready for review January 28, 2025 15:07
Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some minor comments

src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Show resolved Hide resolved
src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Show resolved Hide resolved
src/peft/utils/hotswap.py Show resolved Hide resolved
Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your comments should be addressed please check again. I split the update into 2 commits, 1 for the refactor and 1 for the rest.

src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Show resolved Hide resolved
src/peft/utils/hotswap.py Show resolved Hide resolved
Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good. While inspecting _get_padded_* I found potentially missing bias handling during padding. Can you take a look?

src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Show resolved Hide resolved
@sayakpaul
Copy link
Member

sayakpaul commented Jan 31, 2025 via email

Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @githubnemo for catching the potential issue with the bias. I have addressed this, as well as an earlier comment that i missed last time.

src/peft/utils/hotswap.py Show resolved Hide resolved
src/peft/utils/hotswap.py Outdated Show resolved Hide resolved
src/peft/utils/hotswap.py Show resolved Hide resolved
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, Benjamin! I have left some comments. LMK if they make sense.

Comment on lines +49 to +53
# load lora 0
model = PeftModel.from_pretrained(model, <path-adapter-0>)
# Prepare the model to allow hotswapping even if ranks/scalings of 2nd adapter differ.
# You can skip this step if all ranks and scalings are identical.
prepare_model_for_compiled_hotswap(model, target_rank=max_rank)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to allow a util like infer_max_rank_from_state_dict()? Because it can be difficult to know for the users to know max_rank from the get-go. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about how an application would use this in practice for the reason you mentioned. If we break it down like this:

  1. Some business use case of serving different LoRAs: They can probably determine the max rank ahead of time and set it correctly.
  2. An image generation app like Comfy: This is hard as they cannot know in advance what LoRAs the user plans on loading. Theoretically, all available LoRAs could be scanned (which might be expensive) and the largest rank be chosen (which might be total overkill and damage runtime performance).

If we had a infer_max_rank_from_state_dict() it would be for the 2nd case, which has the mentioned issues, but correct me if I'm wrong. But regardless of that, where would the state dicts come from? Since we have multiple LoRA formats in the diffusion ecosystem and diffusers knows how to deal with them, would that function not be better added to diffusers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

- The adapters must be compatible (e.g. same LoRA alpha, same target modules).
- If you use `torch.compile` and want to avoid recompilation, the LoRA rank must be the same.
- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
- The adapters must be compatible (e.g. same target modules).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the first adapter already contains a subset of the target_modules of the second adapter, for example, -- would it still be a problem?

This limitation seems to be quite critical to me as a user, though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 2 scenarios when the target modules are not identical:

  1. The 2nd adapter targets a subset of the 1st adapter: We can theoretically swap in the 2nd adapter, since there will be a LoRA layer in place where it is needed. However, the results will be wrong because the we will also have unneeded LoRA layers (e.g. LoRA 1 targets foo and bar and LoRA 2 only targets bar, then foo is still present). Therefore, we need to ensure that the missing weights are zeroed out. I added logic and tests for this.
  2. The 2nd adapter targets a superset of layers or layers are partly disjoint: In this case, we would not to create new LoRA layers. This logic is not implemented (yet). But even if added, it would require recompilation and is thus not very useful.

This limits what the user can do with this feature but I don't have a good proposal how to solve it, except for asking users to swap in the LoRA that targets most layers first.

scaling = module.scaling
for key, val in scaling.items():
if isinstance(val, float):
scaling[key] = torch.tensor(val, device=module.weight.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to consider the dtype, too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some quick testing of matrix x scalar multiplication and it seems like torch always coerces to the dtype of the matrix, so we should be fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think having that explicitly mentioned might just be more readable for maintainers. But no strong opinions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Comment on lines +107 to +116
if is_lora_A:
# LoRA A affects out_features
padded = torch.zeros(target_rank, in_features, device=weight.device, dtype=weight.dtype)
padded[:original_rank, :] = weight
new_layer = torch.nn.Linear(in_features, target_rank, bias=lora_module.bias is not None)
else:
# LoRA B affects in_features
padded = torch.zeros(out_features, target_rank, device=weight.device, dtype=weight.dtype)
padded[:, :original_rank] = weight
new_layer = torch.nn.Linear(target_rank, out_features, bias=lora_module.bias is not None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add a comment to note that we might revisit if we have to consider if the corresponding layers come from a quantization backend.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understand. The LoRA A and B layers are always nn.Linear, no matter if the base layer is quantized or not.

Copy link
Member

@sayakpaul sayakpaul Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about functions like this:

def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, **kwargs):

Or do we simply use these for injection into the base model layers and NOT the adapter layers?

(Apologies in advance for not thoroughly reading through the corresponding code).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if the model has a linear layer that is targeted with LoRA, we check the type of linear layer and replace it with a normal LoRA Linear, with a bnb Linear4bit, etc. These PEFT layers have lora_A and lora_B sub-modules, but those are always nn.Linear.

In this padding function, we pad the lora_A and lora_B layers, thus we know that they're always nn.Linear, regardless of what the type of the original targeted layer was.

I added small comments for that.

Comment on lines +134 to +135
if lora_module.bias is not None:
new_layer.bias.data = lora_module.bias.data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could potentially also leverage the copy_() method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I tried but got:

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(10, 1, 16)), Parameter(FakeTensor(..., size=(13, 16), requires_grad=True)), None), **{}):
Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu

Not sure why that is, since we take care of the device, but I reverted the change. I think in this specific situation, there is no practical benefit to .copy_ anyway, is there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. Thanks for looking into it.

Comment on lines +281 to +283
is_compiled = hasattr(model, "_orig_mod")
if is_compiled:
raise ValueError("Call prepare_model_for_compiled_hotswap *before* compiling the model")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, _orig_mod is present when we do torch.compile() but compilation can be also triggerd with model.compile().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we can do anything about that. Any suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could check with torch.compiler.is_compiling().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean:

- if is_compiled:
+ if is_compiled or torch.compiler.is_compiling():

IIUC, this would not help. We can call this function inside, say, a forward method of an nn.Module to check if the module is being compiled. However, prepare_model_for_compiled_hotswap is not called inside the module. The result of torch.compiler.is_compiling() would thus always return False (except if someone called it inside forward and then compiles that module, but that would be very incorrect).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Then maybe having a comment to mention the .compile() use case and that we don't know how to deal with that yet would be nice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I did find a way to check for model.compile(), namely the _compiled_call_impl attribute: https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/module.py#L2991

It is private, so I used getattr to ensure this doesn't break easily. I also tested this and it seems to work, although I'm not sure when users should use model.compile instead of torch.compile(model).

Comment on lines +295 to +296
# config can be either a PeftConfig, or a dict of PeftConfigs like PeftModel.peft_config
config = {"dummy": config}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config won't have any repercussions because it's never assigned to a PeftModel, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The {"dummy": config} is not assigned, but the passed in config argument is still mutated, so the intended side effect is triggered.

As to why I wrap this into {"dummy": config}, it's just to have the same code below for configs in dicts.

# TODO conv2d
raise NotImplementedError
if old_val.shape[0] > new_val.shape[0]:
old_val.data.fill_(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): Would it make sense to have 0 assigned as a PAD_VALUE var and use it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would never change that value, right? It must be 0.

@@ -4172,18 +4173,19 @@ def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):
assert "__recompiles" in stderr.decode()

@pytest.mark.xfail(strict=True, reason="Requires hotswap to be implemented in diffusers")
def test_hotswapping_compiled_diffusion_model_does_not_trigger_recompilation(self):
@pytest.mark.parametrize("ranks", ["7,13", "13,7"]) # the ranks of the 2 LoRAs as str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check for different alphas?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In run_compiled_model_hotswap.py, we set:

    config0 = LoraConfig(init_lora_weights=False, r=rank0, lora_alpha=alpha0, target_modules=["q_proj", "v_proj"])
    config1 = LoraConfig(init_lora_weights=False, r=rank1, lora_alpha=alpha1, target_modules=["q_proj"])

so the alphas are indeed different.

elif "lora_A" in name and name.endswith(".bias"):
assert False, "LoRA A should not have a bias term"
elif "lora_B" in name and name.endswith(".bias"):
assert param.shape[0] == 10 # output shape of conv layer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add a test for checking if the correct config was updated?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have test_prepare_model_for_compiled_hotswap_scalings_update_config above, is there something more that should be tested with regard to configs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps alse for alphas? IIUC it only checks for ranks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but alpha values are not changed, only converted to tensors. So there is nothing to update in the config, is there?

- Allow 2nd adapter to target subset of 1st
- Update tests to reflect this
- Improve docs
Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, Sayak. LMK if I addressed all your points.

Comment on lines +49 to +53
# load lora 0
model = PeftModel.from_pretrained(model, <path-adapter-0>)
# Prepare the model to allow hotswapping even if ranks/scalings of 2nd adapter differ.
# You can skip this step if all ranks and scalings are identical.
prepare_model_for_compiled_hotswap(model, target_rank=max_rank)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about how an application would use this in practice for the reason you mentioned. If we break it down like this:

  1. Some business use case of serving different LoRAs: They can probably determine the max rank ahead of time and set it correctly.
  2. An image generation app like Comfy: This is hard as they cannot know in advance what LoRAs the user plans on loading. Theoretically, all available LoRAs could be scanned (which might be expensive) and the largest rank be chosen (which might be total overkill and damage runtime performance).

If we had a infer_max_rank_from_state_dict() it would be for the 2nd case, which has the mentioned issues, but correct me if I'm wrong. But regardless of that, where would the state dicts come from? Since we have multiple LoRA formats in the diffusion ecosystem and diffusers knows how to deal with them, would that function not be better added to diffusers?

- The adapters must be compatible (e.g. same LoRA alpha, same target modules).
- If you use `torch.compile` and want to avoid recompilation, the LoRA rank must be the same.
- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
- The adapters must be compatible (e.g. same target modules).
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 2 scenarios when the target modules are not identical:

  1. The 2nd adapter targets a subset of the 1st adapter: We can theoretically swap in the 2nd adapter, since there will be a LoRA layer in place where it is needed. However, the results will be wrong because the we will also have unneeded LoRA layers (e.g. LoRA 1 targets foo and bar and LoRA 2 only targets bar, then foo is still present). Therefore, we need to ensure that the missing weights are zeroed out. I added logic and tests for this.
  2. The 2nd adapter targets a superset of layers or layers are partly disjoint: In this case, we would not to create new LoRA layers. This logic is not implemented (yet). But even if added, it would require recompilation and is thus not very useful.

This limits what the user can do with this feature but I don't have a good proposal how to solve it, except for asking users to swap in the LoRA that targets most layers first.

scaling = module.scaling
for key, val in scaling.items():
if isinstance(val, float):
scaling[key] = torch.tensor(val, device=module.weight.device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some quick testing of matrix x scalar multiplication and it seems like torch always coerces to the dtype of the matrix, so we should be fine.

Comment on lines +107 to +116
if is_lora_A:
# LoRA A affects out_features
padded = torch.zeros(target_rank, in_features, device=weight.device, dtype=weight.dtype)
padded[:original_rank, :] = weight
new_layer = torch.nn.Linear(in_features, target_rank, bias=lora_module.bias is not None)
else:
# LoRA B affects in_features
padded = torch.zeros(out_features, target_rank, device=weight.device, dtype=weight.dtype)
padded[:, :original_rank] = weight
new_layer = torch.nn.Linear(target_rank, out_features, bias=lora_module.bias is not None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understand. The LoRA A and B layers are always nn.Linear, no matter if the base layer is quantized or not.

Comment on lines +134 to +135
if lora_module.bias is not None:
new_layer.bias.data = lora_module.bias.data
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I tried but got:

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(10, 1, 16)), Parameter(FakeTensor(..., size=(13, 16), requires_grad=True)), None), **{}):
Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu

Not sure why that is, since we take care of the device, but I reverted the change. I think in this specific situation, there is no practical benefit to .copy_ anyway, is there?

Comment on lines +281 to +283
is_compiled = hasattr(model, "_orig_mod")
if is_compiled:
raise ValueError("Call prepare_model_for_compiled_hotswap *before* compiling the model")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we can do anything about that. Any suggestion?

Comment on lines +295 to +296
# config can be either a PeftConfig, or a dict of PeftConfigs like PeftModel.peft_config
config = {"dummy": config}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The {"dummy": config} is not assigned, but the passed in config argument is still mutated, so the intended side effect is triggered.

As to why I wrap this into {"dummy": config}, it's just to have the same code below for configs in dicts.

# TODO conv2d
raise NotImplementedError
if old_val.shape[0] > new_val.shape[0]:
old_val.data.fill_(0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would never change that value, right? It must be 0.

@@ -4172,18 +4173,19 @@ def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):
assert "__recompiles" in stderr.decode()

@pytest.mark.xfail(strict=True, reason="Requires hotswap to be implemented in diffusers")
def test_hotswapping_compiled_diffusion_model_does_not_trigger_recompilation(self):
@pytest.mark.parametrize("ranks", ["7,13", "13,7"]) # the ranks of the 2 LoRAs as str
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In run_compiled_model_hotswap.py, we set:

    config0 = LoraConfig(init_lora_weights=False, r=rank0, lora_alpha=alpha0, target_modules=["q_proj", "v_proj"])
    config1 = LoraConfig(init_lora_weights=False, r=rank1, lora_alpha=alpha1, target_modules=["q_proj"])

so the alphas are indeed different.

elif "lora_A" in name and name.endswith(".bias"):
assert False, "LoRA A should not have a bias term"
elif "lora_B" in name and name.endswith(".bias"):
assert param.shape[0] == 10 # output shape of conv layer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have test_prepare_model_for_compiled_hotswap_scalings_update_config above, is there something more that should be tested with regard to configs?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Benjamin! Just last few comments. But nothing merge-blocking.

Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul I replied to your comments, please check again.

scaling = module.scaling
for key, val in scaling.items():
if isinstance(val, float):
scaling[key] = torch.tensor(val, device=module.weight.device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Comment on lines +107 to +116
if is_lora_A:
# LoRA A affects out_features
padded = torch.zeros(target_rank, in_features, device=weight.device, dtype=weight.dtype)
padded[:original_rank, :] = weight
new_layer = torch.nn.Linear(in_features, target_rank, bias=lora_module.bias is not None)
else:
# LoRA B affects in_features
padded = torch.zeros(out_features, target_rank, device=weight.device, dtype=weight.dtype)
padded[:, :original_rank] = weight
new_layer = torch.nn.Linear(target_rank, out_features, bias=lora_module.bias is not None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if the model has a linear layer that is targeted with LoRA, we check the type of linear layer and replace it with a normal LoRA Linear, with a bnb Linear4bit, etc. These PEFT layers have lora_A and lora_B sub-modules, but those are always nn.Linear.

In this padding function, we pad the lora_A and lora_B layers, thus we know that they're always nn.Linear, regardless of what the type of the original targeted layer was.

I added small comments for that.

Comment on lines +281 to +283
is_compiled = hasattr(model, "_orig_mod")
if is_compiled:
raise ValueError("Call prepare_model_for_compiled_hotswap *before* compiling the model")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean:

- if is_compiled:
+ if is_compiled or torch.compiler.is_compiling():

IIUC, this would not help. We can call this function inside, say, a forward method of an nn.Module to check if the module is being compiled. However, prepare_model_for_compiled_hotswap is not called inside the module. The result of torch.compiler.is_compiling() would thus always return False (except if someone called it inside forward and then compiles that module, but that would be very incorrect).

elif "lora_A" in name and name.endswith(".bias"):
assert False, "LoRA A should not have a bias term"
elif "lora_B" in name and name.endswith(".bias"):
assert param.shape[0] == 10 # output shape of conv layer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but alpha values are not changed, only converted to tensors. So there is nothing to update in the config, is there?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a single comment. Thanks for the all hard work and iterations!

The recompilation tests so far called a subprocess and then checked the
torch logs. This is cumbersome. Now, the tests have been simplified by
setting the torch._dynamo.config.error_on_recompile flag to True. This
way, we can check for presence/absence of this error instead of having
to check the logs.

As a consequence, the standalone scripts could be removed and all the
code now resides within the test class.

Unfortunately, a side effect change is that now, all tests run in the
same process. This is problematic because torch.compile caches compile
artifacts in the same process, which leads to errors when we use the
same compiled model across multiple tests. I could verify that this is
the case by running the tests individually, which prevented the
compilation error.

I tried a few things to mitigate this, but none helped:

- Setting torch._inductor.config.force_disable_caches = True
- Setting TORCHINDUCTOR_FORCE_DISABLE_CACHES=1
- Setting a separate TORCHINDUCTOR_CACHE_DIR for each test

Thus I resorted to using a different model for each test. There is a
function to check for this and raise an error if the same model is used
twice.
Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with having different models for different scenarios to work around the dynamo cache but there seems to be infrastructure in place that could be used to work around this.

I think that the recompilation detection itself utilizes the caching so disabling caching will not work either way.

Comment on lines 4163 to 4173
def raise_error_on_recompile(self):
"""Raise an error when torch recompiles in the context.

Raises a torch._dynamo.exc.RecompileError error.
"""
prev_value = torch._dynamo.config.error_on_recompile
torch._dynamo.config.error_on_recompile = True
try:
yield
finally:
torch._dynamo.config.error_on_recompile = prev_value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can replace this with torch._dynamo.config.patch(error_on_recompile=True) if you want. See https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/utils/_config_module.py#L226

# LLM #
#######

def check_hotswap(self, do_hotswap, model_id, ranks, alpha_scalings):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe introduce a context manager that calls torch._dynamo.reset() to reset the caches will work around the need for different models?

@BenjaminBossan
Copy link
Member Author

@githubnemo Very nice finds, both of your suggestions work and are much better than the workarounds I came up with.

... instead of at the start of the test.
@BenjaminBossan BenjaminBossan merged commit eaab05e into huggingface:main Feb 5, 2025
14 checks passed
@BenjaminBossan BenjaminBossan deleted the hotswap-allow-different-alpha-scalings branch February 5, 2025 17:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants