Skip to content

Commit

Permalink
Merge pull request #70 from tanganke/develop
Browse files Browse the repository at this point in the history
update TSVM_utils
  • Loading branch information
tanganke authored Jan 17, 2025
2 parents bfa663a + 467d3b9 commit a983ea2
Showing 1 changed file with 91 additions and 93 deletions.
184 changes: 91 additions & 93 deletions fusion_bench/method/task_singular_vector/utils/TSVM_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def sum_svd_dict(svd_dict, config):

###############
##### LOSSLESS Orthogonalization
def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
def compute_and_sum_svd_mem_reduction_lossless(
task_vectors: List[StateDictType],
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
Expand All @@ -129,40 +132,38 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
Args:
task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
config (object): A configuration object containing the device and dataset information.
accelerator (torch.device): The device to use for the computation.
Returns:
dict: A dictionary containing the new vectors after summing the SVD components.
"""
# becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
device = config.device
print("Computing SVD...")
with torch.no_grad():
new_vector = {}
for key in task_vectors[0].vector:
for key in task_vectors[0]:
original_device = task_vectors[0][key].device
new_vector[key] = {}
for i, (task_vector, dataset) in enumerate(
zip(task_vectors, config.DATASETS)
):
vec = task_vector.vector[key].to(device)
for i, task_vector in enumerate(task_vectors):
vec = task_vector[key].to(accelerator)

if (
len(task_vector.vector[key].shape) == 2
and "text_projection" not in key
):
if len(task_vector[key].shape) == 2 and "text_projection" not in key:

u, s, v = torch.linalg.svd(vec, full_matrices=False)

if i == 0:
print(f"Computed SVD for {key}...")
sum_u = torch.zeros(
u.shape[0], u.shape[1] * config.num_tasks, device=device
u.shape[0],
u.shape[1] * len(task_vectors),
device=accelerator,
)
sum_s = torch.zeros(
s.shape[0] * config.num_tasks, device=device
s.shape[0] * len(task_vectors), device=accelerator
)
sum_v = torch.zeros(
v.shape[0] * config.num_tasks, v.shape[1], device=device
v.shape[0] * len(task_vectors),
v.shape[1],
device=accelerator,
)
reduced_index_s = s.shape[0]

Expand All @@ -184,7 +185,7 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
else:
new_vector[key] += (vec - new_vector[key]) / (i + 1)

if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)

Expand All @@ -197,13 +198,16 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
v_v,
)
)

new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
return new_vector


###############
##### LOSSLESS EIGENDECOMP
def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
def compute_and_sum_svd_mem_reduction_lossless_eigen(
task_vectors: List[StateDictType],
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
Expand All @@ -216,40 +220,39 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
Args:
task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
config (object): A configuration object containing the device and dataset information.
accelerator (torch.device): The device to use for the computation.
Returns:
dict: A dictionary containing the new vectors after merging the SVD components.
"""
# becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
device = config.device
print("Computing SVD...")
with torch.no_grad():
new_vector = {}
for key in task_vectors[0].vector:
for key in task_vectors[0]:
original_device = task_vectors[0][key].device
new_vector[key] = {}
for i, (task_vector, dataset) in enumerate(
zip(task_vectors, config.DATASETS)
):
vec = task_vector.vector[key].to(device)
for i, task_vector in enumerate(task_vectors):
vec = task_vector[key].to(accelerator)

if (
len(task_vector.vector[key].shape) == 2
and "text_projection" not in key
):
if len(task_vector[key].shape) == 2 and "text_projection" not in key:

u, s, v = torch.linalg.svd(vec, full_matrices=False)

if i == 0:
print(f"Computed SVD for {key}...")
sum_u = torch.zeros(
u.shape[0], u.shape[1] * config.num_tasks, device=device
u.shape[0],
u.shape[1] * len(task_vectors),
device=accelerator,
)
sum_s = torch.zeros(
s.shape[0] * config.num_tasks, device=device
s.shape[0] * len(task_vectors), device=accelerator
)
sum_v = torch.zeros(
v.shape[0] * config.num_tasks, v.shape[1], device=device
v.shape[0] * len(task_vectors),
v.shape[1],
device=accelerator,
)
reduced_index_s = s.shape[0]

Expand All @@ -271,7 +274,7 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
else:
new_vector[key] += (vec - new_vector[key]) / (i + 1)

if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
sum_s, indices = torch.sort(sum_s, stable=True)

sum_u = torch.index_select(sum_u, 1, indices)
Expand All @@ -293,12 +296,14 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):

new_vector[key] = torch.linalg.multi_dot( # bool_mask *
(
sum_u,
u_orth,
torch.diag(sum_s),
v_orth,
sum_v,
)
)

new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
return new_vector


