百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

TensorFlow 1.3的Datasets和Estimator知多少?谷歌大神来解答

itomcoil 2025-07-08 19:19 13 浏览

图:pixabay

原文来源:Google Developers Blog

作者:TensorFlow团队

「机器人圈」编译:嗯~阿童木呀、多啦A亮

在TensorFlow 1.3版本里面有两个重要的特征,你应该好好尝试一下:

o数据集(Datasets):一种创建输入流水线的全新方法(即将数据读取到程序中)。

o评估器(Estimator):一种创建TensorFlow模型的高级方法。评估器包括用于常见机器学习任务的预制模型,当然,你也可以使用它们来创建你的自定义模型。

接下来你将看到它们如何是如何适应TensorFlow架构的。如果将它们结合起来,它们将提供了一种创建TensorFlow模型并向其馈送数据的简单方法:

我们的示例模型

为了能够更好地对这些特征进行深一步探索,我们将构建一个模型并展示相关代码片段。点击此处链接,你将获得完整代码资源(
https://github.com/mhyttsten/Misc/blob/master/Blog_Estimators_DataSet.py),其中包含关于训练和测试文件的说明。有一点需要注意的是,此代码只是为了演示数据集和评估器在功能方面的有效性,因此并未针对最大性能进行优化。

一个经过训练的模型根据四种植物特征(萼片长度、萼片宽度、花瓣长度和花瓣宽度)对鸢尾花进行分类。因此,在推理过程中,你可以为这四个特征提供值,并且该模型将预测出该花是以下三种美丽的变体之一:

从左至右:Iris setosa(山鸢尾,Radomil,CC BY-SA 3.0),Iris versicolor(杂色鸢尾,Dlanglois,CC BY-SA 3.0)和Iris virginica(维吉尼亚鸢尾,Frank Mayfield,CC BY-SA 2.0)。

我们将用下面的结构对一个深度神经网络分类器进行训练。所有输入和输出值都将为float32,输出值的和为1(正如我们所预测的每个单独鸢尾花类型的概率):

例如,一个输出结果是Iris Setosa(山鸢尾)的概率为0.05,是Iris Versicolor(杂色鸢尾)的概率为0.9,是Iris Virginica(维吉尼亚鸢尾)的概率为0.05,这表明该花是Iris Versicolor(杂色鸢尾)的概率为90%。

好的!既然我们已经定义了这个模型,接下来就看一下该如何使用Datasets(数据集)和Estimator(评估器)对其进行训练并做出预测。

Datasets(数据集)的简介

Dataset是一种为TensorFlow模型创建输入流水线的新方式。相较于使用feed_dict或基于队列的流水线,这个API要好用得多,而且它更干净,更易于使用。虽然在1.3版本中,Datasets仍然位于tf.contrib.data中,但我们希望将该API移动到1.4版本中,所以现在是对其进行测试驱动器的时候了。

在高级别中,Dataset涵盖以下几级:

其中:

oDataset:包含创建和转换数据集方法的基类。还使得你能够对内存中或来自Python生成器的数据初始化数据集。

oTextLineDataset:从文本文件中读取行。

oTFRecordDataset:读取TFRecord文件中的记录。

oFixedLengthRecordDataset:从二进制文件读取固定大小的记录。

oIterator(迭代器):提供一种一次访问一个数据集元素的方法。

我们的数据集

首先,我们先来看看那些将用来馈送模型的数据集。我们将从CSV文件中读取数据,其中每行将包含五个值——四个输入值以及标签:

标签将是:

0为Iris Setosa(山鸢尾);

1为Versicolor(杂色鸢尾);

2为Virginica(维吉尼亚鸢尾)

表征数据集

为了描述我们的数据集,我们首先创建一个关于特征的列表:

feature_names = [
 'SepalLength',
 'SepalWidth',
 'PetalLength',
 'PetalWidth']

当训练模型时,我们需要一个读取输入文件并返回特征和标签数据的函数。Estimators(评估器)要求你按照以下格式创建一个函数:

def input_fn():
 ...<code>...
 return ({ 'SepalLength':[values], ..<etc>.., 'PetalWidth':[values] },
 [IrisFlowerType])

返回值必须是一个双元素元组,其组织如下:

o第一个元素必须是一个dict(命令),其中每个输入特征都是一个键,然后是训练批量的值列表。

o第二个元素是训练批量的标签列表。

由于我们返回了一批输入特征和训练标签,所以这意味着返回语句中的所有列表将具有相同的长度。从技术上说,每当我们在这里提到“列表”时,实际上指的是一个1-d TensorFlow张量。

为了使得能够重用input_fn,我们将添加一些参数。从而使得我们能够用不同的设置构建输入函数。这些配置是很简单的:

file_path:要读取的数据文件。

perform_shuffle:记录顺序是否应该是随机的。

repeat_count:迭代数据集中记录的次数。例如,如果我们指定1,则每个记录将被读取一次。如果我们指定None,则迭代将永远持续下去。

以下是使用Dataset API实现此函数的方法。我们将把它封装在一个“输入函数”中,它将与我们馈送评估器模型相适应。

def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
 def decode_csv(line):
 parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
 label = parsed_line[-1:] # Last element is the label
 del parsed_line[-1] # Delete last element
 features = parsed_line # Everything (but last element) are the features
 d = dict(zip(feature_names, features)), label return d
 dataset = (tf.contrib.data.TextLineDataset(file_path) # Read text file
 .skip(1) # Skip header row
 .map(decode_csv)) # Transform each elem by applying decode_csv fn
 if perform_shuffle:
 # Randomizes input using a window of 256 elements (read into memory)
 dataset = dataset.shuffle(buffer_size=256)
 dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
 dataset = dataset.batch(32) # Batch size to use
 iterator = dataset.make_one_shot_iterator()
 batch_features, batch_labels = iterator.get_next()
 return batch_features, batch_labels

请注意以下事项:

TextLineDataset:当你使用其基于文件的数据集时,Dataset API将为你处理大量的内存管理。例如,你可以通过指定列表作为参数,读取比内存大得多的数据集文件或读入多个文件。

Shuffle(随机化):读取buffer_size记录,然后shuffle(随机化)其顺序。

Map(映射):将数据集中的每个元素调用decode_csv函数,作为参数(因为我们使用的是TextLineDataset,每个元素都将是一行CSV文本)。然后我们将decode_csv应用于每一行。

decode_csv:将每行拆分为字段,如有必要,提供默认值。然后返回一个带有字段键和字段值的dict(命令)。映射函数使用dict更新数据集中的每个elem(行)。

当然,以上只是对Datasets的粗略介绍!接下来,我们可以使用此函数打印第一个批次:

next_batch = my_input_fn(FILE, True) # Will return 32 random elements# Now let's try it out, retrieving and printing one batch of data.# Although this code looks strange, you don't need to understand# the details.with tf.Session() as sess:
 first_batch = sess.run(next_batch)print(first_batch)# Output({'SepalLength': array([ 5.4000001, ...<repeat to 32 elems>], dtype=float32),
 'PetalWidth': array([ 0.40000001, ...<repeat to 32 elems>], dtype=float32),
 ...
},
[array([[2], ...<repeat to 32 elems>], dtype=int32) # Labels)

实际上,我们需要从Dataset API中实现我们的模型。Datasets具有更多的功能,详情请看这篇文章的结尾,我们收集了更多的资源。

Estimators(评估器)的介绍

Estimator是一种高级API,在训练TensorFlow模型时,它可以减少以前需要编写的大量样板代码。Estimator也非常灵活,如果你对模型有特定要求,它使得你能够覆盖其默认行为。

下面介绍两种可能的方法,你可以用来用Estimator构建模型:

oPre-made Estimator(预制评估器)——这些是预定义的评估器,用于生成特定类型的模型。在这篇文章中,我们将使用DNNClassifier预制评估器。

oEstimator(基础级别)——通过使用model_fn函数,你可以完全控制如何创建模型。我们将在另一篇文章中对其详细介绍。

以下是评估器的类图:

我们希望在将来的版本中添加更多的预制评估器。

你可以看到,所有的评估器都使用input_fn来提供输入数据。在我们的示例中,我们将重用我们为此定义的my_input_fn。

以下代码实例化了预测鸢尾花类型的评估器:

# Create the feature_columns, which specifies the input to our model.# All our input features are numeric, so use numeric_column for each one.feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]# Create a deep neural network regression classifier.# Use the DNNClassifier pre-made estimatorclassifier = tf.estimator.DNNClassifier(
 feature_columns=feature_columns, # The input features to our model
 hidden_units=[10, 10], # Two layers, each with 10 neurons
 n_classes=3,
 model_dir=PATH) # Path to where checkpoints etc are stored

