Variational Autoencoder in TensorFlow

From Jan Hendrik Metzen 2017-01-03 with a small mods by d. gannon

From Metzen's intro:

The main motivation for this post was that I wanted to get more experience with both Variational Autoencoders (VAEs) and with Tensorflow. Thus, implementing the former in the latter sounded like a good idea for learning about both at the same time. This post summarizes the result.

Note: The post was updated on December 7th 2015:

  • a bug in the computation of the latent_loss was fixed (removed an erroneous factor 2). Thanks Colin Fang for pointing this out.
  • Using a Bernoulli distribution rather than a Gaussian distribution in the generator network

Note: The post was updated on January 3rd 2017:

  • changes required for supporting TensorFlow v0.12 and Python 3 support

Notes on this version

From d. gannon: This version uses this version uses images from the web of galaxies. it is a very small database (less than 100 images) but the autoencoder has no problem with it. this is a version of the code that we modified from another project on autoencoders. The article that describes that work is https://cloud4scieng.org/manifold-learning-and-deep-autoencoders-in-science/

In [203]:
#!pip install opencv-python
In [1]:
import numpy as np
import tensorflow as tf
import cv2
import random

import matplotlib.pyplot as plt
%matplotlib inline

np.random.seed(0)
tf.set_random_seed(0)
totalimg = 90 #1032
n_samples = totalimg
/home/dbgannon/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

download the data

The link to download a tar'd gzipped copy of the data is here: http://www.imagexd.org/assets/cells.tar.gz . The following function is looking for the path to the un-tar'd images and a file with the name of each image one per line. you can create this file with the command

ls > classlist

In [3]:
def test(size, listfile, datafile):
    classes = np.array([0]*size)
    data = np.array([np.zeros(128*128*3, dtype=np.float32)]*size )
    print('data', data.shape)
    with open(listfile) as f:
        i = 0
        for line_of_text in f:
            #print(datafile+"/"+line_of_text[:-1])
            x = cv2.imread(datafile+"/"+line_of_text[:-1])
            #print(x.shape)
            x = cv2.resize(x, (128,128))
            #print("shape =",x.shape)
            img2 = cv2.imdecode(x, cv2.IMREAD_UNCHANGED)
            f = 256.0
            x = x/f
            #print(x.shape)
            data[i] = x.reshape(128*128*3)
            x = line_of_text.find("class")
            c = line_of_text[0]
            if c=='e':
                classes[i]= 0
            elif c == 'b':
                classes[i]= 1
            else:
                classes[i]= 2
            i = i+1
            if i == size:
                break
    return classes, data

The function test above loads the data file which contains 3 classes of galaxies: eliptical, barred and spiral.

the etract_class can be used to pull out the subsets of each galaxy class.

In [9]:
def extract_class(i, bigclasses,bigdata ):
    size = 0
    for x in bigclasses:
        if x==i:
            size = size+1
    print('size =', size)
    data = np.array([np.zeros(128*128*3, dtype=np.float32)]*size )
    j = 0
    for k in range(len(bigclasses)):
        if bigclasses[k] == i:
            data[j] = bigdata[k]
            j = j+1
    return data
In [64]:
classes, data = test(totalimg, './classlist', 
                            './galaxies')
c0data = extract_class(0, classes, data)
c1data = extract_class(1, classes, data)
c2data = extract_class(2, classes, data)
c0class = np.zeros(len(c0data), dtype=np.float32)
c1class = np.zeros(len(c1data), dtype=np.float32)+1
c2class = np.zeros(len(c2data), dtype=np.float32)+2
data (90, 49152)
size = 29
size = 30
size = 31
In [11]:
data[3]
Out[11]:
array([0.01953125, 0.0234375 , 0.015625  , ..., 0.015625  , 0.01171875,
       0.01953125], dtype=float32)
In [12]:
plt.imshow(data[55].reshape((128,128,3)))
Out[12]:
<matplotlib.image.AxesImage at 0x7fe38ec9e278>
In [13]:
def get_next_batch(size, classes, data):
    if size > totalimg:
        s = totalimg
    else:
        s = size
    c = np.array([0]*s)
    d = np.array([np.zeros(128*128*3, dtype=np.float32)]*s )
    for i in range(s):
        j = np.random.randint(0,len(classes)-1)
        c[i]= classes[j]
        d[i]= data[j]
    return d, c


    
In [14]:
print('totalimg', totalimg)
totalimg 90
In [15]:
dd, cc = get_next_batch(totalimg, classes, data)
In [16]:
def xavier_init(fan_in, fan_out, constant=1): 
    """ Xavier initialization of network weights"""
    # https://stackoverflow.com/questions/33640581/how-to-do-xavier-initialization-on-tensorflow
    low = -constant*np.sqrt(6.0/(fan_in + fan_out)) 
    high = constant*np.sqrt(6.0/(fan_in + fan_out))
    return tf.random_uniform((fan_in, fan_out), 
                             minval=low, maxval=high, 
                             dtype=tf.float32)

Based on this, we define now a class "VariationalAutoencoder" with a sklearn-like interface that can be trained incrementally with mini-batches using partial_fit. The trained model can be used to reconstruct unseen input, to generate new samples, and to map inputs to the latent space.

$$ KL(P || Q) = - \sum_x p(x)log q(x) + \sum_x p(x)log(p(x)) $$

