diff --git a/basaran/model.py b/basaran/model.py index 470a449d..2829becf 100644 --- a/basaran/model.py +++ b/basaran/model.py @@ -23,9 +23,14 @@ class StreamModel: def __init__(self, model, tokenizer): super().__init__() - self.model = model self.tokenizer = tokenizer - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.backends.mps.is_available(): + self.device = 'mps' + elif torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + self.model = model.to(self.device) def __call__( self,