1 00:00:01,680 --> 00:00:04,080 The following content is provided under a Creative 2 00:00:04,080 --> 00:00:05,620 Commons license. 3 00:00:05,620 --> 00:00:07,920 Your support will help MIT OpenCourseWare 4 00:00:07,920 --> 00:00:12,280 continue to offer high quality educational resources for free. 5 00:00:12,280 --> 00:00:14,910 To make a donation or view additional materials 6 00:00:14,910 --> 00:00:18,480 from hundreds of MIT courses, visit MIT OpenCourseWare 7 00:00:18,480 --> 00:00:19,670 at ocw.mit.edu. 8 00:00:22,780 --> 00:00:25,330 SURYA GANGULI: I'm going to talk about statistical physics 9 00:00:25,330 --> 00:00:26,630 of deep learning, essentially. 10 00:00:26,630 --> 00:00:28,852 So this is some ongoing work in my lab 11 00:00:28,852 --> 00:00:30,310 that was really motivated by trying 12 00:00:30,310 --> 00:00:32,650 to understand how neural networks and infants learn 13 00:00:32,650 --> 00:00:33,250 categories. 14 00:00:33,250 --> 00:00:35,300 And then it sort of led to a bunch of results 15 00:00:35,300 --> 00:00:38,080 in deep learning that involved statistical physics. 16 00:00:38,080 --> 00:00:40,864 So I wanted to just introduce my lab a little bit. 17 00:00:40,864 --> 00:00:42,280 I'm an interloper from the Methods 18 00:00:42,280 --> 00:00:44,071 in Computational Neuroscience Summer School 19 00:00:44,071 --> 00:00:45,730 where I tend to spend a month. 20 00:00:45,730 --> 00:00:48,301 And so there, you know, the flavor of research 21 00:00:48,301 --> 00:00:50,800 that we do there and the flavor of research we do in our lab 22 00:00:50,800 --> 00:00:53,740 is sort of drilling down into neural mechanisms underlying 23 00:00:53,740 --> 00:00:55,119 well-defined computations. 24 00:00:55,119 --> 00:00:56,410 And we've been working on that. 25 00:00:56,410 --> 00:00:58,240 You know, I spent a lot of time talking 26 00:00:58,240 --> 00:01:01,690 to neurophysiologists especially at Stanford where I am. 27 00:01:01,690 --> 00:01:03,520 So we have a whole bunch of collaborations 28 00:01:03,520 --> 00:01:06,160 going on now involving understanding 29 00:01:06,160 --> 00:01:07,060 neural computation. 30 00:01:07,060 --> 00:01:09,360 So, for example, the retina itself 31 00:01:09,360 --> 00:01:11,110 is actually a deep neural circuit already, 32 00:01:11,110 --> 00:01:14,830 because there's an intervening layer of neurons, 33 00:01:14,830 --> 00:01:16,600 the bipolar cells and amacrine cells, 34 00:01:16,600 --> 00:01:19,100 that intervene between the photoreceptors and the ganglion 35 00:01:19,100 --> 00:01:19,600 cells. 36 00:01:19,600 --> 00:01:22,090 And oftentimes, what we do is we shine light 37 00:01:22,090 --> 00:01:24,510 on the photo receptors and we measure the ganglion cells, 38 00:01:24,510 --> 00:01:27,190 but we have no clue what's going on in the interior. 39 00:01:27,190 --> 00:01:29,980 So we've developed computational methods that can successfully 40 00:01:29,980 --> 00:01:31,960 computationally reconstruct what's 41 00:01:31,960 --> 00:01:35,299 going on in the interior of this putative deep neural network 42 00:01:35,299 --> 00:01:37,340 even though we don't have access to these things. 43 00:01:37,340 --> 00:01:39,910 And we actually infer the existence and properties 44 00:01:39,910 --> 00:01:41,140 of intermediate neurons here. 45 00:01:41,140 --> 00:01:44,320 And those properties are sort of similar to what's 46 00:01:44,320 --> 00:01:47,590 been previously recorded when people do directly record 47 00:01:47,590 --> 00:01:49,450 from, say, the bipolar cells. 48 00:01:49,450 --> 00:01:51,160 In the Clandinin Lab, we've sort of 49 00:01:51,160 --> 00:01:58,977 been unraveling the computations underlying a motion vision. 50 00:01:58,977 --> 00:02:01,060 So you know when you swat a fly, it's really hard. 51 00:02:01,060 --> 00:02:02,643 Because they can really quickly detect 52 00:02:02,643 --> 00:02:04,030 motion coming towards it. 53 00:02:04,030 --> 00:02:06,550 And they fly away. 54 00:02:06,550 --> 00:02:08,860 You know, so there's been lots of work 55 00:02:08,860 --> 00:02:12,190 on what kinds of algorithms might underlie 56 00:02:12,190 --> 00:02:14,260 motion estimation-- for example, the Reichardt 57 00:02:14,260 --> 00:02:16,492 correlator the, Barlow-Levick model and so forth. 58 00:02:16,492 --> 00:02:18,700 We've been applying systems identification techniques 59 00:02:18,700 --> 00:02:21,220 to whole brain calcium imaging data-- 60 00:02:21,220 --> 00:02:23,980 well, whole brain meaning from the fly visual circuit. 61 00:02:23,980 --> 00:02:27,630 And we just literally identify the computation. 62 00:02:27,630 --> 00:02:29,630 And we find that it's sort of none of the above. 63 00:02:29,630 --> 00:02:32,470 It's a mixture of all previous approaches. 64 00:02:32,470 --> 00:02:34,960 So grid cells-- we have some results on grid cells. 65 00:02:34,960 --> 00:02:37,630 So these are these famous cells that resulted in a Nobel Prize 66 00:02:37,630 --> 00:02:38,680 recently. 67 00:02:38,680 --> 00:02:41,110 We can actually show that these grid cells maintain 68 00:02:41,110 --> 00:02:43,790 their spatial coherence, because the rat and the mouse 69 00:02:43,790 --> 00:02:46,150 are always interacting with the boundaries. 70 00:02:46,150 --> 00:02:48,160 And the boundaries actually correct the rat's 71 00:02:48,160 --> 00:02:50,020 internal estimate of position. 72 00:02:50,020 --> 00:02:51,685 And were it not for these interactions 73 00:02:51,685 --> 00:02:54,730 with the boundaries, the grid cells would rapidly 74 00:02:54,730 --> 00:02:57,250 decohere on the time scale of minutes, 75 00:02:57,250 --> 00:02:59,105 you know, like less than a minute. 76 00:02:59,105 --> 00:03:00,730 And so it's actually quite interesting. 77 00:03:00,730 --> 00:03:04,750 We can show that whenever the rat encounters a boundary, 78 00:03:04,750 --> 00:03:06,880 it corrects its internal estimate of position 79 00:03:06,880 --> 00:03:08,830 perpendicular to the boundary. 80 00:03:08,830 --> 00:03:09,980 But it doesn't parallel. 81 00:03:09,980 --> 00:03:11,170 And it doesn't, because it can't. 82 00:03:11,170 --> 00:03:12,586 Because when it hits the boundary, 83 00:03:12,586 --> 00:03:14,800 it receives no information about where it is parallel 84 00:03:14,800 --> 00:03:17,812 to the boundary, but it receives information in this direction. 85 00:03:17,812 --> 00:03:19,270 And then with the Shenoy Lab, we've 86 00:03:19,270 --> 00:03:22,300 been looking at I think a really major conceptual puzzle 87 00:03:22,300 --> 00:03:23,590 in neuroscience. 88 00:03:23,590 --> 00:03:26,800 Why can we record from 100 neurons in a circuit containing 89 00:03:26,800 --> 00:03:28,690 millions of neurons? 90 00:03:28,690 --> 00:03:33,520 And do dimensionality detection and try to infer the state 91 00:03:33,520 --> 00:03:35,920 space dynamics of the circuit and claim 92 00:03:35,920 --> 00:03:37,570 that we've achieved success, right? 93 00:03:37,570 --> 00:03:40,240 We're doing dramatic undersampling recording 94 00:03:40,240 --> 00:03:43,090 100 neurons out of millions. 95 00:03:43,090 --> 00:03:46,394 How would the state space dynamics that we infer change 96 00:03:46,394 --> 00:03:47,560 if we recorded more neurons? 97 00:03:47,560 --> 00:03:49,480 And we can show that, essentially, it 98 00:03:49,480 --> 00:03:50,290 will not change. 99 00:03:50,290 --> 00:03:53,177 Because we've come up with a novel connection 100 00:03:53,177 --> 00:03:54,760 between the act of neural measurements 101 00:03:54,760 --> 00:03:56,814 and the act of random projections. 102 00:03:56,814 --> 00:03:58,480 So we can show that the act of recording 103 00:03:58,480 --> 00:04:03,130 from 100 neurons in the brain is like the act of measuring 104 00:04:03,130 --> 00:04:05,650 100 random linear combinations of all neurons 105 00:04:05,650 --> 00:04:07,060 in the relevant brain circuit. 106 00:04:07,060 --> 00:04:08,976 And then we can apply random projection theory 107 00:04:08,976 --> 00:04:11,590 to give us a predictive theory of experimental design 108 00:04:11,590 --> 00:04:14,934 that tells us, given the complexity of the task, how 109 00:04:14,934 --> 00:04:17,350 many neurons would you need to record to correctly recover 110 00:04:17,350 --> 00:04:19,360 the state-space dynamics of the circuit. 111 00:04:19,360 --> 00:04:21,160 And then in the Raymond Lab at Stanford, 112 00:04:21,160 --> 00:04:25,750 we've been looking at how enhancing synaptic plasticity 113 00:04:25,750 --> 00:04:27,760 can either enhance or impair learning 114 00:04:27,760 --> 00:04:28,800 depending on experience. 115 00:04:28,800 --> 00:04:32,110 So, for example, we think that synaptic plasticity 116 00:04:32,110 --> 00:04:35,230 underlies the very basis of our ability to learn and remember. 117 00:04:35,230 --> 00:04:37,660 So you might think that enhancing synaptic plasticity 118 00:04:37,660 --> 00:04:41,050 through various pharmacological or genetic modifications 119 00:04:41,050 --> 00:04:43,870 might enhance our ability to learn and remember. 120 00:04:43,870 --> 00:04:45,430 But previous results have been mixed. 121 00:04:45,430 --> 00:04:47,140 When you perturb synaptic plasticity, 122 00:04:47,140 --> 00:04:50,230 sometimes you enhance learning, sometimes you impair learning. 123 00:04:50,230 --> 00:04:52,150 So I believe in the Raymond Lab, they're 124 00:04:52,150 --> 00:04:57,484 the first to show that in the same subject enhancing 125 00:04:57,484 --> 00:04:58,900 syntactic plasticity, for example, 126 00:04:58,900 --> 00:05:00,859 in the cerebellum can either enhance 127 00:05:00,859 --> 00:05:02,650 or impair learning depending on the history 128 00:05:02,650 --> 00:05:04,060 of prior experience. 129 00:05:04,060 --> 00:05:06,850 And we can show that in order to explain the behavioral learning 130 00:05:06,850 --> 00:05:09,520 curves, you need much more complex postsynaptic dynamics 131 00:05:09,520 --> 00:05:11,200 than people naturally assume. 132 00:05:11,200 --> 00:05:14,080 We have to really promote our notion of what a synapse is 133 00:05:14,080 --> 00:05:17,620 from a single scalar, like a WIJ at a neural network 134 00:05:17,620 --> 00:05:20,300 to an entire dynamical system in its own right. 135 00:05:20,300 --> 00:05:21,819 So this relates to VOR learning. 136 00:05:21,819 --> 00:05:24,360 So this is sort of the low-level stuff that we've been doing. 137 00:05:24,360 --> 00:05:27,537 I know that in this course you guys study higher level stuff. 138 00:05:27,537 --> 00:05:29,370 So I'm not going to talk about any of these. 139 00:05:29,370 --> 00:05:31,650 Each of them could be like a one hour talk. 140 00:05:31,650 --> 00:05:33,720 But I wanted to discuss some more high level 141 00:05:33,720 --> 00:05:35,350 stuff that we've been doing. 142 00:05:35,350 --> 00:05:35,850 Oh, sorry. 143 00:05:35,850 --> 00:05:38,850 And sorry, the other sort of direction in our lab 144 00:05:38,850 --> 00:05:40,291 that we're looking at is actually 145 00:05:40,291 --> 00:05:42,790 the statistical mechanics of high dimensional data analysis. 146 00:05:42,790 --> 00:05:44,748 So this is, of course, very relevant in the age 147 00:05:44,748 --> 00:05:47,442 of the BRAIN Initiative and so on where we're developing 148 00:05:47,442 --> 00:05:48,900 very large scale data sets and we'd 149 00:05:48,900 --> 00:05:52,480 like to extract theories from this so-called big data. 150 00:05:52,480 --> 00:05:54,960 So the entire edifice of classical statistics 151 00:05:54,960 --> 00:05:58,290 is predicated on the assumption that you have many, many data 152 00:05:58,290 --> 00:06:01,820 points and you have a small number of features, right? 153 00:06:01,820 --> 00:06:04,500 So then it's very easy to see patterns in your data. 154 00:06:04,500 --> 00:06:06,270 But what's actually happening, nowadays, 155 00:06:06,270 --> 00:06:09,175 is that we have a large number of data points 156 00:06:09,175 --> 00:06:11,175 and a large number of features-- so for example, 157 00:06:11,175 --> 00:06:15,720 where you can record from 100 neurons using electrophysiology 158 00:06:15,720 --> 00:06:17,910 maybe for only 100 trials of any given trial 159 00:06:17,910 --> 00:06:20,100 type in a monkey doing some task. 160 00:06:20,100 --> 00:06:24,640 So the ratio of the amount of data to the number of features 161 00:06:24,640 --> 00:06:25,890 is something that's order one. 162 00:06:25,890 --> 00:06:27,982 So you know, the data sets are like three points 163 00:06:27,982 --> 00:06:29,190 in a three-dimensional space. 164 00:06:29,190 --> 00:06:30,820 That's the best we can visualize. 165 00:06:30,820 --> 00:06:32,970 So it's a significant challenge to do data analysis 166 00:06:32,970 --> 00:06:34,230 in this scenario. 167 00:06:34,230 --> 00:06:36,120 So it turns out there's beautiful connections 168 00:06:36,120 --> 00:06:37,320 between machine learning and data 169 00:06:37,320 --> 00:06:39,278 analysis and the statistical physics of systems 170 00:06:39,278 --> 00:06:41,040 of quenched disorder. 171 00:06:41,040 --> 00:06:43,260 So what I mean by that is in data analysis 172 00:06:43,260 --> 00:06:45,060 you want to learn statistical parameters 173 00:06:45,060 --> 00:06:47,760 by maximizing the log likelihood of the data given 174 00:06:47,760 --> 00:06:48,850 the parameters. 175 00:06:48,850 --> 00:06:51,060 In statistical physics, you often want to minimize. 176 00:06:51,060 --> 00:06:53,550 You know, this can be viewed as energy minimization. 177 00:06:53,550 --> 00:06:55,330 And so there's beautiful connections between these. 178 00:06:55,330 --> 00:06:56,290 And so we work on that. 179 00:06:56,290 --> 00:06:58,866 So we've applied this to compressed sensing. 180 00:06:58,866 --> 00:07:00,240 We've applied this to the problem 181 00:07:00,240 --> 00:07:03,570 of what's the optimal inference procedure in a regime, 182 00:07:03,570 --> 00:07:04,890 in a high-dimensional regime. 183 00:07:04,890 --> 00:07:07,350 We know that maximum likelihood is optimal in this regime. 184 00:07:07,350 --> 00:07:09,750 But something else is better in this regime. 185 00:07:09,750 --> 00:07:11,230 And we found what's better. 186 00:07:11,230 --> 00:07:13,350 It turns out to be a smooth maximum likelihood. 187 00:07:13,350 --> 00:07:14,310 And, of course, we've applied this 188 00:07:14,310 --> 00:07:16,554 to a theory of neural dimensionality and measurement. 189 00:07:16,554 --> 00:07:18,720 So, you know, there's lots of beautiful interactions 190 00:07:18,720 --> 00:07:21,930 between physics, machine learning, and neuroscience. 191 00:07:21,930 --> 00:07:24,480 And, you know, this school is a lot about that. 192 00:07:24,480 --> 00:07:27,450 If you're interested, actually, we wrote like a 70-page review 193 00:07:27,450 --> 00:07:29,610 on statistical mechanics of complex neural systems 194 00:07:29,610 --> 00:07:30,930 and high-dimensional data. 195 00:07:30,930 --> 00:07:34,040 We cover a whole bunch of things like spin glasses, 196 00:07:34,040 --> 00:07:36,540 the statistical mechanics of learning, random matrix theory, 197 00:07:36,540 --> 00:07:38,706 random dimensionality reduction, compressed sensing, 198 00:07:38,706 --> 00:07:39,820 and so on and so forth. 199 00:07:39,820 --> 00:07:41,610 It was our attempt to sort of put 200 00:07:41,610 --> 00:07:45,260 some systematic order on the diversity of topics 201 00:07:45,260 --> 00:07:48,295 viewed through the lens of statistical physics. 202 00:07:48,295 --> 00:07:49,920 But what do I want to talk about today? 203 00:07:49,920 --> 00:07:52,420 And why did I decide to branch out in the direction 204 00:07:52,420 --> 00:07:53,310 that I'm going to tell you about? 205 00:07:53,310 --> 00:07:55,020 Well, I think there's lots of motivations 206 00:07:55,020 --> 00:07:57,679 for the alliance between theoretical neuroscience, 207 00:07:57,679 --> 00:07:59,970 theoretical machine learning that lead to opportunities 208 00:07:59,970 --> 00:08:03,010 for physics, and math, and so on. 209 00:08:03,010 --> 00:08:06,174 So this is the question that should haunt all of us, right? 210 00:08:06,174 --> 00:08:07,590 The question is, what does it even 211 00:08:07,590 --> 00:08:09,570 mean to understand how the brain works 212 00:08:09,570 --> 00:08:11,040 or how a neural circuit works? 213 00:08:11,040 --> 00:08:12,860 OK? 214 00:08:12,860 --> 00:08:14,540 You know, that's an open question 215 00:08:14,540 --> 00:08:16,290 that we really have to come to terms with. 216 00:08:21,690 --> 00:08:23,820 A more concrete version of this question might be, 217 00:08:23,820 --> 00:08:25,410 or a specification of this question 218 00:08:25,410 --> 00:08:29,730 might be, we will understand this when we understand 219 00:08:29,730 --> 00:08:32,340 how the connectivity and dynamics of a neural circuit 220 00:08:32,340 --> 00:08:35,190 give rise to behavior and also how neural activity 221 00:08:35,190 --> 00:08:36,720 and synaptic learning rules conspire 222 00:08:36,720 --> 00:08:38,640 to self-organize useful connectivity that 223 00:08:38,640 --> 00:08:40,240 subserves behavior. 224 00:08:40,240 --> 00:08:40,740 OK? 225 00:08:43,289 --> 00:08:45,390 So, you know, various BRAIN Initiatives 226 00:08:45,390 --> 00:08:47,670 are promising to give us recording 227 00:08:47,670 --> 00:08:49,260 some large numbers of neurons and even 228 00:08:49,260 --> 00:08:52,830 give us the connectivity between those neurons. 229 00:08:52,830 --> 00:08:54,080 Now, what have theorists done? 230 00:08:54,080 --> 00:08:56,280 Often what theorists in computational neuroscience 231 00:08:56,280 --> 00:08:59,520 do is they often develop theories of random networks 232 00:08:59,520 --> 00:09:00,780 that have no function. 233 00:09:00,780 --> 00:09:02,280 But what we would like to do is we'd 234 00:09:02,280 --> 00:09:05,160 like to understand engineered networks that have function. 235 00:09:05,160 --> 00:09:07,890 So the field of machine learning has generated a plethora 236 00:09:07,890 --> 00:09:09,840 of learned neural networks that accomplish 237 00:09:09,840 --> 00:09:11,130 interesting functions. 238 00:09:11,130 --> 00:09:13,740 Yet we still don't have a meaningful understanding 239 00:09:13,740 --> 00:09:16,870 of how their connectivity, dynamics, the learning rule, 240 00:09:16,870 --> 00:09:18,120 the developmental experience-- 241 00:09:18,120 --> 00:09:18,300 OK. 242 00:09:18,300 --> 00:09:19,841 So basically, we can measure anything 243 00:09:19,841 --> 00:09:22,300 we want in these artificial neural networks, right? 244 00:09:22,300 --> 00:09:23,759 We can measure all the connectivity 245 00:09:23,759 --> 00:09:24,758 between all the neurons. 246 00:09:24,758 --> 00:09:26,480 We know the dynamics of all the neurons. 247 00:09:26,480 --> 00:09:28,140 We know the learning rule. 248 00:09:28,140 --> 00:09:30,840 We know the entire developmental experience of the network, 249 00:09:30,840 --> 00:09:33,630 because we know the training data that it was exposed to. 250 00:09:33,630 --> 00:09:35,970 Yet we still do not have a meaningful understanding 251 00:09:35,970 --> 00:09:37,920 of how they learn and how they work. 252 00:09:37,920 --> 00:09:40,470 And if we can't solve that problem, how are we ever going 253 00:09:40,470 --> 00:09:44,450 to understand the brain, right, in the form of this question? 254 00:09:44,450 --> 00:09:45,206 OK? 255 00:09:45,206 --> 00:09:46,830 So that was sort of what was motivating 256 00:09:46,830 --> 00:09:48,420 me to look into this. 257 00:09:48,420 --> 00:09:49,962 So this is the outline of the talk. 258 00:09:49,962 --> 00:09:51,420 The original entry point was trying 259 00:09:51,420 --> 00:09:53,754 to understand category learning in neural networks. 260 00:09:53,754 --> 00:09:55,170 And then at the end of the day, we 261 00:09:55,170 --> 00:09:57,431 made actually several theoretical advances 262 00:09:57,431 --> 00:09:59,805 that led to advances in machine learning and applications 263 00:09:59,805 --> 00:10:01,270 to engineering. 264 00:10:01,270 --> 00:10:04,800 So, for example, we found random weight initializations 265 00:10:04,800 --> 00:10:07,020 that make a network dynamically critical 266 00:10:07,020 --> 00:10:08,670 and allow very, very rapid training 267 00:10:08,670 --> 00:10:10,830 of deep neural networks. 268 00:10:10,830 --> 00:10:12,570 We were able to understand and exploit 269 00:10:12,570 --> 00:10:17,790 the geometry of high-dimensional error surfaces to, again, speed 270 00:10:17,790 --> 00:10:20,430 up learning, like training deep neural networks. 271 00:10:20,430 --> 00:10:23,030 And we were also able to exploit sort of recent work 272 00:10:23,030 --> 00:10:25,070 in non-equilibrium thermodynamics 273 00:10:25,070 --> 00:10:28,880 to learn complex probabilistic generative models. 274 00:10:28,880 --> 00:10:32,450 So it's a diversity of topics, but I'll walk you through them. 275 00:10:32,450 --> 00:10:35,210 And, you know, you can relax, because almost everything 276 00:10:35,210 --> 00:10:38,490 I'm going to talk about is published. 277 00:10:38,490 --> 00:10:38,990 OK. 278 00:10:38,990 --> 00:10:40,448 So let's start with the motivation, 279 00:10:40,448 --> 00:10:42,470 a mathematical theory of semantic development. 280 00:10:42,470 --> 00:10:44,595 I think this speaks to some of the high level stuff 281 00:10:44,595 --> 00:10:45,800 that you guys think about. 282 00:10:45,800 --> 00:10:48,980 This part could be called the misadventures 283 00:10:48,980 --> 00:10:51,290 of an applied physicists who found himself lost 284 00:10:51,290 --> 00:10:52,534 in the psychology department. 285 00:10:52,534 --> 00:10:54,200 So I just sort of showed up at Stanford. 286 00:10:54,200 --> 00:10:54,950 Jay's a great guy. 287 00:10:54,950 --> 00:10:56,122 I was talking to him. 288 00:10:56,122 --> 00:10:57,330 And I learned about his work. 289 00:10:57,330 --> 00:10:59,190 And I realized it didn't understand it. 290 00:10:59,190 --> 00:11:01,610 And this is my attempt to understand that work 291 00:11:01,610 --> 00:11:03,780 with Andrew and Jay. 292 00:11:03,780 --> 00:11:04,280 OK. 293 00:11:04,280 --> 00:11:06,050 So what is semantic cognition? 294 00:11:06,050 --> 00:11:09,020 So human semantic cognition, a rough definition of this field, 295 00:11:09,020 --> 00:11:12,320 is that we have an ability to learn, recognize, comprehend, 296 00:11:12,320 --> 00:11:15,110 and produce inferences about properties of objects 297 00:11:15,110 --> 00:11:16,804 and events in the world, especially 298 00:11:16,804 --> 00:11:19,220 properties that are not present in your current perceptual 299 00:11:19,220 --> 00:11:19,770 stimulus. 300 00:11:19,770 --> 00:11:22,310 So, for example, I can ask you does a cat fur 301 00:11:22,310 --> 00:11:24,687 and do birds fly, and you can answer these questions 302 00:11:24,687 --> 00:11:27,020 correctly despite the fact that there's currently no cat 303 00:11:27,020 --> 00:11:29,360 or bird in the room, right? 304 00:11:29,360 --> 00:11:32,000 So, you know, our ability to do this likely 305 00:11:32,000 --> 00:11:34,580 relies on our ability to form internal representations 306 00:11:34,580 --> 00:11:37,564 of categories in the external world and associate properties 307 00:11:37,564 --> 00:11:38,480 with those categories. 308 00:11:38,480 --> 00:11:41,510 Because we never see the same stimulus twice. 309 00:11:41,510 --> 00:11:43,310 So whenever we see a new stimulus 310 00:11:43,310 --> 00:11:46,400 or we try to recall information from our brain, 311 00:11:46,400 --> 00:11:48,306 we rapidly identify the relevant category 312 00:11:48,306 --> 00:11:49,555 that contains the information. 313 00:11:49,555 --> 00:11:52,040 And we use that categorical representation 314 00:11:52,040 --> 00:11:54,890 to guide future actions or give answers. 315 00:11:54,890 --> 00:11:57,710 So category formation is central to this, right? 316 00:11:57,710 --> 00:11:59,690 So what are the kinds of psychophysical tasks 317 00:11:59,690 --> 00:12:02,770 that people use to probe semantic cognition? 318 00:12:02,770 --> 00:12:04,370 So this is a very rich field. 319 00:12:04,370 --> 00:12:07,110 Psychologists have been working on this all the time. 320 00:12:07,110 --> 00:12:08,990 So one example is looking time studies 321 00:12:08,990 --> 00:12:11,510 to ascertain whether or not an infant can 322 00:12:11,510 --> 00:12:13,490 distinguish between two categories of objects 323 00:12:13,490 --> 00:12:14,520 at what age. 324 00:12:14,520 --> 00:12:16,730 So, for example, they'll show a sequence of objects 325 00:12:16,730 --> 00:12:18,611 from category one, say, horses. 326 00:12:18,611 --> 00:12:21,110 And the first time the infant sees a horse, the looking time 327 00:12:21,110 --> 00:12:21,980 will go up. 328 00:12:21,980 --> 00:12:23,390 And then it goes down over time. 329 00:12:23,390 --> 00:12:24,380 It gets bored. 330 00:12:24,380 --> 00:12:25,340 Then you show a cow. 331 00:12:25,340 --> 00:12:27,740 And if the infant is old enough, the looking time 332 00:12:27,740 --> 00:12:29,720 will go up and then go back down. 333 00:12:29,720 --> 00:12:32,180 And from that, we infer that the infant can distinguish 334 00:12:32,180 --> 00:12:33,530 between horses and cows. 335 00:12:33,530 --> 00:12:36,300 But if it's not old enough, the looking time will not go up. 336 00:12:36,300 --> 00:12:38,270 So as the infant gets older and older, 337 00:12:38,270 --> 00:12:40,460 it can make more and more fine scale discriminations 338 00:12:40,460 --> 00:12:43,640 between categories it turns out. 339 00:12:43,640 --> 00:12:45,500 So property verification tasks-- 340 00:12:45,500 --> 00:12:46,970 you can ask, can a canary move? 341 00:12:46,970 --> 00:12:47,639 Can it sing? 342 00:12:47,639 --> 00:12:49,430 And certain questions are answered quickly. 343 00:12:49,430 --> 00:12:51,374 Certain questions are answered late, 344 00:12:51,374 --> 00:12:52,790 which speaks to certain properties 345 00:12:52,790 --> 00:12:55,880 being central and peripheral to certain categories. 346 00:12:55,880 --> 00:12:58,280 Category membership queries-- is a sparrow a bird, 347 00:12:58,280 --> 00:12:59,280 or is an ostrich a bird? 348 00:12:59,280 --> 00:13:01,490 Again, there's different latencies. 349 00:13:01,490 --> 00:13:04,310 And that suggests that there are typical and atypical category 350 00:13:04,310 --> 00:13:05,120 members. 351 00:13:05,120 --> 00:13:06,800 And also, very, very important to us 352 00:13:06,800 --> 00:13:08,540 is inductive generalization. 353 00:13:08,540 --> 00:13:11,420 We can both generalize familiar properties to novel objects-- 354 00:13:11,420 --> 00:13:12,920 for example, a blick has feathers. 355 00:13:12,920 --> 00:13:13,490 Does it fly? 356 00:13:13,490 --> 00:13:14,450 Does it sing? 357 00:13:14,450 --> 00:13:17,360 And we can generalize novel properties to familiar objects. 358 00:13:17,360 --> 00:13:18,640 A bird has gene x. 359 00:13:18,640 --> 00:13:19,940 Does a crocodile have gene x? 360 00:13:19,940 --> 00:13:21,650 Does a dog have gene x? 361 00:13:21,650 --> 00:13:23,652 You know, so people have measured these patterns 362 00:13:23,652 --> 00:13:24,860 of inductive generalizations. 363 00:13:24,860 --> 00:13:26,401 And there's various theories that try 364 00:13:26,401 --> 00:13:27,952 to explain all of this stuff. 365 00:13:27,952 --> 00:13:30,410 So Jay has been working on this stuff from a neural network 366 00:13:30,410 --> 00:13:30,980 perspective. 367 00:13:30,980 --> 00:13:34,070 And he wrote a beautiful book called Semantic Cognition where 368 00:13:34,070 --> 00:13:36,440 he uses neural networks to explain 369 00:13:36,440 --> 00:13:38,990 a whole variety of phenomena especially, 370 00:13:38,990 --> 00:13:43,490 for example, the progressive differentiation of concepts. 371 00:13:43,490 --> 00:13:45,150 So let me just walk you through that. 372 00:13:45,150 --> 00:13:49,080 And so this was, you know, a first encounter 373 00:13:49,080 --> 00:13:50,630 with a deep neural network. 374 00:13:50,630 --> 00:13:52,850 So they were doing deep neural networks 375 00:13:52,850 --> 00:13:54,680 before they became popular. 376 00:13:54,680 --> 00:13:57,380 And so what they were doing was they asked, 377 00:13:57,380 --> 00:14:01,190 can we model the development of, say, concepts in infants? 378 00:14:01,190 --> 00:14:03,720 And so what they did was they had a toy data 379 00:14:03,720 --> 00:14:06,230 set where they had a bunch of objects 380 00:14:06,230 --> 00:14:08,640 and each object had a whole bunch of properties. 381 00:14:08,640 --> 00:14:12,970 So, for example, a canary can grow, move, fly, and sing, 382 00:14:12,970 --> 00:14:13,760 right? 383 00:14:13,760 --> 00:14:15,980 And so what they did was they exposed 384 00:14:15,980 --> 00:14:19,577 this deep neural network to training data of this form. 385 00:14:19,577 --> 00:14:21,410 You know, they had a whole bunch of features 386 00:14:21,410 --> 00:14:22,610 and questions and objects. 387 00:14:22,610 --> 00:14:25,070 And they just exposed the network to training data, 388 00:14:25,070 --> 00:14:26,750 trained it using back propagation. 389 00:14:26,750 --> 00:14:29,700 And they looked at the internal representations in the network, 390 00:14:29,700 --> 00:14:32,810 especially their evolution over developmental time or training 391 00:14:32,810 --> 00:14:34,280 time, right? 392 00:14:34,280 --> 00:14:35,570 And this is what they found. 393 00:14:35,570 --> 00:14:37,820 So initially, the network started with random weights. 394 00:14:37,820 --> 00:14:40,080 So there was no structure. 395 00:14:40,080 --> 00:14:40,580 OK. 396 00:14:40,580 --> 00:14:42,200 So what did they do here? 397 00:14:42,200 --> 00:14:45,470 They looked at the distances between the internal 398 00:14:45,470 --> 00:14:48,020 representations in this space. 399 00:14:48,020 --> 00:14:50,420 And they did hierarchical clustering or multidimensional 400 00:14:50,420 --> 00:14:51,030 scaling. 401 00:14:51,030 --> 00:14:53,640 And they found these plots or these plots, right? 402 00:14:53,640 --> 00:14:57,050 So what you see is, early in developmental time, 403 00:14:57,050 --> 00:14:59,900 the network first makes a coarse-grain discrimination 404 00:14:59,900 --> 00:15:05,240 between animals and plants, right? 405 00:15:05,240 --> 00:15:07,877 And then later, it makes finer scale discriminations. 406 00:15:07,877 --> 00:15:09,710 And then eventually when it's fully learned, 407 00:15:09,710 --> 00:15:11,501 it learns the hierarchical structure that's 408 00:15:11,501 --> 00:15:13,440 implicit in the training data. 409 00:15:13,440 --> 00:15:13,940 OK. 410 00:15:13,940 --> 00:15:15,731 And this is a multidimensional scaling plot 411 00:15:15,731 --> 00:15:18,240 where initially the animals move away from the plants, 412 00:15:18,240 --> 00:15:22,430 and then, you know, fish move away from birds, 413 00:15:22,430 --> 00:15:24,920 and trees move away from flowers. 414 00:15:24,920 --> 00:15:28,010 And then finally, individual discriminations are learned. 415 00:15:28,010 --> 00:15:31,370 So when I learned about this, I was at once excited 416 00:15:31,370 --> 00:15:32,630 and also mystified. 417 00:15:32,630 --> 00:15:35,030 Because this is sort of qualitatively 418 00:15:35,030 --> 00:15:36,990 behaving like the way that an infant behaves, 419 00:15:36,990 --> 00:15:40,760 yet it's a really stupid neural network with like five layers. 420 00:15:40,760 --> 00:15:45,020 Yet I don't understand how it's doing this, right? 421 00:15:45,020 --> 00:15:49,640 So I wanted a theory of what's going on here, right? 422 00:15:49,640 --> 00:15:51,140 Oh and by the way, you know, there's 423 00:15:51,140 --> 00:15:56,030 lots of reasons to believe that semantic relationships are 424 00:15:56,030 --> 00:15:59,180 encoded in the brain using relatively simple metrics 425 00:15:59,180 --> 00:16:01,490 like Euclidean distance between neural representations 426 00:16:01,490 --> 00:16:02,625 for different objects. 427 00:16:02,625 --> 00:16:04,250 So, for example, this is a famous study 428 00:16:04,250 --> 00:16:05,594 which I'm sure you've seen. 429 00:16:05,594 --> 00:16:07,010 What they showed was a whole bunch 430 00:16:07,010 --> 00:16:10,160 of objects to both monkeys and humans. 431 00:16:10,160 --> 00:16:13,010 And they clustered the objects or looked at similarity matrix 432 00:16:13,010 --> 00:16:16,370 of the objects measured using a Euclidean distance 433 00:16:16,370 --> 00:16:19,280 in neural electrophysiology space, 434 00:16:19,280 --> 00:16:22,190 so fine rates of neurons here and voxel activity 435 00:16:22,190 --> 00:16:23,180 patterns in the human. 436 00:16:23,180 --> 00:16:25,680 And they showed the same set of objects to monkey and human. 437 00:16:25,680 --> 00:16:28,040 And the matrices aligned, essentially. 438 00:16:28,040 --> 00:16:29,690 So basically, the similarity structure 439 00:16:29,690 --> 00:16:31,950 of internal representations of both monkey and human 440 00:16:31,950 --> 00:16:32,930 is the same. 441 00:16:32,930 --> 00:16:37,449 So we tend to encode semantic information using 442 00:16:37,449 --> 00:16:38,740 the similarity representations. 443 00:16:38,740 --> 00:16:42,440 So this is the hierarchical clustering view. 444 00:16:42,440 --> 00:16:44,210 So this sort of seems to actually happen 445 00:16:44,210 --> 00:16:46,590 in real live animals and humans. 446 00:16:46,590 --> 00:16:47,090 OK. 447 00:16:47,090 --> 00:16:48,500 There's actually something else that happens. 448 00:16:48,500 --> 00:16:50,950 It's that different properties are learned on different time 449 00:16:50,950 --> 00:16:51,450 scales. 450 00:16:51,450 --> 00:16:53,690 So, for example, the network can learn 451 00:16:53,690 --> 00:16:55,880 that canaries can move much more quickly than it 452 00:16:55,880 --> 00:16:58,790 learns that a canary is yellow. 453 00:16:58,790 --> 00:17:01,840 So some properties are much easier to learn than others. 454 00:17:01,840 --> 00:17:04,849 And the properties that are easier to learn for the network 455 00:17:04,849 --> 00:17:07,490 are also easier to learn for the infant. 456 00:17:07,490 --> 00:17:08,210 OK. 457 00:17:08,210 --> 00:17:10,680 So these are the theoretical questions we'd like to answer. 458 00:17:10,680 --> 00:17:12,619 What are the mathematical principles 459 00:17:12,619 --> 00:17:14,690 underlying the hierarchical self-organization 460 00:17:14,690 --> 00:17:17,472 of internal representations in the network? 461 00:17:17,472 --> 00:17:18,930 You know, this is a complex system. 462 00:17:18,930 --> 00:17:21,429 So what are the relative roles of the various ingredients? 463 00:17:21,429 --> 00:17:23,220 There's a non-linear input-output response. 464 00:17:23,220 --> 00:17:25,512 There's a learning rule, which is back propagation. 465 00:17:25,512 --> 00:17:26,720 There's the input statistics. 466 00:17:26,720 --> 00:17:29,810 Is the network somehow reaching into complex input statistics 467 00:17:29,810 --> 00:17:33,710 in the training set, or can it really 468 00:17:33,710 --> 00:17:36,980 rely on just second order statistics? 469 00:17:36,980 --> 00:17:39,350 You know, what is a mathematical definition of something 470 00:17:39,350 --> 00:17:40,460 called category coherence? 471 00:17:40,460 --> 00:17:42,420 And how does it relate to the speed of category learning? 472 00:17:42,420 --> 00:17:45,196 So what determines the speed at which we learn categories? 473 00:17:45,196 --> 00:17:47,570 Why are some properties learned more quickly than others? 474 00:17:47,570 --> 00:17:49,310 And how can we explain changing patterns 475 00:17:49,310 --> 00:17:52,270 of inductive generalization over these developmental timescales? 476 00:17:56,051 --> 00:17:56,550 OK. 477 00:17:56,550 --> 00:17:58,640 So how do we get a theory? 478 00:17:58,640 --> 00:18:00,710 Well, it turns out if you look at the activations 479 00:18:00,710 --> 00:18:03,660 of this network as it's training over time, 480 00:18:03,660 --> 00:18:05,060 so these are sigmoidal units. 481 00:18:05,060 --> 00:18:08,487 And the activations don't really hit the saturating regime 482 00:18:08,487 --> 00:18:10,070 that much during training, because you 483 00:18:10,070 --> 00:18:12,050 start from small weights. 484 00:18:12,050 --> 00:18:13,910 So we started with an audacious proposal 485 00:18:13,910 --> 00:18:16,850 that maybe even a linear neural network might exhibit 486 00:18:16,850 --> 00:18:18,560 this kind of learning dynamics. 487 00:18:18,560 --> 00:18:19,330 OK? 488 00:18:19,330 --> 00:18:21,230 Now, it's not at all obvious that it should, 489 00:18:21,230 --> 00:18:23,063 because it's a simple linear neural network. 490 00:18:23,063 --> 00:18:26,540 And this learning dynamics is highly non-linear, right? 491 00:18:26,540 --> 00:18:28,850 But it turns out that even in a linear neural network, 492 00:18:28,850 --> 00:18:31,200 the dynamics of learning on synaptic weight space 493 00:18:31,200 --> 00:18:32,090 is non-linear. 494 00:18:32,090 --> 00:18:34,170 And so there might be a hope that it might work 495 00:18:34,170 --> 00:18:36,270 and we might be able to get a coherent theory. 496 00:18:36,270 --> 00:18:36,770 OK. 497 00:18:36,770 --> 00:18:39,740 So what we did was we analyzed just a simple linear neural 498 00:18:39,740 --> 00:18:43,370 network that looks like this that goes from input layer 499 00:18:43,370 --> 00:18:45,180 to hidden layer to output layer. 500 00:18:45,180 --> 00:18:48,290 So the composite map is linear. 501 00:18:48,290 --> 00:18:50,540 OK. 502 00:18:50,540 --> 00:18:53,440 So we can write down dynamical equations and weight 503 00:18:53,440 --> 00:18:55,000 space for the learning dynamic. 504 00:18:55,000 --> 00:18:56,210 So this is the training data. 505 00:18:56,210 --> 00:18:58,490 And we can adjust the weights using back propagation. 506 00:18:58,490 --> 00:19:00,620 And these are the back propagation equations. 507 00:19:00,620 --> 00:19:03,350 And if we work in a limit where the learning is slow 508 00:19:03,350 --> 00:19:06,350 relative to the time it takes to cycle through the data set, 509 00:19:06,350 --> 00:19:09,500 you can take a continuous time limit, 510 00:19:09,500 --> 00:19:11,870 and you essentially get a non-linear set of equations 511 00:19:11,870 --> 00:19:13,550 in weight space, right? 512 00:19:13,550 --> 00:19:16,127 And the equations are cubic in the weights, right? 513 00:19:16,127 --> 00:19:18,710 And that's because the error is quartic in the weights, right? 514 00:19:18,710 --> 00:19:23,840 The error is the output minus w, w times the inputs squared. 515 00:19:23,840 --> 00:19:25,480 So the error is quartic in the weights. 516 00:19:25,480 --> 00:19:28,090 And so if you can differentiate the weights 517 00:19:28,090 --> 00:19:30,340 on the right-hand side, the gradient descent equations 518 00:19:30,340 --> 00:19:32,420 will be cubic in the weights. 519 00:19:32,420 --> 00:19:34,550 But there is one simplification that happens. 520 00:19:34,550 --> 00:19:37,520 Because the network is linear, it's 521 00:19:37,520 --> 00:19:40,340 learning dynamics is sensitive only to the second order 522 00:19:40,340 --> 00:19:42,650 statistics of the data, right? 523 00:19:42,650 --> 00:19:44,510 So in particular-- the input-input 524 00:19:44,510 --> 00:19:48,350 covariance matrix and the input-output covariance matrix. 525 00:19:48,350 --> 00:19:49,250 OK. 526 00:19:49,250 --> 00:19:52,010 So essentially, this network knows only 527 00:19:52,010 --> 00:19:53,390 about second order statistics. 528 00:19:53,390 --> 00:19:56,190 In our work here, the input statistics is white. 529 00:19:56,190 --> 00:19:58,390 So it's really only the input-output statistics 530 00:19:58,390 --> 00:19:59,750 that drives learning. 531 00:19:59,750 --> 00:20:00,590 OK? 532 00:20:00,590 --> 00:20:02,870 So this is a set of coupled non-linear differential 533 00:20:02,870 --> 00:20:03,902 equations. 534 00:20:03,902 --> 00:20:05,360 They're, in general, hard to solve. 535 00:20:05,360 --> 00:20:08,330 But we found solutions to them. 536 00:20:08,330 --> 00:20:11,240 We can express the solutions in terms of the singular value 537 00:20:11,240 --> 00:20:14,315 decomposition of the input-output covariance matrix. 538 00:20:17,201 --> 00:20:19,700 You know, any rectangular matrix has a unique singular value 539 00:20:19,700 --> 00:20:20,283 decomposition. 540 00:20:20,283 --> 00:20:22,400 In this context, we can think about 541 00:20:22,400 --> 00:20:24,590 the input-output covariance matrix as a matrix 542 00:20:24,590 --> 00:20:27,770 that maps input objects to feature attributes. 543 00:20:27,770 --> 00:20:29,780 And the singular vectors have an interpretation 544 00:20:29,780 --> 00:20:34,250 where these singular vectors essentially map objects 545 00:20:34,250 --> 00:20:36,020 into internal representations. 546 00:20:36,020 --> 00:20:37,910 The singular values amplify them. 547 00:20:37,910 --> 00:20:41,600 And then the columns of you are sort of feature synthesizers. 548 00:20:41,600 --> 00:20:44,550 The columns are sort of modes in the output feature space. 549 00:20:44,550 --> 00:20:45,050 OK. 550 00:20:45,050 --> 00:20:46,400 So this is the SVD. 551 00:20:46,400 --> 00:20:50,570 But the question is, how does this 552 00:20:50,570 --> 00:20:52,250 drive the learning dynamics? 553 00:20:52,250 --> 00:20:54,830 So what we did was we found exact solutions 554 00:20:54,830 --> 00:20:56,920 to the learning dynamics of this form 555 00:20:56,920 --> 00:20:58,952 where the product of the layer one 556 00:20:58,952 --> 00:21:01,160 to layer two weights and the layer two to layer three 557 00:21:01,160 --> 00:21:03,710 weights are of this form. 558 00:21:03,710 --> 00:21:07,010 Where, essentially, the system, what it's doing-- 559 00:21:07,010 --> 00:21:09,500 the composite system-- is building up the singular value 560 00:21:09,500 --> 00:21:11,840 decomposition of the input-output covariance matrix 561 00:21:11,840 --> 00:21:13,270 mode by mode. 562 00:21:13,270 --> 00:21:16,850 And each mode alpha, associated with singular value alpha 563 00:21:16,850 --> 00:21:20,660 in the training data, is being learned 564 00:21:20,660 --> 00:21:22,770 in the sigmoidal fashion. 565 00:21:22,770 --> 00:21:24,860 OK? 566 00:21:24,860 --> 00:21:29,120 So at time zero, A is sort of small and random, 567 00:21:29,120 --> 00:21:31,310 you know, some initial condition A zero. 568 00:21:31,310 --> 00:21:36,950 But over time, as time training time goes to infinity, 569 00:21:36,950 --> 00:21:39,370 the A approaches the actual singular value 570 00:21:39,370 --> 00:21:41,760 in the input-output covariance matrix. 571 00:21:41,760 --> 00:21:43,600 So basically, this is the learning dynamic. 572 00:21:43,600 --> 00:21:45,560 So nothing happens for a while. 573 00:21:45,560 --> 00:21:48,110 And then suddenly, the strongest singular mode 574 00:21:48,110 --> 00:21:51,620 defined by the largest single value gets learned. 575 00:21:51,620 --> 00:21:54,830 And then later on, a smaller singular mode gets learned. 576 00:21:54,830 --> 00:21:58,100 And later on, an even smaller singular mode gets learned. 577 00:21:58,100 --> 00:22:00,410 And the time it takes to learn each mode 578 00:22:00,410 --> 00:22:02,820 is governed by one over the singular value. 579 00:22:02,820 --> 00:22:05,830 So just intuitively, stronger statistical structure 580 00:22:05,830 --> 00:22:09,140 as quantified by singular value is learned first. 581 00:22:09,140 --> 00:22:10,720 That's the intuition. 582 00:22:10,720 --> 00:22:12,470 Often time when we train neural networks, 583 00:22:12,470 --> 00:22:14,840 we see sort of these plateaus in performance 584 00:22:14,840 --> 00:22:17,330 where the network does nothing and then suddenly drops, 585 00:22:17,330 --> 00:22:18,522 plateaus and drops. 586 00:22:18,522 --> 00:22:19,730 And this actually shows that. 587 00:22:19,730 --> 00:22:22,564 You can see very, very sharp transitions in learning. 588 00:22:22,564 --> 00:22:24,980 And you can actually show that the ratio of the transition 589 00:22:24,980 --> 00:22:27,950 period to the ignorance period can be arbitrarily small. 590 00:22:27,950 --> 00:22:32,510 Infants also seem to show these developmental transitions. 591 00:22:32,510 --> 00:22:34,390 OK. 592 00:22:34,390 --> 00:22:37,340 So, yeah, you can have arbitrarily sharp transitions 593 00:22:37,340 --> 00:22:39,330 in the system. 594 00:22:39,330 --> 00:22:39,830 OK. 595 00:22:39,830 --> 00:22:42,530 So the take-home messages so far is the network 596 00:22:42,530 --> 00:22:44,720 learns different modes of covariation between input 597 00:22:44,720 --> 00:22:46,678 and output on time scale inversely proportional 598 00:22:46,678 --> 00:22:49,160 to the statistical strength of that covariation. 599 00:22:49,160 --> 00:22:51,770 And you can get these sudden transitions in learning. 600 00:22:51,770 --> 00:22:53,228 Now the question is, what does this 601 00:22:53,228 --> 00:22:56,020 have to do with the hierarchical differentiation of concepts? 602 00:22:56,020 --> 00:22:57,580 All right, that's what we'd like to understand first. 603 00:22:57,580 --> 00:23:00,230 So now we've come up with a general theory of learning, 604 00:23:00,230 --> 00:23:01,850 the non-linear dynamics of learning 605 00:23:01,850 --> 00:23:03,710 in these deep circuits. 606 00:23:03,710 --> 00:23:06,360 Now we want to connect this back to hierarchical structure. 607 00:23:06,360 --> 00:23:08,030 So one of the things with Jay's work 608 00:23:08,030 --> 00:23:09,740 is that we're just working with toy data sets. 609 00:23:09,740 --> 00:23:11,490 And we didn't have any theoretical control 610 00:23:11,490 --> 00:23:12,890 over those toy data sets. 611 00:23:12,890 --> 00:23:14,510 But we sort of understood implicitly 612 00:23:14,510 --> 00:23:16,940 that these toy data sets have hierarchical structure. 613 00:23:16,940 --> 00:23:18,680 So we need a generative model of data, 614 00:23:18,680 --> 00:23:21,290 a controlled mathematically well-defined generative 615 00:23:21,290 --> 00:23:24,450 model of data that encodes the notion of hierarchy. 616 00:23:24,450 --> 00:23:26,990 OK? 617 00:23:26,990 --> 00:23:29,210 So can we move beyond specific data sets 618 00:23:29,210 --> 00:23:31,312 to general principles of what happens 619 00:23:31,312 --> 00:23:33,770 when a neural network is exposed to hierarchical structure? 620 00:23:33,770 --> 00:23:36,400 That's what we'd like to answer. 621 00:23:36,400 --> 00:23:38,510 So we consider a hierarchical generative model. 622 00:23:38,510 --> 00:23:41,558 And a classic hierarchical generative model is-- 623 00:23:44,416 --> 00:23:46,040 yeah, so essentially what we want to do 624 00:23:46,040 --> 00:23:48,290 is we want to connect the world of generative models 625 00:23:48,290 --> 00:23:49,950 to the world of neural networks. 626 00:23:49,950 --> 00:23:52,370 And, you know, that will connect the methods 627 00:23:52,370 --> 00:23:55,410 in computational neuroscience to this course eventually, right? 628 00:23:58,630 --> 00:23:59,130 Yeah. 629 00:23:59,130 --> 00:24:01,370 So we have data generated by some generative model. 630 00:24:01,370 --> 00:24:03,890 We take that data, and we expose it to a neural network. 631 00:24:03,890 --> 00:24:06,440 And we'd like to understand how the dynamics of learning 632 00:24:06,440 --> 00:24:09,151 depends on the parameters of the generative model. 633 00:24:09,151 --> 00:24:09,650 OK. 634 00:24:09,650 --> 00:24:11,870 So a natural generative model for defining 635 00:24:11,870 --> 00:24:14,990 hierarchical structure is a branching diffusion process 636 00:24:14,990 --> 00:24:17,630 that essentially mimics the process of evolution where 637 00:24:17,630 --> 00:24:20,480 properties diffuse down a tree and instantiate themselves 638 00:24:20,480 --> 00:24:21,580 across a set of items. 639 00:24:21,580 --> 00:24:22,921 So what do I mean by that? 640 00:24:22,921 --> 00:24:23,420 OK. 641 00:24:23,420 --> 00:24:24,740 OK. 642 00:24:24,740 --> 00:24:26,390 So basically imagine, for example, 643 00:24:26,390 --> 00:24:29,939 that your items are at the leaves of a tree, right? 644 00:24:29,939 --> 00:24:32,480 And you can imagine that this is a process of evolution where 645 00:24:32,480 --> 00:24:35,350 there is some ancestral state maybe for one property. 646 00:24:35,350 --> 00:24:36,749 So we do one property at a time. 647 00:24:36,749 --> 00:24:38,790 And the properties are independent of each other. 648 00:24:38,790 --> 00:24:43,100 This might be an ancestral state like can move, right? 649 00:24:43,100 --> 00:24:46,520 And then each time this property diffuses down the tree, 650 00:24:46,520 --> 00:24:49,160 there's a probability of flipping, OK? 651 00:24:49,160 --> 00:24:52,460 So maybe in this lineage which might correspond to animals, 652 00:24:52,460 --> 00:24:54,940 this doesn't flip, right? 653 00:24:54,940 --> 00:24:57,050 And in these lineages corresponding to plants, 654 00:24:57,050 --> 00:24:57,710 it does flip. 655 00:24:57,710 --> 00:24:59,010 So these things cannot move. 656 00:24:59,010 --> 00:25:00,140 And these things can move. 657 00:25:00,140 --> 00:25:01,431 And then maybe it doesn't flip. 658 00:25:01,431 --> 00:25:04,889 So all of these things inherit the property of moving. 659 00:25:04,889 --> 00:25:05,930 So these are the animals. 660 00:25:05,930 --> 00:25:07,320 And these things cannot move. 661 00:25:07,320 --> 00:25:08,420 So these are the plants. 662 00:25:08,420 --> 00:25:11,030 And then we do that for every single property independently. 663 00:25:11,030 --> 00:25:13,070 And we generate a set of feature vectors. 664 00:25:13,070 --> 00:25:14,757 So that's our generative model. 665 00:25:14,757 --> 00:25:16,340 So what are the statistical properties 666 00:25:16,340 --> 00:25:17,340 of the generative model? 667 00:25:17,340 --> 00:25:19,100 So essentially, because we know that we're 668 00:25:19,100 --> 00:25:20,684 analyzing these deeper linear networks 669 00:25:20,684 --> 00:25:22,974 and we know that the learning dynamics of such networks 670 00:25:22,974 --> 00:25:25,270 is driven only by the input-output covariance matrix, 671 00:25:25,270 --> 00:25:26,610 to understand the learning dynamics 672 00:25:26,610 --> 00:25:28,943 we just have to compute the singular values and singular 673 00:25:28,943 --> 00:25:31,784 vectors of hierarchically structured data generated 674 00:25:31,784 --> 00:25:32,450 in this fashion. 675 00:25:32,450 --> 00:25:33,570 And it's actually quite-- 676 00:25:33,570 --> 00:25:35,151 I mean, we did it. 677 00:25:35,151 --> 00:25:36,150 So here's what happened. 678 00:25:36,150 --> 00:25:38,300 So imagine a nice symmetric tree like this. 679 00:25:38,300 --> 00:25:39,940 So these are objects. 680 00:25:39,940 --> 00:25:43,270 If we look at the similarity structure of objects measured 681 00:25:43,270 --> 00:25:45,850 by dot product in the feature space generated by the features 682 00:25:45,850 --> 00:25:47,860 under this branching diffusion process, 683 00:25:47,860 --> 00:25:51,280 we get this nice blocks within blocks similarity structure 684 00:25:51,280 --> 00:25:54,040 where all the items on this branch-- 685 00:25:54,040 --> 00:25:56,890 you know this item and this item-- are slightly similar. 686 00:25:56,890 --> 00:25:58,900 This item and this item are even more similar. 687 00:25:58,900 --> 00:26:01,296 And, of course, each item is most similar to itself. 688 00:26:01,296 --> 00:26:02,920 So you have this hierarchical hierarchy 689 00:26:02,920 --> 00:26:04,720 of clusters that naturally arise because 690 00:26:04,720 --> 00:26:06,427 of this branching diffusion process. 691 00:26:06,427 --> 00:26:08,260 So what are the singular values and singular 692 00:26:08,260 --> 00:26:12,380 vectors of the associated input-output covariance matrix? 693 00:26:12,380 --> 00:26:15,310 Well, they turn out these are one set of singular vectors, 694 00:26:15,310 --> 00:26:16,930 the so-called object analyzers, which 695 00:26:16,930 --> 00:26:18,935 are functions across objects. 696 00:26:18,935 --> 00:26:20,560 There's another set of singular vectors 697 00:26:20,560 --> 00:26:22,018 that are functions across features, 698 00:26:22,018 --> 00:26:24,079 which I'm not showing you. 699 00:26:24,079 --> 00:26:25,870 But there's, of course, the duality, right? 700 00:26:25,870 --> 00:26:27,244 So that you get pairs of singular 701 00:26:27,244 --> 00:26:28,971 vectors for each single value. 702 00:26:28,971 --> 00:26:29,470 OK. 703 00:26:29,470 --> 00:26:32,170 So what's the singular vector associated with the largest 704 00:26:32,170 --> 00:26:32,800 singular value? 705 00:26:32,800 --> 00:26:35,200 Well, it's a uniform mode that's constant across all 706 00:26:35,200 --> 00:26:36,220 the objects. 707 00:26:36,220 --> 00:26:39,940 But the most interesting one, the next largest one, 708 00:26:39,940 --> 00:26:43,180 is the most lowest frequency function, essentially. 709 00:26:43,180 --> 00:26:47,020 It's constant along all the ancestors of this branch 710 00:26:47,020 --> 00:26:48,640 and a different constant along all 711 00:26:48,640 --> 00:26:50,230 the ancestors of this branch. 712 00:26:50,230 --> 00:26:51,910 So this singular vector, essentially, 713 00:26:51,910 --> 00:26:54,760 makes the most coarse grain discrimination 714 00:26:54,760 --> 00:26:57,144 in this hierarchically structured data set. 715 00:26:57,144 --> 00:26:58,560 The next set of singular vectors-- 716 00:26:58,560 --> 00:26:59,744 there's a pair of them-- 717 00:26:59,744 --> 00:27:02,410 discriminate between this set of objects and this set of objects 718 00:27:02,410 --> 00:27:04,120 and don't know about these ones. 719 00:27:04,120 --> 00:27:06,536 And the next one discriminates between this set of objects 720 00:27:06,536 --> 00:27:07,710 and this set of objects. 721 00:27:07,710 --> 00:27:11,920 And then as you go down to the smaller singular values, 722 00:27:11,920 --> 00:27:15,190 you get individual object discriminations, right? 723 00:27:15,190 --> 00:27:16,900 So this is how the hierarchical structure 724 00:27:16,900 --> 00:27:19,990 is reflected in the second order statistics of the data. 725 00:27:19,990 --> 00:27:22,100 And these are the singular values. 726 00:27:22,100 --> 00:27:25,780 So this is the theory for the singular values in a tree that 727 00:27:25,780 --> 00:27:27,460 has five levels of hierarchy. 728 00:27:27,460 --> 00:27:29,440 And you can see that the singular values decay 729 00:27:29,440 --> 00:27:31,910 with the hierarchy level of the singular vectors. 730 00:27:31,910 --> 00:27:32,410 OK. 731 00:27:32,410 --> 00:27:33,910 So there's a general theory for this 732 00:27:33,910 --> 00:27:36,130 in which singular vectors are associated 733 00:27:36,130 --> 00:27:37,270 with levels of the tree. 734 00:27:37,270 --> 00:27:37,930 OK? 735 00:27:37,930 --> 00:27:40,870 So now you can see the end of the story. 736 00:27:40,870 --> 00:27:42,940 If you put the two together, you automatically 737 00:27:42,940 --> 00:27:46,214 get the results that we were trying to explain, right? 738 00:27:46,214 --> 00:27:48,130 So essentially, the general theory of learning 739 00:27:48,130 --> 00:27:50,130 says that the network learns input-output modes 740 00:27:50,130 --> 00:27:53,530 on a time scale given by 1 over the singular value. 741 00:27:53,530 --> 00:27:56,020 When the data is hierarchically structured, 742 00:27:56,020 --> 00:28:00,550 singular values of broader hierarchical distinctions 743 00:28:00,550 --> 00:28:03,626 are larger than singular values of finer distinctions. 744 00:28:03,626 --> 00:28:05,500 And the input-output modes correspond exactly 745 00:28:05,500 --> 00:28:07,250 to hierarchical distributions of the tree. 746 00:28:07,250 --> 00:28:09,760 So that essentially says the network must learn broad scale 747 00:28:09,760 --> 00:28:11,740 discriminations before it can learn fine scale 748 00:28:11,740 --> 00:28:12,931 discriminations. 749 00:28:12,931 --> 00:28:15,430 So then actually what we did was we just analytically worked 750 00:28:15,430 --> 00:28:17,920 out that the dynamics of learning for hierarchically 751 00:28:17,920 --> 00:28:19,026 structured data. 752 00:28:19,026 --> 00:28:20,900 And we computed the multidimensional scaling. 753 00:28:20,900 --> 00:28:21,733 And this was theory. 754 00:28:21,733 --> 00:28:25,000 We never did a single simulation to get this plot. 755 00:28:25,000 --> 00:28:27,190 We generated a branching diffusion process 756 00:28:27,190 --> 00:28:29,050 that was essentially this one. 757 00:28:29,050 --> 00:28:33,820 And we just labeled these nodes arbitrarily with these labels. 758 00:28:33,820 --> 00:28:35,560 And this is what we get, right? 759 00:28:35,560 --> 00:28:37,660 So we see the multidimensional scaling plot 760 00:28:37,660 --> 00:28:39,550 that we sort of see here. 761 00:28:39,550 --> 00:28:41,350 And essentially, just to compare, 762 00:28:41,350 --> 00:28:44,140 this is what was done with a toy data set over which 763 00:28:44,140 --> 00:28:46,990 we had no theoretical control and a non-linear neural network 764 00:28:46,990 --> 00:28:49,000 over which we had no theoretical control. 765 00:28:49,000 --> 00:28:52,020 And this is a well-defined mathematical generative model 766 00:28:52,020 --> 00:28:53,270 under a linear neural network. 767 00:28:53,270 --> 00:28:56,750 And we see that we qualitatively explain the results. 768 00:28:56,750 --> 00:28:59,800 So this is the difference between simulation and theory, 769 00:28:59,800 --> 00:29:00,430 right? 770 00:29:00,430 --> 00:29:03,664 Now we have a conceptual understanding of effectively 771 00:29:03,664 --> 00:29:05,080 what was going on in this circuit. 772 00:29:05,080 --> 00:29:07,990 And now, it's no longer a mystery. 773 00:29:07,990 --> 00:29:10,990 So now I think I understand what Jay and collaborators were 774 00:29:10,990 --> 00:29:12,460 doing. 775 00:29:12,460 --> 00:29:15,460 It would be lovely if, for all of the stuff that's going on 776 00:29:15,460 --> 00:29:17,380 in this course, we could obtain such 777 00:29:17,380 --> 00:29:18,700 a deep rigorous understanding. 778 00:29:18,700 --> 00:29:19,824 It's much more challenging. 779 00:29:19,824 --> 00:29:24,771 But it's a goal worthy of pursuit I think. 780 00:29:24,771 --> 00:29:25,270 OK. 781 00:29:25,270 --> 00:29:30,300 So conclusions-- progressive differentiation 782 00:29:30,300 --> 00:29:32,481 of hierarchical structure is a general feature 783 00:29:32,481 --> 00:29:33,980 of learning in deep neural networks. 784 00:29:33,980 --> 00:29:36,210 It cannot be any other way. 785 00:29:36,210 --> 00:29:37,470 OK? 786 00:29:37,470 --> 00:29:39,600 Interestingly enough, deep, but not shallow, 787 00:29:39,600 --> 00:29:42,570 networks exhibit such stage-like transitions during learning. 788 00:29:42,570 --> 00:29:44,130 So if you just do no hidden layers, 789 00:29:44,130 --> 00:29:45,338 you don't get this, actually. 790 00:29:45,338 --> 00:29:47,730 You need a hidden layer to do this. 791 00:29:47,730 --> 00:29:50,160 And somewhat surprisingly, it turns out 792 00:29:50,160 --> 00:29:53,340 that even only the second order statistics 793 00:29:53,340 --> 00:29:57,060 of semantic properties provide powerful statistical signals 794 00:29:57,060 --> 00:30:00,150 that are sufficient to drive this non-linear learning 795 00:30:00,150 --> 00:30:01,080 dynamics, right? 796 00:30:01,080 --> 00:30:02,850 You don't need to look at the higher order 797 00:30:02,850 --> 00:30:05,880 statistics of the data to get this dynamic. 798 00:30:05,880 --> 00:30:08,400 Second order statistics suffice, which was not obvious 799 00:30:08,400 --> 00:30:10,720 before we started. 800 00:30:10,720 --> 00:30:11,220 OK. 801 00:30:11,220 --> 00:30:12,840 So in ongoing work, we can explain 802 00:30:12,840 --> 00:30:15,000 a whole bunch of things, like illusory correlations 803 00:30:15,000 --> 00:30:16,390 early in learning. 804 00:30:16,390 --> 00:30:22,410 So, for example, infants, they don't even 805 00:30:22,410 --> 00:30:26,052 know that, for example, pine trees don't have leaves. 806 00:30:26,052 --> 00:30:27,510 Then at an intermediate point, they 807 00:30:27,510 --> 00:30:28,950 think that pine trees have leaves. 808 00:30:28,950 --> 00:30:32,100 And then at a later point, they correctly know 809 00:30:32,100 --> 00:30:35,910 that pine trees don't have leaves, right? 810 00:30:35,910 --> 00:30:38,750 So we can explain these non-monotonic learning curves. 811 00:30:38,750 --> 00:30:41,190 We can explain these familiarity and typicality effects. 812 00:30:41,190 --> 00:30:42,981 We can explain inductive property judgments 813 00:30:42,981 --> 00:30:44,850 analytically. 814 00:30:44,850 --> 00:30:46,500 We're looking at basic level effects. 815 00:30:46,500 --> 00:30:49,030 We have a theory of category coherence, and so on. 816 00:30:49,030 --> 00:30:50,820 But in the interest of moving forward, 817 00:30:50,820 --> 00:30:53,420 I wanted to give short shrift to this stuff. 818 00:30:53,420 --> 00:30:55,410 And essentially, we can answer why are 819 00:30:55,410 --> 00:30:57,750 some properties learned faster? 820 00:30:57,750 --> 00:31:00,990 Basically, properties that are low frequency 821 00:31:00,990 --> 00:31:03,810 properties on the leaves of the tree get learned faster. 822 00:31:03,810 --> 00:31:06,570 Properties whose inner product with the singular 823 00:31:06,570 --> 00:31:09,150 vector as a larger singular value get learned faster. 824 00:31:09,150 --> 00:31:10,320 That's the story. 825 00:31:10,320 --> 00:31:11,840 Why are some items more typical? 826 00:31:11,840 --> 00:31:13,410 We have a theory for that. 827 00:31:13,410 --> 00:31:15,990 How is inductive generalization achieved by neural networks? 828 00:31:15,990 --> 00:31:19,170 We have a theory for that and so on. 829 00:31:19,170 --> 00:31:22,680 And, you know, what is a useful mathematical definition 830 00:31:22,680 --> 00:31:23,680 of category coherence? 831 00:31:23,680 --> 00:31:25,500 So, for example, you know, there are some things 832 00:31:25,500 --> 00:31:27,791 that are just intuitively called incoherent categories. 833 00:31:27,791 --> 00:31:30,520 "The set of all things are blue" is a very incoherent category. 834 00:31:30,520 --> 00:31:32,228 In fact, it's so incoherent we don't have 835 00:31:32,228 --> 00:31:33,360 a name for such a category. 836 00:31:33,360 --> 00:31:34,901 "The set of all things that are dogs" 837 00:31:34,901 --> 00:31:36,490 seems to be a very coherent category. 838 00:31:36,490 --> 00:31:38,948 And it's so coherent that we have a well-known name for it. 839 00:31:38,948 --> 00:31:40,560 The name's quite short actually, too. 840 00:31:40,560 --> 00:31:45,780 Actually, I wonder if there's a theory where 841 00:31:45,780 --> 00:31:48,160 shorter words correspond to more coherent categories 842 00:31:48,160 --> 00:31:51,244 and that's like an informative or efficient representation 843 00:31:51,244 --> 00:31:52,160 of category structure. 844 00:31:52,160 --> 00:31:55,950 But anyways, we have a natural definition of coherent category 845 00:31:55,950 --> 00:31:57,990 that's precise enough to prove a theorem 846 00:31:57,990 --> 00:32:00,630 that coherent categories are learned faster. 847 00:32:00,630 --> 00:32:03,580 And actually, this also relates to the size of the categories. 848 00:32:03,580 --> 00:32:05,450 So frequency effects show up. 849 00:32:05,450 --> 00:32:07,300 Anyways, so there's a lot of stuff there. 850 00:32:07,300 --> 00:32:09,550 But that was sort of the entry point. 851 00:32:09,550 --> 00:32:11,550 So now, what about a theory of learning 852 00:32:11,550 --> 00:32:14,220 in much deeper networks that have many, many layers? 853 00:32:14,220 --> 00:32:15,100 OK? 854 00:32:15,100 --> 00:32:17,100 So, again, I'm going to make a long story short, 855 00:32:17,100 --> 00:32:19,550 because it's all published. 856 00:32:19,550 --> 00:32:21,000 So you can read all the details. 857 00:32:21,000 --> 00:32:23,700 But I wanted to give you the spirit or the essence 858 00:32:23,700 --> 00:32:25,991 or the intuition behind the work. 859 00:32:25,991 --> 00:32:26,490 OK. 860 00:32:26,490 --> 00:32:30,720 So the questions we'd like to answer are, 861 00:32:30,720 --> 00:32:32,730 how does training time scale with depth? 862 00:32:32,730 --> 00:32:34,976 How should learning rate scale with depth? 863 00:32:34,976 --> 00:32:36,600 How do different weight initializations 864 00:32:36,600 --> 00:32:38,090 impact learning speed? 865 00:32:38,090 --> 00:32:40,590 And what we'll do is once we understand these theoretically, 866 00:32:40,590 --> 00:32:42,720 we'll find certain weight initializations 867 00:32:42,720 --> 00:32:45,420 that correspond to critical dynamics, which I'll define, 868 00:32:45,420 --> 00:32:47,790 can aid deep learning and generalization. 869 00:32:47,790 --> 00:32:51,150 So the basic idea is in a very, very deep neural network, 870 00:32:51,150 --> 00:32:54,876 right, you have a vanishing, exploding, or gradient problem. 871 00:32:54,876 --> 00:32:56,250 And that's one of the issues that 872 00:32:56,250 --> 00:32:58,320 makes deep neural network learning hard. 873 00:32:58,320 --> 00:32:59,820 So if you're going to back propagate 874 00:32:59,820 --> 00:33:02,070 the error through multiple layers, 875 00:33:02,070 --> 00:33:03,810 the back propagation operation is 876 00:33:03,810 --> 00:33:06,360 a product of Jacobians from layer to layer. 877 00:33:06,360 --> 00:33:08,370 And that product of Jacobians is fundamentally 878 00:33:08,370 --> 00:33:10,260 a linear mapping, right? 879 00:33:10,260 --> 00:33:12,360 So if the singular values associated 880 00:33:12,360 --> 00:33:14,260 with that linear mapping-- 881 00:33:14,260 --> 00:33:16,110 so essentially, if the singular values 882 00:33:16,110 --> 00:33:19,490 of the Jacobian in each layer are large, bigger than one, 883 00:33:19,490 --> 00:33:21,750 the product of such matrices will 884 00:33:21,750 --> 00:33:24,270 lead to a product matrix that has singular values that 885 00:33:24,270 --> 00:33:26,260 grow exponentially with depth. 886 00:33:26,260 --> 00:33:29,010 Similarly, if the single values are less than one, 887 00:33:29,010 --> 00:33:31,110 they'll decay with depth, right? 888 00:33:31,110 --> 00:33:33,840 So that's a vanishing gradient in the latter case 889 00:33:33,840 --> 00:33:36,166 and an exploding gradient in the former case. 890 00:33:36,166 --> 00:33:38,040 That seems to be one of the major impediments 891 00:33:38,040 --> 00:33:39,880 to understanding deep learning. 892 00:33:39,880 --> 00:33:42,150 So what people often did was they 893 00:33:42,150 --> 00:33:45,600 tried to scale the matrices to avoid this question, right? 894 00:33:45,600 --> 00:33:50,910 So what they often do is they initialize the weights randomly 895 00:33:50,910 --> 00:33:54,360 so that W is a random matrix where the elements W, 896 00:33:54,360 --> 00:34:00,450 I, J, are IID and Gaussian with a scale factor scaled precisely 897 00:34:00,450 --> 00:34:02,760 so that the largest eigenvalue of the Jacobian 898 00:34:02,760 --> 00:34:05,640 or the back propagation operator is one. 899 00:34:05,640 --> 00:34:06,270 OK? 900 00:34:06,270 --> 00:34:08,400 So that's like scaling the system so 901 00:34:08,400 --> 00:34:12,690 that if you place a random error vector here, the desired 902 00:34:12,690 --> 00:34:15,719 output minus the actual output, and back propagate it 903 00:34:15,719 --> 00:34:19,727 through a random network, a error vector 904 00:34:19,727 --> 00:34:23,790 will preserve its norm as it's back propagated across. 905 00:34:23,790 --> 00:34:26,730 And this is the famous sort of Glorot and Bengio 906 00:34:26,730 --> 00:34:28,540 initialization. 907 00:34:28,540 --> 00:34:32,290 And it works pretty well for depth four or five or whatever, 908 00:34:32,290 --> 00:34:33,540 right? 909 00:34:33,540 --> 00:34:34,650 OK. 910 00:34:34,650 --> 00:34:39,292 So we would like a theory of that for the learning 911 00:34:39,292 --> 00:34:40,000 dynamics of that. 912 00:34:40,000 --> 00:34:43,090 And as I said, there's no hope for a complete theory 913 00:34:43,090 --> 00:34:45,949 at the moment with arbitrary non-linearities. 914 00:34:45,949 --> 00:34:46,449 OK. 915 00:34:46,449 --> 00:34:47,865 So what we're going to do is we're 916 00:34:47,865 --> 00:34:49,270 to analyze the learning dynamics. 917 00:34:49,270 --> 00:34:52,054 Just-- we'll get rid of the non-linearities, right? 918 00:34:52,054 --> 00:34:54,429 So, again, it might seem like we're throwing the baby out 919 00:34:54,429 --> 00:34:55,690 with the bathwater, but we're actually 920 00:34:55,690 --> 00:34:57,273 going to learn something that helps us 921 00:34:57,273 --> 00:34:58,820 to train non-linear networks. 922 00:34:58,820 --> 00:34:59,320 OK. 923 00:34:59,320 --> 00:35:01,780 So the basic idea then is that we 924 00:35:01,780 --> 00:35:03,670 have a network which is linear. 925 00:35:03,670 --> 00:35:09,201 So y, the output, is a product of weights, right? 926 00:35:09,201 --> 00:35:09,700 OK. 927 00:35:09,700 --> 00:35:12,730 So then the back propagation's, again, just a product 928 00:35:12,730 --> 00:35:14,920 of matrices, right? 929 00:35:14,920 --> 00:35:16,900 The gradient dynamics is non-linear 930 00:35:16,900 --> 00:35:19,270 and coupled and non-convex. 931 00:35:19,270 --> 00:35:21,670 And actually even in this linear network, 932 00:35:21,670 --> 00:35:24,850 you see plateaus and sudden transitions, right? 933 00:35:24,850 --> 00:35:26,870 And actually, interestingly enough, 934 00:35:26,870 --> 00:35:28,630 even in this very deep linear network, 935 00:35:28,630 --> 00:35:32,030 you see faster conversions from pre-trained initial conditions, 936 00:35:32,030 --> 00:35:32,530 right? 937 00:35:32,530 --> 00:35:34,730 So basically, if you start from random Gaussian 938 00:35:34,730 --> 00:35:37,570 initial conditions, you get slow learning for a while 939 00:35:37,570 --> 00:35:40,300 and then sudden learning here, relatively 940 00:35:40,300 --> 00:35:41,710 sudden learning here. 941 00:35:41,710 --> 00:35:43,720 Whereas, if you're pre-train the network using 942 00:35:43,720 --> 00:35:47,080 greedy unsupervised learning, so this is the time 943 00:35:47,080 --> 00:35:50,101 it takes to pre-train, you get sudden learning and a drop 944 00:35:50,101 --> 00:35:50,600 here. 945 00:35:50,600 --> 00:35:52,975 So remember, if you go back to the original Hinton paper, 946 00:35:52,975 --> 00:35:55,840 this was the phenomenon that started deep learning. 947 00:35:55,840 --> 00:35:57,880 Greedy unsupervised pre-training allows 948 00:35:57,880 --> 00:36:01,120 you to rapidly train very, very deep neural networks. 949 00:36:01,120 --> 00:36:03,190 So the very empirical phenomenon that 950 00:36:03,190 --> 00:36:05,140 led to the genesis of deep learning 951 00:36:05,140 --> 00:36:09,460 was present already in deep linear neural networks, right? 952 00:36:09,460 --> 00:36:11,530 So deep linear neural networks, in terms 953 00:36:11,530 --> 00:36:13,180 of their expressive power, are crappy. 954 00:36:13,180 --> 00:36:16,672 Because the composition of linear operations is linear. 955 00:36:16,672 --> 00:36:18,880 They're not a good model for deep non-linear networks 956 00:36:18,880 --> 00:36:20,560 in terms of input-output mappings. 957 00:36:20,560 --> 00:36:24,220 But they're a surprisingly good model, theoretical toy model, 958 00:36:24,220 --> 00:36:26,480 for modeling the dynamics of learning 959 00:36:26,480 --> 00:36:27,790 in non-linear networks. 960 00:36:27,790 --> 00:36:28,300 OK? 961 00:36:28,300 --> 00:36:30,580 Because very important phenomena also 962 00:36:30,580 --> 00:36:32,380 arise in the deep linear networks. 963 00:36:32,380 --> 00:36:35,571 And we're focusing on learning dynamics here. 964 00:36:35,571 --> 00:36:36,070 OK. 965 00:36:36,070 --> 00:36:38,153 So we can build intuitions for the non-linear case 966 00:36:38,153 --> 00:36:39,680 by analyzing the linear case. 967 00:36:39,680 --> 00:36:40,180 OK? 968 00:36:40,180 --> 00:36:43,210 So we went through the three layer dynamics already. 969 00:36:43,210 --> 00:36:45,030 What about the multiple layer dynamics? 970 00:36:45,030 --> 00:36:50,030 So, again, the Jacobian can back propagate or explode, right? 971 00:36:50,030 --> 00:36:50,530 OK. 972 00:36:50,530 --> 00:36:52,840 So, again, I'm going to make a long story short. 973 00:36:52,840 --> 00:36:55,150 But what we find is that if you take-- 974 00:36:59,012 --> 00:37:00,470 OK, I'll tell you the final result. 975 00:37:00,470 --> 00:37:04,870 What we find is that we find a class of weight initializations 976 00:37:04,870 --> 00:37:08,470 that allow learning time to remain constant 977 00:37:08,470 --> 00:37:10,640 as the depth of the network goes to infinity. 978 00:37:10,640 --> 00:37:14,400 Now, I'm measuring learning time in units of learning epochs, 979 00:37:14,400 --> 00:37:15,037 right? 980 00:37:15,037 --> 00:37:17,620 So, obviously, to train a deep neural network, very, very deep 981 00:37:17,620 --> 00:37:19,180 neural network, it just takes longer 982 00:37:19,180 --> 00:37:21,220 to compute each gradient, right? 983 00:37:21,220 --> 00:37:23,020 So in terms of real time, of course, 984 00:37:23,020 --> 00:37:25,190 the time will scale with the depth of the network. 985 00:37:25,190 --> 00:37:27,580 But you might imagine in terms of number 986 00:37:27,580 --> 00:37:30,940 of gradient evaluations, as the network gets deeper and deeper, 987 00:37:30,940 --> 00:37:33,340 it might take longer and longer to train it. 988 00:37:33,340 --> 00:37:35,515 And we show a class of initial conditions 989 00:37:35,515 --> 00:37:37,480 for which that's not true. 990 00:37:37,480 --> 00:37:40,720 As the network gets deeper and deeper, 991 00:37:40,720 --> 00:37:42,370 the number of gradient evaluations 992 00:37:42,370 --> 00:37:45,490 you need to train the network can remain constant 993 00:37:45,490 --> 00:37:49,330 even as the depth goes to infinity even 994 00:37:49,330 --> 00:37:50,330 in a non-linear network. 995 00:37:50,330 --> 00:37:50,830 OK. 996 00:37:50,830 --> 00:37:52,371 So let me give you intuition for why. 997 00:37:56,820 --> 00:37:57,320 OK. 998 00:37:57,320 --> 00:38:03,550 So, for example, in the classical initialization, 999 00:38:03,550 --> 00:38:05,570 this Glorot and Bengio initialization 1000 00:38:05,570 --> 00:38:06,560 doesn't have that. 1001 00:38:06,560 --> 00:38:08,390 But our initialization does. 1002 00:38:08,390 --> 00:38:10,306 So basically what we did was-- 1003 00:38:10,306 --> 00:38:11,930 we'll start off with a linear networks. 1004 00:38:11,930 --> 00:38:14,480 We trained deep linear networks on MNIST. 1005 00:38:14,480 --> 00:38:17,480 And we scaled that depth like this, right? 1006 00:38:17,480 --> 00:38:20,240 And we started with random Gaussian initial conditions 1007 00:38:20,240 --> 00:38:22,400 and then ran back propagation, but scaled 1008 00:38:22,400 --> 00:38:25,820 random Gaussian initializations. 1009 00:38:25,820 --> 00:38:28,370 And we found that the training time, as you might expect, 1010 00:38:28,370 --> 00:38:29,104 grew with depth. 1011 00:38:29,104 --> 00:38:30,770 This is training time measured in number 1012 00:38:30,770 --> 00:38:34,280 of learning epochs or number of gradient evaluations. 1013 00:38:34,280 --> 00:38:36,410 But here what we did was we initialized 1014 00:38:36,410 --> 00:38:39,680 the weights using random orthogonal weights, right? 1015 00:38:39,680 --> 00:38:41,930 And then we found that the learning time 1016 00:38:41,930 --> 00:38:43,220 didn't grow with depth. 1017 00:38:43,220 --> 00:38:45,930 And also, if you pre-train it, it doesn't grow with depth. 1018 00:38:45,930 --> 00:38:46,430 OK. 1019 00:38:46,430 --> 00:38:48,763 So there's a dramatically different scaling and learning 1020 00:38:48,763 --> 00:38:51,260 time between random Gaussian initialization 1021 00:38:51,260 --> 00:38:53,680 and random orthogonal initialization. 1022 00:38:53,680 --> 00:38:54,770 Why? 1023 00:38:54,770 --> 00:38:55,610 OK? 1024 00:38:55,610 --> 00:38:57,890 And the answer is the following. 1025 00:38:57,890 --> 00:39:00,230 Let's think about the back propagation operator. 1026 00:39:00,230 --> 00:39:02,120 Let's say you want to back propagate errors 1027 00:39:02,120 --> 00:39:04,440 from the output to the input. 1028 00:39:04,440 --> 00:39:08,360 So the back propagation operator in a linear network 1029 00:39:08,360 --> 00:39:11,341 is just the product of weights throughout the entire network. 1030 00:39:11,341 --> 00:39:11,840 OK? 1031 00:39:14,420 --> 00:39:19,240 So if you do a random Gaussian weight initialization here, 1032 00:39:19,240 --> 00:39:24,080 then this is a product of random Gaussian matrices. 1033 00:39:24,080 --> 00:39:26,990 So to understand the statistical properties of back propagation, 1034 00:39:26,990 --> 00:39:29,240 you need to understand the statistical properties 1035 00:39:29,240 --> 00:39:32,446 of the singular value spectrum of random Gaussian matrices. 1036 00:39:32,446 --> 00:39:34,320 There isn't really a general theory for that, 1037 00:39:34,320 --> 00:39:36,445 but we can look at it numerically and get intuition 1038 00:39:36,445 --> 00:39:37,040 for it. 1039 00:39:37,040 --> 00:39:39,530 So the basic idea is if you have one random Gaussian 1040 00:39:39,530 --> 00:39:44,930 matrix, the singular values of W are the eigenvalues of W 1041 00:39:44,930 --> 00:39:47,810 transpose W. That's a famous distribution called 1042 00:39:47,810 --> 00:39:49,670 the Marchenko-Pastur distribution. 1043 00:39:49,670 --> 00:39:52,940 And you know, they vary in a range that's order one. 1044 00:39:52,940 --> 00:39:54,130 OK? 1045 00:39:54,130 --> 00:39:57,410 So if you back propagate through one layer, you're fine. 1046 00:39:57,410 --> 00:39:59,700 You don't get vanishing exponent gradients. 1047 00:39:59,700 --> 00:40:00,200 OK. 1048 00:40:00,200 --> 00:40:06,900 But if you look at the singular values of a product of five 1049 00:40:06,900 --> 00:40:09,210 random Gaussian matrices, the singular value spectrum 1050 00:40:09,210 --> 00:40:10,560 gets very distorted. 1051 00:40:10,560 --> 00:40:12,420 You've got a large number of similar values 1052 00:40:12,420 --> 00:40:16,200 that are close to zero and a long tail that, you know, 1053 00:40:16,200 --> 00:40:17,620 extends up to four. 1054 00:40:17,620 --> 00:40:18,130 OK. 1055 00:40:18,130 --> 00:40:22,020 But if you do 100 layers, you get a very, very large number 1056 00:40:22,020 --> 00:40:24,120 of singular values that are close to zero 1057 00:40:24,120 --> 00:40:26,390 and very much longer tail. 1058 00:40:26,390 --> 00:40:27,450 OK? 1059 00:40:27,450 --> 00:40:31,040 Now, this is a product of random Gaussian matrices. 1060 00:40:31,040 --> 00:40:35,100 So if you feed a random vector into this, on average, 1061 00:40:35,100 --> 00:40:36,460 it's norm will be preserved. 1062 00:40:36,460 --> 00:40:39,090 The vector's length will not change. 1063 00:40:39,090 --> 00:40:41,640 But we know that preserving the length of a vector 1064 00:40:41,640 --> 00:40:43,320 is not the same as preserving angles 1065 00:40:43,320 --> 00:40:45,130 between all pairs of vectors. 1066 00:40:45,130 --> 00:40:45,630 OK. 1067 00:40:45,630 --> 00:40:49,590 So actually, the way that this product of random Gaussian 1068 00:40:49,590 --> 00:40:52,910 matrices preserves the norm of the gradient 1069 00:40:52,910 --> 00:40:55,560 is it does it in a very anisotropic way. 1070 00:40:55,560 --> 00:40:58,470 It takes an error vector at the output, 1071 00:40:58,470 --> 00:41:01,410 and it projects it into a low dimensional space 1072 00:41:01,410 --> 00:41:03,930 corresponding to the singular values that are large. 1073 00:41:03,930 --> 00:41:06,420 And then it amplifies it in that space. 1074 00:41:06,420 --> 00:41:08,750 So the length is preserved, but all error vectors 1075 00:41:08,750 --> 00:41:10,500 get projected onto a low-dimensional space 1076 00:41:10,500 --> 00:41:11,500 and amplified. 1077 00:41:11,500 --> 00:41:14,310 So a lot of error information is lost in a product 1078 00:41:14,310 --> 00:41:16,460 of random Gaussian matrices. 1079 00:41:16,460 --> 00:41:17,370 OK. 1080 00:41:17,370 --> 00:41:20,370 So that's why the Glorot and Bengio initial conditions work 1081 00:41:20,370 --> 00:41:23,130 well up to five or six or seven, but they 1082 00:41:23,130 --> 00:41:24,960 don't work well up to depth, say, 100 1083 00:41:24,960 --> 00:41:27,150 or in recurrent neural networks as well. 1084 00:41:27,150 --> 00:41:29,400 OK? 1085 00:41:29,400 --> 00:41:30,240 So what can we do? 1086 00:41:30,240 --> 00:41:31,614 Well, a simple thing we can do is 1087 00:41:31,614 --> 00:41:35,400 we can replace the matrices, these random matrices, 1088 00:41:35,400 --> 00:41:36,900 with orthogonal matrices. 1089 00:41:36,900 --> 00:41:37,666 OK? 1090 00:41:37,666 --> 00:41:40,290 So we know that all the singular values of an orthogonal matrix 1091 00:41:40,290 --> 00:41:41,550 are one, every single one. 1092 00:41:41,550 --> 00:41:44,400 And the product of orthogonal matrices is orthogonal. 1093 00:41:44,400 --> 00:41:46,470 So therefore, the back propagation operator 1094 00:41:46,470 --> 00:41:48,930 has all of its singular values equal to one. 1095 00:41:48,930 --> 00:41:51,270 And there's generalizations of orthogonal matrices 1096 00:41:51,270 --> 00:41:53,940 to rectangular versions when the layers 1097 00:41:53,940 --> 00:41:56,470 don't have the same number of neurons in each layer. 1098 00:41:56,470 --> 00:41:56,970 OK? 1099 00:41:56,970 --> 00:41:58,170 So this is fantastic. 1100 00:41:58,170 --> 00:42:00,630 So this works really well for linear networks. 1101 00:42:00,630 --> 00:42:01,470 OK. 1102 00:42:01,470 --> 00:42:03,930 But how does this generalize to non-linear networks? 1103 00:42:03,930 --> 00:42:07,020 Because then you have a product of Jacobians, right? 1104 00:42:07,020 --> 00:42:08,830 So what happens here? 1105 00:42:08,830 --> 00:42:09,330 OK. 1106 00:42:09,330 --> 00:42:11,430 So what is the product of Jacobian? 1107 00:42:15,010 --> 00:42:15,510 OK. 1108 00:42:15,510 --> 00:42:19,620 So if we imagine how either errors back 1109 00:42:19,620 --> 00:42:22,860 propagate to the front or how input perturbations back 1110 00:42:22,860 --> 00:42:27,200 propagate to the end, it's the same thing. 1111 00:42:27,200 --> 00:42:29,402 So it's easier to think about forward propagation. 1112 00:42:29,402 --> 00:42:31,860 Imagine that you have an input and you perturb it slightly. 1113 00:42:31,860 --> 00:42:34,110 How does the perturbation grow or decay? 1114 00:42:34,110 --> 00:42:37,200 Well, what happens is there's a linear expansion or contraction 1115 00:42:37,200 --> 00:42:38,370 due to W. 1116 00:42:38,370 --> 00:42:40,950 And then this nominator is usually compressive. 1117 00:42:40,950 --> 00:42:42,900 So there's a non-linear compression 1118 00:42:42,900 --> 00:42:46,320 due the diagonal Jacobian passing through the point y's 1119 00:42:46,320 --> 00:42:47,460 non-linearity. 1120 00:42:47,460 --> 00:42:49,625 And then, again, linear modification and non-linear 1121 00:42:49,625 --> 00:42:51,000 compression, linear modification, 1122 00:42:51,000 --> 00:42:52,610 and non-linear compression. 1123 00:42:52,610 --> 00:42:53,250 OK? 1124 00:42:53,250 --> 00:42:56,220 So what we could do is just simply choose these again 1125 00:42:56,220 --> 00:42:58,440 to be random orthogonal matrices. 1126 00:42:58,440 --> 00:43:01,320 And then what happens is the growth or decay 1127 00:43:01,320 --> 00:43:07,140 of perturbations-- and we scale the random orthogonal matrices 1128 00:43:07,140 --> 00:43:10,500 by a scale factor to combat the non-linear compression. 1129 00:43:10,500 --> 00:43:13,350 And then the dynamics of perturbations is like this. 1130 00:43:13,350 --> 00:43:17,430 You rotate, linearly scale, non-linearly compress, rotate, 1131 00:43:17,430 --> 00:43:20,520 linearly scale, non-linearly compress, and so on. 1132 00:43:20,520 --> 00:43:22,440 That's essentially the type of dynamics 1133 00:43:22,440 --> 00:43:26,850 that occurs in dynamically critical systems that 1134 00:43:26,850 --> 00:43:28,470 are close to the edge of chaos. 1135 00:43:28,470 --> 00:43:32,140 You get this alternating phase space expansion and compression 1136 00:43:32,140 --> 00:43:34,720 that's in different dimensions at different times. 1137 00:43:34,720 --> 00:43:35,700 OK? 1138 00:43:35,700 --> 00:43:37,530 So now you can just compute numerically 1139 00:43:37,530 --> 00:43:40,301 under that initialization. 1140 00:43:40,301 --> 00:43:41,800 How does the singular value spectrum 1141 00:43:41,800 --> 00:43:43,590 of the product of Jacobian scale? 1142 00:43:43,590 --> 00:43:45,120 And it scales beautifully. 1143 00:43:45,120 --> 00:43:47,580 So this is the scale factor for the type of non-linearity 1144 00:43:47,580 --> 00:43:50,370 that we use, the hyperbolic tangent non-linearity. 1145 00:43:50,370 --> 00:43:52,320 The optimal scale factor in front 1146 00:43:52,320 --> 00:43:54,414 of the random orthogonal matrix is one. 1147 00:43:54,414 --> 00:43:55,830 And you see when you choose that-- 1148 00:43:58,380 --> 00:44:00,450 this was 100 layers, I believe-- 1149 00:44:00,450 --> 00:44:03,134 even for 100 layers, that end to end Jacobian 1150 00:44:03,134 --> 00:44:05,550 and from the input to output has a singular value spectrum 1151 00:44:05,550 --> 00:44:07,740 that remains within the range of order one. 1152 00:44:07,740 --> 00:44:10,270 If g is even slightly less than one, 1153 00:44:10,270 --> 00:44:11,910 the singular values exponentially 1154 00:44:11,910 --> 00:44:13,200 vanish with depth. 1155 00:44:13,200 --> 00:44:14,870 If g is larger than one, the singular 1156 00:44:14,870 --> 00:44:19,160 values grow, but actually not as quickly as you'd think. 1157 00:44:19,160 --> 00:44:22,140 So this is the critically dynamical regime 1158 00:44:22,140 --> 00:44:24,570 that at least preserves not only the norm 1159 00:44:24,570 --> 00:44:26,550 of back propagated gradients, but all angles 1160 00:44:26,550 --> 00:44:28,770 between pairs of gradients, right? 1161 00:44:28,770 --> 00:44:32,700 So it's an isotropic preservation 1162 00:44:32,700 --> 00:44:34,950 of error information from the end of the network all 1163 00:44:34,950 --> 00:44:36,020 the way to the beginning. 1164 00:44:36,020 --> 00:44:37,080 OK? 1165 00:44:37,080 --> 00:44:38,130 So does it work? 1166 00:44:38,130 --> 00:44:41,910 And it works better than other initializations even 1167 00:44:41,910 --> 00:44:42,930 in non-linear networks. 1168 00:44:42,930 --> 00:44:45,960 So we trained 30 layer non-linear networks. 1169 00:44:45,960 --> 00:44:47,670 And the initialization works better. 1170 00:44:47,670 --> 00:44:54,010 And so also, interestingly enough, at this critical factor 1171 00:44:54,010 --> 00:44:56,520 you also achieve better generalization error. 1172 00:44:56,520 --> 00:44:59,250 And we don't have a good theory for that actually. 1173 00:44:59,250 --> 00:45:06,780 The test error and the training error, of course, goes down. 1174 00:45:06,780 --> 00:45:07,280 OK. 1175 00:45:07,280 --> 00:45:08,630 So that's an interesting situation 1176 00:45:08,630 --> 00:45:10,010 where a theory of linear networks 1177 00:45:10,010 --> 00:45:12,120 led to a practical training advantage 1178 00:45:12,120 --> 00:45:13,860 from non-linear networks. 1179 00:45:13,860 --> 00:45:14,360 OK. 1180 00:45:14,360 --> 00:45:17,280 So here's another question that we had. 1181 00:45:17,280 --> 00:45:18,920 OK? 1182 00:45:18,920 --> 00:45:21,736 There's a whole world of convex optimization. 1183 00:45:21,736 --> 00:45:23,360 We want our machine learning algorithms 1184 00:45:23,360 --> 00:45:25,190 to correspond to convex optimization, 1185 00:45:25,190 --> 00:45:26,800 so we can find the global minimum. 1186 00:45:26,800 --> 00:45:28,550 And there are no local minima to impede us 1187 00:45:28,550 --> 00:45:30,174 from finding the global minimum, right? 1188 00:45:30,174 --> 00:45:32,250 That's conventional wisdom. 1189 00:45:32,250 --> 00:45:34,275 Yet the deep neural network people 1190 00:45:34,275 --> 00:45:36,650 ignore this conventional wisdom and train very, very deep 1191 00:45:36,650 --> 00:45:39,650 neural networks and don't worry about the potential impediments 1192 00:45:39,650 --> 00:45:41,510 to the local minima. 1193 00:45:41,510 --> 00:45:44,100 They seem to find pretty good solutions why. 1194 00:45:44,100 --> 00:45:45,080 OK? 1195 00:45:45,080 --> 00:45:47,210 Is the intuition that local minima 1196 00:45:47,210 --> 00:45:50,540 are really an impediment to non-linear non-convex 1197 00:45:50,540 --> 00:45:54,210 optimization in high dimensional spaces really true? 1198 00:45:54,210 --> 00:45:55,460 OK? 1199 00:45:55,460 --> 00:45:57,080 And you might think that it's not 1200 00:45:57,080 --> 00:45:59,570 true for the following intuitive reason, right? 1201 00:45:59,570 --> 00:46:01,580 So, again, it's often thought that local minima, 1202 00:46:01,580 --> 00:46:04,520 at some high level of error and training error, 1203 00:46:04,520 --> 00:46:08,336 stand as a major impediment to non-convex optimization. 1204 00:46:08,336 --> 00:46:10,040 And, you know, this is an example-- 1205 00:46:10,040 --> 00:46:12,740 a two-dimensional caricature of a protein folding energy 1206 00:46:12,740 --> 00:46:13,250 landscape. 1207 00:46:13,250 --> 00:46:16,060 And it's very rough, so there's many, many local minima. 1208 00:46:16,060 --> 00:46:18,554 And the global minima might be hard to find. 1209 00:46:18,554 --> 00:46:19,220 And that's true. 1210 00:46:19,220 --> 00:46:21,230 If you sort of draw random generic surfaces 1211 00:46:21,230 --> 00:46:23,600 over low dimensions, those random surfaces 1212 00:46:23,600 --> 00:46:25,220 will have many local minima. 1213 00:46:25,220 --> 00:46:27,320 But, of course, our intuition about geometry 1214 00:46:27,320 --> 00:46:30,140 derived from our experience with a low-dimensional world 1215 00:46:30,140 --> 00:46:32,360 is woefully inadequate for thinking about geometry 1216 00:46:32,360 --> 00:46:35,270 in high-dimensional spaces. 1217 00:46:35,270 --> 00:46:38,690 So it turns out that random non-convex error functions 1218 00:46:38,690 --> 00:46:41,600 over high-dimensional spaces, local minima 1219 00:46:41,600 --> 00:46:43,910 are sort of exponentially rare in the dimensionality 1220 00:46:43,910 --> 00:46:45,230 relative to saddle points. 1221 00:46:45,230 --> 00:46:47,360 Just intuitively, imagine you have an error 1222 00:46:47,360 --> 00:46:50,210 function over 1,000 variables, say 1,000 synaptic weights 1223 00:46:50,210 --> 00:46:51,102 in a deep network. 1224 00:46:51,102 --> 00:46:52,310 That's a small, deep network. 1225 00:46:52,310 --> 00:46:55,760 But anyways, let's say there's a point at which the gradient 1226 00:46:55,760 --> 00:46:57,570 and weight space vanishes. 1227 00:46:57,570 --> 00:46:59,540 So now there's 1,000 directions in weight space 1228 00:46:59,540 --> 00:47:02,300 you could move away from that extreme. 1229 00:47:02,300 --> 00:47:05,810 What are the chances that every single direction you move 1230 00:47:05,810 --> 00:47:08,330 has positive curvature, right? 1231 00:47:08,330 --> 00:47:10,040 If it's a fairly generic landscape, 1232 00:47:10,040 --> 00:47:12,816 the answer is exponentially small in the dimensionality. 1233 00:47:12,816 --> 00:47:14,690 Some directions will have negative curvature. 1234 00:47:14,690 --> 00:47:16,760 Some directions will have positive curvature, 1235 00:47:16,760 --> 00:47:19,552 unless your critical point is already at the bottom. 1236 00:47:19,552 --> 00:47:22,010 In that case, most directions will have positive curvature. 1237 00:47:22,010 --> 00:47:24,440 Or unless your critical point is at the top higher, 1238 00:47:24,440 --> 00:47:27,170 then most directions will have negative curvature, right? 1239 00:47:27,170 --> 00:47:29,600 So statistical physicists have made this intuition very 1240 00:47:29,600 --> 00:47:33,890 precise for random landscapes. 1241 00:47:33,890 --> 00:47:35,840 And they've developed a theory for it. 1242 00:47:35,840 --> 00:47:37,580 So this is a paper in Physical Review 1243 00:47:37,580 --> 00:47:39,290 Letters by Bray and Dean. 1244 00:47:39,290 --> 00:47:41,090 So what they did was they imagined 1245 00:47:41,090 --> 00:47:43,860 just a random Gaussian error landscape. 1246 00:47:43,860 --> 00:47:47,360 So what they did was they looked at an error landscape that's 1247 00:47:47,360 --> 00:47:50,270 a continuous function over n dimensions, 1248 00:47:50,270 --> 00:47:51,530 but there is correlations. 1249 00:47:51,530 --> 00:47:53,540 It's correlated over some length scale. 1250 00:47:53,540 --> 00:47:57,200 So it's a single draw from a random Gaussian process where 1251 00:47:57,200 --> 00:48:00,140 the kernel of the Gaussian process 1252 00:48:00,140 --> 00:48:02,180 is falling off with some length scale. 1253 00:48:02,180 --> 00:48:04,850 So the error at 0.1 is correlated with the error 1254 00:48:04,850 --> 00:48:06,530 at 0.2 over some length scale. 1255 00:48:06,530 --> 00:48:08,790 And that correlation falls off smoothly. 1256 00:48:08,790 --> 00:48:10,580 So it's a random smooth landscape. 1257 00:48:10,580 --> 00:48:11,300 OK. 1258 00:48:11,300 --> 00:48:13,560 So the correlations are local, essentially. 1259 00:48:13,560 --> 00:48:16,670 And then what they did was they asked the following question. 1260 00:48:16,670 --> 00:48:18,650 Let x be a critical point, a point where 1261 00:48:18,650 --> 00:48:20,080 the gradient vanishes. 1262 00:48:20,080 --> 00:48:20,960 OK? 1263 00:48:20,960 --> 00:48:22,760 We can plot every single critical point 1264 00:48:22,760 --> 00:48:24,500 in a two-dimensional feature space. 1265 00:48:24,500 --> 00:48:25,940 What is that feature space? 1266 00:48:25,940 --> 00:48:28,024 Well, the horizontal axis is the error level 1267 00:48:28,024 --> 00:48:28,940 of the critical point. 1268 00:48:28,940 --> 00:48:32,660 At how high on the error axis does this critical point sit? 1269 00:48:32,660 --> 00:48:36,290 And then this f is the fraction of negative eigenvalues 1270 00:48:36,290 --> 00:48:38,190 of the Hessian at that critical point. 1271 00:48:38,190 --> 00:48:39,680 So it's the fraction of directions 1272 00:48:39,680 --> 00:48:43,280 that curve downwards. 1273 00:48:43,280 --> 00:48:44,000 OK. 1274 00:48:44,000 --> 00:48:46,790 So now a priori, critical points could potentially 1275 00:48:46,790 --> 00:48:49,820 set anywhere in this two-dimensional feature space, 1276 00:48:49,820 --> 00:48:50,960 right? 1277 00:48:50,960 --> 00:48:53,150 It turns out they don't. 1278 00:48:53,150 --> 00:48:55,550 They concentrate on a monotonically increasing 1279 00:48:55,550 --> 00:48:57,360 curve that looks like this. 1280 00:48:57,360 --> 00:48:59,990 So the higher the error level of the critical point, 1281 00:48:59,990 --> 00:49:03,050 the more the negative curvature directions you have. 1282 00:49:03,050 --> 00:49:03,770 OK? 1283 00:49:03,770 --> 00:49:06,506 And to be an order one distance away from this curve, 1284 00:49:06,506 --> 00:49:07,880 the probability of that happening 1285 00:49:07,880 --> 00:49:11,510 is exponentially small in the dimensionality of the problem. 1286 00:49:11,510 --> 00:49:12,260 OK? 1287 00:49:12,260 --> 00:49:13,532 Now, what does that mean? 1288 00:49:13,532 --> 00:49:14,990 It automatically implies that there 1289 00:49:14,990 --> 00:49:17,990 are no local minima at high error, 1290 00:49:17,990 --> 00:49:19,400 or at least they're exponentially 1291 00:49:19,400 --> 00:49:21,860 rare relative to saddle points of a given index. 1292 00:49:21,860 --> 00:49:22,940 OK? 1293 00:49:22,940 --> 00:49:25,470 So basically, you typically never encounter local minima 1294 00:49:25,470 --> 00:49:26,990 at height error, right? 1295 00:49:26,990 --> 00:49:29,510 That would be stuff that sits here. 1296 00:49:29,510 --> 00:49:30,710 And there's nothing here. 1297 00:49:30,710 --> 00:49:31,520 OK? 1298 00:49:31,520 --> 00:49:34,220 Second, if you are a local minimum, which 1299 00:49:34,220 --> 00:49:36,560 means on this axis you're at the bottom, 1300 00:49:36,560 --> 00:49:38,750 then your error level must be very, very close 1301 00:49:38,750 --> 00:49:40,640 to the global minimum. 1302 00:49:40,640 --> 00:49:41,660 OK? 1303 00:49:41,660 --> 00:49:44,684 So if you get stuck in a local minimum, 1304 00:49:44,684 --> 00:49:46,850 you're already close in error to the global minimum. 1305 00:49:46,850 --> 00:49:48,630 AUDIENCE: Can you repeat this last element? 1306 00:49:48,630 --> 00:49:49,463 SURYA GANGULI: Yeah. 1307 00:49:49,463 --> 00:49:51,434 So if you're a local minimum, your error level 1308 00:49:51,434 --> 00:49:53,600 will be close to the error level the global minimum. 1309 00:49:53,600 --> 00:49:54,183 AUDIENCE: Why? 1310 00:49:54,183 --> 00:49:57,850 SURYA GANGULI: Because what does it mean to be a local minimum? 1311 00:49:57,850 --> 00:49:58,864 It means that f is zero. 1312 00:49:58,864 --> 00:50:00,780 The fraction of negative curvature eigenvalues 1313 00:50:00,780 --> 00:50:02,300 of the Hessian is zero. 1314 00:50:02,300 --> 00:50:04,240 And this is the distribution of error levels 1315 00:50:04,240 --> 00:50:05,290 of such critical points. 1316 00:50:05,290 --> 00:50:07,640 They're strongly peaked at this value, 1317 00:50:07,640 --> 00:50:09,700 which is the value of the global minimum. 1318 00:50:09,700 --> 00:50:11,450 Essentially, there's nothing out here. 1319 00:50:11,450 --> 00:50:12,550 OK? 1320 00:50:12,550 --> 00:50:17,440 All right, now in physics there is this well-known principle 1321 00:50:17,440 --> 00:50:19,150 called universality. 1322 00:50:19,150 --> 00:50:21,010 There are certain questions whose answers 1323 00:50:21,010 --> 00:50:23,210 don't depend on the details. 1324 00:50:23,210 --> 00:50:25,060 For example, certain critical exponents 1325 00:50:25,060 --> 00:50:27,040 in the liquid-gas phase transition 1326 00:50:27,040 --> 00:50:29,290 are exactly the same as critical exponents 1327 00:50:29,290 --> 00:50:31,430 in the ferromagnetic phase transition. 1328 00:50:31,430 --> 00:50:33,610 Because the symmetry and dimensionality 1329 00:50:33,610 --> 00:50:35,380 of the order parameter density in the case 1330 00:50:35,380 --> 00:50:38,470 of liquid and magnetization in the case of ferromagnets 1331 00:50:38,470 --> 00:50:39,400 are the same. 1332 00:50:39,400 --> 00:50:41,200 So there's certain questions whose answers 1333 00:50:41,200 --> 00:50:42,200 don't depend on the detail. 1334 00:50:42,200 --> 00:50:44,324 They only depend on the symmetry and dimensionality 1335 00:50:44,324 --> 00:50:45,800 of the problem. 1336 00:50:45,800 --> 00:50:48,520 So one might think that this qualitative prediction 1337 00:50:48,520 --> 00:50:52,782 is true in just generic high-dimensional landscapes. 1338 00:50:52,782 --> 00:50:55,240 Now, the computer scientists would say, no, no, no, no, no, 1339 00:50:55,240 --> 00:50:55,740 no. 1340 00:50:55,740 --> 00:50:59,230 Your random landscapes are a horrible model for our error 1341 00:50:59,230 --> 00:51:00,730 landscapes of deep neural networks 1342 00:51:00,730 --> 00:51:04,310 trained on MNIST and CIFAR-10 and so on and so forth. 1343 00:51:04,310 --> 00:51:06,137 You're completely irrelevant to us, 1344 00:51:06,137 --> 00:51:07,720 because we're doing something special. 1345 00:51:07,720 --> 00:51:09,095 We're not doing something random. 1346 00:51:09,095 --> 00:51:10,660 We have a lot of structure. 1347 00:51:10,660 --> 00:51:11,160 OK. 1348 00:51:11,160 --> 00:51:13,034 The physicists might counter, well, you know, 1349 00:51:13,034 --> 00:51:14,840 you just have a high-dimensional problem. 1350 00:51:14,840 --> 00:51:17,260 The basic intuition that in high dimensions 1351 00:51:17,260 --> 00:51:19,630 it's very hard to have all directions curve up 1352 00:51:19,630 --> 00:51:21,490 at a critical point high error should also 1353 00:51:21,490 --> 00:51:22,780 hold true in your problem. 1354 00:51:22,780 --> 00:51:23,604 OK? 1355 00:51:23,604 --> 00:51:25,270 But, of course, we'll never get anywhere 1356 00:51:25,270 --> 00:51:26,500 if we stop there, right? 1357 00:51:26,500 --> 00:51:29,570 We have to move over into your land which is also my land 1358 00:51:29,570 --> 00:51:31,870 and just simulate the system. 1359 00:51:31,870 --> 00:51:34,180 So oftentimes, you know, biologists and computer 1360 00:51:34,180 --> 00:51:35,554 scientists don't believe a theory 1361 00:51:35,554 --> 00:51:37,360 until they see the simulation. 1362 00:51:37,360 --> 00:51:40,960 So what we'll do is we'll search for critical points 1363 00:51:40,960 --> 00:51:43,042 in the error landscape of deep neural networks. 1364 00:51:43,042 --> 00:51:44,000 And that's what we did. 1365 00:51:44,000 --> 00:51:48,850 So what we did was we used Newton's method 1366 00:51:48,850 --> 00:51:50,200 to find critical points. 1367 00:51:50,200 --> 00:51:51,700 So it turns out that Newton's method 1368 00:51:51,700 --> 00:51:53,710 is attracted to saddles, right? 1369 00:51:53,710 --> 00:51:56,500 So Newton's method will descend in the positive curvature 1370 00:51:56,500 --> 00:51:57,340 direction. 1371 00:51:57,340 --> 00:51:59,740 But it will ascend in the negative curvature direction. 1372 00:51:59,740 --> 00:52:03,490 Because Newton's method is gradient descent multiplied 1373 00:52:03,490 --> 00:52:05,860 by the inverse of the Hessian. 1374 00:52:05,860 --> 00:52:08,147 So if the Hessian has a negative eigenvalue, 1375 00:52:08,147 --> 00:52:09,980 you take a negative gradient and multiply it 1376 00:52:09,980 --> 00:52:12,146 by the negative eigenvalue, and you turn back around 1377 00:52:12,146 --> 00:52:13,270 and you go uphill. 1378 00:52:13,270 --> 00:52:16,390 So Newton's method uncorrected is attracted to saddle points. 1379 00:52:16,390 --> 00:52:17,260 OK? 1380 00:52:17,260 --> 00:52:19,090 So what we did was we looked at the error 1381 00:52:19,090 --> 00:52:21,570 landscape of deep neural networks trained on MNIST 1382 00:52:21,570 --> 00:52:23,680 and CIFAR-10, and we just plotted 1383 00:52:23,680 --> 00:52:27,520 the prediction of random landscape theory, right? 1384 00:52:27,520 --> 00:52:31,810 And what we found was exactly qualitatively their prediction. 1385 00:52:31,810 --> 00:52:33,505 We took each critical point and plot it 1386 00:52:33,505 --> 00:52:35,088 in this two-dimensional feature space. 1387 00:52:35,088 --> 00:52:37,300 And we found that the critical points concentrated 1388 00:52:37,300 --> 00:52:39,874 on a monotonically increase in curve 1389 00:52:39,874 --> 00:52:41,290 which, again, shows that there are 1390 00:52:41,290 --> 00:52:43,990 no local minima at high error. 1391 00:52:43,990 --> 00:52:46,480 And if your a local minimum, your error 1392 00:52:46,480 --> 00:52:50,585 is close to at least the lowest error minimum that we found. 1393 00:52:50,585 --> 00:52:52,960 We can't guarantee that the lowest error minimum we found 1394 00:52:52,960 --> 00:52:53,876 is the global minimum. 1395 00:52:53,876 --> 00:52:56,650 But qualitatively, this structure holds. 1396 00:52:56,650 --> 00:52:57,370 OK? 1397 00:52:57,370 --> 00:52:59,330 Now, the issue is what can we do about it. 1398 00:52:59,330 --> 00:53:00,913 So what this is telling us-- that even 1399 00:53:00,913 --> 00:53:03,660 in these problems of practical interest, 1400 00:53:03,660 --> 00:53:06,340 saddle points might stand as the major impediment 1401 00:53:06,340 --> 00:53:07,960 to optimization, right? 1402 00:53:07,960 --> 00:53:10,945 Because saddle points can trap you. 1403 00:53:10,945 --> 00:53:14,080 You know, you might go down here. 1404 00:53:14,080 --> 00:53:16,221 And then there might be a very slowly curving 1405 00:53:16,221 --> 00:53:18,220 negative curvature direction that might take you 1406 00:53:18,220 --> 00:53:19,600 a while to escape. 1407 00:53:19,600 --> 00:53:21,780 In fact, in the learning dynamics 1408 00:53:21,780 --> 00:53:23,770 that I showed in these transitions 1409 00:53:23,770 --> 00:53:25,679 in learning hierarchical structure, 1410 00:53:25,679 --> 00:53:27,220 the thing controlling the transitions 1411 00:53:27,220 --> 00:53:29,303 was the existence of saddle points in weight space 1412 00:53:29,303 --> 00:53:30,760 of the linear neural network. 1413 00:53:30,760 --> 00:53:33,930 And so the part of no learning corresponded 1414 00:53:33,930 --> 00:53:36,730 to sort of falling down this direction slowly. 1415 00:53:36,730 --> 00:53:39,130 And then the rapid learning corresponded-- 1416 00:53:39,130 --> 00:53:40,940 eventually coming out this way. 1417 00:53:40,940 --> 00:53:41,440 OK. 1418 00:53:41,440 --> 00:53:43,720 So how do we do that? 1419 00:53:43,720 --> 00:53:46,060 Well, what we can do is we can do a simple modification 1420 00:53:46,060 --> 00:53:48,310 to Newton's method, which instead of dividing 1421 00:53:48,310 --> 00:53:51,011 by the Hessian, we divide by the absolute value of the Hessian. 1422 00:53:51,011 --> 00:53:53,510 And, again, I should say that this was done in collaboration 1423 00:53:53,510 --> 00:53:54,817 with Yoshua Bengio's lab. 1424 00:53:54,817 --> 00:53:57,400 And a set a fantastic graduate students in Yoshua Bengio's lab 1425 00:53:57,400 --> 00:53:59,400 did all of this work on the training and testing 1426 00:53:59,400 --> 00:54:02,120 of these predictions. 1427 00:54:02,120 --> 00:54:02,620 OK. 1428 00:54:02,620 --> 00:54:06,820 So what we suggested was, you know, the offending thing 1429 00:54:06,820 --> 00:54:08,332 is dividing by negative eigenvalues. 1430 00:54:08,332 --> 00:54:10,540 So just take the absolute value of the Hessian, which 1431 00:54:10,540 --> 00:54:12,625 by definition I mean take the Hessian, 1432 00:54:12,625 --> 00:54:14,875 compute its eigenvalues, and replace each negative one 1433 00:54:14,875 --> 00:54:16,900 with its absolute value. 1434 00:54:16,900 --> 00:54:17,710 OK? 1435 00:54:17,710 --> 00:54:22,060 So that will obviously get repelled by saddles, all right? 1436 00:54:22,060 --> 00:54:24,070 And that actually works really, really well. 1437 00:54:24,070 --> 00:54:27,100 And there's a way to derive this algorithm 1438 00:54:27,100 --> 00:54:30,100 in a way that makes sense, even far from saddles 1439 00:54:30,100 --> 00:54:32,200 by minimizing a linear approximation to f 1440 00:54:32,200 --> 00:54:34,900 within a trust region in which the linear and quadratic 1441 00:54:34,900 --> 00:54:36,610 approximations agree. 1442 00:54:36,610 --> 00:54:37,240 OK? 1443 00:54:37,240 --> 00:54:39,210 So let me just show you first that it works. 1444 00:54:39,210 --> 00:54:42,790 So this is the most dramatic plot. 1445 00:54:42,790 --> 00:54:46,690 So basically what we did was we did stochastic gradient descent 1446 00:54:46,690 --> 00:54:47,800 for a while. 1447 00:54:47,800 --> 00:54:49,180 And then it seemed like the error 1448 00:54:49,180 --> 00:54:52,330 as a function of training time plateaued both for a deep auto 1449 00:54:52,330 --> 00:54:55,030 encoder and a recurrent neural network problem. 1450 00:54:55,030 --> 00:54:58,420 So when the error as a function of training time plateaus, 1451 00:54:58,420 --> 00:55:00,340 that's sort of interpreted as the fact 1452 00:55:00,340 --> 00:55:03,315 that you're stuck in a local minimum, right? 1453 00:55:03,315 --> 00:55:05,980 But actually, when we switched to this, what we call, 1454 00:55:05,980 --> 00:55:09,190 the saddle-free Newton method, the error suddenly drops again. 1455 00:55:09,190 --> 00:55:11,720 So this was an illusory signature of a local minimum. 1456 00:55:11,720 --> 00:55:14,230 It was actually a saddle point with probably 1457 00:55:14,230 --> 00:55:16,160 a very shallow negative curvature direction 1458 00:55:16,160 --> 00:55:17,200 that was hard to escape. 1459 00:55:17,200 --> 00:55:21,340 And when we switched to our algorithm, we could escape it. 1460 00:55:21,340 --> 00:55:23,520 And, you know, what these curves show 1461 00:55:23,520 --> 00:55:27,040 is that we do do better in the final training error as while. 1462 00:55:27,040 --> 00:55:30,259 So now, how do we train deep neural networks 1463 00:55:30,259 --> 00:55:31,300 with thousands of layers? 1464 00:55:31,300 --> 00:55:34,360 And actually, how do we model complex probability 1465 00:55:34,360 --> 00:55:36,220 distributions? 1466 00:55:36,220 --> 00:55:39,100 So we want to sample from very, very complex probability 1467 00:55:39,100 --> 00:55:43,030 distributions and do complex distributional learning, right? 1468 00:55:43,030 --> 00:55:45,100 So this was done by a fantastic post-doc 1469 00:55:45,100 --> 00:55:48,520 of mine, Jascha Sohl-Dickstein. 1470 00:55:48,520 --> 00:55:51,040 So we were going to Berkeley for this non-equilibrium 1471 00:55:51,040 --> 00:55:54,440 statistical mechanics meetings and things like that. 1472 00:55:54,440 --> 00:55:56,200 And there's been lots of advances 1473 00:55:56,200 --> 00:55:57,866 in non-equilibrium statistical mechanics 1474 00:55:57,866 --> 00:56:00,744 where you can show that, roughly speaking, the second law 1475 00:56:00,744 --> 00:56:03,160 of thermodynamics which says that things get more and more 1476 00:56:03,160 --> 00:56:05,830 disordered with time can be transiently violated 1477 00:56:05,830 --> 00:56:07,570 in small systems or short periods of time 1478 00:56:07,570 --> 00:56:10,150 so you can spontaneously generate order. 1479 00:56:10,150 --> 00:56:12,820 OK. 1480 00:56:12,820 --> 00:56:17,150 So I'll just go through this, again, very quickly. 1481 00:56:17,150 --> 00:56:19,160 So here's the basic idea. 1482 00:56:19,160 --> 00:56:23,110 Let's say you have a complicated probability distribution. 1483 00:56:23,110 --> 00:56:24,250 Let's just destroy it. 1484 00:56:24,250 --> 00:56:28,180 Let's feed that probability distribution through diffusion 1485 00:56:28,180 --> 00:56:30,040 to turn it into a simple distribution, maybe 1486 00:56:30,040 --> 00:56:31,390 an isotropic Gaussian. 1487 00:56:31,390 --> 00:56:34,360 And we keep a record of that destruction of structure. 1488 00:56:34,360 --> 00:56:37,300 And then we try to reverse time in that process 1489 00:56:37,300 --> 00:56:39,820 by using deep neural networks to reverse time and then 1490 00:56:39,820 --> 00:56:41,980 essentially create structure from noise. 1491 00:56:41,980 --> 00:56:44,230 And then you have a very, very simple way to sample 1492 00:56:44,230 --> 00:56:45,605 from complex distributions if you 1493 00:56:45,605 --> 00:56:48,509 can train the neural network, which is you just sample noise. 1494 00:56:48,509 --> 00:56:50,800 And you feed it through a deterministic neural network. 1495 00:56:50,800 --> 00:56:53,611 And that constitutes a sample from a complex distribution. 1496 00:56:53,611 --> 00:56:55,360 And so this was inspired by recent results 1497 00:56:55,360 --> 00:56:56,950 in non-equilibrium stat mech. 1498 00:56:56,950 --> 00:56:58,660 So the basic idea, again, is let's 1499 00:56:58,660 --> 00:57:00,730 imagine that you have a very complex distribution 1500 00:57:00,730 --> 00:57:03,310 corresponding to this density of dye. 1501 00:57:03,310 --> 00:57:04,624 You diffuse for a while. 1502 00:57:04,624 --> 00:57:06,040 It becomes a simpler distribution. 1503 00:57:06,040 --> 00:57:07,540 Eventually, they become uniform. 1504 00:57:07,540 --> 00:57:10,630 You keep a [AUDIO OUT] Now, if you 1505 00:57:10,630 --> 00:57:13,840 reverse process of diffusion, you'll never go from this back 1506 00:57:13,840 --> 00:57:14,770 to this. 1507 00:57:14,770 --> 00:57:17,320 But if you reverse process a neural network trained to do 1508 00:57:17,320 --> 00:57:19,910 it, you might be able to do it. 1509 00:57:19,910 --> 00:57:22,220 So that's the basic idea. 1510 00:57:22,220 --> 00:57:23,230 So that's what we did. 1511 00:57:23,230 --> 00:57:26,547 And I'll just show you some nice movies to show that it works. 1512 00:57:26,547 --> 00:57:27,880 This is the classical toy model. 1513 00:57:27,880 --> 00:57:29,213 We'll go to more complex models. 1514 00:57:29,213 --> 00:57:32,650 This is a sample distribution in two-dimensional space. 1515 00:57:32,650 --> 00:57:35,560 And so what we do is we just systematically have the points 1516 00:57:35,560 --> 00:57:38,830 diffuse under Gaussian diffusion with a restoring force 1517 00:57:38,830 --> 00:57:39,920 to the origin. 1518 00:57:39,920 --> 00:57:43,390 So the stationary distribution of that destructive process 1519 00:57:43,390 --> 00:57:45,530 is an isotropic Gaussian. 1520 00:57:45,530 --> 00:57:46,030 OK? 1521 00:57:46,030 --> 00:57:48,280 And that's what happened. 1522 00:57:48,280 --> 00:57:50,020 So that's our training data. 1523 00:57:50,020 --> 00:57:52,190 The entire movie is the training data. 1524 00:57:52,190 --> 00:57:53,260 OK? 1525 00:57:53,260 --> 00:57:55,240 Then what we do is we train a neural network 1526 00:57:55,240 --> 00:57:58,150 to reverse time in that movie. 1527 00:57:58,150 --> 00:58:00,790 So it's a neural network with many, many layers-- hundreds 1528 00:58:00,790 --> 00:58:02,086 and hundreds of layers, right? 1529 00:58:02,086 --> 00:58:03,460 So classically training a network 1530 00:58:03,460 --> 00:58:05,650 with hundreds of layers, you have the credit assignment 1531 00:58:05,650 --> 00:58:05,950 problem. 1532 00:58:05,950 --> 00:58:08,230 Because you don't know what the intermediate neurons 1533 00:58:08,230 --> 00:58:09,370 are supposed to do. 1534 00:58:09,370 --> 00:58:12,280 You can circumvent the credit assignment problem, 1535 00:58:12,280 --> 00:58:15,250 because each layer going up to the next layer just 1536 00:58:15,250 --> 00:58:19,090 has to go from time t to time t minus 1 in the training data. 1537 00:58:19,090 --> 00:58:21,850 So you have targets for all the intermediate layers. 1538 00:58:21,850 --> 00:58:24,600 Therefore, you've circumvented the credit assignment problem. 1539 00:58:24,600 --> 00:58:25,360 OK? 1540 00:58:25,360 --> 00:58:27,940 So it's relatively easy to train such networks. 1541 00:58:27,940 --> 00:58:29,620 And so once you have such a network, 1542 00:58:29,620 --> 00:58:31,550 what should you be able to do? 1543 00:58:31,550 --> 00:58:33,880 You should be able to feed that neural network 1544 00:58:33,880 --> 00:58:38,320 an isotropic Gaussian, and then have that Gaussian be turned 1545 00:58:38,320 --> 00:58:39,642 into the data distribution. 1546 00:58:39,642 --> 00:58:40,600 So that's what happens. 1547 00:58:40,600 --> 00:58:43,252 This on the right is a different Gaussian. 1548 00:58:43,252 --> 00:58:45,460 And we just feed it through the trained deterministic 1549 00:58:45,460 --> 00:58:47,770 neural network. 1550 00:58:47,770 --> 00:58:50,350 And out pops the structure. 1551 00:58:50,350 --> 00:58:51,854 It's not perfect. 1552 00:58:51,854 --> 00:58:53,770 There are some data points that are over here. 1553 00:58:53,770 --> 00:58:56,020 But this is roughly the distribution that it learned, 1554 00:58:56,020 --> 00:58:58,380 which is similar to what it was trained on. 1555 00:58:58,380 --> 00:58:59,080 OK? 1556 00:58:59,080 --> 00:59:02,590 So now we can look at slightly more complicated distributions. 1557 00:59:08,428 --> 00:59:10,800 OK. 1558 00:59:10,800 --> 00:59:11,460 So that's that. 1559 00:59:11,460 --> 00:59:14,980 So now we can train it on a toy model of natural images, right? 1560 00:59:14,980 --> 00:59:17,040 So a classic toy model of natural images 1561 00:59:17,040 --> 00:59:19,350 is the dead leaves model where the sampling 1562 00:59:19,350 --> 00:59:22,380 process is you just throw down circles of different radii. 1563 00:59:22,380 --> 00:59:25,080 So you get a complex model of natural images 1564 00:59:25,080 --> 00:59:27,660 that has long range edges, occlusion, 1565 00:59:27,660 --> 00:59:31,050 coherence over long length scales, and so on and so forth. 1566 00:59:31,050 --> 00:59:34,230 So we can train the neural network on such distributions. 1567 00:59:34,230 --> 00:59:36,510 We train it in a convolutional fashion 1568 00:59:36,510 --> 00:59:38,130 by working on local image patches. 1569 00:59:38,130 --> 00:59:39,629 And we convolve, so information will 1570 00:59:39,629 --> 00:59:41,310 propagate over long ranges. 1571 00:59:41,310 --> 00:59:43,680 And so, again, we take these natural images 1572 00:59:43,680 --> 00:59:46,476 and turn them into noise, keep a record of that movie, 1573 00:59:46,476 --> 00:59:47,850 and then reverse the flow of time 1574 00:59:47,850 --> 00:59:49,016 using a deep neural network. 1575 00:59:51,460 --> 00:59:51,960 OK. 1576 00:59:51,960 --> 00:59:57,360 So once we train that, we should be able to turn noise 1577 00:59:57,360 --> 01:00:00,120 into the networks best guess as to what a dead leaves 1578 01:00:00,120 --> 01:00:01,600 model would look like. 1579 01:00:01,600 --> 01:00:03,270 So this is what happens. 1580 01:00:03,270 --> 01:00:06,390 It's taking noise, and it turns it into a gas. 1581 01:00:06,390 --> 01:00:07,020 OK. 1582 01:00:07,020 --> 01:00:09,190 So it's not a perfect model. 1583 01:00:09,190 --> 01:00:12,240 But it turns out log probability of dead leaves under this 1584 01:00:12,240 --> 01:00:15,480 generative model that consists of 1,000-layer deep neural 1585 01:00:15,480 --> 01:00:19,840 network, that's higher than any other model so far. 1586 01:00:19,840 --> 01:00:21,940 So this is currently state of the art. 1587 01:00:21,940 --> 01:00:22,440 OK. 1588 01:00:22,440 --> 01:00:24,990 And as you can see, it gets long-range coherence 1589 01:00:24,990 --> 01:00:26,130 and sharp edges. 1590 01:00:26,130 --> 01:00:28,590 And moreover, it gets long-range coherence 1591 01:00:28,590 --> 01:00:30,570 in the orientation of [AUDIO OUT] often hard 1592 01:00:30,570 --> 01:00:34,510 to do in generative models of natural images. 1593 01:00:34,510 --> 01:00:35,010 OK. 1594 01:00:35,010 --> 01:00:37,230 Now, we can actually do something somewhat practical 1595 01:00:37,230 --> 01:00:42,000 with this is we can sample from the conditional. 1596 01:00:42,000 --> 01:00:43,810 So then we also trained it on textures. 1597 01:00:43,810 --> 01:00:44,310 OK? 1598 01:00:47,460 --> 01:00:50,010 So, for example, textures of bark, right? 1599 01:00:50,010 --> 01:00:52,720 And we can also sample from the conditional distribution. 1600 01:00:52,720 --> 01:00:56,730 So what we can do is we can clamp the pixels outside 1601 01:00:56,730 --> 01:00:58,440 of a certain range, replace the interior 1602 01:00:58,440 --> 01:01:00,540 with [AUDIO OUT] make it blank. 1603 01:01:00,540 --> 01:01:03,150 And because the network operates in a convolutional fashion, 1604 01:01:03,150 --> 01:01:04,650 information from the boundary should 1605 01:01:04,650 --> 01:01:07,890 propagate into the interior and fill it in, right? 1606 01:01:07,890 --> 01:01:10,830 So if we look at that, so that's white noise. 1607 01:01:10,830 --> 01:01:13,110 And the network is filling it in. 1608 01:01:13,110 --> 01:01:17,291 And it fills in the best guess image. 1609 01:01:17,291 --> 01:01:17,790 OK. 1610 01:01:17,790 --> 01:01:19,200 And so, again, it's not identical 1611 01:01:19,200 --> 01:01:21,570 to the original image, but it does get long-range edge 1612 01:01:21,570 --> 01:01:24,930 structure, coherence in the orientation of the edge, 1613 01:01:24,930 --> 01:01:26,454 and smooth structure as well. 1614 01:01:26,454 --> 01:01:28,620 And, again, this is like 1,000-layer neural network. 1615 01:01:28,620 --> 01:01:29,911 Now, there's some lessons here. 1616 01:01:29,911 --> 01:01:30,870 OK? 1617 01:01:30,870 --> 01:01:33,810 Oftentimes when we model complex data distributions, what 1618 01:01:33,810 --> 01:01:38,030 we try to do is we try to create a stochastic process whose 1619 01:01:38,030 --> 01:01:42,300 stationery distribution is the complex distribution, right? 1620 01:01:42,300 --> 01:01:44,420 Now, if your distribution has multiple modes, 1621 01:01:44,420 --> 01:01:46,170 you're going to run into a mixing problem, 1622 01:01:46,170 --> 01:01:48,960 because it can take a stochastic process a long time 1623 01:01:48,960 --> 01:01:50,730 to jump over energy barriers that 1624 01:01:50,730 --> 01:01:52,470 separate the multiple modes. 1625 01:01:52,470 --> 01:01:53,970 So you always have a mixing problem. 1626 01:01:53,970 --> 01:01:57,030 And oftentimes when you train probabilistic models, 1627 01:01:57,030 --> 01:02:00,510 you have to sample and then the samples to train the model. 1628 01:02:00,510 --> 01:02:05,320 So that makes training take a long time. 1629 01:02:05,320 --> 01:02:07,560 So what we're also doing in addition to circumventing 1630 01:02:07,560 --> 01:02:09,310 the credit assignment problem and training 1631 01:02:09,310 --> 01:02:11,310 very deep neural networks, we're circumventing 1632 01:02:11,310 --> 01:02:14,370 the mixing problem in training the generative model. 1633 01:02:14,370 --> 01:02:16,360 Because we're not trying to model the data 1634 01:02:16,360 --> 01:02:18,120 distribution as a stationary distribution 1635 01:02:18,120 --> 01:02:19,179 of a stochastic process. 1636 01:02:19,179 --> 01:02:20,970 That would have to run for a very long time 1637 01:02:20,970 --> 01:02:22,720 to get to the stationary distribution. 1638 01:02:22,720 --> 01:02:25,200 We're demanding that during training the process 1639 01:02:25,200 --> 01:02:29,820 get to the data distribution in a finite amount of time, right? 1640 01:02:29,820 --> 01:02:32,400 So because during training we demand that we get to the data 1641 01:02:32,400 --> 01:02:34,230 distribution a finite amount of time, 1642 01:02:34,230 --> 01:02:37,170 we're circumventing the mixing problem during training. 1643 01:02:37,170 --> 01:02:38,040 And that's the idea. 1644 01:02:38,040 --> 01:02:39,930 That's [AUDIO OUT] an idea. 1645 01:02:39,930 --> 01:02:41,940 There's lots of results now that show 1646 01:02:41,940 --> 01:02:45,150 that you can attain information about stationary equilibrium 1647 01:02:45,150 --> 01:02:48,250 distributions from non-equilibrium trajectories. 1648 01:02:48,250 --> 01:02:48,750 OK. 1649 01:02:48,750 --> 01:02:50,860 So now, I'm done. 1650 01:02:50,860 --> 01:02:51,750 So let's see. 1651 01:02:51,750 --> 01:02:52,290 OK. 1652 01:02:52,290 --> 01:02:54,290 So there's that. 1653 01:02:54,290 --> 01:02:54,790 OK. 1654 01:02:54,790 --> 01:02:56,310 So, again, you can read about all 1655 01:02:56,310 --> 01:02:58,620 of this stuff in this set of papers. 1656 01:02:58,620 --> 01:03:01,350 Again, I'd like to thank my funding and just 1657 01:03:01,350 --> 01:03:02,220 the key players. 1658 01:03:02,220 --> 01:03:04,690 So Andrew Saxe, you know, did the work with me 1659 01:03:04,690 --> 01:03:07,530 on non-linear learning dynamics and learning 1660 01:03:07,530 --> 01:03:09,110 hierarchical category structure. 1661 01:03:09,110 --> 01:03:11,520 Jascha Sohl-Dickstein did the work 1662 01:03:11,520 --> 01:03:15,150 on deep learning using non-equilibrium thermodynamics. 1663 01:03:15,150 --> 01:03:18,195 And the work on saddle points was a nice collaboration 1664 01:03:18,195 --> 01:03:20,040 with Yoshua Bengio's lab. 1665 01:03:20,040 --> 01:03:22,000 And again-- fantastic graduate students 1666 01:03:22,000 --> 01:03:24,130 in Yoshua Bengio's lab. 1667 01:03:24,130 --> 01:03:24,630 OK. 1668 01:03:24,630 --> 01:03:26,160 So I think there's a lot more to do 1669 01:03:26,160 --> 01:03:28,020 in terms of unifying neuroscience, machine 1670 01:03:28,020 --> 01:03:30,360 learning, physics, math, statistics, all of that stuff. 1671 01:03:30,360 --> 01:03:33,260 It'll keep us busy for the next century.