$$ Z_{\mu} , Z_{ln(\sigma^2)} = enc(X) $$ $$ \epsilon \in N(0,1) $$ $$ Z = Z_{\mu} + \epsilon \sqrt{exp(Z_{ln(\sigma^2)})} $$ $$ X_{recon\mu} = dec(Z) $$

the loss function is

$$ ReconLoss = \sum{ (X*ln(X_{recon\mu}) + (1-X)*ln(1 - X_{recon\mu})})$$ $$ LatentLoss = \sum{ (1+Z_{ln(\sigma^2)} - {Z_{\mu}}^2 - \exp(Z_{ln(\sigma^2)})) } $$

In [17]:
class VariationalAutoencoder(object):
    """ Variation Autoencoder (VAE) with an sklearn-like interface implemented using TensorFlow.
    
    This implementation uses probabilistic encoders and decoders using Gaussian 
    distributions and  realized by multi-layer perceptrons. The VAE can be learned
    end-to-end.
    
    See "Auto-Encoding Variational Bayes" by Kingma and Welling for more details.
    """
    def __init__(self, network_architecture, transfer_fct=tf.nn.softplus, 
                 learning_rate=0.001, batch_size=20):
        self.network_architecture = network_architecture
        self.transfer_fct = transfer_fct
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        
        # tf Graph input
        self.x = tf.placeholder(tf.float32, [None, network_architecture["n_input"]])
        
        # Create autoencoder network
        self._create_network()
        # Define loss function based variational upper-bound and 
        # corresponding optimizer
        self._create_loss_optimizer()
        
        # Initializing the tensor flow variables
        init = tf.initialize_all_variables()

        # Launch the session
        self.sess = tf.InteractiveSession()
        self.sess.run(init)
    
    def _create_network(self):
        # Initialize autoencode network weights and biases
        network_weights = self._initialize_weights(**self.network_architecture)

        # Use recognition network to determine mean and 
        # (log) variance of Gaussian distribution in latent
        # space
        self.z_mean, self.z_log_sigma_sq = \
            self._recognition_network(network_weights["weights_recog"], 
                                      network_weights["biases_recog"])

        # Draw one sample z from Gaussian distribution
        n_z = self.network_architecture["n_z"]
        eps = tf.random_normal((self.batch_size, n_z), 0, 1, 
                               dtype=tf.float32)
        # z = mu + sigma*epsilon
        self.z = tf.add(self.z_mean, 
                        tf.multiply(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps))

        # Use generator to determine mean of
        # Bernoulli distribution of reconstructed input
        self.x_reconstr_mean = \
            self._generator_network(network_weights["weights_gener"],
                                    network_weights["biases_gener"])
            
    def _initialize_weights(self, n_hidden_recog_1, n_hidden_recog_2, 
                            n_hidden_gener_1,  n_hidden_gener_2, 
                            n_input, n_z):
        all_weights = dict()
        all_weights['weights_recog'] = {
            'h1': tf.Variable(xavier_init(n_input, n_hidden_recog_1)),
            'h2': tf.Variable(xavier_init(n_hidden_recog_1, n_hidden_recog_2)),
            'out_mean': tf.Variable(xavier_init(n_hidden_recog_2, n_z)),
            'out_log_sigma': tf.Variable(xavier_init(n_hidden_recog_2, n_z))}
        all_weights['biases_recog'] = {
            'b1': tf.Variable(tf.zeros([n_hidden_recog_1], dtype=tf.float32)),
            'b2': tf.Variable(tf.zeros([n_hidden_recog_2], dtype=tf.float32)),
            'out_mean': tf.Variable(tf.zeros([n_z], dtype=tf.float32)),
            'out_log_sigma': tf.Variable(tf.zeros([n_z], dtype=tf.float32))}
        all_weights['weights_gener'] = {
            'h1': tf.Variable(xavier_init(n_z, n_hidden_gener_1)),
            'h2': tf.Variable(xavier_init(n_hidden_gener_1, n_hidden_gener_2)),
            'out_mean': tf.Variable(xavier_init(n_hidden_gener_2, n_input)),
            'out_log_sigma': tf.Variable(xavier_init(n_hidden_gener_2, n_input))}
        all_weights['biases_gener'] = {
            'b1': tf.Variable(tf.zeros([n_hidden_gener_1], dtype=tf.float32)),
            'b2': tf.Variable(tf.zeros([n_hidden_gener_2], dtype=tf.float32)),
            'out_mean': tf.Variable(tf.zeros([n_input], dtype=tf.float32)),
            'out_log_sigma': tf.Variable(tf.zeros([n_input], dtype=tf.float32))}
        return all_weights
            
    def _recognition_network(self, weights, biases):
        # Generate probabilistic encoder (recognition network), which
        # maps inputs onto a normal distribution in latent space.
        # The transformation is parametrized and can be learned.
        layer_1 = self.transfer_fct(tf.add(tf.matmul(self.x, weights['h1']), 
                                           biases['b1'])) 
        layer_2 = self.transfer_fct(tf.add(tf.matmul(layer_1, weights['h2']), 
                                           biases['b2'])) 
        z_mean = tf.add(tf.matmul(layer_2, weights['out_mean']),
                        biases['out_mean'])
        z_log_sigma_sq = \
            tf.add(tf.matmul(layer_2, weights['out_log_sigma']), 
                   biases['out_log_sigma'])
        return (z_mean, z_log_sigma_sq)

    def _generator_network(self, weights, biases):
        # Generate probabilistic decoder (decoder network), which
        # maps points in latent space onto a Bernoulli distribution in data space.
        # The transformation is parametrized and can be learned.
        layer_1 = self.transfer_fct(tf.add(tf.matmul(self.z, weights['h1']), 
                                           biases['b1'])) 
        layer_2 = self.transfer_fct(tf.add(tf.matmul(layer_1, weights['h2']), 
                                           biases['b2'])) 
        x_reconstr_mean = \
            tf.nn.sigmoid(tf.add(tf.matmul(layer_2, weights['out_mean']), 
                                 biases['out_mean']))
        return x_reconstr_mean
            
    def _create_loss_optimizer(self):
        # The loss is composed of two terms:
        # 1.) The reconstruction loss (the negative log probability
        #     of the input under the reconstructed Bernoulli distribution 
        #     induced by the decoder in the data space).
        #     This can be interpreted as the number of "nats" required
        #     for reconstructing the input when the activation in latent
        #     is given.
        # Adding 1e-10 to avoid evaluation of log(0.0)
        reconstr_loss = \
            -tf.reduce_sum(self.x * tf.log(1e-10 + self.x_reconstr_mean)
                           + (1-self.x) * tf.log(1e-10 + 1 - self.x_reconstr_mean),
                           1)
        # 2.) The latent loss, which is defined as the Kullback Leibler divergence 
        ##    between the distribution in latent space induced by the encoder on 
        #     the data and some prior. This acts as a kind of regularizer.
        #     This can be interpreted as the number of "nats" required
        #     for transmitting the the latent space distribution given
        #     the prior.
        latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq 
                                           - tf.square(self.z_mean) 
                                           - tf.exp(self.z_log_sigma_sq), 1)
        self.cost = tf.reduce_mean(reconstr_loss + latent_loss)   # average over batch
        # Use ADAM optimizer
        self.optimizer = \
            tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.cost)
        
    def partial_fit(self, X):
        """Train model based on mini-batch of input data.
        
        Return cost of mini-batch.
        """
        opt, cost = self.sess.run((self.optimizer, self.cost), 
                                  feed_dict={self.x: X})
        return cost
    
    def transform(self, X):
        """Transform data by mapping it into the latent space."""
        # Note: This maps to mean of distribution, we could alternatively
        # sample from Gaussian distribution
        return self.sess.run(self.z_mean, feed_dict={self.x: X})
    
    def generate(self, z_mu=None):
        """ Generate data by sampling from latent space.
        
        If z_mu is not None, data for this point in latent space is
        generated. Otherwise, z_mu is drawn from prior in latent 
        space.        
        """
        if z_mu is None:
            z_mu = np.random.normal(size=self.network_architecture["n_z"])
        # Note: This maps to mean of distribution, we could alternatively
        # sample from Gaussian distribution
        return self.sess.run(self.x_reconstr_mean, 
                             feed_dict={self.z: z_mu})
    
    def reconstruct(self, X):
        """ Use VAE to reconstruct given data. """
        return self.sess.run(self.x_reconstr_mean, 
                             feed_dict={self.x: X})

