Skip to content

Commit

Permalink
Add routes for care scribe (#493)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashesh3 authored Mar 28, 2024
1 parent f8bb89c commit bde9e60
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 14 deletions.
7 changes: 7 additions & 0 deletions ayushma/models/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ class STTEngine(IntegerChoices):
GOOGLE = 2
SELF_HOSTED = 3

@classmethod
def get_id_from_name(cls, name):
for member in cls:
if member.name.lower() == name.lower():
return member.value
return None


class TTSEngine(IntegerChoices):
OPENAI = (1, "openai")
Expand Down
2 changes: 1 addition & 1 deletion ayushma/utils/converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def converse_api(
if not open_ai_key:
open_ai_key = (
request.headers.get("OpenAI-Key")
or (chat.project.open_ai_key)
or (chat.project and chat.project.openai_key)
or (user.allow_key and settings.OPENAI_API_KEY)
)
noonce = request.data.get("noonce")
Expand Down
4 changes: 2 additions & 2 deletions ayushma/utils/openaiapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def converse(

prompt = chat.prompt or (chat.project and chat.project.prompt)

if documents or chat.project.model == ModelType.GPT_4_VISUAL:
if documents or (chat.project and chat.project.model == ModelType.GPT_4_VISUAL):
prompt = "Image Capabilities: Enabled\n" + prompt

# excluding the latest query since it is not a history
Expand All @@ -327,7 +327,7 @@ def converse(
elif message.messageType == ChatMessageType.AYUSHMA:
chat_history.append(AIMessage(content=f"Ayushma: {message.message}"))

tts_engine = chat.project.tts_engine
tts_engine = chat.project and chat.project.tts_engine

if not stream:
lang_chain_helper = LangChainHelper(
Expand Down
75 changes: 64 additions & 11 deletions ayushma/views/orphan.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import time
from types import SimpleNamespace

import openai
from django.conf import settings
from django.http import StreamingHttpResponse
from drf_spectacular.utils import extend_schema, extend_schema_view, inline_serializer
from drf_spectacular.utils import extend_schema
from rest_framework import permissions, status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.mixins import CreateModelMixin, ListModelMixin, RetrieveModelMixin
from rest_framework.parsers import MultiPartParser
from rest_framework.response import Response
from rest_framework.serializers import CharField, IntegerField

from ayushma.models import APIKey, Chat, ChatMessage, Project
from ayushma.models import APIKey, Chat
from ayushma.models.enums import STTEngine
from ayushma.serializers import ChatDetailSerializer, ConverseSerializer
from ayushma.utils.converse import converse_api
from ayushma.utils.language_helpers import translate_text
from ayushma.utils.openaiapi import converse
from ayushma.utils.speech_to_text import speech_to_text
from utils.views.base import BaseModelViewSet
from utils.views.mixins import PartialUpdateModelMixin

from .chat import ChatViewSet
PREDEFINED_CONFIGS = {
"ai_form_fill": {
"model": "gpt-4-turbo-preview",
"response_format": "json_object",
"max_tokens": 4096,
"temperature": 0,
},
}


class Struct:
Expand All @@ -34,7 +36,7 @@ def has_permission(self, request, view):
if request.headers.get("X-API-KEY"):
api_key = request.headers.get("X-API-KEY")
try:
key = APIKey.objects.get(key=api_key)
APIKey.objects.get(key=api_key)
return True
except APIKey.DoesNotExist:
return False
Expand Down Expand Up @@ -82,6 +84,7 @@ def create(self, request, *args, **kwargs):
response = converse_api(
request=self.request,
chat=chat,
is_thread=False,
)
return Response(
{
Expand Down Expand Up @@ -110,7 +113,57 @@ def converse(self, *args, **kwarg):
response = converse_api(
request=self.request,
chat=chat,
is_thread=False,
)
return response
except Exception as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)

@action(detail=False, methods=["post"])
def transcribe(self, request, *args, **kwargs):
language = request.data.get("language") or "en"
audio = request.data.get("audio")
engine = request.data.get("engine")
if not audio or not engine:
raise ValidationError("audio and engine are required")
try:
engine_id = STTEngine.get_id_from_name(engine)
transcript = speech_to_text(engine_id, audio, language + "-IN")
return Response({"transcript": transcript})
except Exception as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)

@action(detail=False, methods=["post"])
def completion(self, request, *args, **kwargs):
task = request.data.get("task")

if not task:
raise ValidationError("task is required")

config = PREDEFINED_CONFIGS.get(task)

if not config:
raise ValidationError("Invalid task")

model = config.get("model")
response_format = config.get("response_format")
max_tokens = config.get("max_tokens")
temperature = config.get("temperature")

messages = request.data.get("messages")
if not messages or len(messages) == 0:
raise ValidationError("audio and engine are required")
try:
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
completion = client.chat.completions.create(
model=model,
temperature=temperature,
response_format={"type": response_format},
messages=messages,
max_tokens=max_tokens,
)

ai_response = completion.choices[0].message.content
return Response({"response": ai_response})
except Exception as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)

0 comments on commit bde9e60

Please sign in to comment.