Variational Autoencoder in TensorFlow

From Jan Hendrik Metzen 2017-01-03 with a small mods by d. gannon

From Metzen's intro:

The main motivation for this post was that I wanted to get more experience with both Variational Autoencoders (VAEs) and with Tensorflow. Thus, implementing the former in the latter sounded like a good idea for learning about both at the same time. This post summarizes the result.

Note: The post was updated on December 7th 2015:

  • a bug in the computation of the latent_loss was fixed (removed an erroneous factor 2). Thanks Colin Fang for pointing this out.
  • Using a Bernoulli distribution rather than a Gaussian distribution in the generator network

Note: The post was updated on January 3rd 2017:

  • changes required for supporting TensorFlow v0.12 and Python 3 support

Notes on this version

From d. gannon: This version uses images courtesy of Maryana Alegro from the University of California, San Francisco. It is really interesting. Dr. Alegro tells us that the data consists of “images of neurons drawn from real immunofluorescence (IF) microscopy of post-mortem human brain tissue. These are part of an UCSF study that aims to understand Alzheimer’s disease (AD) development in brain stem nuclei, which are affected by the disease years before initial symptoms onset. The red cells are stained using a fluorescence marker called CP-13 that binds to TAU tangles, commonly associated with AD. Green cells are stained for aCasp-6, which binds to proteins present during apoptosis (cell death) and yellow cells that have overlap of both markers.” The goal for their project is to was quantifying the presence of these three classes in IF images can help understand if the presence of TAU is really causing cell death in brain stem nuclei. They have a very nice paper “Automating cell detection and classification in human brain fluorescent microscopy images using dictionary learning and sparse coding” in Journal of Neuroscience Methods 282 (2017) 20–33, that describes some of their work.

The article that describes the work here is https://cloud4scieng.org/manifold-learning-and-deep-autoencoders-in-science/

In [2]:
import numpy as np
import tensorflow as tf
import cv2
import random

import matplotlib.pyplot as plt
%matplotlib inline

np.random.seed(0)
tf.set_random_seed(0)
totalimg = 1032
n_samples = totalimg

download the data

The link to download a tar'd gzipped copy of the data is here: http://www.imagexd.org/assets/cells.tar.gz . The following function is looking for the path to the un-tar'd images and a file with the name of each image one per line. you can create this file with the command

ls > classlist

In [3]:
def read_images(size, listfile, datafile):
    classes = np.array([0]*size)
    data = np.array([np.zeros(28*28*3, dtype=np.float32)]*size )
    #print data.shape
    with open(listfile) as f:
        i = 0
        for line_of_text in f:
            if i == size:
                return classes, data
            with open(datafile+"/"+line_of_text[:-1], 'rb') as infile:
                 buf = infile.read()
            infile.close()
            #use numpy to construct an array from the bytes
            x = np.fromstring(buf, dtype='uint8')
            #decode the array into an image
            img = cv2.imdecode(x, cv2.IMREAD_UNCHANGED)
            img2 = cv2.resize(img, (28,28))
            img3 = img2.reshape(28*28*3)
            f = 256.0
            data[i] = img3/(1.0*f)
            x = line_of_text.find("class")
            classes[i] = int(line_of_text[x+5:x+6])
            i = i+1
    if i != size:
        print("size error "+i)
    return classes, data

I have stored the untar'd and ungzipped data in director newcells/cells. In addition to reading the full dataset, save all the class0, class1 and class0 cells in seperate directories. We will use these later.

In [4]:
classes, data = read_images(totalimg, 'newcells/cells/classlist', 'newcells/cells')
c2class,c2data = read_images(300,'newcells/class2/classlist','newcells/class2')
c1class,c1data = read_images(300,'newcells/class1/classlist','newcells/class1')
c0class,c0data = read_images(310,'newcells/class0/classlist','newcells/class0')
In [5]:
c0data.shape
Out[5]:
(310, 2352)
In [6]:
plt.imshow(c0data[305].reshape(28, 28,3))
Out[6]:
<matplotlib.image.AxesImage at 0x7fcb3960a250>
In [7]:
def get_next_batch(size, classes, data):
    if size > totalimg:
        s = totalimg
    else:
        s = size
    c = np.array([0]*s)
    d = np.array([np.zeros(28*28*3, dtype=np.float32)]*s )
    for i in range(s):
        j = np.random.randint(0,len(classes)-1)
        c[i]= classes[j]
        d[i]= data[j]
    return d, c

    
In [8]:
def xavier_init(fan_in, fan_out, constant=1): 
    """ Xavier initialization of network weights"""
    # https://stackoverflow.com/questions/33640581/how-to-do-xavier-initialization-on-tensorflow
    low = -constant*np.sqrt(6.0/(fan_in + fan_out)) 
    high = constant*np.sqrt(6.0/(fan_in + fan_out))
    return tf.random_uniform((fan_in, fan_out), 
                             minval=low, maxval=high, 
                             dtype=tf.float32)

Based on this, we define now a class "VariationalAutoencoder" with a sklearn-like interface that can be trained incrementally with mini-batches using partial_fit. The trained model can be used to reconstruct unseen input, to generate new samples, and to map inputs to the latent space.

In [ ]:
 
$$ KL(P || Q) = - \sum_x p(x)log q(x) + \sum_x p(x)log(p(x)) $$$$ Z_{\mu} , Z_{ln(\sigma^2)} = enc(X) $$$$ \epsilon \in N(0,1) $$$$ Z = Z_{\mu} + \epsilon \sqrt{exp(Z_{ln(\sigma^2)})} $$$$ X_{recon\mu} = dec(Z) $$

the loss function is

