Convolutional Neural Network(CNN) in MATLAB
What is CNN and Why is it Widely Used?
CNN is a deep learning algorithm that is mostly used for image and video analysis. It is a special type of deep neural networks. I assume you already know what a deep neural network is! I will write a post later on neural networks. However, this post is focused more on building CNN in MATLAB and it’s explanation. So, what does make CNN so special?!
Firstly, CNN is a feature learning based algorithm. It can automatically detect which features are more important for images to be recognized.
Secondly, because of keeping only the important features, huge amount of data that is not important for the neural net gets eliminated thus saving computing time and computational power.
Last but not the least, training a network with CNN is very easy and the input images does not require much pre-processing. It can work with both RGB and gray scale images.
Why CNN Implementation in MATLAB is Getting Popular?
MATLAB is great tool for data exploration, analysis and visualization. It combines many premium quality tools and features important for scientific research. It is often used by researches to design simple to complicated system and simulate it to analyze the how the model is performing. Additionally, MATLAB comes with an special feature called the deep learning(DL) toolbox that made designing and training deep neural networks so easy that researches are more likely to use MATLAB over python. Another reason could be, performing DL operations in MATLAB made the whole system compact. Previously, the neural network training and system simulation was done in two different segments.
If you are familiar with MATLAB environment you would know that the MATLAB programming language is very understandable and easy learn. Just a few lines of code and your model is ready.
Training a Convolutional Neural Network(CNN) in MATLAB
I have written the following code for training a convolutional neural network on the MNIST handwritten digit dataset. You don’t need to download the dataset. MATLAB has it pre-download with the deep learning toolbox.
%% First Load the Hand Written Dataset from MATLAB Root
dataset = fullfile(matlabroot, ‘toolbox’, ‘nnet’, ‘nndemos’, …
‘nndatasets’, ‘DigitDataset’);
%% Let’s convert the dataset to MATLAB ‘imagedatastore’ object
imds = imageDatastore(dataset, …
‘IncludeSubfolders’,true,’LabelSource’,’foldernames’);
%% Let’s check if it’s correctly loaded
%imshow(imds.Files{1005});
%% Our dataset has 10 classes and each class has 1000 images
labelCount = countEachLabel(imds)
%% Let’s split the data into training and test data. 70% will be training data
[training_data, test_data] = splitEachLabel(imds, 0.7 ,’randomize’);
%% Let’s Define the layers of the CNN now
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,’Padding’,’same’)
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,’Stride’,2)
convolution2dLayer(3,16,’Padding’,’same’)
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,’Stride’,2)
convolution2dLayer(3,32,’Padding’,’same’)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
%% Set the options
options = trainingOptions(‘sgdm’, …
‘InitialLearnRate’,0.01, …
‘MaxEpochs’,4, …
‘Shuffle’,’every-epoch’, …
‘ValidationData’,test_data, …
‘ValidationFrequency’,30, …
‘Verbose’,false, …
‘Plots’,’training-progress’);
%% Let’s Train the network
net = trainNetwork(training_data, layers, options);
save net;
%% Test the accuracy
predicted_labels = classify(net, test_data);
actual_labels = test_data.Labels;
accuracy = sum(predicted_labels == actual_labels)/numel(actual_labels)
I have commented the use of each code segment. Just run it on your MATLAB to see the output. In this post we’re interested in discussing the CNN layer definition part and setting different parameters of the network. Inside the variable “layers = []” the neural net is defined. First of all the input image size.
imageInputLayer([28 28 1]): It will be “28 28 1” where the 28x28 is the pixels of the image and 1 stands for channel. In this case, all the images are in grayscale.
You may find convolution2dLayer() function calling three times. It sets three hidden convolutional neural network layers.
convolution2dLayer(3,8,’Padding’,’same’): ‘3’ stands for the 3x3 kernel size or the size of the filter. ‘8’ is the number of filters. The way of specifying parameter value here is first passing the parameter and then setting the property. The parameter ‘Padding’ sets some padding on all sides of the image. If we set the property to ‘same’ then the software will automatically set some paddings
The maxPooling2dLayer() function defines the max pooling layer that is responsible for reducing the number of unnecessary features.
maxPooling2dLayer(2,’Stride’,2): The first ‘2’ denotes a the pool size, so, it will be 2x2. If you you ‘3’, then the pool size will be 3x3. ‘Stride’ is the step size for traversing the input horizontally and vertically. The following ‘2’ defines the size of the Stride.
The fullyConnectedLayer(number_of_classes) function is the output layer and here we have to define how many classes the network is going to be trained for. In our case, the number_of_class = 10 as the handwritten digit dataset has 10 classes.
After that, we need to define the classifier and the classification layer. The neural network definition is completed here. Now, we need to set the options for training. Inside the ‘trainingOptions()’ function we first define the optimizer. We have used ‘sgdm’ or ‘Stochastic Gradient Descent Momentum’ optimizer. Then we have set the ‘InitialLearnRate’. Remember, the first one is the parameter and the second one is the value. So, the learning rate is set to 0.01. Define other parameters by the same way.
If you set the ‘Plots’,’training-progress’ pair, then you will se the learning curve and loss curve being plotted in real-time while training the network. The window shows many useful info. You, can see in the image below:
You will get final the accuracy once the training is completed.
You can also, export the learning and loss curve. The following video might help you with this.
That’s all for today. I’ve discussed only the important things to keep this post short and less confusing. Thank you for reading. :)