Skip to content

Commit

Permalink
Update clip_finetune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Dec 24, 2024
1 parent 345f90c commit 023faec
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions fusion_bench/method/classification/clip_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@
from fusion_bench import print_parameters
from fusion_bench.compat.method import ModelFusionAlgorithm
from fusion_bench.compat.modelpool import to_modelpool
from fusion_bench.compat.modelpool.huggingface_clip_vision import (
HuggingFaceClipVisionPool,
)
from fusion_bench.dataset.clip_dataset import CLIPDataset
from fusion_bench.mixins import CLIPClassificationMixin
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
from fusion_bench.modelpool import CLIPVisionModelPool
from fusion_bench.models.hf_clip import HFCLIPClassifier
from fusion_bench.models.linearized.linearized_model_utils import LinearizedModelWraper
from fusion_bench.utils.data import InfiniteDataLoader
Expand Down Expand Up @@ -92,12 +91,12 @@ class ImageClassificationFineTuningForCLIP(
A class for fine-tuning CLIP models for image classification tasks.
"""

def run(self, modelpool: HuggingFaceClipVisionPool):
def run(self, modelpool: CLIPVisionModelPool):
"""
Executes the fine-tuning process.
Args:
modelpool (HuggingFaceClipVisionPool): The modelpool is responsible for loading the pre-trained model and training datasets.
modelpool (CLIPVisionModelPool): The modelpool is responsible for loading the pre-trained model and training datasets.
Returns:
VisionModel: The fine-tuned vision model.
Expand All @@ -109,9 +108,7 @@ def run(self, modelpool: HuggingFaceClipVisionPool):

L.seed_everything(config.seed)

task_names = [
dataset_config["name"] for dataset_config in modelpool.config.train_datasets
]
task_names = modelpool.train_dataset_names
with self.profile("setup model and optimizer"):
processor, classifier, optimizer, lr_scheduler = self.setup_model()

Expand All @@ -133,7 +130,7 @@ def run(self, modelpool: HuggingFaceClipVisionPool):

with self.profile("setup data"):
train_datasets = [
modelpool.get_train_dataset(task_name, processor)
CLIPDataset(modelpool.load_train_dataset(task_name), processor)
for task_name in task_names
]
train_dataloaders = [
Expand All @@ -157,6 +154,7 @@ def run(self, modelpool: HuggingFaceClipVisionPool):
range(config.num_steps),
desc=self.finetune_method,
disable=not self.fabric.is_global_zero,
dynamic_ncols=True,
):
optimizer.zero_grad()
loss = 0
Expand All @@ -183,7 +181,7 @@ def run(self, modelpool: HuggingFaceClipVisionPool):
save_path = os.path.join(
self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
)
self.save_model(classifier, save_path, trainable_only=True)
self.save_model(classifier, save_path)

if config.state_dict_save_path is not None:
self.save_model(
Expand Down Expand Up @@ -232,9 +230,8 @@ def setup_model(self):
config = self.config
modelpool = self.modelpool

pretrained_model_config = modelpool.get_model_config("_pretrained_")
clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_model_config.path)
processor = CLIPProcessor.from_pretrained(pretrained_model_config.path)
clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
processor = modelpool.load_processor()

self.finetune_method = "full fine-tune"
if config.use_lora or config.use_l_lora:
Expand Down

0 comments on commit 023faec

Please sign in to comment.