import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras import layers
import numpy as np
from tensorflow.keras.layers import Conv2D,MaxPool2D,Flatten,Dense,Dropout
path = "/home/qjm/Downloads/mnist.npz"
f=np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
x_train=x_train.astype("float32")
x_test=x_test.astype("float32")
x_train/=255.0
x_test/=255.0
x_train=x_train.reshape(-1,28,28,1)
x_test=x_test.reshape(-1,28,28,1)
yy_train=np.zeros((60000,10))
yy_test=np.zeros((10000,10))
for i in range(60000):
yy_train[i,y_train[i]]=1
for i in range(10000):
yy_test[i,y_test[i]]=1
y_train=yy_train
y_test=yy_test
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
model = tf.keras.models.Sequential([
Conv2D(filters=6,kernel_size=5,strides=(1,1),padding='same',activation='relu',use_bias=False,input_shape=(28,28,1)),
MaxPool2D(pool_size=(3,3),strides=2,padding="same"),
Conv2D(filters=16,kernel_size=5,strides=(1,1),padding='same',activation='relu',use_bias=False),
MaxPool2D(pool_size=(3,3),strides=2,padding="same"),
Flatten(input_shape=(7, 7)),
Dense(120, activation='relu'),
Dense(84, activation='relu'),
Dropout(0.2),
Dense(10, activation='softmax')
])
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, final_sparsity=0.5,
begin_step=200, end_step=4000)
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)
model_for_pruning.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
log_dir = '/home/qjm/Desktop/model'
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
# Log sparsity and other metrics in Tensorboard.
tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir),
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
]
model_for_pruning.fit(x_train, y_train, epochs=10,callbacks=callbacks)
model_for_pruning.evaluate(x_test, y_test, verbose=2)
print(model_for_pruning.summary())
weight=model_for_pruning.get_layer(index=0).get_weights()
for w in weight:
print(w.shape)
print(1-1.0*np.count_nonzero(w)/w.size)
weight=model_for_pruning.get_layer(index=2).get_weights()
for w in weight:
print(w.shape)
print(1-1.0*np.count_nonzero(w)/w.size)
Tensorflow weight pruning
关注
打赏