PyTorch实战:TorchVision目标检测模型微调完
itomcoil 2025-08-21 03:16 1 浏览
PyTorch实战:TorchVision目标检测模型微调完整教程
一、什么是微调(Finetuning)?
微调(Finetuning)是指在已经预训练好的模型基础上,使用自己的数据对模型进行进一步训练,使之快速适应特定任务。
在深度学习领域,从头训练一个模型往往需要大量计算资源与数据。微调的出现极大降低了训练成本和门槛,帮助开发者在有限数据集上快速实现高性能模型。
在本教程中,我们将以目标检测(Object Detection)为例,带领大家用PyTorch和TorchVision完成微调过程。
二、PyTorch与TorchVision基础介绍
PyTorch是什么?
PyTorch是Facebook开发的开源深度学习框架,因其灵活易用的动态图机制,迅速成为了学术界与工业界的主流工具。
TorchVision是什么?
TorchVision是PyTorch官方提供的计算机视觉工具包,包含:
- o 常用的图像数据集(如COCO、CIFAR)
- o 预训练模型(如ResNet、VGG、Faster R-CNN)
- o 数据转换工具(如Transforms)
我们本次教程的核心便是利用TorchVision提供的预训练目标检测模型进行微调。
三、数据集介绍与准备
本教程使用的Penn-Fudan Pedestrian Dataset,包含170张图像,用于行人检测任务。每张图像都附有行人的语义掩膜(mask)标注。
数据下载链接:
Penn-Fudan Pedestrian Dataset下载地址
数据集结构说明
下载解压后,文件夹结构如下:
PennFudanPed/
├── Annotation # 标注信息(XML格式,此处未用到)
├── PedMasks # 行人掩膜(mask)
└── PNGImages # 原始图片
数据示例图:
数据集示例图(PyTorch官方教程)
(图片来源:PyTorch官方教程)
数据标注格式
掩膜(mask)图像以整数形式标注,每个像素表示所属对象的类别编号,背景为0,每个对象(如行人)依次编号(1,2,…)。
四、自定义Dataset类创建(含详细中文注释)
我们定义一个继承torch.utils.data.Dataset的数据加载类:
import os
import numpy as np
import torch
from PIL import Image
class PennFudanDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms=None):
self.root = root
self.transforms = transforms
self.imgs = sorted(os.listdir(os.path.join(root, "PNGImages")))
self.masks = sorted(os.listdir(os.path.join(root, "PedMasks")))
def __getitem__(self, idx):
# 读取图像和掩膜
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path)
mask = np.array(mask)
obj_ids = np.unique(mask)[1:] # 去除背景编号0
masks = mask == obj_ids[:, None, None]
boxes = []
for i in range(len(obj_ids)):
pos = np.where(masks[i])
xmin, xmax = np.min(pos[1]), np.max(pos[1])
ymin, ymax = np.min(pos[0]), np.max(pos[0])
boxes.append([xmin, ymin, xmax, ymax])
# 转换为torch.Tensor类型
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.ones((len(obj_ids),), dtype=torch.int64) # 行人标签为1
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
iscrowd = torch.zeros((len(obj_ids),), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)
五、模型选择与修改(预训练模型微调)
本教程使用TorchVision内置的Faster R-CNN模型:
Faster R-CNN网络结构:
Faster R-CNN结构图(PyTorch官方教程)
(图片来源:PyTorch官方教程)
使用预训练模型代码(含中文注释)
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def get_model(num_classes):
# 加载在COCO上预训练的Faster R-CNN
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# 获取分类器输入特征数
in_features = model.roi_heads.box_predictor.cls_score.in_features
# 替换分类头,适应我们的数据集(只有背景和行人两个类别)
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
六、模型训练流程详解(代码详解)
1. 数据加载与数据增强
import utils
from engine import train_one_epoch, evaluate
import transforms as T
def get_transform(train):
transforms = [T.ToTensor()]
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
2. 完整训练脚本
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=utils.collate_fn)
model = get_model(num_classes=2)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
num_epochs = 10
for epoch in range(num_epochs):
# 训练一个epoch
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
# 在测试集上评估
evaluate(model, data_loader_test, device=device)
七、模型评估与测试(代码与指标分析)
模型训练后,使用以下代码进行预测与可视化:
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
model.eval()
img, _ = dataset_test[0]
with torch.no_grad():
prediction = model([img.to(device)])
boxes = prediction[0]['boxes'].cpu()
scores = prediction[0]['scores'].cpu()
img = (img * 255).type(torch.uint8)
img_with_boxes = draw_bounding_boxes(img, boxes, labels=[f'{s:.2f}' for s in scores])
plt.figure(figsize=(12, 8))
plt.imshow(img_with_boxes.permute(1, 2, 0))
plt.axis('off')
plt.show()
预测效果图示例:
模型预测结果示例图(PyTorch官方教程)
(图片来源:PyTorch官方教程)
八、常见问题及排错技巧
- 1. 模型不收敛怎么办?
调整学习率,检查数据标注正确性。 - 2. GPU内存不足怎么办?
减小batch_size或裁剪图像尺寸。 - 3. 为什么模型预测框位置偏差较大?
检查训练数据标注框准确性,增加数据增强方法。
九、总结与延伸阅读
本教程详细介绍了使用PyTorch与TorchVision完成目标检测微调的全过程,帮助初中级开发者掌握核心技巧,推荐进一步阅读PyTorch官方教程深入学习。
原文链接:PyTorch TorchVision目标检测微调官方教程
相关推荐
- 最强聚类模型,层次聚类 !!_层次聚类的优缺点
-
哈喽,我是小白~咱们今天聊聊层次聚类,这种聚类方法在后面的使用,也是非常频繁的~首先,聚类很好理解,聚类(Clustering)就是把一堆“东西”自动分组。这些“东西”可以是人、...
- python决策树用于分类和回归问题实际应用案例
-
决策树(DecisionTrees)通过树状结构进行决策,在每个节点上根据特征进行分支。用于分类和回归问题。实际应用案例:预测一个顾客是否会流失。决策树是一种基于树状结构的机器学习算法,用于解决分类...
- Python教程(四十五):推荐系统-个性化推荐算法
-
今日目标o理解推荐系统的基本概念和类型o掌握协同过滤算法(用户和物品)o学会基于内容的推荐方法o了解矩阵分解和深度学习推荐o掌握推荐系统评估和优化技术推荐系统概述推荐系统是信息过滤系统,用于...
- 简单学Python——NumPy库7——排序和去重
-
NumPy数组排序主要用sort方法,sort方法只能将数值按升充排列(可以用[::-1]的切片方式实现降序排序),并且不改变原数组。例如:importnumpyasnpa=np.array(...
- PyTorch实战:TorchVision目标检测模型微调完
-
PyTorch实战:TorchVision目标检测模型微调完整教程一、什么是微调(Finetuning)?微调(Finetuning)是指在已经预训练好的模型基础上,使用自己的数据对模型进行进一步训练...
- C4.5算法解释_简述c4.5算法的基本思想
-
C4.5算法是ID3算法的改进版,它在特征选择上采用了信息增益比来解决ID3算法对取值较多的特征有偏好的问题。C4.5算法也是一种用于决策树构建的算法,它同样基于信息熵的概念。C4.5算法的步骤如下:...
- Python中的数据聚类及可视化分析实践
-
探索如何通过聚类分析揭露糖尿病预测数据集的特征!我们将运用Python的强力工具,深入挖掘数据,以直观的可视化揭示不同特征间的关系。一同探索聚类分析在糖尿病预测中的实践!所有这些可视化都可以通过数据操...
- 用Python来统计大乐透号码的概率分布
-
用Python来统计大乐透号码的概率分布,可以按照以下步骤进行:导入所需的库:使用Python中的numpy库生成数字序列,使用matplotlib库生成概率分布图。读取大乐透历史数据:从网络上找到大...
- python:支持向量机监督学习算法用于二分类和多分类问题示例
-
监督学习-支持向量机(SVM)支持向量机(SupportVectorMachine,简称SVM)是一种常用的监督学习算法,用于解决分类和回归问题。SVM的目标是找到一个最优的超平面,将不同类别的...
- 25个例子学会Pandas Groupby 操作
-
groupby是Pandas在数据分析中最常用的函数之一。它用于根据给定列中的不同值对数据点(即行)进行分组,分组后的数据可以计算生成组的聚合值。如果我们有一个包含汽车品牌和价格信息的数据集,那么可以...
- 数据挖掘流程_数据挖掘流程主要有哪些步骤
-
数据挖掘流程1.了解需求,确认目标说一下几点思考方法:做什么?目的是什么?目标是什么?为什么要做?有什么价值和意义?如何去做?完整解决方案是什么?2.获取数据pandas读取数据pd.read.c...
- 使用Python寻找图像最常见的颜色_python 以图找图
-
如果我们知道图像或对象最常见的是哪种颜色,那么可以解决图像处理中的几个用例,例如在农业领域,我们可能需要确定水果的成熟度。我们可以简单地检查一下水果的颜色是否在预定的范围内,看看它是成熟的,腐烂的,还...
- 财务预算分析全网最佳实践:从每月分析到每天分析
-
原文链接如下:「链接」掌握本文的方法,你就掌握了企业预算精细化分析的能力,全网首发。数据模拟稍微有点问题,不要在意数据细节,先看下最终效果。在编制财务预算或业务预算的过程中,通常预算的所有数据都是按月...
- 常用数据工具去重方法_数据去重公式
-
在数据处理中,去除重复数据是确保数据质量和分析准确性的关键步骤。特别是在处理多列数据时,保留唯一值组合能够有效清理数据集,避免冗余信息对分析结果的干扰。不同的工具和编程语言提供了多种方法来实现多列去重...
- Python教程(四十):PyTorch深度学习-动态计算图
-
今日目标o理解PyTorch的基本概念和动态计算图o掌握PyTorch张量操作和自动求导o学会构建神经网络模型o了解PyTorch的高级特性o掌握模型训练和部署PyTorch概述PyTorc...
- 一周热门
- 最近发表
- 标签列表
-
- ps图案在哪里 (33)
- super().__init__ (33)
- python 获取日期 (34)
- 0xa (36)
- super().__init__()详解 (33)
- python安装包在哪里找 (33)
- linux查看python版本信息 (35)
- python怎么改成中文 (35)
- php文件怎么在浏览器运行 (33)
- eval在python中的意思 (33)
- python安装opencv库 (35)
- python div (34)
- sticky css (33)
- python中random.randint()函数 (34)
- python去掉字符串中的指定字符 (33)
- python入门经典100题 (34)
- anaconda安装路径 (34)
- yield和return的区别 (33)
- 1到10的阶乘之和是多少 (35)
- python安装sklearn库 (33)
- dom和bom区别 (33)
- js 替换指定位置的字符 (33)
- python判断元素是否存在 (33)
- sorted key (33)
- shutil.copy() (33)