Skip to content

Commit

Permalink
add base angle
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Jul 26, 2024
1 parent ddd97b6 commit dcb8c76
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
3 changes: 2 additions & 1 deletion angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from peft.tuners.lora import LoraLayer

from .base import AngleBase
from .utils import logger
from .evaluation import CorrelationEvaluator

Expand Down Expand Up @@ -994,7 +995,7 @@ def __call__(self,
return loss


class AnglE:
class AnglE(AngleBase):
"""
AnglE. Everything is here👋
Expand Down
12 changes: 12 additions & 0 deletions angle_emb/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABCMeta, abstractmethod


class AngleBase(metaclass=ABCMeta):

@abstractmethod
def encode(self):
raise NotImplementedError

@abstractmethod
def fit(self):
raise NotImplementedError
4 changes: 3 additions & 1 deletion angle_emb/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)
from scipy.stats import pearsonr, spearmanr

from .base import AngleBase


class CorrelationEvaluator(object):
def __init__(
Expand All @@ -27,7 +29,7 @@ def __init__(
self.labels = labels
self.batch_size = batch_size

def __call__(self, model, **kwargs) -> dict:
def __call__(self, model: AngleBase, **kwargs) -> dict:
""" Evaluate the model on the given dataset.
:param model: AnglE, the model to evaluate.
Expand Down

0 comments on commit dcb8c76

Please sign in to comment.