Main Content

Parameter Pruning and Quantization of Image Classification Network

This example shows how to prune the parameters of a trained neural network using two parameter score metrics: The Magnitude score [1] and Synaptic Flow score [2].

In many applications where transfer learning is used to retrain an image classification network for a new task or where a new network is trained from scratch, the optimal network architecture is not known, and the network might be overparameterized. An overparameterized network has redundant connections. Structured pruning, also known as sparsification, is a compression technique that aims to identify redundant, unnecessary connections you can remove without affecting the network accuracy. When you use pruning in combination with network quantization, you can reduce the inference time and memory footprint of the network making it easier to deploy.

This example shows how to:

  • Perform post-training, iterative, unstructured pruning without the need for training data

  • Evaluate the performance of two different pruning algorithms

  • Investigate the layer-wise sparsity induced after pruning

  • Evaluate the impact of pruning on classification accuracy

  • Evaluate the impact of quantization on the classification accuracy of the pruned network

This example uses a simple convolutional neural network to classify handwritten digits from 0 to 9. For more information on setting up the data used for training and validation, see Create Simple Deep Learning Neural Network for Classification.

Load Pretrained Network and Data

Load the training and validation data. Train a convolutional neural network for the classification task.

[imdsTrain, imdsValidation] = loadDigitDataset;
net = trainDigitDataNetwork(imdsTrain, imdsValidation);
trueLabels = imdsValidation.Labels;
classes = categories(trueLabels);

Create a minibatchqueue object containing the validation data. Set executionEnvironment to auto to evaluate the network on a GPU, if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