我们现在有一个评估器,我们可以开始训练了。

训练模型

使用单行TensorFlow代码进行训练:

# Train our model, use the previously function my_input_fn# Input to training is a file with training example# Stop training after 8 iterations of train data (epochs)classifier.train(
 input_fn=lambda: my_input_fn(FILE_TRAIN, True, 8))

但等一下... "lambda: my_input_fn(FILE_TRAIN, True, 8)"这是什么东西?这就是我们用评估器连接数据集的地方!评估器需要数据来执行训练、评估和预测,并且使用input_fn来获取数据。评估器需要一个没有参数的input_fn,所以我们使用lambda创建一个没有参数的函数,它使用所需的参数调用input_fn:file_path、shuffle setting和repeat_count。在我们的示例中,我们使用my_input_fn,传递它:

oFILE_TRAIN,它是训练数据文件。

oTrue,这告诉评估器shuffle数据。

o8,它告诉评估器并重复数据集8次。

评估我们训练过的模型

好的,现在我们有一个训练过的模型。我们如何评估它的表现呢?幸运的是,每个评估器都包含一个评估方法:

# Evaluate our model using the examples contained in FILE_TEST# Return value will contain evaluation_metrics such as: loss & average_lossevaluate_result = estimator.evaluate(
 input_fn=lambda: my_input_fn(FILE_TEST, False, 4)print("Evaluation results")for key in evaluate_result:
 print(" {}, was: {}".format(key, evaluate_result[key]))

