Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sequence2txt error and usage total token issue #2961

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/apps/conversation_app.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,6 @@
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api.settings import RetCode, retrievaler
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor
@@ -187,6 +186,7 @@ def stream():
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
traceback.print_exc()
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
3 changes: 2 additions & 1 deletion api/db/services/llm_service.py
Original file line number Diff line number Diff line change
@@ -133,7 +133,8 @@ def model_instance(cls, tenant_id, llm_type,
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], lang,
key=model_config["api_key"], model_name=model_config["llm_name"],
lang=lang,
base_url=model_config["api_base"]
)
if llm_type == LLMType.TTS:
2 changes: 2 additions & 0 deletions api/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -197,6 +197,7 @@ def thumbnail_img(filename, blob):
pass
return None


def thumbnail(filename, blob):
img = thumbnail_img(filename, blob)
if img is not None:
@@ -205,6 +206,7 @@ def thumbnail(filename, blob):
else:
return ''


def traversal_files(base):
for root, ds, fs in os.walk(base):
for f in fs:
18 changes: 10 additions & 8 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
@@ -67,14 +67,16 @@ def chat_streamly(self, system, history, gen_conf):
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
total_tokens = (
(
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
if not hasattr(resp, "usage") or not resp.usage
else resp.usage.get("total_tokens", total_tokens)
)
total_tokens += 1
if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens

if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
2 changes: 1 addition & 1 deletion rag/llm/sequence2txt_model.py
Original file line number Diff line number Diff line change
@@ -87,7 +87,7 @@ def __init__(self, key, model_name, lang="Chinese", **kwargs):


class XinferenceSeq2txt(Base):
def __init__(self,key,model_name="whisper-small",**kwargs):
def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get('base_url', None)
self.model_name = model_name
self.key = key