$$ ReconLoss = \sum{ (X*ln(X_{recon\mu}) + (1-X)*ln(1 - X_{recon\mu})})$$$$ LatentLoss = \sum{ (1+Z_{ln(\sigma^2)} - {Z_{\mu}}^2 - \exp(Z_{ln(\sigma^2)})) } $$
In [9]:
class VariationalAutoencoder(object):
    """ Variation Autoencoder (VAE) with an sklearn-like interface implemented using TensorFlow.
    
    This implementation uses probabilistic encoders and decoders using Gaussian 
    distributions and  realized by multi-layer perceptrons. The VAE can be learned
    end-to-end.
    
    See "Auto-Encoding Variational Bayes" by Kingma and Welling for more details.
    """
    def __init__(self, network_architecture, transfer_fct=tf.nn.softplus, 
                 learning_rate=0.001, batch_size=100):
        self.network_architecture = network_architecture
        self.transfer_fct = transfer_fct
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        
        # tf Graph input
        self.x = tf.placeholder(tf.float32, [None, network_architecture["n_input"]])
        
        # Create autoencoder network
        self._create_network()
        # Define loss function based variational upper-bound and 
        # corresponding optimizer
        self._create_loss_optimizer()
        
        # Initializing the tensor flow variables
        init = tf.initialize_all_variables()

        # Launch the session
        self.sess = tf.InteractiveSession()
        self.sess.run(init)
    
    def _create_network(self):
        # Initialize autoencode network weights and biases
        network_weights = self._initialize_weights(**self.network_architecture)

        # Use recognition network to determine mean and 
        # (log) variance of Gaussian distribution in latent
        # space
        self.z_mean, self.z_log_sigma_sq = \
            self._recognition_network(network_weights["weights_recog"], 
                                      network_weights["biases_recog"])

        # Draw one sample z from Gaussian distribution
        n_z = self.network_architecture["n_z"]
        eps = tf.random_normal((self.batch_size, n_z), 0, 1, 
                               dtype=tf.float32)
        # z = mu + sigma*epsilon
        self.z = tf.add(self.z_mean, 
                        tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps))

        # Use generator to determine mean of
        # Bernoulli distribution of reconstructed input
        self.x_reconstr_mean = \
            self._generator_network(network_weights["weights_gener"],
                                    network_weights["biases_gener"])
            
    def _initialize_weights(self, n_hidden_recog_1, n_hidden_recog_2, 
                            n_hidden_gener_1,  n_hidden_gener_2, 
                            n_input, n_z):
        all_weights = dict()
        all_weights['weights_recog'] = {
            'h1': tf.Variable(xavier_init(n_input, n_hidden_recog_1)),
            'h2': tf.Variable(xavier_init(n_hidden_recog_1, n_hidden_recog_2)),
            'out_mean': tf.Variable(xavier_init(n_hidden_recog_2, n_z)),
            'out_log_sigma': tf.Variable(xavier_init(n_hidden_recog_2, n_z))}
        all_weights['biases_recog'] = {
            'b1': tf.Variable(tf.zeros([n_hidden_recog_1], dtype=tf.float32)),
            'b2': tf.Variable(tf.zeros([n_hidden_recog_2], dtype=tf.float32)),
            'out_mean': tf.Variable(tf.zeros([n_z], dtype=tf.float32)),
            'out_log_sigma': tf.Variable(tf.zeros([n_z], dtype=tf.float32))}
        all_weights['weights_gener'] = {
            'h1': tf.Variable(xavier_init(n_z, n_hidden_gener_1)),
            'h2': tf.Variable(xavier_init(n_hidden_gener_1, n_hidden_gener_2)),
            'out_mean': tf.Variable(xavier_init(n_hidden_gener_2, n_input)),
            'out_log_sigma': tf.Variable(xavier_init(n_hidden_gener_2, n_input))}
        all_weights['biases_gener'] = {
            'b1': tf.Variable(tf.zeros([n_hidden_gener_1], dtype=tf.float32)),
            'b2': tf.Variable(tf.zeros([n_hidden_gener_2], dtype=tf.float32)),
            'out_mean': tf.Variable(tf.zeros([n_input], dtype=tf.float32)),
            'out_log_sigma': tf.Variable(tf.zeros([n_input], dtype=tf.float32))}
        return all_weights
            
    def _recognition_network(self, weights, biases):
        # Generate probabilistic encoder (recognition network), which
        # maps inputs onto a normal distribution in latent space.
        # The transformation is parametrized and can be learned.
        layer_1 = self.transfer_fct(tf.add(tf.matmul(self.x, weights['h1']), 
                                           biases['b1'])) 
        layer_2 = self.transfer_fct(tf.add(tf.matmul(layer_1, weights['h2']), 
                                           biases['b2'])) 
        z_mean = tf.add(tf.matmul(layer_2, weights['out_mean']),
                        biases['out_mean'])
        z_log_sigma_sq = \
            tf.add(tf.matmul(layer_2, weights['out_log_sigma']), 
                   biases['out_log_sigma'])
        return (z_mean, z_log_sigma_sq)

    def _generator_network(self, weights, biases):
        # Generate probabilistic decoder (decoder network), which
        # maps points in latent space onto a Bernoulli distribution in data space.
        # The transformation is parametrized and can be learned.
        layer_1 = self.transfer_fct(tf.add(tf.matmul(self.z, weights['h1']), 
                                           biases['b1'])) 
        layer_2 = self.transfer_fct(tf.add(tf.matmul(layer_1, weights['h2']), 
                                           biases['b2'])) 
        x_reconstr_mean = \
            tf.nn.sigmoid(tf.add(tf.matmul(layer_2, weights['out_mean']), 
                                 biases['out_mean']))
        return x_reconstr_mean
            
    def _create_loss_optimizer(self):
        # The loss is composed of two terms:
        # 1.) The reconstruction loss (the negative log probability
        #     of the input under the reconstructed Bernoulli distribution 
        #     induced by the decoder in the data space).
        #     This can be interpreted as the number of "nats" required
        #     for reconstructing the input when the activation in latent
        #     is given.
        # Adding 1e-10 to avoid evaluation of log(0.0)
        reconstr_loss = \
            -tf.reduce_sum(self.x * tf.log(1e-10 + self.x_reconstr_mean)
                           + (1-self.x) * tf.log(1e-10 + 1 - self.x_reconstr_mean),
                           1)
        # 2.) The latent loss, which is defined as the Kullback Leibler divergence 
        ##    between the distribution in latent space induced by the encoder on 
        #     the data and some prior. This acts as a kind of regularizer.
        #     This can be interpreted as the number of "nats" required
        #     for transmitting the the latent space distribution given
        #     the prior.
        latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq 
                                           - tf.square(self.z_mean) 
                                           - tf.exp(self.z_log_sigma_sq), 1)
        self.cost = tf.reduce_mean(reconstr_loss + latent_loss)   # average over batch
        # Use ADAM optimizer
        self.optimizer = \
            tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.cost)
        
    def partial_fit(self, X):
        """Train model based on mini-batch of input data.
        
        Return cost of mini-batch.
        """
        opt, cost = self.sess.run((self.optimizer, self.cost), 
                                  feed_dict={self.x: X})
        return cost
    
    def transform(self, X):
        """Transform data by mapping it into the latent space."""
        # Note: This maps to mean of distribution, we could alternatively
        # sample from Gaussian distribution
        return self.sess.run(self.z_mean, feed_dict={self.x: X})
    
    def generate(self, z_mu=None):
        """ Generate data by sampling from latent space.
        
        If z_mu is not None, data for this point in latent space is
        generated. Otherwise, z_mu is drawn from prior in latent 
        space.        
        """
        if z_mu is None:
            z_mu = np.random.normal(size=self.network_architecture["n_z"])
        # Note: This maps to mean of distribution, we could alternatively
        # sample from Gaussian distribution
        return self.sess.run(self.x_reconstr_mean, 
                             feed_dict={self.z: z_mu})
    
    def reconstruct(self, X):
        """ Use VAE to reconstruct given data. """
        return self.sess.run(self.x_reconstr_mean, 
                             feed_dict={self.x: X})