executionEnvironment = "auto";
miniBatchSize = 128;
imdsValidation.ReadSize = miniBatchSize;
mbqValidation = minibatchqueue(imdsValidation,1,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFormat','SSCB',...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'OutputEnvironment',executionEnvironment);

Neural Network Pruning

The goal of neural network pruning is to identify and remove unimportant connections to reduce the size of the network without affecting network accuracy. In the following figure, on the left, the network has connections that map each neuron to the neuron of the next layer. After pruning, the network has fewer connections than the original network.

ParameterPruningExample_01[1].png

A pruning algorithm assigns a score to each parameter in the network. The score ranks the importance of each connection in the network. You can use one of two pruning approaches to achieve a target sparsity:

  • One-shot pruning - Remove a specified percentage of connections based on their score in one step. This method is prone to layer collapse when you specify a high sparsity value.

  • Iterative pruning - Achieve the target sparsity in a series of iterative steps. You can use this method when evaluated scores are sensitive to network structure. Scores are reevaluated at every iteration, so using a series of steps allows the network to move toward sparsity incrementally.

This example uses the iterative pruning method to achieve a target sparsity.

Iterative Pruning

pruningWF.png

Convert to dlnetwork Object

In this example, you use the Synaptic Flow algorithm, which requires that you create a custom cost function and evaluate the gradients with respect to the cost function to calculate the parameter score. To create a custom cost function, first convert the pretrained network to a dlnetwork.

Convert the network to a layer graph and remove the layers used for classification using removeLayers.

lgraph = layerGraph(net.Layers);
lgraph = removeLayers(lgraph,["softmax","classoutput"]);
dlnet = dlnetwork(lgraph);

Use analyzeNetwork to analyze the network architecture and learnable parameters.

analyzeNetwork(dlnet)

Evaluate the accuracy of the network before pruning.

accuracyOriginalNet = evaluateAccuracy(dlnet,mbqValidation,classes,trueLabels)
accuracyOriginalNet = 0.9908

The layers with learnable parameters are the 3 convolutional layers and one fully connected layer. The network initially consists of total 21578 learnable parameters.

numTotalParams = sum(cellfun(@numel,dlnet.Learnables.Value))
numTotalParams = 21578
numNonZeroPerParam = cellfun(@(w)nnz(extractdata(w)),dlnet.Learnables.Value)
numNonZeroPerParam = 8×1

          72
           8
        1152
          16
        4608
          32
       15680
          10

Sparsity is defined as the percentage of parameters in the network with a value of zero. Check the sparsity of the network.

initialSparsity = 1-(sum(numNonZeroPerParam)/numTotalParams)
initialSparsity = 0

Before pruning, the network has a sparsity of zero.

Create Iteration Scheme

To define an iterative pruning scheme, specify the target sparsity and number of iterations. For this example, use linearly spaced iterations to achieve the target sparsity.

numIterations = 10; 
targetSparsity = 0.90;
iterationScheme = linspace(0,targetSparsity,numIterations); 

Pruning Loop

For each iteration, the custom pruning loop in this example performs the following steps:

  • Calculate the score for each connection.

  • Rank the scores for all connections in the network based on the selected pruning algorithm.

  • Determine the threshold for removing connections with the lowest scores.

  • Create the pruning mask using the threshold.

  • Apply the pruning mask to learnable parameters of the network.

Network Mask

Instead of setting entries in the weight arrays directly to zero, the pruning algorithm creates a binary mask for each learnable parameter that specifies whether a connection is pruned. The mask allows you to explore the behavior of the pruned network and try different pruning schemes without changing the underlying network structure.

For example, consider the following weights.

testWeight = [10.4 5.6 0.8 9];

Create a binary mask for each parameter in testWeight.

testMask = [1 0 1 0];

Apply the mask to testWeight to get the pruned weights.

testWeightsPruned = testWeight.*testMask
testWeightsPruned = 1×4

   10.4000         0    0.8000         0

In iterative pruning, you create a binary mask for each iteration that contains pruning information. Applying the mask to the weights array does not change either the size of the array or the structure of the neural network. Therefore, the pruning step does not directly result in any speedup during inference or compression of the network size on disk.

Initialize a plot that compares the accuracy of the pruned network to the original network.

figure
plot(100*iterationScheme([1,end]),100*accuracyOriginalNet*[1 1],'*-b','LineWidth',2,"Color","b")
ylim([0 100])
xlim(100*iterationScheme([1,end]))
xlabel("Sparsity (%)")
ylabel("Accuracy (%)")
legend("Original Accuracy","Location","southwest")
title("Pruning Accuracy")    
grid on

Magnitude Pruning

Magnitude pruning [1] assigns a score to each parameter equal to its absolute value. It is assumed that the absolute value of a parameter corresponds to its relative importance to the accuracy of the trained network.

Initialize the mask. For the first iteration, you do not prune any parameters and the sparsity is 0%.

pruningMaskMagnitude = cell(1,numIterations); 
pruningMaskMagnitude{1} = dlupdate(@(p)true(size(p)), dlnet.Learnables);

Below is an implementation of magnitude pruning. The network is pruned to various target sparsities in a loop to provide the flexibility to choose a pruned network based on its accuracy.

lineAccuracyPruningMagnitude = animatedline('Color','g','Marker','o','LineWidth',1.5);
legend("Original Accuracy","Magnitude Pruning Accuracy","Location","southwest")

% Compute magnitude scores
scoresMagnitude = calculateMagnitudeScore(dlnet);

for idx = 1:numel(iterationScheme)

    prunedNetMagnitude = dlnet;
    
    % Update the pruning mask
    pruningMaskMagnitude{idx} = calculateMask(scoresMagnitude,iterationScheme(idx));
    
    % Check the number of zero entries in the pruning mask
    numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskMagnitude{idx}.Value));
    sparsity = numPrunedParams/numTotalParams;
    
    % Apply pruning mask to network parameters
    prunedNetMagnitude.Learnables = dlupdate(@(W,M)W.*M, prunedNetMagnitude.Learnables, pruningMaskMagnitude{idx});
    
    % Compute validation accuracy on pruned network
    accuracyMagnitude = evaluateAccuracy(prunedNetMagnitude,mbqValidation,classes,trueLabels);
    
    % Display the pruning progress
    addpoints(lineAccuracyPruningMagnitude,100*sparsity,100*accuracyMagnitude)
    drawnow
end

SynFlow Pruning

Synaptic flow conservation (SynFlow) [2] scores are used for pruning. You can use this method to prune networks that use linear activation functions such as ReLU.

Initialize the mask. For the first iteration, no parameters are pruned, and the sparsity is 0%.

pruningMaskSynFlow = cell(1,numIterations); 
pruningMaskSynFlow{1} = dlupdate(@(p)true(size(p)),dlnet.Learnables);

The input data you use to compute the scores is a single image containing ones. If you are using a GPU, convert the data to a gpuArray.

dlX = dlarray(ones(net.Layers(1).InputSize),'SSC');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

The below loop implements iterative synaptic flow score for pruning [2] where a custom cost function evaluates the SynFlow score for each parameter used for network pruning.

lineAccuracyPruningSynflow = animatedline('Color','r','Marker','o','LineWidth',1.5);
legend("Original Accuracy","Magnitude Pruning Accuracy","Synaptic Flow Accuracy","Location","southwest")

prunedNetSynFlow = dlnet;

