百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

用于稀疏向量、独热编码数据的损失函数回顾和PyTorch实现

itomcoil 2024-12-15 13:58 51 浏览


在稀疏的、独热编码编码数据上构建自动编码器


自1986年[1]问世以来,在过去的30年里,通用自动编码器神经网络已经渗透到现代机器学习的大多数主要领域的研究中。在嵌入复杂数据方面,自动编码器已经被证明是非常有效的,它提供了简单的方法来将复杂的非线性依赖编码为平凡的向量表示。但是,尽管它们的有效性已经在许多方面得到了证明,但它们在重现稀疏数据方面常常存在不足,特别是当列像一个热编码那样相互关联时。

在本文中,我将简要地讨论一种热编码(OHE)数据和一般的自动编码器。然后,我将介绍使用在一个热门编码数据上受过训练的自动编码器所带来的问题的用例。最后,我将深入讨论稀疏OHE数据重构的问题,然后介绍我发现在这些条件下运行良好的3个损失函数:

· CosineEmbeddingLoss

· Sorenson-Dice Coefficient Loss

· Multi-Task Learning Losses of Individual OHE Components

-解决了上述挑战,包括在PyTorch中实现它们的代码。

热编码数据

热编码数据是一种最简单的,但在一般机器学习场景中经常被误解的数据预处理技术。该过程将具有"N"不同类别的分类数据二值化为二进制0和1的N列。第N个类别中出现1表示该观察属于该类别。这个过程在Python中很简单,使用Scikit-Learn OneHotEncoder模块:

from sklearn.preprocessing import OneHotEncoder
import numpy as np# Instantiate a column of 10 random integers from 5 classes
x = np.random.randint(5, size=10).reshape(-1,1)print(x)
>>> [[2][3][2][2][1][1][4][1][0][4]]# Instantiate OHE() + Fit/Transform the data
ohe_encoder = OneHotEncoder(categories="auto")
encoded = ohe_encoder.fit_transform(x).todense()print(encoded)
>>> matrix([[0., 1., 0., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 1., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 0., 0., 1.]])print(list(ohe_encoder.get_feature_names()))
>>> ["x0_0", "x0_1", "x0_2", "x0_3", "x0_4"]

但是,尽管这个技巧很简单,但如果不小心,它可能很快就会失效。它可以很容易地为数据添加多余的复杂性,并改变数据上某些分类方法的有效性。例如,转换成OHE向量的列现在是相互依赖的,这种交互使得在某些类型的分类器中有效地表示数据方面变得困难。例如,如果您有一个包含15个不同类别的列,那么就需要一个深度为15的决策树来处理该热编码列中的if-then模式(当然树形模型的数据处理是不需要进行独热编码的,这里只是举例)。类似地,由于列是相互依赖的,如果使用bagging (Bootstrap聚合)的分类策略并执行特性采样,则可能会完全错过单次编码的列,或者只考虑它的部分组件类。

Autoencoders

自动编码器是一种无监督的神经网络,其工作是将数据嵌入到一种有效的压缩格式。它利用编码和解码过程将数据编码为更小的格式,然后再将更小的格式解码为原始的输入表示。利用模型重构(译码)与原始数据之间的损失对模型进行训练。


实际上,用代码表示这个网络也很容易。我们从两个函数开始:编码器模型和解码器模型。这两个"模型"都被封装在一个叫做Network的类中,它将包含我们的培训和评估的整个系统。最后,我们定义了一个Forward函数,PyTorch将它用作进入网络的入口,用于包装数据的编码和解码。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass Network(nn.Module):
   def __init__(self, input_shape: int):
      super().__init__()
      self.encode1 = nn.Linear(input_shape, 500)
      self.encode2 = nn.Linear(500, 250)
      self.encode3 = nn.Linear(250, 50)
      
      self.decode1 = nn.Linear(50, 250)
      self.decode2 = nn.Linear(250, 500)
      self.decode3 = nn.Linear(500, input_shape)   def encode(self, x: torch.Tensor):
      x = F.relu(self.encode1(x))
      x = F.relu(self.encode2(x))
      x = F.relu(self.encode3(x))
      return x   def decode(self, x: torch.Tensor):
      x = F.relu(self.decode1(x))
      x = F.relu(self.decode2(x))
      x = F.relu(self.decode3(x))
      return x   def forward(self, x: torch.Tensor):
      x = encode(x)
      x = decode(x)
      return x
def train_model(data: pd.DataFrame):
   net = Network()
   optimizer = optim.Adagrad(net.parameters(), lr=1e-3, weight_decay=1e-4)
   losses = []   for epoch in range(250):
     for batch in get_batches(data)
        net.zero_grad()
        
        # Pass batch through 
        output = net(batch)
        
        # Get Loss + Backprop
        loss = loss_fn(output, batch).sum() # 
        losses.append(loss)
        loss.backward()
        optimizer.step()
     return net, losses

