Parameter Pruning and Quantization of Image Classification Network
This example shows how to prune the parameters of a trained neural network using two parameter score metrics: The Magnitude score [1] and Synaptic Flow score [2].
In many applications where transfer learning is used to retrain an image classification network for a new task or where a new network is trained from scratch, the optimal network architecture is not known, and the network might be overparameterized. An overparameterized network has redundant connections. Structured pruning, also known as sparsification, is a compression technique that aims to identify redundant, unnecessary connections you can remove without affecting the network accuracy. When you use pruning in combination with network quantization, you can reduce the inference time and memory footprint of the network making it easier to deploy.
This example shows how to:
Perform post-training, iterative, unstructured pruning without the need for training data
Evaluate the performance of two different pruning algorithms
Investigate the layer-wise sparsity induced after pruning
Evaluate the impact of pruning on classification accuracy
Evaluate the impact of quantization on the classification accuracy of the pruned network
This example uses a simple convolutional neural network to classify handwritten digits from 0 to 9. For more information on setting up the data used for training and validation, see Create Simple Deep Learning Neural Network for Classification.
Load Pretrained Network and Data
Load the training and validation data. Train a convolutional neural network for the classification task.
[imdsTrain, imdsValidation] = loadDigitDataset; net = trainDigitDataNetwork(imdsTrain, imdsValidation); trueLabels = imdsValidation.Labels; classes = categories(trueLabels);
Create a minibatchqueue
object containing the validation data. Set executionEnvironment
to auto to evaluate the network on a GPU, if one is available. By default, the minibatchqueue
object converts each output to a gpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
executionEnvironment = "auto"; miniBatchSize = 128; imdsValidation.ReadSize = miniBatchSize; mbqValidation = minibatchqueue(imdsValidation,1,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFormat','SSCB',... 'MiniBatchFcn',@preprocessMiniBatch,... 'OutputEnvironment',executionEnvironment);
Neural Network Pruning
The goal of neural network pruning is to identify and remove unimportant connections to reduce the size of the network without affecting network accuracy. In the following figure, on the left, the network has connections that map each neuron to the neuron of the next layer. After pruning, the network has fewer connections than the original network.
A pruning algorithm assigns a score to each parameter in the network. The score ranks the importance of each connection in the network. You can use one of two pruning approaches to achieve a target sparsity:
One-shot pruning - Remove a specified percentage of connections based on their score in one step. This method is prone to layer collapse when you specify a high sparsity value.
Iterative pruning - Achieve the target sparsity in a series of iterative steps. You can use this method when evaluated scores are sensitive to network structure. Scores are reevaluated at every iteration, so using a series of steps allows the network to move toward sparsity incrementally.
This example uses the iterative pruning method to achieve a target sparsity.
Iterative Pruning
Convert to dlnetwork Object
In this example, you use the Synaptic Flow algorithm, which requires that you create a custom cost function and evaluate the gradients with respect to the cost function to calculate the parameter score. To create a custom cost function, first convert the pretrained network to a dlnetwork
.
Convert the network to a layer graph and remove the layers used for classification using removeLayers
.
lgraph = layerGraph(net.Layers); lgraph = removeLayers(lgraph,["softmax","classoutput"]); dlnet = dlnetwork(lgraph);
Use analyzeNetwork
to analyze the network architecture and learnable parameters.
analyzeNetwork(dlnet)
Evaluate the accuracy of the network before pruning.
accuracyOriginalNet = evaluateAccuracy(dlnet,mbqValidation,classes,trueLabels)
accuracyOriginalNet = 0.9908
The layers with learnable parameters are the 3 convolutional layers and one fully connected layer. The network initially consists of total 21578 learnable parameters.
numTotalParams = sum(cellfun(@numel,dlnet.Learnables.Value))
numTotalParams = 21578
numNonZeroPerParam = cellfun(@(w)nnz(extractdata(w)),dlnet.Learnables.Value)
numNonZeroPerParam = 8×1
72
8
1152
16
4608
32
15680
10
Sparsity is defined as the percentage of parameters in the network with a value of zero. Check the sparsity of the network.
initialSparsity = 1-(sum(numNonZeroPerParam)/numTotalParams)
initialSparsity = 0
Before pruning, the network has a sparsity of zero.
Create Iteration Scheme
To define an iterative pruning scheme, specify the target sparsity and number of iterations. For this example, use linearly spaced iterations to achieve the target sparsity.
numIterations = 10; targetSparsity = 0.90; iterationScheme = linspace(0,targetSparsity,numIterations);
Pruning Loop
For each iteration, the custom pruning loop in this example performs the following steps:
Calculate the score for each connection.
Rank the scores for all connections in the network based on the selected pruning algorithm.
Determine the threshold for removing connections with the lowest scores.
Create the pruning mask using the threshold.
Apply the pruning mask to learnable parameters of the network.
Network Mask
Instead of setting entries in the weight arrays directly to zero, the pruning algorithm creates a binary mask for each learnable parameter that specifies whether a connection is pruned. The mask allows you to explore the behavior of the pruned network and try different pruning schemes without changing the underlying network structure.
For example, consider the following weights.
testWeight = [10.4 5.6 0.8 9];
Create a binary mask for each parameter in testWeight.
testMask = [1 0 1 0];
Apply the mask to testWeight
to get the pruned weights.
testWeightsPruned = testWeight.*testMask
testWeightsPruned = 1×4
10.4000 0 0.8000 0
In iterative pruning, you create a binary mask for each iteration that contains pruning information. Applying the mask to the weights array does not change either the size of the array or the structure of the neural network. Therefore, the pruning step does not directly result in any speedup during inference or compression of the network size on disk.
Initialize a plot that compares the accuracy of the pruned network to the original network.
figure plot(100*iterationScheme([1,end]),100*accuracyOriginalNet*[1 1],'*-b','LineWidth',2,"Color","b") ylim([0 100]) xlim(100*iterationScheme([1,end])) xlabel("Sparsity (%)") ylabel("Accuracy (%)") legend("Original Accuracy","Location","southwest") title("Pruning Accuracy") grid on
Magnitude Pruning
Magnitude pruning [1] assigns a score to each parameter equal to its absolute value. It is assumed that the absolute value of a parameter corresponds to its relative importance to the accuracy of the trained network.
Initialize the mask. For the first iteration, you do not prune any parameters and the sparsity is 0%.
pruningMaskMagnitude = cell(1,numIterations); pruningMaskMagnitude{1} = dlupdate(@(p)true(size(p)), dlnet.Learnables);
Below is an implementation of magnitude pruning. The network is pruned to various target sparsities in a loop to provide the flexibility to choose a pruned network based on its accuracy.
lineAccuracyPruningMagnitude = animatedline('Color','g','Marker','o','LineWidth',1.5); legend("Original Accuracy","Magnitude Pruning Accuracy","Location","southwest") % Compute magnitude scores scoresMagnitude = calculateMagnitudeScore(dlnet); for idx = 1:numel(iterationScheme) prunedNetMagnitude = dlnet; % Update the pruning mask pruningMaskMagnitude{idx} = calculateMask(scoresMagnitude,iterationScheme(idx)); % Check the number of zero entries in the pruning mask numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskMagnitude{idx}.Value)); sparsity = numPrunedParams/numTotalParams; % Apply pruning mask to network parameters prunedNetMagnitude.Learnables = dlupdate(@(W,M)W.*M, prunedNetMagnitude.Learnables, pruningMaskMagnitude{idx}); % Compute validation accuracy on pruned network accuracyMagnitude = evaluateAccuracy(prunedNetMagnitude,mbqValidation,classes,trueLabels); % Display the pruning progress addpoints(lineAccuracyPruningMagnitude,100*sparsity,100*accuracyMagnitude) drawnow end
SynFlow Pruning
Synaptic flow conservation (SynFlow) [2] scores are used for pruning. You can use this method to prune networks that use linear activation functions such as ReLU.
Initialize the mask. For the first iteration, no parameters are pruned, and the sparsity is 0%.
pruningMaskSynFlow = cell(1,numIterations); pruningMaskSynFlow{1} = dlupdate(@(p)true(size(p)),dlnet.Learnables);
The input data you use to compute the scores is a single image containing ones. If you are using a GPU, convert the data to a gpuArray
.
dlX = dlarray(ones(net.Layers(1).InputSize),'SSC'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end
The below loop implements iterative synaptic flow score for pruning [2] where a custom cost function evaluates the SynFlow score for each parameter used for network pruning.
lineAccuracyPruningSynflow = animatedline('Color','r','Marker','o','LineWidth',1.5); legend("Original Accuracy","Magnitude Pruning Accuracy","Synaptic Flow Accuracy","Location","southwest") prunedNetSynFlow = dlnet; % Iteratively increase sparsity for idx = 1:numel(iterationScheme) % Compute SynFlow scores scoresSynFlow = calculateSynFlowScore(prunedNetSynFlow,dlX); % Update the pruning mask pruningMaskSynFlow{idx} = calculateMask(scoresSynFlow,iterationScheme(idx)); % Check the number of zero entries in the pruning mask numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskSynFlow{idx}.Value)); sparsity = numPrunedParams/numTotalParams; % Apply pruning mask to network parameters prunedNetSynFlow.Learnables = dlupdate(@(W,M)W.*M, prunedNetSynFlow.Learnables, pruningMaskSynFlow{idx}); % Compute validation accuracy on pruned network accuracySynFlow = evaluateAccuracy(prunedNetSynFlow,mbqValidation,classes,trueLabels); % Display the pruning progress addpoints(lineAccuracyPruningSynflow,100*sparsity,100*accuracySynFlow) drawnow end
Investigate Structure of Pruned Network
Choosing how much to prune a network is a trade-off between accuracy and sparsity. Use the sparsity versus accuracy plot to select the iteration with the desired sparsity level and acceptable accuracy.
pruningMethod = "SynFlow"; selectedIteration = 8; prunedDLNet = createPrunedNet(dlnet,selectedIteration,pruningMaskSynFlow,pruningMaskMagnitude,pruningMethod); [sparsityPerLayer,prunedChannelsPerLayer,numOutChannelsPerLayer,layerNames] = pruningStatistics(prunedDLNet);
Earlier convolutional layers are typically pruned less since they contain more relevant information about the core low-level structure of the image (e.g. edges and corners) which are essential for interpreting the image.
Plot the sparsity per layer for the selected pruning method and iteration.
figure bar(sparsityPerLayer*100) title("Sparsity per layer") xlabel("Layer") ylabel("Sparsity (%)") xticks(1:numel(sparsityPerLayer)) xticklabels(layerNames) xtickangle(45) set(gca,'TickLabelInterpreter','none')
The pruning algorithm prunes single connections when you specify a low target sparsity. When you specify a high target sparsity, the pruning algorithm can prune whole filters and neurons in convolutional or fully connected layers.
figure bar([prunedChannelsPerLayer,numOutChannelsPerLayer-prunedChannelsPerLayer],"stacked") xlabel("Layer") ylabel("Number of filters") title("Number of filters per layer") xticks(1:(numel(layerNames))) xticklabels(layerNames) xtickangle(45) legend("Pruned number of channels/neurons" , "Original number of channels/neurons","Location","southoutside") set(gca,'TickLabelInterpreter','none')
Evaluate Network Accuracy
Compare the accuracy of the network before and after pruning.
YPredOriginal = modelPredictions(dlnet,mbqValidation,classes); accOriginal = mean(YPredOriginal == trueLabels)
accOriginal = 0.9908
YPredPruned = modelPredictions(prunedDLNet,mbqValidation,classes); accPruned = mean(YPredPruned == trueLabels)
accPruned = 0.9328
Create a confusion matrix chart to explore the true class labels to the predicted class labels for the original and pruned network.
figure
confusionchart(trueLabels,YPredOriginal);
title("Original Network")
The validation set of the digits data contains 250 images for each class, so if a network predicts the class of each image perfectly, all scores on the diagonal equal 250 and no values are outside of the diagonal.
confusionchart(trueLabels,YPredPruned);
title("Pruned Network")
When pruning a network, compare the confusion chart of the original network and the pruned network to check how the accuracy for each class label changes for the selected sparsity level. If all numbers on the diagonal decrease roughly equally, no bias is present. However, if the decreases are not equal, you might need to choose a pruned network from an earlier iteration by reducing the value of the variable selectedIteration
.
Quantize Pruned Network
Deep neural networks trained in MATLAB use single-precision floating point data types. Even networks that are small require a considerable amount of memory and hardware to perform floating-point arithmetic operations. These restrictions can inhibit deployment of deep learning models that have low computational power and less memory resources. By using a lower precision to store the weights and activations, you can reduce the memory requirements of the network. You can use Deep Learning Toolbox in tandem with the Deep Learning Model Quantization Library support package to reduce the memory footprint of a deep neural network by quantizing the weights, biases, and activations of the convolution layers to 8-bit scaled integer data types.
Pruning a network impacts the range statistics of parameters and activations at each layer, so the accuracy of the quantized network can change. To explore this difference, quantize the pruned network and use the quantized network to perform inference.
Split the data into calibration and validation data sets.
calibrationDataStore = splitEachLabel(imdsTrain,0.1,'randomize');
validationDataStore = imdsValidation;
Create a dlquantizer
object and specify the pruned network as the network to quantize.
prunedNet = assembleNetwork([prunedDLNet.Layers ; net.Layers(end-1:end)]); quantObjPrunedNetwork = dlquantizer(prunedNet,'ExecutionEnvironment','GPU');
Use the calibrate
function to exercise the network with the calibration data and collect range statistics for the weights, biases, and activations at each layer.
calResults = calibrate(quantObjPrunedNetwork, calibrationDataStore)
Use the validate
function to compare the results of the network before and after quantization using the validation data set.
valResults = validate(quantObjPrunedNetwork, validationDataStore);
Examine the MetricResults.Result
field of the validation output to see the accuracy of the quantized network.
valResults.MetricResults.Result valResults.Statistics
Mini Batch Preprocessing Function
The preprocessMiniBatch
function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating the data over the fourth dimension adds a third dimension to each image to use as a singleton channel dimension.
function X = preprocessMiniBatch(XCell) % Extract image data from cell and concatenate. X = cat(4,XCell{:}); end
Model Accuracy Function
Evaluate the classification accuracy of the dlnetwork
. Accuracy is the percentage of labels correctly classified by the network.
function accuracy = evaluateAccuracy(dlnet,mbqValidation,classes,trueLabels) YPred = modelPredictions(dlnet,mbqValidation,classes); accuracy = mean(YPred == trueLabels); end
SynFlow Score Function
The calculateSynFlowScore
function calculates Synaptic Flow (SynFlow) scores. Synaptic saliency [2] is described as the class of gradient-based scores defined by the product of gradient of loss multiplied by the parameter value:
*
The SynFlow score is a synaptic saliency score that uses the sum of all network outputs as a loss function:
is the function represented by the neural network
are the parameters of the network
is the input array to the network
To compute parameter gradients with respect to this loss function, use dlfeval
and a model gradients function.
function score = calculateSynFlowScore(dlnet,dlX) dlnet.Learnables = dlupdate(@abs, dlnet.Learnables); gradients = dlfeval(@modelGradients,dlnet,dlX); score = dlupdate(@(g,w)g.*w, gradients, dlnet.Learnables); end
Model Gradients for SynFlow Score
function gradients = modelGradients(dlNet,inputArray) % Evaluate the gradients on a given input to the dlnetwork dlYPred = predict(dlNet,inputArray); pseudoloss = sum(dlYPred,'all'); gradients = dlgradient(pseudoloss,dlNet.Learnables); end
Magnitude Score Function
The calculateMagnitudeScore
function returns the magnitude score, defined as the element-wise absolute value of the parameters.
function score = calculateMagnitudeScore(dlnet) score = dlupdate(@abs, dlnet.Learnables); end
Mask Generation Function
The calculateMask
function returns a binary mask for the network parameters based on the given scores and the target sparsity.
function mask = calculateMask(scoresMagnitude,sparsity) % Compute a binary mask based on the parameter-wise scores such that the mask contains a percentage of zeros as specified by sparsity. % Flatten the cell array of scores into one long score vector flattenedScores = cell2mat(cellfun(@(S)extractdata(gather(S(:))),scoresMagnitude.Value,'UniformOutput',false)); % Rank the scores and determine the threshold for removing connections for the % given sparsity flattenedScores = sort(flattenedScores); k = round(sparsity*numel(flattenedScores)); if k==0 thresh = 0; else thresh = flattenedScores(k); end % Create a binary mask mask = dlupdate( @(S)S>thresh, scoresMagnitude); end
Model Predictions Function
The modelPredictions
function takes as input a dlnetwor
k object dlnet, a minibatchqueue
of input data mbq
, and the network classes, and computes the model predictions by iterating over all data in the minibatchqueue object. The function uses the onehotdecode
function to find the predicted class with the highest score.
function predictions = modelPredictions(dlnet,mbq,classes) predictions = []; while hasdata(mbq) dlXTest = next(mbq); dlYPred = softmax(predict(dlnet,dlXTest)); YPred = onehotdecode(dlYPred,classes,1)'; predictions = [predictions; YPred]; end reset(mbq) end
Apply Pruning Function
The createPrunedNet
function returns the pruned dlnetwork for the specified pruning algorithm and iteration.
function prunedNet = createPrunedNet(dlnet,selectedIteration,pruningMaskSynFlow,pruningMaskMagnitude,pruningMethod) switch pruningMethod case "Magnitude" prunedNet = dlupdate(@(W,M)W.*M, dlnet, pruningMaskMagnitude{selectedIteration}); case "SynFlow" prunedNet = dlupdate(@(W,M)W.*M, dlnet, pruningMaskSynFlow{selectedIteration}); end end
Pruning Statistics Function
The pruningStatistics
function extracts detailed layer-level pruning statistics such as the layer-level sparsity and the number of filters or neurons being pruned.
sparsityPerLayer - percentage of parameters pruned in each layer
prunedChannelsPerLayer - number of channels/neurons in each layer that can be removed as a result of pruning
numOutChannelsPerLayer - number of channels/neurons in each layer
function [sparsityPerLayer,prunedChannelsPerLayer,numOutChannelsPerLayer,layerNames] = pruningStatistics(dlnet) layerNames = unique(dlnet.Learnables.Layer,'stable'); numLayers = numel(layerNames); layerIDs = zeros(numLayers,1); for idx = 1:numel(layerNames) layerIDs(idx) = find(layerNames(idx)=={dlnet.Layers.Name}); end sparsityPerLayer = zeros(numLayers,1); prunedChannelsPerLayer = zeros(numLayers,1); numOutChannelsPerLayer = zeros(numLayers,1); numParams = zeros(numLayers,1); numPrunedParams = zeros(numLayers,1); for idx = 1:numLayers layer = dlnet.Layers(layerIDs(idx)); % Calculate the sparsity paramIDs = strcmp(dlnet.Learnables.Layer,layerNames(idx)); paramValue = dlnet.Learnables.Value(paramIDs); for p = 1:numel(paramValue) numParams(idx) = numParams(idx) + numel(paramValue{p}); numPrunedParams(idx) = numPrunedParams(idx) + nnz(extractdata(paramValue{p})==0); end % Calculate channel statistics sparsityPerLayer(idx) = numPrunedParams(idx)/numParams(idx); switch class(layer) case "nnet.cnn.layer.FullyConnectedLayer" numOutChannelsPerLayer(idx) = layer.OutputSize; prunedChannelsPerLayer(idx) = nnz(all(layer.Weights==0,2)&layer.Bias(:)==0); case "nnet.cnn.layer.Convolution2DLayer" numOutChannelsPerLayer(idx) = layer.NumFilters; prunedChannelsPerLayer(idx) = nnz(reshape(all(layer.Weights==0,[1,2,3]),[],1)&layer.Bias(:)==0); case "nnet.cnn.layer.GroupedConvolution2DLayer" numOutChannelsPerLayer(idx) = layer.NumGroups*layer.NumFiltersPerGroup; prunedChannelsPerLayer(idx) = nnz(reshape(all(layer.Weights==0,[1,2,3]),[],1)&layer.Bias(:)==0); otherwise error("Unknown layer: "+class(layer)) end end end
Load Digits Data set Function
The loadDigitDataset
function loads the Digits data set and splits the data into training and validation data.
function [imdsTrain, imdsValidation] = loadDigitDataset() digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true,'LabelSource','foldernames'); [imdsTrain, imdsValidation] = splitEachLabel(imds,0.75,"randomized"); end
Train Digit Recognition Network Function
The trainDigitDataNetwork
function trains a convolutional neural network to classify digits in grayscale images.
function net = trainDigitDataNetwork(imdsTrain,imdsValidation) layers = [ imageInputLayer([28 28 1],"Normalization","rescale-zero-one") convolution2dLayer(3,8,'Padding','same') reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') reluLayer fullyConnectedLayer(10) softmaxLayer classificationLayer]; % Specify the training options options = trainingOptions('sgdm', ... 'InitialLearnRate',0.01, ... 'MaxEpochs',10, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency',30, ... 'Verbose',false, ... 'Plots','none',"ExecutionEnvironment","auto"); % Train network net = trainNetwork(imdsTrain,layers,options); end
References
[1] Song Han, Jeff Pool, John Tran, and William J. Dally. 2015. "Learning Both Weights and Connections for Efficient Neural Networks." Advances in Neural Information Processing Systems 28 (NIPS 2015): 1135–1143.
[2] Hidenori Tanaka, Daniel Kunin, Daniel L. K. Yamins, and Surya Ganguli 2020. "Pruning Neural Networks Without Any Data by Iteratively Conserving Synaptic Flow." 34th Conference on Neural Information Processing Systems (NeurlPS 2020)
See Also
Functions
dlarray
|dlquantizer
|predict
|dlnetwork