您当前的位置: 首页 >  tensorflow

耐心的小黑

暂无认证

  • 0浏览

    0关注

    323博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

tensorflow2.x学习笔记十三:使用 tf.data下的API构建数据通道Dataset

耐心的小黑 发布时间:2020-08-30 21:52:35 ,浏览量:0

我们可以从以下七种数据结构中构建数据通道:

  • Numpy array,
  • Pandas DataFrame,
  • Python generator,
  • csv文件,
  • 文本文件,
  • 文件路径,
  • tfrecords文件

由于从tfrecord文件中构建数据通道比较复杂,所以接下来就只介绍前面六种情况。下面在介绍的时候也会指出使用了的tf.data下的五种API,其中 Numpy arrayPandas DataFrame使用的是同一种API。

一、从Numpy array构建数据管道(tf.data.Dataset.from_tensor_slices)
import tensorflow as tf
import numpy as np 
from sklearn import datasets 
iris = datasets.load_iris()


ds1 = tf.data.Dataset.from_tensor_slices((iris["data"],iris["target"]))

for features,label in ds1.take(5):
    print(features,label)
tf.Tensor([5.1 3.5 1.4 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor([4.9 3.  1.4 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor([4.7 3.2 1.3 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor([4.6 3.1 1.5 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor([5.  3.6 1.4 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
二、从 Pandas DataFrame构建数据管道(tf.data.Dataset.from_tensor_slices)
import tensorflow as tf
from sklearn import datasets 
import pandas as pd
iris = datasets.load_iris()
dfiris = pd.DataFrame(iris["data"],columns = iris.feature_names)

##dfiris.to_dict("list")会将DataFrame转换成字典,
##生成的数据集中也会以字典为形式设置元素
ds2 = tf.data.Dataset.from_tensor_slices((dfiris.to_dict("list"),iris["target"]))

for features,label in ds2.take(1):
    print(features,label)
{'sepal length (cm)': , 
 'sepal width (cm)': , 
 'petal length (cm)': , 
 'petal width (cm)': 
 } 
 tf.Tensor(0, shape=(), dtype=int64)
三、从Python generator构建数据管道(tf.data.Dataset.from_generator)
# 从Python generator构建数据管道
import tensorflow as tf
from matplotlib import pyplot as plt 
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 定义一个从文件中读取图片的generator
image_generator = ImageDataGenerator(rescale=1.0/255)
					.flow_from_directory(
		                    "./data/cifar2/test/",
		                    target_size=(32, 32),
		                    batch_size=20,
		                    class_mode='binary')

classdict = image_generator.class_indices
print(classdict)

def generator():
    for features,label in image_generator:
        yield (features,label)

ds3 = tf.data.Dataset.from_generator(generator,output_types=(tf.float32,tf.int32))


Found 2000 images belonging to 2 classes.
{'airplane': 0, 'automobile': 1}
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.figure(figsize=(6,6)) 
for i,(img,label) in enumerate(ds3.unbatch().take(9)):
    ax=plt.subplot(3,3,i+1)
    ax.imshow(img.numpy())
    ax.set_title("label = %d"%label)
    ax.set_xticks([])
    ax.set_yticks([]) 
plt.show()

在这里插入图片描述

四、从csv文件构建数据管道(tf.data.experimental.make_csv_dataset)
# 从csv文件构建数据管道
ds4 = tf.data.experimental.make_csv_dataset(
      file_pattern = ["./data/titanic/train.csv",
      				  "./data/titanic/test.csv"],
      batch_size=3, 
      label_name="Survived",
      na_value="",
      num_epochs=1,
      ignore_errors=True)

for data,label in ds4.take(1):
    print(data,label)
OrderedDict([
('PassengerId', ), 
('Pclass', ), 
('Name', ), 
('Sex', ), 
('Age', ), 
('SibSp', ), 
('Parch', ), 
('Ticket', ), 
('Fare', ), 
('Cabin', ), 
('Embarked', )
]) 

tf.Tensor([0 1 1], shape=(3,), dtype=int32)
五、从文本文件构建数据管道(tf.data.TextLineDataset)
ds5 = tf.data.TextLineDataset(
      filenames = ["./data/titanic/train.csv",
    			   "./data/titanic/test.csv"]).skip(1) #略去第一行header

for line in ds5.take(1):
    print(line)
tf.Tensor(b'493,0,1,"Molson, Mr. Harry Markland",
		  male,55.0,0,0,113787,30.5,C30,S', shape=(), dtype=string)
六、从文件路径构建数据管道(tf.data.Dataset.list_files)
ds6 = tf.data.Dataset.list_files("./data/cifar2/train/*/*.jpg")
for file in ds6.take(2):
    print(file)
tf.Tensor(b'./data/cifar2/train/airplane/4266.jpg', 
shape=(), dtype=string)

tf.Tensor(b'./data/cifar2/train/airplane/4131.jpg', 
shape=(), dtype=string)

参考链接:https://github.com/lyhue1991/eat_tensorflow2_in_30_days

关注
打赏
1640088279
查看更多评论
立即登录/注册

微信扫码登录

0.0382s