从未如此简单,15分钟就上手的神经网络构建方法
itomcoil 2024-12-28 13:36 45 浏览
全文共2392字,预计学习时长11分钟
人工智能,深度学习,这些词是不是听起来就很高大上,充满了神秘气息?仿佛是只对数学博士开放的高级领域?
错啦!在B站已经变成学习网站的今天,还有什么样的教程是网上找不到的呢?深度学习从未如此好上手,至少实操部分是这样。
假如你只是了解人工神经网络基础理论,却从未踏足如何编写,跟着本文一起试试吧。你将会对如何在PyTorch 库中执行人工神经网络运算,以预测原先未见的数据有一个基本的了解。
这篇文章最多10分钟就能读完;如果要跟着代码一步步操作的话,只要已经安装了必要的库,那么也只需15分钟。相信我,它并不难。
长话短说,快开始吧!
导入语句和数据集
在这个简单的范例中将用到几个库:
· Pandas:用于数据加载和处理
· Matplotlib: 用于数据可视化处理
· PyTorch: 用于模型训练
· Scikit-learn: 用于拆分训练集和测试集
如果仅仅是想复制粘贴的话,以下几条导入语句可供参考:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split至于数据集,Iris数据集可以在这个URL上找到。下面演示如何把它直接导入
Pandas:
iris = pd.read_csv('https://raw.githubusercontent.com/pandas-dev/pandas/master/pandas/tests/data/iris.csv')
iris.head()前几行如下图所示:
现在需要将 Name列中鸢尾花的品种名称更改或者重映射为分类值。——也就是0、1、2。以下是步骤说明:
mappings = {
'Iris-setosa': 0,
'Iris-versicolor': 1,
'Iris-virginica': 2
}iris['Name'] = iris['Name'].apply(lambda x: mappings[x])执行上述代码得到的DataFrame如下:
这恭喜你,你已经成功地迈出了第一步!
拆分训练集和测试集
在此环节,将使用 Scikit-Learn库拆分训练集和测试集。随后, 将拆分过的数据由 Numpy arrays 转换为PyTorchtensors。
首先,需要将Iris 数据集划分为“特征”和“ 标签集” ——或者是x和y。Name列是因变量而其余的则是“特征”(或者说是自变量)。
接下来笔者也将使用随机种子,所以可以直接复制下面的结果。代码如下:
X = iris.drop('Name', axis=1).values
y = iris['Name'].valuesX_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2, random_state=42)X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)如果从 X_train 开始检查前三行,会得到如下结果:
从 y_train开始则得到如下结果:
地基已经打好,下一环节将正式开始搭建神经网络。
定义神经网络模型
模型的架构很简单。重头戏在于神经网络的架构:
1.输入层 (4个输入特征(即X所含特征的数量),16个输出特征(随机))
2.全连接层 (16个输入特征(即输入层中输出特征的数量),12个输出特征(随机))
3.输出层(12个输入特征(即全连接层中输出特征的数量),3个输出特征(即不同品种的数量)
大致就是这样。除此之外还将使用ReLU 作为激活函数。下面展示如何在代码里执行这个激活函数。
class ANN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 =nn.Linear(in_features=4, out_features=16)
self.fc2 =nn.Linear(in_features=16, out_features=12)
self.output =nn.Linear(in_features=12, out_features=3)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.output(x)
return xPyTorch使用的面向对象声明模型的方式非常直观。在构造函数中,需定义所有层及其架构,若使用forward(),则需定义正向传播。
接着创建一个模型实例,并验证其架构是否与上文所指的架构相匹配:
model = ANN()
model在训练模型之前,需注明以下几点:
· 评价标准:主要使用 CrossEntropyLoss来计算损失
· 优化器:使用学习率为0.01的Adam 优化算法
下面展示如何在代码中执行CrossEntropyLoss和Adam :
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)令人期盼已久的环节终于来啦——模型训练!
模型训练
这部分同样相当简单。模型训练将进行100轮, 持续追踪时间和损失。每10轮就向控制台输出一次当前状态——以指出目前所处的轮次和当前的损失。
代码如下:
%%timeepochs = 100
loss_arr = []for i in range(epochs):
y_hat = model.forward(X_train)
loss = criterion(y_hat, y_train)
loss_arr.append(loss)
if i % 10 == 0:
print(f'Epoch: {i} Loss: {loss}')
optimizer.zero_grad()
loss.backward()
optimizer.step()好奇最后三行是干嘛用的吗?答案很简单——反向传播——权重和偏置的更新使模型能真正地“学习”。
以下是上述代码的运行结果:
进度很快——但不要掉以轻心。
如果对纯数字真的不感冒,下图是损失曲线的可视化图(x轴为轮次编号,y轴为损失):
模型已经训练完毕,现在该干嘛呢?当然是模型评估。需要以某种方式在原先未见的数据上对这个模型进行评估。
模型评估
在评估过程中,欲以某种方式持续追踪模型做出的预测。需要迭代 X_test并进行预测,然后将预测结果与实际值进行比较。
这里将使用 torch.no_grad(),因为只是评估而已——无需更新权重和偏置。
总而言之,代码如下:
preds = []with torch.no_grad():
for val in X_test:
y_hat = model.forward(val)
preds.append(y_hat.argmax().item())现在预测结果被存储在 preds阵列。可以用下列三个值构建一个Pandas DataFrame。
· Y:实际值
· YHat: 预测值
· Correct:对角线,对角线的值为1表示Y和YHat相匹配,值为0则表示不匹配
代码如下:
df = pd.DataFrame({'Y': y_test, 'YHat':preds})df['Correct'] = [1 if corr == pred else 0 for corr, pred in zip(df['Y'],df['YHat'])]df 的前五行如下图所示:
下一个问题是,实际该如何计算精确度呢?
很简单——只需计算 Correct列的和再除以 df的长度:
df['Correct'].sum() / len(df)>>> 1.0此模型对原先未见数据的准确率为100%。但需注意这完全是因为Iris数据集非常易于归类,并不意味着对于Iris数据集来说,神经网络就是最好的算法。NN对于这类问题来讲有点大材小用,不过这都是以后讨论的话题了。
这可能是你写过最简单的神经网络,有着完美简洁的数据集、没有缺失值、层次最少、还有神经元!本文没有什么高级深奥的东西,相信你一定能够掌握它。
留言点赞关注
我们一起分享AI学习与发展的干货
如转载,请后台留言,遵守转载规范
相关推荐
-
- Python编程实现求解高次方程_python求次幂
-
#头条创作挑战赛#编程求解一元多次方程,一般情况下对于高次方程我们只求出近似解,较少的情况可以得到精确解。这里给出两种经典的方法,一种是牛顿迭代法,它是求解方程根的有效方法,通过若干次迭代(重复执行部分代码,每次使变量的当前值被计算出的新值...
-
2025-10-23 03:58 itomcoil
- python常用得内置函数解析——sorted()函数
-
接下来我们详细解析Python中非常重要的内置函数sorted()1.函数定义sorted()函数用于对任何可迭代对象进行排序,并返回一个新的排序后的列表。语法:sorted(iterabl...
- Python入门学习教程:第 6 章 列表
-
6.1什么是列表?在Python中,列表(List)是一种用于存储多个元素的有序集合,它是最常用的数据结构之一。列表中的元素可以是不同的数据类型,如整数、字符串、浮点数,甚至可以是另一个列表。列...
- Python之函数进阶-函数加强(上)_python怎么用函数
-
一.递归函数递归是一种编程技术,其中函数调用自身以解决问题。递归函数需要有一个或多个终止条件,以防止无限递归。递归可以用于解决许多问题,例如排序、搜索、解析语法等。递归的优点是代码简洁、易于理解,并...
- Python内置函数range_python内置函数int的作用
-
range类型表示不可变的数字序列,通常用于在for循环中循环指定的次数。range(stop)range(start,stop[,step])range构造器的参数必须为整数(可以是内...
- python常用得内置函数解析——abs()函数
-
大家号这两天主要是几个常用得内置函数详解详细解析一下Python中非常常用的内置函数abs()。1.函数定义abs(x)是Python的一个内置函数,用于返回一个数的绝对值。参数:x...
- 如何在Python中获取数字的绝对值?
-
Python有两种获取数字绝对值的方法:内置abs()函数返回绝对值。math.fabs()函数还返回浮点绝对值。abs()函数获取绝对值内置abs()函数返回绝对值,要使用该函数,只需直接调用:a...
- 贪心算法变种及Python模板_贪心算法几个经典例子python
-
贪心算法是一种在每一步选择中都采取当前状态下最优的选择,从而希望导致结果是全局最优的算法策略。以下是贪心算法的主要变种、对应的模板和解决的问题特点。1.区间调度问题问题特点需要从一组区间中选择最大数...
- Python倒车请注意!负步长range的10个高能用法,让代码效率翻倍
-
你是否曾遇到过需要倒着处理数据的情况?面对时间序列、日志文件或者矩阵操作,传统的遍历方式往往捉襟见肘。今天我们就来揭秘Python中那个被低估的功能——range的负步长操作,让你的代码优雅反转!一、...
- Python中while循环详解_python怎么while循环
-
Python中的`while`循环是一种基于条件判断的重复执行结构,适用于不确定循环次数但明确终止条件的场景。以下是详细解析:---###一、基本语法```pythonwhile条件表达式:循环体...
- 简单的python-核心篇-面向对象编程
-
在Python中,类本身也是对象,这被称为"元类"。这种设计让Python的面向对象编程具有极大的灵活性。classMyClass:"""一个简单的...
- 简单的python-python3中的不变的元组
-
golang中没有内置的元组类型,但是多值返回的处理结果模拟了元组的味道。因此,在golang中"元组”只是一个将多个值(可能是同类型的,也可能是不同类型的)绑定在一起的一种便利方法,通常,也...
- python中必须掌握的20个核心函数——sorted()函数
-
sorted()是Python的内置函数,用于对可迭代对象进行排序,返回一个新的排序后的列表,不修改原始对象。一、sorted()的基本用法1.1方法签名sorted(iterable,*,ke...
- 12 个 Python 高级技巧,让你的代码瞬间清晰、高效
-
在日常的编程工作中,我们常常追求代码的精简、优雅和高效。你可能已经熟练掌握了列表推导式(listcomprehensions)、f-string和枚举(enumerate)等常用技巧,但有时仍会觉...
- Python的10个进阶技巧:写出更快、更省内存、更优雅的代码
-
在Python的世界里,我们总是在追求效率和可读性的完美平衡。你不需要一个数百行的新框架来让你的代码变得优雅而快速。事实上,真正能带来巨大提升的,往往是那些看似微小、却拥有高杠杆作用的技巧。这些技巧能...
- 一周热门
- 最近发表
- 标签列表
-
- ps图案在哪里 (33)
- super().__init__ (33)
- python 获取日期 (34)
- 0xa (36)
- super().__init__()详解 (33)
- python安装包在哪里找 (33)
- linux查看python版本信息 (35)
- python怎么改成中文 (35)
- php文件怎么在浏览器运行 (33)
- eval在python中的意思 (33)
- python安装opencv库 (35)
- python div (34)
- sticky css (33)
- python中random.randint()函数 (34)
- python去掉字符串中的指定字符 (33)
- python入门经典100题 (34)
- anaconda安装路径 (34)
- yield和return的区别 (33)
- 1到10的阶乘之和是多少 (35)
- python安装sklearn库 (33)
- dom和bom区别 (33)
- js 替换指定位置的字符 (33)
- python判断元素是否存在 (33)
- sorted key (33)
- shutil.copy() (33)
