diff --git a/eureka_ml_insights/models/__init__.py b/eureka_ml_insights/models/__init__.py index 24065c8..8f3cd51 100644 --- a/eureka_ml_insights/models/__init__.py +++ b/eureka_ml_insights/models/__init__.py @@ -12,7 +12,8 @@ Phi3HF, KeyBasedAuthentication, EndpointModels, - RestEndpointModels + RestEndpointModels, + RestEndpointO1PreviewModelsAzure, ) __all__ = [ diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index 3bfd487..7cdd660 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -167,6 +167,54 @@ def handle_request_error(self, e): return None, False, False +@dataclass +class RestEndpointO1PreviewModelsAzure(EndpointModels): + + do_sample: bool = True + + def __post_init__(self): + self.bearer_token_provider = get_bearer_token_provider(AzureCliCredential(), "https://cognitiveservices.azure.com/.default") + + def create_request(self, text_prompt, query_images=None, system_message=None): + data = { + "messages": [ + { + "role": "user", + "content": text_prompt, + } + ], + } + if system_message: + data["messages"]= [{"role": "system", "content": system_message}] + data["input_data"][ + "input_string" + ] + if query_images: + raise NotImplementedError("Images are not supported for GCR endpoints yet.") + + body = str.encode(json.dumps(data)) + # The azureml-model-deployment header will force the request to go to a specific deployment. + # Remove this header to have the request observe the endpoint traffic rules + headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + self.bearer_token_provider()), + } + + return urllib.request.Request(self.url, body, headers) + + def get_response(self, request): + response = urllib.request.urlopen(request) + res = json.loads(response.read()) + return res["choices"][0]["message"]["content"] + + def handle_request_error(self, e): + if isinstance(e, urllib.error.HTTPError): + logging.info("The request failed with status code: " + str(e.code)) + # Print the headers - they include the requert ID and the timestamp, which are useful for debugging. + logging.info(e.info()) + logging.info(e.read().decode("utf8", "ignore")) + return None, False, False + + @dataclass class ServerlessAzureRestEndpointModels(EndpointModels, KeyBasedAuthentication): """This class can be used for serverless Azure model deployments."""