In general, implementing a VAE in tensorflow is relatively straightforward (in particular since we don not need to code the gradient computation). A bit confusing is potentially that all the logic happens at initialization of the class (where the graph is generated), while the actual sklearn interface methods are very simple one-liners.

We can now define a simple fuction which trains the VAE using mini-batches:

In [10]:
def train(network_architecture, learning_rate=0.0001,
          batch_size=100, training_epochs=10, display_step=5):
    vae = VariationalAutoencoder(network_architecture, 
                                 learning_rate=learning_rate, 
                                 batch_size=batch_size)
    # Training cycle
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(n_samples / batch_size)
        # Loop over all batches
        for i in range(total_batch):
            batch_xs, _ = get_next_batch(batch_size, classes, data)

            # Fit training using batch data
            cost = vae.partial_fit(batch_xs)
            # Compute average loss
            avg_cost += cost / n_samples * batch_size

        # Display logs per epoch step
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch+1), 
                  "cost=", "{:.9f}".format(avg_cost))
    return vae

Illustrating reconstruction quality

We can now train a VAE on MNIST by just specifying the network topology. We start with training a VAE with a 20-dimensional latent space.

In [11]:
network_architecture = \
    dict(n_hidden_recog_1=1000, # 1st layer encoder neurons
         n_hidden_recog_2=1000, # 2nd layer encoder neurons
         n_hidden_gener_1=1000, # 1st layer decoder neurons
         n_hidden_gener_2=1000, # 2nd layer decoder neurons
         n_input=28*28*3, # MNIST data input (img shape: 28*28)
         n_z=20)  # dimensionality of latent space