% Iteratively increase sparsity
for idx = 1:numel(iterationScheme)
    % Compute SynFlow scores
    scoresSynFlow = calculateSynFlowScore(prunedNetSynFlow,dlX);
    
    % Update the pruning mask
    pruningMaskSynFlow{idx} = calculateMask(scoresSynFlow,iterationScheme(idx));
    
    % Check the number of zero entries in the pruning mask
    numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskSynFlow{idx}.Value));
    sparsity = numPrunedParams/numTotalParams;
    
    % Apply pruning mask to network parameters
    prunedNetSynFlow.Learnables = dlupdate(@(W,M)W.*M, prunedNetSynFlow.Learnables, pruningMaskSynFlow{idx});
    
    % Compute validation accuracy on pruned network
    accuracySynFlow = evaluateAccuracy(prunedNetSynFlow,mbqValidation,classes,trueLabels);
     
    % Display the pruning progress
    addpoints(lineAccuracyPruningSynflow,100*sparsity,100*accuracySynFlow)
    drawnow
end

Investigate Structure of Pruned Network

Choosing how much to prune a network is a trade-off between accuracy and sparsity. Use the sparsity versus accuracy plot to select the iteration with the desired sparsity level and acceptable accuracy.

pruningMethod = "SynFlow";
selectedIteration = 8;

prunedDLNet = createPrunedNet(dlnet,selectedIteration,pruningMaskSynFlow,pruningMaskMagnitude,pruningMethod);

[sparsityPerLayer,prunedChannelsPerLayer,numOutChannelsPerLayer,layerNames] = pruningStatistics(prunedDLNet);

Earlier convolutional layers are typically pruned less since they contain more relevant information about the core low-level structure of the image (e.g. edges and corners) which are essential for interpreting the image.

Plot the sparsity per layer for the selected pruning method and iteration.

figure
bar(sparsityPerLayer*100)
title("Sparsity per layer")
xlabel("Layer")
ylabel("Sparsity (%)")
xticks(1:numel(sparsityPerLayer))
xticklabels(layerNames)
xtickangle(45)
set(gca,'TickLabelInterpreter','none')

The pruning algorithm prunes single connections when you specify a low target sparsity. When you specify a high target sparsity, the pruning algorithm can prune whole filters and neurons in convolutional or fully connected layers.

figure
bar([prunedChannelsPerLayer,numOutChannelsPerLayer-prunedChannelsPerLayer],"stacked")
xlabel("Layer")
ylabel("Number of filters")
title("Number of filters per layer")
xticks(1:(numel(layerNames)))
xticklabels(layerNames)
xtickangle(45)
legend("Pruned number of channels/neurons" , "Original number of channels/neurons","Location","southoutside")
set(gca,'TickLabelInterpreter','none')

Evaluate Network Accuracy

Compare the accuracy of the network before and after pruning.

YPredOriginal = modelPredictions(dlnet,mbqValidation,classes);
accOriginal = mean(YPredOriginal == trueLabels)
accOriginal = 0.9908
YPredPruned = modelPredictions(prunedDLNet,mbqValidation,classes);
accPruned = mean(YPredPruned == trueLabels)
accPruned = 0.9328

Create a confusion matrix chart to explore the true class labels to the predicted class labels for the original and pruned network.

figure
confusionchart(trueLabels,YPredOriginal);
title("Original Network")

The validation set of the digits data contains 250 images for each class, so if a network predicts the class of each image perfectly, all scores on the diagonal equal 250 and no values are outside of the diagonal.

confusionchart(trueLabels,YPredPruned);
title("Pruned Network")

When pruning a network, compare the confusion chart of the original network and the pruned network to check how the accuracy for each class label changes for the selected sparsity level. If all numbers on the diagonal decrease roughly equally, no bias is present. However, if the decreases are not equal, you might need to choose a pruned network from an earlier iteration by reducing the value of the variable selectedIteration.

Quantize Pruned Network

Deep neural networks trained in MATLAB use single-precision floating point data types. Even networks that are small require a considerable amount of memory and hardware to perform floating-point arithmetic operations. These restrictions can inhibit deployment of deep learning models that have low computational power and less memory resources. By using a lower precision to store the weights and activations, you can reduce the memory requirements of the network. You can use Deep Learning Toolbox in tandem with the Deep Learning Model Quantization Library support package to reduce the memory footprint of a deep neural network by quantizing the weights, biases, and activations of the convolution layers to 8-bit scaled integer data types.

Pruning a network impacts the range statistics of parameters and activations at each layer, so the accuracy of the quantized network can change. To explore this difference, quantize the pruned network and use the quantized network to perform inference.

