一、The Functional API
tf.keras.Sequential
模型是层的简单堆叠,无法表示任意模型。使用 Keras 函数式 API 可以构建复杂的模型拓扑,例如:
- 多输入模型
- 多输出模型
- 具有共享层的模型(同一层被调用多次),
- 具有非序列数据流的模型(例如,残差连接)
使用函数式 API 构建的模型具有以下特征:
- 层实例可调用并返回张量。
- 输入张量和输出张量用于定义
tf.keras.Model
实例。 - 此模型的训练方式和
Sequential Model
一样。
下面的示例使用functional API构建一个简单的全连接网络
def buildComplexModel():
print("The Functional API")
# layer实列作用于一个tensor, 并返回一个tensor
input = tf.keras.Input(shape=(32,))
print(type(input)) #
x = layers.Dense(64, activation='relu')(input)
print(type(x)) #
x = layers.Dense(64, activation='relu')(x)
print(type(x)) #
predictions = layers.Dense(10, activation='softmax')(x)
print(type(predictions)) #
print("predictions: ", predictions) # predictions: Tensor("dense_2/Identity:0", shape=(None, 10), dtype=float32)
# 构建模型
model = tf.keras.Model(inputs=input, outputs=predictions)
# 编译模型
model.compile(
optimizer=tf.keras.optimizers.RMSprop(0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# 训练模型
# With Numpy arrays
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
model.fit(data, labels, batch_size=32, epochs=5)
1.1、简单实现
def TestFunctionalAPI():
inputs = tf.keras.Input(shape=(784,)) # 784维的向量
print(inputs.shape, inputs.dtype)
img_inputs = tf.keras.Input(shape=(32, 32, 3))
# layer on input
from tensorflow.keras import layers
devse = layers.Dense(64, activation='relu')
# 添加更多的层
x = devse(inputs) # layer call inputs
x = layers.Dense(64, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax')(x)
# 创建模型
model = tf.keras.Model(inputs=inputs, outputs=outputs)
print(model.summary())
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 784)] 0
_________________________________________________________________
dense (Dense) (None, 64) 50240
_________________________________________________________________
dense_1 (Dense) (None, 64) 4160
_________________________________________________________________
dense_2 (Dense) (None, 10) 650
=================================================================
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________
None
未完待续。。。