9.8 模型训练、评估与上线¶
模型训练是将数据集、算法、算力转化为可实际使用的智能系统的核心过程;评估确保模型在真实场景下表现可信;上线则完成从实验室到生产环境的最后一公里。
一、训练流程总览¶
数据集准备(9.7)
↓ 模型搭建(选骨干 + 任务头)
↓ 训练配置(学习率调度 / 损失函数 / 优化器)
↓ 训练循环(前向 → 损失 → 反向 → 更新)
↓ 验证集评估(每 N epoch)→ 保存最优 checkpoint
↓ 测试集最终评估
↓ 模型导出(ONNX / TorchScript / TensorRT)
↓ 线上部署
二、训练策略¶
2.1 迁移学习(Transfer Learning)¶
利用在大数据集(ImageNet、COCO、OpenImages)上预训练的权重作为起点,在目标数据集上微调:
| 策略 | 适用场景 | 做法 |
|---|---|---|
| 全量微调(Fine-Tune) | 数据量较大(>5k) | 解冻所有层,小学习率全局更新 |
| 冻结骨干微调 | 数据量较小(<1k) | 冻结骨干,只训练任务头 |
| 线性探测(Linear Probe) | 快速验证特征质量 | 骨干全冻,只训练一个 FC 层 |
2.2 学习率调度¶
- Warmup 防止大学习率初期梯度爆炸
- 余弦退火(Cosine Annealing)比阶梯衰减更平滑,最终性能通常更好
2.3 混合精度训练(AMP)¶
使用 FP16 前向/反向,FP32 参数更新,显存节省约 50%,速度提升 1.5–3×,PyTorch 用 torch.cuda.amp.autocast() 一键开启。
2.4 分布式训练¶
| 策略 | 原理 | 适用场景 |
|---|---|---|
| DataParallel(DP) | 单进程,多 GPU 复制模型 | 单机多卡,简单但低效 |
| DistributedDataParallel(DDP) | 多进程,梯度 AllReduce | 单机/多机首选 |
| ZeRO(DeepSpeed) | 优化器状态 / 梯度 / 参数分片 | 超大模型(LLM) |
三、评估方法¶
3.1 评估原则¶
- 评估集严格隔离:test 集禁止用于任何决策(调参、选模型),只做一次性最终报告
- 多指标联合报告:不能只看单一指标(如只看准确率会掩盖类别不均衡问题)
- 置信区间:小测试集(<1000 样本)需报告标准差或置信区间
3.2 常见评估场景¶
| 任务 | 主评指标 | 辅助指标 |
|---|---|---|
| 分类 | Top-1 Accuracy | F1、AUC-ROC |
| 目标检测 | mAP@0.5:0.95 | AR@100、推理延迟 |
| 分割 | mIoU | Dice、边界 F1 |
| 去噪/超分辨率 | PSNR + SSIM | LPIPS |
3.3 消融实验(Ablation Study)¶
逐步增删模型组件,量化每个设计决策的贡献。典型格式:
| 配置 | mAP | 说明 |
|---|---|---|
| 基线 | 72.3 | ResNet50 + FPN |
| + CBAM | 73.8 | +1.5,注意力增益 |
| + 数据增广 | 75.1 | +1.3,Mosaic + CutMix |
| + 大分辨率输入 | 76.4 | +1.3,640→1280 |
四、模型导出与优化¶
4.1 导出格式¶
| 格式 | 场景 | 工具 |
|---|---|---|
| ONNX | 跨框架部署,中间格式 | torch.onnx.export |
| TorchScript | PyTorch 原生,服务端 | torch.jit.trace/script |
| TensorRT | NVIDIA GPU 最快推理 | TensorRT(TRT)优化 |
| RKNN | 瑞芯微 NPU(RK3588 等) | RKNN-Toolkit |
| CoreML | Apple 芯片 | coremltools |
4.2 量化(Quantization)¶
将浮点权重转为 INT8 / INT4,减少显存和延迟:
- 训练后量化(PTQ):用 100–1000 张校准图,速度快,精度损失 1–3%
- 量化感知训练(QAT):训练阶段模拟量化误差,精度损失 <1%,但耗时
4.3 剪枝(Pruning)¶
删除不重要的权重/通道:
- 结构化剪枝(整通道删除)→ 直接得到小模型,推理加速明显
- 非结构化剪枝(稀疏权重)→ 需稀疏计算库支持才能加速
五、上线部署¶
5.1 部署形态¶
| 形态 | 特点 | 代表场景 |
|---|---|---|
| 云端 API | 弹性伸缩,GPU 充足 | SaaS、在线识别服务 |
| 边缘服务器 | 低延迟,本地化 | 工业相机、安防视频流 |
| 嵌入式端(NPU/MCU) | 极低功耗 | 机载、手持设备 |
5.2 推理服务化¶
5.3 上线检查清单¶
- [ ] 测试集最终指标已记录(与预期基线对比)
- [ ] 推理延迟、吞吐已压测(P99 满足 SLA)
- [ ] 模型版本与数据集版本绑定记录
- [ ] 灰度发布或 A/B 测试策略就绪
- [ ] 监控告警(精度漂移检测)已配置
六、线上监控与再训练¶
模型上线后需持续监控:
- 数据漂移(Data Drift):输入数据分布偏离训练分布,可通过特征统计量检测
- 概念漂移(Concept Drift):真实标签分布变化(如新产品外观变更)
- 定期再训练:收集线上难样本 → 回流标注 → 增量或全量再训练 → 灰度发布
参考资料¶
- PyTorch 文档:https://pytorch.org/docs/stable/
- NVIDIA TensorRT 文档:https://docs.nvidia.com/deeplearning/tensorrt/
- Sculley et al., \"Hidden Technical Debt in Machine Learning Systems\", NeurIPS, 2015
更新时间¶
2026-03-03