Skip to content

Commit

Permalink
added convo_starter
Browse files Browse the repository at this point in the history
  • Loading branch information
Gautam-Rajeev committed Nov 2, 2023
1 parent b879bb5 commit 481eadc
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/text_classification/convo_starter_orgbot/local/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Use an official Python runtime as a parent image
FROM python:3.9-slim

WORKDIR /app


#install requirements
COPY requirements.txt requirements.txt
RUN pip3 install -r requirements.txt

# Copy the rest of the application code to the working directory
COPY . /app/
EXPOSE 8000
# Set the entrypoint for the container
CMD ["hypercorn", "--bind", "0.0.0.0:8000", "api:app"]
21 changes: 21 additions & 0 deletions src/text_classification/convo_starter_orgbot/local/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## Grievance classification:


### Purpose :
Model to classify grievances into 3 buckets :
- Label 0: 'General questions'
- Label 1: 'Starter: Hi, hello etc'



### Testing the model deployment :
To run for testing just the Hugging Face deployment for grievence recognition, you can follow the following steps :

- Git clone the repo
- Go to current folder location i.e. ``` cd /src/text_classification/flow_classification/local ```
- Create docker image file and test the api:
```
docker build -t testmodel .
docker run -p 8000:8000 testmodel
curl -X POST -H "Content-Type: application/json" -d '{"text": "Where is my money? "}' http://localhost:8000/
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .request import ModelRequest
from .request import Model
24 changes: 24 additions & 0 deletions src/text_classification/convo_starter_orgbot/local/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from model import Model
from request import ModelRequest
from quart import Quart, request
import aiohttp

app = Quart(__name__)

model = None

@app.before_serving
async def startup():
app.client = aiohttp.ClientSession()
global model
model = Model(app)

@app.route('/', methods=['POST'])
async def embed():
global model
data = await request.get_json()
req = ModelRequest(**data)
return await model.inference(req)

if __name__ == "__main__":
app.run()
24 changes: 24 additions & 0 deletions src/text_classification/convo_starter_orgbot/local/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from request import ModelRequest

class Model():
def __new__(cls, context):
cls.context = context
if not hasattr(cls, 'instance'):
cls.instance = super(Model, cls).__new__(cls)
model_name = "GautamR/convo_beginner_orgbot"
cls.tokenizer = AutoTokenizer.from_pretrained(model_name)
cls.model = AutoModelForSequenceClassification.from_pretrained(model_name)
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.model.to(cls.device)
return cls.instance


async def inference(self, request: ModelRequest):
inputs = self.tokenizer(request.text, return_tensors="pt")
inputs = {key: value.to(self.device) for key, value in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits
predicted_class_id = logits.argmax().item()
return self.model.config.id2label[predicted_class_id]
11 changes: 11 additions & 0 deletions src/text_classification/convo_starter_orgbot/local/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import requests
import json


class ModelRequest():
def __init__(self, text):
self.text = text

def to_json(self):
return json.dumps(self, default=lambda o: o.__dict__,
sort_keys=True, indent=4)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch==2.0.1 --index-url https://download.pytorch.org/whl/cpu
transformers
quart
aiohttp

0 comments on commit 481eadc

Please sign in to comment.