From 467d3b90ec65331d01b72f11f6a15d78ee7e4060 Mon Sep 17 00:00:00 2001 From: Anke Tang Date: Wed, 15 Jan 2025 15:09:04 +0800 Subject: [PATCH] update TSVM_utils --- .../task_singular_vector/utils/TSVM_utils.py | 184 +++++++++--------- 1 file changed, 91 insertions(+), 93 deletions(-) diff --git a/fusion_bench/method/task_singular_vector/utils/TSVM_utils.py b/fusion_bench/method/task_singular_vector/utils/TSVM_utils.py index fc0291b9..10dd9408 100644 --- a/fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +++ b/fusion_bench/method/task_singular_vector/utils/TSVM_utils.py @@ -116,7 +116,10 @@ def sum_svd_dict(svd_dict, config): ############### ##### LOSSLESS Orthogonalization -def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config): +def compute_and_sum_svd_mem_reduction_lossless( + task_vectors: List[StateDictType], + accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu", +): """ Computes the Singular Value Decomposition (SVD) for each task vector and merge the results. @@ -129,40 +132,38 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config): Args: task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task. - config (object): A configuration object containing the device and dataset information. - + accelerator (torch.device): The device to use for the computation. Returns: dict: A dictionary containing the new vectors after summing the SVD components. """ # becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer) - device = config.device print("Computing SVD...") with torch.no_grad(): new_vector = {} - for key in task_vectors[0].vector: + for key in task_vectors[0]: + original_device = task_vectors[0][key].device new_vector[key] = {} - for i, (task_vector, dataset) in enumerate( - zip(task_vectors, config.DATASETS) - ): - vec = task_vector.vector[key].to(device) + for i, task_vector in enumerate(task_vectors): + vec = task_vector[key].to(accelerator) - if ( - len(task_vector.vector[key].shape) == 2 - and "text_projection" not in key - ): + if len(task_vector[key].shape) == 2 and "text_projection" not in key: u, s, v = torch.linalg.svd(vec, full_matrices=False) if i == 0: print(f"Computed SVD for {key}...") sum_u = torch.zeros( - u.shape[0], u.shape[1] * config.num_tasks, device=device + u.shape[0], + u.shape[1] * len(task_vectors), + device=accelerator, ) sum_s = torch.zeros( - s.shape[0] * config.num_tasks, device=device + s.shape[0] * len(task_vectors), device=accelerator ) sum_v = torch.zeros( - v.shape[0] * config.num_tasks, v.shape[1], device=device + v.shape[0] * len(task_vectors), + v.shape[1], + device=accelerator, ) reduced_index_s = s.shape[0] @@ -184,7 +185,7 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config): else: new_vector[key] += (vec - new_vector[key]) / (i + 1) - if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key: + if len(task_vector[key].shape) == 2 and "text_projection" not in key: u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) @@ -197,13 +198,16 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config): v_v, ) ) - + new_vector[key] = new_vector[key].to(original_device, non_blocking=True) return new_vector ############### ##### LOSSLESS EIGENDECOMP -def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config): +def compute_and_sum_svd_mem_reduction_lossless_eigen( + task_vectors: List[StateDictType], + accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu", +): """ Computes the Singular Value Decomposition (SVD) for each task vector and merge the results. @@ -216,40 +220,39 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config): Args: task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task. - config (object): A configuration object containing the device and dataset information. + accelerator (torch.device): The device to use for the computation. Returns: dict: A dictionary containing the new vectors after merging the SVD components. """ # becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer) - device = config.device print("Computing SVD...") with torch.no_grad(): new_vector = {} - for key in task_vectors[0].vector: + for key in task_vectors[0]: + original_device = task_vectors[0][key].device new_vector[key] = {} - for i, (task_vector, dataset) in enumerate( - zip(task_vectors, config.DATASETS) - ): - vec = task_vector.vector[key].to(device) + for i, task_vector in enumerate(task_vectors): + vec = task_vector[key].to(accelerator) - if ( - len(task_vector.vector[key].shape) == 2 - and "text_projection" not in key - ): + if len(task_vector[key].shape) == 2 and "text_projection" not in key: u, s, v = torch.linalg.svd(vec, full_matrices=False) if i == 0: print(f"Computed SVD for {key}...") sum_u = torch.zeros( - u.shape[0], u.shape[1] * config.num_tasks, device=device + u.shape[0], + u.shape[1] * len(task_vectors), + device=accelerator, ) sum_s = torch.zeros( - s.shape[0] * config.num_tasks, device=device + s.shape[0] * len(task_vectors), device=accelerator ) sum_v = torch.zeros( - v.shape[0] * config.num_tasks, v.shape[1], device=device + v.shape[0] * len(task_vectors), + v.shape[1], + device=accelerator, ) reduced_index_s = s.shape[0] @@ -271,7 +274,7 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config): else: new_vector[key] += (vec - new_vector[key]) / (i + 1) - if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key: + if len(task_vector[key].shape) == 2 and "text_projection" not in key: sum_s, indices = torch.sort(sum_s, stable=True) sum_u = torch.index_select(sum_u, 1, indices) @@ -293,12 +296,14 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config): new_vector[key] = torch.linalg.multi_dot( # bool_mask * ( + sum_u, u_orth, torch.diag(sum_s), v_orth, + sum_v, ) ) - + new_vector[key] = new_vector[key].to(original_device, non_blocking=True) return new_vector @@ -394,7 +399,10 @@ def compute_and_sum_svd_mem_reduction( ############### #### TSV Merge Eigendecomp -def compute_and_sum_svd_mem_reduction_2(task_vectors, config): +def compute_and_sum_svd_mem_reduction_2( + task_vectors: List[StateDictType], + accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu", +): """ Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors, reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate @@ -404,36 +412,30 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config): Args: task_vectors (list): A list of task vector objects, where each object contains a dictionary of vectors. - config (object): Configuration object containing the following attributes: - - DATASETS (list): List of datasets. - - device (torch.device): The device to perform computations on. + accelerator (torch.device): The device to use for the computation. Returns: dict: A dictionary containing the new vectors after SVD computation and merging. """ - sv_reduction = 1 / len(config.DATASETS) - device = config.device + sv_reduction = 1 / len(task_vectors) + print("Computing SVD...") with torch.no_grad(): new_vector = {} - for key in task_vectors[0].vector: + for key in task_vectors[0]: + original_device = task_vectors[0][key].device new_vector[key] = {} - for i, (task_vector, dataset) in enumerate( - zip(task_vectors, config.DATASETS) - ): - vec = task_vector.vector[key].to(device) + for i, task_vector in enumerate(task_vectors): + vec = task_vector[key].to(accelerator) - if ( - len(task_vector.vector[key].shape) == 2 - and "text_projection" not in key - ): + if len(task_vector[key].shape) == 2 and "text_projection" not in key: u, s, v = torch.linalg.svd(vec, full_matrices=False) if i == 0: print(f"Computed SVD for {key}...") - sum_u = torch.zeros_like(u, device=device) - sum_s = torch.zeros_like(s, device=device) - sum_v = torch.zeros_like(v, device=device) + sum_u = torch.zeros_like(u, device=accelerator) + sum_s = torch.zeros_like(s, device=accelerator) + sum_v = torch.zeros_like(v, device=accelerator) reduced_index_s = int(s.shape[0] * sv_reduction) # select only the first reduced_index_s columns of u and place them @@ -454,7 +456,7 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config): else: new_vector[key] += (vec - new_vector[key]) / (i + 1) - if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key: + if len(task_vector[key].shape) == 2 and "text_projection" not in key: sum_s, indices = torch.sort(sum_s, stable=True) sum_u = torch.index_select(sum_u, 1, indices) @@ -483,13 +485,17 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config): sum_v, ) ) + new_vector[key] = new_vector[key].to(original_device, non_blocking=True) return new_vector ############### #### Rank Reduction TV -def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config): +def compute_and_sum_svd_mem_reduction_rank_reduction( + task_vectors: List[StateDictType], + accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu", +): """ Compute and sum the Singular Value Decomposition (SVD) of task vectors with rank reduction. @@ -499,36 +505,29 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config): Args: task_vectors (list): A list of task vector objects. Each object should have a `vector` attribute which is a dictionary where keys are vector names and values are tensors. - config (object): Configuration object containing the following attributes: - - DATASETS (list): List of datasets. - - device (torch.device): The device to perform computations on. + accelerator (torch.device): The device to use for the computation. Returns: dict: A dictionary containing the new vectors after SVD computation and summation. """ - sv_reduction = 1 / len(config.DATASETS) - device = config.device + sv_reduction = 1 / len(task_vectors) print("Computing SVD...") with torch.no_grad(): new_vector = {} - for key in task_vectors[0].vector: + for key in task_vectors[0]: + original_device = task_vectors[0][key].device new_vector[key] = {} - for i, (task_vector, dataset) in enumerate( - zip(task_vectors, config.DATASETS) - ): - vec = task_vector.vector[key].to(device) + for i, task_vector in enumerate(task_vectors): + vec = task_vector[key].to(accelerator) - if ( - len(task_vector.vector[key].shape) == 2 - and "text_projection" not in key - ): + if len(task_vector[key].shape) == 2 and "text_projection" not in key: u, s, v = torch.linalg.svd(vec, full_matrices=False) if i == 0: print(f"Computed SVD for {key}...") - sum_u = torch.zeros_like(u, device=device) - sum_s = torch.zeros_like(s, device=device) - sum_v = torch.zeros_like(v, device=device) + sum_u = torch.zeros_like(u, device=accelerator) + sum_s = torch.zeros_like(s, device=accelerator) + sum_v = torch.zeros_like(v, device=accelerator) reduced_index_s = int(s.shape[0] * sv_reduction) # select only the first reduced_index_s columns of u and place them @@ -549,7 +548,7 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config): else: new_vector[key] += (vec - new_vector[key]) / (i + 1) - if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key: + if len(task_vector[key].shape) == 2 and "text_projection" not in key: new_vector[key] = torch.linalg.multi_dot( ( sum_u, @@ -557,26 +556,29 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config): sum_v, ) ) + + new_vector[key] = new_vector[key].to(original_device, non_blocking=True) return new_vector -def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config): +def compute_and_sum_svd_mem_reduction_dummy( + task_vectors: List[StateDictType], + accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu", +): """To perform dummy operations.""" - sv_reduction = 1 / 8 + sv_reduction = 1 / len(task_vectors) print("Computing SVD...") with torch.no_grad(): new_vector = {} - for key in task_vectors[0].vector: + for key in task_vectors[0]: + original_device = task_vectors[0][key].device new_vector[key] = {} - for i in range(0, 8): - if ( - len(task_vectors[0].vector[key].shape) == 2 - and "text_projection" not in key - ): + for i, task_vector in enumerate(task_vectors): + vec = task_vector[key].to(accelerator) + + if len(task_vector[key].shape) == 2 and "text_projection" not in key: if i == 0: - u, s, v = torch.linalg.svd( - task_vectors[0].vector[key], full_matrices=False - ) + u, s, v = torch.linalg.svd(vec, full_matrices=False) reduced_index_s = int(s.shape[0] * sv_reduction) print(f"Computed SVD for {key}...") @@ -620,16 +622,11 @@ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config): else: if i == 0: - new_vector[key] = task_vectors[0].vector[key] - # else: - # new_vector[key] += ( - # task_vector.vector[key] - new_vector[key] - # ) / (i + 1) + new_vector[key] = vec.clone() + else: + new_vector[key] += (vec - new_vector[key]) / (i + 1) - if ( - len(task_vectors[0].vector[key].shape) == 2 - and "text_projection" not in key - ): + if len(task_vector[key].shape) == 2 and "text_projection" not in key: new_vector[key] = torch.linalg.multi_dot( ( @@ -639,4 +636,5 @@ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config): ) ) + new_vector[key] = new_vector[key].to(original_device, non_blocking=True) return new_vector