Skip to content

Commit

Permalink
fix(parallel_pipeline.py): multimodal retrieval (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisonooo authored Aug 26, 2024
1 parent 79fa810 commit 88a1bf3
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 20 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Our Web version has been released to [OpenXLab](https://openxlab.org.cn/apps/det
<b>Retrieval Method</b>
</td>
<td>
<b>Instant Messaging</b>
<b>Integration</b>
</td>
<td>
<b>Preprocessing</b>
Expand Down Expand Up @@ -126,8 +126,11 @@ Our Web version has been released to [OpenXLab](https://openxlab.org.cn/apps/det

<td>

- WeChat
- WeChat([android](./docs/add_wechat_accessibility_zh.md)/[wkteam](./docs/add_wechat_commercial_zh.md))
- Lark
- [OpenXLab Web](https://openxlab.org.cn/apps/detail/tpoisonooo/huixiangdou-web)
- [Gradio Demo](./huixiangdou/gradio.py)
- [HTTP Server](./huixiangdou/server.py)

</td>

Expand Down
13 changes: 8 additions & 5 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Web 版视频教程见 [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn)
<b>检索方法</b>
</td>
<td>
<b>即时通讯</b>
<b>接入方法</b>
</td>
<td>
<b>预处理</b>
Expand Down Expand Up @@ -125,8 +125,11 @@ Web 版视频教程见 [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn)

<td>

- WeChat
- Lark
- 微信([android](./docs/add_wechat_accessibility_zh.md)/[wkteam](./docs/add_wechat_commercial_zh.md)
- 飞书
- [OpenXLab Web](https://openxlab.org.cn/apps/detail/tpoisonooo/huixiangdou-web)
- [Gradio Demo](./huixiangdou/gradio.py)
- [HTTP Server](./huixiangdou/server.py)

</td>

Expand Down Expand Up @@ -321,8 +324,8 @@ reranker_model_path = "BAAI/bge-reranker-v2-minicpm-layerwise"

需要注意:

- 要手动下载 [Visualized_m3.pth](https://huggingface.co/BAAI/bge-visualized/blob/main/Visualized_m3.pth)[bge-m3](https://huggingface.co/BAAI/bge-m3) 目录下
- FlagEmbedding 需要安装新版,我们做了 [bugfix](https://github.com/FlagOpen/FlagEmbedding/commit/3f84da0796d5badc3ad519870612f1f18ff0d1d3)[这里](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/visual/eva_clip/bpe_simple_vocab_16e6.txt.gz)可以下载 BGE 打包漏掉的 `bpe_simple_vocab_16e6.txt.gz`
- 先下载 [bge-m3](https://huggingface.co/BAAI/bge-m3),然后把 [Visualized_m3.pth](https://huggingface.co/BAAI/bge-visualized/blob/main/Visualized_m3.pth) 放进 `bge-m3` 目录
- FlagEmbedding 需要安装 master 最新版,我们做了 [bugfix](https://github.com/FlagOpen/FlagEmbedding/commit/3f84da0796d5badc3ad519870612f1f18ff0d1d3)[这里](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/visual/eva_clip/bpe_simple_vocab_16e6.txt.gz)可以下载 BGE 打包漏掉的 `bpe_simple_vocab_16e6.txt.gz`
- 安装 [requirments-multimodal.txt](./requirements-multimodal.txt)

运行 gradio 测试,图文检索效果见[这里](https://github.com/InternLM/HuixiangDou/pull/326).
Expand Down
2 changes: 2 additions & 0 deletions docs/architecture_en.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Code Structure Explanation

<img src="./figures/huixiangdou.png" width="400">

This document primarily explains the directory structure and functionalities of HuixiangDou. The documentation may not be updated in real-time with the code, but the definitions that are in place will no longer change.

## First Layer: Project Introduction
Expand Down
2 changes: 2 additions & 0 deletions docs/architecture_zh.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# 代码结构说明

<img src="./figures/huixiangdou.png" width="400">

本文主要解释豆哥(茴香豆)各目录和功能。文档可能无法随代码即时更新,但已有定义不会再变动。

## 第一层:项目介绍
Expand Down
Binary file added docs/figures/huixiangdou.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion huixiangdou/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def predict(text:str, image:str):

if image is not None:
filename = 'image.png'
image_path = os.path.join(args.work_dir, filename)
image_path = os.path.join(main_args.work_dir, filename)
cv2.imwrite(image_path, image)
else:
image_path = None
Expand Down
6 changes: 3 additions & 3 deletions huixiangdou/primitive/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def __str__(self) -> str:

formatted = ''
if self.text is not None:
formatted += f"text='{self.text}'"
formatted += f"text='{self.text}' "
if self.image is not None:
formatted += f"image='{self.image}'"
formatted += f"image='{self.image}' "
if self.audio is not None:
formatted += f"audio='{self.audio}'"
formatted += f"audio='{self.audio}' "
return formatted

def __repr__(self) -> str:
Expand Down
16 changes: 10 additions & 6 deletions huixiangdou/primitive/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,8 @@ def nested_split_markdown(filepath: str,
image_chunks = []

text_ref_splitter = MarkdownTextRefSplitter(chunk_size=chunksize)
ref_pattern = re.compile(r'\[([^\]]+)\]\(([a-zA-Z0-9:/._~#-]+)?\)')
md_image_pattern = re.compile(r'\[([^\]]+)\]\(([a-zA-Z0-9:/._~#-]+)?\)')
html_image_pattern = re.compile(r'<img\s+[^>]*?src=["\']([^"\']*)["\'][^>]*>')
file_opr = FileOperation()

for chunk in chunks:
Expand All @@ -592,12 +593,15 @@ def nested_split_markdown(filepath: str,
content = '{} {}'.format(header, chunk.content_or_path.lower())
text_chunks.append(Chunk(content, metadata))

# extract images
matches = ref_pattern.findall(chunk.content_or_path)
# extract images path
dirname = os.path.dirname(filepath)
for match in matches:
# target = match[0]
image_path = match[1]

image_paths = []
for match in md_image_pattern.findall(chunk.content_or_path):
image_paths.append(match[1])
for match in html_image_pattern.findall(chunk.content_or_path):
image_paths.append(match)
for image_path in image_paths:
if file_opr.get_type(image_path) != 'image':
continue

Expand Down
6 changes: 3 additions & 3 deletions huixiangdou/service/parallel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, config: dict, llm: ChatClient, language: str):

def process(self, sess: Session) -> Generator[Session, None, None]:
# check input
if sess.query.text is None or len(sess.query.text) < 6:
if sess.query.text is None or len(sess.query.text) < 2:
sess.code = ErrorCode.QUESTION_TOO_SHORT
yield sess
return
Expand Down Expand Up @@ -127,7 +127,7 @@ async def process_coroutine(self, sess: Session) -> Session:
"""Try get reply with text2vec & rerank model."""

# retrieve from knowledge base
sess.parallel_chunks = await asyncio.to_thread(self.retriever.text2vec_retrieve, sess.query.text)
sess.parallel_chunks = await asyncio.to_thread(self.retriever.text2vec_retrieve, sess.query)
# sess.parallel_chunks = self.retriever.text2vec_retrieve(query=sess.query.text)
return sess

Expand Down Expand Up @@ -220,7 +220,7 @@ async def process(self, sess: Session) -> Generator[Session, None, None]:
else:
_, context_str, references = self.retriever.rerank_fuse(query=sess.query, chunks=sess.parallel_chunks, context_max_length=self.context_max_length)
sess.references = references
prompt = self.GENERATE_TEMPLATE.format(context_str, sess.query)
prompt = self.GENERATE_TEMPLATE.format(context_str, sess.query.text)
async for part in self.llm.chat_stream(prompt=prompt, history=history):
sess.delta = part
yield sess
Expand Down

0 comments on commit 88a1bf3

Please sign in to comment.