Inspired by Implementing an Autoencoder in PyTorch. The implementation in the article does not work with PyTorch
1.4.0. So, I had to slightly modify the original code. The code is shared here for my future reference.
First, I prepared a conda environment for this example.
$ conda create --name latency python=3.6
And activate it.
$ conda activate latency
Install Spyder for the environment.
$ conda install spyder
Then, install pytorch
and torchvsion
.
$ conda install -c pytorch pytorch torchvision
To plot, matplotlib
is required.
$ conda install matplotlib
After all the successful installations of the required packages including torch
and torchvision
, write the following code with Spyder.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 25 10:14:03 2020
@author: jaerock
"""
import matplotlib.pyplot as plt
#import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
# Setup
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
batch_size = 512
epochs = 40
learning_rate = 1e-3
# Dataset
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
root ="~/torch_datasets", train = True,
transform = transform, download = True)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size = batch_size, shuffle = True)
# Image size
img_size = (28, 28)
flatten_size = img_size[0]*img_size[1]
# Autoencoder
class AE(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.encoder_hidden_layer = nn.Linear(
in_features = kwargs["input_shape"], out_features = 128)
self.encoder_output_layer = nn.Linear(
in_features = 128, out_features=128)
self.decoder_hidden_layer = nn.Linear(
in_features = 128, out_features=128)
self.decoder_output_layer = nn.Linear(
in_features = 128, out_features=kwargs["input_shape"])
def forward(self, features):
activation = self.encoder_hidden_layer(features)
activation = torch.relu(activation)
code = self.encoder_output_layer(activation)
code = torch.sigmoid(code)
activation = self.decoder_hidden_layer(code)
activation = torch.relu(activation)
activation = self.decoder_output_layer(activation)
reconstructed = torch.sigmoid(activation)
return reconstructed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AE(input_shape = flatten_size) #.to(device)
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.MSELoss()
for epoch in range(epochs):
loss = 0
for batch_features, _ in train_loader:
batch_features = batch_features.view(-1, flatten_size) #.to(device)
optimizer.zero_grad()
outputs = model(batch_features)
train_loss = criterion(outputs, batch_features)
train_loss.backward()
optimizer.step()
loss += train_loss.item()
loss = loss / len(train_loader)
print("epoch: {}/{}, recon_loss = {:.8f}".format(epoch + 1, epochs, loss))
# Test Dataset
test_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=10, shuffle=False)
test_examples = None
with torch.no_grad():
for batch_features in test_loader:
batch_features = batch_features[0]
test_examples = batch_features.view(-1, flatten_size)
reconstruction = model(test_examples)
break
# Visualize
with torch.no_grad():
number = 10
plt.figure(figsize=(20, 4))
for index in range(number):
ax = plt.subplot(2, number, index + 1)
plt.imshow(test_examples[index].numpy().reshape(
img_size[0], img_size[1]))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, number, index + 1 + number)
plt.imshow(reconstruction[index].numpy().reshape(
img_size[0], img_size[1]))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
The output of the program is as follows.
epoch: 1/40, recon_loss = 0.08389360
epoch: 2/40, recon_loss = 0.06247949
epoch: 3/40, recon_loss = 0.05559721
epoch: 4/40, recon_loss = 0.04590546
epoch: 5/40, recon_loss = 0.04079535
epoch: 6/40, recon_loss = 0.03859406
epoch: 7/40, recon_loss = 0.03625911
epoch: 8/40, recon_loss = 0.03332380
epoch: 9/40, recon_loss = 0.03136062
epoch: 10/40, recon_loss = 0.02987017
epoch: 11/40, recon_loss = 0.02821913
epoch: 12/40, recon_loss = 0.02613089
epoch: 13/40, recon_loss = 0.02480429
epoch: 14/40, recon_loss = 0.02399249
epoch: 15/40, recon_loss = 0.02311944
epoch: 16/40, recon_loss = 0.02207562
epoch: 17/40, recon_loss = 0.02090794
epoch: 18/40, recon_loss = 0.01969131
epoch: 19/40, recon_loss = 0.01886172
epoch: 20/40, recon_loss = 0.01819751
epoch: 21/40, recon_loss = 0.01752722
epoch: 22/40, recon_loss = 0.01684558
epoch: 23/40, recon_loss = 0.01616531
epoch: 24/40, recon_loss = 0.01554577
epoch: 25/40, recon_loss = 0.01500371
epoch: 26/40, recon_loss = 0.01453820
epoch: 27/40, recon_loss = 0.01413630
epoch: 28/40, recon_loss = 0.01375301
epoch: 29/40, recon_loss = 0.01341729
epoch: 30/40, recon_loss = 0.01307167
epoch: 31/40, recon_loss = 0.01275160
epoch: 32/40, recon_loss = 0.01244124
epoch: 33/40, recon_loss = 0.01216159
epoch: 34/40, recon_loss = 0.01186636
epoch: 35/40, recon_loss = 0.01157315
epoch: 36/40, recon_loss = 0.01127241
epoch: 37/40, recon_loss = 0.01100121
epoch: 38/40, recon_loss = 0.01075159
epoch: 39/40, recon_loss = 0.01052394
epoch: 40/40, recon_loss = 0.01035362