使用tf.data进行数据集处理

描述

在进行AI模型训练过程前,需要对数据集进行处理, Tensorflow提供了tf.data数据集处理模块,通过该接口能够轻松实现数据集预处理。tf.data支持对数据集进行大量处理,如图片裁剪、图片打乱、图片分批次处理等操作。

 

数据集加载介绍

通过tf.data能够实现数据集加载,加载的数据格式包括:

●使用NumPy数组数据
●使用python生成器数据
●使用TFRecords格式数据
●使用文本格式数据
●使用CSV文件格式数据

tf.data常见数据格式加载示例

●使用Numpy数组数据

通过numpy构建数据,将构建的数据传递到tf.data的Dataset中。

 

import tensorflow as tf
import numpy as np
# 通过numpy构建数据个数
input_data = np.arange(4)
# 将数据传递到Dataset
dataset = tf.data.Dataset.from_tensor_slices(input_data)
for data in dataset:
    # 打印数据集,转换数据集tensor格式
    print(data)
输出为tensor数据集:
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64)

 

  ●读取文本中数据

通过准备的文本文件file.txt,将文本文件中的内容读取到tf.data,文件内容为:

Tf dataset load numpy data
Tf dataset load txt file data
Tf dateset load CSV file data

 

加载文本文件代码:

 

import tensorflow as tf
# 通过TextLineDataset进行加载文本文件内容
dataset = tf.data.TextLineDataset("file.txt")
for line in dataset:
    print(line)

 

文本加载数据输出(输出的Tensor中已包含了文件文件中的数据):

 

tf.Tensor(b'Tf dataset load numpy data', shape=(), dtype=string)
tf.Tensor(b'Tf dataset load txt file data', shape=(), dtype=string)
tf.Tensor(b'Tf dateset load CSV file data', shape=(), dtype=string)

 

●读取csv文本中数据

准备csv文件file.csv,文件内容为:

Data  

加载文本文件代码:

  import tensorflow as tf
import pandas as pd
# 使用pandas读取csv文本中数据
data = pd.read_csv('date.csv')
# 将读取的data数据传递到dataset中
f_slices = tf.data.Dataset.from_tensor_slices(dict(data))
for d in f_slices:
    print (d)

 

csv文本加载数据输出(输出的Tensor中已包含了文件文件中的数据):

 

{'Year': , 'Month': , 'Day': , 'Hour': }
{'Year': , 'Month': , 'Day': , 'Hour': }
{'Year': , 'Month': , 'Day': , 'Hour': }

 

●利用python迭代构建数据

通过python构建迭代器方式,将数据传递到tf.data, 示例代码如下:

 

# 迭代函数,通过传递的stop数据进行迭代
def build_data(stop):
  i = 0
  while i
示例代码输出(迭代5次的Tensor):
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)

 

tf.data常见数据处理

tf.data常用以下操作对数据完成预处理过程,操作包括: repeat、batch、shuffle、map等。

●tf.data数据repeat操作

通过调用repeat操作,将原数据进行重复构建,重复构建根据传递的repeat(x)次数决定。

●tf.data数据batch操作

通过调用batch操作将数据进行分批次执行,每批次数量根据batch(x)的值决定。

●tf.data 数据shuffle操作,打乱数据顺序

shuffle操作常用于预处理数据集时,将数据集中的顺序打乱,shuffle支持配置(buffer_size=x)将数据放置在缓冲区,通过缓冲区方式将数据打乱。

●tf.data 数据map操作

map操作能够将数组中的元素重构,同时能够实现读取图片,对图片进行旋转操作。

示例:

import tensorflow as tf
import numpy as np
# 使用numpy构建12个数据
input_data = np.arange(12)
# 将构建数据传递到dataset,传递中添加shuffle(10个缓冲区数据), batch分批次执行(每次4个数据), repeat重复构建数据2次
dataset = tf.data.Dataset.from_tensor_slices(input_data).shuffle(buffer_size=10).batch(4).repeat(2)
for data in dataset:
    print(data)
示例代码输出(输出中可以看到Tensor每次4个数据,每个数据重复出现2次,每次数据乱序输出):
tf.Tensor([8 3 9 1], shape=(4,), dtype=int64)
tf.Tensor([2 0 4 5], shape=(4,), dtype=int64)
tf.Tensor([ 7 11 10  6], shape=(4,), dtype=int64)
tf.Tensor([6 8 5 4], shape=(4,), dtype=int64)
tf.Tensor([ 7 10  2 11], shape=(4,), dtype=int64)
tf.Tensor([3 1 0 9], shape=(4,), dtype=int64)

 

图片旋转示例,示例代码如下:

 

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1)
mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
# 构建旋转函数,通过tensorflow的image.rot90完成90度旋转
def rot90(image, label):
    image = tf.image.rot90(image)
    return image, label
# 通过map方式调用构建的旋转函数
mnist_dataset = mnist_dataset.map(rot90)
for image, label in mnist_dataset.take(1):
    #添加图片抬头标签
    plt.title(label.numpy())
    plt.imshow(image.numpy()[:, :])
    plt.show()

 

正常加载图片输出:

Data

示例代码运行后,图片旋转输出:

Data

 

mnist数据集预处理

利用TensorFlow Datasets 提供了一系列可以和 TensorFlow 配合使用的数据集。下载和准备数据,以及构建tf.data.Dataset。

示例代码需要: 

 

python3.6版本环境
安装tensorflow==1.14.0版本(pip3 install tensorflow==2.1.0)
安装tensorflow_datasets==4.4.0(pip3 install tensorflow-datasets==4.4.0)
示例代码:
import tensorflow as tf
import tensorflow_datasets as tfds


#数据集通过Tensorflow Eager模式执行
tf.compat.v1.enable_eager_execution()


# 加载 MNIST 训练数据。这个步骤会下载并准备好该数据,除非你显式指定 `download=False` ,值得注意的是,一旦该数据准备好了,后续的  `load`  命令便不会重新下载,可以重复使用准备好的数据。你可以通过指定  `data_dir=`  (默认是  `~/tensorflow_datasets/` ) 来自定义数据保存/加载的路径。
mnist_train = tfds.load(name="mnist", split="train")
assert isinstance(mnist_train, tf.data.Dataset)


mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
mnist_train = mnist_builder.as_dataset(split="train")
# 对数据集进行重复使用,并对数据进行打乱,分批次处理
mnist_train = mnist_train.repeat().shuffle(1024).batch(32)
# prefetch 将使输入流水线可以在模型训练时异步获取批处理
mnist_train = mnist_train.prefetch(tf.data.experimental.AUTOTUNE)
info = mnist_builder.info
print(info.features["label"].names)
mnist_test, info = tfds.load("mnist", split="test", with_info=True)
print(info)
# 通过tfds.show_examples可视化数据样本
fig = tfds.show_examples(info, mnist_test)

 

代码示例输出:

 

# 数据集label名称
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# 数据集信息
tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='/home/fabian/tensorflow_datasets/mnist/3.0.1',
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': ,
        'train': ,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
)
可视化样本数据图片:

Data

作者介绍:陈远斌,本科毕业于南开大学,海云捷迅研发工程师,熟悉OpenStack,Kubernetes技术,曾参与社区代码贡献,在OpenStack云计算技术上有一定的开发经验。

  审核编辑:汤梓红

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分