Get Started with Transfer Learning
This example shows how to use transfer learning to retrain SqueezeNet, a pretrained convolutional neural network, to classify a new set of images. Try this example to see how simple it is to get started with deep learning in MATLAB®.
For a visual walkthrough of the example, watch the video.
Transfer learning is commonly used in deep learning applications. You can take a pretrained network and use it as a starting point to learn a new task. Fine-tuning a network with transfer learning is usually much faster and easier than training a network with randomly initialized weights from scratch. You can quickly transfer learned features to a new task using a smaller number of training images.
Extract Data
In the workspace, extract the bat365 Merch data set. This is a small data set containing 75 images of bat365 merchandise, belonging to five different classes (cap, cube, playing cards, screwdriver, and torch).
unzip("MerchData.zip");
Load Pretrained Network
Open Deep Network Designer.
deepNetworkDesigner
Select SqueezeNet from the list of pretrained networks and click Open.
Deep Network Designer displays a zoomed-out view of the whole network.
Explore the network plot. To zoom in with the mouse, use Ctrl+scroll wheel. To pan, use the arrow keys, or hold down the scroll wheel and drag the mouse. Select a layer to view its properties. Deselect all layers to view the network summary in the Properties pane.
Import Data
To load the data into Deep Network Designer, on the Data tab, click Import Data > Import Image Classification Data.
In the Data source list, select Folder. Click Browse and select the extracted MerchData folder.
Divide the data into 70% training data and 30% validation data.
Specify augmentation operations to perform on the training images. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images. For this example, apply a random reflection in the x-axis, a random rotation from the range [-90,90] degrees, and a random rescaling from the range [1,2].
Click Import to import the data into Deep Network Designer.
Edit Network for Transfer Learning
To retrain SqueezeNet to classify new images, edit the last 2-D convolutional layer and the final classification layer of the network. In SqueezeNet, these layers have the names 'conv10'
and 'ClassificationLayer_predictions'
, respectively.
On the Designer pane, select the 'conv10'
layer. At the bottom of the Properties pane, click Unlock Layer. In the warning dialog that appears, click Unlock Anyway. This unlocks the layer properties so that you can adapt them to your new task.
Before R2023b: To edit the layer properties, you must replace the layers instead of unlocking them. In the new convolutional 2-D layer, set the FilterSize to [1 1].
Set the NumFilters
property to the new number of classes, in this example, 5
.
Change the learning rates so that learning is faster in the new layer than in the transferred layers by setting WeightLearnRateFactor
and BiasLearnRateFactor
to 10
.
Configure the output layer. Select the classification layer, ClassificationLayer_predictions
, and click Unlock Layer and then click Unlock Anyway. For the unlocked output layer, you do not need to set the OutputSize
. At training time, Deep Network Designer automatically sets the output classes of the layer from the data.
Train Network
To choose the training options, select the Training tab and click Training Options. Set the initial learn rate to a small value to slow down learning in the transferred layers. In the previous step, you increased the learning rate factors for the 2-D convolutional layer to speed up learning in the new final layers. This combination of learning rate settings results in fast learning only in the new layers and slower learning in the other layers.
For this example, set InitialLearnRate to 0.0001
, MaxEpochs to 8
, and ValidationFrequency to 5
. As there are 55 observations, set MiniBatchSize to 11
to divide the training data evenly and ensure the whole training set is used during each epoch.
To train the network with the specified training options, click OK and then click Train.
Deep Network Designer allows you to visualize and monitor the training progress. You can then edit the training options and retrain the network, if required.
Export Results and Generate MATLAB Code
To export the results from training, on the Training tab, select Export > Export Trained Network and Results. Deep Network Designer exports the trained network as the variable trainedNetwork_1
and the training info as the variable trainInfoStruct_1
.
You can also generate MATLAB code, which recreates the network and the training options used. On the Training tab, select Export > Generate Code for Training. Examine the MATLAB code to learn how to programmatically prepare the data for training, create the network architecture, and train the network.
Classify New Image
Load a new image to classify using the trained network.
I = imread("MerchDataTest.jpg");
Resize the test image to match the network input size.
I = imresize(I, [227 227]);
Classify the test image using the trained network.
[YPred,probs] = classify(trainedNetwork_1,I); imshow(I) label = YPred; title(string(label) + ", " + num2str(100*max(probs),3) + "%");
References
[1] ImageNet. http://www.image-net.org
[2] Iandola, Forrest N., Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, and Kurt Keutzer. "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5 MB model size." Preprint, submitted November 4, 2016. https://arxiv.org/abs/1602.07360.
[3] Iandola, Forrest N. "SqueezeNet." https://github.com/forresti/SqueezeNet.
See Also
trainNetwork
| trainingOptions
| squeezenet
| Deep Network
Designer
Related Topics
- Try Deep Learning in 10 Lines of MATLAB Code
- Classify Image Using Pretrained Network
- Transfer Learning with Deep Network Designer
- Create Simple Image Classification Network Using Deep Network Designer
- Create Simple Image Classification Network
- Create Simple Sequence Classification Network Using Deep Network Designer
- Generate Experiment Using Deep Network Designer