diff --git a/ai_x/train_resnet_classification.ipynb b/ai_x/train_resnet_classification.ipynb index 8de3a50..8eec1e8 100644 --- a/ai_x/train_resnet_classification.ipynb +++ b/ai_x/train_resnet_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "530389db", + "id": "0", "metadata": { "pycharm": { "name": "#%% md\n" @@ -29,7 +29,7 @@ } }, "cell_type": "markdown", - "id": "68406d8c-b9df-449e-9c12-bb89143c9c46", + "id": "1", "metadata": { "tags": [] }, @@ -44,7 +44,7 @@ }, { "cell_type": "markdown", - "id": "137be2b4", + "id": "2", "metadata": {}, "source": [ "## 环境准备\n", @@ -59,7 +59,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f557f629-ec73-47be-884f-58346fd4bf12", + "id": "3", "metadata": { "tags": [] }, @@ -71,7 +71,7 @@ }, { "cell_type": "markdown", - "id": "731bb59a-52ff-463a-af06-9ada7f4ded31", + "id": "4", "metadata": {}, "source": [ "如果你在如[昇思大模型平台](https://xihe.mindspore.cn/training-projects)、[华为云ModelArts](https://www.huaweicloud.com/product/modelarts.html)、[启智社区](https://openi.pcl.ac.cn/)等算力平台的Jupyter在线编程环境中运行本案例,可取消如下代码的注释,进行依赖库安装:" @@ -80,7 +80,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3eab2c4a", + "id": "5", "metadata": { "tags": [] }, @@ -95,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a87f1302-0261-4d93-a716-6579ee09473f", + "id": "6", "metadata": { "tags": [] }, @@ -120,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "13b77923", + "id": "7", "metadata": {}, "source": [ "其他场景可参考[MindSpore安装指南](https://www.mindspore.cn/install)进行环境搭建。" @@ -128,7 +128,7 @@ }, { "cell_type": "markdown", - "id": "707edc02-9007-45de-8a50-1c02672a6506", + "id": "8", "metadata": {}, "source": [ "## 数据加载与预处理\n", @@ -138,7 +138,7 @@ }, { "cell_type": "markdown", - "id": "f3b5736c-9d1e-4b14-a5ef-180ae665279e", + "id": "9", "metadata": {}, "source": [ "#### **数据下载**" @@ -147,7 +147,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fd60d946-63e4-4364-82f6-cfaca6e17ad3", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -164,7 +164,7 @@ }, { "cell_type": "markdown", - "id": "bd962595-1e6d-4b85-a79a-2a356788de42", + "id": "11", "metadata": {}, "source": [ "#### **数据裁剪**\n", @@ -174,7 +174,7 @@ { "cell_type": "code", "execution_count": null, - "id": "231fad30-b2d5-4541-8015-1f482fa014e0", + "id": "12", "metadata": { "tags": [] }, @@ -204,7 +204,7 @@ }, { "cell_type": "markdown", - "id": "e3cd1102-4242-4203-9099-99e1fcfb0b5c", + "id": "13", "metadata": {}, "source": [ "#### **数据集划分**\n", @@ -214,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "68211057-19b4-4ac4-a38f-ba685acf4414", + "id": "14", "metadata": { "tags": [] }, @@ -280,7 +280,7 @@ }, { "cell_type": "markdown", - "id": "68c55815-d08a-4b24-a26b-8deb75f01bea", + "id": "15", "metadata": {}, "source": [ "#### **定义数据加载方式**\n", @@ -291,7 +291,7 @@ { "cell_type": "code", "execution_count": null, - "id": "72837cc1-602a-4c7c-8fe3-7a39d6abf952", + "id": "16", "metadata": { "tags": [] }, @@ -341,7 +341,7 @@ }, { "cell_type": "markdown", - "id": "35c20baa-42cf-41e3-a220-2381fbe1c27f", + "id": "17", "metadata": {}, "source": [ "#### **加载数据**\n", @@ -351,7 +351,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f75b6284-80d4-4c74-981c-c07641786906", + "id": "18", "metadata": { "tags": [] }, @@ -397,7 +397,7 @@ }, { "cell_type": "markdown", - "id": "6c979ccd-477b-491a-8b4c-39bee552ca70", + "id": "19", "metadata": {}, "source": [ "#### **类别标签说明**\n", @@ -419,7 +419,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9e3afe8f-8f23-42f4-8536-94e1c5893c05", + "id": "20", "metadata": { "tags": [] }, @@ -439,7 +439,7 @@ }, { "cell_type": "markdown", - "id": "c48cf89d", + "id": "21", "metadata": {}, "source": [ "#### **数据可视化**" @@ -448,7 +448,7 @@ { "cell_type": "code", "execution_count": null, - "id": "baf82d45-7453-4214-b710-6b6b0643202f", + "id": "22", "metadata": { "tags": [] }, @@ -463,7 +463,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b3cc1899-836f-43f7-9ac2-156eea0e8f41", + "id": "23", "metadata": { "tags": [] }, @@ -485,7 +485,7 @@ }, { "cell_type": "markdown", - "id": "740dbcee", + "id": "24", "metadata": {}, "source": [ "## 模型构建\n", @@ -496,7 +496,7 @@ }, { "cell_type": "markdown", - "id": "c146acee-2cb8-4120-bc18-21b8f250123d", + "id": "25", "metadata": {}, "source": [ "残差结构是ResNet网络中最重要的结构,由两个分支构成:一个主分支,一个shortcuts。主分支通过堆叠一系列的卷积操作得到,shortcuts从输入直接到输出,主分支的输出与shortcuts的输出相加后通过Relu激活函数后即为残差网络最后的输出。\n", @@ -506,7 +506,7 @@ }, { "cell_type": "markdown", - "id": "11d4ec89-b80b-47ad-b993-58481fe173c5", + "id": "26", "metadata": {}, "source": [ "#### **定义 Building Block**\n", @@ -522,7 +522,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1a99595c-3dc3-4733-8bd1-377edfada070", + "id": "27", "metadata": { "tags": [] }, @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "41bec72d-540e-4e17-bcfb-d482ae5d1b3a", + "id": "28", "metadata": {}, "source": [ "#### **定义 Bottleneck**\n", @@ -588,7 +588,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a767e952", + "id": "29", "metadata": { "tags": [] }, @@ -636,7 +636,7 @@ }, { "cell_type": "markdown", - "id": "b45270a9-e34f-4a1e-909f-d2d332a5dab9", + "id": "30", "metadata": {}, "source": [ "#### **构建ResNet网络**\n", @@ -646,7 +646,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40671719-6a91-4a60-af68-79e53ca770f9", + "id": "31", "metadata": { "tags": [] }, @@ -673,7 +673,7 @@ }, { "cell_type": "markdown", - "id": "4f301fd5-d351-4d73-9e2f-30a964aca681", + "id": "32", "metadata": {}, "source": [ "实现典型的ResNet架构,包括多个残差层、卷积层、池化层、全连接层等。" @@ -682,7 +682,7 @@ { "cell_type": "code", "execution_count": null, - "id": "77f3f87a-72c9-4d05-b91c-664983a1da07", + "id": "33", "metadata": { "tags": [] }, @@ -726,7 +726,7 @@ }, { "cell_type": "markdown", - "id": "b1a57ff0-09f9-48cf-8025-61220b306616", + "id": "34", "metadata": {}, "source": [ "使用函数resnet50和辅助函数_resnet,来加载一个预训练的ResNet-50模型,或者返回一个未预训练的ResNet-50模型。" @@ -735,7 +735,7 @@ { "cell_type": "code", "execution_count": null, - "id": "261e7463-1e4c-4d09-939a-f9520749d559", + "id": "35", "metadata": { "tags": [] }, @@ -762,7 +762,7 @@ }, { "cell_type": "markdown", - "id": "716057a3-f5aa-4d39-9cd1-1b582c281c85", + "id": "36", "metadata": {}, "source": [ "#### **ResNet分类模型初始化**\n", @@ -772,7 +772,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ee91e4d-0d21-43f4-a72d-7ebbd76188a7", + "id": "37", "metadata": { "tags": [] }, @@ -787,7 +787,7 @@ }, { "cell_type": "markdown", - "id": "5d80b693", + "id": "38", "metadata": {}, "source": [ "## 模型训练\n", @@ -806,7 +806,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9794104f", + "id": "39", "metadata": { "tags": [] }, @@ -827,7 +827,7 @@ }, { "cell_type": "markdown", - "id": "54792cd0", + "id": "40", "metadata": {}, "source": [ "#### **定义训练推理函数**\n", @@ -837,7 +837,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a228b1a", + "id": "41", "metadata": { "tags": [] }, @@ -885,7 +885,7 @@ }, { "cell_type": "markdown", - "id": "7cb94f32", + "id": "42", "metadata": {}, "source": [ "#### **开始训练**\n", @@ -896,7 +896,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2378a76f", + "id": "43", "metadata": { "tags": [] }, @@ -930,7 +930,7 @@ }, { "cell_type": "markdown", - "id": "07c4a56f-638c-48e6-a515-aa31a492224f", + "id": "44", "metadata": {}, "source": [ "#### **结果可视化展示**\n", @@ -940,7 +940,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8defcb02", + "id": "45", "metadata": { "tags": [] }, @@ -973,7 +973,7 @@ }, { "cell_type": "markdown", - "id": "fbf2898d", + "id": "46", "metadata": {}, "source": [ "## 模型推理\n", @@ -983,7 +983,7 @@ }, { "cell_type": "markdown", - "id": "86dce5f6-0709-4196-b8a5-d31c860d0ea3", + "id": "47", "metadata": {}, "source": [ "#### **加载模型**\n", @@ -993,7 +993,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5b7a0f2", + "id": "48", "metadata": { "tags": [] }, @@ -1013,7 +1013,7 @@ }, { "cell_type": "markdown", - "id": "f43d219a-0ecb-4fcd-9d51-62bdad33e48d", + "id": "49", "metadata": {}, "source": [ "#### **通过传入测试数据集进行推理**\n", @@ -1023,7 +1023,7 @@ { "cell_type": "code", "execution_count": null, - "id": "196160df-c480-4722-859f-578cb7bf422c", + "id": "50", "metadata": { "tags": [] }, @@ -1062,7 +1062,7 @@ }, { "cell_type": "markdown", - "id": "49db7c27-794a-4478-8615-63315a0210ea", + "id": "51", "metadata": {}, "source": [ "展示模型的预测结果与真实标签的对比" @@ -1071,7 +1071,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54e90bbd-cc6f-4776-b1ea-810fe9724947", + "id": "52", "metadata": { "tags": [] }, @@ -1082,7 +1082,7 @@ }, { "cell_type": "markdown", - "id": "cb0416a0-b6ce-44c7-b7d2-06b3319223cc", + "id": "53", "metadata": {}, "source": [ "### **参考文献**\n", diff --git a/cv/README.md b/cv/README.md index d8edae0..c225229 100644 --- a/cv/README.md +++ b/cv/README.md @@ -8,7 +8,7 @@ This directory contains ready-to-use Computer Vision application notebooks built | :-- | :---- | :-------------------------------- | | 1 | [ResNet](./resnet/) | Includes notebooks for ResNet finetuning on tasks such as chinese herbal classification | | 2 | [U-Net](./unet/) | Includes notebooks for U-Net training on tasks such as segmentation | -| 3 | [SAM](./sam/) | Includes notebooks for using SAM to inference | +| 3 | [OCR](./ocr/) | Includes notebooks for OCR inference on tasks such as deepseek-ocr demo | ## Contributing New CV Applications diff --git a/cv/ocr/app.py b/cv/ocr/app.py new file mode 100644 index 0000000..93fb73b --- /dev/null +++ b/cv/ocr/app.py @@ -0,0 +1,480 @@ +""" +DeepSeek-OCR MindSpore DEMO +基于 MindSpore 2.7.0 + MindNLP 0.5.1 的文本识别与结构化解析交互式 DEMO +支持流式生成、token 时间统计和性能优化 +""" + +import os +import math +import time +import types +import tempfile +from threading import Thread +from typing import Optional + +from PIL import Image, ImageOps + +import mindspore as ms +ms.set_context(device_target="Ascend", device_id=0) + +import mindnlp # noqa: F401 — patches transformers for MindSpore +import mindtorch +import torch +import torch.nn.functional as F + +from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer + +import gradio as gr + +# ============================================================ +# 全局配置 +# ============================================================ +MODEL_NAME = "lvyufeng/DeepSeek-OCR" +IMAGE_TOKEN = "" +IMAGE_TOKEN_ID = 128815 +PATCH_SIZE = 16 +DOWNSAMPLE_RATIO = 4 +BOS_ID = 0 +STOP_STR = "<|end▁of▁sentence|>" + +# 分辨率预设 +RESOLUTION_PRESETS = { + "Tiny (512, 快速)": {"base_size": 512, "image_size": 512, "crop_mode": False}, + "Small (640)": {"base_size": 640, "image_size": 640, "crop_mode": False}, + "Base (1024)": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, + "Large (1280)": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, + "Gundam (推荐)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, +} + +# 任务类型 +TASK_PROMPTS = { + "Free OCR": "\nFree OCR. ", + "转换为 Markdown": "\n<|grounding|>Convert the document to markdown. ", + "解析图表": "\nParse the figure. ", + "文本定位": "\n<|grounding|>Find \"{ref_text}\". ", +} + +# ============================================================ +# 模型加载(从模型文件中导入辅助函数) +# ============================================================ +print("=" * 60) +print("正在加载 DeepSeek-OCR 模型...") +print(f"模型: {MODEL_NAME}") +print("=" * 60) + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) +model = AutoModel.from_pretrained( + MODEL_NAME, + _attn_implementation="eager", + trust_remote_code=True, + use_safetensors=True, + device_map="auto", +) +model = model.eval() + +print("正在合并 MoE 权重 (combine_moe)...") +model.combine_moe() + +# 修复 NPU 不支持 scatter_add 的问题:用 one_hot + 矩阵乘法替代 +def _patched_forward_for_moe(self, hidden_states): + batch_size, sequence_length, hidden_dim = hidden_states.shape + selected_experts, routing_weights = self.gate(hidden_states) + n_experts = self.config.n_routed_experts + routing_weights = routing_weights.to(hidden_states.dtype) + # 用 one_hot 替代 scatter_add + one_hot = F.one_hot(selected_experts, n_experts).to(routing_weights.dtype) + router_scores = (one_hot * routing_weights.unsqueeze(-1)).sum(dim=1) + hidden_states = hidden_states.view(-1, hidden_dim) + if self.config.n_shared_experts is not None: + shared_expert_output = self.shared_experts(hidden_states) + hidden_w1 = torch.matmul(hidden_states, self.w1) + hidden_w3 = torch.matmul(hidden_states, self.w3) + hidden_states = self.act(hidden_w1) * hidden_w3 + hidden_states = torch.bmm(hidden_states, self.w2) * torch.transpose(router_scores, 0, 1).unsqueeze(-1) + final_hidden_states = hidden_states.sum(dim=0, dtype=hidden_states.dtype) + if self.config.n_shared_experts is not None: + hidden_states = final_hidden_states + shared_expert_output + return hidden_states.view(batch_size, sequence_length, hidden_dim) + +# 对所有 MoE 层应用修复后的 forward +for layer in model.model.layers: + if hasattr(layer.mlp, 'w1'): # combine_moe 已处理的层 + layer.mlp.forward = types.MethodType(_patched_forward_for_moe, layer.mlp) + +print("模型加载完成!") +print("=" * 60) + +# 从模型的 trust_remote_code 模块中获取辅助函数 +# 这些函数通过 trust_remote_code=True 加载后可在模块中找到 +_model_module = type(model).__module__ +import importlib + +_mod = importlib.import_module(_model_module) +format_messages = _mod.format_messages +load_pil_images = _mod.load_pil_images +text_encode = _mod.text_encode +BasicImageTransform = _mod.BasicImageTransform +dynamic_preprocess = _mod.dynamic_preprocess +re_match = _mod.re_match +process_image_with_refs = _mod.process_image_with_refs + + +# ============================================================ +# 图像预处理(从 model.infer() 方法中抽取) +# ============================================================ +def prepare_inputs(prompt_text: str, image_file: str, base_size: int, image_size: int, crop_mode: bool): + """ + 从 model.infer() 方法 (modeling_deepseekocr.py:732-937) 中抽取的图像预处理逻辑。 + 构建 conversation -> format_messages -> 图像 token 化 -> 返回模型输入张量。 + """ + # 1. 构建 conversation + conversation = [ + { + "role": "<|User|>", + "content": prompt_text, + "images": [image_file], + }, + {"role": "<|Assistant|>", "content": ""}, + ] + + # 2. format_messages 转换 prompt + formatted_prompt = format_messages(conversations=conversation, sft_format="plain", system_prompt="") + + # 3. 加载图片 + images = load_pil_images(conversation) + image_draw = images[0].copy() + + # 4. 图像 token 化 + image_transform = BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) + + text_splits = formatted_prompt.split(IMAGE_TOKEN) + + images_list, images_crop_list, images_seq_mask = [], [], [] + tokenized_str = [] + images_spatial_crop = [] + + for text_sep, image in zip(text_splits, images): + tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + if crop_mode: + if image.size[0] <= 640 and image.size[1] <= 640: + crop_ratio = [1, 1] + else: + images_crop_raw, crop_ratio = dynamic_preprocess(image) + + # 全局视图 + global_view = ImageOps.pad( + image, (base_size, base_size), + color=tuple(int(x * 255) for x in image_transform.mean), + ) + images_list.append(image_transform(global_view).to(model.dtype)) + + width_crop_num, height_crop_num = crop_ratio + images_spatial_crop.append([width_crop_num, height_crop_num]) + + if width_crop_num > 1 or height_crop_num > 1: + for i in range(len(images_crop_raw)): + images_crop_list.append(image_transform(images_crop_raw[i]).to(model.dtype)) + + num_queries = math.ceil((image_size // PATCH_SIZE) / DOWNSAMPLE_RATIO) + num_queries_base = math.ceil((base_size // PATCH_SIZE) / DOWNSAMPLE_RATIO) + + # 图像 token 序列 + tokenized_image = ([IMAGE_TOKEN_ID] * num_queries_base + [IMAGE_TOKEN_ID]) * num_queries_base + tokenized_image += [IMAGE_TOKEN_ID] + if width_crop_num > 1 or height_crop_num > 1: + tokenized_image += ( + [IMAGE_TOKEN_ID] * (num_queries * width_crop_num) + [IMAGE_TOKEN_ID] + ) * (num_queries * height_crop_num) + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + else: + if image_size <= 640: + image = image.resize((image_size, image_size)) + global_view = ImageOps.pad( + image, (image_size, image_size), + color=tuple(int(x * 255) for x in image_transform.mean), + ) + images_list.append(image_transform(global_view).to(model.dtype)) + + width_crop_num, height_crop_num = 1, 1 + images_spatial_crop.append([width_crop_num, height_crop_num]) + + num_queries = math.ceil((image_size // PATCH_SIZE) / DOWNSAMPLE_RATIO) + + tokenized_image = ([IMAGE_TOKEN_ID] * num_queries + [IMAGE_TOKEN_ID]) * num_queries + tokenized_image += [IMAGE_TOKEN_ID] + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + + # 最后一段文本 + tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + # 添加 BOS token + tokenized_str = [BOS_ID] + tokenized_str + images_seq_mask = [False] + images_seq_mask + + # 转为张量 + input_ids = torch.LongTensor(tokenized_str) + images_seq_mask_t = torch.tensor(images_seq_mask, dtype=torch.bool) + + if len(images_list) == 0: + images_ori = torch.zeros((1, 3, image_size, image_size)) + images_spatial_crop_t = torch.zeros((1, 2), dtype=torch.long) + images_crop = torch.zeros((1, 3, base_size, base_size)) + else: + images_ori = torch.stack(images_list, dim=0) + images_spatial_crop_t = torch.tensor(images_spatial_crop, dtype=torch.long) + if images_crop_list: + images_crop = torch.stack(images_crop_list, dim=0) + else: + images_crop = torch.zeros((1, 3, base_size, base_size)) + + return { + "input_ids": input_ids.unsqueeze(0).cuda(), + "images": [(images_crop.cuda(), images_ori.cuda())], + "images_seq_mask": images_seq_mask_t.unsqueeze(0).cuda(), + "images_spatial_crop": images_spatial_crop_t, + "image_draw": image_draw, + } + + +# ============================================================ +# 后处理:标注图生成 +# ============================================================ +def postprocess_output(raw_text: str, image_draw: Image.Image): + """处理模型输出,生成带标注的图像。""" + if raw_text.endswith(STOP_STR): + raw_text = raw_text[: -len(STOP_STR)] + raw_text = raw_text.strip() + + matches_ref, matches_images, matches_other = re_match(raw_text) + + annotated_image = None + if matches_ref: + with tempfile.TemporaryDirectory() as tmp_dir: + os.makedirs(os.path.join(tmp_dir, "images"), exist_ok=True) + annotated_image = process_image_with_refs(image_draw, matches_ref, tmp_dir) + + # 无标注时返回原图 + if annotated_image is None: + annotated_image = image_draw + + # 清理特殊标记,保留可读文本 + # matches_ref 是元组列表: [(full_match, ref_text, det_coords), ...] + display_text = raw_text + for full_match, ref_text, det_coords in matches_ref: + if ref_text == "image": + display_text = display_text.replace(full_match, "[图片区域]") + else: + # 仅去除定位标签,保留引用文本内容 + display_text = display_text.replace(full_match, ref_text) + display_text = display_text.replace("\\coloneqq", ":=").replace("\\eqqcolon", "=:") + + return display_text, annotated_image + + +# ============================================================ +# 流式推理 + 时间统计 +# ============================================================ +def format_metrics(ttft: Optional[float], token_count: int, t_start: float) -> str: + """格式化性能指标。""" + elapsed = time.time() - t_start + lines = [] + lines.append(f"**首 Token 延迟 (TTFT)**: {ttft:.3f}s" if ttft else "**首 Token 延迟 (TTFT)**: 等待中...") + lines.append(f"**已生成 Token 数**: {token_count}") + lines.append(f"**总耗时**: {elapsed:.2f}s") + if token_count > 0 and elapsed > 0: + tokens_per_sec = token_count / elapsed + lines.append(f"**生成速度**: {tokens_per_sec:.2f} tokens/s") + if token_count > 1 and ttft: + decode_time = elapsed - ttft + decode_speed = (token_count - 1) / decode_time if decode_time > 0 else 0 + lines.append(f"**解码速度** (不含首 token): {decode_speed:.2f} tokens/s") + return "\n\n".join(lines) + + +def stream_ocr(image, resolution, task_type, ref_text): + """ + 流式 OCR 推理函数。 + 使用 TextIteratorStreamer 实现流式 token 输出。 + """ + if image is None: + yield "请上传图片", None, "请先上传一张图片" + return + + # 获取分辨率参数 + preset = RESOLUTION_PRESETS[resolution] + base_size = preset["base_size"] + image_size = preset["image_size"] + crop_mode = preset["crop_mode"] + + # 构建 prompt + prompt_template = TASK_PROMPTS[task_type] + if "{ref_text}" in prompt_template: + if not ref_text or not ref_text.strip(): + yield "请输入要定位的文本", None, "「文本定位」模式需要输入引用文本" + return + prompt_text = prompt_template.format(ref_text=ref_text.strip()) + else: + prompt_text = prompt_template + + # 保存临时图片文件供模型使用 + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + tmp_path = tmp.name + Image.fromarray(image).save(tmp_path) + + try: + # 1. 准备输入 + model.disable_torch_init() + inputs = prepare_inputs(prompt_text, tmp_path, base_size, image_size, crop_mode) + image_draw = inputs.pop("image_draw") + + # 2. 创建 streamer + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) + + # 3. 后台线程运行 generate + generate_kwargs = dict( + input_ids=inputs["input_ids"], + images=inputs["images"], + images_seq_mask=inputs["images_seq_mask"], + images_spatial_crop=inputs["images_spatial_crop"], + temperature=0.0, + eos_token_id=tokenizer.eos_token_id, + streamer=streamer, + max_new_tokens=8192, + no_repeat_ngram_size=20, + use_cache=True, + ) + + thread = Thread(target=_generate_with_no_grad, kwargs=generate_kwargs) + + # 4. 流式输出 + 时间统计 + t_start = time.time() + thread.start() + first_token_time = None + token_count = 0 + full_text = "" + + for new_text in streamer: + if first_token_time is None: + first_token_time = time.time() - t_start + token_count += 1 + full_text += new_text + # 流式 yield:显示文本、暂无标注图、实时指标 + display = full_text.replace(STOP_STR, "").strip() + yield display, None, format_metrics(first_token_time, token_count, t_start) + + thread.join() + + # 5. 最终后处理 + display_text, annotated_image = postprocess_output(full_text, image_draw) + final_metrics = format_metrics(first_token_time, token_count, t_start) + yield display_text, annotated_image, final_metrics + + finally: + os.unlink(tmp_path) + + +def _generate_with_no_grad(**kwargs): + """在 no_grad 上下文中运行 model.generate。""" + with torch.no_grad(): + model.generate(**kwargs) + + +# ============================================================ +# Gradio UI +# ============================================================ +def toggle_ref_text(task_type): + """根据任务类型切换引用文本输入框可见性。""" + return gr.update(visible=(task_type == "文本定位")) + + +DESCRIPTION = """ +# DeepSeek-OCR MindSpore DEMO + +基于 **MindSpore 2.7.0 + MindNLP 0.5.1** 的文本识别与结构化解析交互式演示。 + +**模型**: DeepSeek-OCR | **硬件**: Ascend NPU 910B | **优化**: MoE 权重合并 + KV Cache + +### 性能优化说明 +| 优化项 | 说明 | +|--------|------| +| `combine_moe()` | 合并 MoE 专家权重,减少内存访问开销 | +| `scatter_add` 适配 | 用 `one_hot` + 矩阵乘法替代 NPU 不支持的 `scatter_add` | +| `use_cache=True` | 启用 KV Cache,避免重复计算注意力 | +| `no_repeat_ngram_size=20` | 控制重复生成,提升有效 token 效率 | +| `eager` attention | Ascend NPU 上兼容性最佳的注意力实现 | +| `float32` 精度 | 保证 OCR 输出质量(float16 存在精度退化)| + +### 优化前后对比(Gundam 模式,Ascend 910B,256 tokens) +| 配置 | TTFT | 生成速度 | 解码速度 | 加速比 | +|------|------|----------|----------|--------| +| **全部优化** | 9.757s | 7.95 tok/s | **11.34 tok/s** | **基线** | +| 关闭 MoE 合并 | 10.805s | 1.68 tok/s | 2.29 tok/s | **4.95x 慢** | + +### 不同分辨率模式对比(256 tokens) +| 模式 | TTFT | 生成速度 | 解码速度 | 适用场景 | +|------|------|----------|----------|----------| +| Tiny (512) | **0.214s** | **11.00 tok/s** | 11.06 tok/s | 快速预览 | +| Small (640) | 0.257s | 10.76 tok/s | 10.83 tok/s | 一般文档 | +| **Gundam (推荐)** | 9.757s | 7.95 tok/s | 11.34 tok/s | **精度最佳** | +""" + +with gr.Blocks(title="DeepSeek-OCR MindSpore DEMO") as demo: + gr.Markdown(DESCRIPTION) + + with gr.Row(): + # 左侧:输入区 + with gr.Column(scale=1): + input_image = gr.Image(label="上传图片", type="numpy", height=400) + + resolution = gr.Dropdown( + choices=list(RESOLUTION_PRESETS.keys()), + value="Gundam (推荐)", + label="分辨率模式", + info="Gundam 模式在精度和速度之间取得最佳平衡", + ) + + task_type = gr.Dropdown( + choices=list(TASK_PROMPTS.keys()), + value="Free OCR", + label="任务类型", + ) + + ref_text_input = gr.Textbox( + label="引用文本(仅「文本定位」模式)", + placeholder="输入要定位的文本...", + visible=False, + ) + + run_btn = gr.Button("开始识别", variant="primary", size="lg") + + # 右侧:输出区 + with gr.Column(scale=1): + output_text = gr.Textbox( + label="OCR 识别结果", + lines=15, + max_lines=30, + buttons=["copy"], + ) + output_image = gr.Image(label="标注结果图", height=300) + metrics_display = gr.Markdown(label="性能统计", value="等待推理...") + + # 事件绑定 + task_type.change(fn=toggle_ref_text, inputs=task_type, outputs=ref_text_input) + + run_btn.click( + fn=stream_ocr, + inputs=[input_image, resolution, task_type, ref_text_input], + outputs=[output_text, output_image, metrics_display], + ) + + +if __name__ == "__main__": + demo.queue() + demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme=gr.themes.Soft()) diff --git a/cv/ocr/inference_deepSeekorc_demo.ipynb b/cv/ocr/inference_deepSeekorc_demo.ipynb new file mode 100644 index 0000000..74b47b9 --- /dev/null +++ b/cv/ocr/inference_deepSeekorc_demo.ipynb @@ -0,0 +1,520 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DeepSeek-OCR MindSpore DEMO\n", + "\n", + "基于 **MindSpore 2.7.0 + MindNLP 0.5.1** 的文本识别与结构化解析演示。\n", + "\n", + "## 环境要求\n", + "\n", + "| 组件 | 版本 |\n", + "|------|------|\n", + "| Python | 3.10 |\n", + "| MindSpore | 2.7.0 |\n", + "| MindNLP | 0.5.1 |\n", + "| transformers | 4.57.3 |\n", + "| Gradio | 6.1.0 |\n", + "| 硬件 | Ascend NPU 910B (65536MB HBM) |\n", + "| CANN | 8.2.RC2 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 1: 环境检查\n", + "import mindspore as ms\n", + "ms.set_context(device_target=\"Ascend\", device_id=0)\n", + "print(f\"MindSpore version: {ms.__version__}\")\n", + "\n", + "import mindnlp\n", + "print(f\"MindNLP available\")\n", + "\n", + "import transformers\n", + "print(f\"transformers version: {transformers.__version__}\")\n", + "\n", + "import gradio as gr\n", + "print(f\"Gradio version: {gr.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 模型加载\n", + "\n", + "使用 MindNLP 的 transformers 兼容接口加载 DeepSeek-OCR 模型。\n", + "\n", + "**关键参数说明**:\n", + "- `_attn_implementation='eager'`: Ascend NPU 上兼容性最佳的注意力实现\n", + "- `trust_remote_code=True`: 加载模型自定义代码\n", + "- `use_safetensors=True`: 使用安全张量格式" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 2: 模型加载\n", + "import types\n", + "import mindnlp\n", + "import mindtorch\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from transformers import AutoModel, AutoTokenizer\n", + "\n", + "model_name = 'lvyufeng/DeepSeek-OCR'\n", + "\n", + "print(\"加载 tokenizer...\")\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", + "\n", + "print(\"加载模型 (float32)...\")\n", + "model = AutoModel.from_pretrained(\n", + " model_name,\n", + " _attn_implementation='eager',\n", + " trust_remote_code=True,\n", + " use_safetensors=True,\n", + " device_map='auto'\n", + ")\n", + "model = model.eval()\n", + "\n", + "print(\"合并 MoE 权重...\")\n", + "model.combine_moe()\n", + "\n", + "# NPU 不支持 scatter_add 用 one_hot 替代\n", + "def _patched_forward_for_moe(self, hidden_states):\n", + " batch_size, sequence_length, hidden_dim = hidden_states.shape\n", + " selected_experts, routing_weights = self.gate(hidden_states)\n", + " n_experts = self.config.n_routed_experts\n", + " routing_weights = routing_weights.to(hidden_states.dtype)\n", + " one_hot = F.one_hot(selected_experts, n_experts).to(routing_weights.dtype)\n", + " router_scores = (one_hot * routing_weights.unsqueeze(-1)).sum(dim=1)\n", + " hidden_states = hidden_states.view(-1, hidden_dim)\n", + " if self.config.n_shared_experts is not None:\n", + " shared_expert_output = self.shared_experts(hidden_states)\n", + " hidden_w1 = torch.matmul(hidden_states, self.w1)\n", + " hidden_w3 = torch.matmul(hidden_states, self.w3)\n", + " hidden_states = self.act(hidden_w1) * hidden_w3\n", + " hidden_states = torch.bmm(hidden_states, self.w2) * torch.transpose(router_scores, 0, 1).unsqueeze(-1)\n", + " final_hidden_states = hidden_states.sum(dim=0, dtype=hidden_states.dtype)\n", + " if self.config.n_shared_experts is not None:\n", + " hidden_states = final_hidden_states + shared_expert_output\n", + " return hidden_states.view(batch_size, sequence_length, hidden_dim)\n", + "\n", + "for layer in model.model.layers:\n", + " if hasattr(layer.mlp, 'w1'):\n", + " layer.mlp.forward = types.MethodType(_patched_forward_for_moe, layer.mlp)\n", + "\n", + "print(\"模型加载完成!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 单张图片推理示例(非流式)\n", + "\n", + "使用 `model.infer()` 方法进行标准推理,支持多种分辨率模式:\n", + "\n", + "| 模式 | base_size | image_size | crop_mode | 适用场景 |\n", + "|------|-----------|------------|-----------|----------|\n", + "| Tiny | 512 | 512 | False | 快速预览 |\n", + "| Small | 640 | 640 | False | 一般文档 |\n", + "| Base | 1024 | 1024 | False | 高质量 |\n", + "| Large | 1280 | 1280 | False | 超高分辨率 |\n", + "| **Gundam** | **1024** | **640** | **True** | **推荐:精度速度最佳平衡** |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 3: 单张图片推理(非流式)\n", + "import time\n", + "import os\n", + "\n", + "# 准备测试图片\n", + "# 如果没有测试图片,可以从 HuggingFace 下载\n", + "image_file = 'image_ocr.jpg'\n", + "if not os.path.exists(image_file):\n", + " import urllib.request\n", + " url = 'https://hf-mirror.com/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg'\n", + " print(f\"下载测试图片: {url}\")\n", + " urllib.request.urlretrieve(url, image_file)\n", + " print(\"下载完成\")\n", + "\n", + "prompt = \"\\nFree OCR. \"\n", + "output_path = './output'\n", + "os.makedirs(output_path, exist_ok=True)\n", + "\n", + "print(\"开始推理 (Gundam 模式)...\")\n", + "t0 = time.time()\n", + "\n", + "with mindtorch.no_grad():\n", + " res = model.infer(\n", + " tokenizer,\n", + " prompt=prompt,\n", + " image_file=image_file,\n", + " output_path=output_path,\n", + " base_size=1024,\n", + " image_size=640,\n", + " crop_mode=True,\n", + " save_results=True,\n", + " test_compress=True,\n", + " )\n", + "\n", + "elapsed = time.time() - t0\n", + "print(f\"\\n推理完成,总耗时: {elapsed:.2f}s\")\n", + "\n", + "# 显示结果\n", + "if os.path.exists(f'{output_path}/result.mmd'):\n", + " with open(f'{output_path}/result.mmd', 'r') as f:\n", + " print(\"\\n识别结果:\")\n", + " print(f.read())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 流式生成 + 时间统计\n", + "\n", + "使用 `TextIteratorStreamer` 实现流式 token 输出,可以在生成过程中实时查看结果。\n", + "\n", + "**核心思路**:\n", + "1. 从 `model.infer()` 中抽取图像预处理逻辑为独立函数\n", + "2. 用 `TextIteratorStreamer` 替换原始 `NoEOSTextStreamer`\n", + "3. 在独立线程中运行 `model.generate()`\n", + "4. 主线程通过 streamer 迭代获取 token 并统计时间" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 4: 流式生成 + 时间统计\n", + "import math\n", + "import importlib\n", + "from threading import Thread\n", + "from PIL import Image, ImageOps\n", + "from transformers import TextIteratorStreamer\n", + "\n", + "# 导入模型辅助函数\n", + "_mod = importlib.import_module(type(model).__module__)\n", + "format_messages = _mod.format_messages\n", + "load_pil_images = _mod.load_pil_images\n", + "text_encode = _mod.text_encode\n", + "BasicImageTransform = _mod.BasicImageTransform\n", + "dynamic_preprocess = _mod.dynamic_preprocess\n", + "\n", + "IMAGE_TOKEN = ''\n", + "IMAGE_TOKEN_ID = 128815\n", + "PATCH_SIZE = 16\n", + "DOWNSAMPLE_RATIO = 4\n", + "BOS_ID = 0\n", + "\n", + "\n", + "def prepare_inputs(prompt_text, image_file, base_size, image_size, crop_mode):\n", + " \"\"\"从 model.infer() 中抽取的图像预处理逻辑。\"\"\"\n", + " conversation = [\n", + " {\"role\": \"<|User|>\", \"content\": prompt_text, \"images\": [image_file]},\n", + " {\"role\": \"<|Assistant|>\", \"content\": \"\"},\n", + " ]\n", + " formatted_prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')\n", + " images = load_pil_images(conversation)\n", + "\n", + " image_transform = BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)\n", + " text_splits = formatted_prompt.split(IMAGE_TOKEN)\n", + "\n", + " images_list, images_crop_list, images_seq_mask = [], [], []\n", + " tokenized_str = []\n", + " images_spatial_crop = []\n", + "\n", + " for text_sep, image in zip(text_splits, images):\n", + " tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)\n", + " tokenized_str += tokenized_sep\n", + " images_seq_mask += [False] * len(tokenized_sep)\n", + "\n", + " if crop_mode:\n", + " if image.size[0] <= 640 and image.size[1] <= 640:\n", + " crop_ratio = [1, 1]\n", + " else:\n", + " images_crop_raw, crop_ratio = dynamic_preprocess(image)\n", + "\n", + " global_view = ImageOps.pad(image, (base_size, base_size),\n", + " color=tuple(int(x * 255) for x in image_transform.mean))\n", + " images_list.append(image_transform(global_view).to(model.dtype))\n", + " width_crop_num, height_crop_num = crop_ratio\n", + " images_spatial_crop.append([width_crop_num, height_crop_num])\n", + "\n", + " if width_crop_num > 1 or height_crop_num > 1:\n", + " for i in range(len(images_crop_raw)):\n", + " images_crop_list.append(image_transform(images_crop_raw[i]).to(model.dtype))\n", + "\n", + " num_queries = math.ceil((image_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + " num_queries_base = math.ceil((base_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + "\n", + " tokenized_image = ([IMAGE_TOKEN_ID] * num_queries_base + [IMAGE_TOKEN_ID]) * num_queries_base\n", + " tokenized_image += [IMAGE_TOKEN_ID]\n", + " if width_crop_num > 1 or height_crop_num > 1:\n", + " tokenized_image += ([IMAGE_TOKEN_ID] * (num_queries * width_crop_num) + [IMAGE_TOKEN_ID]) * (\n", + " num_queries * height_crop_num)\n", + " tokenized_str += tokenized_image\n", + " images_seq_mask += [True] * len(tokenized_image)\n", + " else:\n", + " if image_size <= 640:\n", + " image = image.resize((image_size, image_size))\n", + " global_view = ImageOps.pad(image, (image_size, image_size),\n", + " color=tuple(int(x * 255) for x in image_transform.mean))\n", + " images_list.append(image_transform(global_view).to(model.dtype))\n", + " images_spatial_crop.append([1, 1])\n", + "\n", + " num_queries = math.ceil((image_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + " tokenized_image = ([IMAGE_TOKEN_ID] * num_queries + [IMAGE_TOKEN_ID]) * num_queries\n", + " tokenized_image += [IMAGE_TOKEN_ID]\n", + " tokenized_str += tokenized_image\n", + " images_seq_mask += [True] * len(tokenized_image)\n", + "\n", + " tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)\n", + " tokenized_str += tokenized_sep\n", + " images_seq_mask += [False] * len(tokenized_sep)\n", + " tokenized_str = [BOS_ID] + tokenized_str\n", + " images_seq_mask = [False] + images_seq_mask\n", + "\n", + " input_ids = torch.LongTensor(tokenized_str)\n", + " images_seq_mask_t = torch.tensor(images_seq_mask, dtype=torch.bool)\n", + "\n", + " if len(images_list) == 0:\n", + " images_ori = torch.zeros((1, 3, image_size, image_size))\n", + " images_spatial_crop_t = torch.zeros((1, 2), dtype=torch.long)\n", + " images_crop = torch.zeros((1, 3, base_size, base_size))\n", + " else:\n", + " images_ori = torch.stack(images_list, dim=0)\n", + " images_spatial_crop_t = torch.tensor(images_spatial_crop, dtype=torch.long)\n", + " images_crop = torch.stack(images_crop_list, dim=0) if images_crop_list else torch.zeros((1, 3, base_size, base_size))\n", + "\n", + " return {\n", + " 'input_ids': input_ids.unsqueeze(0).cuda(),\n", + " 'images': [(images_crop.cuda(), images_ori.cuda())],\n", + " 'images_seq_mask': images_seq_mask_t.unsqueeze(0).cuda(),\n", + " 'images_spatial_crop': images_spatial_crop_t,\n", + " }\n", + "\n", + "\n", + "# 流式推理\n", + "prompt_text = \"\\nFree OCR. \"\n", + "\n", + "model.disable_torch_init()\n", + "inputs = prepare_inputs(prompt_text, image_file, base_size=1024, image_size=640, crop_mode=True)\n", + "\n", + "streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)\n", + "\n", + "generate_kwargs = dict(\n", + " input_ids=inputs['input_ids'],\n", + " images=inputs['images'],\n", + " images_seq_mask=inputs['images_seq_mask'],\n", + " images_spatial_crop=inputs['images_spatial_crop'],\n", + " temperature=0.0,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " streamer=streamer,\n", + " max_new_tokens=8192,\n", + " no_repeat_ngram_size=20,\n", + " use_cache=True,\n", + ")\n", + "\n", + "\n", + "def run_generate():\n", + " with torch.no_grad():\n", + " model.generate(**generate_kwargs)\n", + "\n", + "\n", + "thread = Thread(target=run_generate)\n", + "t_start = time.time()\n", + "thread.start()\n", + "\n", + "first_token_time = None\n", + "token_count = 0\n", + "full_text = \"\"\n", + "STOP_STR = '<|end▁of▁sentence|>'\n", + "\n", + "print(\"流式生成中...\")\n", + "print(\"=\" * 50)\n", + "for new_text in streamer:\n", + " if first_token_time is None:\n", + " first_token_time = time.time() - t_start\n", + " token_count += 1\n", + " full_text += new_text\n", + " # 实时输出\n", + " clean = new_text.replace(STOP_STR, '')\n", + " if clean:\n", + " print(clean, end='', flush=True)\n", + "\n", + "thread.join()\n", + "total_time = time.time() - t_start\n", + "\n", + "print(\"\\n\" + \"=\" * 50)\n", + "print(f\"\\n性能统计:\")\n", + "print(f\" 首 Token 延迟 (TTFT): {first_token_time:.3f}s\")\n", + "print(f\" 总 Token 数: {token_count}\")\n", + "print(f\" 总耗时: {total_time:.2f}s\")\n", + "print(f\" 生成速度: {token_count / total_time:.2f} tokens/s\")\n", + "if token_count > 1 and first_token_time:\n", + " decode_time = total_time - first_token_time\n", + " print(f\" 解码速度 (不含首 token): {(token_count - 1) / decode_time:.2f} tokens/s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 性能优化方案说明\n", + "\n", + "### 已实施的优化\n", + "\n", + "| # | 优化项 | 方法 | 效果 |\n", + "|---|--------|------|------|\n", + "| 1 | **MoE 权重合并** | `model.combine_moe()` | 将分散的专家权重合并为矩阵运算,减少内存访问次数,加速前向传播 |\n", + "| 2 | **scatter_add NPU 适配** | `F.one_hot` + 矩阵乘法 | 替换 NPU 不支持的 `scatter_add_ext` 算子,保证 MoE 合并后推理正确性 |\n", + "| 3 | **KV Cache** | `use_cache=True` | 缓存已计算的 Key/Value,避免自回归生成时重复计算所有位置的注意力 |\n", + "| 4 | **N-gram 去重** | `no_repeat_ngram_size=20` | 防止模型生成重复文本,提升有效 token 效率 |\n", + "| 5 | **Eager Attention** | `_attn_implementation='eager'` | 在 Ascend NPU 上比 Flash Attention 兼容性更好,避免算子不支持的问题 |\n", + "\n", + "### 优化前后实测数据对比(Ascend 910B, Gundam 模式, 256 tokens)\n", + "\n", + "| 配置 | TTFT | 生成速度 | 解码速度 | 加速比 |\n", + "|------|------|----------|----------|--------|\n", + "| **全部优化** (combine_moe + KV Cache) | 9.757s | **7.95 tok/s** | **11.34 tok/s** | 基线 |\n", + "| 关闭 MoE 合并 (无 combine_moe) | 10.805s | 1.68 tok/s | 2.29 tok/s | **4.95x 慢** |\n", + "\n", + "> **结论**: `combine_moe()` 是最关键的优化项,使解码速度提升约 **5 倍**(2.29 → 11.34 tok/s)。\n", + "\n", + "### 不同分辨率模式实测对比(256 tokens, 全部优化)\n", + "\n", + "| 模式 | TTFT | 总耗时 | 生成速度 | 解码速度 | 适用场景 |\n", + "|------|------|--------|----------|----------|----------|\n", + "| Tiny (512) | **0.214s** | **23.36s** | **11.00 tok/s** | 11.06 tok/s | 快速预览、低分辨率 |\n", + "| Small (640) | 0.257s | 23.89s | 10.76 tok/s | 10.83 tok/s | 一般文档 |\n", + "| **Gundam (推荐)** | 9.757s | 32.33s | 7.95 tok/s | **11.34 tok/s** | 高精度 OCR |\n", + "\n", + "> **说明**: Gundam 模式 TTFT 较高是因为 crop 模式需要处理多个图像切片(全局视图+局部切片),但解码速度与其他模式持平。Tiny 模式 TTFT 极低(0.2s),适合对延迟敏感的场景。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 5: 不同分辨率模式对比(可选运行)\n", + "import time\n", + "\n", + "results = {}\n", + "modes = {\n", + " 'Tiny': {'base_size': 512, 'image_size': 512, 'crop_mode': False},\n", + " 'Small': {'base_size': 640, 'image_size': 640, 'crop_mode': False},\n", + " 'Gundam': {'base_size': 1024, 'image_size': 640, 'crop_mode': True},\n", + "}\n", + "\n", + "for mode_name, params in modes.items():\n", + " print(f\"\\n{'='*50}\")\n", + " print(f\"测试模式: {mode_name} (base={params['base_size']}, img={params['image_size']}, crop={params['crop_mode']})\")\n", + " print(f\"{'='*50}\")\n", + "\n", + " inputs = prepare_inputs(prompt_text, image_file, **params)\n", + " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)\n", + "\n", + " gen_kwargs = dict(\n", + " input_ids=inputs['input_ids'],\n", + " images=inputs['images'],\n", + " images_seq_mask=inputs['images_seq_mask'],\n", + " images_spatial_crop=inputs['images_spatial_crop'],\n", + " temperature=0.0,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " streamer=streamer,\n", + " max_new_tokens=4096,\n", + " no_repeat_ngram_size=20,\n", + " use_cache=True,\n", + " )\n", + "\n", + " def _gen():\n", + " with torch.no_grad():\n", + " model.generate(**gen_kwargs)\n", + "\n", + " thread = Thread(target=_gen)\n", + " t0 = time.time()\n", + " thread.start()\n", + "\n", + " ttft = None\n", + " n_tokens = 0\n", + " for text in streamer:\n", + " if ttft is None:\n", + " ttft = time.time() - t0\n", + " n_tokens += 1\n", + " thread.join()\n", + " total = time.time() - t0\n", + "\n", + " results[mode_name] = {'ttft': ttft, 'tokens': n_tokens, 'total': total, 'tps': n_tokens / total}\n", + " print(f\" TTFT: {ttft:.3f}s | Tokens: {n_tokens} | Total: {total:.2f}s | Speed: {n_tokens/total:.2f} tok/s\")\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"对比汇总:\")\n", + "print(f\"{'模式':<10} {'TTFT':>8} {'Tokens':>8} {'总耗时':>8} {'速度':>12}\")\n", + "print(f\"{'-'*50}\")\n", + "for name, r in results.items():\n", + " print(f\"{name:<10} {r['ttft']:>7.3f}s {r['tokens']:>8} {r['total']:>7.2f}s {r['tps']:>8.2f} tok/s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gradio 交互 DEMO\n", + "\n", + "启动完整的 Gradio Web 界面,支持:\n", + "- 图片上传\n", + "- 多种分辨率模式选择\n", + "- 多种任务类型(Free OCR / Markdown / 图表解析 / 文本定位)\n", + "- 流式文本输出\n", + "- 实时性能统计\n", + "\n", + "魔乐社区链接:(待更新)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 相关链接\n", + "\n", + "- **魔乐社区 (ModelScope)**: [DeepSeek-OCR 模型页](https://modelers.cn/)\n", + "- **HuggingFace 模型**: [lvyufeng/DeepSeek-OCR](https://huggingface.co/lvyufeng/DeepSeek-OCR)\n", + "- **MindNLP 项目**: [GitHub - mindnlp](https://github.com/mindspore-lab/mindnlp)\n", + "- **MindSpore 官网**: [mindspore.cn](https://www.mindspore.cn/)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/cv/sam/inference_sam_segmentation.ipynb b/cv/sam/inference_sam_segmentation.ipynb index 278e9dd..95ce0eb 100644 --- a/cv/sam/inference_sam_segmentation.ipynb +++ b/cv/sam/inference_sam_segmentation.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "8285b580", + "id": "0", "metadata": {}, "source": [ "# 基于MindSpore 和 MindSpore NLP 的 Segment Anything Model(SAM)通用图像分割推理任务\n", @@ -26,7 +26,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1b550a2d-7114-4a32-bcb6-4f6a8504da19", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -36,7 +36,7 @@ }, { "cell_type": "markdown", - "id": "f5a0708f", + "id": "2", "metadata": {}, "source": [ "如果你在如昇思大模型平台、华为云ModelArts、启智社区等算力平台的Jupyter在线编程环境中运行本案例,可取消如下代码的注释,进行依赖库安装:" @@ -45,7 +45,7 @@ { "cell_type": "code", "execution_count": null, - "id": "359d8d32-e0b4-484c-a0b5-0c4b2956253f", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -62,7 +62,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f78fd9fb-2b63-4372-adeb-a27331a805c3", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -80,7 +80,7 @@ }, { "cell_type": "markdown", - "id": "d735217e", + "id": "5", "metadata": {}, "source": [ "## 数据加载\n", @@ -94,7 +94,7 @@ { "cell_type": "code", "execution_count": null, - "id": "610eb8c7-73b3-47d3-ad43-698b7d9cfe1e", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -127,7 +127,7 @@ }, { "cell_type": "markdown", - "id": "36ab043f-076b-4e0b-b47d-8d8ceed8e875", + "id": "7", "metadata": {}, "source": [ "#### **数据加载**\n", @@ -141,7 +141,7 @@ { "cell_type": "code", "execution_count": null, - "id": "45a22bc6-9399-4044-9897-a61dc2f158ad", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "108b98f7-a054-4fa6-b4f9-2c139ca17355", + "id": "9", "metadata": {}, "source": [ "## 模型推理\n", @@ -180,7 +180,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d06aee8c", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "9fa86db6", + "id": "11", "metadata": {}, "source": [ "#### **传入图像进行推理**\n", @@ -221,7 +221,7 @@ { "cell_type": "code", "execution_count": null, - "id": "411f4369", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "59162075-e83d-46cb-b208-8cf70a92511e", + "id": "13", "metadata": {}, "source": [ "## 结果可视化展示\n", @@ -241,7 +241,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88964116-f7f7-4d6d-b0f7-9f05494e7421", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -270,7 +270,7 @@ }, { "cell_type": "markdown", - "id": "a73bb2f6", + "id": "15", "metadata": {}, "source": [ "#### **最佳掩码与原图叠加的可视化展示**\n", @@ -283,7 +283,7 @@ { "cell_type": "code", "execution_count": null, - "id": "77b7e2f7", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -309,7 +309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4e908d71", + "id": "17", "metadata": {}, "outputs": [], "source": [ diff --git a/llm/distilgpt2/finetune_distilgpt2_language_modeling.ipynb b/llm/distilgpt2/finetune_distilgpt2_language_modeling.ipynb index 3c1a8eb..32cbda7 100644 --- a/llm/distilgpt2/finetune_distilgpt2_language_modeling.ipynb +++ b/llm/distilgpt2/finetune_distilgpt2_language_modeling.ipynb @@ -93,7 +93,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66945dbe", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -115,7 +115,7 @@ }, { "cell_type": "markdown", - "id": "6", + "id": "7", "metadata": {}, "source": [ "#### **设置 MindSpore 上下文**\n", @@ -126,7 +126,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -140,7 +140,7 @@ }, { "cell_type": "markdown", - "id": "8", + "id": "9", "metadata": {}, "source": [ "## 数据加载与预处理\n", @@ -153,7 +153,7 @@ }, { "cell_type": "markdown", - "id": "9", + "id": "10", "metadata": {}, "source": [ "#### **加载 Wikitext-2-raw-v1 数据集**" @@ -162,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "markdown", - "id": "11", + "id": "12", "metadata": {}, "source": [ "## 模型构建\n", @@ -200,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -224,7 +224,7 @@ }, { "cell_type": "markdown", - "id": "13", + "id": "14", "metadata": {}, "source": [ "#### **文本预处理与语言建模数据构造**\n", @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -351,7 +351,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "17", "metadata": {}, "source": [ "### 模型训练\n", @@ -421,7 +421,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -563,7 +563,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "19", "metadata": {}, "source": [ "#### **开始训练**\n", @@ -579,7 +579,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -635,7 +635,7 @@ }, { "cell_type": "markdown", - "id": "20", + "id": "21", "metadata": {}, "source": [ "## 模型推理\n", @@ -652,7 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -719,7 +719,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "23", "metadata": {}, "source": [ "## 模型保存与加载(可选)\n", @@ -731,7 +731,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "24", "metadata": {}, "outputs": [], "source": [ diff --git a/llm/esmforproteinfolding/inference_esmforproteinfolding_prediction.ipynb b/llm/esmforproteinfolding/inference_esmforproteinfolding_prediction.ipynb index 608b357..6fe9df3 100644 --- a/llm/esmforproteinfolding/inference_esmforproteinfolding_prediction.ipynb +++ b/llm/esmforproteinfolding/inference_esmforproteinfolding_prediction.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "829cca7c", + "id": "0", "metadata": {}, "source": [ "# 基于 MindSpore NLP 实现 ESMFold 蛋白质结构预测\n", @@ -25,7 +25,7 @@ }, { "cell_type": "markdown", - "id": "60fe0cd1", + "id": "1", "metadata": {}, "source": [ "## 2. 环境准备\n", @@ -42,7 +42,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f95cc806-1cfb-4308-931a-4291ae0b1ecc", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -58,7 +58,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fc1c20ef", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "f4014443", + "id": "4", "metadata": {}, "source": [ "## 3. 加载 ESMFold 模型\n", @@ -88,7 +88,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41b5efbe", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -109,7 +109,7 @@ }, { "cell_type": "markdown", - "id": "d0205a80", + "id": "6", "metadata": {}, "source": [ "## 4. 蛋白质结构预测 \n", @@ -120,7 +120,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66d4c37a", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -168,7 +168,7 @@ }, { "cell_type": "markdown", - "id": "f954f054", + "id": "8", "metadata": {}, "source": [ "## 5. (进阶) 性能优化:混合精度推理 (AMP)\n", @@ -181,7 +181,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07e4019c", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -237,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "74411a0d-0696-42c8-8138-c38257064e69", + "id": "10", "metadata": {}, "source": [ "说明:本环境下 `infer_pdb` 的导出阶段涉及 Python 字符串格式化,放在 autocast 里可能踩到兼容性问题。\n", @@ -247,7 +247,7 @@ }, { "cell_type": "markdown", - "id": "42971375", + "id": "11", "metadata": {}, "source": [ "## 6. (进阶) 多链复合物预测 (Multimer / Complex)\n", @@ -260,7 +260,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2ba34af", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -315,7 +315,7 @@ }, { "cell_type": "markdown", - "id": "c293e375", + "id": "13", "metadata": {}, "source": [ "## 7. 保存 PDB 文件\n", @@ -326,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a27c3a22", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -340,7 +340,7 @@ }, { "cell_type": "markdown", - "id": "38112517-e67a-400e-86b3-74d73b4e7c28", + "id": "15", "metadata": { "vscode": { "languageId": "ini" @@ -355,7 +355,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0c78ef1f-fe40-43d3-95af-24414cf8d91f", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -376,7 +376,7 @@ }, { "cell_type": "markdown", - "id": "61f1c056", + "id": "17", "metadata": {}, "source": [ "## 9. 模型微调演示 \n", @@ -390,7 +390,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0700e509", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "61653ce6", + "id": "19", "metadata": {}, "source": [ "## 10. 总结\n", diff --git a/llm/t5/finetune_t5_daily_email_summarization.ipynb b/llm/t5/finetune_t5_daily_email_summarization.ipynb index b59dfdd..c3c28f3 100644 --- a/llm/t5/finetune_t5_daily_email_summarization.ipynb +++ b/llm/t5/finetune_t5_daily_email_summarization.ipynb @@ -33,7 +33,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6d9eee91", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -44,7 +44,7 @@ }, { "cell_type": "markdown", - "id": "e9a4b33f", + "id": "3", "metadata": {}, "source": [ "如果你在如[昇思大模型平台](https://xihe.mindspore.cn/training-projects)、[华为云ModelArts](https://www.huaweicloud.com/product/modelarts.html)、[启智社区](https://openi.pcl.ac.cn/)等算力平台的Jupyter在线编程环境中运行本案例,可取消如下代码的注释,进行依赖库安装:" @@ -53,7 +53,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -69,7 +69,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3bb7b07e", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -89,7 +89,7 @@ }, { "cell_type": "markdown", - "id": "31311e88", + "id": "6", "metadata": {}, "source": [ "其他场景可参考[MindSpore安装指南](https://www.mindspore.cn/install)进行环境搭建。" @@ -97,7 +97,7 @@ }, { "cell_type": "markdown", - "id": "3", + "id": "7", "metadata": {}, "source": [ "## 版本检查\n", @@ -108,7 +108,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +118,7 @@ }, { "cell_type": "markdown", - "id": "5", + "id": "9", "metadata": {}, "source": [ "## 版本显示差异说明(MindSpore NLP 0.5.1 安装后显示为 0.5.0rc2)\n", @@ -130,7 +130,7 @@ }, { "cell_type": "markdown", - "id": "6", + "id": "10", "metadata": {}, "source": [ "## 数据集加载\n", @@ -141,7 +141,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "8", + "id": "12", "metadata": {}, "source": [ "## 分词器与编码设置\n", @@ -170,7 +170,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -184,7 +184,7 @@ }, { "cell_type": "markdown", - "id": "10", + "id": "14", "metadata": {}, "source": [ "## 生成式数据集\n", @@ -201,7 +201,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -223,7 +223,7 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "16", "metadata": {}, "source": [ "## 构建数据处理流水线\n", @@ -238,7 +238,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -255,7 +255,7 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "18", "metadata": {}, "source": [ "## 构建模型与训练函数\n", @@ -273,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -324,7 +324,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "20", "metadata": {}, "source": [ "## 训练\n", @@ -341,7 +341,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -365,7 +365,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "22", "metadata": {}, "source": [ "## 推理\n", @@ -376,7 +376,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -401,7 +401,7 @@ }, { "cell_type": "markdown", - "id": "20", + "id": "24", "metadata": {}, "source": [ "## 保存模型与分词器\n", @@ -417,7 +417,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "25", "metadata": {}, "outputs": [], "source": [