In this blog we will see how to train a simple neural network to fit a linear function y = Mx + C, where M is 3 and C is 0.5 and a quadratic function y = x2.
First, we will train a Neural Network to learn the value of M and C and in the end we will compare the model weights with M and C.
One can easily use this Google Colab Notebook on web browser to follow this blog and train given neural network for linear function.
1. Import required packages
2. Define a linear function
3. Lets generate some data for the above function and split data for training and testing. You can normalize the data for faster training.
4. Visualize the training and testing data using matplotlib
5. Lets create a simple neural network to learn the above function. This neural network have only input and output layer, no hidden layer. As this don't have hidden layer so relation between input and output will be y = w*x + b, where w is weight and b is bias for the input layer. Model have only two learning parameter, verify from the model summary.
6. Now we will train the above network for 1500 epochs, if you add hidden layer then model will converge in less epochs.
Training loss become negligible.
7. Evaluate model on test data.
Loss on test data is similar to training loss, means model is not overfitting.
8. Check generalization of model.
Let's check model prediction and ground truth for some random data.
Model is performing very well on the data outside of our training data range.
9. Visualize prediction on the test data.
10. We have seen model is performing very well let's check it's learned weights and compare with the original linear function coefficients.
As we know this model have only two weights and we can see these learned weights are almost equal to our M and C value.
It is quite easy to train a neural network for linear function but quite difficult for quadratic function.
Here is the Notebook for above code, open in google colab, modify the neural network and data for quadratic function and try to train. You will find out the challenges like finding optimum number of hidden layer / unit in hidden layer and many more.
Try to train for quadratic function or just check this Notebook for same.
That's all for this blog.
Thank you !!
No comments:
Post a Comment