Skip to content

Commit

Permalink
feat: download superpoint weights from github
Browse files Browse the repository at this point in the history
  • Loading branch information
ccrutchf committed Sep 15, 2024
1 parent 0b82100 commit e3eda9e
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions pyfishsensedev/library/homography/models/superpoint_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Authors: Rémi Pautrat, Paul-Edouard Sarlin
"""
from collections import OrderedDict
from pathlib import Path
from types import SimpleNamespace

import numpy as np
Expand All @@ -14,6 +15,7 @@
from kornia.color import rgb_to_grayscale

from pyfishsensedev.library.homography.utils import Extractor
from pyfishsensedev.library.online_ml_model import OnlineMLModel


def sample_descriptors(keypoints, descriptors, s: int = 8):
Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(self, c_in, c_out, kernel_size, relu=True):
)


class SuperPoint(Extractor):
class SuperPoint(Extractor, OnlineMLModel):
default_conf = {
"nms_radius": 4,
"max_num_keypoints": None,
Expand Down Expand Up @@ -112,9 +114,17 @@ def __init__(self, **conf):
)

# Load weights
weights_path = "weights/superpoint_v6_from_tf.pth"
weights_path = self.download_model()
self.load_state_dict(torch.load(weights_path))

@property
def _model_path(self) -> Path:
return self._model_cache_path / "superpoint_v6_from_tf.pth"

@property
def _model_url(self) -> str:
return "https://github.com/rpautrat/SuperPoint/raw/master/weights/superpoint_v6_from_tf.pth"

def forward(self, data):
image = data["image"]
if image.shape[1] == 3: # RGB to gray
Expand Down

0 comments on commit e3eda9e

Please sign in to comment.