在我们的示例中,准确度能达到93%。当然有各种各样的方式来提高这个准确性。一种方法是一遍又一遍地运行程序。由于模型的状态是持久的(在上面的model_dir = PATH中),模型将会改进你对其进行的迭代次数的更改,直到它稳定为止。另一种方法是调整隐藏层数或每个隐藏层中的节点数。随意尝试一下,但请注意,当你进行更改时,你需要删除model_dir = PATH中指定的目录,因为你正在更改DNNClassifier的结构。

使用我们训练过模型进行预测

就是这样!我们现在有一个训练过的模型,如果我们对评估结果感到满意,我们可以使用它来基于一些输入来预测鸢尾花。与训练和评估一样,我们使用单个函数调用进行预测:

# Predict the type of some Iris flowers.# Let's predict the examples in FILE_TEST, repeat only once.predict_results = classifier.predict(
 input_fn=lambda: my_input_fn(FILE_TEST, False, 1))print("Predictions on test file")for prediction in predict_results:
 # Will print the predicted class, i.e: 0, 1, or 2 if the prediction
 # is Iris Sentosa, Vericolor, Virginica, respectively.
 print prediction["class_ids"][0]

在内存中对数据进行预测

前面的代码指定了FILE_TEST以对存储在文件中的数据进行预测,但是我们如何对驻留在其他来源的数据进行预测,例如在内存中?你可能会猜到,这并不需要改变我们的预测调用。相反,我们将Dataset API配置为使用记忆结构,如下所示:

# Let create a memory dataset for prediction.# We've taken the first 3 examples in FILE_TEST.prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor
 [6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica
 [5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Sentosadef new_input_fn():
 def decode(x):
 x = tf.split(x, 4) # Need to split into our 4 features
 # When predicting, we don't need (or have) any labels
 return dict(zip(feature_names, x)) # Then build a dict from them
 # The from_tensor_slices function will use a memory structure as input
 dataset = tf.contrib.data.Dataset.from_tensor_slices(prediction_input)
 dataset = dataset.map(decode)
 iterator = dataset.make_one_shot_iterator()
 next_feature_batch = iterator.get_next()
 return next_feature_batch, None # In prediction, we have no labels# Predict all our prediction_inputpredict_results = classifier.predict(input_fn=new_input_fn)# Print resultsprint("Predictions on memory data")for idx, prediction in enumerate(predict_results):
 type = prediction["class_ids"][0] # Get the predicted class (index)
 if type == 0:
 print("I think: {}, is Iris Sentosa".format(prediction_input[idx]))
 elif type == 1:
 print("I think: {}, is Iris Versicolor".format(prediction_input[idx]))
 else:
 print("I think: {}, is Iris Virginica".format(prediction_input[idx])


Dataset.from_tensor_slides()专为适合内存的小型数据集而设计。当我们使用TextLineDataset进行训练和评估时,你可以拥有任意大的文件,只要你的内存可以管理随机缓冲区和批量大小。

使用像DNNClassifier这样的预制评估器提供了很多价值。除了易于使用,预制评估器还提供内置的评估指标,并创建可在TensorBoard中看到的概要。要查看此报告,请从你的命令行启动TensorBoard,如下所示:

# Replace PATH with the actual path passed as model_dir argument when the# DNNRegressor estimator was created.tensorboard --logdir=PATH

下面的图显示了一些tensorboard将提供数据:

概要

在这篇文章中,我们探讨了数据集和评估器。这些是用于定义输入数据流和创建模型的重要API,因此投入时间来学习它们是绝对值得的!

有关更多详情,请务必查看:

o此文中使用的完整源代码可在此处获取。(https://goo.gl/PdGCRx)

oJosh Gordon的Jupyter notebook的出色使用。(
https://github.com/tensorflow/workshops/blob/master/notebooks/07_structured_data.ipynb)使用Jupyter notebook,你将学习如何运行一个更广泛的例子,其具有许多不同类型的特征(输入)。从我们的模型来看,我们只使用了数值特征。

o有关数据集,请参阅程序员指南(
https://www.tensorflow.org/programmers_guide/datasets)和参考文档(
https://www.tensorflow.org/api_docs/python/tf/contrib/data)中的新章节。

o有关评估器,请参阅程序员指南(
https://www.tensorflow.org/programmers_guide/estimators)和参考文档(
https://www.tensorflow.org/versions/master/api_docs/python/tf/estimator)中的新章节。

相关推荐

Filter函数在WPS里的正确用法,官方教程里都没有说......

Filter函数是office365新增的筛选函数,WPS也紧跟添加了它。但在二个软件中的使用方法却完全不同。office365有单元格溢出功能,只需要输入一个Filter公式即可完成数据筛选。但在W...

跳过VLOOKUP天坑!FILTER函数10个招式让同事以为你开了外挂?

还在为VLOOKUP的"一对多"限制头疼?是否还在为INDEX+MATCH的嵌套抓狂?今天教你用Excel新晋顶流——FILTER函数,10个高能用法让你秒变数据操控大师!用法1:精准...

Filter函数的三种用法,比用VLOOKUP一对多查询,更加灵活方便

文章最后有彩蛋!好礼相送!Excel秘籍大全,正文开始FILTER函数可以基于定义的条件筛选一系列数据。在没有filter函数之前,如果实现一对多查询,常见的是构建辅助列,然后使用VLOOKUP+R...

Filter函数公式,快速实现订单核对,1分钟学会

举个例子,我们有一份公司所有的订单源数据表格,这里我们只用两列信息来模拟,实际可能有很多列数据,几百行数据然后我们有另外一个表,里面有部分已经处理过的订单数据,如下所示,这里举例是4个,实际可能有上百...

FILTER函数结合及经典用法2:一对多筛选

FILTER经典用法2:一对多筛选。FILTER函数的经典用法2:一对多的筛选。比如左边这个表格,需要根据部门筛选出每个部门的人员,应该怎样做?·把鼠标放在单元格内,在编辑栏输入等于FILTER。·第...

干掉VLOOKUP,FILTER函数9大用法全解析!

1.单条件基础筛选场景:筛选销量>5000的记录公式:=VSTACK(A1:D1,FILTER(A2:D9,D2:D9>5500))解析:A2:D9为需要筛选的数据区域,D2:D9&...

Excel新函数公式Filter,秒杀VLOOKUP,人人必学

以前VLOOKUP公式是必学的公式,自从新版本更新之后,VLOOKUP已经变得可有可无了,但是新出来的Filter函数公式,你必须学会,它非常的强大,工作中用到非常频繁1、Filter公式背景在学会这...

第一讲:filter的基本用法及拓展_filter详解

全能查找函数filter的基本用法及拓展初学者,务必观看。进阶者,可互相学习,欢迎在回复中补充新用法。首次撰写此函数相关内容,若有不足之处,请予以指教,请勿诋毁,多谢。提示:以下内容以WPS最新版本为...

测一测你是什么粒子?_测测你是什么质

大亚湾实验。|图片来源:RoyKaltschmidt,LawrenceBerkeleyNationalLaboratory/WikimediaCommons2020年12月12日,大亚湾...

SpringBoot如何处理配置文件的密文

在SpringBoot应用中,直接在配置文件(如application.yml或application.properties)中明文存储数据库密码、API密钥等敏感信息是严重的安全风险,...

大语言模型解释Python的 类装饰器

一、什么是类装饰器?在Python中,装饰器(Decorator)是一种高阶函数,它接受另一个对象(通常是函数或类),并返回一个经“增强”处理后的新对象。我们常见的是对函数进行装饰:@my_dec...

Thymeleaf_thymeleaf属于前端吗

一、Thymeleaf简介Thymeleaf是用来开发Web和独立环境项目的服务器端的Java模版引擎Spring官方支持的服务的渲染模板中,并不包含jsp。而是Thymeleaf和Freemarke...

Win9去哪了?Win10避讳Windows95、98

10月1日,微软在旧金山发布了新一代操作系统预览版。但不是名为Windows9,而是win10,有业内人士猜测,跳过9而取10为命名是为了预示十全十美。可是小编还觉得9还代表长长久久呢!恐怕这里又说...

仓颉编程练习-字符串操作_仓颉编译器

main.cj:importstd.convert.Parsablemain():Int64{//字符串比较lets1:String="abc"...

一课译词:断断续续_一课译词:断断续续的意思

PhotobyMikefromPexels“断断续续”,或“时断时续”,意思是时而中断,时而继续地接连下去(continuefromtimetotime)。与英文惯用语“fitsan...