import tensorflow as tf
import tempfile
import zipfile
import os
import tensorboard
import numpy as np
from tensorflow_model_optimization.sparsity import keras as sparsity
## global parameters
batch_size = 128
num_classes = 10
epochs = 10
# input image dimensions
img_rows, img_cols = 28, 28
logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)
def prepare_trainval(img_rows, img_cols):
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
if tf.keras.backend.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
return x_train,x_test,y_train,y_test
def build_prune_model(input_shape,end_step):
l = tf.keras.layers
print('End step: ' + str(end_step))
pruning_params = {
'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,
final_sparsity=0.90,
begin_step=2000,
end_step=end_step,
frequency=100)
}
pruned_model = tf.keras.Sequential([
sparsity.prune_low_magnitude(
l.Conv2D(32, 5, padding='same', activation='relu'),
input_shape=input_shape,**pruning_params),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.BatchNormalization(),
sparsity.prune_low_magnitude(
l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.Flatten(),sparsity.prune_low_magnitude(l.Dense(1024, activation='relu'),
**pruning_params),
l.Dropout(0.4),
sparsity.prune_low_magnitude(l.Dense(num_classes, activation='softmax'),
**pruning_params)
])
pruned_model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy'])
pruned_model.summary()
return pruned_model
def train_prune_model(x_train,x_test,y_train,y_test,epochs,prune_model_file):
input_shape = (img_rows, img_cols,1)
num_train_samples = x_train.shape[0]
end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs
pruned_model = build_prune_model(input_shape,end_step)
# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
sparsity.UpdatePruningStep(),
sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
]
pruned_model.fit(x_train, y_train,
batch_size=batch_size,
epochs=10,
verbose=1,
callbacks=callbacks,
validation_data=(x_test, y_test))
score = pruned_model.evaluate(x_test, y_test, verbose=0)
print('Saving pruned model to: ', prune_model_file)
# saved_model() sets include_optimizer to True by default. Spelling it out here
# to highlight.
tf.keras.models.save_model(pruned_model, prune_model_file, include_optimizer=True)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
def build_clean_model(input_shape):
l = tf.keras.layers
model = tf.keras.Sequential([
l.Conv2D(
32, 5, padding='same', activation='relu', input_shape=input_shape),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.BatchNormalization(),
l.Conv2D(64, 5, padding='same', activation='relu'),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.Flatten(),
l.Dense(1024, activation='relu'),
l.Dropout(0.4),
l.Dense(num_classes, activation='softmax')
])
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy'])
model.summary()
return model
def train_clean_model(x_train,x_test,y_train,y_test,epochs,keras_file):
callbacks = [tf.keras.callbacks.TensorBoard(log_dir=logdir, profile_batch=0)]
input_shape = (img_rows, img_cols, 1)
model = build_clean_model(input_shape)
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
callbacks=callbacks,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
#x_train,x_test,y_train,y_test = prepare_trainval(img_rows, img_cols)
keras_file = "./ori_mnist_classifier.h5"
#train_clean_model(x_train,x_test,y_train,y_test,epochs,keras_file)
prune_model_file = "./prune_mnist_classifier.h5"
#train_prune_model(x_train,x_test,y_train,y_test,epochs,prune_model_file)
_, zip1 = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zip1, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(keras_file)
print("Size of the unpruned model before compression: %.2f Mb" %
(os.path.getsize(keras_file) / float(2**20)))
print("Size of the unpruned model after compression: %.2f Mb" %
(os.path.getsize(zip1) / float(2**20)))
_, zip2 = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zip2, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(prune_model_file)
print("Size of the pruned model before compression: %.2f Mb" %
(os.path.getsize(prune_model_file) / float(2**20)))
print("Size of the pruned model after compression: %.2f Mb" %
(os.path.getsize(zip2) / float(2**20)))