Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance batch processing in BatchAnalyze with layout and OCR timing logs #1284

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions magic_pdf/model/batch_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,80 @@ def __init__(self, model: CustomPEKModel, batch_ratio: int):
self.batch_ratio = batch_ratio

def __call__(self, images: list) -> list:
images_layout_res = []

layout_start_time = time.time()
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
images_layout_res = []
for image in images:
layout_res = self.model.layout_model(image, ignore_catids=[])
images_layout_res.append(layout_res)
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
images_layout_res = self.model.layout_model.batch_predict(
images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images = []
modified_images = []
for image_index, image in enumerate(images):
pil_img = Image.fromarray(image)
width, height = pil_img.size
if height > width:
input_res = {"poly": [0, 0, width, 0, width, height, 0, height]}
new_image, useful_list = crop_img(
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
)
layout_images.append(new_image)
modified_images.append([image_index, useful_list])
else:
layout_images.append(pil_img)

images_layout_res += self.model.layout_model.batch_predict(
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
)

for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res["poly"])):
if i % 2 == 0:
res["poly"][i] = (
res["poly"][i] - useful_list[0] + useful_list[2]
)
else:
res["poly"][i] = (
res["poly"][i] - useful_list[1] + useful_list[3]
)
logger.info(
f"layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}"
)

if self.model.apply_formula:
# 公式检测
mfd_start_time = time.time()
images_mfd_res = self.model.mfd_model.batch_predict(
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
)
logger.info(
f"mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}"
)

# 公式识别
mfr_start_time = time.time()
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
logger.info(
f"mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}"
)

# 清理显存
clean_vram(self.model.device, vram_threshold=8)

ocr_time = 0
ocr_count = 0
table_time = 0
table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)):
layout_res = images_layout_res[index]
Expand Down Expand Up @@ -99,12 +143,8 @@ def __call__(self, images: list) -> list:
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list)

ocr_cost = round(time.time() - ocr_start, 2)
if self.model.apply_ocr:
logger.info(f"ocr time: {ocr_cost}")
else:
logger.info(f"det time: {ocr_cost}")
ocr_time += time.time() - ocr_start
ocr_count += len(ocr_res_list)

# 表格识别 table recognition
if self.model.apply_table:
Expand Down Expand Up @@ -146,7 +186,17 @@ def __call__(self, images: list) -> list:
logger.warning(
"table recognition processing fails, not get html return"
)
logger.info(f"table time: {round(time.time() - table_start, 2)}")
table_time += time.time() - table_start
table_count += len(table_res_list)

if self.model.apply_ocr:
logger.info(f"ocr time: {round(ocr_time, 2)}, image num: {ocr_count}")
else:
logger.info(f"det time: {round(ocr_time, 2)}, image num: {ocr_count}")
if self.model.apply_table:
logger.info(f"table time: {round(table_time, 2)}, image num: {table_count}")

return images_layout_res


def doc_batch_analyze(
Expand Down Expand Up @@ -223,6 +273,8 @@ def doc_batch_analyze(
model_json.append(page_dict)

# TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time()
clean_memory()
logger.info(f"clean memory time: {round(time.time() - clean_memory_start_time, 2)}")

return InferenceResult(model_json, dataset)
19 changes: 11 additions & 8 deletions magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ def predict(self, image):
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
for index in range(0, len(images), batch_size):
doclayout_yolo_res = self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
).cpu()
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
)
]
for image_res in doclayout_yolo_res:
layout_res = []
for xyxy, conf, cla in zip(
Expand Down
19 changes: 11 additions & 8 deletions magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ def predict(self, image):
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
for index in range(0, len(images), batch_size):
mfd_res = self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
).cpu()
mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
)
]
for image_res in mfd_res:
images_mfd_res.append(image_res)
return images_mfd_res
Loading