只有正样本和无标记数据的半监督学习(PU Learning)
itomcoil 2025-05-23 17:46 15 浏览
作者:Alon Agmon
编译:ronghuaiyang
正文共:5411 字 6 图
预计阅读时间:16 分钟
导读
在实际业务场景中,可能只会收到正反馈,所以反映到数据上,就只有正样本,另外就是大量的没有标记的样本,那么如何给这些没有标记的样本打上标签呢?
当你只有几个正样本的时候,如何分类未标注的数据
假设您有一个支付事务数据集。其中一些交易被标记为欺诈,其余的被标记为真实交易,你需要设计一个模型来区分欺诈交易和真实交易。假设你有足够的数据和良好的特征,这似乎是一个简单的分类任务。但是,假设只有15%的数据有标注,并且标注的样本只属于一个类,因此你的训练集由15%标记为真实的样本组成,而其余的没有标记,可能是真实的,也可能是虚假的。你如何对它们进行分类?这样的需求是否只是将这个任务变成了一个无监督的学习问题?好吧,不一定。
这个问题 —— 通常被称为PU(正样本的和未标记的样本)的分类问题 —— 应该首先从两个相似且常见的“标注问题”中区分出来,这两个问题使许多分类任务复杂化。第一个也是最常见的标签问题是“小训练集”的问题。当你有相当数量的数据,但只有一小部分被标记时,它就会出现。这个问题有许多种类和相当多的具体训练方法。另一个常见的标记问题(通常与PU问题合并在一起)涉及的情况是,我们的训练数据集全都有标记,但它只包含一个类。例如,假设我们只有一个非欺诈事务的数据集,并且我们需要使用这个数据集来训练一个模型来区分(类似的)非欺诈事务和欺诈事务。这也是一个常见的问题,通常被视为无监督的离群点检测问题,尽管在ML领域中也有相当多的工具是专门设计来处理这些场景的(OneClassSVM可能是最著名的)。
相比之下,PU分类问题涉及到一个训练集,其中只有部分数据被标记为正,而其余数据未被标记,可能是正的,也可能是负的。例如,假设你的雇主是一家银行,它可以为你提供大量的事务性数据,但只能确认其中的一部分是100%真实的。我在这里使用的例子涉及到关于伪钞的类似场景。它包括了1200张纸币的数据集,其中大部分没有标记,只有一部分被确认为真实的。虽然PU问题也很常见,但是与前面提到的两个分类问题相比,它们的讨论要少得多,而且很少有实践的例子或库可以广泛使用。
本文的目的是提供一种可能的方法来解决PU问题,我最近在一个分类项目中使用了这种方法。它是基于Charles Elkan和Keith Noto写的论文“Learning classifiers from only positive and unlabeled data”(2008),以及由Alexandre Drouin写的一些代码。尽管在文章中有更多的PU学习方法(我打算在以后的文章中讨论另一种相当流行的方法),Elkan和Noto的(E&N)方法非常简单,可以很容易地在Python中实现。
一点点理论(请原谅)
E&N本质上声称,给定一个数据集,我们有正的和未标记的数据,某个样本标记为正的概率是 [ P(y=1|x)] 的概率等于样本被标记的概率 [P(s=1|x)] 除以我们的数据集中正样本被标记的概率[P(s=1|y=1)]。
如果这个断言是正确的,那么实现起来就相对容易了。这是因为虽然我们没有足够的数据来训练分类器来告诉我们样本是正的还是负的,在PU场景中我们确实有足够的标签数据告诉我们正样本是否可能被标记,根据E&N,这足以估计有多可能是正的。
更正式地说,给定一个未标记的数据集,其中只有一组标记为正的样本,如果我们估计P(s=1|x) / P(s=1|y=1),我们就可以估计未标记的样本x为正的概率。幸运的是,我们几乎可以使用任何基于sklearn的分类器,按照以下步骤来估计:
(1)在包含已标记和未标记数据的数据集上拟合一个分类器,同时使用isLabeled作为目标y。以这种方式拟合分类器,训练它预测给定样本x被标记的概率P(s=1|x)。
(2)使用分类器预测数据集中已知正样本被标记的概率,预测结果可以表示为正样本被标记的概率 P(s=1|y=1|x),计算这些预测概率的平均值,这就是我们的P(s=1|y=1)。有了P(s=1|y=1)的估计值,为了预测数据点k为正的概率,根据E&N,我们需要做的就是估计P(s=1|k)或它被标记的概率,这正是我们训练的分类器(1)知道如何做的。
(3)使用我们在(1)上训练的分类器来估计k被标记或P(s=1|k)的概率。
(4)一旦我们估算出P(s=1|k),我们就可以将这个概率除以P(s=1|y=1) ,这是在步骤(2)上估算出来的,这样就可以得到它属于这两类的实际概率。
我们现在写代码并进行测试
以上步骤1-4可按如下方式实施:
# prepare data
x_data = the training set
y_data = target var (1 for the positives and not-1 for the rest)
# fit the classifier and estimate P(s=1|y=1)
classifier, ps1y1 =
fit_PU_estimator(x_data, y_data, 0.2, Estimator())
# estimate the prob that x_data is labeled P(s=1|X)
predicted_s = classifier.predict_proba(x_data)
# estimate the actual probabilities that X is positive
# by calculating P(s=1|X) / P(s=1|y=1)
predicted_y = estimated_s / ps1y1
让我们从这里开始:fit_PU_estimator()方法。
fit_PU_estimator()方法完成了两个主要任务:它拟合一个分类器,你选择一个具有正样本和未标记样本的训练集,然后估计一个正样本被标记的概率。相应地,它返回拟合的分类器(学会估计给定样本被标记的概率)和估计的概率P(s=1|y=1)。之后,我们需要做的就是找到P(s=1|x)或者标记为x的概率。因为这就是我们训练的分类器要做的,我们只需要调用它的predict_proba()方法。最后,为了实际对样本x进行分类,我们只需要将结果除以我们已经找到的P(s=1|y=1)。这可以用代码表示为:
pu_estimator, probs1y1 = fit_PU_estimator(
x_train,
y_train,
0.2,
xgb.XGBClassifier())
predicted_s = pu_estimator.predict_proba(x_train)
predicted_s = predicted_s[:,1]
predicted_y = predicted_s / probs1y1
实现fit_PU_estimator()方法本身非常简单:
def fit_PU_estimator(X,y, hold_out_ratio, estimator):
# The training set will be divided into a fitting-set that will be used
# to fit the estimator in order to estimate P(s=1|X) and a held-out set of positive samples
# that will be used to estimate P(s=1|y=1)
# --------
# find the indices of the positive/labeled elements
assert (type(y) == np.ndarray), "Must pass np.ndarray rather than list as y"
positives = np.where(y == 1.)[0]
# hold_out_size = the *number* of positives/labeled samples
# that we will use later to estimate P(s=1|y=1)
hold_out_size = int(np.ceil(len(positives) * hold_out_ratio))
np.random.shuffle(positives)
# hold_out = the *indices* of the positive elements
# that we will later use to estimate P(s=1|y=1)
hold_out = positives[:hold_out_size]
# the actual positive *elements* that we will keep aside
X_hold_out = X[hold_out]
# remove the held out elements from X and y
X = np.delete(X, hold_out,0)
y = np.delete(y, hold_out)
# We fit the estimator on the unlabeled samples + (part of the) positive and labeled ones.
# In order to estimate P(s=1|X) or what is the probablity that an element is *labeled*
estimator.fit(X, y)
# We then use the estimator for prediction of the positive held-out set
# in order to estimate P(s=1|y=1)
hold_out_predictions = estimator.predict_proba(X_hold_out)
#take the probability that it is 1
hold_out_predictions = hold_out_predictions[:,1]
# save the mean probability
c = np.mean(hold_out_predictions)
return estimator, c
def predict_PU_prob(X, estimator, prob_s1y1):
prob_pred = estimator.predict_proba(X)
prob_pred = prob_pred[:,1]
return prob_pred / prob_s1y1
为了测试这一点,我使用了[Bank Note Authentication dataset](
http://archive.ics.uci.edu/ml/datasets/banknote+ Authentication),它基于从真钞和假钞图像中提取的4个数据点。第一次,我使用标记数据集上的分类器来设置一个基线,然后移除了75%的样本的标签,以测试在P&U数据集上执行的如何。如输出所示,这个的数据集不是最很难分类,但你可以看到,虽然PU分类器只是“知道”153个正样本,而其余1219个样本是没有标记的,它表现的和知道了所有的标记样本的分类器差不多。然而,它确实损失了17%的召回率,因此损失了相当多的正样本。不过无论怎样,相比于其他的方法,我相信这些结果是相当令人满意的。
===>> load data set <<===
data size: (1372, 5)
Target variable (fraud or not):
0 762
1 610
===>> create baseline classification results <<===
Classification results:
f1: 99.57%
roc: 99.57%
recall: 99.15%
precision: 100.00%
===>> classify on all the data set <<===
Target variable (labeled or not):
-1 1219
1 153
Classification results:
f1: 90.24%
roc: 91.11%
recall: 82.62%
precision: 99.41%
一些重点。首先,这种方法的性能在很大程度上取决于数据集的大小。在本例中,我使用了大约150个正样本和1200个未标记样本。这远不是这种方法的理想数据集。例如,如果我们只有100个样本,我们的分类器就会表现得很差。其次,正如所附的notebook所示,有一些变量需要调优(例如要设置的样本大小、用于分类的概率阈值等),但最重要的可能是所选的分类器及其参数。我选择使用XGBoost是因为它在具有很少特征的小型数据集上执行得相对较好,但需要注意的是,它并不是在所有场景中都执行得最好,测试正确的分类器非常重要。
代码在这里:
https://github.com/a-agmon/pu-learn/blob/master/PU_Learning_EN.ipynb
英文原文:
https://towardsdatascience.com/semi-supervised-classification-of-unlabeled-data-pu-learning-81f96e96f7cb
相关推荐
- 最强聚类模型,层次聚类 !!_层次聚类的优缺点
-
哈喽,我是小白~咱们今天聊聊层次聚类,这种聚类方法在后面的使用,也是非常频繁的~首先,聚类很好理解,聚类(Clustering)就是把一堆“东西”自动分组。这些“东西”可以是人、...
- python决策树用于分类和回归问题实际应用案例
-
决策树(DecisionTrees)通过树状结构进行决策,在每个节点上根据特征进行分支。用于分类和回归问题。实际应用案例:预测一个顾客是否会流失。决策树是一种基于树状结构的机器学习算法,用于解决分类...
- Python教程(四十五):推荐系统-个性化推荐算法
-
今日目标o理解推荐系统的基本概念和类型o掌握协同过滤算法(用户和物品)o学会基于内容的推荐方法o了解矩阵分解和深度学习推荐o掌握推荐系统评估和优化技术推荐系统概述推荐系统是信息过滤系统,用于...
- 简单学Python——NumPy库7——排序和去重
-
NumPy数组排序主要用sort方法,sort方法只能将数值按升充排列(可以用[::-1]的切片方式实现降序排序),并且不改变原数组。例如:importnumpyasnpa=np.array(...
- PyTorch实战:TorchVision目标检测模型微调完
-
PyTorch实战:TorchVision目标检测模型微调完整教程一、什么是微调(Finetuning)?微调(Finetuning)是指在已经预训练好的模型基础上,使用自己的数据对模型进行进一步训练...
- C4.5算法解释_简述c4.5算法的基本思想
-
C4.5算法是ID3算法的改进版,它在特征选择上采用了信息增益比来解决ID3算法对取值较多的特征有偏好的问题。C4.5算法也是一种用于决策树构建的算法,它同样基于信息熵的概念。C4.5算法的步骤如下:...
- Python中的数据聚类及可视化分析实践
-
探索如何通过聚类分析揭露糖尿病预测数据集的特征!我们将运用Python的强力工具,深入挖掘数据,以直观的可视化揭示不同特征间的关系。一同探索聚类分析在糖尿病预测中的实践!所有这些可视化都可以通过数据操...
- 用Python来统计大乐透号码的概率分布
-
用Python来统计大乐透号码的概率分布,可以按照以下步骤进行:导入所需的库:使用Python中的numpy库生成数字序列,使用matplotlib库生成概率分布图。读取大乐透历史数据:从网络上找到大...
- python:支持向量机监督学习算法用于二分类和多分类问题示例
-
监督学习-支持向量机(SVM)支持向量机(SupportVectorMachine,简称SVM)是一种常用的监督学习算法,用于解决分类和回归问题。SVM的目标是找到一个最优的超平面,将不同类别的...
- 25个例子学会Pandas Groupby 操作
-
groupby是Pandas在数据分析中最常用的函数之一。它用于根据给定列中的不同值对数据点(即行)进行分组,分组后的数据可以计算生成组的聚合值。如果我们有一个包含汽车品牌和价格信息的数据集,那么可以...
- 数据挖掘流程_数据挖掘流程主要有哪些步骤
-
数据挖掘流程1.了解需求,确认目标说一下几点思考方法:做什么?目的是什么?目标是什么?为什么要做?有什么价值和意义?如何去做?完整解决方案是什么?2.获取数据pandas读取数据pd.read.c...
- 使用Python寻找图像最常见的颜色_python 以图找图
-
如果我们知道图像或对象最常见的是哪种颜色,那么可以解决图像处理中的几个用例,例如在农业领域,我们可能需要确定水果的成熟度。我们可以简单地检查一下水果的颜色是否在预定的范围内,看看它是成熟的,腐烂的,还...
- 财务预算分析全网最佳实践:从每月分析到每天分析
-
原文链接如下:「链接」掌握本文的方法,你就掌握了企业预算精细化分析的能力,全网首发。数据模拟稍微有点问题,不要在意数据细节,先看下最终效果。在编制财务预算或业务预算的过程中,通常预算的所有数据都是按月...
- 常用数据工具去重方法_数据去重公式
-
在数据处理中,去除重复数据是确保数据质量和分析准确性的关键步骤。特别是在处理多列数据时,保留唯一值组合能够有效清理数据集,避免冗余信息对分析结果的干扰。不同的工具和编程语言提供了多种方法来实现多列去重...
- Python教程(四十):PyTorch深度学习-动态计算图
-
今日目标o理解PyTorch的基本概念和动态计算图o掌握PyTorch张量操作和自动求导o学会构建神经网络模型o了解PyTorch的高级特性o掌握模型训练和部署PyTorch概述PyTorc...
- 一周热门
- 最近发表
- 标签列表
-
- ps图案在哪里 (33)
- super().__init__ (33)
- python 获取日期 (34)
- 0xa (36)
- super().__init__()详解 (33)
- python安装包在哪里找 (33)
- linux查看python版本信息 (35)
- python怎么改成中文 (35)
- php文件怎么在浏览器运行 (33)
- eval在python中的意思 (33)
- python安装opencv库 (35)
- python div (34)
- sticky css (33)
- python中random.randint()函数 (34)
- python去掉字符串中的指定字符 (33)
- python入门经典100题 (34)
- anaconda安装路径 (34)
- yield和return的区别 (33)
- 1到10的阶乘之和是多少 (35)
- python安装sklearn库 (33)
- dom和bom区别 (33)
- js 替换指定位置的字符 (33)
- python判断元素是否存在 (33)
- sorted key (33)
- shutil.copy() (33)