正如我们在上面看到的,我们有一个编码函数,它从输入数据的形状开始,然后随着它向下传播到形状为50而降低它的维数。从那里,解码层接受嵌入,然后将其扩展回原来的形状。在训练中,我们从译码器中取出重构的结果,并取出重构与原始输入的损失。

损失函数的问题

所以现在我们已经讨论了自动编码器的结构和一个热编码过程,我们终于可以讨论与使用一个热编码在自动编码器相关的问题,以及如何解决这个问题。当一个自动编码器比较重建到原始输入数据,必须有一些估值之间的距离提出重建和真实的价值。通常,在输出值被认为互不相干的情况下,将使用交叉熵损失或MSE损失。但在我们的一个热编码的情况下,有几个问题,使系统更复杂:

· 一列出现1意味着对应的OHE列必须有一个0。即列不是不相交的

· OHE向量输入的稀疏性会导致系统选择简单地将大多数列返回0以减少误差

这些问题结合起来导致上述两个损失(MSE,交叉熵)在重构稀疏OHE数据时无效。下面我将介绍三种损失,它们提供了一个解决方案,或上述问题,并在PyTorch实现它们的代码:

余弦嵌入损失

余弦距离是一种经典的向量距离度量,常用于NLP问题中比较字包表示。通过求两个向量之间的余弦来计算距离,计算方法为:


由于该方法能够考虑到各列中二进制值的偏差来评估两个向量之间的距离,因此在稀疏嵌入重构中,该方法能够很好地量化误差。这种损失是迄今为止在PyTorch中最容易实现的,因为它在 Torch.nn.CosineEmbeddingLoss中有一个预先构建的解决方案

loss_function = torch.nn.CosineEmbeddingLoss(reduction='none')# . . . Then during training . . . loss = loss_function(reconstructed, input_data).sum()
loss.backward()

Dice Loss

Dice Loss是一个实现S?rensen-Dice系数[2],这是非常受欢迎的计算机视觉领域的分割任务。简单地说,它是两个集合之间重叠的度量,并且与两个向量之间的Jaccard距离有关。骰子系数对向量中列值的差异高度敏感,利用这种敏感性有效地区分图像中像素的边缘,因此在图像分割中非常流行。Dice Loss为:


PyTorch没有内部实现的Dice Loss。但是在Kaggle上可以在其丢失函数库- Keras & PyTorch[3]中找到一个很好的实现:

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid acitvation
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/
               (inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

不同OHE列的单个损失函数

最后,您可以将每个热编码列视为其自身的分类问题,并承担每个分类的损失。这是一个多任务学习问题的用例,其中autoencoder正在解决重构输入向量的各个分量的问题。当你有几个/所有的列在你的输入数据时,这个工作最好。例如,如果您有一个编码列,前7列是7个类别:您可以将其视为一个多类分类问题,并将损失作为子问题的交叉熵损失。然后,您可以将子问题的损失合并在一起,并将其作为整个批的损失向后传递。


下面您将看到这个过程的示例,其中示例有三个热编码的列,每个列有50个类别。

from torch.nn.modules import _Loss
from torch import argmaxclass CustomLoss(_Loss):
  def __init__(self):
    super(CustomLoss, self).__init__()  def forward(self, input, target):
    """ loss function called at runtime """
   
    # Class 1 - Indices [0:50]
    class_1_loss = F.nll_loss(
        F.log_softmax(input[:, 0:50], dim=1), 
        argmax(target[:, 0:50])
    )    # Class 2 - Indices [50:100]
    class_2_loss = F.nll_loss(
        F.log_softmax(input[:, 50:100], dim=1), 
        argmax(target[:, 50:100])
    )    # Class 3 - Indices [100:150]
    class_3_loss = F.nll_loss(
        F.log_softmax(input[:, 100:150], dim=1), 
        argmax(target[:, 100:150])
    )    return class_1_loss + class_2_loss + class_3_loss

在上面的代码中,您可以看到重构输出的子集是如何承受个体损失的,然后在最后将其合并为一个总和。这里我们使用了一个负对数似然损失(nll_loss),它是一个很好的损失函数用于多类分类方案,并与交叉熵损失有关。

总结

在本文中,我们浏览了一个独热编码分类变量的概念,以及自动编码器的一般结构和目标。我们讨论了一个热编码向量的缺点,以及在尝试训练稀疏的、一个独热编码数据的自编码器模型时的主要问题。最后,我们讨论了解决稀疏一热编码问题的3个损失函数。训练这些网络并没有更好或更坏的损失,在我所介绍的功能中,没有办法知道哪个是适合您的用例的,除非您尝试它们!

下面我提供了一些深入讨论上述主题的资源,以及一些我提供的关于丢失函数的资源。

资源

1. D.E. Rumelhart, G.E. Hinton, and R.J. Williams, "Learning internal representations by error propagation." Parallel Distributed Processing. Vol 1: Foundations. MIT Press, Cambridge, MA, 1986.

1. S?rensen, T. (1948). "A method of establishing groups of equal amplitude in plant sociology based on similarity of species and its application to analyses of the vegetation on Danish commons". Kongelige Danske Videnskabernes Selskab. 5 (4): 1–34. AND\ Dice, Lee R. (1945). "Measures of the Amount of Ecologic Association Between Species". Ecology. 26 (3): 297–302.

1. Kaggle's Loss Function Library: https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch


作者:Nick Hespe


deephub翻译组

相关推荐

《Queendom》宣布冠军!女团MAMAMOO四人激动落泪

网易娱乐11月1日报道据台湾媒体报道,南韩女团竞争回归的生死斗《Queendom》昨(10/31)晚播出大决赛,并以直播方式进行,6组女团、女歌手皆演唱新歌,并加总前三轮的赛前赛、音源成绩与直播现场投...

正确复制、重写别人的代码,不算抄袭

我最近在一篇文章提到,工程师应该怎样避免使用大量的库、包以及其他依赖关系。我建议的另一种方案是,如果你没有达到重用第三方代码的阈值时,那么你就可以自己编写代码。在本文中,我将讨论一个在重用和从头开始编...

HTML DOM tr 对象_html event对象

tr对象tr对象代表了HTML表格的行。HTML文档中出现一个<tr>标签,就会创建一个tr对象。tr对象集合W3C:W3C标签。集合描述W3Ccells返回...

JS 打造动态表格_js如何动态改变表格内容

后台列表页最常见的需求:点击表头排序+一键全选。本文用原生js代码实现零依赖方案,涵盖DOM查询、排序算法、事件代理三大核心技能。效果速览一、核心思路事件入口:为每个<th>绑...

连肝7个晚上,总结了66条计算机网络的知识点

作者|哪吒来源|程序员小灰(ID:chengxuyuanxiaohui)计算机网络知识是面试常考的内容,在实际工作中经常涉及。最近,我总结了66条计算机网络相关的知识点。1、比较http0....

Vue 中 强制组件重新渲染的正确方法

作者:MichaelThiessen译者:前端小智来源:hackernoon有时候,依赖Vue响应方式来更新数据是不够的,相反,我们需要手动重新渲染组件来更新数据。或者,我们可能只想抛开当前的...

为什么100个前端只有1人能说清?浏览器重排/重绘深度解析

面试现场的"致命拷问""你的项目里做过哪些性能优化?能具体讲讲重排和重绘的区别吗?"作为面试官,我在秋招季连续面试过100多位前端候选人,这句提问几乎成了必考题。但令...

HTML DOM 介绍_dom4j html

HTMLDOM(文档对象模型)是一种基于文档的编程接口,它是HTML和XML文档的编程接口。它可以让开发人员通过JavaScript或其他脚本语言来访问和操作HTML和XML文档...

JavaScript 事件——“事件流和事件处理程序”的注意要点

事件流事件流描述的是从页面中接收事件的顺序。IE的事件流是事件冒泡流,而NetscapeCommunicator的事件流是事件捕获流。事件冒泡即事件开始时由最具体的元素接收,然后逐级向上传播到较为不...

探秘 Web 水印技术_水印制作网页

作者:fransli,腾讯PCG前端开发工程师Web水印技术在信息安全和版权保护等领域有着广泛的应用,对防止信息泄露或知识产品被侵犯有重要意义。水印根据可见性可分为可见水印和不可见水印(盲水印)...

国外顶流网红为流量拍摄性侵女学生?仅被封杀三月,回归仍爆火

曾经的油管之王,顶流网红DavidDobrik复出了。一切似乎都跟他因和成员灌酒性侵女学生被骂到退网之前一样:住在950万美元的豪宅,开着20万美元的阿斯顿马丁,每条视频都有数百万观看...人们仿佛...

JavaScript 内存泄漏排查方法_js内存泄漏及解决方法

一、概述本文主要介绍了如何通过Devtools的Memory内存工具排查JavaScript内存泄漏问题。先介绍了一些相关概念,说明了Memory内存工具的使用方式,然后介绍了堆快照的...

外贸独立站,网站优化的具体内容_外贸独立站,网站优化的具体内容有哪些

Wordpress网站优化,是通过优化代码、数据库、缓存、CSS/JS等内容,提升网站加载速度、交互性和稳定性。网站加载速度,是Google搜索引擎的第一权重,也是SEO优化的前提。1.优化渲染阻塞。...

这8个CSS工具可以提升编程速度_css用什么编译器

下面为大家推荐的这8个CSS工具,有提供函数的,有提供类的,有提取代码的,还有收集CSS的统计数据的……请花费两分钟的时间看完这篇文章,或许你会找到意外的惊喜,并且为你的编程之路打开了一扇新的大门。1...

vue的理解-vue源码 历史 简介 核心特性 和jquery区别 和 react对比

一、从历史说起Web是WorldWideWeb的简称,中文译为万维网我们可以将它规划成如下的几个时代来进行理解石器时代文明时代工业革命时代百花齐放时代石器时代石器时代指的就是我们的静态网页,可以欣...