In general, implementing a VAE in tensorflow is relatively straightforward (in particular since we don not need to code the gradient computation). A bit confusing is potentially that all the logic happens at initialization of the class (where the graph is generated), while the actual sklearn interface methods are very simple one-liners.

We can now define a simple fuction which trains the VAE using mini-batches:

In [18]:
def train(network_architecture, learning_rate=0.0001,
          batch_size=20, training_epochs=10, display_step=5):
    vae = VariationalAutoencoder(network_architecture, 
                                 learning_rate=learning_rate, 
                                 batch_size=batch_size)
    # Training cycle
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(n_samples / batch_size)
        # Loop over all batches
        for i in range(total_batch):
            batch_xs, _ = get_next_batch(batch_size, classes, data)

            # Fit training using batch data
            cost = vae.partial_fit(batch_xs)
            # Compute average loss
            avg_cost += cost / n_samples * batch_size

        # Display logs per epoch step
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch+1), 
                  "cost=", "{:.9f}".format(avg_cost))
    return vae

Illustrating reconstruction quality

We can now train a VAE by just specifying the network topology. We start with training a VAE with a 20-dimensional latent space.

In [19]:
network_architecture = \
    dict(n_hidden_recog_1=1000, # 1st layer encoder neurons
         n_hidden_recog_2=1000, # 2nd layer encoder neurons
         n_hidden_gener_1=1000, # 1st layer decoder neurons
         n_hidden_gener_2=1000, # 2nd layer decoder neurons
         n_input=128*128*3, # MNIST data input (img shape: 28*28)
         n_z=20)  # dimensionality of latent space

