Skip to content

Commit

Permalink
Update chatglm3.py
Browse files Browse the repository at this point in the history
  • Loading branch information
shadowcz007 committed Dec 18, 2023
1 parent 598b27c commit 8fbf8ec
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions chatglm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ class ChatCompletionResponse(BaseModel):




@app.on_event("startup")
async def startup_event():
global pipeline
Expand All @@ -158,7 +157,9 @@ async def startup_event():
"role":"system", "content":CHAT_SYSTEM_PROMPT
})
messages_with_system += messages
print("--------")
print(messages_with_system)
print("--------")
res=pipeline.chat(messages_with_system,max_length=2048,
max_context_length=2048,
do_sample=0.8 > 0,
Expand All @@ -170,7 +171,8 @@ async def startup_event():
stream=False,)

print(res)
logging.info("End Loading chatglm model")
print("--------")
print("End Loading chatglm model")



Expand Down Expand Up @@ -208,9 +210,9 @@ async def stream_chat_event_publisher(history, body):
await asyncio.sleep(0) # yield control back to event loop for cancellation check
output += chunk.choices[0].delta.content or ""
yield chunk.model_dump_json(exclude_unset=True)
logging.info(f'prompt: "{history[-1]}", stream response: "{output}"')
print(f'prompt: "{history[-1]}", stream response: "{output}"')
except asyncio.CancelledError as e:
logging.info(f'prompt: "{history[-1]}", stream response (partial): "{output}"')
print(f'prompt: "{history[-1]}", stream response (partial): "{output}"')
raise e


Expand All @@ -221,7 +223,8 @@ async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionR
raise HTTPException(status.HTTP_400_BAD_REQUEST, "empty messages")

messages = [chatglm_cpp.ChatMessage(role=msg.role, content=msg.content) for msg in body.messages]

print('Message length',len(messages))
print('------')
if body.stream:
generator = stream_chat_event_publisher(messages, body)
return EventSourceResponse(generator)
Expand All @@ -235,7 +238,7 @@ async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionR
top_p=body.top_p,
temperature=body.temperature,
)
logging.info(f'prompt: "{messages[-1].content}", sync response: "{output.content}"')
print(f'prompt: "{messages[-1].content}", sync response: "{output.content}"')
prompt_tokens = len(pipeline.tokenizer.encode_messages(messages, max_context_length))
completion_tokens = len(pipeline.tokenizer.encode(output.content, body.max_tokens))

Expand Down Expand Up @@ -279,7 +282,7 @@ def start():
# 示例用法
end_port = 9000
available_port = find_available_port(port, end_port)
print("Available port:", available_port)
print("##Available port:", available_port)

uvicorn.run(app, host=settings.host, port=available_port)

Expand Down

0 comments on commit 8fbf8ec

Please sign in to comment.