Skip to content

Commit

Permalink
Merge pull request #1049 from roboflow/revert-1046-revert-1026-lean/s…
Browse files Browse the repository at this point in the history
…ingleton-owlv2-model-compile

Add Background Compilation for OWLv2 Vision Model
  • Loading branch information
PawelPeczek-Roboflow authored Feb 24, 2025
2 parents ae8ed64 + cc89580 commit f6289dd
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 6 deletions.
2 changes: 0 additions & 2 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@
# OWLv2 CPU image cache size, default is 10000
OWLV2_CPU_IMAGE_CACHE_SIZE = int(os.getenv("OWLV2_CPU_IMAGE_CACHE_SIZE", 1000))

COMPILE_OWLV2_MODEL = str2bool(os.getenv("COMPILE_OWLV2_MODEL", "True"))

# Maximum batch size for GAZE, default is 8
GAZE_MAX_BATCH_SIZE = int(os.getenv("GAZE_MAX_BATCH_SIZE", 8))

Expand Down
57 changes: 53 additions & 4 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import os
import pickle
import threading
import weakref
from collections import defaultdict
from typing import Any, Dict, List, Literal, NewType, Optional, Tuple, Union
Expand All @@ -13,6 +14,7 @@
from transformers import Owlv2ForObjectDetection, Owlv2Processor
from transformers.models.owlv2.modeling_owlv2 import box_iou