vae = train(network_architecture, training_epochs=2500)
WARNING:tensorflow:From /home/dbgannon/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:118: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
Epoch: 0001 cost= 28840.174479167
Epoch: 0006 cost= 19031.012586806
Epoch: 0011 cost= 17604.420138889
Epoch: 0016 cost= 18578.544704861
Epoch: 0021 cost= 17523.484375000
Epoch: 0026 cost= 17781.375868056
Epoch: 0031 cost= 18291.717881944
Epoch: 0036 cost= 17007.068142361
Epoch: 0041 cost= 18223.212673611
Epoch: 0046 cost= 18067.176215278
Epoch: 0051 cost= 17922.154079861
Epoch: 0056 cost= 18394.543836806
Epoch: 0061 cost= 17180.345486111
Epoch: 0066 cost= 16884.414496528
Epoch: 0071 cost= 17854.503038194
Epoch: 0076 cost= 17817.343750000
Epoch: 0081 cost= 17637.635850694
Epoch: 0086 cost= 17281.582465278
Epoch: 0091 cost= 17794.785590278
Epoch: 0096 cost= 16319.256510417
Epoch: 0101 cost= 17799.138888889
Epoch: 0106 cost= 17698.367187500
Epoch: 0111 cost= 16731.486545139
Epoch: 0116 cost= 17043.723958333
Epoch: 0121 cost= 16960.702690972
Epoch: 0126 cost= 18120.866319444
Epoch: 0131 cost= 16518.418402778
Epoch: 0136 cost= 17203.845486111
Epoch: 0141 cost= 16112.808159722
Epoch: 0146 cost= 17067.257378472
Epoch: 0151 cost= 16703.901475694
Epoch: 0156 cost= 16944.108506944
Epoch: 0161 cost= 17324.052517361
Epoch: 0166 cost= 16615.122395833
Epoch: 0171 cost= 16096.480034722
Epoch: 0176 cost= 16514.127604167
Epoch: 0181 cost= 17004.260416667
Epoch: 0186 cost= 16697.056423611
Epoch: 0191 cost= 16838.878906250
Epoch: 0196 cost= 15935.458767361
Epoch: 0201 cost= 17383.253472222
Epoch: 0206 cost= 16509.929687500
Epoch: 0211 cost= 16539.016059028
Epoch: 0216 cost= 16864.768663194
Epoch: 0221 cost= 17677.800781250
Epoch: 0226 cost= 16250.740885417
Epoch: 0231 cost= 16181.250434028
Epoch: 0236 cost= 16415.381076389
Epoch: 0241 cost= 17073.576822917
Epoch: 0246 cost= 16048.416666667
Epoch: 0251 cost= 15968.373263889
Epoch: 0256 cost= 15667.450086806
Epoch: 0261 cost= 16842.475260417
Epoch: 0266 cost= 16322.367187500
Epoch: 0271 cost= 16662.580295139
Epoch: 0276 cost= 15931.948350694
Epoch: 0281 cost= 16317.053385417
Epoch: 0286 cost= 16933.537760417
Epoch: 0291 cost= 16422.004774306
Epoch: 0296 cost= 15381.491753472
Epoch: 0301 cost= 16256.547309028
Epoch: 0306 cost= 17077.939670139
Epoch: 0311 cost= 16688.310329861
Epoch: 0316 cost= 16533.190104167
Epoch: 0321 cost= 16297.553819444
Epoch: 0326 cost= 15892.516927083
Epoch: 0331 cost= 15780.644965278
Epoch: 0336 cost= 15824.844618056
Epoch: 0341 cost= 15877.611979167
Epoch: 0346 cost= 15697.395833333
Epoch: 0351 cost= 16171.762586806
Epoch: 0356 cost= 15943.956163194
Epoch: 0361 cost= 15902.065972222
Epoch: 0366 cost= 16957.545572917
Epoch: 0371 cost= 15308.323133681
Epoch: 0376 cost= 16745.470920139
Epoch: 0381 cost= 16523.680555556
Epoch: 0386 cost= 16036.400390625
Epoch: 0391 cost= 15799.812065972
Epoch: 0396 cost= 15510.680772569
Epoch: 0401 cost= 15287.471571181
Epoch: 0406 cost= 16625.402343750
Epoch: 0411 cost= 16416.235677083
Epoch: 0416 cost= 17154.807725694
Epoch: 0421 cost= 15314.470052083
Epoch: 0426 cost= 15490.281467014
Epoch: 0431 cost= 15801.716145833
Epoch: 0436 cost= 16050.029079861
Epoch: 0441 cost= 16283.814236111
Epoch: 0446 cost= 16605.355468750
Epoch: 0451 cost= 16347.652777778
Epoch: 0456 cost= 15731.038628472
Epoch: 0461 cost= 16587.036024306
Epoch: 0466 cost= 15477.838541667
Epoch: 0471 cost= 15662.475260417
Epoch: 0476 cost= 15263.150173611
Epoch: 0481 cost= 15867.213758681
Epoch: 0486 cost= 16430.001736111
Epoch: 0491 cost= 14886.850694444
Epoch: 0496 cost= 16623.927083333
Epoch: 0501 cost= 15730.987847222
Epoch: 0506 cost= 15711.533637153
Epoch: 0511 cost= 15355.798177083
Epoch: 0516 cost= 16069.349826389
Epoch: 0521 cost= 15544.365885417
Epoch: 0526 cost= 15536.296440972
Epoch: 0531 cost= 15435.029079861
Epoch: 0536 cost= 16410.739583333
Epoch: 0541 cost= 15833.997829861
Epoch: 0546 cost= 16031.910156250
Epoch: 0551 cost= 15923.396050347
Epoch: 0556 cost= 16153.213541667
Epoch: 0561 cost= 15456.763671875
Epoch: 0566 cost= 14967.770616319
Epoch: 0571 cost= 16074.850260417
Epoch: 0576 cost= 16535.595920139
Epoch: 0581 cost= 16076.526909722
Epoch: 0586 cost= 16165.127170139
Epoch: 0591 cost= 15410.311631944
Epoch: 0596 cost= 15651.730902778
Epoch: 0601 cost= 16572.694010417
Epoch: 0606 cost= 15695.222656250
Epoch: 0611 cost= 16542.186197917
Epoch: 0616 cost= 15823.003472222
Epoch: 0621 cost= 14885.219618056
Epoch: 0626 cost= 15617.280381944
Epoch: 0631 cost= 14832.931423611
Epoch: 0636 cost= 14404.384331597
Epoch: 0641 cost= 15533.767361111
Epoch: 0646 cost= 15625.672309028
Epoch: 0651 cost= 16186.404947917
Epoch: 0656 cost= 15847.280164931
Epoch: 0661 cost= 16707.793836806
Epoch: 0666 cost= 15976.004123264
Epoch: 0671 cost= 14773.956814236
Epoch: 0676 cost= 15082.963975694
Epoch: 0681 cost= 15877.440972222
Epoch: 0686 cost= 16340.460069444
Epoch: 0691 cost= 15748.605034722
Epoch: 0696 cost= 15448.057074653
Epoch: 0701 cost= 15843.820746528
Epoch: 0706 cost= 15083.006727431
Epoch: 0711 cost= 15628.255859375
Epoch: 0716 cost= 15325.751519097
Epoch: 0721 cost= 15984.244357639
Epoch: 0726 cost= 16740.977430556
Epoch: 0731 cost= 15068.658203125
Epoch: 0736 cost= 15748.534722222
Epoch: 0741 cost= 15402.148871528
Epoch: 0746 cost= 16370.703125000
Epoch: 0751 cost= 15680.200737847
Epoch: 0756 cost= 15999.031250000
Epoch: 0761 cost= 15287.985026042
Epoch: 0766 cost= 15009.791015625
Epoch: 0771 cost= 15589.638454861
Epoch: 0776 cost= 15872.957465278
Epoch: 0781 cost= 15983.239149306
Epoch: 0786 cost= 15313.299479167
Epoch: 0791 cost= 15391.838975694
Epoch: 0796 cost= 15233.024305556
Epoch: 0801 cost= 15956.710503472
Epoch: 0806 cost= 15435.410590278
Epoch: 0811 cost= 15566.144097222
Epoch: 0816 cost= 15878.977864583
Epoch: 0821 cost= 16329.916666667
Epoch: 0826 cost= 15631.745225694
Epoch: 0831 cost= 16530.667100694
Epoch: 0836 cost= 16724.144531250
Epoch: 0841 cost= 15400.772352431
Epoch: 0846 cost= 15435.082899306
Epoch: 0851 cost= 15360.632595486
Epoch: 0856 cost= 15929.900173611
Epoch: 0861 cost= 14529.122829861
Epoch: 0866 cost= 15283.594618056
Epoch: 0871 cost= 16477.843750000
Epoch: 0876 cost= 15584.115885417
Epoch: 0881 cost= 15222.365668403
Epoch: 0886 cost= 14791.812065972
Epoch: 0891 cost= 15611.644748264
Epoch: 0896 cost= 15946.345486111
Epoch: 0901 cost= 15706.883246528
Epoch: 0906 cost= 14403.358940972
Epoch: 0911 cost= 15565.815104167
Epoch: 0916 cost= 16362.740451389
Epoch: 0921 cost= 14830.655164931
Epoch: 0926 cost= 15807.617187500
Epoch: 0931 cost= 15013.864583333
Epoch: 0936 cost= 16403.090277778
Epoch: 0941 cost= 16050.911458333
Epoch: 0946 cost= 15144.014756944
Epoch: 0951 cost= 15669.813802083
Epoch: 0956 cost= 15261.877604167
Epoch: 0961 cost= 15824.849392361
Epoch: 0966 cost= 15773.828125000
Epoch: 0971 cost= 16409.272569444
Epoch: 0976 cost= 14720.042100694
Epoch: 0981 cost= 15533.814670139
Epoch: 0986 cost= 15366.313585069
Epoch: 0991 cost= 15895.554253472
Epoch: 0996 cost= 16721.046006944
Epoch: 1001 cost= 14841.590928819
Epoch: 1006 cost= 16024.634548611
Epoch: 1011 cost= 15302.706597222
Epoch: 1016 cost= 16115.976996528
Epoch: 1021 cost= 15473.321180556
Epoch: 1026 cost= 16181.732638889
Epoch: 1031 cost= 14933.835720486
Epoch: 1036 cost= 15679.352430556
Epoch: 1041 cost= 15718.683376736
Epoch: 1046 cost= 15762.429253472
Epoch: 1051 cost= 14347.155598958
Epoch: 1056 cost= 17032.856336806
Epoch: 1061 cost= 15626.519531250
Epoch: 1066 cost= 14275.175998264
Epoch: 1071 cost= 15950.962239583
Epoch: 1076 cost= 16085.137152778
Epoch: 1081 cost= 14783.549913194
Epoch: 1086 cost= 15651.002604167
Epoch: 1091 cost= 16057.259331597
Epoch: 1096 cost= 15536.607204861
Epoch: 1101 cost= 16515.369791667
Epoch: 1106 cost= 15564.346788194
Epoch: 1111 cost= 14858.960720486
Epoch: 1116 cost= 16573.859809028
Epoch: 1121 cost= 15320.070746528
Epoch: 1126 cost= 14649.675564236
Epoch: 1131 cost= 15033.369791667
Epoch: 1136 cost= 15485.304470486
Epoch: 1141 cost= 14817.481987847
Epoch: 1146 cost= 14726.206163194
Epoch: 1151 cost= 15510.894097222
Epoch: 1156 cost= 14962.381727431
Epoch: 1161 cost= 15205.320312500
Epoch: 1166 cost= 16276.861545139
Epoch: 1171 cost= 15784.200086806
Epoch: 1176 cost= 16044.858940972
Epoch: 1181 cost= 16096.375000000
Epoch: 1186 cost= 15124.996527778
Epoch: 1191 cost= 15309.229600694
Epoch: 1196 cost= 15796.595052083
Epoch: 1201 cost= 15183.100694444
Epoch: 1206 cost= 15634.029947917
Epoch: 1211 cost= 16117.921440972
Epoch: 1216 cost= 15224.677083333
Epoch: 1221 cost= 15986.023437500
Epoch: 1226 cost= 16233.084201389
Epoch: 1231 cost= 16810.348958333
Epoch: 1236 cost= 15927.556423611
Epoch: 1241 cost= 15481.944010417
Epoch: 1246 cost= 15498.469835069
Epoch: 1251 cost= 15269.464626736
Epoch: 1256 cost= 15528.737847222
Epoch: 1261 cost= 15699.188151042
Epoch: 1266 cost= 14850.009114583
Epoch: 1271 cost= 15723.745659722
Epoch: 1276 cost= 15598.565104167
Epoch: 1281 cost= 17366.572048611
Epoch: 1286 cost= 15587.766493056
Epoch: 1291 cost= 15006.193359375
Epoch: 1296 cost= 15953.669270833
Epoch: 1301 cost= 15138.610243056
Epoch: 1306 cost= 15850.892578125
Epoch: 1311 cost= 15726.017795139
Epoch: 1316 cost= 15443.205078125
Epoch: 1321 cost= 15407.407335069
Epoch: 1326 cost= 14905.136718750
Epoch: 1331 cost= 15177.516059028
Epoch: 1336 cost= 15937.000434028
Epoch: 1341 cost= 14506.340928819
Epoch: 1346 cost= 15950.305555556
Epoch: 1351 cost= 15221.126302083
Epoch: 1356 cost= 17045.373697917
Epoch: 1361 cost= 16819.125434028
Epoch: 1366 cost= 15713.704861111
Epoch: 1371 cost= 16295.185763889
Epoch: 1376 cost= 15868.142795139
Epoch: 1381 cost= 14987.942274306
Epoch: 1386 cost= 16262.848524306
Epoch: 1391 cost= 15738.012586806
Epoch: 1396 cost= 15389.999565972
Epoch: 1401 cost= 14975.054036458
Epoch: 1406 cost= 15931.837673611
Epoch: 1411 cost= 15241.211805556
Epoch: 1416 cost= 14702.252821181
Epoch: 1421 cost= 16755.736545139
Epoch: 1426 cost= 15440.311631944
Epoch: 1431 cost= 15852.828125000
Epoch: 1436 cost= 15270.374131944
Epoch: 1441 cost= 16363.072048611
Epoch: 1446 cost= 15644.578993056
Epoch: 1451 cost= 15258.314236111
Epoch: 1456 cost= 15179.885633681
Epoch: 1461 cost= 15261.356770833
Epoch: 1466 cost= 15386.348524306
Epoch: 1471 cost= 15050.888237847
Epoch: 1476 cost= 16116.278645833
Epoch: 1481 cost= 16142.502170139
Epoch: 1486 cost= 15197.445963542
Epoch: 1491 cost= 15496.765407986
Epoch: 1496 cost= 15738.003472222
Epoch: 1501 cost= 15526.238498264
Epoch: 1506 cost= 14949.957031250
Epoch: 1511 cost= 15163.950954861
Epoch: 1516 cost= 15659.492621528
Epoch: 1521 cost= 15240.648654514
Epoch: 1526 cost= 14935.817491319
Epoch: 1531 cost= 15091.196614583
Epoch: 1536 cost= 15098.367621528
Epoch: 1541 cost= 15101.361762153
Epoch: 1546 cost= 15529.722656250
Epoch: 1551 cost= 14961.253038194
Epoch: 1556 cost= 14408.342013889
Epoch: 1561 cost= 16054.362413194
Epoch: 1566 cost= 16294.401041667
Epoch: 1571 cost= 16352.948350694
Epoch: 1576 cost= 16140.829427083
Epoch: 1581 cost= 14927.062065972
Epoch: 1586 cost= 15591.800781250
Epoch: 1591 cost= 15971.815104167
Epoch: 1596 cost= 15506.117621528
Epoch: 1601 cost= 15093.574218750
Epoch: 1606 cost= 14563.998480903
Epoch: 1611 cost= 16643.719184028
Epoch: 1616 cost= 15806.802083333
Epoch: 1621 cost= 15487.332031250
Epoch: 1626 cost= 15494.845052083
Epoch: 1631 cost= 15567.907118056
Epoch: 1636 cost= 14775.688802083
Epoch: 1641 cost= 14981.224609375
Epoch: 1646 cost= 14991.221788194
Epoch: 1651 cost= 16116.137152778
Epoch: 1656 cost= 15510.988715278
Epoch: 1661 cost= 15736.465711806
Epoch: 1666 cost= 15949.360243056
Epoch: 1671 cost= 16119.377170139
Epoch: 1676 cost= 15465.193576389
Epoch: 1681 cost= 15584.028645833
Epoch: 1686 cost= 15436.595920139
Epoch: 1691 cost= 16550.630208333
Epoch: 1696 cost= 15786.145833333
Epoch: 1701 cost= 15886.181857639
Epoch: 1706 cost= 15242.954427083
Epoch: 1711 cost= 15874.965711806
Epoch: 1716 cost= 16553.207465278
Epoch: 1721 cost= 15710.133897569
Epoch: 1726 cost= 15627.160373264
Epoch: 1731 cost= 16037.253038194
Epoch: 1736 cost= 15146.009331597
Epoch: 1741 cost= 15385.211805556
Epoch: 1746 cost= 15482.138454861
Epoch: 1751 cost= 14674.360894097
Epoch: 1756 cost= 15539.644965278
Epoch: 1761 cost= 15834.611111111
Epoch: 1766 cost= 14976.759548611
Epoch: 1771 cost= 15665.084852431
Epoch: 1776 cost= 15915.502604167
Epoch: 1781 cost= 15023.055121528
Epoch: 1786 cost= 15271.296875000
Epoch: 1791 cost= 15763.449435764
Epoch: 1796 cost= 15738.618923611
Epoch: 1801 cost= 15785.073784722
Epoch: 1806 cost= 15959.746744792
Epoch: 1811 cost= 15788.366102431
Epoch: 1816 cost= 15087.789713542
Epoch: 1821 cost= 15210.580729167
Epoch: 1826 cost= 17045.042968750
Epoch: 1831 cost= 14860.406684028
Epoch: 1836 cost= 14663.517795139
Epoch: 1841 cost= 15252.060980903
Epoch: 1846 cost= 15881.200086806
Epoch: 1851 cost= 15448.911675347
Epoch: 1856 cost= 14891.531250000
Epoch: 1861 cost= 15559.797526042
Epoch: 1866 cost= 15827.694878472
Epoch: 1871 cost= 15484.562934028
Epoch: 1876 cost= 15750.637586806
Epoch: 1881 cost= 16593.620442708
Epoch: 1886 cost= 16076.913194444
Epoch: 1891 cost= 15745.107204861
Epoch: 1896 cost= 15643.075737847
Epoch: 1901 cost= 14882.996310764
Epoch: 1906 cost= 15619.387152778
Epoch: 1911 cost= 15163.862630208
Epoch: 1916 cost= 16007.298611111
Epoch: 1921 cost= 15518.523003472
Epoch: 1926 cost= 16745.198350694
Epoch: 1931 cost= 15304.340277778
Epoch: 1936 cost= 16106.516493056
Epoch: 1941 cost= 15621.708550347
Epoch: 1946 cost= 16066.757378472
Epoch: 1951 cost= 15443.883246528
Epoch: 1956 cost= 16537.971354167
Epoch: 1961 cost= 15337.107638889
Epoch: 1966 cost= 15451.966796875
Epoch: 1971 cost= 14895.346788194
Epoch: 1976 cost= 15603.504340278
Epoch: 1981 cost= 15943.692708333
Epoch: 1986 cost= 15195.282769097
Epoch: 1991 cost= 15916.763888889
Epoch: 1996 cost= 16082.940972222
Epoch: 2001 cost= 15176.236111111
Epoch: 2006 cost= 15435.772135417
Epoch: 2011 cost= 16115.955295139
Epoch: 2016 cost= 15818.113715278
Epoch: 2021 cost= 15278.912326389
Epoch: 2026 cost= 15333.633029514
Epoch: 2031 cost= 15846.574218750
Epoch: 2036 cost= 15525.987196181
Epoch: 2041 cost= 15534.358506944
Epoch: 2046 cost= 15976.773003472
Epoch: 2051 cost= 15752.037543403
Epoch: 2056 cost= 15218.504557292
Epoch: 2061 cost= 15670.388888889
Epoch: 2066 cost= 15774.121961806
Epoch: 2071 cost= 15674.822916667
Epoch: 2076 cost= 16058.871961806
Epoch: 2081 cost= 16394.874565972
Epoch: 2086 cost= 16032.221788194
Epoch: 2091 cost= 15876.728732639
Epoch: 2096 cost= 15857.551649306
Epoch: 2101 cost= 15972.689670139
Epoch: 2106 cost= 16009.328559028
Epoch: 2111 cost= 16304.927517361
Epoch: 2116 cost= 16189.036458333
Epoch: 2121 cost= 15000.961371528
Epoch: 2126 cost= 14417.955295139
Epoch: 2131 cost= 14254.083333333
Epoch: 2136 cost= 15744.819444444
Epoch: 2141 cost= 15146.777126736
Epoch: 2146 cost= 15247.738064236
Epoch: 2151 cost= 15131.171006944
Epoch: 2156 cost= 15678.994357639
Epoch: 2161 cost= 16317.479166667
Epoch: 2166 cost= 15775.261718750
Epoch: 2171 cost= 15792.545572917
Epoch: 2176 cost= 15282.832031250
Epoch: 2181 cost= 15843.942708333
Epoch: 2186 cost= 15003.796657986
Epoch: 2191 cost= 16313.785156250
Epoch: 2196 cost= 15446.582682292
Epoch: 2201 cost= 15666.940104167
Epoch: 2206 cost= 15232.085503472
Epoch: 2211 cost= 16243.274305556
Epoch: 2216 cost= 15565.167534722
Epoch: 2221 cost= 15155.738281250
Epoch: 2226 cost= 15436.213541667
Epoch: 2231 cost= 15476.478732639
Epoch: 2236 cost= 15241.529079861
Epoch: 2241 cost= 15349.078125000
Epoch: 2246 cost= 15126.438368056
Epoch: 2251 cost= 16297.749565972
Epoch: 2256 cost= 14855.444661458
Epoch: 2261 cost= 15911.165798611
Epoch: 2266 cost= 15498.292968750
Epoch: 2271 cost= 14608.004123264
Epoch: 2276 cost= 15215.793836806
Epoch: 2281 cost= 14484.308810764
Epoch: 2286 cost= 16426.986111111
Epoch: 2291 cost= 16001.542534722
Epoch: 2296 cost= 15040.688802083
Epoch: 2301 cost= 15146.635416667
Epoch: 2306 cost= 16032.718315972
Epoch: 2311 cost= 15581.417100694
Epoch: 2316 cost= 15887.167968750
Epoch: 2321 cost= 16128.678385417
Epoch: 2326 cost= 15479.073350694
Epoch: 2331 cost= 15641.638454861
Epoch: 2336 cost= 16129.910156250
Epoch: 2341 cost= 15953.890190972
Epoch: 2346 cost= 16723.968315972
Epoch: 2351 cost= 15893.334852431
Epoch: 2356 cost= 15715.620442708
Epoch: 2361 cost= 15667.273654514
Epoch: 2366 cost= 15633.814670139
Epoch: 2371 cost= 16001.891059028
Epoch: 2376 cost= 15223.713324653
Epoch: 2381 cost= 15387.595052083
Epoch: 2386 cost= 15850.988281250
Epoch: 2391 cost= 15818.870225694
Epoch: 2396 cost= 14643.782769097
Epoch: 2401 cost= 15578.932074653
Epoch: 2406 cost= 15574.729817708
Epoch: 2411 cost= 15748.767361111
Epoch: 2416 cost= 15685.569444444
Epoch: 2421 cost= 15738.786024306
Epoch: 2426 cost= 15482.898003472
Epoch: 2431 cost= 14928.416666667
Epoch: 2436 cost= 16567.221354167
Epoch: 2441 cost= 15527.588541667
Epoch: 2446 cost= 16282.500868056
Epoch: 2451 cost= 15336.701822917
Epoch: 2456 cost= 15067.723524306
Epoch: 2461 cost= 15377.094184028
Epoch: 2466 cost= 15365.740885417
Epoch: 2471 cost= 15740.294270833
Epoch: 2476 cost= 15765.306423611
Epoch: 2481 cost= 15241.011284722
Epoch: 2486 cost= 15907.166666667
Epoch: 2491 cost= 15812.263454861
Epoch: 2496 cost= 14764.913628472

