Train Network in Parallel with Custom Training Loop
This example shows how to set up a custom training loop to train a network in parallel. In this example, parallel workers train on portions of the overall mini-batch. If you have a GPU, then training happens on the GPU. During training, a DataQueue
object sends training progress information back to the MATLAB client.
Load Data Set
Load the digit data set and create an image datastore for the data set. Split the datastore into training and test datastores in a randomized way. Create an augmentedImageDatastore
containing the training data and the shuffle the data with the shuffle
function.
digitDatasetPath = fullfile(toolboxdir("nnet"),"nndemos", ... "nndatasets","DigitDataset"); imds = imageDatastore(digitDatasetPath, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized"); inputSize = [28 28 1]; augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain); augimdsTrain = shuffle(augimdsTrain);
Determine the different classes in the training set.
classes = categories(imdsTrain.Labels); numClasses = numel(classes);
Define Network
Define your network architecture. This network architecture includes batch normalization layers, which track the mean and variance statistics of the data set. When training in parallel, combine the statistics from all of the workers at the end of each iteration step, to ensure the network state reflects the whole mini-batch. Otherwise, the network state can diverge across the workers. If you are training stateful recurrent neural networks (RNNs), for example, using sequence data that has been split into smaller sequences to train networks containing LSTM or GRU layers, you must also manage the state between the workers.
layers = [
imageInputLayer(inputSize,Normalization="none")
convolution2dLayer(5,20)
batchNormalizationLayer
reluLayer
convolution2dLayer(3,20,Padding=1)
batchNormalizationLayer
reluLayer
convolution2dLayer(3,20,Padding=1)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer];
Create a dlnetwork
object from the layer array. dlnetwork
objects allow for training with custom loops.
net = dlnetwork(layers)
net = dlnetwork with properties: Layers: [12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
Set Up Parallel Environment
Determine if GPUs are available for MATLAB to use with the canUseGPU
function.
If there are GPUs available, then train on the GPUs. Create a parallel pool with as many workers as GPUs.
If there are no GPUs available, then train on the CPUs. Create a parallel pool with the default number of workers.
if canUseGPU executionEnvironment = "gpu"; numberOfGPUs = gpuDeviceCount("available"); pool = parpool(numberOfGPUs); else executionEnvironment = "cpu"; pool = parpool; end
Starting parallel pool (parpool) using the 'Processes' profile ... Connected to parallel pool with 4 workers.
Get the number of workers in the parallel pool. Later in this example, you divide the workload according to this number.
numWorkers = pool.NumWorkers;
Train Model
Specify the training options.
numEpochs = 20; miniBatchSize = 128; velocity = [];
For GPU training, a recommended practice is to scale up the mini-batch size linearly with the number of GPUs, in order to keep the workload on each GPU constant. For more related advice, see Deep Learning with MATLAB on Multiple GPUs.
if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* numWorkers end
miniBatchSize = 512
Calculate the mini-batch size for each worker by dividing the overall mini-batch size evenly among the workers. Distribute the remainder across the first workers.
workerMiniBatchSize = floor(miniBatchSize ./ repmat(numWorkers,1,numWorkers)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,numWorkers-remainder)]
workerMiniBatchSize = 1×4
128 128 128 128
This network contains batch normalization layers that keep track of the mean and variance of the data the network is trained on. Since each worker processes a portion of each mini-batch during each iteration, the mean and variance must be aggregated across all the workers. Find the indices of the mean and variance state parameters of the batch normalization layers in the network state property.
batchNormLayers = arrayfun(@(l)isa(l,"nnet.cnn.layer.BatchNormalizationLayer"),net.Layers); batchNormLayersNames = string({net.Layers(batchNormLayers).Name}); state = net.State; isBatchNormalizationStateMean = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedMean"; isBatchNormalizationStateVariance = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedVariance";
Initialize the TrainingProgressMonitor
object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor( ... Metrics="TrainingLoss", ... Info=["Epoch" "Workers"], ... XLabel="Iteration");
Create a Dataqueue
object on the workers to send a flag to stop training when the Stop button is clicked.
spmd stopTrainingEventQueue = parallel.pool.DataQueue; end stopTrainingQueue = stopTrainingEventQueue{1};
To send data back from the workers during training, create a DataQueue
object. Use afterEach
to set up a function, displayTrainingProgress
, to call each time a worker sends data. displayTrainingProgress
is a supporting function, defined at the end of this example, that displays updates the TrainingProgressMonitor
object to show the training progress information that comes from the workers and sends a flag to the workers if the Stop button has been clicked.
dataQueue = parallel.pool.DataQueue; displayFcn = @(x) displayTrainingProgress(x,numEpochs,numWorkers,monitor,stopTrainingQueue); afterEach(dataQueue,displayFcn)
Train the model using a custom parallel training loop, as detailed in the following steps. To execute the code simultaneously on all the workers, use an spmd
block. Within the spmd
block, spmdIndex
gives the index of the worker currently executing the code.
Before training, partition the datastore for each worker by using the partition
function. Use the partitioned datastore to create a minibatchqueue
on each worker. For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
(defined at the end of this example) to normalize the data, convert the target classes to one-hot encoded variables, and determine the number of observations in the mini-batch.Format the image data with the dimension labels
"SSCB"
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
. Do not add a format to the target classes or the number of observations.Train on a GPU if one is available. By default, the
minibatchqueue
object converts each output to agpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox) (Parallel Computing Toolbox).
For each epoch, shuffle the datastore with the shuffle
function. For each iteration in the epoch:
Ensure that all workers have data available before beginning processing it in parallel, by performing a global
and
operation usingspmdreduce
on the result of thehasdata
function.Read a mini-batch from the
minibatchqueue
by using thenext
function.Compute the loss and the gradients of the network on each worker by calling
dlfeval
on themodelLoss
function. Thedlfeval
function evaluates the helper functionmodelLoss
with automatic differentiation enabled, somodelLoss
can compute the gradients with respect to the loss in an automatic way.modelLoss
is defined at the end of the example and returns the loss and gradients given a network, mini-batch of data, and targets.To obtain the overall loss, aggregate the losses on all workers. This example uses cross entropy for the loss function, and the aggregated loss is the sum of all losses. Before aggregating, normalize each loss by multiplying by the proportion of the overall mini-batch that the worker is working on. Use
spmdPlus
to add all losses together and replicate the results across workers.To aggregate and update the gradients of all workers, use the
dlupdate
function with theaggregateGradients
function.aggregateGradients
is a supporting function defined at the end of this example. This function usesspmdPlus
to add together and replicate gradients across workers, following normalization according to the proportion of the overall mini-batch that each worker is working on.Aggregate the state of the network on all workers using the
aggregateState
function.aggregateState
is a supporting function defined at the end of this example. The batch normalization layers in the network track the mean and variance of the data. Since the complete mini-batch is spread across multiple workers, aggregate the network state after each iteration to compute the mean and variance of the whole minibatch.After computing the final gradients, update the network learnable parameters with the
sgdmupdate
function.
After each epoch, check whether the Stop button has been clicked and send training progress information back to the client using the send
function with the Dataqueue
object. You only need to use one worker to send back data because all of the workers have the same loss information. To ensure that data is on the CPU and a client machine without a GPU can access it, use gather
on the dlarray
before sending it to the client. As communication between the workers occurs after each epoch, click Stop to stop training at the end of the current epoch. If you want the Stop button to stop training at the end of each iteration, you can check whether the Stop button has been clicked and send training progress information back to the client each iteration at the cost of increased communcation overhead.
spmd % Partition the datastore. workerImds = partition(augimdsTrain,numWorkers,spmdIndex); % Create minibatchqueue using partitioned datastore on each worker. workerMbq = minibatchqueue(workerImds,3,... MiniBatchSize=workerMiniBatchSize(spmdIndex),... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat=["SSCB" "" ""]); workerVelocity = velocity; epoch = 0; iteration = 0; stopRequest = false; while epoch < numEpochs && ~stopRequest epoch = epoch + 1; shuffle(workerMbq); % Loop over mini-batches. while spmdReduce(@and,hasdata(workerMbq)) && ~stopRequest iteration = iteration + 1; % Read a mini-batch of data. [workerX,workerT,workerNumObservations] = next(workerMbq); % Evaluate the model loss and gradients on the worker. [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerT); % Aggregate the losses on all workers. workerNormalizationFactor = workerMiniBatchSize(spmdIndex)./miniBatchSize; loss = spmdPlus(workerNormalizationFactor*extractdata(workerLoss)); % Aggregate the network state on all workers. net.State = aggregateState(workerState,workerNormalizationFactor,... isBatchNormalizationStateMean,isBatchNormalizationStateVariance); % Aggregate the gradients on all workers. workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor}); % Update the network parameters using the SGDM optimizer. [net,workerVelocity] = sgdmupdate(net,workerGradients,workerVelocity); end % Stop training if the Stop button has been clicked. stopRequest = spmdPlus(stopTrainingEventQueue.QueueLength); % Send training progress information to the client. if spmdIndex == 1 data = [epoch loss iteration]; send(dataQueue,gather(data)); end end end
Test Model
After you train the network, you can test its accuracy.
Load the test images into memory by using readall
on the test datastore, concatenate them, and normalize them.
XTest = readall(imdsTest); XTest = cat(4,XTest{:}); XTest = single(XTest) ./ 255; TTest = imdsTest.Labels;
After the training is complete, all workers have the same complete trained network. Retrieve any of them.
netFinal = net{1};
To classify images using a dlnetwork
object, use the predict
function on a dlarray
.
YTest = predict(netFinal,dlarray(XTest,"SSCB"));
From the predicted scores, find the class with the highest score with the max
function. Before you do that, extract the data from the dlarray
with the extractdata
function.
[~,idx] = max(extractdata(YTest),[],1); YTest = classes(idx);
To obtain the classification accuracy of the model, compare the predictions on the test set against the true classes.
accuracy = mean(YTest==TTest)
accuracy = 0.9070
Mini Batch Preprocessing Function
The preprocessMiniBatch
function preprocesses a mini-batch of predictors and target classes using the following steps:
Determine the number of observations in the mini-batch
Preprocess the images using the
preprocessMiniBatchPredictors
function.Extract the target class data from the incoming cell array and concatenate into a categorical array along the second dimension.
One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
function [X,Y,numObs] = preprocessMiniBatch(XCell,YCell) numObs = numel(YCell); % Preprocess predictors. X = preprocessMiniBatchPredictors(XCell); % Extract class data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode classes. Y = onehotencode(Y,1); end
Mini-Batch Predictors Preprocessing Function
The preprocessMiniBatchPredictors
function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension. The data are then normalized.
function X = preprocessMiniBatchPredictors(XCell) % Concatenate. X = cat(4,XCell{1:end}); % Normalize. X = X ./ 255; end
Model Loss Function
Define a function, modelLoss
, to compute the gradients of the loss with respect to the learnable parameters of the network. This function computes the network outputs for a mini-batch X
with forward
and calculates the loss, given the targets T
, using cross entropy. When you call this function with dlfeval
, automatic differentiation is enabled, and dlgradient
can compute the gradients of the loss with respect to the learnables automatically.
function [loss,gradients,state] = modelLoss(net,X,T) [Y,state] = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
Display Training Progress Function
Define a function to display training progress information that comes from the workers and checks whether the Stop button has been clicked. If the Stop button has been clicked, a flag is sent to the workers to indicate that training should stop. The DataQueue
in this example calls this function every time a worker sends data.
function displayTrainingProgress(data,numEpochs,numWorkers,monitor,stopTrainingQueue) epoch = data(1); loss = data(2); iteration = data(3); recordMetrics(monitor,iteration,TrainingLoss=loss); updateInfo(monitor,Epoch=epoch + " of " + numEpochs, Workers= numWorkers); monitor.Progress = 100 * epoch/numEpochs; if monitor.Stop send(stopTrainingQueue,true); end end
Aggregate Gradients Function
Define a function that aggregates the gradients on all workers by adding them together. spmdPlus
adds together and replicates all the gradients on the workers. Before adding them together, normalize them by multiplying them by a factor that represents the proportion of the overall mini-batch that the worker is working on. To retrieve the contents of a dlarray
, u
se extractdata
.
function gradients = aggregateGradients(gradients,factor) gradients = extractdata(gradients); gradients = spmdPlus(factor*gradients); end
Aggregate State Function
Define a function that aggregates the network state on all workers. The network state contains the trained batch normalization statistics of the data set. Since each worker only sees a portion of the mini-batch, aggregate the network state so that the statistics are representative of the statistics across all the data. For each mini-batch, the combined mean is calculated as a weighted average of the mean across the workers for each iteration. The combined variance is calculated according to the following formula:
where is the total number of workers, is the total number of observations in a mini-batch, is the number of observations processed on the th worker, and are the mean and variance statistics calculated on that worker, and is the combined mean across all workers.
function state = aggregateState(state,factor,... isBatchNormalizationStateMean,isBatchNormalizationStateVariance) stateMeans = state.Value(isBatchNormalizationStateMean); stateVariances = state.Value(isBatchNormalizationStateVariance); for j = 1:numel(stateMeans) meanVal = stateMeans{j}; varVal = stateVariances{j}; % Calculate combined mean. combinedMean = spmdPlus(factor*meanVal); % Calculate combined variance terms to sum. varTerm = factor.*(varVal + (meanVal - combinedMean).^2); % Update state. stateMeans{j} = combinedMean; stateVariances{j} = spmdPlus(varTerm); end state.Value(isBatchNormalizationStateMean) = stateMeans; state.Value(isBatchNormalizationStateVariance) = stateVariances; end
See Also
dlarray
| dlnetwork
| sgdmupdate
| dlupdate
| dlfeval
| dlgradient
| crossentropy
| softmax
| forward
| predict