百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

LossVal:一种集成于损失函数的高效数据价值评估方法

itomcoil 2025-02-07 17:47 15 浏览

在机器学习领域,训练数据的价值并非均等:部分训练数据点对模型训练的影响显著高于其他数据点。评估单个数据点的影响程度通常需要反复重训练模型,计算效率低下。LossVal提出了一种创新方法,通过将数据价值评估过程直接集成到神经网络的损失函数中,实现了高效的数据价值评估。

现代机器学习模型通常依赖大规模数据集进行训练。在实际应用中,数据集中的训练样本对模型的信息贡献度存在显著差异。例如含噪声数据点或标注错误的样本往往对机器学习模型的学习过程贡献有限。在这篇研究的一个实验中,利用车辆碰撞测试数据集训练模型,目标是基于车辆参数预测碰撞对乘员的伤害程度。数据集中包含80年代和90年代的车辆数据,这些历史数据对现代车辆的碰撞预测可能具有较低的参考价值。

LossVal技术原理

LossVal的核心思想是在模型训练过程中同步学习样本重要性得分,这一过程与模型权重的学习机制类似。这种方法避免了传统方法中需要多次重训练模型的计算开销,同时也无需记录训练过程中的模型权重更新序列。

实现上述目标的关键在于对标准损失函数(如均方误差MSE和交叉熵损失)进行改进。通过在损失函数中引入实例级权重,并将其与加权分布距离函数相乘。LossVal损失函数的一般形式可表示为:

其中?表示加权目标损失(可以是加权MSE或交叉熵),OT代表最优传输的加权分布距离。这种改进后的损失函数可直接用于神经网络训练,其中权重w通过梯度下降方法在每个训练步骤中更新。

以下分别介绍回归任务和分类任务中LossVal的具体实现方法,随后详细探讨分布距离OT的计算原理。

回归任务中的LossVal实现

从最基础的MSE开始分析。标准MSE定义为模型预测值?与真实值y之间的平方差(n为训练样本索引):

LossVal对MSE进行了两步改进:首先引入样本权重w?,为每个训练实例分配权重;其次将加权MSE与分布距离函数相乘。

分类任务中的LossVal实现

标准交叉熵损失的表达式为:

对交叉熵损失的改进方式与MSE类似:

最优传输距离度量

最优传输距离反映了将一个分布转换为另一个分布所需的最小代价,也称为推土机距离(这一形象化的名称源自于描述将一堆土填入坑洞的最优路径问题)。其数学定义为:

其中c表示将点x?移动到x?的代价,γ代表可能的传输方案集合,定义了点的移动路径。最优传输方案γ*是指具有最小分布距离的传输方案。值得注意的是,论文通过联合分布Π(w, 1)将权重w整合到代价函数中。因此OT??实际上度量了训练集与验证集之间的加权距离。

从实际应用角度来看,通过优化权重以最小化OT??,算法会自动为与验证数据相似的训练数据点分配较高权重,而噪声样本则会获得较低权重。这种机制确保了模型能够更多地从高质量数据中学习。

算法实现

完整的实现代码和相关数据集已在GitHub平台开源。以下代码展示了LossVal在均方误差场景下的核心实现:

def LossVal_mse(train_X: torch.Tensor, 
train_y_true: torch.Tensor, train_y_pred: torch.Tensor, 
val_X: torch.Tensor, sample_ids: torch.Tensor 
weights: torch.Tensor, device: torch.device) -> torch.Tensor: 
weights = weights.index_select(0, sample_ids) # 根据sample_ids选择对应的权重

# 步骤1:计算加权均方误差损失
loss = torch.sum((train_y_true - train_y_pred) ** 2, dim=1) 
weighted_loss = torch.sum(weights @ loss) # loss为向量,weights为矩阵

# 步骤2:计算训练集与验证集之间的Sinkhorn距离
sinkhorn_distance = SamplesLoss(loss="sinkhorn") 
dist_loss = sinkhorn_distance(weights, train_X, torch.ones(val_X.shape[0], requires_grad=True).to(device), val_X) 

# 步骤3:组合MSE损失与Sinkhorn距离
return weighted_loss * dist_loss**2

该损失函数在PyTorch框架中的使用方式与标准损失函数类似,但需要注意以下特殊之处:函数参数中包含验证集、样本权重以及批次样本索引,这些参数对于计算批处理样本的加权损失至关重要。实现依赖PyTorch的自动微分机制,因此样本权重向量需要作为模型参数的一部分。这样设计使得权重优化可以直接利用Adam等优化器的优势。另外也可以通过手动计算损失对各权重i的梯度来更新权重。对于交叉熵损失的实现,架构相似,主要区别在于需要修改第8行的损失计算方式。

实验验证

噪声样本检测任务中各数据价值评估方法的性能对比。指标越高表示性能越好。

上图展示了各种数据价值评估方法在噪声样本检测任务中的性能对比。该任务基于OpenDataVal基准测试框架:首先在训练数据的p%样本中注入噪声,然后利用数据价值评估方法识别这些噪声样本。评估方法的性能通过其识别噪声样本的准确度(F1分数)来衡量。图中结果是在6个分类数据集和6个回归数据集上的平均表现。实验中考虑了三种噪声类型:标签噪声、特征噪声和混合噪声(其中混合噪声条件下,一半样本包含特征噪声,另一半包含标签噪声)。结果表明,在标签噪声和混合噪声场景下,LossVal的性能优于其他方法。但在特征噪声场景中,LAVA展现出更好的性能。

