Skip to content

Commit

Permalink
Add AOAI O1 Preview specific endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
ilmarinen committed Sep 25, 2024
1 parent 34c2fbc commit 52e89b5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
3 changes: 2 additions & 1 deletion eureka_ml_insights/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
Phi3HF,
KeyBasedAuthentication,
EndpointModels,
RestEndpointModels
RestEndpointModels,
RestEndpointO1PreviewModelsAzure,
)

__all__ = [
Expand Down
48 changes: 48 additions & 0 deletions eureka_ml_insights/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,54 @@ def handle_request_error(self, e):
return None, False, False


@dataclass
class RestEndpointO1PreviewModelsAzure(EndpointModels):

do_sample: bool = True

def __post_init__(self):
self.bearer_token_provider = get_bearer_token_provider(AzureCliCredential(), "https://cognitiveservices.azure.com/.default")

def create_request(self, text_prompt, query_images=None, system_message=None):
data = {
"messages": [
{
"role": "user",
"content": text_prompt,
}
],
}
if system_message:
data["messages"]= [{"role": "system", "content": system_message}] + data["input_data"][
"input_string"
]
if query_images:
raise NotImplementedError("Images are not supported for GCR endpoints yet.")

body = str.encode(json.dumps(data))
# The azureml-model-deployment header will force the request to go to a specific deployment.
# Remove this header to have the request observe the endpoint traffic rules
headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self.bearer_token_provider()),
}

return urllib.request.Request(self.url, body, headers)

def get_response(self, request):
response = urllib.request.urlopen(request)
res = json.loads(response.read())
return res["choices"][0]["message"]["content"]

def handle_request_error(self, e):
if isinstance(e, urllib.error.HTTPError):
logging.info("The request failed with status code: " + str(e.code))
# Print the headers - they include the requert ID and the timestamp, which are useful for debugging.
logging.info(e.info())
logging.info(e.read().decode("utf8", "ignore"))
return None, False, False


@dataclass
class ServerlessAzureRestEndpointModels(EndpointModels, KeyBasedAuthentication):
"""This class can be used for serverless Azure model deployments."""
Expand Down

0 comments on commit 52e89b5

Please sign in to comment.