-
Unknown A
Hi everyone. So today we are once again continuing our implementation of Make More. Now, so far we've come up to here multilayer perceptrons and our neural net looked like this and we were implementing this over the last few lectures. Now I'm sure everyone is very excited to go into recurrent neural networks and all of their variants and how they work. And the diagrams look cool and it's very exciting and interesting and we're going to get a better result. But unfortunately I think we have to remain here for one more lecture. And the reason for that is we've already trained this multi layer perceptron, right? And we are getting pretty good loss and I think we have a pretty decent understanding of the architecture and how it works. But the line of code here that I take an issue with is here loss backward. That is, we are taking Pytorch autograd and using it to calculate all of our gradients along the way.
-
Unknown A
And I would like to remove the use of loss backward and I would like us to write our backward pass manually on the level of tensors. And I think that this is a very useful exercise for the following reasons. I actually have an entire blog post on this topic. But I like to call backpropagation a leaky abstraction. And what I mean by that is backpropagation doesn't just make your neural networks just work magically. It's not the case that you can just stack up arbitrary Lego blocks of differentiable functions and just cross your fingers and back propagate and everything is great. Things don't just work automatically. It is a leaky abstraction in the sense that you can shoot yourself in the foot if you do not understand its internals, it will magically not work, or not work optimally. And you will need to understand how it works under the hood if you're hoping to debug it and if you are hoping to address it in your neural net.
-
Unknown A
So this blog post here from a while ago goes into some of those examples. So for example, we've already covered them, some of them already. For example, the flat tails of these functions and how you do not want to saturate them too much because your gradients will die. The case of dead neurons, which I've already covered as well. The case of exploding or vanishing gradients in the case of recurrent neural networks, which we are about to cover. And then also you will often come across some examples in the wild. This is a snippet that I found in a random code base on the Internet where they actually have like a very subtle but pretty major bug in their implementation. And the bug points at the fact that the author of this code does not actually understand back propagation. So they're trying to do here is they're trying to clip the loss at a certain maximum value, but actually what they're trying to do is they're trying to clip the gradients to have a maximum value instead of trying to clip the loss at a maximum value.
-
Unknown A
And indirectly they're basically causing some of the outliers to be actually ignored, because when you clip a loss of an outlier, you are setting its gradient to zero. And so have a look through this and read through it. But there's basically a bunch of subtle issues that you're going to avoid if you actually know what you're doing. And that's why I don't think it's the case that because Pytorch or other frameworks offer autograd, it is okay for us to ignore how it works. Now, we've actually already covered autograd and we wrote micrograd, but micrograd was an autograd engine only on the level of individual scalars. So. So the atoms were single individual numbers. And, you know, I don't think it's enough. And I'd like us to basically think about backpropagation on level of tensors as well. And so in a summary, I think it's a good exercise.
-
Unknown A
I think it is very, very valuable. You're going to become better at debugging neural networks and making sure that you understand what you're doing. It's going to make everything fully explicit, so you're not going to be nervous about what is hidden away from you. And basically, in general, we're going to emerge stronger. And so let's get into it. A bit of a fun historical note here is that today writing your backward pass by hand and manually is not recommended, and no one does it except for the purposes of exercise. But about 10 years ago in deep learning, this was fairly standard and in fact pervasive. So at the time, everyone used to write their own backward pass by hand manually, including myself. And it's just what you would do. So we used to write backward pass by hand, and now everyone just called last backward.
-
Unknown A
We've lost something. I wanted to give you a few examples of this. So here's a 2006 paper from Jeff Hinton and Ruslan Slaktinov in Science that was influential at the time. And this was training some architectures called restricted Boltzmann machines. And basically it's an autoencoder trained here. And this is from roughly 2010. I had a library for training restricted Boltzmann machines. And this was at the time written in Matlab. So Python was not used for deep learning. Pervasively, it was all Matlab. And Matlab was this scientific computing package that everyone would use. So we would write matlab, which is barely a programming language as well, but it had a very convenient tensor class and it was this computing environment. And you would run here, it would all run on a cpu, of course, but you would have very nice plots to go with it and a built in debugger.
-
Unknown A
And it was pretty nice. Now, the code in this package in 2010 that I wrote for fitting restrictable machines to a large extent is recognizable, but I wanted to show you how you would. Well, I'm creating the data in the XY batches, I'm initializing the neural net. So it's got weights and biases just like we're used to. And then this is the training loop where we actually do the forward pass. And then here at this time, they didn't even necessarily use back propagation to train neural networks. So this in particular implements contrastive divergence, which estimates a gradient. And then here we take that gradient and use it for a parameter update along the lines that we're used to. Yeah, here. But you can see that basically people are meddling with these gradients directly and inline and themselves. It wasn't that common to use an autograd engine.
-
Unknown A
Here's one more example from a paper of mine from 2014 called Defragment Embeddings. And here what I was doing is I was aligning images and text. And so it's kind of like a clip, if you're familiar with it. But instead of working on the level of entire images and entire sentences, it was working on the level of individual objects and little pieces of sentences. And I was embedding them and then calculating a very much like a clip, like loss. And I dug up the code from 2014 of how I implemented this and it was already in Numpy and Python. And here I'm implementing the cost function. And it was standards to implement not just the cost, but also the backward pass manually. So here I'm calculating the image embeddings, sentence embeddings, the loss function, I calculate the scores, this is the loss function. And then once I have the loss function, I do the backward pass right here.
-
Unknown A
So I backward through the loss function and through the neural net and I append regularization. So Everything was done by hand, manually, and you just write out the backward pass and then you would use a gradient checker to make sure that your numerical estimate of the gradient agrees with the one you calculated during back propagation. So this was very standard for a long time, but today, of course, it is standard to use an autograd engine, but it was definitely useful. And I think people sort of understood how these neural networks work on a very intuitive level. And so I think it's a good exercise again, and this is where we want to be. Okay, so just as a reminder from our previous lecture, this is the Jupyter notebook that we implemented at the time. And we're going to keep everything the same. So we're still going to have a two layer multilayer perceptron with, with a batch normalization layer.
-
Unknown A
So the forward pass will be basically identical to this lecture. But here we're going to get rid of loss backward and instead we're going to write a backward pass manually. Now, here's the starter code for this lecture. We are becoming a backprop ninja in this notebook. And the first few cells here are identical to what we are used to. So we are doing some imports, loading the data set and processing the data set. None of this changed now. Now here I'm introducing a utility function that we're going to use later to compare the gradients. So in particular, we are going to have the gradients that we estimate manually ourselves, and we're going to have gradients that Pytorch calculates, and we're going to be checking for correctness, assuming of course, that Pytorch is correct. Then here we have the initialization that we are quite used to.
-
Unknown A
So we have our embedding table for the characters, the first layer, second layer, and a batch normalization in between. And here's where we create all the parameters. Now, you will note that I changed the initialization a little bit to be small numbers. So normally you would set the biases to be all zero. Here I'm setting them to be small random numbers. And I'm doing this because if your variables are initialized to exactly zero, sometimes what can happen is that can mask an incorrect implementation of a gradient, because when everything is zero, it sort of like simplifies and gives you a much simpler expression of the gradient than you would otherwise get. And so by making it small numbers, I'm trying to unmask those potential errors in these calculations. You also notice that I'm using B1 in the first layer. I'm using a bias despite batch normalization right afterwards.
-
Unknown A
So this would typically not be what you do, because we talked about the fact that you don't need the bias. But I'm doing this here just for fun, because we're going to have a gradient with respect to it and we can check that we are still calculating it correctly, even though this bias is spurious. So here I'm calculating a single batch, and then here I am doing a forward pass. Now you'll notice that the forward pass is significantly expanded from what we are used to here. The forward pass was just here. Now, the reason that the forward pass is longer is for two reasons. Number one, here we just had an F crossentropy, but here I am bringing back a explicit implementation of the loss function. And number two, I've broken up the implementation into manageable chunks. So we have a lot more intermediate tensors along the way in the forward pass.
-
Unknown A
And that's because we are about to go backwards and calculate the gradients in this back propagation from the bottom to the top. So we're going to go upwards. And just like we have, for example, the logprops tensor in a forward pass, in the backward pass we're going to have a dlockprobs which is going to store the derivative of the loss with respect to the lockprobs tensorflow. And so we're going to be prepending D to every one of these tensors and calculating it along the way of this back propagation. So as an example, we have a, B and raw. Here we're going to be calculating a dbn. So here I'm telling Pytorch that we want to retain the grad of all these intermediate values, because here in exercise one, we're going to calculate the backward pass. So we're going to calculate all these D variables and use the CMP function I've introduced above to check our correctness with respect to what Pytorch is telling us.
-
Unknown A
This is going to be exercise one, where we sort of back propagate through this entire graph. Now, just to give you a very quick preview of what's going to happen in exercise two and below, here we have fully broken up the loss and backpropagated through it manually in all the little atomic pieces that make it up. But here we're going to collapse the loss into a single cross entropy call and instead we're going to analytically derive, using math and paper and pencil, the gradient of velocity with respect to the logits. And instead of backpropagating through all of its little Chunks one at a time, we're just going to analytically derive what that gradient is and we're going to implement that which is much more efficient, as we'll see in a bit. Then we're going to do the exact same thing for batch normalization. So instead of breaking up batch norm into all the tiny components, we're going to use pen and paper and mathematics and calculus to derive the gradient through the batchNorm layer.
-
Unknown A
So we're going to calculate the Backward pass through BatchNorm layer in a much more efficient expression instead of backward propagating through all of its little pieces independently. So that's going to be exercise three, and then in exercise four, we're going to put it all together and this is the full code of training this two layer mlp. And we're going to basically insert our manual back prop and we're going to take out LosStep backward. And you will basically see that you can get all the same results using fully your own code. And the only thing we're using from Pytorch is the Torch tensor to make the calculations efficient. But otherwise you will understand fully what it means to forward and backward the neural net and train it. And I think that'll be awesome. So let's get to it. Okay, so I ran all the cells of this notebook all the way up to here, and I'm going to erase this and I'm going to start implementing backward pass starting with dlockprobs.
-
Unknown A
So we want to understand what should go here to calculate the gradient of the loss with respect to all the elements of the logprobs Tensor. Now, I'm going to give away the answer here, but I wanted to put a quick note here that I think will be most pedagogically useful for you, is to actually go into the description of this video and find the link to this jupyter notebook. You can find it both on GitHub, but you can also find Google Colab with it. So you don't have to install anything, you'll just go to a website, Google Colab, and you can try to implement these derivatives or gradients yourself. And then if you are not able to come to my video and see me do it, and so work in tandem and try it first yourself and then see me give away the answer. And I think that'll be most valuable to you.
-
Unknown A
And that's how I recommend you go through this lecture. So we are starting here with DLOG props. Now, DLOG props will hold the derivative of the loss with respect to all the elements of log props. What is inside log props? The shape of this is 32 by 27. So it's not going to surprise you that Dlogprops should also be an array of size 32 by 27, because we want the derivative loss with respect to all of its elements. So the sizes of those are always going to be equal. Now how does lock probes influence the loss? Okay, loss is negative lock probs indexed with range of n and yb and then the mean of that. Now just as a reminder, YB is just basically an array of all the correct indices. So what we're doing here is we're taking the log props array of size 32 by 27, right?
-
Unknown A
And then we are going in every single row, and in each row we are plugging, plucking out the index 8 and then 14 and 15 and so on. So we're going down the rows. That's the iterator range of n. And then we are always plucking out the index of the column specified by this tensor yb. So in the 0th row we are taking the 8th column, in the first row we're taking the 14th column, etc. And so logprops at this plucks out all those log probabilities of the correct next character in a sequence. So that's what that does. And the shape of this or the size of it is of course 32, because our batch size is 32. So these elements get plucked out and then their mean and the negative of that becomes loss. So I always like to work with simpler examples to understand the numerical form of derivative.
-
Unknown A
What's going on here is once we've plucked out these examples, we're taking the mean and then the negative. So the loss, basically, I can write it this way, is the negative of say A plus B plus C. And the mean of those three numbers would be, say negative would divide three. That would be how we achieve the mean of three numbers ABC, although we actually have 32 numbers here. And so what is basically the loss by say like da, right? Well, we simplify this expression mathematically. This is negative 1 over 3 of A and negative plus negative 1 over 3 of B plus negative 1 over 3 of C. And so what is dloss by da? It's just negative 1 over 3. And so you can see that if we don't just have A, B and C, but we have 32 numbers, then D loss by D, you know, every one of those numbers is going to be 1 over n.
-
Unknown A
More generally, because N is the the size of the batch 32 in this case. So dloss by dlockprobs is negative 1 over n in all these places. Now, what about the other elements inside lockprobs? Because lockprobs is large array. You see that logprobs shape is 32 by 27, but only 32 of them participate in the loss calculation. So what's the derivative of all the other most of the elements that do not get plucked out here? Well, their loss intuitively is zero. Sorry, their gradient intuitively is zero, and that's because they did not participate in the loss. So most of these numbers inside this tensor does not feed into the loss. And so if we were to change these numbers, then the loss doesn't change, which is the equivalent of way of saying that the derivative of the loss with respect to them is zero. They don't impact it.
-
Unknown A
So here's a way to implement this derivative. Then we start out with torch zeros of shape 32 by 27. Or let's just say instead of doing this because we don't want to hard code numbers, let's do torch zeros like logprobs. So basically this is going to create an array of zeros exactly in the shape of logprobs. And then we need to set the derivative of negative 1 over n inside exactly these locations. So here's what we can do. The logprops indexed in the identical way will be just set to negative one over zero, divide N. Right, just like we derived here. So now let me erase all these reasoning and then this is the candidate derivative for dlogprops. Let's uncomment the first line and check that this is correct. Okay, so CMP ran and let's go back up to CMP and you see that what it's doing is it's calculating if the calculated value by us, which is dt, is exactly equal to T grad as calculated by Pytorch.
-
Unknown A
And then this is making sure that all of the elements are exactly equal and then converting this to a single Boolean value. Because we don't want a Boolean tensor, we just want a Boolean value. And then here we are making sure that, okay, if they're not exactly equal, maybe they are approximately equal because of some floating point issues, but they're very, very close. So here we are using Torch Allclose, which has a little bit of a wiggle available, because sometimes you can get very, very close. But if you use a slightly different calculation because of floating point arithmetic, you can get a slightly different result so this is checking if you get an approximately close result. And then here we are checking the maximum, basically the value that has the highest difference and what is the difference and the absolute value difference between those two.
-
Unknown A
And so we are printing whether we have an exact equality, an approximate equality, and what is the largest difference. And so here we see that we actually have exact equality. And so therefore of course we also have an approximate equality and the maximum difference is exactly zero. So basically our dlogprobs is exactly equal to what Pytorch calculated to be logprobs grad in its backpropagation. So so far we're working pretty well. Okay, so let's now continue our backpropagation. We have that logprobs depends on probs through a log. So all the elements of probs are being element wise applied log two. Now if we want deep probs, then, then remember your micrograph training. We have like a log node, it takes in probs and creates logprobs. And dprops will be the local derivative of that individual operation log times the derivative loss with respect to its output, which in this case is d logprops.
-
Unknown A
So what is the local derivative of this operation? Well, we are taking log element wise and we can come here and we can see well from alpha is your friend that D by DX of log of x is just simply one over x. So therefore in this case X is probs. So we have D by DX is 1 over X, which is 1 over Probs. And then this is the local derivative. And then times we want to chain it. So this is chain rule times d logprops. Then let me uncomment this and let me run the cell in place. And we see that the derivative of probs as we calculated here is exactly correct. And so notice here how this works. Probs is going to be inverted and then element wise multiplied here. So if your probs is very, very close to one, that means your network is currently predicting the character correctly, then this will become one over one and vlog probs just gets passed through.
-
Unknown A
But if your probabilities are incorrectly assigned, so if the correct character here is getting a very low probability, then 1.0 dividing by it will boost this and then multiply by the props. So basically what this line is doing intuitively is it's taking the examples that have a very low probability currently assigned and it's boosting their gradient. You can look at it that way. Next up is countsuminv. So we want derivative of this. Now let me just pause here and kind of introduce what's happening here in general, because I know it's a little bit confusing. We have the logits that come out of the neural net here. What I'm doing is I'm finding the maximum in each row and I'm subtracting it for the purpose of numerical stability. And we talked about how if you do not do this, you run into numerical issues if some of the logits take on too large values because we end up exponentiating them.
-
Unknown A
So this is done just for safety numerically. Then here's the exponentiation of all the sort of like logits to create our counts. And then we want to take the sum of these counts and normalize so that all of the probs sum to one. Now here instead of using one over count sum, I use raised to the power of negative one. Mathematically they are identical. I just found that there's something wrong with the Pytorch implementation of the backward pass of division. And it gives like a weird result, but that doesn't happen for native one. So I'm using this formula instead. But basically all that's happening here is we got the logits, we want to exponentiate all of them and want to normalize the counts to create our probabilities. It's just that it's happening across multiple lines. So now here we want to first take the derivative, we want to back propagate into count sumif and then into counts as well.
-
Unknown A
So what should be the count sum? If now we actually have to be careful here because we have to scrutinize and be careful with the shapes. So counts shape and then count sum in that shape are different. So in particular counts is 32 by 27. But this count sumim is 32 by 1. And so in this multiplication here we also have an implicit broadcasting that Pytorch will do because it needs to take this column tensor of 32 numbers and replicate it horizontally 27 times to align these two tensors. So we can do an element wise multiply. So really what this looks like is the following using a toy example. Again, what we really have here is just props is counts times count summative. So it's A C equals A times B. But A is three by three and B is just three by one, a column tensor.
-
Unknown A
And so Pytorch internally replicated this elements of B and it did that across all the columns. So for example B1, which, which is the first element of B would be replicated here across all the columns in this Multiplication. And now we're trying to back propagate through this operation to count some in. So when we are calculating this derivative, it's important to realize that these two this looks like a single operation, but actually is two operations applied sequentially. The first operation that Pytorch did is it took this column tensor and replicated it across all the across all the columns, basically 27 times. So that's the first operation, it's a replication. And then the second operation is the multiplication. So let's first background through the multiplication. If these two arrays are of the same size and we just have A and B, both of them three by three, then how do we back propagate through a multiplication?
-
Unknown A
So if we just have scalars and not tensors, then then if you have C equals A times B, then what is the derivative of C with respect to B? Well, it's just A and so that's the local derivative. So here in our case I'm doing the multiplication and back propagating through just multiplication itself, which is element wise is going to be the local derivative, which in this case is simply counts, because counts is the A. So this is the local derivative. And then times because the chain rule dprops. So this here is the derivative or the gradient, but with respect to replicated B. But we don't have a replicated B, we just have a single B column. So how do we now back propagate through the replication? And intuitively, this B1 is the same variable and it's just reused multiple times. And so you can look at it as being equivalent to a case we've encountered in micrograd.
-
Unknown A
And so here I'm just pulling out a random graph we used in micrograd. We had an example where a single node has its output feeding into two branches of basically the graph until the last function. And we're talking about how the correct thing to do in the backward pass is we need to sum all the gradients that arrive at any one node. So across these different branches the gradients would sum. So if a node is used multiple times, the gradients for all of its uses sum during back propagation. So here B1 is used multiple times in all of these columns. And therefore the right thing to do here is to sum horizontally across all the rows. So we want to sum in Dimension 1, but we want to retain this dimension. So that countsumenv and its gradient are going to be exactly the same shape.
-
Unknown A
So we want to make sure that we keep them is true so we don't lose this dimension and this will make the CountSumInv be exactly shape 32 by 1. So revealing this comparison as well. And running this, we see that we get an exact match. So this derivative is exactly correct. And let me erase this. Now let's also back propagate into counts, which is the other variable here to create props. So from props to count, someinv we just did that. Let's go into counts as well. So dcounts will be DCounces RA. So DC by DA is just B. So therefore it's countsuminv. And then times chain rule dprops. Now countsuminv is 32 by 1, dprobs is 32 by 27. So those will broadcast fine and will give us dcounts. There's no additional summation required here. There will be a broadcasting that happens in this multiply here because countsumimv needs to be replicated again to correctly multiply dprops.
-
Unknown A
But that's going to give the correct result. So as far as this single operation is concerned, so we've backpropagated from props to counts, but we can't actually check the derivative of counts. I have it much later on. And the reason for that is because countsuminv depends on counts. And so there's a second branch here that we have to finish. Because countsuminv backpropagates into countsum and countsum will backpropagate into counts. And so counts is a node that is being used twice. It's used right here into props, and it goes through this other branch through countsuminv. So even though we've calculated the first contribution of it, we still have to calculate the second contribution of it later. Okay, so we're continuing with this branch. We have the derivative for countsuminv. Now we want the derivative countsum. So dcountsum equals what is the local derivative of this operation.
-
Unknown A
So this is basically an element wise one over counts sum. So count sum raised to the power of negative 1 is the same as 1 over count sum. If we go to WolframAlpha, we see that x to the negative 1d by dx of it is basically negative x to the negative 2. Right? Negative 1 over x squared is the same as negative x to the negative 2. So D count sum here will be local derivative is going to be negative counts sum to the negative two. That's the local derivative times chain rule, which is dcountsuminv. So that's dcountsum. Let's uncomment this and check that I am correct. Okay, so we have perfect equality and there's no sketchiness going on here with any shapes because these are of the same shape. Okay, next up we want to back propagate through this line. We have that count sum is count sum along the rows.
-
Unknown A
So I wrote out along some help here. We have to keep in mind that counts of course is 32 by 27 and count sum is 32 by 1. So in this back propagation, we need to take this column of derivatives and transform it into a array of derivatives, two dimensional array. So what is this operation doing? We're taking some kind of an input like say a three by three matrix A, and we are summing up the rows into column tensor B, B1, B2, B3. That is basically this. So now we have the derivatives of the loss with respect to B, all the elements of B. And now we want the derivative loss with respect to all these little A's. So how do the B's depend on the A's is basically what we're after. What is the local derivative of this operation? Well, we can see here that B1 only depends on these elements here.
-
Unknown A
The derivative of B1 with respect to all of these elements down here is 0. But for these elements here, like a 11, a 12, etc, the local derivative is 1, right? So db1 by d, a 11, for example is 1. So it's 1, 1 and 1. So when we have the derivative of loss with respect to B1, the local derivative of B1 with respect to these inputs is zeros here, but it's one on these guys. So in the chain rule we have the local derivative times sort of the derivative of B1. And so because the local derivative is one on these three elements, the local derivative multiplying the derivative B1 will just be the derivative of B1. And so you can look at it as a router. Basically an addition is a router of gradient. Whatever gradient comes from above, it just gets routed equally to all the elements that participate in that addition.
-
Unknown A
So in this case, the derivative of B1 will just flow equally to the derivative of A11, A12 and A13. So if we have a derivative of all the elements of B in this column tensor, which is D counts sum that we've calculated just now, we basically see that what that amounts to is all of these are now flowing to all these elements of A, and they're doing that horizontally. So basically what we want is we want to take the dcount sum of size 30 by 1 and we just want to replicate it 27 times horizontally to create 32 by 27 array. So there's many ways to implement this operation. You could of course just replicate the tensor, but I think maybe one clean one is that dcounts is simply torch ones like so just a two dimensional arrays of ones in the shape of counts.
-
Unknown A
So 32 by 27 times dcounts sum. So this way we're letting the broadcasting here basically implement the replication. You can look at it that way. But then we have to also be careful because dcounts was already calculated, we calculated earlier here, and that was just the first branch and we're now finishing the second branch. So we need to make sure that these gradients add so plus equals. And then here let's comment out the comparison and let's make sure crossing fingers that we have the correct result. So Pytorch agrees with us on this gradient as well. Okay, hopefully we're getting a hang of this. Now counts as an element wise exp of norm logits. So now we want denormlowjits. And because it's an element wise operation, everything is very simple. What is the local derivative of E to the X? It's famously just E to the X.
-
Unknown A
So this is the local derivative that is the local derivative. Now we already calculated it and it's inside counts. So we may as well potentially just reuse counts. That is the local derivative times dcounts. Funny as that looks, counts times dcounts is the derivative on the normal digits. And now let's erase this and let's verify and it looks good. So that's normloadits. Okay, so we are here on this line. Now the normlow jits, we have that and we're trying to calculate the logits and the logit maxes. So back propagating through this line. Now we have to be careful here because the shapes again are not the same. And so there's an implicit broadcasting happening here. So normal jits has the shape 32 by 27 logits does as well. But logitmaxis is only 32 by 1. So there's a broadcasting here in the minus. Now here I tried to sort of write out a toy example.
-
Unknown A
Again we basically have that this is our C equals A minus B. And we see that because of the shape these are three by three, but this one is just a column. And so for example, every element of C we have to look at how it came to be and every element of C is just the corresponding element of A minus basically that associated B. So it's very clear now that the derivatives of every One of these Cs with respect to their inputs are one for the corresponding A and it's a negative one for the corresponding B. And so therefore the derivatives on the C will flow equally to the corresponding A's and then also to the corresponding B's. But then in addition to that, the B's are broadcast. So we'll have to do the additional sum just like we did before. And of course derivatives for Bs will undergo a minus because the local derivative here is negative one.
-
Unknown A
So DC32 by D B3 is negative one. So let's just implement that. Basically, dlogits will be exactly copying the derivative on normlogits. So dlojits equals dnormlojits and I'll do a dot clone for safety. So we're just making a copy and then we have that dlogit MAXIS will be the negative of dnormlojits because of the negative sign. And then we have to be careful because logitmaxis is a column. And so just like we saw before, because we keep replicating the same elements across all the columns, then in the backward pass, because we keep reusing this, these are all just like separate branches of use of that one variable. And so therefore we have to do a sum along one would keep them equals true so that we don't destroy this dimension. And then the logitmaxis will be the same shape. Now we have to be careful because this dlogits is not the final dlogits and that's because not only do we get gradient signal into logits through here, but logitmaxis is a function of logits and that's the second branch into logits.
-
Unknown A
So this is not yet our final derivative for logits. We will come back later for the second branch. For now the logitmaxis is the final derivative. So let me uncomment this CMP here and let's just run this and logit maxis if Pytorch agrees with us. So that was the derivative through this line. Now, before we move on, I want to pause here briefly and I want to look at these logitmaxes and especially their gradients. We've talked previously in the previous lecture that the only reason we're doing this is for the numerical stability of the softmax that we are implementing here. And we talked about how if you take these logits for any one of these examples, so one row of this logits tensor, if you add or subtract any value equally to all the elements Then the value of the probs will be unchanged. You're not changing the softmax.
-
Unknown A
The only thing that this is doing is it's making sure that exp doesn't overflow. And the reason we're using a max is because then we are guaranteed that each row of logits the highest number is zero. And so this will be safe. And, and so basically that has repercussions. If it is the case that changing logit maxis does not change the probs and therefore does not change the loss, then the gradient on logitmaxis should be zero. Right, because saying those two things is the same. So indeed we hope that this is very, very small numbers. Indeed we hope this is zero. Now, because of floating point sort of wonkiness, this doesn't come out exactly zero, only in some of the rows it does, but we get extremely small values like 1, E, 9 or 10. And so this is telling us that the values of logitmaxes are not impacting the loss, as they shouldn't.
-
Unknown A
It feels kind of weird to backpropagate through this branch, honestly, because if you have any implementation of like F crossentropy in Pytorch and you block together all of these elements and you're not doing the back propagation piece by piece, then you would probably assume that the derivative through here is exactly zero. So you would be sort of skipping this branch because it's only done for numerical stability. But it's interesting to see that even if you break up everything into the full atoms and you still do the computation as you'd like with respect to numerical stability, the correct thing happens and you still get a very, very small gradients here, basically reflecting the fact that the values of these do not matter with respect to the final loss. Okay, so let's now continue back propagation through this line here. We've just calculated the logit maxes and now we want to back prop into logits through this second branch.
-
Unknown A
Now here, of course, we took logits and we took the max along all the rows and then we looked at its values here. Now the way this works is that in Pytorch, this thing here, the max returns both the values and it returns the indices at which those values to count the maximum value. Now, in the forward pass we only used values because that's all we needed, but in the backward pass, it's extremely useful to know about where those maximum values occurred and we have the indices at which they occurred. And this will of course help us do the backpropagation, because what should the backward pass be here? In this case we have the logis tensor, which is 32 by 27, and in each row we find a maximum value and then that value gets plucked out into logitmaxes. And so intuitively, basically the derivative flowing through here then should be one times the local derivatives is one for the appropriate entry that was plucked out and then times the global derivative of the logit maxis.
-
Unknown A
So really what we're doing here is, if you think through it, is we need to take the DLOADIT maxis and we need to scatter it to the correct positions in these logits from where the maximum values came. And so I came up with one line of code, so that does that. Let me just erase a bunch of stuff here so the line of you could do it kind of very similar to what we've done here, where we create a zeros and then we populate the correct elements. So we use the indices here and we would set them to be one, but you can also use one hot. So F1 hot. And then I'm taking the Logis Max over the first dimension indices and I'm telling Pytorch that the dimension of every one of these tensors should be 27. And so what this is going to do is.
-
Unknown A
Okay, I apologize, this is crazy. Plt Im show of this. It's really just an array of where the maxis came from in each row and that element is one and all the other elements are zero. So it's one hot vector in each row. And these indices are now populating a single one in the proper place. And then what I'm doing here is I'm multiplying by the logitmaxis. And keep in mind that this, this is a column of 32 by one. And so when I'm doing this times the logit maxis, the logit maxis will broadcast and that column will, you know, get replicated. And then an element wise multiply will ensure that each of these just gets routed to whichever one of these bits is turned on. And so that's another way to implement this kind of a operation. And, and both of these can be used.
-
Unknown A
I just thought I would show an equivalent way to do it. And I'm using plus equals because we already calculated the logits here and this is now the second branch. So let's look at logits and make sure that this is correct and we see that we have exactly the correct answer. Next up we want to continue with logits here. That is an outcome of a Matrix multiplication and a bias offset in this linear layer. So I've printed out the shapes of all these intermediate tensors. We see that logits is of course 32 by 27, as we've just seen. Then the H here is 32 by 64. So these are 64 dimensional hidden states. And then this W Matrix projects those 64 dimensional vectors into 27 dimensions. And then there's a 27 dimensional offset which is a one dimensional vector. Now we should note that this plus here actually broadcasts because H multiplied by by W2 will give us a 32 by 27.
-
Unknown A
And so then this plus B2 is a 27 dimensional vector here. Now, in the rules of broadcasting, what's going to happen with this bias vector is that this one dimensional vector of 27 will get aligned with a padded dimension of 1 on the left and it will basically become a row vector. And then it will get replicated vertically 32 times to make it 32 by 27. And then there's an element wise multiply. Now the question is, how do we backpropagate from logits to the hidden states, the weight matrix W2 and the bias B2. And you might think that we need to go to some matrix calculus and then we have to look up the derivative for matrix multiplication. But actually you don't have to do any of that. And you can go back to first principles and derive this yourself on a piece of paper.
-
Unknown A
And specifically what I like to do and what I find works well for me is you find a specific small example that you then fully write out. And then in the process of analyzing how that individual small example works, you will understand a broader pattern and you'll be able to generalize and write out the full general formula for how these derivatives flow in an expression like, like this. So let's try that out. So, pardon the low budget production here, but what I've done here is I'm writing it out on a piece of paper. Really what we are interested in is we have a multiply B plus C and that creates a D and we have the derivative of the loss with respect to D. And we'd like to know what the derivative of the loss is with respect to A, B and C. Now these Here are little two dimensional examples of a matrix multiplication 2 by 2 times a 2 by 2 plus a 2.
-
Unknown A
A vector of just two elements C1 and C2 gives me a 2 x 2. Now notice here that I have a bias vector here called C, and the bias vector C1 and C2. But as I described over here, that bias vector will become a row vector in the broadcasting and will replicate vertically. So that's what's happening here as well. C1, C2 is replicated vertically, and we see how we have two rows of C1, C2 as a result. So now when I say write it out, I just mean like this. Basically break up this matrix multiplication into the actual thing that's going on under the hood. So as a result of matrix multiplication and how it works, D11 is the result of a dot product between the first row of A and the first column of B. So A11, B11 plus A12, B21 plus C1, so on, so forth for all the other elements of D.
-
Unknown A
And once you actually write it out, it becomes obvious this is just a bunch of multiplies and adds. And we know from micrograd how to differentiate multiplies and adds. And so this is not scary anymore. It's not just matrix multiplication, it's just tedious, unfortunately. But this is completely tractable. We have DL by D for all of these, and we want DL by all these little other variables. So how do we achieve that? And how do we actually get the gradients? Okay, so the low budget production continues here. So let's for example, derive the derivative of the loss with respect to A11. We see here that A11 occurs twice in our simple expression right here, right here, and influences d11 and d12. So this is. So what is DL by d, a11? Well, it's DL by d11 times the local derivative of d11, which in this case is just b11, because that's what's multiplying a11 here.
-
Unknown A
And likewise here, the local derivative of d12 with respect to a11 is just b12. And so b12 will in the chain rule, therefore multiply DL by d12. And then because a11 is used both to produce d11 and d12, we need to add up the contributions of both of those sort of chains that are running in parallel. And that's why we get a plus just adding up those two, those two contributions. And that gives us DL by DA11. We can do the exact same analysis for the other one, for all the other elements of A. And when you simply write it out, it's just super simple. Taking the gradients on, you know, expressions like this, you find that this matrix DL by DA that we're after, right? If we just arrange all of them in the same shape as A takes so A is just a 2H matrix.
-
Unknown A
So DL by DA here will be also just the same shape tensor with the derivatives. Now, so DL by DA11, etc. And we see that actually we can express what we've written out here as a matrix multiply. And so it just so happens that DL by that all of these formulas that we've derived here by taking gradients can actually be expressed as a matrix multiplication. And in particular, we see that it is the matrix multiplication of these two matrices. So it is the DL by D and then matrix multiplying B, but B transpose actually. So you see that B21 and B12 have changed place, whereas before we had of course, B11, B12, B21, B22. So you see that this other matrix B is transposed. And so basically what we have, long story short, just by doing very simple reasoning here, by breaking up the expression, in the case of a very simple example, is that DL by da, which is this is simply equal to DL by DD matrix multiplied with B transpose.
-
Unknown A
So that is what we have so far. Now we also want the derivative with respect to B and C. Now for B, I'm not actually doing the full derivation because honestly, it's not deep, it's just annoying. It's exhausting. You can actually do this analysis yourself. You'll also find that if you take these expressions and you differentiate with respect to B instead of A, you will find that DL by DB is also a matrix multiplication. In this case, you have to take the matrix A and transpose it and matrix multiply that with DL by dd. And that's what gives you DL by db. And then here, for the offsets c1 and c2, if you again just differentiate with respect to c1, you will find an expression like this and c2 an expression like this. And basically you'll find that DL by dc is simply because they're just offsetting these expressions.
-
Unknown A
You just have to take the DL by DD matrix of the derivatives of D, and you just have to sum across the columns, and that gives you the derivatives for C. So, long story short, the backward pass of a matrix multiply is a matrix multiply. And instead of just like we had D equals A times B plus C, in a scalar case, we sort of like arrive at something very, very similar, but now with a matrix multiplication instead of a scalar multiplication, so the derivative of D with respect to A is DL by DD matrix multiply, B transpose. And here it's a transpose multiply DL by dd. But in both cases it's matrix multiplication with the derivative and the other term in the multiplication and for C, it is a sum. Now I'll tell you a secret. I can never remember the formulas that we just arrived for backpropagating from H multiplication.
-
Unknown A
And I can back propagate through these expressions just fine. And the reason this works is because the dimensions have to work out. So let me give you an example. Say I want to create dh. Then what should the age be? Number one. I have to know that the shape of DH must be the same as the shape of H, and the shape of H is 32 by 64. And then the other piece of information I know is that DH must be some kind of matrix multiplication of dlowgets with w2 and dlowgets is 32 by 27 and w2 is 64 by 27. There is only a single way to make the shape work out in this case, and it is indeed the correct result. In particular here, H needs to be 32 by 64. The only way to achieve that is to take delogits and matrix multiply it with.
-
Unknown A
You see how I have to take W2 but I have to transpose it to make the dimensions work out. So W2 transpose and it's the only way to make these to matrix multiply those two pieces to make the shapes work out. And that turns out to be the correct formula. So if we come here, we want Dh, which is Da, and we see that Da is DL by DD matrix multiply, B transpose. So that's DLow J multiply and B is W2. So W2 transpose, which is exactly what we have here. So there's no need to remember these formulas. Similarly, now If I want DW2, well, I know that it must be a matrix multiplication of dlowgets and H. And maybe there's a few transpose, like there's one transpose in there as well. And I don't know which way it is. So I have to come to W2 and I see that its shape is 64 by 27 and that has to come from some matrix multiplication of these two.
-
Unknown A
And so to get a 64 by 27 I need to take H, I need to transpose it and then I need to matrix multiply it. So that will become 64 by 32. And then I need to matrix multiply with the 32 by 27 and that's going to give me a 64 by 27. So I need to matrix multiply this with the logits shape just like that. That's the only way to make the dimensions work out and just use matrix multiplication. And if we come here, we see that that's exactly what's here. So a transpose A for us is H multiplied with the logits. So that's W2 and then DB2 is just the vertical sum. And actually in the same way, there's only one way to make the shapes work out. I don't have to remember that it's a vertical sum along the zero axis because that's the only way that this makes sense, because B2 shape is 27.
-
Unknown A
So in order to get a Dlogits here is 32 by 27. So knowing that it's just sum over dlogits in some direction, that direction must be zero because I need to eliminate this dimension. So it's this. So this is. So that's kind of like the hacky way. Let me copy, paste and delete that and let me swing over here. And this is our backward pass for the linear layer, hopefully. So now let's uncomment these three and we're checking that we got all the three derivatives correct and run and we see that H, W2 and B2 are all exactly correct. So we backpropagated through a linear layer. Now next up we have derivative for the H already and we need to back propagate through 10H into HPREACT. So we want to derive dhpreact. And here we have to back propagate through a tanh. And we've already done this in micrograd and we remember that tanh is a very simple backward formula.
-
Unknown A
Now unfortunately, if I just put in D by DX of TanH of X into WolframAlpha, it lets us down. It tells us that it's a hyperbolic secant function squared of x. It's not exactly helpful, but luckily Google Image search does not let us down and it gives us the simpler formula. And in particular, if you have that A is equal to tanh of z, then da by dz backpropagating through tanh is just 1 minus a squared. And take note that 1 minus a square a here is the output of the tanh, not the input to the tanh z. So the DA by DZ is here formulated in terms of the output of that tanh. And here also in Google Image search we have the full derivation. If you want to actually take the actual definition of 10h and work through the math to figure out 1 minus tangent square of z.
-
Unknown A
So 1 minus a square is the local derivative in our case, that is 1 minus the output of tanh squared, which here is h, so it's H squared and that is the local derivative and then times the chain rule dh. So that is going to be our candidate implementation. So if we come here and then uncomment this, let's hope for the best and we have the right answer. Okay, next up we have dhpreact and we want to back propagate into the gain the B and RAW and the BNBias. So here this is the Bastron parameters BN gain and bias inside the bash norm that take the B and raw, that is exact unit Gaussian and they scale it and shift it. And these are the parameters of the bash norm. Now here we have a multiplication, but it's worth noting that this multiply is very, very different from this matrix multiply.
-
Unknown A
Here matrix multiply are dot products between rows and columns of these matrices involved. This is an element wise multiply. So things are quite a bit simpler. Now we do have to be careful with some of the broadcasting happening in this line of code though. So you see how BNG and BnBIAs are 1 by 64, but HPREACT and BnRAW are 32 by 64. So we have to be careful with that and make sure that all the shapes work out fine and that the broadcasting is correctly back propagated. So in particular, let's start with DB and gain. So DB and gain should be. And here this is again element wise multiply. And whenever we have A times B equals C, we saw that the local derivative here is just if this is A, the local derivative is just the B the other one. So the local derivative is just B and raw and then times chain rule.
-
Unknown A
So dhpreact. So this is the candidate gradient. Now again we have to be careful because BNG is of size 1 by 64, but this here would be 32 by 64. And so the correct thing to do in this case of course is, is that B and gain here is a row vector of 64 numbers. It gets replicated vertically in this operation. And so therefore the correct thing to do is to sum because it's being replicated. And therefore all the gradients in each of the rows that are now flowing backwards need to sum up to that same tensor DBN gain. So we have to sum across all the zero, all the examples basically which is the direction in which this gets replicated. And now we have to be also careful because BnGane is of shape 1 by 64, so in fact I need to keep them as true, otherwise I would just get 64.
-
Unknown A
Now, I don't actually really remember why the BN gain and the BN bias, I made them be one by 64, but the biases B1 and B2, I just made them be one dimensional vectors, they're not two dimensional tensors. So I can't recall exactly why I left the gain and the bias as two dimensional. But it doesn't really matter as long as you are consistent and you're keeping it the same. So in this case we want to keep the dimension so that the tensor shapes work. Next up we have B and raw. So DBN raw will be BN gain multiplying dhpreact, that's our chain rule. Now what about the dimensions of this? We have to be careful, right? So Dhpreact is 32 by 64, Bngane is 1 by 64. So it will just get replicated and to create this multiplication, which is the correct thing because in a forward pass it also gets replicated in just the same way.
-
Unknown A
So in fact we don't need the brackets here. We're done and the shapes are already correct. And finally for the bias, very similar. This bias here is very, very similar to the bias we saw in the linear layer. And we see that the gradients from hpreact will simply flow into the biases and add up because these are just offsets. And so basically we want this to be dhpreact, but it needs to sum along the right dimension. And in this case, similar to the gain, we need to sum across the 0th dimension, the examples because of the way that the bias gets replicated vertically and we also want to have keep them as true. And so this will basically take this and sum it up and give us a one by 64. So this is the candidate implementation, it makes all the shapes work. Let me bring it up down here and then let me uncomment these three lines to check that we are getting the correct result for all the three tensors.
-
Unknown A
And indeed we see that all of that got back propagated correctly. So now we get to the batch norm layer. We see how here BNG and bnbias are the parameters, so the backpropagation ends, but B and raw now is the output of the standardization. So here, what I'm doing of course is I'm breaking up the batch norm into manageable pieces so we can back propagate through each each line individually. But basically what's happening is bn mean I is the sum. So this is the BN mean I. I apologize for the variable naming. Bn diff is x minus mu b and diff 2 is x minus mu squared here inside the variance b and var is the variance. So sigma squared. This is b and var and it's basically the sum of squares. So this is the x minus mu squared and then the sum. Now you'll notice one departure here.
-
Unknown A
Here it is normalized as 1 over m, which is number of examples. Here I'm normalizing as 1 over n minus 1 instead of m. And this is deliberate and I'll come back to that in a bit. When we are at this line, it is something called the Bessel's correction, but this is how I want it. In our case, BN var inv then becomes basically BNvar plus epsilon, epsilon is 1 negative 5 and then it's 1 over square root is the same as raising to the power of negative 0.5, right? Because 0.5 is square root and then negative makes it one over square root. So BNV is a one over this denominator here. And then we can see that bnraw, which is the x hat here, is equal to the bndiff, the numerator multiplied by the bnvar inv. And this line here that creates pre hpreact was the last piece.
-
Unknown A
We've already back propagated through it. So now what we want to do is we are here and we have B and raw and we have to first backpropagate into bndiff and bn var inf. So now we're here and we have dbinraw and we need to backpropagate through this line. Now I've written out the shapes here and Indeed BN VAR inv is a shape 1x64. So there is a broadcasting happening here that we have to be careful with. But it is just an element wise simple multiplication. By now we should be pretty comfortable with that to get db and if we know that this is just B and varenv multiplied with dbnraw. And conversely to get dbnvarinv, we need to take bndif and multiply that by dbnraw. So this is the candidate. But of course we need to make sure that broadcasting is obeyed. So in particular BN var inf multiplying with DPN raw will be okay, and give us 32 by 64 as we expect.
-
Unknown A
But DBN var inv would be taking a 32 by 64, multiplying it by 32 by 64. So this is a 32 by 64. But of course db, this BNVar inv is only 1 by 64. So the second line here needs A sum across the examples and because there's this dimension here, we need to make sure that keep them is true. So this is the candidate. Let's erase this and let's swing down here and implement it. And then let's comment out dbnbarimpv and dbndiff. Now we'll actually notice that DB&DIFF, by the way is going to be incorrect. So when I run this, bn var INV is correct, bndiff is not correct. And this is actually expected because we're not done with bndiff. So in particular when we slide here, we see here that B and raw is a function of bnif. But actually bn var is a function of B and var which is a function of B&DF do which is a function of b and diff.
-
Unknown A
So it comes here. So bdn diff. These variable names are crazy. I'm sorry. It branches out into two branches and we've only done one branch of it. We have to continue our back propagation and eventually come back to B and div and then we'll be able to do A plus equals and get the actual correct gradient. For now it is good to verify that CMP also works. It doesn't just lie to us and tell us that everything is always correct. It can in fact detect when your gradient is not correct. So that's good to see as well. Okay, so now we have the derivative here and we're trying to back propagate through this line. And because we're raising to a power of negative 0.5 I brought up the power rule and we see that basically we have that the BM bar will now be we bring down the exponent so negative 0.5 times x which is this.
-
Unknown A
And now raised to the power of negative 0.5 minus 1, which is negative 1.5. Now we would have to also apply a small chain rule here in our head because we need to take further derivative of b and var with respect to this expression here inside the bracket. But because this is an element wise operation and everything is fairly simple, that's just one and so there's nothing to do there. So this is the local derivative and then times the global derivative to create the chain rule. This is just times the b and var nth. So this is our candidate. Let me bring this down and uncomment the check and we see that we have the correct result. Now, before we back propagate through the next line, I wanted to briefly talk about the note here where I'm using the Bessel's Correction dividing by n -1 instead of dividing by n.
-
Unknown A
When I normalize here the sum of squares. Now you'll notice that this is a departure from the paper which uses 1 over n instead, not 1 over n minus 1. There m is rn. And so it turns out that there are two ways of estimating variance of an array. One is the biased estimate which is 1 over n, and the other one is the unbiased estimate which is 1 over n minus 1. Now confusingly in the paper this is not very clearly described and also it's a detail that kind of matters. I think they are using the biased version training time, but later when they are talking about the inference, they are mentioning that when they do the inference they are using the unbiased estimate which is the N minus 1 version in basically for inference and to calibrate the running mean and the running variance basically.
-
Unknown A
And so they actually introduce a train test mismatch where in training they use the biased version and in the test time they use the unbiased version. I find this extremely confusing. You can read more about the Bessel's correction and why dividing by N minus 1 gives you a better estimate of the variance in the case where you have population sizes or samples for a population that are very small. And that is indeed the case for us because we are dealing with mini batches and these mini batches are a small sample of a larger population which is the entire training set. And so it just turns out that if you just estimate it using one n that actually almost always underestimates the variance and it is a biased estimator and it is advised that you use the unbiased version and divide by N1. And you can go through this article here that I liked that actually describes the full of reasoning and I'll link it in the video description.
-
Unknown A
Now when you calculate Torchdott variance you'll notice that they take the unbiased flag with or not you want to divide by N or N1. Confusingly, they do not mention what the default is for unbiased, but I believe unbiased by default is true. I'm not sure why the docs here don't cite that. Now in the batch norm 1D, the documentation again is kind of wrong and confusing. It says that the standard deviation is calculated via the biased estimator, but this is actually not exactly right and people have pointed out that it is not right in a number of issues since then because actually the rabbit hole is deeper and they follow the paper exactly and they use the biased version for training, but when they're estimating the running standard deviation, they are using the unbiased version. So again, there's the train test mismatch. So, long story short, I'm not a fan of train test discrepancies.
-
Unknown A
I basically kind of consider the fact that we use the biased version, the training time and the unbiased test time. I basically consider this to be a bug, and I don't think that there's a good reason for that. It's not really. They don't really go into the detail of the reasoning behind it in this paper. So that's why I basically prefer to use the bestless correction in my own work. Unfortunately, Bastion does not take a keyword argument that tells you whether or not you want to use the unbiased version or the biased version in both train and test. And so therefore anyone using batch normalization basically, in my view, has a bit of a bug in the code. And this turns out to be much less of a problem if your batch mini batch sizes are a bit larger. But still, I just find it kind of unpalatable.
-
Unknown A
So maybe someone can explain why this is okay, but for now I prefer to use the unbiased version consistently, both during training and at test time. And that's why I'm using one over N1 here. Okay, so let's now actually back propagate through this line. So the first thing that I always like to do is I like to scrutinize the shapes first. So in particular here, looking at the shapes of what's involved, I see that BN VAR shape is 1 by 64, so it's a row vector. And BNIF2 shape is 32 by 64. So clearly here we're doing a sum over the zeroth axis to squash the first dimension of the shapes here using a sum. So that right away actually hints to me that there will be some kind of a replication or broadcasting in the backward pass. And maybe you're noticing the pattern here, but basically anytime you have a sum in the forward pass, that turns into a replication or broadcasting in the backward pass along the same dimension.
-
Unknown A
And conversely, when we have a replication or a broadcasting in the forward pass, that indicates a variable reuse. And so in the backward pass that turns into a sum over the exact same dimension. And so hopefully you're noticing that duality, that those two are kind of like the opposites of each other in the forward and backward pass. Now, once we understand the shapes, the next thing I like to do always is I Like to look at a toy example in my head to sort of just like understand roughly how the variable, the variable dependencies go in the mathematical formula. So here we have a two dimensional array at B end of two which we are scaling by a constant and then we are summing vertically over the columns. So if we have a two by two matrix A and then we sum over the columns and scale, we would get A row vector B1, B2 and B1 depends on a in this way, where it's just sum that is scaled of A and B2 in this way where it's the second column summed and scaled.
-
Unknown A
And so looking at this, basically what we want to do now is we have the derivatives on B1 and B2 and we want to back propagate them into A's. And so it's clear that just differentiating in your head, the local derivative here is 1 over n minus 1 times 1 for each one of these A's. And basically the derivative of B1 has to flow through the columns of A A scaled by one over n minus one. And that's roughly what's happening here. So intuitively the derivative flow tells us that DBNIF2 will be the local derivative of this operation. And there are many ways to do this by the way, but I like to do something like this torch 1sl of bnif 2. So I'll create a large array 2 dimensional of ones and then I will scale it so 1.0 divided by n minus 1.
-
Unknown A
So this is a array of 1 over n minus 1 and that's sort of like the local derivative. And now for the chain rule I will simply just multiply it by dbn bar. And notice here what's going to happen. This is 32 by 64 and this is just one by 64. So I'm letting the broadcasting do the replication, because internally In PyTorch basically DBN VAR, which is 1 by 64 row vector will in this multiplication get copied vertically until the two are of the same shape and then there will be an element wise multiply and so that, so that the broadcasting is basically doing the replication and I will end up with the derivatives of DBN Diff2 here. So this is the candidate solution. Let's bring it down here, let's uncomment this line where we check it and let's hope for the best. And indeed we see that this is the correct formula.
-
Unknown A
Next up let's differentiate here into B and diff. So here we have that B and diff is element Wise squared to create B and if so this is a relatively simple derivative because it's a simple element wise operation. So it's kind of like the scalar case and we have that dbndiff should be if this is x squared, then the derivative of this is 2x right? So it's simply 2 times B and if that's the local derivative and then times chain rule and the shape of these is the same, they are of the same shape. So times this. So that's the backward pass for this variable. Let me bring that down here. And now we have to be careful because we already calculated dbmdif, right? So this is just the end of the other branch coming back to bndiff because bndiff was already back propagated to way over here from B and raw.
-
Unknown A
So we now completed the second branch. And so that's why I have to do plus equals. And if you recall, we had an incorrect derivative for bndiff before. And I'm hoping that once we append this last missing piece we have the exact correctness. So let's run and bndiff now actually shows the exact correct derivative. So that's comforting. Okay, so let's now back propagate through this line here. The first thing we do of course is we check the shapes and I wrote them out here and basically the shape of this is 32 by 64 HPRE BN is the same shape, but BN mini is a row vector 1 by 64. So this minus here will actually do broadcasting. And so we have to be careful with that. And as a hint to us, again because of the duality, a broadcasting in the forward pass means variable reuse and therefore there will be a sum in the backward pass.
-
Unknown A
So let's write out the backward pass here. Now backpropagate into the hprebn because these are the same shape, then the local derivative for each one of the elements here is just one for the corresponding element in here. So basically what this means is that the gradient just simply copies, it's just a variable assignment, it's equality. So I'm just going to clone this tensor just for safety to create an exact copy of dbn div and then here to back propagate into this one. What I'm inclined to do here is dbn minai will basically be what is the local derivative? Well, it's negative torch one slike of the shape of B and if right? And then times the derivative here DB and if and this here is the back propagation for the replicated B and mean I. So I still have to back propagate through the replication and the broadcasting.
-
Unknown A
And I do that by doing a sum. So I'm going to take this whole thing and I'm going to do a sum over the 0th dimension, which was the replication. So if you scrutinize this, by the way, you'll notice that this is the same shape as that. And so what I'm doing, what I'm doing here doesn't actually make that much sense because it's just a array of ones multiplying DPN diff. So in fact I can just do this and that is equivalent. So this is the candidate backward pass. Let me copy it here and then let me comment out this one and this one. Enter and it's wrong. Damn. Actually, sorry, this is supposed to be wrong. And it's supposed to be wrong because we are backpropagating from a BN diff into HPRE bn, but we're not done because BN min I depends on HPRE bn and there will be a second portion of that derivative coming from this second branch.
-
Unknown A
So we're not done yet and we expect it to be incorrect. So there you go. So let's now back propagate from BNMIN I into hprbn. And so here again we have to be careful because there's a broadcasting along or there's a sum along the 0th dimension. So this will turn into broadcasting in the backward pass now and I'm going to go a little bit faster on this line because it is very similar to the line that we had before and multiple lines in the past in fact. So DHPBN will be, the gradient will be scaled by 1n and then basically this gradient here on DBN min I is going to be scaled by 1 over n and then it's going to flow across all the columns and deposit itself into DHPBN. So what we want is this thing scaled by 1N. Let me put the constant up front here.
-
Unknown A
So scale down the gradient. And now we need to replicate it across all the. Across all the rows here. So we, I like to do that by torch, one slash of basically HPBN and I will let broadcasting do the work of replication. So like that. So this is the HPREPN and hopefully we can plus equals that. So this here is broadcasting and then this is the scaling. So this should be correct. Okay, so that completes the backpropagation of the Bastrom layer. And we are now here, let's back propagate through the linear layer one here now, because everything is getting a little vertically crazy. I copy pasted the line here and let's just back propagate through this one line. So first of course we inspect the shapes and we see that this is 32 by 64, MCAT is 32 by 30, W1 is 30 by 64 and B1 is just 64.
-
Unknown A
So as I mentioned, back propagating through linear layers is fairly easy just by matching the shapes. So let's do that. We have that DMP cat should be some matrix multiplication of DHPBN with W1 and one transpose thrown in there. So to make MCAT be 32 by 30 I need to take DHPBN 32 by 64 and multiply it by W1 transpose. To get DW1 I need to end up with 30 by 64. So to get that I need to take MCAT transpose and multiply that by DhPrepian. And finally to get DB1, this is a addition and we saw that basically I need to just sum the elements in DHPBN along some dimension and to make the dimensions work out, I need to sum along the zero axis here to eliminate this dimension. And we do not keep dims so that we want to just get a single one dimensional vector of 64.
-
Unknown A
So these are the claimed derivatives. Let me put that here and let me uncomment three lines and cross our fingers. Everything is great. Okay, so we now continue almost there. We have the derivative of MCAT and we want to derivative, we want to back propagate into em. So I again copied this line over here. So this is the forward pass and then this is the shapes. So remember that the shape here was 32 by 30 and the original shape of EM was 32 by 3 by 10. So this layer in the forward pass, as you recall, did the concatenation of these three ten dimensional character vectors. And so now we just want to undo that. So this is actually relatively straightforward operation because the backward pass of the what is the view View is just a re representation of the array. It's just a logical form of how you interpret the array.
-
Unknown A
So let's just reinterpret it to be what it was before. So in other words, the EMP is not 32 by 30, it is basically Dempcat. But if you view it as the original shape, so just M shape you can pass in tuples into view. And so this should just be okay, we just re represent that view and then we uncomment this line here and hopefully yeah, so the derivative of M is correct. So in this case we just have to re represent the shape of those derivatives into the original view. So now we are at the final line and the only thing that's left to back propagate through is this indexing operation here and a C at xb. So as I did before, I copy pasted this line here and let's look at the shapes of everything that's involved and remind ourselves how this worked.
-
Unknown A
So m shape was 32 by 3 by 10, so it's 32 examples. And then we have three characters, each one of them has a 10 dimensional embedding. And this was achieved by taking the lookup table c which have 27 possible characters, each of them 10 dimensional. And we looked up at the rows that were specified inside this tensor XB. So XB is 32x3 and it's basically giving us for each example the identity or the index of which character is part of that example. And so here I'm showing the first five rows of this tensor xb. And so we can see that for example here it was, the first example in this batch is that the first character and the first character and the fourth character character comes into the neural net. And then we want to predict the next character in a sequence after the character is 1, 1, 4.
-
Unknown A
So basically what's happening here is there are integers inside xb and each one of these integers is specifying which row of C we want to pluck out, right? And then we arrange those rows that we've plucked out into 32 by 3 by 10 tensor and we just package them in, we just package them into the sensor. And now what's happening is that we have dmp. So for every one of these basically plucked out rows, we have their gradients now, but they're arranged inside this 32x3x tensor. So all we have to do now is we just need to route this gradient backwards through this assignment. So we need to find which row of C that every one of these ten dimensional embeddings come from, and then we need to deposit them into dc. So we just need to undo the indexing. And of course, if any of these rows of C was used multiple times, which almost certainly is the case, like the row 1 and 1 was used multiple times, then we have to remember that the gradients that arrive there have to add.
-
Unknown A
So for each occurrence we have to have an addition. So let's now write this out. And I don't actually know of like a much better way to do this than a for loop. Unfortunately, in Python. So maybe someone can come up with a vectorized efficient operation. But for now let's just use for loops. So let me create a torch 0C to initialize just 27 by 10 tensor of all zeros and then honestly 4k in range XB shape at 0. Maybe someone has a better way to do this. But for j in range xp shape at 1, this is going to iterate over all the, all the elements of xb, all these integers and then let's get the index at this position. So the index is basically XB at KJ. So that an example of that like is 11 or 14 and so on. And now in the forward pass we took, we basically took the row of C at index and we deposited it into M@kj.
-
Unknown A
That's what happened, that's where they are packaged. So now we need to go backwards and we just need to route DMP at the position kj. We now have these derivatives for each position and it's 10 dimensional and you just need to go into the correct row of C. So DC rather at IX is this, but plus equals because there could be multiple occurrences like the same row could have been used many, many times. And so all of those derivatives will just go backwards through the indexing and they will add. So this is my candidate solution. Let's copy it here. Let's uncomment this and cross our fingers. Yay. So that's it, we've back propagated through this entire beast. So there we go. Totally makes sense. So now we come to exercise 2. It basically turns out that in this first exercise we were doing way too much work, we were back propagating way too much and it was all good practice and so on.
-
Unknown A
But it's not what you would do in practice. And the reason for that is, for example here I separated out this loss calculation over multiple lines and I broke it up all to like its smallest atomic pieces and we back propagated through all of those individually. But it turns out that if you just look at the mathematical expression for the loss, then actually you can do the differentiation on pen and paper and a lot of terms cancel and simplify and the mathematical expression you end up with can be significantly shorter and easier to implement than back propagating through all the little pieces of everything you've done. So before we had this complicated forward pass going from logits to the loss. But in Pytorch everything can just be glued together into a single call. After cross entropy just pass in logits and the labels and you get the exact same loss as I verify here.
-
Unknown A
So our previous loss and the fast loss coming from the chunk of operations as a single mathematical expression is the same, but it's much, much faster in forward pass. It's also much, much faster in backward pass. And the reason for that is if you just look at the mathematical form of this and differentiate again, you will end up with a very small and short expression. So that's what we want to do here. We want to in a single operation or in a single go or like very quickly go directly into dlogits. And we need to implement dlogits as a function of logits and ybs, but it will be significantly shorter than whatever we did here. Where to get to dlojits, we had to go all the way here. So all of this work can be skipped in a much, much simpler mathematical expression that you can implement here.
-
Unknown A
So, so you can give it a shot yourself. Basically look at what exactly is the mathematical expression of loss and differentiate with respect to the logits. So let me show you a hint. You can of course try it fully yourself, but if not, I can give you some hint of how to get started mathematically. So basically what's happening here is we have logits, then there's the softmax that takes the logits and gives you probabilities. Then we are using the identity of the correct next character to pluck out a row of probabilities. Take the negative log of it to get our negative log probability, and then we average up all the log probabilities or negative log probabilities to get our loss. So basically what we have is for a single individual example, rather we have that loss is equal to negative log probability, where P here is kind of like the of as a vector of all the probabilities.
-
Unknown A
So at the y position, where Y is the label, and we have that P here of course is the softmax. So the I component of P of this probability vector is just the softmax function. So raising all the logits basically to the power of E and normalizing so everything sums to one. Now if you write out P of y here, you can just write out the softmax. And then basically what we're interested in is we're interested in the derivative of the loss with respect to the ith logit. And so basically it's a D by DLI of this expression here, where we have l indexed with the specific label y. And on the bottom we have a sum over j of E to the LJ and the negative block of all that. So Potentially give it a shot pen and paper and see if you can actually derive the expression for DLOSS by dli.
-
Unknown A
And then we're going to implement it here. Okay, so I'm going to give away the result here. So this is some of the math I did to derive the gradients analytically. And so we see here that I'm just applying the rules of calculus from your first or second year of bachelor's degree, if you took it. And we see that the expressions actually simplify quite a bit. You have to separate out the analysis in the case where, where the ith index that you're interested in inside logits is either equal to the label or it's not equal to the label. And then the expressions simplify and cancel in a slightly different way. And what we end up with is something very, very simple. We either end up with basically P at I, where p is again this vector of probabilities after a softmax, or P at I minus 1, where we just simply subtract a 1.
-
Unknown A
But in any case, we just need to calculate the softmax p and then in the correct dimension we need to subtract one, and that's the gradient, the form that it takes analytically. So let's implement this basically, and we have to keep in mind that this is only done for a single example. But here we are working with batches of examples, so we have to be careful of that. And then the loss for a batch is the average loss over all the examples. So in other words, is the example for all the individual examples is the loss for each individual example summed up and then divided by n. And we have to back propagate through that as well and be careful with it. So dlogits is going to be f softmax. Pytorch has a softmax function that you can call, and we want to apply the softmax on the logits and we want to go in the dimension that is one.
-
Unknown A
So basically we want to do the softmax along the rows, these logits. Then at the correct positions, we need to subtract a 1. So dlowgets@ iterating over all the rows and indexing into the columns provided by the correct labels inside YB we need to subtract one and then finally it's the average loss that is the loss. And in the average there's a 1N of all the losses added up. And so we need to also backpropagate through that division. So the gradient has to be scaled down by n as well, because of the mean, but this otherwise should be the result. So now if we verify this, we see that we don't get an exact match. But at the same time, the maximum difference From PyTorch and Rdlojits here is on the order of 5e negative 9. So it's a tiny, tiny number. So because of floating point wonkiness, we don't get the exact bitwise result, but we basically get the correct answer.
-
Unknown A
Approximately. Now, I'd like to pause here briefly before we move on to the next exercise, because I'd like us to get an intuitive sense of what dlogits is, because it has a beautiful and very simple explanation, honestly. So here I'm taking dilojits and I'm visualizing it, and we can see that we have a batch of 32 examples of 27 characters. And what is DeloJits intuitively right? Dlogits is the probabilities that the probabilities matrix in the forward pass. But then here, these black squares are the positions of the correct indices where we subtracted a one. And so what is this doing right? These are the derivatives on delojits. And so let's look at just the first row here. So that's what I'm doing here. I'm calculating the probabilities of these logits and then I'm taking just a first row and this is the probability row and then the logits of the first row and multiplying by n just for us, so that we don't have the scaling by n in here and everything is more interpretable.
-
Unknown A
We see that it's exactly equal to the probability, of course, but then the position of the correct index has a minus equals 1, so minus 1 on that position. And so notice that if you take DLOW jits at zero and you sum actually sums to zero, and so you should think of these gradients here at each cell as like a force. We are going to be basically pulling down on the probabilities of the incorrect characters and we're going to be pulling up on the probability at the correct index. And that's what's basically happening in each row. And the amount of push and pull is exactly equalized because the sum is zero. So the amount to which we pull down the probabilities and the amount that we push up on the probability of the correct character is equal. So sort of the repulsion and the attraction are equal.
-
Unknown A
And think of the neural net now as like a massive pulley system or something like that. We're up here on top of the logits and we're pulling up, we're pulling down the probabilities of incorrect and pulling up the property of the correct. And in this complicated pulley system, because everything is mathematically just determined, just think of it as sort of like this tension translating to this complicating pulley mechanism. And then eventually we get a tug on the weights and the biases. And basically in each update we just kind of like tug in the direction that we'd like for each of these elements and the parameters are slowly given in to the tug. And that's what training a neural net kind of like looks like on a high level. And so I think the forces of push and pull in these gradients are actually very intuitive.
-
Unknown A
Here we're pushing and pulling on the correct answer and the incorrect answers. And the amount of force that we're applying is actually proportional to the probabilities that came out in the forward pass. And so, for example, if our probabilities came out exactly correct, so they would have had zero everywhere except for one at the correct position, then the deloadjits would be all a row of zeros for that example, there would be no push and pull. So the amount to which your prediction is incorrect is exactly the amount by which you're going to get a pull or a push in that dimension. So if you have, for example, a very confidently mispredicted element here, then what's going to happen is that element is going to be pulled down very heavily and the correct answer is going to be pulled up to the same amount and the other characters are not going to be influenced too much.
-
Unknown A
So the amount to which you mispredict is then proportional to the strength of the pole. And that's happening independently in all the dimensions of this tensor. And it's sort of very intuitive and very easy to think through. And that's basically the magic of the cross entropy loss and what it's doing dynamically in the backward pass of the neural net. So now we get to exercise number three, which is a very fun exercise, depending on your definition of fun. And we are going to do for batch normalization exactly what we did for cross entropy loss in exercise number two. That is, we are going to consider it as a glued single mathematical expression and back propagate through it in a very efficient manner, because we are going to derive a much simpler formula for the backward pass of batch normalization. And we're going to do that using pen and paper.
-
Unknown A
So previously we've broken up batch normalization into all of the little intermediate pieces and all the atomic operations inside it and then we back propagated through it one by one. Now we just have a single sort of forward pass of a batch form and it's all glued together and we see that we get the exact same result as before. Now for the bash backward pass, we'd like to also implement a single formula basically for backpropagating through this entire operation, that is the bash normalization. So in the forward pass previously we took HPRebn, the hidden states of the pre batch normalization and created hpreact, which is the hidden states just before the activation. In the batch normalization paper hprebn is X and hpreact is yes. So in the backward pass what we'd like to do now is we have dhpreact and we'd like to produce dhprevn and we'd like to do that in a very efficient manner.
-
Unknown A
So that's the name of the game, calculate DhPrebn given DhPreact. And for the purposes of this exercise we're going to ignore gamma and beta and their derivatives because they take on a very simple form in a very similar way to what we did up above. So let's calculate this, given that right here. So to help you a little bit, like I did before, I started off the implementation here on pen and paper and I took two sheets of paper to derive the mathematical formulas for the backward pass. And basically to set up the problem, just write out the mu, sigma square, variance, xi, hat and yi exactly as in the paper, except for the Bessel correction. And then in the backward pass we have the derivative of the loss with respect to all the elements of Y. And remember that Y is a vector. There's, there's multiple numbers here.
-
Unknown A
So we have all the derivatives with respect to all the y's and then there's a gamma and a beta. And this is kind of like the compute graph. The gamma and the beta, there's the x hat and then the mu and the sigma square squared and the x. So we have DL by dyi and we want DL by DXI for all the I's in these vectors. So this is the compute graph and you have to be careful because I'm trying to note here that these are vectors. There's many nodes here inside X, x hat and y, but mu and sigma, sorry, sigma square, are just individual scalars, single numbers. So you have to be careful with that. You have to imagine there's multiple nodes here or you're gonna get your math wrong. So as an example, I would suggest that you go in the following 1, 2, 3, 4 in terms of the back propagation.
-
Unknown A
So back propagate into x hat, then into sigma squared, then into mu and then into x. Just like in a topological sort in micrograd we would go from right to left. You're doing the exact same thing, except you're doing it with symbols and on a piece of paper. So for number one, I'm not giving away too much. If you want DL of D xi hat, then we just take DL by dyi and multiply it by gamma because of this expression here, where any individual yi is just gamma times xi hat plus beta. So didn't help you too much there, but this gives you basically the derivatives for all the x hats. And so now try to go through this computational graph and derive what is DL by d sigma square and then what is DL by d mu and then what is DL by dx eventually?
-
Unknown A
So give it a go and I'm going to be revealing the answer one piece at a time. Okay, so to get DL by d sigma squared, we have to remember again, like I mentioned, that there are many X's X hats here. And remember that sigma squared is just a single individual number here. So when we look at the expression for DL sigma squared, we have that. We have to actually consider all the possible paths that we basically have that there's many x hats and they all feed off from, they all depend on sigma square. So sigma square has a large fan out, there's lots of arrows coming out from square sigma square into all the x hats. And then there's a back propagating signal from each x hat into sigma square. And that's why we actually need to sum over all those I's from I equal to 1 to m of the DL by D xi hat, which is the global gradient times D xi hat by d sigma squared, which is the local gradient of this operation here.
-
Unknown A
And then mathematically I'm just working it out here and I'm simplifying and you get a certain expression for DL by d sigma squared. And we're going to be using this expression when we back propagate into mu and then eventually into x. So now let's continue our back propagation into mu. So what is DL by d mu? Now again, be careful that mu influences x hat and X hat is actually lots of values. So for example, if our mini batch size is 32, as it is in our example that we were working on, Then this is 32 numbers and 32 arrows going back to mu and then mu going to sigma square is just a single arrow because sigma square is a scalar. So in total there are 33 arrows emanating from mu. And then all of them have gradients coming into mu, and they all need to be summed up.
-
Unknown A
And so that's why when we look at the expression for DL by dmu, I am summing up over all the gradients of DL by dx, I hatch times dxi hat by dmu. So that's the, that's this arrow, and that's 32 arrows here, and then plus the one arrow from here, which is DL by D Sigma squared times D Sigma squared by DMU. So now we have to work out that expression and let me just reveal the rest of it. Simplifying here is not complicated. The first term, and you just get an expression here. For the second term, though, there's something really interesting that happens when we look at d sigma squared by d and we simplify at one point. If we assume that in a special case where mu is actually the average of XI's as it is in this case, then if we plug that in, then actually the gradient vanishes and becomes exactly zero, and that makes the entire second term cancel.
-
Unknown A
And so if you just have a mathematical expression like this, and you look at D sigma squared by dmu, you would get some mathematical formula for how mu impacts sigma squared. But if it is the special case that mu is actually equal to the average, as it is in the case of batch normalization, that gradient will actually vanish and become zero. So the whole term cancels, and we just get a fairly straightforward expression here for DL by d mu. Okay, and now we get to the craziest part, which is deriving DL by dxi, which is ultimately what we're after. Now let's count first of all, how many numbers are there inside X? As I mentioned, there are 32 numbers, there are 32 little XI's. And let's count the number of arrows emanating from each xi. There's an arrow going to mu, an arrow going to sigma square, and then there's an arrow going to X hat.
-
Unknown A
But this arrow here, let's scrutinize that a little bit. Each xi hat is just a function of xi and all the other scalars, so xi hat only depends on xi and none of the other X's. And so therefore there are actually, in this single arrow, there are 32 arrows. But those 32 arrows are going exactly parallel. They don't interfere, they're just going parallel between x and X hat. You can look at it that way. And so how many arrows are emanating from each xi? There are three arrows, mu sigma squared and the associated x hat. And so in back propagation, we now need to apply the chain rule and we need to add up those three contributions. So here's what that looks like if I just write that out. We have, we're going through. We're chaining through mu sigma square and through X hat. And those three terms are just here.
-
Unknown A
Now, we already have three of these. We have DL by d xi hat, we have DL by dmu, which, which we derived here. And we have DL by d sigma squared, which we derived here. But we need three other terms here. This one, this one, and this one. So I invite you to try to derive them. It's not that complicated. You're just looking at these expressions here and differentiating with respect to xi. So give it a shot. But here's the result, or at least what I got. Yeah, I'm just differentiating with respect to xi for all these expressions. And honestly, I don't think there's anything too tricky here. It's basic calculus. Now it gets a little bit more tricky is we are now going to plug everything together so all of these terms multiplied with all of these terms and add it up according to this formula, and that gets a little bit hairy.
-
Unknown A
So what ends up happening is you get a large expression. And the thing to be very careful with here, of course, is we are working with a DL by dxi for specific I here. But when we are plugging in some of these terms, like say this term here, DL by d sigma squared. You see how DL by d sigma squared I end up with an expression and I'm iterating over little I's here. But I can't use I as the variable when I plug in here, because this is a different I from this I. This I here is just a placeholder, like a local variable for, or a for loop in here. So here, when I plug that in, you notice that I rename the I to a j because I need to make sure that this J is not that this J is not this I.
-
Unknown A
This J is like, like a little local iterator over 32 terms. And so you have to be careful with that. When you're plugging in the expressions from here to here, you may have to rename I's into j's. And you have to be very careful what is actually an I with respect to DL by D Xi. So some of these are Js some of these are I's. And then we simplified this expression. And I guess, like the big thing to notice here is a bunch of terms just kind of come out to the front and you can refactor them. There's a sigma squared plus epsilon raised to the power of negative three over two. This sigma squared plus epsilon can be actually separated out into three terms. Each of them are sigma squared plus epsilon to the negative one over two. So the three of them multiplied is equal to this.
-
Unknown A
And then those three terms can go different places because of the multiplication. So one of them actually comes out to the front and will end up here outside. One of them joins up with this term, and one of them joins up with this other term. And then when you simplify the expression, you'll notice that some of these terms that are coming out are just the xi hats. So you can simplify just by rewriting that. And what we end up with at the end is a fairly simple mathematical expression over here that I cannot simplify further. But basically you'll notice that it only uses the stuff we have, and it derives the thing we need. So we have DL by DY for all the I's, and those are used plenty of times here. And also in addition, what we're using is these XI hats and XJ hats, and they just come from the forward pass and otherwise this is a simple expression, and it gives us the DL by D xi for all the I's, and that's ultimately what we're interested in.
-
Unknown A
So that's the end of batch norm, backward pass analytically. Let's now implement this final result. Okay, so I implemented the expression into a single line of code here, and you can see that the maxdiff is tiny. So this is the correct implementation of this formula. Now, I'll just basically tell you that getting this formula here from this mathematical expression was not trivial. And there's a lot going on packed into this one formula. And this is a whole exercise by itself, because you have to consider the fact that this formula here is just for a single neuron and a batch of 32 examples. But what I'm doing here is I'm actually. We actually have 64 neurons. And so this expression has to in parallel evaluate the batch from backward pass for all of those 64 neurons in parallel, independently. So this has to happen basically in every single column of the inputs here.
-
Unknown A
And in addition to that, you see how there are a bunch of sums here, and we need to make sure that when I do those sums that they broadcast correctly onto everything else that's here. And so getting this expression is just like highly non trivial. And I invite you to basically look through it and step through it, and it's a whole exercise to make sure that this checks out. But once all the shapes agree, and once you convince yourself that it's correct, you can also verify that Pytorch gets the exact same answer as well. And so that gives you a lot of peace of mind that this mathematical formula is correctly implemented here and broadcasted correctly and replicated in parallel for all of the 64 neurons inside this batch trim layer. Okay, and finally exercise number four asks you to put it all together. And here we have a redefinition of the entire problem.
-
Unknown A
So you see that we reinitialize the neural net from scratch and everything. And then here instead of calling loss backward, we want to have the manual back propagation here as we derived it up above. So go up, copy, paste all the chunks of code that we've already derived, put them here, and derive your own gradients, and then optimize this neural net basically using your own gradients, all the way to the calibration of the batch norm and the evaluation of the loss. And I was able to achieve quite a bit loss, basically the same loss you would achieve before. And that shouldn't be surprising because all we've done is we've really gone into lossatbackward and we've pulled out all the code and inserted it here. But those gradients are identical, and everything is identical, and the results are identical. It's just that we have full visibility on exactly what goes on under the hood of Lotta backward in this specific case.
-
Unknown A
Okay, and this is all of our code, this is the full backward pass using basically the simplified backward pass for the cross entropy loss and the batch normalization. So back propagating through cross entropy, the second layer, the 10h nonlinearity, the batch normalization through the first layer and through the embedding. And so you see that this is only maybe, what is this, 20 lines of code or something like that. And that's what gives us gradients. And now we can potentially erase losses backward. So the way I have the code set up is you should be able to run this entire cell once you fill this in, and this will run for only 100 iterations and then break. And it breaks because it gives you an opportunity to check your gradients against Pytorch. So here our gradients we see are not exactly equal, they are approximately equal and the differences are tiny.
-
Unknown A
One in negative nine or so, and I don't exactly know where they're coming from, to be honest. So once we have some confidence that the gradients are basically correct, we can take out the gradient checking, we can disable this breaking statement and then we can basically disable loss of backward. We don't need it anymore. Feels amazing to see that. And then here when we are doing the update, we're not going to use P grad. This is the old way of Pytorch. We don't have that anymore because we're not doing backward. We are going to use this update where we, you see that I'm iterating over. I've arranged the grads to be in the same order as the parameters and I'm zipping them up, the gradients and the parameters into P and grad. And then here I'm going to step with just the grad that we derived manually.
-
Unknown A
So the last piece is that none of this now requires gradients from Pytorch. And so one thing you can do here is, is you can do with torch no grad and offset this whole code block. And really what you're saying is you're telling Pytorch that hey, I'm not going to call backward on any of this. And this allows Pytorch to be a bit more efficient with all of it. And then we should be able to just run this and it's running and you see that lost and backward is commented out and we're optimizing. So we're going to leave this run and hopefully we get a good result. Okay, so I allowed the neural net to finish optimization. Then here I calibrate the bastion parameters because I did not keep track of the running mean and variance in the training loop. Then here I ran the loss and you see that we actually obtained a pretty good loss, very similar to what we've achieved before.
-
Unknown A
And then here I'm sampling from the model and we see some of the name like gibberish that we're sort of used to. So basically the model worked and samples pretty decent results compared to what we're used to. So everything is the same. But of course the big deal is that we did not use lots of backward, we did not use Pytorch autograd and we estimated our gradients ourselves by hand. And so hopefully you're looking at this, the backward pass of this neural net and you're thinking to yourself, actually that's not too complicated. Each one of these layers is like three lines of code or something. Like that. And most of it is fairly straightforward, potentially, with the notable exception of the batch normalization backward pass. Otherwise, it's pretty good. Okay. And that's everything I wanted to cover for this lecture. So hopefully you found this interesting.
-
Unknown A
And what I liked about it, honestly, is that it gave us a very nice diversity of layers to backpropagate through. And I think it gives a pretty nice and comprehensive sense of how these backward passes are implemented and how they work. And you'd be able to derive them yourself, but of course, in practice you probably don't want to, and you want to use the Pytorch autograd, but hopefully you have some intuition about how gradients flow backwards through the neural net, starting at the loss, and how they flow through all the variables and all the intermediate results. And if you understood a good chunk of it, and if you have a sense of that, then you can count yourself as one of these buff dojis on the left instead of the dojis on the right here. Now, in the next lecture, we're actually going to go to recurrent neural nets, LSTMs and all the other variants of RNNs, and we're going to start to complexify the architecture and start to achieve better log likelihoods.
-
Unknown A
And so I'm really looking forward to that, and I'll see you then.