数据点移除实验(如下图所示)采用了类似的实验设计。该实验的目标是评估移除高价值数据点对模型性能的影响。理论上,更准确的数据价值评估方法会优先识别出更重要的数据点,因此移除这些点会导致模型性能更快下降。实验结果显示,LossVal在此任务上与当前最先进的方法达到相当的性能水平。

高价值数据点移除实验中各方法的性能对比。指标越低表示性能越好。

总结

LossVal方法的技术创新在于:通过梯度下降方法优化每个数据点的权重,从而量化数据点的重要性。

实验结果表明,LossVal在OpenDataVal基准测试中达到了领先性能水平。相比其他基于模型的方法,LossVal具有更低的时间复杂度,并在不同类型的噪声和任务场景下展现出更稳定的性能。

综上所述,LossVal为神经网络的数据价值评估提供了一种高效且有效的新方法。

相关推荐

Python Qt GUI设计:将UI文件转换Python文件三种妙招(基础篇—2)

在开始本文之前提醒各位朋友,Python记得安装PyQt5库文件,Python语言功能很强,但是Python自带的GUI开发库Tkinter功能很弱,难以开发出专业的GUI。好在Python语言的开放...

Connect 2.0来了,还有Nuke和Maya新集成

ftrackConnect2.0现在可以下载了--重新设计的桌面应用程序,使用户能够将ftrackStudio与创意应用程序集成,发布资产等。这个新版本的发布中还有两个Nuke和Maya新集成,...

Magicgui:不会GUI编程也能轻松构建Python GUI应用

什么是MagicguiMagicgui是一个Python库,它允许开发者仅凭简单的类型注解就能快速构建图形用户界面(GUI)应用程序。这个库基于Napari项目,利用了Python的强大类型系统,使得...

Python入坑系列:桌面GUI开发之Pyside6

阅读本章之后,你可以掌握这些内容:Pyside6的SignalsandSlots、Envents的作用,如何使用?PySide6的Window、DialogsandAlerts、Widgets...

Python入坑系列-一起认识Pyside6 designer可拖拽桌面GUI

通过本文章,你可以了解一下内容:如何安装和使用Pyside6designerdesigner有哪些的特性通过designer如何转成python代码以前以为Pyside6designer需要在下载...

pyside2的基础界面(pyside2显示图片)

今天我们来学习pyside2的基础界面没有安装过pyside2的小伙伴可以看主页代码效果...

Python GUI开发:打包PySide2应用(python 打包pyc)

之前的文章我们介绍了怎么使用PySide2来开发一个简单PythonGUI应用。这次我们来将上次完成的代码打包。我们使用pyinstaller。注意,pyinstaller默认会将所有安装的pack...

使用PySide2做窗体,到底是怎么个事?看这个能不能搞懂

PySide2是Qt框架的Python绑定,允许你使用Python创建功能强大的跨平台GUI应用程序。PySide2的基本使用方法:安装PySide2pipinstallPy...

pycharm中conda解释器无法配置(pycharm安装的解释器不能用)

之前用的好好的pycharm正常配置解释器突然不能用了?可以显示有这个环境然后确认后可以conda正在配置解释器,但是进度条结束后还是不成功!!试过了pycharm重启,pycharm重装,anaco...

Conda使用指南:从基础操作到Llama-Factory大模型微调环境搭建

Conda虚拟环境在Linux下的全面使用指南:从基础操作到Llama-Factory大模型微调环境搭建在当今的AI开发与数据分析领域,conda虚拟环境已成为Linux系统下管理项目依赖的标配工具。...

Python操作系统资源管理与监控(python调用资源管理器)

在现代计算环境中,对操作系统资源的有效管理和监控是确保应用程序性能和系统稳定性的关键。Python凭借其丰富的标准库和第三方扩展,提供了强大的工具来实现这一目标。本文将探讨Python在操作系统资源管...

本地部署开源版Manus+DeepSeek创建自己的AI智能体

1、下载安装Anaconda,设置conda环境变量,并使用conda创建python3.12虚拟环境。2、从OpenManus仓库下载代码,并安装需要的依赖。3、使用Ollama加载本地DeepSe...

一文教会你,搭建AI模型训练与微调环境,包学会的!

一、硬件要求显卡配置:需要Nvidia显卡,至少配备8G显存,且专用显存与共享显存之和需大于20G。二、环境搭建步骤1.设置文件存储路径非系统盘存储:建议将非安装版的环境文件均存放在非系统盘(如E盘...

使用scikit-learn为PyTorch 模型进行超参数网格搜索

scikit-learn是Python中最好的机器学习库,而PyTorch又为我们构建模型提供了方便的操作,能否将它们的优点整合起来呢?在本文中,我们将介绍如何使用scikit-learn中的网格搜...

如何Keras自动编码器给极端罕见事件分类

全文共7940字,预计学习时长30分钟或更长本文将以一家造纸厂的生产为例,介绍如何使用自动编码器构建罕见事件分类器。现实生活中罕见事件的数据集:背景1.什么是极端罕见事件?在罕见事件问题中,数据集是...