From 631acd0fdf88e388cafdc07bb5471670ed491809 Mon Sep 17 00:00:00 2001 From: mrfakename Date: Sat, 19 Aug 2023 15:32:47 -0700 Subject: [PATCH] Add MPS support --- basaran/model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/basaran/model.py b/basaran/model.py index 470a449d..2829becf 100644 --- a/basaran/model.py +++ b/basaran/model.py @@ -23,9 +23,14 @@ class StreamModel: def __init__(self, model, tokenizer): super().__init__() - self.model = model self.tokenizer = tokenizer - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.backends.mps.is_available(): + self.device = 'mps' + elif torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + self.model = model.to(self.device) def __call__( self,