vae = train(network_architecture, training_epochs=2500)
('Epoch:', '0001', 'cost=', '1503.554197799')
('Epoch:', '0006', 'cost=', '1175.129356680')
('Epoch:', '0011', 'cost=', '1142.124820680')
('Epoch:', '0016', 'cost=', '1117.732226941')
('Epoch:', '0021', 'cost=', '1123.777427969')
('Epoch:', '0026', 'cost=', '1102.219869924')
('Epoch:', '0031', 'cost=', '1109.888665251')
('Epoch:', '0036', 'cost=', '1120.647270735')
('Epoch:', '0041', 'cost=', '1117.311640303')
('Epoch:', '0046', 'cost=', '1099.075967951')
('Epoch:', '0051', 'cost=', '1098.661792371')
('Epoch:', '0056', 'cost=', '1112.151922182')
('Epoch:', '0061', 'cost=', '1113.598088701')
('Epoch:', '0066', 'cost=', '1115.497500397')
('Epoch:', '0071', 'cost=', '1100.419486586')
('Epoch:', '0076', 'cost=', '1090.798157685')
('Epoch:', '0081', 'cost=', '1096.583320559')
('Epoch:', '0086', 'cost=', '1084.205154301')
('Epoch:', '0091', 'cost=', '1088.264063162')
('Epoch:', '0096', 'cost=', '1085.455263123')
('Epoch:', '0101', 'cost=', '1106.583303999')
('Epoch:', '0106', 'cost=', '1079.089402783')
('Epoch:', '0111', 'cost=', '1084.171667764')
('Epoch:', '0116', 'cost=', '1095.720169895')
('Epoch:', '0121', 'cost=', '1092.160377207')
('Epoch:', '0126', 'cost=', '1085.466310959')
('Epoch:', '0131', 'cost=', '1073.384259838')
('Epoch:', '0136', 'cost=', '1069.471184782')
('Epoch:', '0141', 'cost=', '1090.811157227')
('Epoch:', '0146', 'cost=', '1077.252374693')
('Epoch:', '0151', 'cost=', '1077.455009046')
('Epoch:', '0156', 'cost=', '1066.611770512')
('Epoch:', '0161', 'cost=', '1098.691422810')
('Epoch:', '0166', 'cost=', '1060.950192001')
('Epoch:', '0171', 'cost=', '1071.339428332')
('Epoch:', '0176', 'cost=', '1051.293016774')
('Epoch:', '0181', 'cost=', '1062.926358955')
('Epoch:', '0186', 'cost=', '1076.825932939')
('Epoch:', '0191', 'cost=', '1061.015118739')
('Epoch:', '0196', 'cost=', '1065.352139732')
('Epoch:', '0201', 'cost=', '1082.817243236')
('Epoch:', '0206', 'cost=', '1065.598019149')
('Epoch:', '0211', 'cost=', '1070.870226483')
('Epoch:', '0216', 'cost=', '1062.725404251')
('Epoch:', '0221', 'cost=', '1061.772001991')
('Epoch:', '0226', 'cost=', '1061.619000102')
('Epoch:', '0231', 'cost=', '1079.729857925')
('Epoch:', '0236', 'cost=', '1072.764362660')
('Epoch:', '0241', 'cost=', '1065.612035944')
('Epoch:', '0246', 'cost=', '1071.844849106')
('Epoch:', '0251', 'cost=', '1051.146100658')
('Epoch:', '0256', 'cost=', '1068.212074457')
('Epoch:', '0261', 'cost=', '1066.001501749')
('Epoch:', '0266', 'cost=', '1057.091232418')
('Epoch:', '0271', 'cost=', '1078.148065611')
('Epoch:', '0276', 'cost=', '1049.921482293')
('Epoch:', '0281', 'cost=', '1056.803219817')
('Epoch:', '0286', 'cost=', '1065.449535754')
('Epoch:', '0291', 'cost=', '1041.735313475')
('Epoch:', '0296', 'cost=', '1065.079894546')
('Epoch:', '0301', 'cost=', '1078.156570316')
('Epoch:', '0306', 'cost=', '1062.148692996')
('Epoch:', '0311', 'cost=', '1071.216518195')
('Epoch:', '0316', 'cost=', '1068.040436737')
('Epoch:', '0321', 'cost=', '1048.305174362')
('Epoch:', '0326', 'cost=', '1067.764258570')
('Epoch:', '0331', 'cost=', '1051.598186641')
('Epoch:', '0336', 'cost=', '1058.553816921')
('Epoch:', '0341', 'cost=', '1062.394300357')
('Epoch:', '0346', 'cost=', '1060.058274380')
('Epoch:', '0351', 'cost=', '1078.140448045')
('Epoch:', '0356', 'cost=', '1066.755451528')
('Epoch:', '0361', 'cost=', '1069.124869413')
('Epoch:', '0366', 'cost=', '1048.890561836')
('Epoch:', '0371', 'cost=', '1054.650748793')
('Epoch:', '0376', 'cost=', '1074.355073677')
('Epoch:', '0381', 'cost=', '1059.240864598')
('Epoch:', '0386', 'cost=', '1057.436116536')
('Epoch:', '0391', 'cost=', '1056.364819431')
('Epoch:', '0396', 'cost=', '1060.746978050')
('Epoch:', '0401', 'cost=', '1061.035972418')
('Epoch:', '0406', 'cost=', '1046.577702012')
('Epoch:', '0411', 'cost=', '1052.258513695')
('Epoch:', '0416', 'cost=', '1060.163347111')
('Epoch:', '0421', 'cost=', '1057.843017578')
('Epoch:', '0426', 'cost=', '1066.255874042')
('Epoch:', '0431', 'cost=', '1058.433378944')
('Epoch:', '0436', 'cost=', '1072.903797238')
('Epoch:', '0441', 'cost=', '1070.758494296')
('Epoch:', '0446', 'cost=', '1042.763578430')
('Epoch:', '0451', 'cost=', '1074.257961539')
('Epoch:', '0456', 'cost=', '1067.147779834')
('Epoch:', '0461', 'cost=', '1072.042882165')
('Epoch:', '0466', 'cost=', '1057.442574908')
('Epoch:', '0471', 'cost=', '1057.722236574')
('Epoch:', '0476', 'cost=', '1066.477055513')
('Epoch:', '0481', 'cost=', '1059.133757362')
('Epoch:', '0486', 'cost=', '1057.106917034')
('Epoch:', '0491', 'cost=', '1056.071507654')
('Epoch:', '0496', 'cost=', '1061.782351945')
('Epoch:', '0501', 'cost=', '1063.693935187')
('Epoch:', '0506', 'cost=', '1045.273962686')
('Epoch:', '0511', 'cost=', '1069.033423135')
('Epoch:', '0516', 'cost=', '1068.126790838')
('Epoch:', '0521', 'cost=', '1051.295565820')
('Epoch:', '0526', 'cost=', '1060.181811429')
('Epoch:', '0531', 'cost=', '1035.211347240')
('Epoch:', '0536', 'cost=', '1046.359028188')
('Epoch:', '0541', 'cost=', '1055.279493702')
('Epoch:', '0546', 'cost=', '1070.519345306')
('Epoch:', '0551', 'cost=', '1049.731953939')
('Epoch:', '0556', 'cost=', '1063.511409316')
('Epoch:', '0561', 'cost=', '1042.310055282')
('Epoch:', '0566', 'cost=', '1052.045872415')
('Epoch:', '0571', 'cost=', '1054.840892230')
('Epoch:', '0576', 'cost=', '1035.965024963')
('Epoch:', '0581', 'cost=', '1052.356938798')
('Epoch:', '0586', 'cost=', '1048.199498376')
('Epoch:', '0591', 'cost=', '1058.892514724')
('Epoch:', '0596', 'cost=', '1060.809763827')
('Epoch:', '0601', 'cost=', '1044.971921462')
('Epoch:', '0606', 'cost=', '1062.086155618')
('Epoch:', '0611', 'cost=', '1048.526131090')
('Epoch:', '0616', 'cost=', '1051.192953420')
('Epoch:', '0621', 'cost=', '1061.876246726')
('Epoch:', '0626', 'cost=', '1051.335936554')
('Epoch:', '0631', 'cost=', '1047.706805089')
('Epoch:', '0636', 'cost=', '1059.891030955')
('Epoch:', '0641', 'cost=', '1053.166553586')
('Epoch:', '0646', 'cost=', '1057.221801521')
('Epoch:', '0651', 'cost=', '1042.698503834')
('Epoch:', '0656', 'cost=', '1064.478650943')
('Epoch:', '0661', 'cost=', '1041.213504289')
('Epoch:', '0666', 'cost=', '1060.324262279')
('Epoch:', '0671', 'cost=', '1035.933561103')
('Epoch:', '0676', 'cost=', '1057.921487971')
('Epoch:', '0681', 'cost=', '1050.480113658')
('Epoch:', '0686', 'cost=', '1055.289890969')
('Epoch:', '0691', 'cost=', '1048.493141352')
('Epoch:', '0696', 'cost=', '1061.470368851')
('Epoch:', '0701', 'cost=', '1051.261091602')
('Epoch:', '0706', 'cost=', '1045.854861237')
('Epoch:', '0711', 'cost=', '1048.757786714')
('Epoch:', '0716', 'cost=', '1051.508538298')
('Epoch:', '0721', 'cost=', '1050.878161053')
('Epoch:', '0726', 'cost=', '1048.395284017')
('Epoch:', '0731', 'cost=', '1042.873849795')
('Epoch:', '0736', 'cost=', '1036.434025728')
('Epoch:', '0741', 'cost=', '1067.051578492')
('Epoch:', '0746', 'cost=', '1051.790908695')
('Epoch:', '0751', 'cost=', '1058.927716396')
('Epoch:', '0756', 'cost=', '1058.698503361')
('Epoch:', '0761', 'cost=', '1055.173368232')
('Epoch:', '0766', 'cost=', '1049.037986578')
('Epoch:', '0771', 'cost=', '1042.386420198')
('Epoch:', '0776', 'cost=', '1052.964439688')
('Epoch:', '0781', 'cost=', '1051.007352134')
('Epoch:', '0786', 'cost=', '1054.580357278')
('Epoch:', '0791', 'cost=', '1050.721752551')
('Epoch:', '0796', 'cost=', '1052.144610974')
('Epoch:', '0801', 'cost=', '1064.326394251')
('Epoch:', '0806', 'cost=', '1045.380224184')
('Epoch:', '0811', 'cost=', '1043.466210180')
('Epoch:', '0816', 'cost=', '1037.712824622')
('Epoch:', '0821', 'cost=', '1055.726185880')
('Epoch:', '0826', 'cost=', '1040.517809964')
('Epoch:', '0831', 'cost=', '1031.067551014')
('Epoch:', '0836', 'cost=', '1058.457296209')
('Epoch:', '0841', 'cost=', '1043.528794873')
('Epoch:', '0846', 'cost=', '1060.347162291')
('Epoch:', '0851', 'cost=', '1047.793200589')
('Epoch:', '0856', 'cost=', '1045.012398653')
('Epoch:', '0861', 'cost=', '1060.609353236')
('Epoch:', '0866', 'cost=', '1054.033524861')
('Epoch:', '0871', 'cost=', '1051.954260538')
('Epoch:', '0876', 'cost=', '1046.859989610')
('Epoch:', '0881', 'cost=', '1048.421590642')
('Epoch:', '0886', 'cost=', '1043.957129190')
('Epoch:', '0891', 'cost=', '1033.928355136')
('Epoch:', '0896', 'cost=', '1053.439603111')
('Epoch:', '0901', 'cost=', '1046.105383348')
('Epoch:', '0906', 'cost=', '1048.865479462')
('Epoch:', '0911', 'cost=', '1057.642453216')
('Epoch:', '0916', 'cost=', '1039.487605132')
('Epoch:', '0921', 'cost=', '1040.378304415')
('Epoch:', '0926', 'cost=', '1060.862424392')
('Epoch:', '0931', 'cost=', '1030.585485651')
('Epoch:', '0936', 'cost=', '1053.744849863')
('Epoch:', '0941', 'cost=', '1042.829350908')
('Epoch:', '0946', 'cost=', '1049.746917015')
('Epoch:', '0951', 'cost=', '1042.572630653')
('Epoch:', '0956', 'cost=', '1047.338867188')
('Epoch:', '0961', 'cost=', '1038.588957084')
('Epoch:', '0966', 'cost=', '1037.286294153')
('Epoch:', '0971', 'cost=', '1049.652525436')
('Epoch:', '0976', 'cost=', '1044.007341252')
('Epoch:', '0981', 'cost=', '1043.277001196')
('Epoch:', '0986', 'cost=', '1067.717701520')
('Epoch:', '0991', 'cost=', '1054.190560274')
('Epoch:', '0996', 'cost=', '1040.707172719')
('Epoch:', '1001', 'cost=', '1057.008988728')
('Epoch:', '1006', 'cost=', '1041.497767249')
('Epoch:', '1011', 'cost=', '1061.183343932')
('Epoch:', '1016', 'cost=', '1047.555967819')
('Epoch:', '1021', 'cost=', '1038.022453840')
('Epoch:', '1026', 'cost=', '1049.705398914')
('Epoch:', '1031', 'cost=', '1063.362287181')
('Epoch:', '1036', 'cost=', '1053.188140633')
('Epoch:', '1041', 'cost=', '1049.107603324')
('Epoch:', '1046', 'cost=', '1044.478175437')
('Epoch:', '1051', 'cost=', '1038.982485986')
('Epoch:', '1056', 'cost=', '1051.774100370')
('Epoch:', '1061', 'cost=', '1047.102947383')
('Epoch:', '1066', 'cost=', '1034.228148941')
('Epoch:', '1071', 'cost=', '1048.204998637')
('Epoch:', '1076', 'cost=', '1054.204860953')
('Epoch:', '1081', 'cost=', '1053.489649573')
('Epoch:', '1086', 'cost=', '1040.908162908')
('Epoch:', '1091', 'cost=', '1037.230374832')
('Epoch:', '1096', 'cost=', '1058.770610011')
('Epoch:', '1101', 'cost=', '1048.532648604')
('Epoch:', '1106', 'cost=', '1043.634742914')
('Epoch:', '1111', 'cost=', '1036.068725586')
('Epoch:', '1116', 'cost=', '1052.069422995')
('Epoch:', '1121', 'cost=', '1057.181507672')
('Epoch:', '1126', 'cost=', '1044.902748285')
('Epoch:', '1131', 'cost=', '1043.998079522')
('Epoch:', '1136', 'cost=', '1062.292303041')
('Epoch:', '1141', 'cost=', '1045.957544608')
('Epoch:', '1146', 'cost=', '1037.250773112')
('Epoch:', '1151', 'cost=', '1042.835253338')
('Epoch:', '1156', 'cost=', '1018.554693414')
('Epoch:', '1161', 'cost=', '1019.031938656')
('Epoch:', '1166', 'cost=', '1035.155008006')
('Epoch:', '1171', 'cost=', '1047.987448522')
('Epoch:', '1176', 'cost=', '1048.250810490')
('Epoch:', '1181', 'cost=', '1047.141662125')
('Epoch:', '1186', 'cost=', '1054.426965048')
('Epoch:', '1191', 'cost=', '1043.819232320')
('Epoch:', '1196', 'cost=', '1043.731310941')
('Epoch:', '1201', 'cost=', '1045.924702726')
('Epoch:', '1206', 'cost=', '1056.874900640')
('Epoch:', '1211', 'cost=', '1041.803623170')
('Epoch:', '1216', 'cost=', '1051.539197818')
('Epoch:', '1221', 'cost=', '1040.222020112')
('Epoch:', '1226', 'cost=', '1041.522973822')
('Epoch:', '1231', 'cost=', '1034.828268835')
('Epoch:', '1236', 'cost=', '1037.677912749')
('Epoch:', '1241', 'cost=', '1027.086113035')
('Epoch:', '1246', 'cost=', '1026.662711210')
('Epoch:', '1251', 'cost=', '1043.274115038')
('Epoch:', '1256', 'cost=', '1048.729132127')
('Epoch:', '1261', 'cost=', '1032.644700634')
('Epoch:', '1266', 'cost=', '1047.126699048')
('Epoch:', '1271', 'cost=', '1020.661334844')
('Epoch:', '1276', 'cost=', '1038.917435047')
('Epoch:', '1281', 'cost=', '1046.598248149')
('Epoch:', '1286', 'cost=', '1060.174406776')
('Epoch:', '1291', 'cost=', '1035.082167988')
('Epoch:', '1296', 'cost=', '1056.046845192')
('Epoch:', '1301', 'cost=', '1020.681567525')
('Epoch:', '1306', 'cost=', '1042.292152819')
('Epoch:', '1311', 'cost=', '1044.285358754')
('Epoch:', '1316', 'cost=', '1031.687394963')
('Epoch:', '1321', 'cost=', '1045.063344083')
('Epoch:', '1326', 'cost=', '1042.799690897')
('Epoch:', '1331', 'cost=', '1036.551281833')
('Epoch:', '1336', 'cost=', '1053.853956119')
('Epoch:', '1341', 'cost=', '1037.114354806')
('Epoch:', '1346', 'cost=', '1028.927594562')
('Epoch:', '1351', 'cost=', '1049.003074705')
('Epoch:', '1356', 'cost=', '1048.313442496')
('Epoch:', '1361', 'cost=', '1037.364835517')
('Epoch:', '1366', 'cost=', '1036.550820521')
('Epoch:', '1371', 'cost=', '1034.191995074')
('Epoch:', '1376', 'cost=', '1019.914032692')
('Epoch:', '1381', 'cost=', '1052.661984466')
('Epoch:', '1386', 'cost=', '1045.655639412')
('Epoch:', '1391', 'cost=', '1049.142550683')
('Epoch:', '1396', 'cost=', '1049.042883585')
('Epoch:', '1401', 'cost=', '1061.294780406')
('Epoch:', '1406', 'cost=', '1036.100946471')
('Epoch:', '1411', 'cost=', '1038.826390939')
('Epoch:', '1416', 'cost=', '1040.325691164')
('Epoch:', '1421', 'cost=', '1049.721296443')
('Epoch:', '1426', 'cost=', '1024.913693214')
('Epoch:', '1431', 'cost=', '1049.694102679')
('Epoch:', '1436', 'cost=', '1035.238807146')
('Epoch:', '1441', 'cost=', '1038.063108459')
('Epoch:', '1446', 'cost=', '1037.799249693')
('Epoch:', '1451', 'cost=', '1028.679781182')
('Epoch:', '1456', 'cost=', '1047.029255712')
('Epoch:', '1461', 'cost=', '1042.586954989')
('Epoch:', '1466', 'cost=', '1040.352867186')
('Epoch:', '1471', 'cost=', '1031.121873486')
('Epoch:', '1476', 'cost=', '1027.521059495')
('Epoch:', '1481', 'cost=', '1040.707929744')
('Epoch:', '1486', 'cost=', '1039.413866147')
('Epoch:', '1491', 'cost=', '1054.486852838')
('Epoch:', '1496', 'cost=', '1043.718074828')
('Epoch:', '1501', 'cost=', '1038.290263331')
('Epoch:', '1506', 'cost=', '1032.630269842')
('Epoch:', '1511', 'cost=', '1038.947597770')
('Epoch:', '1516', 'cost=', '1044.973530141')
('Epoch:', '1521', 'cost=', '1039.704824048')
('Epoch:', '1526', 'cost=', '1045.937566240')
('Epoch:', '1531', 'cost=', '1041.502155629')
('Epoch:', '1536', 'cost=', '1035.960187099')
('Epoch:', '1541', 'cost=', '1046.256498588')
('Epoch:', '1546', 'cost=', '1028.541074058')
('Epoch:', '1551', 'cost=', '1040.911244237')
('Epoch:', '1556', 'cost=', '1029.399120715')
('Epoch:', '1561', 'cost=', '1027.928185278')
('Epoch:', '1566', 'cost=', '1035.910844433')
('Epoch:', '1571', 'cost=', '1046.104028983')
('Epoch:', '1576', 'cost=', '1051.756061879')
('Epoch:', '1581', 'cost=', '1038.623023218')
('Epoch:', '1586', 'cost=', '1046.170067602')
('Epoch:', '1591', 'cost=', '1039.433329974')
('Epoch:', '1596', 'cost=', '1025.646990399')
('Epoch:', '1601', 'cost=', '1038.134233342')
('Epoch:', '1606', 'cost=', '1051.823165066')
('Epoch:', '1611', 'cost=', '1036.276316088')
('Epoch:', '1616', 'cost=', '1043.750484969')
('Epoch:', '1621', 'cost=', '1039.709922140')
('Epoch:', '1626', 'cost=', '1018.524802748')
('Epoch:', '1631', 'cost=', '1038.463675329')
('Epoch:', '1636', 'cost=', '1034.947760530')
('Epoch:', '1641', 'cost=', '1034.831568991')
('Epoch:', '1646', 'cost=', '1043.784491960')
('Epoch:', '1651', 'cost=', '1042.380955423')
('Epoch:', '1656', 'cost=', '1061.785403703')
('Epoch:', '1661', 'cost=', '1029.997005019')
('Epoch:', '1666', 'cost=', '1050.857378346')
('Epoch:', '1671', 'cost=', '1018.957413075')
('Epoch:', '1676', 'cost=', '1024.952342898')
('Epoch:', '1681', 'cost=', '1020.705792331')
('Epoch:', '1686', 'cost=', '1048.640642979')
('Epoch:', '1691', 'cost=', '1042.326236695')
('Epoch:', '1696', 'cost=', '1026.072042303')
('Epoch:', '1701', 'cost=', '1026.625297606')
('Epoch:', '1706', 'cost=', '1037.645141838')
('Epoch:', '1711', 'cost=', '1032.026826134')
('Epoch:', '1716', 'cost=', '1041.557714181')
('Epoch:', '1721', 'cost=', '1026.417671618')
('Epoch:', '1726', 'cost=', '1039.911320413')
('Epoch:', '1731', 'cost=', '1026.761668597')
('Epoch:', '1736', 'cost=', '1037.791839126')
('Epoch:', '1741', 'cost=', '1021.622242299')
('Epoch:', '1746', 'cost=', '1021.618841600')
('Epoch:', '1751', 'cost=', '1028.728970077')
('Epoch:', '1756', 'cost=', '1040.896062333')
('Epoch:', '1761', 'cost=', '1041.742599842')
('Epoch:', '1766', 'cost=', '1038.169435013')
('Epoch:', '1771', 'cost=', '1032.077221538')
('Epoch:', '1776', 'cost=', '1041.642181633')
('Epoch:', '1781', 'cost=', '1031.138060814')
('Epoch:', '1786', 'cost=', '1031.380291133')
('Epoch:', '1791', 'cost=', '1008.324243117')
('Epoch:', '1796', 'cost=', '1035.597666659')
('Epoch:', '1801', 'cost=', '1042.178794210')
('Epoch:', '1806', 'cost=', '1032.130763327')
('Epoch:', '1811', 'cost=', '1043.085946963')
('Epoch:', '1816', 'cost=', '1024.037531180')
('Epoch:', '1821', 'cost=', '1038.433897033')
('Epoch:', '1826', 'cost=', '1032.008349988')
('Epoch:', '1831', 'cost=', '1035.705998147')
('Epoch:', '1836', 'cost=', '1009.132219655')
('Epoch:', '1841', 'cost=', '1038.926010723')
('Epoch:', '1846', 'cost=', '1041.334835319')
('Epoch:', '1851', 'cost=', '1024.753221615')
('Epoch:', '1856', 'cost=', '1030.102326149')
('Epoch:', '1861', 'cost=', '1026.419658809')
('Epoch:', '1866', 'cost=', '1025.312083636')
('Epoch:', '1871', 'cost=', '1039.843631715')
('Epoch:', '1876', 'cost=', '1025.228562466')
('Epoch:', '1881', 'cost=', '1032.517526316')
('Epoch:', '1886', 'cost=', '1034.229556535')
('Epoch:', '1891', 'cost=', '1035.945803620')
('Epoch:', '1896', 'cost=', '1025.264408792')
('Epoch:', '1901', 'cost=', '1034.710918101')
('Epoch:', '1906', 'cost=', '1040.455095158')
('Epoch:', '1911', 'cost=', '1026.159904539')
('Epoch:', '1916', 'cost=', '1037.678539661')
('Epoch:', '1921', 'cost=', '1042.327981402')
('Epoch:', '1926', 'cost=', '1034.020398754')
('Epoch:', '1931', 'cost=', '1044.181971587')
('Epoch:', '1936', 'cost=', '1017.152168215')
('Epoch:', '1941', 'cost=', '1031.705794224')
('Epoch:', '1946', 'cost=', '1026.731505875')
('Epoch:', '1951', 'cost=', '1022.419691455')
('Epoch:', '1956', 'cost=', '1038.056094147')
('Epoch:', '1961', 'cost=', '1028.498852161')
('Epoch:', '1966', 'cost=', '1028.946573420')
('Epoch:', '1971', 'cost=', '1028.631296084')
('Epoch:', '1976', 'cost=', '1028.664841763')
('Epoch:', '1981', 'cost=', '1029.726853482')
('Epoch:', '1986', 'cost=', '1037.415177693')
('Epoch:', '1991', 'cost=', '1031.792964492')
('Epoch:', '1996', 'cost=', '1018.542515954')
('Epoch:', '2001', 'cost=', '1031.625259754')
('Epoch:', '2006', 'cost=', '1037.571929222')
('Epoch:', '2011', 'cost=', '1038.733288669')
('Epoch:', '2016', 'cost=', '1030.163550562')
('Epoch:', '2021', 'cost=', '1025.200144450')
('Epoch:', '2026', 'cost=', '1030.501006370')
('Epoch:', '2031', 'cost=', '1030.896339121')
('Epoch:', '2036', 'cost=', '1048.426020423')
('Epoch:', '2041', 'cost=', '1038.044750598')
('Epoch:', '2046', 'cost=', '1024.737998312')
('Epoch:', '2051', 'cost=', '1033.817415459')
('Epoch:', '2056', 'cost=', '1021.561721683')
('Epoch:', '2061', 'cost=', '1036.822840964')
('Epoch:', '2066', 'cost=', '1022.192305927')
('Epoch:', '2071', 'cost=', '1016.722195647')
('Epoch:', '2076', 'cost=', '1043.891415855')
('Epoch:', '2081', 'cost=', '1032.135772705')
('Epoch:', '2086', 'cost=', '1041.342760426')
('Epoch:', '2091', 'cost=', '1026.039780018')
('Epoch:', '2096', 'cost=', '1036.427330786')
('Epoch:', '2101', 'cost=', '1036.109368376')
('Epoch:', '2106', 'cost=', '1036.841595080')
('Epoch:', '2111', 'cost=', '1040.192088046')
('Epoch:', '2116', 'cost=', '1021.616990437')
('Epoch:', '2121', 'cost=', '1029.659123384')
('Epoch:', '2126', 'cost=', '1031.917642993')
('Epoch:', '2131', 'cost=', '1033.590248758')
('Epoch:', '2136', 'cost=', '1031.172759773')
('Epoch:', '2141', 'cost=', '1024.504491525')
('Epoch:', '2146', 'cost=', '1037.033784852')
('Epoch:', '2151', 'cost=', '1024.925196448')
('Epoch:', '2156', 'cost=', '1036.605338163')
('Epoch:', '2161', 'cost=', '1033.936315729')
('Epoch:', '2166', 'cost=', '1024.377583348')
('Epoch:', '2171', 'cost=', '1042.919212164')
('Epoch:', '2176', 'cost=', '1044.907550664')
('Epoch:', '2181', 'cost=', '1042.324444675')
('Epoch:', '2186', 'cost=', '1034.255259906')
('Epoch:', '2191', 'cost=', '1047.244818636')
('Epoch:', '2196', 'cost=', '1023.406515195')
('Epoch:', '2201', 'cost=', '1016.731930518')
('Epoch:', '2206', 'cost=', '1034.564818153')
('Epoch:', '2211', 'cost=', '1032.920263719')
('Epoch:', '2216', 'cost=', '1037.974394569')
('Epoch:', '2221', 'cost=', '1033.916372846')
('Epoch:', '2226', 'cost=', '1015.605755000')
('Epoch:', '2231', 'cost=', '1031.039416882')
('Epoch:', '2236', 'cost=', '1031.899385674')
('Epoch:', '2241', 'cost=', '1034.492191049')
('Epoch:', '2246', 'cost=', '1033.057030966')
('Epoch:', '2251', 'cost=', '1029.066632884')
('Epoch:', '2256', 'cost=', '1032.985704999')
('Epoch:', '2261', 'cost=', '1017.638296674')
('Epoch:', '2266', 'cost=', '1029.977109451')
('Epoch:', '2271', 'cost=', '1042.027170344')
('Epoch:', '2276', 'cost=', '1018.648286568')
('Epoch:', '2281', 'cost=', '1012.665185263')
('Epoch:', '2286', 'cost=', '1032.815327016')
('Epoch:', '2291', 'cost=', '1048.242660641')
('Epoch:', '2296', 'cost=', '1031.403587400')
('Epoch:', '2301', 'cost=', '1031.696792721')
('Epoch:', '2306', 'cost=', '1018.553770790')
('Epoch:', '2311', 'cost=', '1029.610892599')
('Epoch:', '2316', 'cost=', '1029.773676673')
('Epoch:', '2321', 'cost=', '1027.237387960')
('Epoch:', '2326', 'cost=', '1039.073808064')
('Epoch:', '2331', 'cost=', '1024.211954516')
('Epoch:', '2336', 'cost=', '1017.689289418')
('Epoch:', '2341', 'cost=', '1027.440140599')
('Epoch:', '2346', 'cost=', '1033.969045240')
('Epoch:', '2351', 'cost=', '1032.294192055')
('Epoch:', '2356', 'cost=', '1041.503882593')
('Epoch:', '2361', 'cost=', '1040.223658362')
('Epoch:', '2366', 'cost=', '1010.803873225')
('Epoch:', '2371', 'cost=', '1035.523643789')
('Epoch:', '2376', 'cost=', '1026.667300675')
('Epoch:', '2381', 'cost=', '1036.832002152')
('Epoch:', '2386', 'cost=', '1043.892285251')
('Epoch:', '2391', 'cost=', '1035.848100056')
('Epoch:', '2396', 'cost=', '1017.077411977')
('Epoch:', '2401', 'cost=', '1024.107461382')
('Epoch:', '2406', 'cost=', '1033.671557996')
('Epoch:', '2411', 'cost=', '1030.301607117')
('Epoch:', '2416', 'cost=', '1008.808011787')
('Epoch:', '2421', 'cost=', '1028.895332277')
('Epoch:', '2426', 'cost=', '1026.565137760')
('Epoch:', '2431', 'cost=', '1025.939823121')
('Epoch:', '2436', 'cost=', '1017.291407622')
('Epoch:', '2441', 'cost=', '1037.182309646')
('Epoch:', '2446', 'cost=', '1031.215513954')
('Epoch:', '2451', 'cost=', '1037.840093568')
('Epoch:', '2456', 'cost=', '1033.903769560')
('Epoch:', '2461', 'cost=', '1035.031754841')
('Epoch:', '2466', 'cost=', '1030.555068615')
('Epoch:', '2471', 'cost=', '1026.557561593')
('Epoch:', '2476', 'cost=', '1014.456720870')
('Epoch:', '2481', 'cost=', '1016.269512324')
('Epoch:', '2486', 'cost=', '1045.191257684')
('Epoch:', '2491', 'cost=', '1021.212821783')
('Epoch:', '2496', 'cost=', '1024.323935102')

