Skip to content

Commit

Permalink
feat(huixiangdou/service): support reverted-indexer
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisonooo committed Sep 19, 2024
1 parent bbbb6f2 commit fc9d646
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 109 deletions.
83 changes: 45 additions & 38 deletions evaluation/end2end/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import json
import asyncio
import pdb
import os
from typing import List
from rouge import Rouge
from loguru import logger

assistant = ParallelPipeline(work_dir='/home/khj/hxd-ci/workdir', config_path='/home/khj/hxd-ci/config.ini')
config_path = '/home/data/khj/workspace/huixiangdou/config.ini'
assistant = ParallelPipeline(work_dir='/home/data/khj/workspace/huixiangdou/workdir', config_path=config_path)

def format_refs(refs: List[str]):
refs_filter = list(set(refs))
Expand All @@ -31,49 +33,54 @@ async def run(query_text: str):
refs = sess.references
return sentence, refs

gts = []
dts = []

output_filepath = 'out.jsonl'

finished_query = []
with open(output_filepath) as fin:
json_str = ""
for line in fin:
json_str += line
if __name__ == "__main__":
gts = []
dts = []

if '}\n' == line:
print(json_str)
json_obj = json.loads(json_str)
finished_query.append(json_obj['query'].strip())
# hybrid llm serve
print('evaluate ParallelPipeline precision, first `python3 -m huixiangdou.service.llm_server_hybrid`, then prepare your qa pair in `qa.json`.')
output_filepath = 'out.jsonl'

finished_query = []
if os.path.exists(output_filepath):
with open(output_filepath) as fin:
json_str = ""
for line in fin:
json_str += line

if '}\n' == line:
print(json_str)
json_obj = json.loads(json_str)
finished_query.append(json_obj['query'].strip())
json_str = ""

with open('evaluation/end2end/qa.jsonl') as fin:
for json_str in fin:
json_obj = json.loads(json_str)
query = json_obj['query'].strip()
if query in finished_query:
continue

gt = json_obj['resp']
gts.append(gt)
with open('evaluation/end2end/qa.jsonl') as fin:
for json_str in fin:
json_obj = json.loads(json_str)
query = json_obj['query'].strip()
if query in finished_query:
continue
gt = json_obj['resp']
gts.append(gt)

loop = asyncio.get_event_loop()
dt, refs = loop.run_until_complete(run(query_text=query))
dts.append(dt)
loop = asyncio.get_event_loop()
dt, refs = loop.run_until_complete(run(query_text=query))
dts.append(dt)

distance = assistant.retriever.embedder.distance(text1=gt, text2=dt).tolist()
distance = assistant.retriever.embedder.distance(text1=gt, text2=dt).tolist()

rouge = Rouge()
scores = rouge.get_scores(gt, dt)
json_obj['distance'] = distance
json_obj['rouge_scores'] = scores
json_obj['dt'] = dt
json_obj['dt_refs'] = refs
rouge = Rouge()
scores = rouge.get_scores(dt, gt)
json_obj['distance'] = distance
json_obj['rouge_scores'] = scores
json_obj['dt'] = dt
json_obj['dt_refs'] = refs

out_json_str = json.dumps(json_obj, ensure_ascii=False, indent=2)
logger.info(out_json_str)
out_json_str = json.dumps(json_obj, ensure_ascii=False, indent=2)
logger.info(out_json_str)

with open(output_filepath, 'a') as fout:
fout.write(out_json_str)
fout.write('\n')
with open(output_filepath, 'a') as fout:
fout.write(out_json_str)
fout.write('\n')
6 changes: 2 additions & 4 deletions huixiangdou/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from .service import ErrorCode, SerialPipeline, build_reply_text, start_llm_server


def parse_args():
"""Parse args."""
parser = argparse.ArgumentParser(description='SerialPipeline.')
Expand Down Expand Up @@ -60,7 +59,6 @@ def check_env(args):


def show(assistant, fe_config: dict):

queries = ['请问如何安装 mmpose ?', '请问明天天气如何?']
print(colored('Running some examples..', 'yellow'))
for query in queries:
Expand Down Expand Up @@ -142,7 +140,7 @@ def lark_group_recv_and_send(assistant, fe_config: dict):
code, reply, refs = str(sess.code), sess.response, sess.references
if code == ErrorCode.SUCCESS:
json_obj['reply'] = build_reply_text(reply=reply,
references=references)
references=refs)
error, msg_id = send_to_lark_group(
json_obj=json_obj,
app_id=lark_group_config['app_id'],
Expand All @@ -169,7 +167,7 @@ async def api(request):
for sess in assistant.generate(query=query, history=[], groupname=''):
pass
code, reply, refs = str(sess.code), sess.response, sess.references
reply_text = build_reply_text(reply=reply, references=references)
reply_text = build_reply_text(reply=reply, references=refs)

return web.json_response({'code': int(code), 'reply': reply_text})

Expand Down
1 change: 1 addition & 0 deletions huixiangdou/primitive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
nested_split_markdown, split_python_code)
from .limitter import RPM, TPM
from .bm250kapi import BM25Okapi
from .entity import NamedEntity2Chunk
46 changes: 35 additions & 11 deletions huixiangdou/primitive/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,54 @@

