Autoencoder in PyTorch

Inspired by Implementing an Autoencoder in PyTorch. The implementation in the article does not work with PyTorch 1.4.0. So, I had to slightly modify the original code. The code is shared here for my future reference.

First, I prepared a conda environment for this example.

$ conda create --name latency python=3.6

And activate it.

$ conda activate latency

Install Spyder for the environment.

$ conda install spyder

Then, install pytorch and torchvsion.

$ conda install -c pytorch pytorch torchvision

To plot, matplotlib is required.

$ conda install matplotlib

After all the successful installations of the required packages including torch and torchvision, write the following code with Spyder.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 25 10:14:03 2020

@author: jaerock
"""

import matplotlib.pyplot as plt
#import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

# Setup
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

batch_size = 512
epochs = 40
learning_rate = 1e-3

# Dataset
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
    root ="~/torch_datasets", train = True,
    transform = transform, download = True)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size = batch_size, shuffle = True)

# Image size
img_size = (28, 28)
flatten_size = img_size[0]*img_size[1]

# Autoencoder
class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features = kwargs["input_shape"], out_features = 128)
        self.encoder_output_layer = nn.Linear(
            in_features = 128, out_features=128)
        self.decoder_hidden_layer = nn.Linear(
            in_features = 128, out_features=128)
        self.decoder_output_layer = nn.Linear(
            in_features = 128, out_features=kwargs["input_shape"])

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)
        code = self.encoder_output_layer(activation)
        code = torch.sigmoid(code)
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.sigmoid(activation)

        return reconstructed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AE(input_shape = flatten_size) #.to(device)
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.MSELoss()

for epoch in range(epochs):
    loss = 0
    for batch_features, _ in train_loader:
        batch_features = batch_features.view(-1, flatten_size) #.to(device)
        optimizer.zero_grad()
        outputs = model(batch_features)
        train_loss = criterion(outputs, batch_features)
        train_loss.backward()
        optimizer.step()
        loss += train_loss.item()

    loss = loss / len(train_loader)
    print("epoch: {}/{}, recon_loss = {:.8f}".format(epoch + 1, epochs, loss))

# Test Dataset
test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=10, shuffle=False)
test_examples = None

with torch.no_grad():
    for batch_features in test_loader:
        batch_features = batch_features[0]
        test_examples = batch_features.view(-1, flatten_size)
        reconstruction = model(test_examples)
        break

# Visualize
with torch.no_grad():
    number = 10
    plt.figure(figsize=(20, 4))
    for index in range(number):
        ax = plt.subplot(2, number, index + 1)
        plt.imshow(test_examples[index].numpy().reshape(
            img_size[0], img_size[1]))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        ax = plt.subplot(2, number, index + 1 + number)
        plt.imshow(reconstruction[index].numpy().reshape(
            img_size[0], img_size[1]))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    plt.show()

The output of the program is as follows.

epoch: 1/40, recon_loss = 0.08389360
epoch: 2/40, recon_loss = 0.06247949
epoch: 3/40, recon_loss = 0.05559721
epoch: 4/40, recon_loss = 0.04590546
epoch: 5/40, recon_loss = 0.04079535
epoch: 6/40, recon_loss = 0.03859406
epoch: 7/40, recon_loss = 0.03625911
epoch: 8/40, recon_loss = 0.03332380
epoch: 9/40, recon_loss = 0.03136062
epoch: 10/40, recon_loss = 0.02987017
epoch: 11/40, recon_loss = 0.02821913
epoch: 12/40, recon_loss = 0.02613089
epoch: 13/40, recon_loss = 0.02480429
epoch: 14/40, recon_loss = 0.02399249
epoch: 15/40, recon_loss = 0.02311944
epoch: 16/40, recon_loss = 0.02207562
epoch: 17/40, recon_loss = 0.02090794
epoch: 18/40, recon_loss = 0.01969131
epoch: 19/40, recon_loss = 0.01886172
epoch: 20/40, recon_loss = 0.01819751
epoch: 21/40, recon_loss = 0.01752722
epoch: 22/40, recon_loss = 0.01684558
epoch: 23/40, recon_loss = 0.01616531
epoch: 24/40, recon_loss = 0.01554577
epoch: 25/40, recon_loss = 0.01500371
epoch: 26/40, recon_loss = 0.01453820
epoch: 27/40, recon_loss = 0.01413630
epoch: 28/40, recon_loss = 0.01375301
epoch: 29/40, recon_loss = 0.01341729
epoch: 30/40, recon_loss = 0.01307167
epoch: 31/40, recon_loss = 0.01275160
epoch: 32/40, recon_loss = 0.01244124
epoch: 33/40, recon_loss = 0.01216159
epoch: 34/40, recon_loss = 0.01186636
epoch: 35/40, recon_loss = 0.01157315
epoch: 36/40, recon_loss = 0.01127241
epoch: 37/40, recon_loss = 0.01100121
epoch: 38/40, recon_loss = 0.01075159
epoch: 39/40, recon_loss = 0.01052394
epoch: 40/40, recon_loss = 0.01035362

Leave a Reply

Your email address will not be published. Required fields are marked *