连尚读书app
80.86MB · 2025-11-17
你是否遇到过这样的困境:想要识别图片中的文字,但大模型太"重"跑不动,小模型又经常认错字?比如把清晰的"一"识别成"二",把重要的表格数据搞得面目全非……
本篇笔记所对应的视频:www.bilibili.com/video/BV1hP…
别急,今天我要分享一个"化腐朽为神奇"的方法——通过微调技术,让仅有3B参数的DeepSeek-OCR小模型,变成识别准确率媲美大模型的"识字专家"。更重要的是,整个过程只需10分钟,还能用Google Colab的免费GPU完成!
DeepSeek-OCR是一款专门用于文字识别和文档理解的视觉模型,参数量只有3B。别看它"个头小",它有几个独特的优势:
但是,正如"万金油"往往哪里都不精通,DeepSeek-OCR作为通用模型,对各种语言都能识别,但每种语言的准确率都不够理想。这就是我们需要微调的原因。
用最简单的话来说,微调就是给模型"开小灶" 。
想象一下,你有一个什么都会的全能助手(通用模型),但让TA做中文会计报表时经常出错。这时你就给TA准备一本专门的中文会计教材,让TA集中学习这个领域的知识,慢慢地,TA就从"全能选手"变成了"中文会计专家"。
微调的过程就是这样:用特定领域的数据集训练模型,让它在你关注的场景下表现更出色。
根据Unsloth官方的测试数据,微调效果非常显著:
在20万样本的波斯文数据集上微调后,仅用60个训练步(批量大小为8):
在中文场景下的测试显示:
这样的提升,对于实际应用来说是质的飞跃。
整个微调流程比你想象的简单得多,核心步骤只有三步:
你需要准备两类数据:
数据集格式很简单,就是"图像-文本"对:
图像路径: images/doc001.jpg
对应文本: 这是图像中的完整文字内容,包括标点符号。
数据集来源:
制作自己的数据集: 使用提供的Python脚本,只需运行:
python create_dataset.py data.txt output.parquet
脚本会自动将你的图像和文本转换成标准的训练格式。
核心参数设置:
训练完成后,立即可以测试:
# 加载微调后的模型
model, tokenizer = FastVisionModel.from_pretrained("./fine_tuned_model")
# 测试识别
result = model.infer(tokenizer, prompt="<image>nFree OCR.", image_file="test.jpg")
print(result)
对比微调前后的识别结果,你会看到显著的改进。
微调使用的是**LoRA(低秩适应)**技术,这是一种参数高效的微调方法:
这也是为什么我们能用免费资源完成专业级微调的原因。
某档案馆有大量80年代的模糊扫描文件,通用OCR模型错误率高达30%。使用500个样本微调后,错误率降至5%以下,大大加速了数字化进程。
医院需要识别医生的手写处方。使用1000个标注样本微调后,识别准确率从60%提升到92%,显著减少了人工复核工作量。
跨国公司需要处理包含中英混排的合同文档。通过混合数据集微调,模型在中英混排场景下的准确率达到98%。
让我们算一笔账:
传统方案:
微调方案:
更重要的是,微调后的模型完全属于你,可以:
所有资源都已准备好:
微调不再是高深莫测的技术,它已经变得像安装软件一样简单。只要你有需求,有数据,就能动手实践。
在AI快速发展的今天,我们不仅要会"用"模型,更要学会"调"模型。微调技术让我们能够用较小的成本,获得针对性极强的AI能力。
DeepSeek-OCR的微调实战,只是一个开始。掌握了这个方法,你可以将它应用到:
技术的门槛在降低,创新的空间在扩大。现在,轮到你动手实践了!
资源链接:
小提示:
如果这篇文章对你有帮助,欢迎点赞、转发,让更多人了解AI微调的魅力!有任何问题也欢迎在评论区讨论交流。
#AI技术 #OCR识别 #模型微调 #DeepSeek #机器学习 #深度学习实战
colab.research.google.com/github/unsl…
huggingface.co/datasets/pr…
/Users/charlesqin/Desktop/img/1.jpg 剧情跌宕起伏,人
/Users/charlesqin/Desktop/img/2.jpg 好的,特效嘛,也算是良心了,演
/Users/charlesqin/Desktop/img/3.jpg 。。。,剧情逻辑有点不通啊啊。
/Users/charlesqin/Desktop/img/4.jpg 以看不了太烧脑的悬疑片
/Users/charlesqin/Desktop/img/5.jpg 。;这颗行星上存在
/Users/charlesqin/Desktop/img/6.jpg 磁场。不加外磁场时,原子在两个
/Users/charlesqin/Desktop/img/7.jpg 过外放的听歌确实比较不错第一
/Users/charlesqin/Desktop/img/8.jpg 快的,书也很整洁,但是我发现在
/Users/charlesqin/Desktop/img/9.jpg 为空间上的排列,有利于科学研究
/Users/charlesqin/Desktop/img/10.jpg 谁也不讨厌谁
#!/usr/bin/env python3
"""
安装: pip install datasets Pillow scikit-learn tqdm
从 content.md 创建 Parquet 格式的 OCR 数据集
使用方法:
python create_parquet_dataset.py content.md
或者自定义输出路径:
python create_parquet_dataset.py content.md --output my_dataset
"""
import os
import sys
from PIL import Image as PILImage
from datasets import Dataset, DatasetDict, Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
def parse_content_md(file_path):
"""解析 content.md 文件"""
print(f" 读取文件: {file_path}")
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
# 跳过空行和注释
if not line or line.startswith('#'):
continue
# 分割图像路径和文本
parts = line.split(None, 1)
if len(parts) == 2:
image_path, text = parts
data.append((image_path, text))
else:
print(f"️ 行 {line_num}: 格式不正确,已跳过")
print(f" 找到 {len(data)} 条记录")
return data
def create_dataset(data):
"""创建数据集"""
print(f"n 加载图像...")
images = []
texts = []
skipped = 0
for img_path, text in tqdm(data):
# 检查文件
if not os.path.exists(img_path):
print(f"️ 图像不存在: {img_path}")
skipped += 1
continue
try:
# 加载图像
img = PILImage.open(img_path).convert('RGB')
# 基本验证
if img.size[0] < 10 or img.size[1] < 10:
print(f"️ 图像太小: {img_path}")
skipped += 1
continue
if not text or text.strip() == '':
print(f"️ 文本为空: {img_path}")
skipped += 1
continue
images.append(img)
texts.append(text)
except Exception as e:
print(f"️ 加载失败 {img_path}: {e}")
skipped += 1
continue
print(f" 成功加载: {len(images)} 个样本")
if skipped > 0:
print(f"️ 跳过: {skipped} 个样本")
# 创建数据集
dataset = Dataset.from_dict({
'image': images,
'text': texts
})
dataset = dataset.cast_column('image', Image())
return dataset
def split_dataset(dataset):
"""分割数据集为训练/验证/测试集"""
print(f"n 分割数据集...")
indices = list(range(len(dataset)))
# 80% 训练,10% 验证,10% 测试
train_indices, temp_indices = train_test_split(
indices, train_size=0.8, random_state=42
)
val_indices, test_indices = train_test_split(
temp_indices, train_size=0.5, random_state=42
)
train_dataset = dataset.select(train_indices)
val_dataset = dataset.select(val_indices)
test_dataset = dataset.select(test_indices)
print(f" 训练集: {len(train_dataset)} 样本")
print(f" 验证集: {len(val_dataset)} 样本")
print(f" 测试集: {len(test_dataset)} 样本")
return DatasetDict({
'train': train_dataset,
'val': val_dataset,
'test': test_dataset
})
def save_parquet(dataset_dict, output_prefix):
"""保存为 Parquet 格式"""
print(f"n 保存为 Parquet 格式...")
for split_name, split_data in dataset_dict.items():
output_file = f"{output_prefix}_{split_name}.parquet"
split_data.to_parquet(output_file)
print(f" {split_name}: {output_file}")
def print_statistics(dataset_dict):
"""打印统计信息"""
print("n" + "=" * 70)
print(" 数据集统计")
print("=" * 70)
for split_name, split_data in dataset_dict.items():
print(f"n{split_name}:")
print(f" 样本数: {len(split_data)}")
# 文本长度
text_lengths = [len(ex['text']) for ex in split_data]
print(f" 文本长度: {min(text_lengths)}-{max(text_lengths)} "
f"(平均: {sum(text_lengths) / len(text_lengths):.1f})")
# 显示样例
if len(split_data) > 0:
sample_text = split_data[0]['text']
display_text = sample_text[:40] + "..." if len(sample_text) > 40 else sample_text
print(f" 样例: {display_text}")
def main():
# 参数解析
if len(sys.argv) < 2:
print("使用方法: python create_parquet_dataset.py content.md [--output 输出前缀]")
print("n示例:")
print(" python create_parquet_dataset.py content.md")
print(" python create_parquet_dataset.py content.md --output my_dataset")
sys.exit(1)
input_file = sys.argv[1]
# 输出路径
if len(sys.argv) >= 4 and sys.argv[2] == '--output':
output_prefix = sys.argv[3]
else:
output_prefix = "my_ocr_dataset"
print("=" * 70)
print(" 创建 Parquet 格式 OCR 数据集")
print("=" * 70)
print(f"输入文件: {input_file}")
print(f"输出前缀: {output_prefix}")
# 检查输入文件
if not os.path.exists(input_file):
print(f"n 错误: 文件不存在: {input_file}")
sys.exit(1)
try:
# 1. 解析文件
data = parse_content_md(input_file)
if len(data) == 0:
print("n 错误: 没有找到有效数据")
sys.exit(1)
# 2. 创建数据集
dataset = create_dataset(data)
if len(dataset) == 0:
print("n 错误: 没有成功加载任何样本")
sys.exit(1)
# 3. 分割数据集
dataset_dict = split_dataset(dataset)
# 4. 打印统计
print_statistics(dataset_dict)
# 5. 保存为 Parquet
save_parquet(dataset_dict, output_prefix)
# 完成
print("n" + "=" * 70)
print(" 完成!")
print("=" * 70)
print("n 生成的文件:")
print(f" - {output_prefix}_train.parquet")
print(f" - {output_prefix}_val.parquet")
print(f" - {output_prefix}_test.parquet")
print("n 如何使用:")
print(" from datasets import load_dataset")
print()
print(" # 加载训练集")
print(f" train_dataset = load_dataset('parquet', data_files='{output_prefix}_train.parquet')")
print()
print(" # 或加载所有分割")
print(f" dataset = load_dataset('parquet', data_files={{")
print(f" 'train': '{output_prefix}_train.parquet',")
print(f" 'val': '{output_prefix}_val.parquet',")
print(f" 'test': '{output_prefix}_test.parquet'")
print(f" }})")
print()
print(" # 用于微调")
print(" train_data = dataset['train']")
except Exception as e:
print(f"n 错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == '__main__':
main()