from inference.core import logger
from inference.core.cache.model_artifacts import save_bytes_in_cache
from inference.core.entities.requests.inference import ObjectDetectionInferenceRequest
from inference.core.entities.responses.inference import (
Expand All @@ -21,7 +23,6 @@
ObjectDetectionPrediction,
)
from inference.core.env import (
COMPILE_OWLV2_MODEL,
DEVICE,
MAX_DETECTIONS,
MODEL_CACHE_DIR,
Expand Down Expand Up @@ -87,22 +88,58 @@ def _check_size_limit(self):
self.popitem(last=False)


class OWLv2ModelManager:
_instances = {}
_lock = threading.Lock()

def __new__(cls, vision_model, huggingface_id: str):
if huggingface_id not in cls._instances:
with cls._lock:
if huggingface_id not in cls._instances:
instance = super().__new__(cls)
instance._vision_model = vision_model
instance._start_compilation()
cls._instances[huggingface_id] = instance
return cls._instances[huggingface_id]

def get_vision_model(self):
if self._vision_model is None:
raise ValueError("No vision_model has been initialized")
return self._vision_model

def _start_compilation(self):
logger.info("Compiling OWLv2 model")
compilation_thread = threading.Thread(target=self._compile_model)
compilation_thread.daemon = True
compilation_thread.start()

def _compile_model(self):
logger.info("Compiling OWLv2 model in thread")
self._vision_model = torch.compile(self._vision_model)
logger.info("OWLv2 model compiled in thread")


class Owlv2Singleton:
_instances = weakref.WeakValueDictionary()

def __new__(cls, huggingface_id: str):
if huggingface_id not in cls._instances:
logger.info(f"Creating new OWLv2 instance for {huggingface_id}")
instance = super().__new__(cls)
instance.huggingface_id = huggingface_id
# Load model directly in the instance
logger.info(f"Loading OWLv2 model from {huggingface_id}")
model = (
Owlv2ForObjectDetection.from_pretrained(huggingface_id)
.eval()
.to(DEVICE)
)
torch._dynamo.config.suppress_errors = True
if COMPILE_OWLV2_MODEL:
model.owlv2.vision_model = torch.compile(model.owlv2.vision_model)
logger.info(f"OWLv2 model loaded from {huggingface_id}")
owlv2_model_manager = OWLv2ModelManager(
vision_model=model.owlv2.vision_model, huggingface_id=huggingface_id
)
model.owlv2.vision_model = owlv2_model_manager.get_vision_model()
instance.model = model
cls._instances[huggingface_id] = instance
return cls._instances[huggingface_id]
Expand Down Expand Up @@ -772,6 +809,7 @@ def download_model_artefacts_from_s3(self):
raise NotImplementedError("Owlv2 not currently supported on hosted inference")

def download_model_artifacts_from_roboflow_api(self):
logger.info(f"Downloading OWLv2 model artifacts")
if self.version_id is not None:
api_data = get_roboflow_model_data(
api_key=self.api_key,
Expand All @@ -784,10 +822,12 @@ def download_model_artifacts_from_roboflow_api(self):
raise ModelArtefactError(
"Could not find `model` key in roboflow API model description response."
)
logger.info(f"Downloading OWLv2 model weights from {api_data['model']}")
model_weights_response = get_from_url(
api_data["model"], json_response=False
)
else:
logger.info(f"Getting OWLv2 model data for")
api_data = get_roboflow_instant_model_data(
api_key=self.api_key,
model_id=self.endpoint,
Expand All @@ -800,6 +840,9 @@ def download_model_artifacts_from_roboflow_api(self):
raise ModelArtefactError(
"Could not find `modelFiles` key or `modelFiles`.`owlv2` or `modelFiles`.`owlv2`.`model` key in roboflow API model description response."
)
logger.info(
f"Downloading OWLv2 model weights from {api_data['modelFiles']['owlv2']['model']}"
)
model_weights_response = get_from_url(
api_data["modelFiles"]["owlv2"]["model"], json_response=False
)
Expand All @@ -808,8 +851,10 @@ def download_model_artifacts_from_roboflow_api(self):
file=self.weights_file,
model_id=self.endpoint,
)
logger.info(f"OWLv2 model weights saved to cache")

def load_model_artifacts_from_cache(self):
logger.info(f"Loading OWLv2 model artifacts from cache")
if DEVICE == "cpu":
self.model_data = torch.load(
self.cache_file(self.weights_file), map_location="cpu"
Expand All @@ -823,6 +868,7 @@ def load_model_artifacts_from_cache(self):
# each model can have its own OwlV2 instance because we use a singleton
self.owlv2 = OwlV2(model_id=self.roboflow_id)
self.owlv2.cpu_image_embed_cache = self.model_data["image_embeds"]
logger.info(f"OWLv2 model artifacts loaded from cache")

weights_file_path = "weights.pt"

Expand All @@ -833,13 +879,16 @@ def weights_file(self):
def infer(
self, image, confidence: float = 0.99, iou_threshold: float = 0.3, **kwargs
):
return self.owlv2.infer_from_embedding_dict(
logger.info(f"Inferring OWLv2 model")
result = self.owlv2.infer_from_embedding_dict(
image,
self.train_data_dict,
confidence=confidence,
iou_threshold=iou_threshold,
**kwargs,
)
logger.info(f"OWLv2 model inference complete")
return result

def draw_predictions(
self,
Expand Down
22 changes: 22 additions & 0 deletions tests/inference/models_predictions_tests/test_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
OwlV2,
Owlv2Singleton,
SerializedOwlV2,
OWLv2ModelManager,
)


Expand Down Expand Up @@ -427,5 +428,26 @@ def test_owlv2_model_unloaded_when_garbage_collected():
assert len(Owlv2Singleton._instances) == 0



@pytest.mark.slow
def test_owlv2_model_manager_singleton():
owlv2 = OwlV2(model_id=f"owlv2/{OWLV2_VERSION_ID}")

manager1 = OWLv2ModelManager(
vision_model=owlv2.model.owlv2.vision_model,
huggingface_id=f"google/{OWLV2_VERSION_ID}"
)
manager2 = OWLv2ModelManager(
vision_model=owlv2.model.owlv2.vision_model,
huggingface_id=f"google/{OWLV2_VERSION_ID}"
)

assert manager1 is manager2

assert manager1._vision_model is manager2._vision_model

assert len(OWLv2ModelManager._instances) == 1


if __name__ == "__main__":
test_owlv2()

0 comments on commit f6289dd

Please sign in to comment.