Expand Down Expand Up @@ -394,7 +399,10 @@ def compute_and_sum_svd_mem_reduction(

###############
#### TSV Merge Eigendecomp
def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
def compute_and_sum_svd_mem_reduction_2(
task_vectors: List[StateDictType],
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
Expand All @@ -404,36 +412,30 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
Args:
task_vectors (list): A list of task vector objects, where each object contains a
dictionary of vectors.
config (object): Configuration object containing the following attributes:
- DATASETS (list): List of datasets.
- device (torch.device): The device to perform computations on.
accelerator (torch.device): The device to use for the computation.
Returns:
dict: A dictionary containing the new vectors after SVD computation and merging.
"""
sv_reduction = 1 / len(config.DATASETS)
device = config.device
sv_reduction = 1 / len(task_vectors)

print("Computing SVD...")
with torch.no_grad():
new_vector = {}
for key in task_vectors[0].vector:
for key in task_vectors[0]:
original_device = task_vectors[0][key].device
new_vector[key] = {}
for i, (task_vector, dataset) in enumerate(
zip(task_vectors, config.DATASETS)
):
vec = task_vector.vector[key].to(device)
for i, task_vector in enumerate(task_vectors):
vec = task_vector[key].to(accelerator)

if (
len(task_vector.vector[key].shape) == 2
and "text_projection" not in key
):
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
u, s, v = torch.linalg.svd(vec, full_matrices=False)

if i == 0:
print(f"Computed SVD for {key}...")
sum_u = torch.zeros_like(u, device=device)
sum_s = torch.zeros_like(s, device=device)
sum_v = torch.zeros_like(v, device=device)
sum_u = torch.zeros_like(u, device=accelerator)
sum_s = torch.zeros_like(s, device=accelerator)
sum_v = torch.zeros_like(v, device=accelerator)
reduced_index_s = int(s.shape[0] * sv_reduction)

# select only the first reduced_index_s columns of u and place them
Expand All @@ -454,7 +456,7 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
else:
new_vector[key] += (vec - new_vector[key]) / (i + 1)

if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
sum_s, indices = torch.sort(sum_s, stable=True)

sum_u = torch.index_select(sum_u, 1, indices)
Expand Down Expand Up @@ -483,13 +485,17 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
sum_v,
)
)
new_vector[key] = new_vector[key].to(original_device, non_blocking=True)

return new_vector


###############
#### Rank Reduction TV
def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
def compute_and_sum_svd_mem_reduction_rank_reduction(
task_vectors: List[StateDictType],
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Compute and sum the Singular Value Decomposition (SVD) of task vectors with rank reduction.
Expand All @@ -499,36 +505,29 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
Args:
task_vectors (list): A list of task vector objects. Each object should have a `vector` attribute
which is a dictionary where keys are vector names and values are tensors.
config (object): Configuration object containing the following attributes:
- DATASETS (list): List of datasets.
- device (torch.device): The device to perform computations on.
accelerator (torch.device): The device to use for the computation.
Returns:
dict: A dictionary containing the new vectors after SVD computation and summation.
"""
sv_reduction = 1 / len(config.DATASETS)
device = config.device
sv_reduction = 1 / len(task_vectors)
print("Computing SVD...")
with torch.no_grad():
new_vector = {}
for key in task_vectors[0].vector:
for key in task_vectors[0]:
original_device = task_vectors[0][key].device
new_vector[key] = {}
for i, (task_vector, dataset) in enumerate(
zip(task_vectors, config.DATASETS)
):
vec = task_vector.vector[key].to(device)
for i, task_vector in enumerate(task_vectors):
vec = task_vector[key].to(accelerator)

if (
len(task_vector.vector[key].shape) == 2
and "text_projection" not in key
):
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
u, s, v = torch.linalg.svd(vec, full_matrices=False)

if i == 0:
print(f"Computed SVD for {key}...")
sum_u = torch.zeros_like(u, device=device)
sum_s = torch.zeros_like(s, device=device)
sum_v = torch.zeros_like(v, device=device)
sum_u = torch.zeros_like(u, device=accelerator)
sum_s = torch.zeros_like(s, device=accelerator)
sum_v = torch.zeros_like(v, device=accelerator)
reduced_index_s = int(s.shape[0] * sv_reduction)

# select only the first reduced_index_s columns of u and place them
Expand All @@ -549,34 +548,37 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
else:
new_vector[key] += (vec - new_vector[key]) / (i + 1)

if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
new_vector[key] = torch.linalg.multi_dot(
(
sum_u,
torch.diag(sum_s),
sum_v,
)
)

new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
return new_vector


def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
def compute_and_sum_svd_mem_reduction_dummy(
task_vectors: List[StateDictType],
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
"""To perform dummy operations."""
sv_reduction = 1 / 8
sv_reduction = 1 / len(task_vectors)
print("Computing SVD...")
with torch.no_grad():
new_vector = {}
for key in task_vectors[0].vector:
for key in task_vectors[0]:
original_device = task_vectors[0][key].device
new_vector[key] = {}
for i in range(0, 8):
if (
len(task_vectors[0].vector[key].shape) == 2
and "text_projection" not in key
):
for i, task_vector in enumerate(task_vectors):
vec = task_vector[key].to(accelerator)

if len(task_vector[key].shape) == 2 and "text_projection" not in key:
if i == 0:
u, s, v = torch.linalg.svd(
task_vectors[0].vector[key], full_matrices=False
)
u, s, v = torch.linalg.svd(vec, full_matrices=False)
reduced_index_s = int(s.shape[0] * sv_reduction)

print(f"Computed SVD for {key}...")
Expand Down Expand Up @@ -620,16 +622,11 @@ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):

else:
if i == 0:
new_vector[key] = task_vectors[0].vector[key]
# else:
# new_vector[key] += (
# task_vector.vector[key] - new_vector[key]
# ) / (i + 1)
new_vector[key] = vec.clone()
else:
new_vector[key] += (vec - new_vector[key]) / (i + 1)

if (
len(task_vectors[0].vector[key].shape) == 2
and "text_projection" not in key
):
if len(task_vector[key].shape) == 2 and "text_projection" not in key:

new_vector[key] = torch.linalg.multi_dot(
(
Expand All @@ -639,4 +636,5 @@ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
)
)

new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
return new_vector

0 comments on commit a983ea2

Please sign in to comment.