This blog post is part of a mini-series that discusses different aspects of building a PyTorch Deep Learning project using Variational Autoencoders.
In this section, we look at how we can use the code we wrote in the previous section and use it to build convolution requirements. This VAE would be better at identifying important features in images and thus producing even better images.
Best of all, this new model can be built with minimal additional code thanks to PyTorch modules and class inheritance.
ConvolutiOn is an operation commonly used in image processing to extract the properties of a particular image. Images are usually full of unnecessary information and zooming to any pixel, the surrounding pixels are likely to have a very similar color. In convolutional neural networks (CNN), many convolutional filters are automatically learned to obtain properties that are useful in classifying and identifying images. We simply borrowed these principles to be able to use convolutional layers to build VAE.
By building convolution VAE, we aim to have a better feature extraction process. While we do not perform any classification / regression task, we want the latent presentation to be as informative as possible. With more efficient decoding of the feature, the decoder can then create more convincing data points.
While this new model uses a new architecture, we want to write the code efficiently. Good and effective code to use DRY (Do not repeat yourself) principle. To avoid unnecessary duplication of code, we use a strong concept to build the model.
Inheritance is a very powerful concept present in object-oriented programming (OOP) languages. It allows users to define objects and then build new objects while retaining some of the functionality of the original object. Heritage is a fairly broad topic and there are things like multi-heritage that I am not going to go into detail about. Learn more about OOP programming in Python and legacy here.
The legacy is, in fact, so common that we have already used the legacy in it Part 1. Even without knowing it, heritage is widely used in PyTorch, where every neural network inherits from the basic class
That is why we only need to define
forward methods and basic class do the rest. The model we are going to build takes this step further and builds on the VAE that was built in the previous section.
Inheritance allows us to build complex models at different stages. The previous VAE model we built serves as a skeleton. It performs re-parameterization and implements the KL divergence loss.
We can then inherit this class and create a better model that uses an architecture that is better suited to the task.
This model can then be adapted by switching the encoder and decoder. The encoder simply performs representation learning and the decoder generates generation. These subnets can be simple linear layers or complex networks.
In our convolution requirement, we want to change these components by keeping everyone else identical. This can be easily done by inheritance
This way we can avoid much of the repetition of the code. Class methods such as
train_loader kept exactly the same and the legacy allows us to copy it automatically.
If you notice carefully, in the previous model. The propagation step involved flattening the vector before feeding it to the encoder. For Convolution VAE, we don’t want to do this flattening because it prevents us from 2D conversions.
Looks like the legacy works, we need to do a code fix!
Basically, a code fix makes some changes to the code while maintaining external functionality. This means that the code still has the same behavior with respect to inputs and outputs. A correction can be made to speed up the code or, in our case, streamline the code so that we can reuse it elsewhere.
Instead of rewriting the entire propagation step, we can modify our code so that the input voltage flattens and modifies back to 28 x 28 inside
self.decoder instead of being inside the forward function.
This allows the model to be more versatile because it can accommodate different encoders, such as convolution, where we don’t want to flatten the input vector.
But hold on! Let’s not do refraction yet. The worst thing you want to happen is that your model will break when you make code changes. We want to ensure that the VAE model continues to do exactly the same thing after reconstruction.
One good way to ensure that your project continues to work after the changes is to write unit tests.
Unit tests are simple scripts that you can use to make sure your code is working correctly. In the context of our model, we need to ensure that the model we build can continue to train and the gradients continue to grow well.
To do this, we use pytest, which is a powerful library for writing unit tests that also includes useful debugging tools to find out why tests fail.
First, we create a folder in the directory named
tests. Within this folder, we create a file named
test_model.py. This saves all necessary device tests.
Define a simple test:
Another great feature of the pytest is that it automatically searches the package for test functions. As long as the function name starts
test, pytest performs the test accordingly.
pytest on the command line we can confirm that the test passes.
Now that the appropriate testing systems are in place, we can begin making code changes.
One important thing to understand about PyTorch modules is that they are basically functions. When the input is passed to any PyTorch module, it simply performs some operations and increments the gradients.
This means that even simple formatting operations can be formatted as a PyTorch object.
Simply take the first line of the old forward function and add it as a module. Placement in this way
Flatten() as an encoder module performs the same.
Now enter the code for the stack module. For the MNIST dataset, this module converts the tensor back to its original form, which is (1,28,28)
In our case, because the images are in black and white, there is only one channel, but a stack module is built that can also work with color images.
To do this, we need to store information from the data in the model itself. This is done by transferring the original form of the data as parameters to the module. These parameters are channels, height and width. The forward operation is then a
view similar operation as the flattening module.
To store these parameters, we must use __inside of it__ function. This allows us to store these parameters as class variables. To do this, we first format it as a PyTorch module and this is done by calling
Now that we’ve confused these design functions into our own objects, we can use them
nn.Sequential define these functions as part of encoder and decoder modules.
And just like change operations are a part
Perform unit tests to verify that the code is working.
Nice! The test passes and the code executes as expected.