Python实现随机&批量梯度下降算法
itomcoil 2025-09-04 07:46 6 浏览
一.概述
梯度下降属于迭代法的一种,可以用于求解最小二乘问题。在求解机器学习算法的模型参数时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法。在求解损失函数的最小值时,可以通过梯度下降法来一步步的迭代求解,得到最小化的损失函数和模型参数值。反过来,如果我们需要求解损失函数的最大值,这时就需要用梯度上升法。在机器学习中,基于基本的梯度下降法发展了两种梯度下降方法,分别为随机梯度下降法和批量梯度下降法。
- 随机梯度下降:随机选取一条训练数据作为训练样本计算最小化的损失函数和模型参数;
- 批量梯度下降:应用全部训练数据作为训练样本计算最小化的损失函数和模型参数;
- Mini-batch:应用小批量的训练数据;这是随机梯度下降和批量梯度下降的中间产物;
二.随机梯度下降算法代码实现
import numpy as np
import math
__author__ = 'aaa'
# 生成测试数据
x = 2 * np.random.rand(100, 1) # 随机生成100*1的二维数组,值分别在0~2之间
y = 4 + 3 * x + np.random.randn(100, 1) # 随机生成100*1的二维数组,值分别在4~11之间
x_b = np.c_[np.ones((100, 1)), x]
theta = np.random.randn(2, 1)
n_epochs = 100
t0, t1 = 1, 10
m = n_epochs
def learn_step(t): # 模拟实现动态修改步长
return t0 / (t + t1)
for epoch in range(n_epochs):
for i in range(m):
random_index = np.random.randint(m) # 生成随机下标,获取随机训练样本数据
x_i = x_b[random_index:random_index+1]
y_i = y[random_index:random_index+1]
gradients = 2 * x_i.T.dot(x_i.dot(theta)-y_i) # 调用解析解函数
learning_rate = learn_step(epoch * m + i)
theta = theta - learning_rate * gradients
print("最终结果:\n{}".format(theta))
# 计算误差
error = math.sqrt(math.pow((theta[0][0] - 4), 2) + math.pow((theta[1][0] - 3), 2))
print("误差:\n{}".format(error))
执行结果:
三.批量梯度下降算法代码实现
import numpy as np
import math
# 定义基础变量
learning_rate = 0.1
n_iterations = 1000
m = 100
x = 2 * np.random.rand(m, 1) # 生成一组100*1的二维矩阵,该矩阵数据服从0~1均匀分布,下同
y = 4 + 3 * x + np.random.randn(m, 1) # 正态分布
x_b = np.c_[np.ones((m, 1)), x] # np.((100, 1)):表示生成100行1列的矩阵,内部填充为1
# 设置阈值
threshold = 0.2
# 初始化theta
theta = np.random.randn(2, 1)
count = 0
before_value = 1
# 设置阈值、超参数和迭代次数;迭代完次数或者满足阈值,就认为收敛
for iteration in range(n_iterations):
count += 1
# 求梯度gradient
gradients = 1/m * x_b.T.dot(x_b.dot(theta)-y) # 求平均梯度
# 应用公式调整theta值
theta = theta - learning_rate * gradients
# 判断是否满足阈值
mid = math.sqrt(math.pow((theta[0][0] - 4), 2) + math.pow((theta[1][0] - 3), 2))
# 满足阈值,结束循环
err = math.fabs(mid - before_value)
if threshold >= mid: # 前后两次的计算结果差值极小时,可认为已经接近收敛
break
else:
if err < 0.01:
print('多次迭代仍不能满足阈值,请修改阈值或完善程序!')
break
before_value = mid # 暂时保存上一次的中间结果,用于计算差值
print('结果:\n x is : {}\n y is : {}\n 误差 : {}'.format(theta[0][0], theta[1][0], before_value))
执行结果:
四.总结
不管是随机梯度还是批量梯度,它们的区别在于每次迭代计算应用的数据多少;随机一条优点是计算快,但是容易造成抖动,且有很大的随机性【管中窥豹,可见一斑】;批量全部优点是计算准确,可以稳步下降,但每次迭代计算时间长,资源消耗大;中庸之道在算法领域也是一个很重要的思想,Mini-batch就是其中的产物,每次计算取一小批数据进行迭代计算,即减低了异常数据带来的抖动,也降低了每次迭代计算的数据计算量;对于一般的应用场景,Mini-batch会是一个比较好的选择!
相关推荐
- Filter函数在WPS里的正确用法,官方教程里都没有说......
-
Filter函数是office365新增的筛选函数,WPS也紧跟添加了它。但在二个软件中的使用方法却完全不同。office365有单元格溢出功能,只需要输入一个Filter公式即可完成数据筛选。但在W...
- 跳过VLOOKUP天坑!FILTER函数10个招式让同事以为你开了外挂?
-
还在为VLOOKUP的"一对多"限制头疼?是否还在为INDEX+MATCH的嵌套抓狂?今天教你用Excel新晋顶流——FILTER函数,10个高能用法让你秒变数据操控大师!用法1:精准...
- Filter函数的三种用法,比用VLOOKUP一对多查询,更加灵活方便
-
文章最后有彩蛋!好礼相送!Excel秘籍大全,正文开始FILTER函数可以基于定义的条件筛选一系列数据。在没有filter函数之前,如果实现一对多查询,常见的是构建辅助列,然后使用VLOOKUP+R...
- Filter函数公式,快速实现订单核对,1分钟学会
-
举个例子,我们有一份公司所有的订单源数据表格,这里我们只用两列信息来模拟,实际可能有很多列数据,几百行数据然后我们有另外一个表,里面有部分已经处理过的订单数据,如下所示,这里举例是4个,实际可能有上百...
- FILTER函数结合及经典用法2:一对多筛选
-
FILTER经典用法2:一对多筛选。FILTER函数的经典用法2:一对多的筛选。比如左边这个表格,需要根据部门筛选出每个部门的人员,应该怎样做?·把鼠标放在单元格内,在编辑栏输入等于FILTER。·第...
- 干掉VLOOKUP,FILTER函数9大用法全解析!
-
1.单条件基础筛选场景:筛选销量>5000的记录公式:=VSTACK(A1:D1,FILTER(A2:D9,D2:D9>5500))解析:A2:D9为需要筛选的数据区域,D2:D9&...
- Excel新函数公式Filter,秒杀VLOOKUP,人人必学
-
以前VLOOKUP公式是必学的公式,自从新版本更新之后,VLOOKUP已经变得可有可无了,但是新出来的Filter函数公式,你必须学会,它非常的强大,工作中用到非常频繁1、Filter公式背景在学会这...
- 第一讲:filter的基本用法及拓展_filter详解
-
全能查找函数filter的基本用法及拓展初学者,务必观看。进阶者,可互相学习,欢迎在回复中补充新用法。首次撰写此函数相关内容,若有不足之处,请予以指教,请勿诋毁,多谢。提示:以下内容以WPS最新版本为...
- 测一测你是什么粒子?_测测你是什么质
-
大亚湾实验。|图片来源:RoyKaltschmidt,LawrenceBerkeleyNationalLaboratory/WikimediaCommons2020年12月12日,大亚湾...
- SpringBoot如何处理配置文件的密文
-
在SpringBoot应用中,直接在配置文件(如application.yml或application.properties)中明文存储数据库密码、API密钥等敏感信息是严重的安全风险,...
- 大语言模型解释Python的 类装饰器
-
一、什么是类装饰器?在Python中,装饰器(Decorator)是一种高阶函数,它接受另一个对象(通常是函数或类),并返回一个经“增强”处理后的新对象。我们常见的是对函数进行装饰:@my_dec...
- Thymeleaf_thymeleaf属于前端吗
-
一、Thymeleaf简介Thymeleaf是用来开发Web和独立环境项目的服务器端的Java模版引擎Spring官方支持的服务的渲染模板中,并不包含jsp。而是Thymeleaf和Freemarke...
- Win9去哪了?Win10避讳Windows95、98
-
10月1日,微软在旧金山发布了新一代操作系统预览版。但不是名为Windows9,而是win10,有业内人士猜测,跳过9而取10为命名是为了预示十全十美。可是小编还觉得9还代表长长久久呢!恐怕这里又说...
- 仓颉编程练习-字符串操作_仓颉编译器
-
main.cj:importstd.convert.Parsablemain():Int64{//字符串比较lets1:String="abc"...
- 一课译词:断断续续_一课译词:断断续续的意思
-
PhotobyMikefromPexels“断断续续”,或“时断时续”,意思是时而中断,时而继续地接连下去(continuefromtimetotime)。与英文惯用语“fitsan...
- 一周热门
- 最近发表
- 标签列表
-
- ps图案在哪里 (33)
- super().__init__ (33)
- python 获取日期 (34)
- 0xa (36)
- super().__init__()详解 (33)
- python安装包在哪里找 (33)
- linux查看python版本信息 (35)
- python怎么改成中文 (35)
- php文件怎么在浏览器运行 (33)
- eval在python中的意思 (33)
- python安装opencv库 (35)
- python div (34)
- sticky css (33)
- python中random.randint()函数 (34)
- python去掉字符串中的指定字符 (33)
- python入门经典100题 (34)
- anaconda安装路径 (34)
- yield和return的区别 (33)
- 1到10的阶乘之和是多少 (35)
- python安装sklearn库 (33)
- dom和bom区别 (33)
- js 替换指定位置的字符 (33)
- python判断元素是否存在 (33)
- sorted key (33)
- shutil.copy() (33)