Split the data into calibration and validation data sets.

calibrationDataStore = splitEachLabel(imdsTrain,0.1,'randomize');
validationDataStore = imdsValidation;

Create a dlquantizer object and specify the pruned network as the network to quantize.

prunedNet  = assembleNetwork([prunedDLNet.Layers ; net.Layers(end-1:end)]);

quantObjPrunedNetwork = dlquantizer(prunedNet,'ExecutionEnvironment','GPU'); 

Use the calibrate function to exercise the network with the calibration data and collect range statistics for the weights, biases, and activations at each layer.

calResults = calibrate(quantObjPrunedNetwork, calibrationDataStore)

Use the validate function to compare the results of the network before and after quantization using the validation data set.

valResults = validate(quantObjPrunedNetwork, validationDataStore);

Examine the MetricResults.Result field of the validation output to see the accuracy of the quantized network.

valResults.MetricResults.Result
valResults.Statistics

Mini Batch Preprocessing Function

The preprocessMiniBatch function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating the data over the fourth dimension adds a third dimension to each image to use as a singleton channel dimension.

function X = preprocessMiniBatch(XCell)
% Extract image data from cell and concatenate.
X = cat(4,XCell{:});
end

Model Accuracy Function

Evaluate the classification accuracy of the dlnetwork. Accuracy is the percentage of labels correctly classified by the network.

function accuracy = evaluateAccuracy(dlnet,mbqValidation,classes,trueLabels)
YPred = modelPredictions(dlnet,mbqValidation,classes);
accuracy = mean(YPred == trueLabels);
end

SynFlow Score Function

The calculateSynFlowScore function calculates Synaptic Flow (SynFlow) scores. Synaptic saliency [2] is described as the class of gradient-based scores defined by the product of gradient of loss multiplied by the parameter value:

synFlowScore=d(loss)dθ*θ

The SynFlow score is a synaptic saliency score that uses the sum of all network outputs as a loss function:

loss=f(abs(θ),X)

f is the function represented by the neural network

θ are the parameters of the network

X is the input array to the network

To compute parameter gradients with respect to this loss function, use dlfeval and a model gradients function.

function score = calculateSynFlowScore(dlnet,dlX)
dlnet.Learnables = dlupdate(@abs, dlnet.Learnables);
gradients = dlfeval(@modelGradients,dlnet,dlX);
score = dlupdate(@(g,w)g.*w, gradients, dlnet.Learnables);
end

Model Gradients for SynFlow Score

function gradients = modelGradients(dlNet,inputArray)
% Evaluate the gradients on a given input to the dlnetwork
dlYPred = predict(dlNet,inputArray);
pseudoloss = sum(dlYPred,'all');
gradients = dlgradient(pseudoloss,dlNet.Learnables);
end

Magnitude Score Function

The calculateMagnitudeScore function returns the magnitude score, defined as the element-wise absolute value of the parameters.

function score = calculateMagnitudeScore(dlnet)
score = dlupdate(@abs, dlnet.Learnables);
end

Mask Generation Function

The calculateMask function returns a binary mask for the network parameters based on the given scores and the target sparsity.

function mask = calculateMask(scoresMagnitude,sparsity)
% Compute a binary mask based on the parameter-wise scores such that the mask contains a percentage of zeros as specified by sparsity.

% Flatten the cell array of scores into one long score vector
flattenedScores = cell2mat(cellfun(@(S)extractdata(gather(S(:))),scoresMagnitude.Value,'UniformOutput',false));
% Rank the scores and determine the threshold for removing connections for the
% given sparsity
flattenedScores = sort(flattenedScores);
k = round(sparsity*numel(flattenedScores));
if k==0
    thresh = 0;
else
    thresh = flattenedScores(k);
end
% Create a binary mask 
mask = dlupdate( @(S)S>thresh, scoresMagnitude);
end

Model Predictions Function

The modelPredictions function takes as input a dlnetwork object dlnet, a minibatchqueue of input data mbq, and the network classes, and computes the model predictions by iterating over all data in the minibatchqueue object. The function uses the onehotdecode function to find the predicted class with the highest score.

function predictions = modelPredictions(dlnet,mbq,classes)
predictions = [];
while hasdata(mbq)
    dlXTest = next(mbq);
    dlYPred = softmax(predict(dlnet,dlXTest));
    YPred = onehotdecode(dlYPred,classes,1)';
    predictions = [predictions; YPred];
end
reset(mbq)
end

Apply Pruning Function

The createPrunedNet function returns the pruned dlnetwork for the specified pruning algorithm and iteration.

