Part 1: What is Neural Memory?
- Neural networks have hidden layers. Normally, the state of your hidden layer is based ONLY on your input data. So, normally a neural network’s information flow would look like this:input -> hidden -> output
This is straightforward. Certain types of input create certain types of hidden layers. Certain types of hidden layers create certain types of output layers. It’s kindof a closed system. Memory changes this. Memory means that the hidden layer is a combination of your input data at the current timestep and the hidden layer of the previous timestep.
(input + prev_hidden) -> hidden -> output
- 4 timesteps with hidden layer recurrence:(input + empty_hidden) -> hidden -> output(input + prev_hidden) -> hidden -> output(input + prev_hidden) -> hidden -> output(input + prev_hidden ) -> hidden -> output…. and 4 timesteps with input layer recurrence….
(input + empty_input) -> hidden -> output(input + prev_input) -> hidden -> output(input + prev_input) -> hidden -> output(input + prev_input) -> hidden -> outputFocus on the last hidden layer (4th line). In the hidden layer recurrence, we see a presence of every input seen so far. In the input layer recurrence, it’s exclusively defined by the current and previous inputs. This is why we model hidden recurrence. Hidden recurrence learns what to remember whereas input recurrence is hard wired to just remember the immediately previous datapoint.
- Let’s say we were trying to predict the next word in a song given the previous. The “input layer recurrence” would break down if the song accidentally had the same sequence of two words in multiple places. Think about it, if the song had the statements “I love you”, and “I love carrots”, and the network was trying to predict the next word, how would it know what follows “I love”? It could be carrots. It could be you. The network REALLY needs to know more about what part of the song its in. However, the “hidden layer recurrence” doesn’t break down in this way. It subtely remembers everything it saw (with memories becoming more subtle as it they fade into the past).
Part 2: RNN – Neural Network Memory
- our input layer to the neural network is determined by our input dataset. Each row of input data is used to generate the hidden layer (via forward propagation).
- Each hidden layer is then used to populate the output layer (assuming only 1 hidden layer). As we just saw, memory means that the hidden layer is a combination of the input data and the previous hidden layer.
- How is this done? Well, much like every other propagation in neural networks, it’s done with a matrix. This matrix defines the relationship between the previous hidden layer and the current one.
- The gif above reflects the magic of recurrent networks, and several very, very important properties. It depicts 4 timesteps. The first is exclusively influenced by the input data. The second one is a mixture of the first and second inputs. This continues on. You should recognize that, in some way, network 4 is “full”. Presumably, timestep 5 would have to choose which memories to keep and which ones to overwrite. This is very real. It’s the notion of memory “capacity”. As you might expect, bigger layers can hold more memories for a longer period of time. Also, this is when the network learns to forget irrelevant memories and remember important memories.
Part 3: Backpropagation Through Time:
So, how do recurrent neural networks learn? Check out this graphic. Black is the prediction, errors are bright yellow, derivatives are mustard colored.
They learn by fully propagating forward from 1 to 4 (through an entire sequence of arbitrary length), and then backpropagating all the derivatives from 4 back to 1.
As a working example, suppose we only had a vocabulary of four possible letters “helo”, and wanted to train an RNN on the training sequence “hello”. This training sequence is in fact a source of 4 separate training examples: 1. The probability of “e” should be likely given the context of “h”, 2. “l” should be likely in the context of “he”, 3. “l” should also be likely given the context of “hel”, and finally 4. “o” should be likely given the context of “hell”.
Concretely, we will encode each character into a vector using 1-of-k encoding (i.e. all zero except for a single one at the index of the character in the vocabulary), and feed them into the RNN one at a time with the
stepfunction. We will then observe a sequence of 4-dimensional output vectors (one dimension per character), which we interpret as the confidence the RNN currently assigns to each character coming next in the sequence. Here’s a diagram:
RNN computation. So how do these things work? At the core, RNNs have a deceptively simple API: They accept an input vector
x and give you an output vector
y. However, crucially this output vector’s contents are influenced not only by the input you just fed in, but also on the entire history of inputs you’ve fed in in the past. Written as a class, the RNN’s API consists of a single
rnn = RNN() y = rnn.step(x) # x is an input vector, y is the RNN's output vector
The RNN class has some internal state that it gets to update every time
step is called. In the simplest case this state consists of a single hidden vector
h. Here is an implementation of the step function in a Vanilla RNN:
class RNN: # ... def step(self, x): # update the hidden state self.h = np.tanh(np.dot(self.W_hh, self.h) + np.dot(self.W_xh, x)) # compute the output vector y = np.dot(self.W_hy, self.h) return y
The above specifies the forward pass of a vanilla RNN. This RNN’s parameters are the three matrices
W_hh, W_xh, W_hy. The hidden state
self.h is initialized with the zero vector. The
np.tanh function implements a non-linearity that squashes the activations to the range
[-1, 1]. Notice briefly how this works: There are two terms inside of the tanh: one is based on the previous hidden state and one is based on the current input. In numpy
np.dot is matrix multiplication. The two intermediates interact with addition, and then get squashed by the tanh into the new state vector. If you’re more comfortable with math notation, we can also write the hidden state update as , where tanh is applied elementwise.
We initialize the matrices of the RNN with random numbers and the bulk of work during training goes into finding the matrices that give rise to desirable behavior, as measured with some loss function that expresses your preference to what kinds of outputs
y you’d like to see in response to your input sequences
Going deep. RNNs are neural networks and everything works monotonically better (if done right) if you put on your deep learning hat and start stacking models up like pancakes. For instance, we can form a 2-layer recurrent network as follows:
y1 = rnn1.step(x) y = rnn2.step(y1)
In other words we have two separate RNNs: One RNN is receiving the input vectors and the second RNN is receiving the output of the first RNN as its input. Except neither of these RNNs know or care – it’s all just vectors coming in and going out, and some gradients flowing through each module during backpropagation.
Getting fancy. I’d like to briefly mention that in practice most of us use a slightly different formulation than what I presented above called a Long Short-Term Memory (LSTM) network. The LSTM is a particular type of recurrent network that works slightly better in practice, owing to its more powerful update equation and some appealing backpropagation dynamics. I won’t go into details, but everything I’ve said about RNNs stays exactly the same, except the mathematical form for computing the update (the line
self.h = ... ) gets a little more complicated. From here on I will use the terms “RNN/LSTM” interchangeably but all experiments in this post use an LSTM.
The Problems with Deep Backpropagation
Unlike traditional feed forward nets, the feed forward nets generated by unrolling RNNs can be enormously deep. This gives rise to a serious practical issue: it can be obscenely difficult to train using the backpropagation through time approach
Long Short Term Memory
To address these problems, researchers proposed a modified architecture for recurrent neural networks to help bridge long time lags between forcing inputs and appropriate responses and protect against exploding gradients. The architecture forces constant error flow (thus, neither exploding nor vanishing) through the internal state of special memory units. This long short term memory (LSTM) architecture utlized units that were structured as follows:
Structure of the basic LSTM unit
The LSTM unit consists of a memory cell which attempts to store information for extended periods of time. Access to this memory cell is protected by specialized gate neurons – the keep, write, and read gates – which are all logistic units. These gate cells, instead of sending their activities as inputs to other neurons, set the weights on edges connecting the rest of the neural net to the memory cell. The memory cell is a linear neuron that has a connection to itself. When the keep gate is turned on (with an activity of 1), the self connection has weight one and the memory cell writes its contents into itself. When the keep gate outputs a zero, the memory cell forgets its previous contents. The write gate allows the rest of the neural net to write into the memory cell when it outputs a 1 while the read gate allows the rest of the neural net to read from the memory cell when it outputs a 1.
So how exactly does this force a constant error flow through time to locally protect against exploding and vanishing gradients? To visualize this, let’s unroll the LSTM unit through time:
Unrolling the LSTM unit through the time domain
At first, the keep gate is set to 0 and the write gate is set to 1, which places 4.2 into the memory cell. This value is retained in the memory cell by a subsequent keep value of 1 and protected from read/write by values of 0. Finally, the cell is read and then cleared. Now we try to follow the backpropagation from the point of loading 4.2 into the memory cell to the point of reading 4.2 from the cell and its subsequent clearing. We realize that due to the linear nature of the memory neuron, the error derivative that we receive from the read point backpropagates with negligible change until the write point because the weights of the connections connecting the memory cell through all the time layers have weights approximately equal to 1 (approximate because of the logistic output of the keep gate). As a result, we can locally preserve the error derivatives over hundreds of steps without having to worry about exploding or vanishing gradients. You can see the action of this method successfully reading cursive handwriting: