使用PyTorch实操分步指南:针对稻米进行分类
itomcoil 2024-12-28 13:34 34 浏览
在快节奏的农业世界中,能够快速准确地对不同水稻品种进行分类可能会改变游戏规则。但是,我们如何利用机器学习来完成像水稻分类这样小众的事情呢?好吧,这就是强大的深度学习库 PyTorch 发挥作用的地方。今天,我将指导您使用 PyTorch 创建卷积神经网络 (CNN),以根据图像对水稻品种进行分类。本动手教程适用于对 Python 有基本了解的任何人,我将引导您完成代码的每个步骤,以便您轻松跟上。
先决条件
在开始之前,请确保已安装必要的库。运行以下命令安装任何缺少的依赖项:
pip install torch torchvision pandas numpy seaborn matplotlib splitfolders tabulate termcolor scikit-learn安装这些后,您就可以开始编码了!
设置数据集
我们将使用大米图像数据集来训练我们的模型。您可以在网上找到各种数据集,但为了简单起见和保持一致性,最好使用 Kaggle。如果您有 Kaggle 帐户,您可以直接将数据集导入笔记本,而无需在本地下载。
如果您在 Kaggle 笔记本中工作,只需确保将数据集直接上传到环境或使用 Kaggle 的内置数据集。对于本地用户,请下载数据集并将其解压到名为的文件夹中Rice_Image_Dataset。
我将使用来自 kaggle 的数据集
www.kaggle.com 水稻图像数据集,五种不同的大米图像数据集。Arborio、Basmati、Ipsala、Jasmine、Karacadag。
代码分步解释
现在,让我们将代码分解为易于理解的部分,并看看每个部分的作用。
1. 导入库并设置随机种子
import warnings
warnings.filterwarnings('ignore')
import os
import time
import torch
import random
import pathlib
import torchvision
import numpy as np
import pandas as pd
import splitfolders
import torch.nn as nn
import seaborn as sns
import torch.utils.data
from tabulate import tabulate
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import torchvision.transforms as transforms# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)在这里,我们导入了数据处理、可视化和模型构建所需的所有库。我们还设置了随机种子,以确保每次运行代码时都能获得一致的结果。
2. 分割数据集
# 数据集概述
dir1 = 'Rice_Image_Dataset'
splitfolders.ratio( input =dir1, output = 'imgs' , seed= 42 , ratio=( 0.7 , 0.15 , 0.15 ))
dir2 = pathlib.Path( 'imgs' )使用splitfolders,我们将数据分成训练、验证和测试集,比例为 70%、15% 和 15%。这有助于构建我们的数据集,以便有效地训练和测试我们的模型。您将看到创建了一个名为imgs 的文件夹,其中包含train、test和val作为子目录。
3.定义数据转换
transform = transforms.Compose([
transforms.Resize((250, 250)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])我们定义一系列变换来预处理图像。这里:
- Resize((250, 250))将每个图像的大小调整为 250x250 像素。
- ToTensor()将图像转换为 PyTorch 张量。
- Normalize()缩放像素值以使模型训练更加稳定。
4.加载数据
batch_size = 32
train_ds = torchvision.datasets.ImageFolder(os.path.join(dir2,'train' ), transform =transform)
val_ds = torchvision.datasets.ImageFolder(os.path.join(dir2,'val'),transform=transform)
test_ds = torchvision.datasets.ImageFolder(os.path.join(dir2,'test'),transform=transform)train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=True)我们在应用转换后加载数据集,并分批准备进行训练、验证和测试。批处理有助于提高模型训练期间的内存效率。
5. 可视化数据分布
class_counts = [0] * len(train_ds.classes)
for _, label in train_ds:
class_counts[label] += 1class_distribution = pd.DataFrame({
'Class': train_ds.classes,
'Count': class_counts
})plt.figure(figsize=(10, 6))
sns.barplot(x='Class', y='Count', data=class_distribution)
plt.title('Class Distribution in Training Dataset')
plt.xticks(rotation=45)
plt.ylabel('Number of Images')
plt.xlabel('Classes')
plt.show()这部分代码绘制了图像在各个类别中的分布情况,让我们可以快速检查类别平衡情况。这在分类问题中至关重要,因为它会影响模型的泛化能力。
6.定义 CNN 模型
class CNN(nn.Module):
def __init__(self, unique_classes):
super(CNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(64, 128, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.dense_layers = nn.Sequential(
nn.Linear(128 * 29 * 29, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, unique_classes)
)
def forward(self, X):
out = self.conv_layers(X)
out = out.view(out.size(0), -1)
out = self.dense_layers(out)
return out我们的 CNN 模型由用于特征提取的卷积层和用于分类的密集层组成。每个卷积层后面都有 ReLU 激活和最大池化以降低维度。最后,全连接层预测每幅图像的类别。
7.训练模型
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5
num_epochs = 5
train_losses, val_losses, train_accs, val_accs = [], [], [], []
for epoch in range(num_epochs):
model.train()
train_loss, n_correct_train, n_total_train = 0, 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
y_pred = model(images)
loss = criterion(y_pred, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted_labels = torch.max(y_pred, 1)
n_correct_train += (predicted_labels == labels).sum().item()
n_total_train += labels.size(0)
train_losses.append(train_loss / len(train_loader))
train_accs.append(n_correct_train / n_total_train)
# Validation phase
model.eval()
val_loss, n_correct_val, n_total_val = 0, 0, 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
y_pred = model(images)
loss = criterion(y_pred, labels)
val_loss += loss.item()
_, predicted_labels = torch.max(y_pred, 1)
n_correct_val += (predicted_labels == labels).sum().item()
n_total_val += labels.size(0)
val_losses.append(val_loss / len(val_loader))
val_accs.append(n_correct_val / n_total_val)
# Print metrics for the current epoch
print(f'Epoch [{epoch+1}/{num_epochs}]')
print('-' * 50)
print(f'Train Loss: {train_losses[-1]:.4f} | Train Accuracy: {train_accs[-1]:.4f}')
print(f'Validation Loss: {val_losses[-1]:.4f} | Validation Accuracy: {val_accs[-1]:.4f}')
print('-' * 50)我们定义了训练循环来帮助模型从数据中学习并在每个时期提高其性能。
- 损失函数和优化器:我们首先设置nn.CrossEntropyLoss()损失函数,它非常适合多类分类任务。优化器torch.optim.Adam根据反向传播期间计算的梯度更新模型的参数,使用学习率为0.001。
- 时期和跟踪:我们将训练时期(整个数据集上的迭代次数)的数量指定为num_epochs = 5。我们还初始化列表以跟踪训练和验证损失和准确度随时间的变化。
- 训练阶段:在主循环中,模型设置为train模式,这允许它调整其权重。对于每一批图像:
- 我们用optimizer.zero_grad() 重置梯度。
- 将图像传递给模型以获得预测(y_pred)。
- 通过将预测与真实标签进行比较来计算损失。
- 反向传播损失(loss.backward())并用optimizer.step()更新模型的权重。
- 跟踪当前时期的累积训练损失和准确度。
4.验证阶段:在当前时期进行训练后,我们在验证集上评估模型(不更新权重):
- 该模型设置为eval模式以禁用dropout和批量标准化层。
- 我们计算每个批次的验证损失和准确度,类似于训练循环,但没有反向传播(使用torch.no_grad())。
5. 记录结果:在每个时期结束时,我们都会打印训练和验证损失和准确率,以监控模型的进度。这有助于我们了解模型是否正在改进或是否存在过度拟合的可能性。
每个时期都会让我们了解模型的学习效果,记录的结果可以帮助我们在必要时对模型进行微调。
8.评估模型性能
model.eval()
test_loss, total_correct, total_samples = 0, 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
predictions = model(images)
loss = criterion(predictions, labels)
test_loss += loss.item()
_, predicted_classes = torch.max(predictions, 1)
total_correct += (predicted_classes == labels).sum().item()
total_samples += labels.size(0)
avg_test_loss = test_loss / len(test_loader)
test_accuracy = total_correct / total_samples
print(f'Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')训练完模型后,必须评估其在测试集(由未见过的数据组成)上的表现。此步骤可让我们更好地了解模型在训练和验证数据之外的泛化能力。
以下是评估代码的细目:
- 将模型设置为评估模式:通过调用model.eval(),我们将模型设置为评估模式。这将停用某些层(如 dropout),确保输出一致且测试结果可靠。
- 初始化跟踪变量:我们将test_loss、total_correct和初始化total_samples为零。这些变量将帮助我们计算所有批次的总体测试损失和准确度。
- 禁用梯度计算:通过torch.no_grad(),我们禁用梯度计算,这可以减少内存使用并加快进程,因为我们在测试期间不会更新任何模型参数。
- 循环测试数据:针对每个批次test_loader:
- 将图像和标签移动到设备(GPU 或 CPU)。
- 将图像传递给模型以获得预测。
- 计算预测和真实标签之间的损失,并将其添加到test_loss。
- 用于torch.max(predictions, 1)获取预测的类别标签。
- predicted_classes通过与真实值进行比较来计算正确的预测labels,并将结果添加到total_correct。
- 增加total_samples当前批次中的图像数量,以跟踪测试样本的总数。
5.计算平均测试损失和准确率:
- avg_test_loss是通过将累计test_loss除以批次数(len(test_loader))来计算的。
- test_accuracy计算为total_correct预测值与total_samples的比率,给出模型在测试集上的整体准确度。
6. 打印结果:最后,我们打印测试损失和准确率,总结模型在未见数据上的表现。此指标可以帮助您判断模型是否已准备好部署或是否需要进一步调整。
9.保存模型
model_save_path = 'cnn_rice_classifier.pth'
torch.save(model.state_dict(), model_save_path)
print(f'Model saved to {model_save_path}')最后,我们保存模型,以便可以重复使用而无需重新训练。
总结
就这样!您已成功在 PyTorch 中构建并训练了一个 CNN 模型来对水稻品种进行分类。此过程涵盖了数据预处理、构建 CNN、训练和评估模型。只需进行一些调整,您就可以调整此模型来对其他类型的图像进行分类。祝您编码愉快!
相关推荐
-
- Python编程实现求解高次方程_python求次幂
-
#头条创作挑战赛#编程求解一元多次方程,一般情况下对于高次方程我们只求出近似解,较少的情况可以得到精确解。这里给出两种经典的方法,一种是牛顿迭代法,它是求解方程根的有效方法,通过若干次迭代(重复执行部分代码,每次使变量的当前值被计算出的新值...
-
2025-10-23 03:58 itomcoil
- python常用得内置函数解析——sorted()函数
-
接下来我们详细解析Python中非常重要的内置函数sorted()1.函数定义sorted()函数用于对任何可迭代对象进行排序,并返回一个新的排序后的列表。语法:sorted(iterabl...
- Python入门学习教程:第 6 章 列表
-
6.1什么是列表?在Python中,列表(List)是一种用于存储多个元素的有序集合,它是最常用的数据结构之一。列表中的元素可以是不同的数据类型,如整数、字符串、浮点数,甚至可以是另一个列表。列...
- Python之函数进阶-函数加强(上)_python怎么用函数
-
一.递归函数递归是一种编程技术,其中函数调用自身以解决问题。递归函数需要有一个或多个终止条件,以防止无限递归。递归可以用于解决许多问题,例如排序、搜索、解析语法等。递归的优点是代码简洁、易于理解,并...
- Python内置函数range_python内置函数int的作用
-
range类型表示不可变的数字序列,通常用于在for循环中循环指定的次数。range(stop)range(start,stop[,step])range构造器的参数必须为整数(可以是内...
- python常用得内置函数解析——abs()函数
-
大家号这两天主要是几个常用得内置函数详解详细解析一下Python中非常常用的内置函数abs()。1.函数定义abs(x)是Python的一个内置函数,用于返回一个数的绝对值。参数:x...
- 如何在Python中获取数字的绝对值?
-
Python有两种获取数字绝对值的方法:内置abs()函数返回绝对值。math.fabs()函数还返回浮点绝对值。abs()函数获取绝对值内置abs()函数返回绝对值,要使用该函数,只需直接调用:a...
- 贪心算法变种及Python模板_贪心算法几个经典例子python
-
贪心算法是一种在每一步选择中都采取当前状态下最优的选择,从而希望导致结果是全局最优的算法策略。以下是贪心算法的主要变种、对应的模板和解决的问题特点。1.区间调度问题问题特点需要从一组区间中选择最大数...
- Python倒车请注意!负步长range的10个高能用法,让代码效率翻倍
-
你是否曾遇到过需要倒着处理数据的情况?面对时间序列、日志文件或者矩阵操作,传统的遍历方式往往捉襟见肘。今天我们就来揭秘Python中那个被低估的功能——range的负步长操作,让你的代码优雅反转!一、...
- Python中while循环详解_python怎么while循环
-
Python中的`while`循环是一种基于条件判断的重复执行结构,适用于不确定循环次数但明确终止条件的场景。以下是详细解析:---###一、基本语法```pythonwhile条件表达式:循环体...
- 简单的python-核心篇-面向对象编程
-
在Python中,类本身也是对象,这被称为"元类"。这种设计让Python的面向对象编程具有极大的灵活性。classMyClass:"""一个简单的...
- 简单的python-python3中的不变的元组
-
golang中没有内置的元组类型,但是多值返回的处理结果模拟了元组的味道。因此,在golang中"元组”只是一个将多个值(可能是同类型的,也可能是不同类型的)绑定在一起的一种便利方法,通常,也...
- python中必须掌握的20个核心函数——sorted()函数
-
sorted()是Python的内置函数,用于对可迭代对象进行排序,返回一个新的排序后的列表,不修改原始对象。一、sorted()的基本用法1.1方法签名sorted(iterable,*,ke...
- 12 个 Python 高级技巧,让你的代码瞬间清晰、高效
-
在日常的编程工作中,我们常常追求代码的精简、优雅和高效。你可能已经熟练掌握了列表推导式(listcomprehensions)、f-string和枚举(enumerate)等常用技巧,但有时仍会觉...
- Python的10个进阶技巧:写出更快、更省内存、更优雅的代码
-
在Python的世界里,我们总是在追求效率和可读性的完美平衡。你不需要一个数百行的新框架来让你的代码变得优雅而快速。事实上,真正能带来巨大提升的,往往是那些看似微小、却拥有高杠杆作用的技巧。这些技巧能...
- 一周热门
- 最近发表
- 标签列表
-
- ps图案在哪里 (33)
- super().__init__ (33)
- python 获取日期 (34)
- 0xa (36)
- super().__init__()详解 (33)
- python安装包在哪里找 (33)
- linux查看python版本信息 (35)
- python怎么改成中文 (35)
- php文件怎么在浏览器运行 (33)
- eval在python中的意思 (33)
- python安装opencv库 (35)
- python div (34)
- sticky css (33)
- python中random.randint()函数 (34)
- python去掉字符串中的指定字符 (33)
- python入门经典100题 (34)
- anaconda安装路径 (34)
- yield和return的区别 (33)
- 1到10的阶乘之和是多少 (35)
- python安装sklearn库 (33)
- dom和bom区别 (33)
- js 替换指定位置的字符 (33)
- python判断元素是否存在 (33)
- sorted key (33)
- shutil.copy() (33)
