Skip to content

Commit

Permalink
feature:chat api
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehuazhang committed Jan 17, 2025
1 parent 7c0760d commit 63f08b7
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 420 deletions.
56 changes: 0 additions & 56 deletions chatchat-server/chatchat/server/agent/graphs_factory/testinvoke.py

This file was deleted.

34 changes: 19 additions & 15 deletions chatchat-server/chatchat/server/api_server/api_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,30 @@ class Config:
extra = "allow"


class OpenAIChatInput(OpenAIBaseInput):
class AgentChatInput(BaseModel):
messages: List[ChatCompletionMessageParam]
model: str = get_default_llm()
frequency_penalty: Optional[float] = None
function_call: Optional[completion_create_params.FunctionCall] = None
functions: List[completion_create_params.Function] = None
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: completion_create_params.ResponseFormat = None
seed: Optional[int] = None
stop: Union[Optional[str], List[str]] = None
stream: Optional[bool] = None
graph: str
thread_id: int
temperature: Optional[float] = Settings.model_settings.TEMPERATURE
max_completion_tokens: Optional[int] = None
tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
tools: List[Union[ChatCompletionToolParam, str]] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None
stream: Optional[bool] = True
stream_method: Optional[Literal["streamlog", "node", "invoke"]] = "streamlog"
# frequency_penalty: Optional[float] = None
# function_call: Optional[completion_create_params.FunctionCall] = None
# functions: List[completion_create_params.Function] = None
# logit_bias: Optional[Dict[str, int]] = None
# logprobs: Optional[bool] = None
# max_tokens: Optional[int] = None
# n: Optional[int] = None
# presence_penalty: Optional[float] = None
# response_format: completion_create_params.ResponseFormat = None
# seed: Optional[int] = None
# stop: Union[Optional[str], List[str]] = None
# top_logprobs: Optional[int] = None
# top_p: Optional[float] = None


class OpenAIEmbeddingsInput(OpenAIBaseInput):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,36 @@
import asyncio
import json
import logging
from typing import TypedDict, Annotated

from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import tools_condition, ToolNode
from sse_starlette.sse import EventSourceResponse
from typing import Annotated
from typing_extensions import TypedDict
import rich
from fastapi import APIRouter
from langgraph.graph import add_messages
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from sse_starlette import EventSourceResponse

from fastapi import FastAPI, Request
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from chatchat.server.agent.tools_factory import search_internet, search_youtube
from chatchat.server.api_server.api_schemas import AgentChatInput
from chatchat.server.utils import create_agent_models
from chatchat.utils import build_logger

from chatchat.server.agent.tools_factory import search_internet
from chatchat.server.utils import create_agent_models, add_tools_if_not_exists

app = FastAPI()
logger = logging.getLogger("uvicorn.error")
logger = build_logger()


class ClientDisconnectException(Exception):
pass
chat_router = APIRouter(prefix="/v1", tags=["Agent 对话接口"])


def get_chatbot() -> CompiledStateGraph:
class State(TypedDict):
messages: Annotated[list, add_messages]

llm = create_agent_models(configs=None,
model="qwen2.5-instruct",
model="hunyuan-turbo",
max_tokens=None,
temperature=0,
stream=True)

tools = add_tools_if_not_exists(tools_provides=[], tools_need_append=[search_internet])
tools = [search_internet, search_youtube]
llm_with_tools = llm.bind_tools(tools)

def chatbot(state: State):
Expand All @@ -57,16 +54,17 @@ def chatbot(state: State):
return graph


@app.post("/stream")
async def openai_stream_output(request: Request):
@chat_router.post("/chat/completions")
async def openai_stream_output(
body: AgentChatInput
):
rich.print(body)

async def generator():
graph = get_chatbot()
inputs = {"role": "user", "content": "Please introduce Trump based on the Internet search results."}
try:
async for event in graph.astream(input={"messages": inputs}, stream_mode="updates"):
disconnected = await request.is_disconnected()
if disconnected:
raise ClientDisconnectException("Client disconnected")
# async for event in graph.astream(input={"messages": inputs}, stream_mode="updates"):
async for event in graph.astream(input={"messages": body.messages}, stream_mode="updates"):
yield str(event)
except asyncio.exceptions.CancelledError:
logger.warning("Streaming progress has been interrupted by user.")
Expand All @@ -77,7 +75,3 @@ async def generator():
return

return EventSourceResponse(generator())

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)
Loading

0 comments on commit 63f08b7

Please sign in to comment.