function prunedNet = createPrunedNet(dlnet,selectedIteration,pruningMaskSynFlow,pruningMaskMagnitude,pruningMethod)
switch pruningMethod
    case "Magnitude"
        prunedNet = dlupdate(@(W,M)W.*M, dlnet, pruningMaskMagnitude{selectedIteration});
    case "SynFlow"
        prunedNet = dlupdate(@(W,M)W.*M, dlnet, pruningMaskSynFlow{selectedIteration});
end
end

Pruning Statistics Function

The pruningStatistics function extracts detailed layer-level pruning statistics such as the layer-level sparsity and the number of filters or neurons being pruned.

sparsityPerLayer - percentage of parameters pruned in each layer

prunedChannelsPerLayer - number of channels/neurons in each layer that can be removed as a result of pruning

numOutChannelsPerLayer - number of channels/neurons in each layer

function [sparsityPerLayer,prunedChannelsPerLayer,numOutChannelsPerLayer,layerNames] = pruningStatistics(dlnet)

layerNames = unique(dlnet.Learnables.Layer,'stable');
numLayers = numel(layerNames);
layerIDs = zeros(numLayers,1);
for idx = 1:numel(layerNames)
    layerIDs(idx) = find(layerNames(idx)=={dlnet.Layers.Name});
end

sparsityPerLayer = zeros(numLayers,1);
prunedChannelsPerLayer = zeros(numLayers,1);
numOutChannelsPerLayer = zeros(numLayers,1);

numParams = zeros(numLayers,1);
numPrunedParams = zeros(numLayers,1);
for idx = 1:numLayers
    layer = dlnet.Layers(layerIDs(idx));
    
    % Calculate the sparsity
    paramIDs = strcmp(dlnet.Learnables.Layer,layerNames(idx));
    paramValue = dlnet.Learnables.Value(paramIDs);
    for p = 1:numel(paramValue)
        numParams(idx) = numParams(idx) + numel(paramValue{p});
        numPrunedParams(idx) = numPrunedParams(idx) + nnz(extractdata(paramValue{p})==0);
    end

    % Calculate channel statistics
    sparsityPerLayer(idx) = numPrunedParams(idx)/numParams(idx);
    switch class(layer)
        case "nnet.cnn.layer.FullyConnectedLayer"
            numOutChannelsPerLayer(idx) = layer.OutputSize;
            prunedChannelsPerLayer(idx) = nnz(all(layer.Weights==0,2)&layer.Bias(:)==0);
        case "nnet.cnn.layer.Convolution2DLayer"
            numOutChannelsPerLayer(idx) = layer.NumFilters;
            prunedChannelsPerLayer(idx) = nnz(reshape(all(layer.Weights==0,[1,2,3]),[],1)&layer.Bias(:)==0);
        case "nnet.cnn.layer.GroupedConvolution2DLayer"
            numOutChannelsPerLayer(idx) = layer.NumGroups*layer.NumFiltersPerGroup;
            prunedChannelsPerLayer(idx) = nnz(reshape(all(layer.Weights==0,[1,2,3]),[],1)&layer.Bias(:)==0);
        otherwise
            error("Unknown layer: "+class(layer))
    end
end
end

Load Digits Data set Function

The loadDigitDataset function loads the Digits data set and splits the data into training and validation data.

function [imdsTrain, imdsValidation] = loadDigitDataset()
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain, imdsValidation] = splitEachLabel(imds,0.75,"randomized");
end

Train Digit Recognition Network Function

The trainDigitDataNetwork function trains a convolutional neural network to classify digits in grayscale images.

function net = trainDigitDataNetwork(imdsTrain,imdsValidation)
layers = [
    imageInputLayer([28 28 1],"Normalization","rescale-zero-one")
    convolution2dLayer(3,8,'Padding','same')
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

% Specify the training options
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',10, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','none',"ExecutionEnvironment","auto");

% Train network
net = trainNetwork(imdsTrain,layers,options);
end

References

[1] Song Han, Jeff Pool, John Tran, and William J. Dally. 2015. "Learning Both Weights and Connections for Efficient Neural Networks." Advances in Neural Information Processing Systems 28 (NIPS 2015): 1135–1143.

[2] Hidenori Tanaka, Daniel Kunin, Daniel L. K. Yamins, and Surya Ganguli 2020. "Pruning Neural Networks Without Any Data by Iteratively Conserving Synaptic Flow." 34th Conference on Neural Information Processing Systems (NeurlPS 2020)

See Also

Functions

Related Topics