百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

用于稀疏向量、独热编码数据的损失函数回顾和PyTorch实现

itomcoil 2024-12-15 13:58 25 浏览


在稀疏的、独热编码编码数据上构建自动编码器


自1986年[1]问世以来,在过去的30年里,通用自动编码器神经网络已经渗透到现代机器学习的大多数主要领域的研究中。在嵌入复杂数据方面,自动编码器已经被证明是非常有效的,它提供了简单的方法来将复杂的非线性依赖编码为平凡的向量表示。但是,尽管它们的有效性已经在许多方面得到了证明,但它们在重现稀疏数据方面常常存在不足,特别是当列像一个热编码那样相互关联时。

在本文中,我将简要地讨论一种热编码(OHE)数据和一般的自动编码器。然后,我将介绍使用在一个热门编码数据上受过训练的自动编码器所带来的问题的用例。最后,我将深入讨论稀疏OHE数据重构的问题,然后介绍我发现在这些条件下运行良好的3个损失函数:

· CosineEmbeddingLoss

· Sorenson-Dice Coefficient Loss

· Multi-Task Learning Losses of Individual OHE Components

-解决了上述挑战,包括在PyTorch中实现它们的代码。

热编码数据

热编码数据是一种最简单的,但在一般机器学习场景中经常被误解的数据预处理技术。该过程将具有"N"不同类别的分类数据二值化为二进制0和1的N列。第N个类别中出现1表示该观察属于该类别。这个过程在Python中很简单,使用Scikit-Learn OneHotEncoder模块:

from sklearn.preprocessing import OneHotEncoder
import numpy as np# Instantiate a column of 10 random integers from 5 classes
x = np.random.randint(5, size=10).reshape(-1,1)print(x)
>>> [[2][3][2][2][1][1][4][1][0][4]]# Instantiate OHE() + Fit/Transform the data
ohe_encoder = OneHotEncoder(categories="auto")
encoded = ohe_encoder.fit_transform(x).todense()print(encoded)
>>> matrix([[0., 1., 0., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 1., 0., 0.],
           [1., 0., 0., 0., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 1., 0., 0.],
           [0., 0., 0., 1., 0.],
           [0., 0., 0., 0., 1.]])print(list(ohe_encoder.get_feature_names()))
>>> ["x0_0", "x0_1", "x0_2", "x0_3", "x0_4"]

但是,尽管这个技巧很简单,但如果不小心,它可能很快就会失效。它可以很容易地为数据添加多余的复杂性,并改变数据上某些分类方法的有效性。例如,转换成OHE向量的列现在是相互依赖的,这种交互使得在某些类型的分类器中有效地表示数据方面变得困难。例如,如果您有一个包含15个不同类别的列,那么就需要一个深度为15的决策树来处理该热编码列中的if-then模式(当然树形模型的数据处理是不需要进行独热编码的,这里只是举例)。类似地,由于列是相互依赖的,如果使用bagging (Bootstrap聚合)的分类策略并执行特性采样,则可能会完全错过单次编码的列,或者只考虑它的部分组件类。

Autoencoders

自动编码器是一种无监督的神经网络,其工作是将数据嵌入到一种有效的压缩格式。它利用编码和解码过程将数据编码为更小的格式,然后再将更小的格式解码为原始的输入表示。利用模型重构(译码)与原始数据之间的损失对模型进行训练。


实际上,用代码表示这个网络也很容易。我们从两个函数开始:编码器模型和解码器模型。这两个"模型"都被封装在一个叫做Network的类中,它将包含我们的培训和评估的整个系统。最后,我们定义了一个Forward函数,PyTorch将它用作进入网络的入口,用于包装数据的编码和解码。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass Network(nn.Module):
   def __init__(self, input_shape: int):
      super().__init__()
      self.encode1 = nn.Linear(input_shape, 500)
      self.encode2 = nn.Linear(500, 250)
      self.encode3 = nn.Linear(250, 50)
      
      self.decode1 = nn.Linear(50, 250)
      self.decode2 = nn.Linear(250, 500)
      self.decode3 = nn.Linear(500, input_shape)   def encode(self, x: torch.Tensor):
      x = F.relu(self.encode1(x))
      x = F.relu(self.encode2(x))
      x = F.relu(self.encode3(x))
      return x   def decode(self, x: torch.Tensor):
      x = F.relu(self.decode1(x))
      x = F.relu(self.decode2(x))
      x = F.relu(self.decode3(x))
      return x   def forward(self, x: torch.Tensor):
      x = encode(x)
      x = decode(x)
      return x
def train_model(data: pd.DataFrame):
   net = Network()
   optimizer = optim.Adagrad(net.parameters(), lr=1e-3, weight_decay=1e-4)
   losses = []   for epoch in range(250):
     for batch in get_batches(data)
        net.zero_grad()
        
        # Pass batch through 
        output = net(batch)
        
        # Get Loss + Backprop
        loss = loss_fn(output, batch).sum() # 
        losses.append(loss)
        loss.backward()
        optimizer.step()
     return net, losses