Based on this we can sample some test inputs and visualize how well the VAE can reconstruct those. In general the VAE does really well.

let's grab some of the class0 images and show them alongside their reconstruction

In [199]:
x_sample =get_next_batch(20, classes, data)[0]
x_reconstruct = vae.reconstruct(x_sample)
In [201]:
x_sample.shape
Out[201]:
(20, 49152)
In [84]:
plt.figure(figsize=(8, 12))
for i in range(5):
    plt.subplot(5, 2, 2*i + 1)
    plt.imshow(x_sample[i].reshape(128,128,3))#, vmin=0, vmax=1, cmap="gray")
    plt.title("Test input")
    plt.colorbar()
    plt.subplot(5, 2, 2*i + 2)
    plt.imshow(x_reconstruct[i].reshape(128,128,3))#, vmin=0, vmax=1, cmap="gray")
    plt.title("Reconstruction")
    plt.colorbar()
plt.tight_layout()

now show a class1 example.

In [202]:
X= c1data[18]
plt.imshow(X.reshape((128,128,3)))
Out[202]:
<matplotlib.image.AxesImage at 0x7fe37477aac8>
In [126]:
Z = vae.transform(X.reshape((1,49152)))
In [139]:
Z = Z+np.random.normal(size=20)+2.0
In [140]:
#Zex = np.array([np.zeros(128*128*3, dtype=np.float32)]*20 )
Zex = np.array([Z]*20)
Zex = Zex.reshape(20,20)
In [141]:
Zex.shape
Out[141]:
(20, 20)
In [142]:
z_mu = np.random.normal(size=20) #vae.network_architecture["n_z"])
z_mu= z_mu.reshape((1,20))
z_mu.shape
Out[142]:
(1, 20)
In [143]:
c1_imaginary = vae.generate(Zex)
In [144]:
plt.figure(figsize=(8, 12))
for i in range(0,10,2):
    plt.subplot(5, 2, i + 1)
    plt.imshow(c1_imaginary[i].reshape(128,128,3))#, vmin=0, vmax=1, cmap="gray")
    plt.title("Test input")
    plt.colorbar()
    plt.subplot(5, 2, i + 2)
    plt.imshow(c1_imaginary[i+1].reshape(128,128,3))#, vmin=0, vmax=1, cmap="gray")
    plt.title("Reconstruction")
    plt.colorbar()
plt.tight_layout()