diff --git a/examples/llamaindex_rag/app.py b/examples/llamaindex_rag/app.py index c7a4ca2..cc51e9a 100644 --- a/examples/llamaindex_rag/app.py +++ b/examples/llamaindex_rag/app.py @@ -1,10 +1,12 @@ import os import sys +import uuid import logging import click import uvicorn import fastapi import asyncio +import contextvars from sqlalchemy import URL from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.staticfiles import StaticFiles @@ -14,9 +16,31 @@ from llama_index.vector_stores.tidbvector import TiDBVectorStore from llama_index.readers.web import SimpleWebPageReader + +# Setup logging logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger() +# Setup in-memory cache +class InMemoryCache: + def __init__(self): + self.cache = {} + + def set(self, key, value): + self.cache[key] = value + + def get(self, key): + return self.cache.get(key) + + def delete(self, key): + if key in self.cache: + del self.cache[key] + + def clear(self): + self.cache.clear() + +cache = InMemoryCache() + logger.info("Initializing TiDB Vector Store....") tidb_connection_url = URL( @@ -63,6 +87,17 @@ async def astreamer(response: llamaStreamingResponse): app = fastapi.FastAPI() templates = Jinja2Templates(directory="templates") +# Setup contextvars +request_id_contextvar = contextvars.ContextVar('request_id', default=None) + +@app.middleware("http") +async def add_request_id_header(request: fastapi.Request, call_next): + request_id = str(uuid.uuid4()) + request_id_contextvar.set(request_id) + response = await call_next(request) + response.headers["X-Request-ID"] = request_id + return response + @app.get('/', response_class=HTMLResponse) def index(request: fastapi.Request): @@ -72,9 +107,16 @@ def index(request: fastapi.Request): @app.get('/ask') async def ask(q: str): response = query_engine.query(q) + request_id = request_id_contextvar.get() + cache.set(request_id, vars(response)) return StreamingResponse(astreamer(response), media_type='text/event-stream') +@app.get('/getResponseMeta/{request_id}') +async def response(request_id: str): + return cache.get(request_id) + + @click.group(context_settings={'max_content_width': 150}) def cli(): pass @@ -83,7 +125,7 @@ def cli(): @cli.command() @click.option('--host', default='127.0.0.1', help="Host, default=127.0.0.1") @click.option('--port', default=3000, help="Port, default=3000") -@click.option('--reload', is_flag=True, help="Enable auto-reload") +@click.option('--reload', is_flag=True, default=True, help="Enable auto-reload") def runserver(host, port, reload): uvicorn.run( "__main__:app", host=host, port=port, reload=reload, diff --git a/examples/llamaindex_rag/templates/index.html b/examples/llamaindex_rag/templates/index.html index 636b0df..2b713d0 100644 --- a/examples/llamaindex_rag/templates/index.html +++ b/examples/llamaindex_rag/templates/index.html @@ -6,39 +6,63 @@