正如我们在上面看到的,我们有一个编码函数,它从输入数据的形状开始,然后随着它向下传播到形状为50而降低它的维数。从那里,解码层接受嵌入,然后将其扩展回原来的形状。在训练中,我们从译码器中取出重构的结果,并取出重构与原始输入的损失。

损失函数的问题

所以现在我们已经讨论了自动编码器的结构和一个热编码过程,我们终于可以讨论与使用一个热编码在自动编码器相关的问题,以及如何解决这个问题。当一个自动编码器比较重建到原始输入数据,必须有一些估值之间的距离提出重建和真实的价值。通常,在输出值被认为互不相干的情况下,将使用交叉熵损失或MSE损失。但在我们的一个热编码的情况下,有几个问题,使系统更复杂:

· 一列出现1意味着对应的OHE列必须有一个0。即列不是不相交的

· OHE向量输入的稀疏性会导致系统选择简单地将大多数列返回0以减少误差

这些问题结合起来导致上述两个损失(MSE,交叉熵)在重构稀疏OHE数据时无效。下面我将介绍三种损失,它们提供了一个解决方案,或上述问题,并在PyTorch实现它们的代码:

余弦嵌入损失

余弦距离是一种经典的向量距离度量,常用于NLP问题中比较字包表示。通过求两个向量之间的余弦来计算距离,计算方法为:


由于该方法能够考虑到各列中二进制值的偏差来评估两个向量之间的距离,因此在稀疏嵌入重构中,该方法能够很好地量化误差。这种损失是迄今为止在PyTorch中最容易实现的,因为它在 Torch.nn.CosineEmbeddingLoss中有一个预先构建的解决方案

loss_function = torch.nn.CosineEmbeddingLoss(reduction='none')# . . . Then during training . . . loss = loss_function(reconstructed, input_data).sum()
loss.backward()

Dice Loss

Dice Loss是一个实现S?rensen-Dice系数[2],这是非常受欢迎的计算机视觉领域的分割任务。简单地说,它是两个集合之间重叠的度量,并且与两个向量之间的Jaccard距离有关。骰子系数对向量中列值的差异高度敏感,利用这种敏感性有效地区分图像中像素的边缘,因此在图像分割中非常流行。Dice Loss为:


PyTorch没有内部实现的Dice Loss。但是在Kaggle上可以在其丢失函数库- Keras & PyTorch[3]中找到一个很好的实现:

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid acitvation
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/
               (inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

不同OHE列的单个损失函数

最后,您可以将每个热编码列视为其自身的分类问题,并承担每个分类的损失。这是一个多任务学习问题的用例,其中autoencoder正在解决重构输入向量的各个分量的问题。当你有几个/所有的列在你的输入数据时,这个工作最好。例如,如果您有一个编码列,前7列是7个类别:您可以将其视为一个多类分类问题,并将损失作为子问题的交叉熵损失。然后,您可以将子问题的损失合并在一起,并将其作为整个批的损失向后传递。


下面您将看到这个过程的示例,其中示例有三个热编码的列,每个列有50个类别。

from torch.nn.modules import _Loss
from torch import argmaxclass CustomLoss(_Loss):
  def __init__(self):
    super(CustomLoss, self).__init__()  def forward(self, input, target):
    """ loss function called at runtime """
   
    # Class 1 - Indices [0:50]
    class_1_loss = F.nll_loss(
        F.log_softmax(input[:, 0:50], dim=1), 
        argmax(target[:, 0:50])
    )    # Class 2 - Indices [50:100]
    class_2_loss = F.nll_loss(
        F.log_softmax(input[:, 50:100], dim=1), 
        argmax(target[:, 50:100])
    )    # Class 3 - Indices [100:150]
    class_3_loss = F.nll_loss(
        F.log_softmax(input[:, 100:150], dim=1), 
        argmax(target[:, 100:150])
    )    return class_1_loss + class_2_loss + class_3_loss

在上面的代码中,您可以看到重构输出的子集是如何承受个体损失的,然后在最后将其合并为一个总和。这里我们使用了一个负对数似然损失(nll_loss),它是一个很好的损失函数用于多类分类方案,并与交叉熵损失有关。

总结

在本文中,我们浏览了一个独热编码分类变量的概念,以及自动编码器的一般结构和目标。我们讨论了一个热编码向量的缺点,以及在尝试训练稀疏的、一个独热编码数据的自编码器模型时的主要问题。最后,我们讨论了解决稀疏一热编码问题的3个损失函数。训练这些网络并没有更好或更坏的损失,在我所介绍的功能中,没有办法知道哪个是适合您的用例的,除非您尝试它们!

下面我提供了一些深入讨论上述主题的资源,以及一些我提供的关于丢失函数的资源。

资源

1. D.E. Rumelhart, G.E. Hinton, and R.J. Williams, "Learning internal representations by error propagation." Parallel Distributed Processing. Vol 1: Foundations. MIT Press, Cambridge, MA, 1986.

1. S?rensen, T. (1948). "A method of establishing groups of equal amplitude in plant sociology based on similarity of species and its application to analyses of the vegetation on Danish commons". Kongelige Danske Videnskabernes Selskab. 5 (4): 1–34. AND\ Dice, Lee R. (1945). "Measures of the Amount of Ecologic Association Between Species". Ecology. 26 (3): 297–302.

1. Kaggle's Loss Function Library: https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch


作者:Nick Hespe


deephub翻译组

相关推荐

Excel新函数TEXTSPLIT太强大了,轻松搞定数据拆分!

我是【桃大喵学习记】,欢迎大家关注哟~,每天为你分享职场办公软件使用技巧干货!最近我把WPS软件升级到了版本号:12.1.0.15990的最新版本,最版本已经支持文本拆分函数TEXTSPLIT了,并...

Excel超强数据拆分函数TEXTSPLIT,从入门到精通!

我是【桃大喵学习记】,欢迎大家关注哟~,每天为你分享职场办公软件使用技巧干货!今天跟大家分享的是Excel超强数据拆分函数TEXTSPLIT,带你从入门到精通!TEXTSPLIT函数真是太强大了,轻松...

看完就会用的C++17特性总结(c++11常用新特性)

作者:taoklin,腾讯WXG后台开发一、简单特性1.namespace嵌套C++17使我们可以更加简洁使用命名空间:2.std::variant升级版的C语言Union在C++17之前,通...

plsql字符串分割浅谈(plsql字符集设置)

工作之中遇到的小问题,在此抛出问题,并给出解决方法。一方面是为了给自己留下深刻印象,另一方面给遇到相似问题的同学一个解决思路。如若其中有写的不好或者不对的地方也请不加不吝赐教,集思广益,共同进步。遇到...

javascript如何分割字符串(javascript切割字符串)

javascript如何分割字符串在JavaScript中,您可以使用字符串的`split()`方法来将一个字符串分割成一个数组。`split()`方法接收一个参数,这个参数指定了分割字符串的方式。如...

TextSplit函数的使用方法(入门+进阶+高级共八种用法10个公式)

在Excel和WPS新增的几十个函数中,如果按实用性+功能性排名,textsplit排第二,无函数敢排第一。因为它不仅使用简单,而且解决了以前用超复杂公式才能搞定的难题。今天小编用10个公式,让你彻底...

Python字符串split()方法使用技巧

在Python中,字符串操作可谓是基础且关键的技能,而今天咱们要重点攻克的“堡垒”——split()方法,它能将看似浑然一体的字符串,按照我们的需求进行拆分,极大地便利了数据处理与文本解析工作。基本语...

go语言中字符串常用的系统函数(golang 字符串)

最近由于工作比较忙,视频有段时间没有更新了,在这里跟大家说声抱歉了,我尽快抽些时间整理下视频今天就发一篇关于go语言的基础知识吧!我这我工作中用到的一些常用函数,汇总出来分享给大家,希望对...

无规律文本拆分,这些函数你得会(没有分隔符没规律数据拆分)

今天文章来源于表格学员训练营群内答疑,混合文本拆分。其实拆分不难,只要规则明确就好办。就怕规则不清晰,或者规则太多。那真是,Oh,mygod.如上图所示进行拆分,文字表达实在是有点难,所以小熊变身灵...

Python之文本解析:字符串格式化的逆操作?

引言前面的文章中,提到了关于Python中字符串中的相关操作,更多地涉及到了字符串的格式化,有些地方也称为字符串插值操作,本质上,就是把多个字符串拼接在一起,以固定的格式呈现。关于字符串的操作,其实还...

忘记【分列】吧,TEXTSPLIT拆分文本好用100倍

函数TEXTSPLIT的作用是:按分隔符将字符串拆分为行或列。仅ExcelM365版本可用。基本应用将A2单元格内容按逗号拆分。=TEXTSPLIT(A2,",")第二参数设置为逗号...

Excel365版本新函数TEXTSPLIT,专攻文本拆分

Excel中字符串的处理,拆分和合并是比较常见的需求。合并,当前最好用的函数非TEXTJOIN不可。拆分,Office365于2022年3月更新了一个专业函数:TEXTSPLIT语法参数:【...

站长在线Python精讲使用正则表达式的split()方法分割字符串详解

欢迎你来到站长在线的站长学堂学习Python知识,本文学习的是《在Python中使用正则表达式的split()方法分割字符串详解》。使用正则表达式分割字符串在Python中使用正则表达式的split(...

Java中字符串分割的方法(java字符串切割方法)

技术背景在Java编程中,经常需要对字符串进行分割操作,例如将一个包含多个信息的字符串按照特定的分隔符拆分成多个子字符串。常见的应用场景包括解析CSV文件、处理网络请求参数等。实现步骤1.使用Str...

因为一个函数strtok踩坑,我被老工程师无情嘲笑了

在用C/C++实现字符串切割中,strtok函数经常用到,其主要作用是按照给定的字符集分隔字符串,并返回各子字符串。但是实际上,可不止有strtok(),还有strtok、strtok_s、strto...