import os
import fitz
import img2pdf
import io
import re
import gc
import torch
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
if torch.version.cuda == '11.8':
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from deepseek_ocr2 import DeepseekOCR2ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCR2Processor
ModelRegistry.register_model("DeepseekOCR2ForCausalLM", DeepseekOCR2ForCausalLM)
# 引擎重启间隔(处理多少页后重启)
REBOOT_INTERVAL = 50 # 每50页重启一次LLM引擎
class Colors:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
RESET = '\033[0m'
def init_llm():
"""初始化LLM引擎"""
print(f"{Colors.BLUE}初始化LLM引擎...{Colors.RESET}")
llm = LLM(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCR2ForCausalLM"]},
block_size=256,
enforce_eager=False, # 使用eager模式,减少内存累积
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9, # 降低显存使用
disable_mm_preprocessor_cache=True
)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822})]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
include_stop_str_in_output=True,
)
return llm, sampling_params
def pdf_to_images_stream(pdf_path, dpi=144):
"""逐页生成器:每次返回一页图像"""
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
yield img, page_num
pdf_document.close()
def pil_to_pdf_img2pdf(pil_images, output_path):
"""将PIL图像列表保存为PDF"""
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != 'RGB':
img = img.convert('RGB')
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=95)
image_bytes_list.append(img_buffer.getvalue())
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
f.write(pdf_bytes)
except Exception as e:
print(f"error: {e}")
def re_match(text):
"""正则匹配引用和检测标记"""
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
"""提取坐标和标签"""
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
except Exception as e:
print(e)
return None
return (label_type, cor_list)
def draw_bounding_boxes(image, refs, jdx, output_dir):
"""绘制边界框"""
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
font = ImageFont.load_default()
img_idx = 0
images_dir = os.path.join(output_dir, 'images')
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = points
x1 = int(x1 / 999 * image_width)
y1 = int(y1 / 999 * image_height)
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == 'image':
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{images_dir}/{jdx}_{img_idx}.jpg")
except Exception as e:
print(e)
img_idx += 1
try:
if label_type == 'title':
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30))
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
except:
continue
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_single_pdf(pdf_path, pdf_output_dir, reboot_interval=REBOOT_INTERVAL):
"""处理单个PDF文件(逐页流式处理,定期重启LLM引擎)"""
pdf_name = os.path.basename(pdf_path).replace('.pdf', '')
print(f"\n{Colors.GREEN}开始处理: {pdf_name}.pdf{Colors.RESET}")
# 为每个PDF创建独立的输出目录
file_output_dir = os.path.join(pdf_output_dir, pdf_name)
os.makedirs(file_output_dir, exist_ok=True)
os.makedirs(os.path.join(file_output_dir, 'images'), exist_ok=True)
# 准备输出文件路径
mmd_det_path = os.path.join(file_output_dir, f'{pdf_name}_det.mmd')
mmd_path = os.path.join(file_output_dir, f'{pdf_name}.mmd')
pdf_out_path = os.path.join(file_output_dir, f'{pdf_name}_layouts.pdf')
contents_det = ''
contents = ''
draw_images = []
page_idx = 0
# 首次初始化LLM引擎
llm, sampling_params = init_llm()
# 逐页处理
for img, page_num in tqdm(pdf_to_images_stream(pdf_path), desc=f"Processing {pdf_name}.pdf"):
try:
# 检查是否需要重启引擎
if page_idx > 0 and page_idx % reboot_interval == 0:
print(f"\n{Colors.YELLOW}已处理 {page_idx} 页,重启LLM引擎释放内存...{Colors.RESET}")
# 删除旧引擎
del llm
gc.collect()
torch.cuda.empty_cache()
# 重新初始化引擎
llm, sampling_params = init_llm()
print(f"{Colors.GREEN}引擎重启完成,继续处理{Colors.RESET}")
# 处理单页图像
prompt_in = PROMPT
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCR2Processor().tokenize_with_images(
images=[img], bos=True, eos=True, cropping=CROP_MODE)},
}
# 执行推理
outputs = llm.generate([cache_item], sampling_params=sampling_params)
content = outputs[0].outputs[0].text
# 清理输出
if '<|end▁of▁sentence|>' in content:
content = content.replace('<|end▁of▁sentence|>', '')
else:
if SKIP_REPEAT:
continue
page_sep = f'\n<--- Page Split --->\n'
contents_det += content + page_sep
# 处理边界框和图像
image_draw = img.copy()
matches_ref, matches_images, mathes_other = re_match(content)
result_image = draw_bounding_boxes(image_draw, matches_ref, page_idx, file_output_dir)
draw_images.append(result_image)
# 替换图像引用
for idx, a_match_image in enumerate(matches_images):
content = content.replace(a_match_image, f'\n')
for a_match_other in mathes_other:
content = content.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
contents += content + page_sep
page_idx += 1
except Exception as e:
print(f"{Colors.RED}处理第 {page_num+1} 页时出错: {e}{Colors.RESET}")
continue
# 保存结果
with open(mmd_det_path, 'w', encoding='utf-8') as f:
f.write(contents_det)
with open(mmd_path, 'w', encoding='utf-8') as f:
f.write(contents)
if draw_images:
pil_to_pdf_img2pdf(draw_images, pdf_out_path)
print(f"{Colors.GREEN}完成: {pdf_name}.pdf (共 {page_idx} 页){Colors.RESET}")
# 处理完成后释放引擎
del llm
gc.collect()
torch.cuda.empty_cache()
return page_idx
def batch_process_pdfs(input_dir, output_dir):
"""批量处理目录下所有PDF文件"""
# 获取所有PDF文件
pdf_files = [f for f in os.listdir(input_dir) if f.lower().endswith('.pdf')]
if not pdf_files:
print(f"{Colors.RED}在目录 {input_dir} 中没有找到PDF文件{Colors.RESET}")
return
print(f"{Colors.BLUE}找到 {len(pdf_files)} 个PDF文件,开始批量处理...{Colors.RESET}")
print(f"{Colors.BLUE}引擎重启间隔: {REBOOT_INTERVAL} 页{Colors.RESET}")
total_pages = 0
success_count = 0
for pdf_file in pdf_files:
pdf_path = os.path.join(input_dir, pdf_file)
try:
pages = process_single_pdf(pdf_path, output_dir)
total_pages += pages
success_count += 1
except Exception as e:
print(f"{Colors.RED}处理 {pdf_file} 时出错: {e}{Colors.RESET}")
continue
print(f"\n{Colors.GREEN}{'='*50}{Colors.RESET}")
print(f"{Colors.GREEN}批量处理完成!{Colors.RESET}")
print(f"{Colors.GREEN}成功处理: {success_count}/{len(pdf_files)} 个文件{Colors.RESET}")
print(f"{Colors.GREEN}总页数: {total_pages} 页{Colors.RESET}")
print(f"{Colors.GREEN}输出目录: {output_dir}{Colors.RESET}")
print(f"{Colors.GREEN}{'='*50}{Colors.RESET}")
if __name__ == "__main__":
# 配置输入输出目录
if os.path.isdir(INPUT_PATH):
INPUT_DIR = INPUT_PATH
else:
INPUT_DIR = os.path.dirname(INPUT_PATH)
if not INPUT_DIR:
INPUT_DIR = '.'
# 输出根目录
if not OUTPUT_PATH:
OUTPUT_DIR = './batch_output'
else:
OUTPUT_DIR = OUTPUT_PATH
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"{Colors.BLUE}输入目录: {INPUT_DIR}{Colors.RESET}")
print(f"{Colors.BLUE}输出目录: {OUTPUT_DIR}{Colors.RESET}")
# 开始批量处理
batch_process_pdfs(INPUT_DIR, OUTPUT_DIR)