Based on this we can sample some test inputs and visualize how well the VAE can reconstruct those. In general the VAE does really well.

let's grab some of the class0 images and show them alongside their reconstruction

In [12]:
x_sample =get_next_batch(100, c0class, c0data)[0]
x_reconstruct = vae.reconstruct(x_sample)
In [17]:
plt.figure(figsize=(8, 12))
for i in range(5):
    plt.subplot(5, 2, 2*i + 1)
    plt.imshow(x_sample[i+20].reshape(28, 28,3))#, vmin=0, vmax=1, cmap="gray")
    plt.title("Test input")
    plt.colorbar()
    plt.subplot(5, 2, 2*i + 2)
    plt.imshow(x_reconstruct[i+20].reshape(28, 28,3))#, vmin=0, vmax=1, cmap="gray")
    plt.title("Reconstruction")
    plt.colorbar()
plt.tight_layout()

now show a class1 example.

In [18]:
c1_recon = vae.reconstruct(c1data[:100])
In [22]:
plt.imshow(c1_recon[10].reshape(28, 28,3))
Out[22]:
<matplotlib.image.AxesImage at 0x7fca4062cc10>
In [21]:
plt.imshow(c1data[10].reshape(28, 28,3))
Out[21]:
<matplotlib.image.AxesImage at 0x7fca40590510>