Neural Networks gone wild! They can sample from discrete distributions now!


Neural Networks gone wild! They can sample from discrete distributions now! Source –

Guest blog by Yoel Zeldes.

This post describes:

  • what the Gumbel distribution is
  • how it is used for sampling from a discrete distribution
  • how the weights that affect the distribution’s parameters can be trained
  • how to use all of that in a toy example (with code)

In this post you will learn what the Gumbel-softmax trick is. Using this trick, you can sample from a discrete distribution and let the gradients propagate to the weights that affect the distribution’s parameters. This trick opens doors to many interesting applications. For start, you can find an example of text generation in the paper GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution.

Gumbel distribution


Training deep neural networks usually boils down to defining your model’s architecture and a loss function, and watching the gradients propagate.

However, sometimes it’s not that simple: some architectures incorporate a random component. The forward pass is no longer a deterministic function of the input and weights. The random component introduces stochasticity, by means of sampling from it.

When would that happen, you ask? Whenever we want to approximate an intractable sum or integral. Then, we can form a Monte Carlo estimate. A good example is the variational autoencoder. Basically, it’s an autoencoder on steroids: the encoder’s job is to learn a distribution over the latent space. The loss function contains an intractable expectation over that distribution, so we sample from it.

As with any architecture, the gradients need to propagate to the weights of the model. Some of the weights are responsible for transforming the input into the parameters of the distribution from which we sample. Here we face a problem: the gradients can’t propagate through random nodes! Hence, these weights won’t be updated.

One solution to the problem is the reparameterization trick: you substitute the sampled random variable with a deterministic parameterized transformation of a parameterless random variable.

If you don’t know this trick I highly encourage you to read about it. I’ll demonstrate it with the Gaussian case. For many types of continuous distributions you can do the reparameterization trick. But what do you do if you need the distribution to be over a discrete set of values?

Read the full article with source code, here

DSC Resources