class NamedEntity2Chunk:
"""Save the relationship between Named Entity and Chunk to sqlite"""
def __init__(self, file_dir:str):
def __init__(self, file_dir:str, ignore_case=True):
self.file_dir = file_dir
# case sensitive
self.ignore_case = ignore_case
if not os.path.exists(file_dir):
os.makedirs(file_dir)
self.conn = sqlite3.connect(os.path.join(file_dir, 'entity2chunk.sql'))
self.cursor = self.conn.cursor()
self.cursor.execute('''
CREATE TABLE IF NOT EXISTS entities (
kid INTEGER PRIMARY KEY,
eid INTEGER PRIMARY KEY,
chunk_ids TEXT
)
''')
self.conn.commit()

self.entities = []
self.entity_path = os.path.join(self.file_dir, 'entities.json')
if os.path.exists(self.entity_path):
with open(self.entity_path) as f:
self.entities = json.load(f)
if self.ignore_case:
for id, value in enumerate(self.entities):
self.entities[id] = value.lower()

def insert_relation(self, kid: int, chunk_ids: List[int]):
def clean(self):
self.cursor.execute('''DROP TABLE entities;''')
self.cursor.execute('''
CREATE TABLE IF NOT EXISTS entities (
eid INTEGER PRIMARY KEY,
chunk_ids TEXT
)
''')
self.conn.commit()

def insert_relation(self, eid: int, chunk_ids: List[int]):
"""Insert the relationship between keywords id and List of chunk_id"""
chunk_ids_str = ','.join(map(str, chunk_ids)) # 将列表转换为字符串
self.cursor.execute('INSERT INTO entities (kid, chunk_ids) VALUES (?, ?)', (kid, chunk_ids_str))
chunk_ids_str = ','.join(map(str, chunk_ids))
self.cursor.execute('INSERT INTO entities (eid, chunk_ids) VALUES (?, ?)', (eid, chunk_ids_str))
self.conn.commit()

def parse(self, text:str) -> List[int]:
if self.ignore_case:
text = text.lower()

if len(self.entities) < 1:
raise ValueError('entity list empty, please check feature_store init')
ret = []
for index, entity in self.entities:
for index, entity in enumerate(self.entities):
if entity in text:
ret.append(index)
return ret
Expand All @@ -42,7 +61,11 @@ def set_entity(self, entities: List[str]):
json_str = json.dumps(entities, ensure_ascii=False)
with open(self.entity_path, 'w') as f:
f.write(json_str)

self.entities = entities
if self.ignore_case:
for id, value in enumerate(self.entities):
self.entities[id] = value.lower()

def get_chunk_ids(self, entity_ids: Union[List, int]) -> Set:
"""Query by keywords ids"""
Expand All @@ -51,12 +74,13 @@ def get_chunk_ids(self, entity_ids: Union[List, int]) -> Set:

counter = dict()
for eid in entity_ids:
self.cursor.execute('SELECT chunk_ids FROM entities WHERE kid = ?', (eid,))
self.cursor.execute('SELECT chunk_ids FROM entities WHERE eid = ?', (eid,))
result = self.cursor.fetchone()
if result:
chunk_ids = result[0].split(',')
for chunk_id in chunk_ids:
if chunk_id in counter:
for chunk_id_str in chunk_ids:
chunk_id = int(chunk_id_str)
if chunk_id not in counter:
counter[chunk_id] = 1
else:
counter[chunk_id] += 1
Expand Down
9 changes: 5 additions & 4 deletions huixiangdou/primitive/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,12 @@ def nested_split_markdown(filepath: str,
modal='image')
image_chunks.append(c)
else:
logger.error(
f'image cannot access. file: {filepath}, image path: {image_path}'
)
pass
# logger.error(
# f'image cannot access. file: {filepath}, image path: {image_path}'
# )

logger.info('{} text_chunks, {} image_chunks'.format(len(text_chunks), len(image_chunks)))
# logger.info('{} text_chunks, {} image_chunks'.format(len(text_chunks), len(image_chunks)))
return text_chunks + image_chunks

def split_python_code(filepath: str, text: str, metadata: dict = {}):
Expand Down
18 changes: 3 additions & 15 deletions huixiangdou/server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
import argparse
import os
import time

import pytoml
import requests
from aiohttp import web
from loguru import logger
from termcolor import colored

from .service import ErrorCode, SerialPipeline, ParallelPipeline, start_llm_server
from .service import SerialPipeline, ParallelPipeline, start_llm_server
from .primitive import Query
import asyncio
from fastapi import FastAPI, APIRouter
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import json
Expand Down Expand Up @@ -43,7 +33,6 @@ async def huixiangdou_inference(talk: Talk):
query = Query(talk.text, talk.image)

pipeline = {'step': []}
debug = dict()
if type(assistant) is SerialPipeline:
for sess in assistant.generate(query=query):
status = {
Expand Down Expand Up @@ -73,7 +62,6 @@ async def huixiangdou_stream(talk: Talk):
query = Query(talk.text, talk.image)

pipeline = {'step': []}
debug = dict()

def event_stream():
for sess in assistant.generate(query=query):
Expand Down Expand Up @@ -104,7 +92,7 @@ async def event_stream_async():

def parse_args():
"""Parse args."""
parser = argparse.ArgumentParser(description='SerialPipeline.')
parser = argparse.ArgumentParser(description='Serial or Parallel Pipeline.')
parser.add_argument('--work_dir',
type=str,
default='workdir',
Expand Down
Loading

0 comments on commit fc9d646

Please sign in to comment.