Skip to content

Commit

Permalink
Merge pull request #973 from roboflow/use-precomputed-owl-embeddings
Browse files Browse the repository at this point in the history
Use precomputed owl embeddings
  • Loading branch information
grzegorz-roboflow authored Jan 24, 2025
2 parents de69fbd + c191b5f commit b6412f2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
27 changes: 24 additions & 3 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
load_image_rgb,
)

CPU_IMAGE_EMBED_CACHE_SIZE = 10000

# TYPES
Hash = NewType("Hash", str)
PosNegKey = Literal["positive", "negative"]
Expand Down Expand Up @@ -319,7 +321,9 @@ def reset_cache(self):
# each entry should be on the order of 300*4KB, so 1000 is 400MB of CUDA memory
self.image_embed_cache = LimitedSizeDict(size_limit=OWLV2_IMAGE_CACHE_SIZE)
# no need for limit here, as we're only storing on CPU
self.cpu_image_embed_cache = dict()
self.cpu_image_embed_cache = LimitedSizeDict(
size_limit=CPU_IMAGE_EMBED_CACHE_SIZE
)
# each entry should be on the order of 10 bytes, so 1000 is 10KB
self.image_size_cache = LimitedSizeDict(size_limit=OWLV2_IMAGE_CACHE_SIZE)
# entry size will vary depending on the number of samples, but 10 should be safe
Expand Down Expand Up @@ -693,9 +697,24 @@ def serialize_training_data(
hf_id: str = f"google/{OWLV2_VERSION_ID}",
iou_threshold: float = 0.3,
save_dir: str = os.path.join(MODEL_CACHE_DIR, "owl-v2-serialized-data"),
previous_embeddings_file: str = None,
):
roboflow_id = hf_id.replace("google/", "owlv2/")
owlv2 = OwlV2(model_id=roboflow_id)
if previous_embeddings_file is not None:
if DEVICE == "cpu":
model_data = torch.load(previous_embeddings_file, map_location="cpu")
else:
model_data = torch.load(previous_embeddings_file)
class_names = model_data["class_names"]
train_data_dict = model_data["train_data_dict"]
huggingface_id = model_data["huggingface_id"]
roboflow_id = model_data["roboflow_id"]
# each model can have its own OwlV2 instance because we use a singleton
owlv2 = OwlV2(model_id=roboflow_id)
owlv2.cpu_image_embed_cache = model_data["image_embeds"]
else:
owlv2 = OwlV2(model_id=roboflow_id)

train_data_dict, image_embeds = owlv2.make_class_embeddings_dict(
training_data, iou_threshold, return_image_embeds=True
)
Expand Down Expand Up @@ -826,7 +845,9 @@ def draw_predictions(
def save_small_model_without_image_embeds(
self, save_dir: str = os.path.join(MODEL_CACHE_DIR, "owl-v2-serialized-data")
):
self.owlv2.cpu_image_embed_cache = dict()
self.owlv2.cpu_image_embed_cache = LimitedSizeDict(
size_limit=CPU_IMAGE_EMBED_CACHE_SIZE
)
return self.save_model(
self.huggingface_id,
self.roboflow_id,
Expand Down
2 changes: 1 addition & 1 deletion requirements/_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pandas>=2.0.0,<2.3.0
paho-mqtt~=1.6.1
pytest>=8.0.0,<9.0.0 # this is not a joke, sam2 requires this as the fork we are using is dependent on that, yet
# do not mark the dependency: https://github.com/SauravMaheshkar/samv2/blob/main/sam2/utils/download.py
tokenizers>=0.19.0,<=0.20.3
tokenizers>=0.19.0,<=0.21
slack-sdk~=3.33.4
twilio~=9.3.7
httpx>=0.25.1,<0.28.0 # must be pinned as bc in 0.28.0 is causing Anthropics to fail
Expand Down

0 comments on commit b6412f2

Please sign in to comment.