1 00:00:00,000 --> 00:00:01,952 [SQUEAKING] 2 00:00:01,952 --> 00:00:03,904 [RUSTLING] 3 00:00:03,904 --> 00:00:04,880 [CLICKING] 4 00:00:15,128 --> 00:00:18,070 DAVID SONTAG: OK, so then today's lecture 5 00:00:18,070 --> 00:00:21,640 is going to be about data set shifts, specifically 6 00:00:21,640 --> 00:00:25,120 how one can be robust to data set shift. 7 00:00:25,120 --> 00:00:27,220 Now, this is the topic that we've been alluding to 8 00:00:27,220 --> 00:00:30,250 throughout the semester. 9 00:00:30,250 --> 00:00:33,402 And the setting that I want you to be thinking about 10 00:00:33,402 --> 00:00:33,985 is as follows. 11 00:00:37,150 --> 00:00:40,180 You're a data scientist working at, let's say, 12 00:00:40,180 --> 00:00:44,740 Mass General Hospital, and you've 13 00:00:44,740 --> 00:00:47,530 been very careful in setting up your machine learning task 14 00:00:47,530 --> 00:00:51,100 to make sure that the data is well specified, 15 00:00:51,100 --> 00:00:54,280 the labels that you're trying to predict are well specified. 16 00:00:54,280 --> 00:00:55,480 You train on a valid-- 17 00:00:55,480 --> 00:00:58,808 you train on your training data, you test it on a held-out set, 18 00:00:58,808 --> 00:01:00,475 you see that the model generalizes well, 19 00:01:00,475 --> 00:01:02,350 you do chart review to make sure what you're 20 00:01:02,350 --> 00:01:05,440 predicting is actually what you think you're predicting, 21 00:01:05,440 --> 00:01:08,320 and you even do prospective deployment where you then 22 00:01:08,320 --> 00:01:10,242 let your machine learning algorithm drive 23 00:01:10,242 --> 00:01:11,950 some clinical decision support, and you'd 24 00:01:11,950 --> 00:01:14,860 see things are working great. 25 00:01:14,860 --> 00:01:17,020 Now what? 26 00:01:17,020 --> 00:01:21,590 What happens after this stage when you go to deployment? 27 00:01:21,590 --> 00:01:25,000 What happens when your same model 28 00:01:25,000 --> 00:01:27,040 is going to be used not just tomorrow 29 00:01:27,040 --> 00:01:30,220 but also next week, the following week, the next year? 30 00:01:30,220 --> 00:01:32,620 What happens if your model, which is working well 31 00:01:32,620 --> 00:01:36,610 at this one hospital, then wants to-- then there's 32 00:01:36,610 --> 00:01:38,545 another institution, say, maybe Brigham 33 00:01:38,545 --> 00:01:41,440 and Women's Hospital, or maybe UCSF, 34 00:01:41,440 --> 00:01:43,540 or some rural hospital in the United States 35 00:01:43,540 --> 00:01:46,000 wants to use the same model, will it 36 00:01:46,000 --> 00:01:49,390 keep working in this "short term to the future" time period 37 00:01:49,390 --> 00:01:50,685 or in a new institution? 38 00:01:50,685 --> 00:01:52,810 That's the question which we're going to be talking 39 00:01:52,810 --> 00:01:54,010 about in today's lecture. 40 00:01:54,010 --> 00:01:55,810 And we'll be talking about how one 41 00:01:55,810 --> 00:01:59,990 can deal with data set shift of two different varieties. 42 00:01:59,990 --> 00:02:03,510 The first variety is adversarial perturbations to data, 43 00:02:03,510 --> 00:02:06,280 and the second variety is data due to-- the data that 44 00:02:06,280 --> 00:02:09,250 changes for natural reasons. 45 00:02:09,250 --> 00:02:11,500 Now, the reason why it's not at all obvious 46 00:02:11,500 --> 00:02:13,600 that your machine learning algorithm should still 47 00:02:13,600 --> 00:02:16,480 work in the setting is because the number one assumption 48 00:02:16,480 --> 00:02:18,310 we make when we do machine learning 49 00:02:18,310 --> 00:02:20,680 is that your training distribution, your training 50 00:02:20,680 --> 00:02:24,593 data, is drawn from the same distribution as your test data. 51 00:02:24,593 --> 00:02:27,010 So if you now go to a setting where your data distribution 52 00:02:27,010 --> 00:02:33,190 has changed, even if you've computed your accuracy using 53 00:02:33,190 --> 00:02:35,807 your held-out data and it looks good, 54 00:02:35,807 --> 00:02:37,390 there's no reason that should continue 55 00:02:37,390 --> 00:02:40,030 to look good in this new setting, where the data 56 00:02:40,030 --> 00:02:42,580 distribution has changed. 57 00:02:42,580 --> 00:02:45,305 A simple example of what it means for a data distribution 58 00:02:45,305 --> 00:02:46,555 to change might be as follows. 59 00:02:51,510 --> 00:02:57,620 Suppose that we have as input data, 60 00:02:57,620 --> 00:03:03,980 and we're trying to predict some label, which 61 00:03:03,980 --> 00:03:12,060 maybe meant something like, why if a patient has-- 62 00:03:12,060 --> 00:03:14,880 or will be newly diagnosed with type 2 diabetes, 63 00:03:14,880 --> 00:03:18,650 and this is an example which we-- 64 00:03:18,650 --> 00:03:23,180 which we talked about when we introduce risk stratification, 65 00:03:23,180 --> 00:03:27,470 you learn a model to predict y from x. 66 00:03:27,470 --> 00:03:30,080 And now suppose you go to a new institution 67 00:03:30,080 --> 00:03:33,560 where their definition of what type 2 diabetes means 68 00:03:33,560 --> 00:03:35,486 has changed. 69 00:03:35,486 --> 00:03:41,870 For example, maybe they don't actually have type 2 diabetes 70 00:03:41,870 --> 00:03:45,800 coded in their data, maybe they only have diabetes 71 00:03:45,800 --> 00:03:49,010 coded in their data, which is lumping together 72 00:03:49,010 --> 00:03:51,320 both type 1 and type 2 diabetes, type 1 73 00:03:51,320 --> 00:03:56,330 being what's usually juvenile diabetes 74 00:03:56,330 --> 00:03:59,790 and is actually a very distinct disease from type 2 diabetes. 75 00:03:59,790 --> 00:04:02,450 So now the notion of what diabetes is is different. 76 00:04:02,450 --> 00:04:04,495 Maybe the use case is also slightly different. 77 00:04:04,495 --> 00:04:05,870 And there's no reason, obviously, 78 00:04:05,870 --> 00:04:08,480 that your model, which was used to predict type 2 diabetes, 79 00:04:08,480 --> 00:04:10,820 would work for that new label. 80 00:04:10,820 --> 00:04:12,710 Now, this is an example of a very type-- 81 00:04:12,710 --> 00:04:16,730 of a type of data set shift which is perhaps 82 00:04:16,730 --> 00:04:19,730 for you obvious nothing should work in the setting 83 00:04:19,730 --> 00:04:29,540 because here the distribution of P of y given x changes, 84 00:04:29,540 --> 00:04:32,840 meaning even if you have the same individual, 85 00:04:32,840 --> 00:04:35,900 your distribution P(y) given x in, let's say, 86 00:04:35,900 --> 00:04:40,325 the distribution P(0) and the distribution P of y given x 87 00:04:40,325 --> 00:04:42,700 and P(1), where this is, let's say, one institution, this 88 00:04:42,700 --> 00:04:45,110 is another, these now are two different distributions 89 00:04:45,110 --> 00:04:47,220 if the meaning of the label has changed. 90 00:04:47,220 --> 00:04:50,720 So for the same person, there might be different distribution 91 00:04:50,720 --> 00:04:52,800 over what y is. 92 00:04:52,800 --> 00:04:54,545 So this is one type of data shift. 93 00:04:54,545 --> 00:04:55,920 And a very different type of data 94 00:04:55,920 --> 00:04:59,313 set shift is where we assume that these two are equal. 95 00:04:59,313 --> 00:05:00,980 And so that would, for example, rule out 96 00:05:00,980 --> 00:05:03,150 this type of data set shift. 97 00:05:03,150 --> 00:05:11,240 But rather what changes is P of x from location 1 to location-- 98 00:05:11,240 --> 00:05:14,270 to location 2. 99 00:05:14,270 --> 00:05:18,010 And this is the type of data set shift which will be focused on 100 00:05:18,010 --> 00:05:18,860 in today's lecture. 101 00:05:18,860 --> 00:05:21,200 It goes by the name of covariate shift. 102 00:05:27,260 --> 00:05:31,190 And let's look at two different examples of that. 103 00:05:31,190 --> 00:05:34,730 The first example would be of an adversarial perturbation. 104 00:05:34,730 --> 00:05:39,080 And so we've-- you've all seen the use of convolutional neural 105 00:05:39,080 --> 00:05:41,540 networks for image classification problems. 106 00:05:41,540 --> 00:05:44,360 This is just one illustration of such an architecture. 107 00:05:44,360 --> 00:05:45,890 And with such an architecture, one 108 00:05:45,890 --> 00:05:48,140 could then attempt to do all sorts of different object 109 00:05:48,140 --> 00:05:50,900 classification or image classification tasks. 110 00:05:50,900 --> 00:05:54,260 You could take as input this picture of a dog, which 111 00:05:54,260 --> 00:05:57,980 is clearly a dog. 112 00:05:57,980 --> 00:06:01,220 And you could modify it just a little bit. 113 00:06:01,220 --> 00:06:05,258 Just add in a very small amount of noise. 114 00:06:05,258 --> 00:06:06,800 What I'm going to do is now I'm going 115 00:06:06,800 --> 00:06:11,880 to create a new image which is that original image. 116 00:06:11,880 --> 00:06:13,700 Now with every single pixel, I'm going 117 00:06:13,700 --> 00:06:17,840 to add a very small epsilon in the direction of that noise. 118 00:06:17,840 --> 00:06:20,760 And what you get out is this new image, 119 00:06:20,760 --> 00:06:22,718 which you could stare at however long you want, 120 00:06:22,718 --> 00:06:24,718 you're not going to able to tell the difference. 121 00:06:24,718 --> 00:06:26,420 Basically to the human eye, these two 122 00:06:26,420 --> 00:06:29,240 look exactly identical. 123 00:06:29,240 --> 00:06:33,680 Except when you take your machine learning classifier, 124 00:06:33,680 --> 00:06:36,980 which is trained on original unperturbed data, 125 00:06:36,980 --> 00:06:39,330 and now apply it to this new image, 126 00:06:39,330 --> 00:06:40,580 it's classified as an ostrich. 127 00:06:43,490 --> 00:06:46,760 And this observation was published 128 00:06:46,760 --> 00:06:49,940 in a paper in 2014 called "Intriguing properties 129 00:06:49,940 --> 00:06:52,280 of neural networks." 130 00:06:52,280 --> 00:06:58,255 And it really kickstarted a huge surge of interest 131 00:06:58,255 --> 00:06:59,630 in the machine learning community 132 00:06:59,630 --> 00:07:04,350 on adversarial perturbations to machine learning. 133 00:07:04,350 --> 00:07:07,730 So asking questions, if you were to perturb inputs just 134 00:07:07,730 --> 00:07:09,680 a little bit, how does that change 135 00:07:09,680 --> 00:07:11,090 your classifier's output? 136 00:07:11,090 --> 00:07:15,140 And could that be used to attack machine learning algorithms? 137 00:07:15,140 --> 00:07:18,032 And how can one defend against it? 138 00:07:18,032 --> 00:07:19,740 By the way, as an aside, this is actually 139 00:07:19,740 --> 00:07:22,580 a very old area of research. 140 00:07:22,580 --> 00:07:26,005 And even back in the land of linear classifiers, 141 00:07:26,005 --> 00:07:27,380 these questions had been studied. 142 00:07:27,380 --> 00:07:30,870 Although I won't get into it in this course. 143 00:07:30,870 --> 00:07:32,960 So this is a type of data set shift in the sense 144 00:07:32,960 --> 00:07:36,920 that what we want is that this should still be classified 145 00:07:36,920 --> 00:07:40,190 as an ostrich-- as a dog. 146 00:07:40,190 --> 00:07:42,210 So the actual label hasn't changed. 147 00:07:42,210 --> 00:07:44,720 We would like this distribution over the labels, given 148 00:07:44,720 --> 00:07:46,850 the perturbed into it, to be slightly different, 149 00:07:46,850 --> 00:07:49,670 except that now the distribution of inputs 150 00:07:49,670 --> 00:07:51,620 is a little bit different because we're 151 00:07:51,620 --> 00:07:55,537 allowing for some noise to be added to each of the inputs. 152 00:07:55,537 --> 00:07:57,620 And in this case, the noise actually isn't random, 153 00:07:57,620 --> 00:07:58,310 it's adversarial. 154 00:07:58,310 --> 00:07:59,935 And towards the end of today's lecture, 155 00:07:59,935 --> 00:08:02,540 I'll give you an example of how one can actually 156 00:08:02,540 --> 00:08:04,550 generate the adversarial image, which 157 00:08:04,550 --> 00:08:06,680 can change the classifier. 158 00:08:06,680 --> 00:08:08,780 Now, the reason why we should care 159 00:08:08,780 --> 00:08:11,420 about these types of things in this course 160 00:08:11,420 --> 00:08:14,580 are because I expect that this type of data 161 00:08:14,580 --> 00:08:18,020 set shift, which is not at all natural, it's adversarial, 162 00:08:18,020 --> 00:08:21,950 is also going to start showing up in both computer 163 00:08:21,950 --> 00:08:26,570 vision and non-computer vision problems in the medical domain. 164 00:08:26,570 --> 00:08:32,585 There was a nice paper by Sam Finlayson, Andy Beam, and Isaac 165 00:08:32,585 --> 00:08:37,039 Kohane recently, which presented several different case 166 00:08:37,039 --> 00:08:40,370 studies of where these problems could really 167 00:08:40,370 --> 00:08:42,658 arise in health care. 168 00:08:42,658 --> 00:08:44,450 So, for example, here what we're looking at 169 00:08:44,450 --> 00:08:45,920 is an image classification problem 170 00:08:45,920 --> 00:08:47,780 arising from dermatology. 171 00:08:47,780 --> 00:08:52,130 You're given as input an image. 172 00:08:52,130 --> 00:08:57,500 For example, you would like that this image be classified 173 00:08:57,500 --> 00:09:01,490 as an individual having a particular type of skin 174 00:09:01,490 --> 00:09:05,300 disorder, a nevus, and this other image, melanoma. 175 00:09:05,300 --> 00:09:09,230 And what one can see is that with a small perturbation 176 00:09:09,230 --> 00:09:13,730 of the input, one can completely swap 177 00:09:13,730 --> 00:09:17,080 the label that would be assigned to it from one to the other. 178 00:09:19,460 --> 00:09:20,960 And in this paper, which we're going 179 00:09:20,960 --> 00:09:24,290 to post as optional readings for today's course, 180 00:09:24,290 --> 00:09:27,650 they talk about how one could maliciously use 181 00:09:27,650 --> 00:09:31,320 these algorithms for benefit. 182 00:09:31,320 --> 00:09:36,890 So, for example, imagine that a health insurance company now 183 00:09:36,890 --> 00:09:44,330 decides in order to reimburse for an expensive biopsy 184 00:09:44,330 --> 00:09:50,600 of a patient's skin, a clinician or a nurse 185 00:09:50,600 --> 00:09:55,970 must first take a picture of the disorder 186 00:09:55,970 --> 00:10:00,470 and submit that picture together with the bill 187 00:10:00,470 --> 00:10:02,450 for the procedure. 188 00:10:02,450 --> 00:10:04,970 And imagine now that the insurance company were 189 00:10:04,970 --> 00:10:07,895 to have a machine learning algorithm be 190 00:10:07,895 --> 00:10:12,200 an automatic check, was this procedure actually reasonable 191 00:10:12,200 --> 00:10:15,110 for this condition? 192 00:10:15,110 --> 00:10:19,760 And if it isn't, it might be flagged. 193 00:10:19,760 --> 00:10:25,660 Now, a malicious user could perturb the input such 194 00:10:25,660 --> 00:10:28,630 that it would, despite the patient having perhaps even 195 00:10:28,630 --> 00:10:32,560 completely normal-looking skin, could nonetheless 196 00:10:32,560 --> 00:10:34,930 be classified by a machine learning algorithm 197 00:10:34,930 --> 00:10:37,300 as being abnormal in some way, and thus 198 00:10:37,300 --> 00:10:41,080 perhaps could get reimbursed by that procedure. 199 00:10:41,080 --> 00:10:44,560 Now, obviously this is an example 200 00:10:44,560 --> 00:10:47,320 of a nefarious setting where we would then 201 00:10:47,320 --> 00:10:51,220 hope that such an individual would be caught 202 00:10:51,220 --> 00:10:53,793 by the police, sent to jail. 203 00:10:53,793 --> 00:10:55,960 But nonetheless, what we would like to be able to do 204 00:10:55,960 --> 00:10:58,930 is build checks and balances into the system such 205 00:10:58,930 --> 00:11:01,780 that that couldn't even happen because to a human 206 00:11:01,780 --> 00:11:05,860 it's obvious that you shouldn't be able to trick-- 207 00:11:05,860 --> 00:11:09,022 trick anyone with such a very minor perturbation. 208 00:11:09,022 --> 00:11:10,480 So how do you build algorithms that 209 00:11:10,480 --> 00:11:12,440 could also be not tricked as easily 210 00:11:12,440 --> 00:11:13,690 as humans wouldn't be tracked? 211 00:11:13,690 --> 00:11:15,160 AUDIENCE: Can I ask a question 212 00:11:15,160 --> 00:11:15,460 DAVID SONTAG: Yeah. 213 00:11:15,460 --> 00:11:17,140 AUDIENCE: For any of these samples, 214 00:11:17,140 --> 00:11:20,785 did the attacker need access to the network? 215 00:11:20,785 --> 00:11:22,202 Is there a way to [? attack it? ?] 216 00:11:22,202 --> 00:11:24,660 DAVID SONTAG: So the question is whether the attacker needs 217 00:11:24,660 --> 00:11:26,950 to know something about the function that's 218 00:11:26,950 --> 00:11:29,620 being used for classifying. 219 00:11:29,620 --> 00:11:33,190 There are examples of both what are called white box and black 220 00:11:33,190 --> 00:11:38,860 box attacks, where in one setting you have access 221 00:11:38,860 --> 00:11:43,340 to the function and other settings you don't. 222 00:11:43,340 --> 00:11:45,550 And so both have been studied in the literature, 223 00:11:45,550 --> 00:11:47,470 and there are results showing that one 224 00:11:47,470 --> 00:11:49,630 can attack in either setting. 225 00:11:49,630 --> 00:11:51,970 Sometimes you might need to know a little bit more. 226 00:11:51,970 --> 00:11:53,512 Like, for example, sometimes you need 227 00:11:53,512 --> 00:11:55,960 to have the ability to query the function a certain number 228 00:11:55,960 --> 00:11:56,698 of times. 229 00:11:56,698 --> 00:11:58,990 So even if you don't know exactly what the function is, 230 00:11:58,990 --> 00:12:01,450 like you don't know the weights of the neural network, 231 00:12:01,450 --> 00:12:04,630 as long as you can query it sufficiently many times, 232 00:12:04,630 --> 00:12:07,685 you'll be able to construct adversarial examples. 233 00:12:07,685 --> 00:12:08,810 That would be one approach. 234 00:12:08,810 --> 00:12:10,060 Another approach would be, oh, maybe we 235 00:12:10,060 --> 00:12:11,050 don't know the function, but we know something 236 00:12:11,050 --> 00:12:12,650 about the training data. 237 00:12:12,650 --> 00:12:16,060 So there are ways to go about doing this even if you don't 238 00:12:16,060 --> 00:12:17,870 perfectly know the function. 239 00:12:17,870 --> 00:12:19,162 Does that answer your question? 240 00:12:21,950 --> 00:12:25,092 So what about a natural perturbation? 241 00:12:25,092 --> 00:12:26,800 So this figure just pulled from lecture 5 242 00:12:26,800 --> 00:12:28,360 when we talked about non-stationarity 243 00:12:28,360 --> 00:12:30,550 in the context of risk stratification, that's 244 00:12:30,550 --> 00:12:34,450 just to remind you here the x-axis is time, that y-axis is 245 00:12:34,450 --> 00:12:37,000 different types of laboratory test results 246 00:12:37,000 --> 00:12:41,755 that might be ordered, and the color denotes 247 00:12:41,755 --> 00:12:44,020 how many of those laboratory tests 248 00:12:44,020 --> 00:12:47,630 were ordered in a certain population at a point in time. 249 00:12:47,630 --> 00:12:50,980 So what we would expect to see if the data was stationary 250 00:12:50,980 --> 00:12:54,010 is that every row would be a homogeneous color. 251 00:12:54,010 --> 00:12:56,440 But instead what we see is that there are points in time, 252 00:12:56,440 --> 00:12:58,780 for example, a few month integrals over here, 253 00:12:58,780 --> 00:13:02,890 when suddenly it looks like, for some of the laboratory tests, 254 00:13:02,890 --> 00:13:05,800 they were never performed. 255 00:13:05,800 --> 00:13:08,590 That's most likely due to a data problem, 256 00:13:08,590 --> 00:13:11,800 or perhaps the feed of data from that laboratory test provider 257 00:13:11,800 --> 00:13:14,147 got lost, there were some systems problem. 258 00:13:14,147 --> 00:13:15,980 But they're also going to be settings where, 259 00:13:15,980 --> 00:13:18,160 for example, a laboratory test is never 260 00:13:18,160 --> 00:13:19,660 used until it's suddenly used. 261 00:13:19,660 --> 00:13:21,100 And that might be because it's a new test that 262 00:13:21,100 --> 00:13:23,170 was just invented or approved for reimbursement 263 00:13:23,170 --> 00:13:24,560 at that point in time. 264 00:13:24,560 --> 00:13:26,410 So this is an example of non-stationarity. 265 00:13:26,410 --> 00:13:28,540 And, of course, this could also result 266 00:13:28,540 --> 00:13:31,130 in changes in your data distribution, 267 00:13:31,130 --> 00:13:35,380 such as what I described over there, over time. 268 00:13:35,380 --> 00:13:37,090 And the third example is when you then 269 00:13:37,090 --> 00:13:39,400 go across institutions, wherein, of course, 270 00:13:39,400 --> 00:13:41,650 both the language that might be used-- you might think 271 00:13:41,650 --> 00:13:43,420 of a hospital in the United States 272 00:13:43,420 --> 00:13:46,120 versus a hospital in China, the clinical notes will be written 273 00:13:46,120 --> 00:13:48,560 in completely different languages, that'll 274 00:13:48,560 --> 00:13:49,810 would be an extreme case. 275 00:13:49,810 --> 00:13:53,080 And a less extreme case might be two different hospitals 276 00:13:53,080 --> 00:13:56,450 in Boston where the acronyms or the shorthand 277 00:13:56,450 --> 00:13:58,600 they use for some clinical terms might actually 278 00:13:58,600 --> 00:14:03,140 be different because of local practices. 279 00:14:03,140 --> 00:14:04,490 So, what do we do? 280 00:14:04,490 --> 00:14:05,380 This is all a setup. 281 00:14:05,380 --> 00:14:07,810 And for the rest of the lecture, what I'll talk about 282 00:14:07,810 --> 00:14:10,810 is first, very briefly, how one can 283 00:14:10,810 --> 00:14:15,110 build in population-level checks for has something changed. 284 00:14:15,110 --> 00:14:17,930 And then the bulk of today's lecture, 285 00:14:17,930 --> 00:14:20,620 we'll be talking about how to develop transfer learning 286 00:14:20,620 --> 00:14:23,847 algorithms and how one could think about defenses 287 00:14:23,847 --> 00:14:24,805 to adversarial attacks. 288 00:14:29,580 --> 00:14:33,690 So before I show you that first slide for bullet one, 289 00:14:33,690 --> 00:14:35,497 I want to have a bit of discussion. 290 00:14:38,910 --> 00:14:41,160 You've suddenly done that thing of learning machine 291 00:14:41,160 --> 00:14:43,480 learning algorithm in your institution, 292 00:14:43,480 --> 00:14:47,610 and you want to know, will this algorithm 293 00:14:47,610 --> 00:14:50,820 work at some other institution? 294 00:14:50,820 --> 00:14:54,870 You pick up the phone, you call up your collaborating data 295 00:14:54,870 --> 00:14:56,850 scientists at another institution, what 296 00:14:56,850 --> 00:14:58,600 are the questions that you should ask them 297 00:14:58,600 --> 00:15:00,120 when we're trying to understand, will your algorithm 298 00:15:00,120 --> 00:15:00,912 work there as well? 299 00:15:07,860 --> 00:15:08,410 Yeah. 300 00:15:08,410 --> 00:15:10,680 AUDIENCE: What kind of lab test information 301 00:15:10,680 --> 00:15:13,270 they collect [INAUDIBLE]. 302 00:15:13,270 --> 00:15:14,940 DAVID SONTAG: So what type of data 303 00:15:14,940 --> 00:15:17,130 do they have on their patients, and do they 304 00:15:17,130 --> 00:15:20,460 have similar data types or features available 305 00:15:20,460 --> 00:15:22,910 for their patient population? 306 00:15:22,910 --> 00:15:25,160 Other ideas, someone who hasn't spoken in the last two 307 00:15:25,160 --> 00:15:29,262 lectures, maybe someone in the far back there, 308 00:15:29,262 --> 00:15:30,720 people who have their computer out. 309 00:15:30,720 --> 00:15:32,303 Maybe you with your hand in your mouth 310 00:15:32,303 --> 00:15:34,250 right there, yeah, you with your glasses on. 311 00:15:34,250 --> 00:15:34,750 Ideas. 312 00:15:34,750 --> 00:15:35,417 [STUDENT LAUGHS] 313 00:15:35,417 --> 00:15:37,590 AUDIENCE: Sorry, can you repeat the question? 314 00:15:37,590 --> 00:15:40,350 DAVID SONTAG: You want me to repeat the question? 315 00:15:40,350 --> 00:15:42,690 The question was as follows. 316 00:15:42,690 --> 00:15:45,930 You learn your machine learning algorithm at some institution, 317 00:15:45,930 --> 00:15:48,672 and you want to apply it now in a new institution. 318 00:15:48,672 --> 00:15:50,880 What questions should you ask of that new institution 319 00:15:50,880 --> 00:15:52,500 to try to assess whether your algorithm will generalize 320 00:15:52,500 --> 00:15:54,137 in that new institution? 321 00:15:54,137 --> 00:15:56,950 AUDIENCE: I guess it depends on your problem you're looking at, 322 00:15:56,950 --> 00:15:58,880 like whether you're trying to learn 323 00:15:58,880 --> 00:16:00,748 possible differences in your population, 324 00:16:00,748 --> 00:16:04,887 if you're requiring data with particular [INAUDIBLE] use. 325 00:16:04,887 --> 00:16:06,470 So I'd envision it that you'd want to, 326 00:16:06,470 --> 00:16:08,790 like are your machines calibrated [INAUDIBLE]?? 327 00:16:08,790 --> 00:16:11,070 Do they use techniques to acquire the data? 328 00:16:11,070 --> 00:16:12,070 DAVID SONTAG: All right. 329 00:16:12,070 --> 00:16:14,740 So let's break down each of the answers that you gave. 330 00:16:14,740 --> 00:16:16,420 The first answer that you gave was, 331 00:16:16,420 --> 00:16:18,100 are there differences in the population? 332 00:16:20,783 --> 00:16:22,450 What would be an exa-- someone else now, 333 00:16:22,450 --> 00:16:24,742 what are we an example of a difference in a population? 334 00:16:28,660 --> 00:16:29,160 Yep. 335 00:16:29,160 --> 00:16:30,660 AUDIENCE: Age distribution You might 336 00:16:30,660 --> 00:16:32,310 have younger people in maybe Boston 337 00:16:32,310 --> 00:16:33,600 versus like a Massachusetts [INAUDIBLE].. 338 00:16:33,600 --> 00:16:35,517 DAVID SONTAG: So you might have younger people 339 00:16:35,517 --> 00:16:38,650 in Boston versus older people who 340 00:16:38,650 --> 00:16:40,680 are in Central Massachusetts. 341 00:16:40,680 --> 00:16:42,900 How might a change in age distribution 342 00:16:42,900 --> 00:16:47,562 affect your ability of your algorithms to generalize? 343 00:16:47,562 --> 00:16:48,062 Yep. 344 00:16:48,062 --> 00:16:49,770 AUDIENCE: [? Possibly ?] health patterns, 345 00:16:49,770 --> 00:16:52,270 where young people are very different from [INAUDIBLE] who 346 00:16:52,270 --> 00:16:54,740 have some diseases that are clearly more prevalent 347 00:16:54,740 --> 00:16:56,150 in populations that are older [? than you. ?] 348 00:16:56,150 --> 00:16:57,150 DAVID SONTAG: Thank you. 349 00:16:57,150 --> 00:17:01,030 So sometimes we might expect a different just set of diseases 350 00:17:01,030 --> 00:17:03,700 to occur for a younger population versus an older 351 00:17:03,700 --> 00:17:04,700 population. 352 00:17:04,700 --> 00:17:07,810 So I type 2 diabetes, hypertension, 353 00:17:07,810 --> 00:17:13,000 these are diseases that are often diagnosed when patients-- 354 00:17:13,000 --> 00:17:17,185 when individuals are 40s, 50s, and older. 355 00:17:17,185 --> 00:17:19,300 If you have people who are in their 20s, 356 00:17:19,300 --> 00:17:21,190 you don't typically see those diseases 357 00:17:21,190 --> 00:17:23,210 in a younger population. 358 00:17:23,210 --> 00:17:27,460 And so what that means is if your model, for example, 359 00:17:27,460 --> 00:17:32,710 was trained on a population of very young individuals, 360 00:17:32,710 --> 00:17:36,730 then it might not be able to-- and suppose you're doing 361 00:17:36,730 --> 00:17:40,000 something like predicting future cost, 362 00:17:40,000 --> 00:17:42,250 so something which is not directly tied to the disease 363 00:17:42,250 --> 00:17:45,950 itself, the features that are predictive of future cost 364 00:17:45,950 --> 00:17:49,120 in a very young population might be very different from 365 00:17:49,120 --> 00:17:50,680 features-- 366 00:17:50,680 --> 00:17:53,230 for predictors of cost in a much older population because 367 00:17:53,230 --> 00:17:56,635 of the differences in conditions that those individuals have. 368 00:17:56,635 --> 00:17:58,570 Now the second answer that was given 369 00:17:58,570 --> 00:18:01,660 had to do with calibration of instruments. 370 00:18:01,660 --> 00:18:03,370 Can you elaborate a bit about that? 371 00:18:03,370 --> 00:18:03,995 AUDIENCE: Yeah. 372 00:18:03,995 --> 00:18:07,060 So I was thinking [? clearly ?] in the colonoscopy space. 373 00:18:07,060 --> 00:18:09,902 But if you're collecting-- so in that space, 374 00:18:09,902 --> 00:18:11,360 you're collecting videos of colons. 375 00:18:11,360 --> 00:18:14,020 And so you can have machines that 376 00:18:14,020 --> 00:18:15,823 are calibrated very differently, let's say 377 00:18:15,823 --> 00:18:18,470 different light exposure, different camera settings. 378 00:18:18,470 --> 00:18:21,450 But you also have that the GIs and physicians 379 00:18:21,450 --> 00:18:23,980 have different techniques as to how they explore the colon. 380 00:18:23,980 --> 00:18:26,560 So the video data itself is going to be very different. 381 00:18:26,560 --> 00:18:28,540 DAVID SONTAG: So the example that was given 382 00:18:28,540 --> 00:18:31,262 was of colonoscopies and data that might 383 00:18:31,262 --> 00:18:32,470 be collected as part of that. 384 00:18:35,088 --> 00:18:37,630 And the data that could be-- the data that could be collected 385 00:18:37,630 --> 00:18:39,505 could be different for two different reasons. 386 00:18:39,505 --> 00:18:43,967 One, because the-- because the actual instruments 387 00:18:43,967 --> 00:18:46,300 that are collecting the data, for example, imaging data, 388 00:18:46,300 --> 00:18:47,680 might be calibrated a little bit differently. 389 00:18:47,680 --> 00:18:50,013 And a second reason might be because the procedures that 390 00:18:50,013 --> 00:18:53,200 are used to perform that diagnostic test might be 391 00:18:53,200 --> 00:18:54,680 different in each institution. 392 00:18:54,680 --> 00:18:57,800 Each one will result in slightly different biases to the data, 393 00:18:57,800 --> 00:18:59,920 and it's not clear that an algorithm trained 394 00:18:59,920 --> 00:19:02,230 on one type of procedure or one type of instrument 395 00:19:02,230 --> 00:19:04,580 would generalize to another. 396 00:19:04,580 --> 00:19:06,540 So these are all great examples. 397 00:19:06,540 --> 00:19:11,620 And so when one reads a paper from the clinical community 398 00:19:11,620 --> 00:19:16,600 on developing a new risk stratification tool, what 399 00:19:16,600 --> 00:19:19,960 you will always see in this paper 400 00:19:19,960 --> 00:19:23,050 is what's known as "Table 1." 401 00:19:23,050 --> 00:19:25,360 Table 1 looks a little bit like this. 402 00:19:25,360 --> 00:19:27,550 Here I pulled one of my own papers 403 00:19:27,550 --> 00:19:29,975 that was published in JAMA Cardiology for 2016 404 00:19:29,975 --> 00:19:32,350 where we looked at how to try to find patients with heart 405 00:19:32,350 --> 00:19:34,900 failure who are hospitalized. 406 00:19:34,900 --> 00:19:37,510 And I'm just going to walk through what this table is. 407 00:19:37,510 --> 00:19:40,030 So this table is describing the population 408 00:19:40,030 --> 00:19:42,730 that was used in the study. 409 00:19:42,730 --> 00:19:46,150 At the very top, it says these are characteristics of 47,000 410 00:19:46,150 --> 00:19:47,980 hospitalized patients. 411 00:19:47,980 --> 00:19:53,480 Then what we've done is, using our domain knowledge, 412 00:19:53,480 --> 00:19:55,870 we know that this is a heart failure population, 413 00:19:55,870 --> 00:19:58,240 and we know that there are a number of different axes 414 00:19:58,240 --> 00:20:01,420 that differentiate patients who are hospitalized 415 00:20:01,420 --> 00:20:02,960 that have heart failure. 416 00:20:02,960 --> 00:20:07,180 And so we enumerate over many of the features 417 00:20:07,180 --> 00:20:10,840 that we think are critical to characterizing the population, 418 00:20:10,840 --> 00:20:12,880 and we give descriptive statistics 419 00:20:12,880 --> 00:20:14,620 on each one of those features. 420 00:20:14,620 --> 00:20:19,530 You always start with things like age, gender, and race. 421 00:20:19,530 --> 00:20:22,990 And so here, for example, the average age was 61 years old, 422 00:20:22,990 --> 00:20:32,080 this was, by the way, NYU Medical School, 50.8% female, 423 00:20:32,080 --> 00:20:37,480 11.2% Black, African-American, 17.6% of individuals 424 00:20:37,480 --> 00:20:41,290 were on Medicaid, which was a state-provided health 425 00:20:41,290 --> 00:20:46,960 insurance for either disabled or lower-income individuals. 426 00:20:46,960 --> 00:20:51,310 And then we looked at quantities like what types of medications 427 00:20:51,310 --> 00:20:52,480 were patients on. 428 00:20:52,480 --> 00:20:57,970 41% of-- 42% of inpatient patients 429 00:20:57,970 --> 00:20:59,770 were on something called beta blockers. 430 00:20:59,770 --> 00:21:03,970 31.6% of outpatients were on beta blockers. 431 00:21:03,970 --> 00:21:09,080 We then looked at things like laboratory test results. 432 00:21:09,080 --> 00:21:12,400 So one can look at the average creatinine values, 433 00:21:12,400 --> 00:21:16,900 the average sodium values of this patient population. 434 00:21:16,900 --> 00:21:18,670 And in this way, it described what 435 00:21:18,670 --> 00:21:21,173 is the population that's being studied. 436 00:21:21,173 --> 00:21:22,840 Then when you go to the new institution, 437 00:21:22,840 --> 00:21:26,440 that new institution receives not just the algorithm, 438 00:21:26,440 --> 00:21:29,170 but they also receive this Table 1 439 00:21:29,170 --> 00:21:31,900 that describes a population in which the algorithm was learned 440 00:21:31,900 --> 00:21:32,890 on. 441 00:21:32,890 --> 00:21:36,040 And they could use that together with some domain knowledge 442 00:21:36,040 --> 00:21:38,745 to think through questions like what we were eliciting-- what 443 00:21:38,745 --> 00:21:40,793 I elicited from you in our discussion 444 00:21:40,793 --> 00:21:42,460 so that we could think, is it actually-- 445 00:21:42,460 --> 00:21:44,260 does it make sense that this model will generalize 446 00:21:44,260 --> 00:21:45,400 to this new institution? 447 00:21:45,400 --> 00:21:47,920 Are the reasons why it might not? 448 00:21:47,920 --> 00:21:49,840 And you could do that even before doing 449 00:21:49,840 --> 00:21:54,710 any prospective evaluation on the new population. 450 00:21:54,710 --> 00:21:58,540 So almost all of you should have something like Table 1 451 00:21:58,540 --> 00:22:02,380 in your project write-ups because that's 452 00:22:02,380 --> 00:22:06,520 an important part of any study in this field is describing, 453 00:22:06,520 --> 00:22:09,430 what is the population that you're doing your study on? 454 00:22:09,430 --> 00:22:10,493 You agree with me, Pete? 455 00:22:10,493 --> 00:22:11,410 PETER SZOLOVITS: Yeah. 456 00:22:11,410 --> 00:22:16,300 I would just at that Table 1, if you're doing a case control 457 00:22:16,300 --> 00:22:19,420 study, you will have two columns that 458 00:22:19,420 --> 00:22:24,910 show the distributions in the two populations, 459 00:22:24,910 --> 00:22:28,840 and then a p-value of how likely those differences are 460 00:22:28,840 --> 00:22:30,440 to be significant. 461 00:22:30,440 --> 00:22:33,630 And if you leave that out, you can't get your paper published. 462 00:22:33,630 --> 00:22:35,740 DAVID SONTAG: I'll just repeat Pete's answer 463 00:22:35,740 --> 00:22:37,570 for the recording. 464 00:22:37,570 --> 00:22:44,300 If you are-- this table is for a predictive problem. 465 00:22:44,300 --> 00:22:48,100 But if you're thinking about a causal inference type problem, 466 00:22:48,100 --> 00:22:52,250 where there's a notion of different intervention groups, 467 00:22:52,250 --> 00:22:55,450 then you'd be expected to report the same sorts of things, 468 00:22:55,450 --> 00:22:57,200 but for both the case population, 469 00:22:57,200 --> 00:22:58,750 the people who received, let's say, treatment one, 470 00:22:58,750 --> 00:23:00,250 and the control population of people 471 00:23:00,250 --> 00:23:02,107 who receive treatment zero. 472 00:23:02,107 --> 00:23:03,940 And then you would be looking at differences 473 00:23:03,940 --> 00:23:06,910 between those populations as well at the individual feature 474 00:23:06,910 --> 00:23:13,080 level as part of the descriptive statistics for that study. 475 00:23:13,080 --> 00:23:16,614 Now, this-- yeah. 476 00:23:16,614 --> 00:23:19,030 AUDIENCE: Is this to identify [? individually ?] 477 00:23:19,030 --> 00:23:20,710 [? between ?] those peoples? 478 00:23:20,710 --> 00:23:24,225 [INAUDIBLE] institutions to do like t-tests on those tables-- 479 00:23:24,225 --> 00:23:25,975 DAVID SONTAG: To see if they're different? 480 00:23:25,975 --> 00:23:27,730 No, so they're always going to be different. 481 00:23:27,730 --> 00:23:29,105 You go to a new institution, it's 482 00:23:29,105 --> 00:23:31,340 always going to look different. 483 00:23:31,340 --> 00:23:34,600 And so just looking to see how something changed is not-- 484 00:23:34,600 --> 00:23:37,750 the answer's always going to be yes. 485 00:23:37,750 --> 00:23:42,190 But it enables a conversation to think through, OK, this, 486 00:23:42,190 --> 00:23:43,330 and then you might look-- 487 00:23:43,330 --> 00:23:44,440 you might use some of the techniques 488 00:23:44,440 --> 00:23:46,982 that Pete's going to talk about next week on interpretability 489 00:23:46,982 --> 00:23:49,330 to understand, well, what is the model actually using. 490 00:23:49,330 --> 00:23:51,520 Then you might ask, oh, OK, well, 491 00:23:51,520 --> 00:23:53,020 the model is using this thing, which 492 00:23:53,020 --> 00:23:55,480 makes sense in this population but might not make sense 493 00:23:55,480 --> 00:23:56,530 in another population. 494 00:23:56,530 --> 00:23:58,210 And it's these two things together 495 00:23:58,210 --> 00:23:59,693 that make the conversation. 496 00:24:04,130 --> 00:24:07,550 Now, this question has really come 497 00:24:07,550 --> 00:24:13,820 to the forefront in recent years in close connection 498 00:24:13,820 --> 00:24:16,865 to the topic that Pete discussed last week on fairness 499 00:24:16,865 --> 00:24:18,350 in machine learning. 500 00:24:18,350 --> 00:24:20,495 Because you might ask if a classifier is built 501 00:24:20,495 --> 00:24:22,370 in some population, is it going to generalize 502 00:24:22,370 --> 00:24:24,870 to another population if that population that has learned on 503 00:24:24,870 --> 00:24:26,540 was very biased, for example, it might 504 00:24:26,540 --> 00:24:27,770 have been all white people. 505 00:24:27,770 --> 00:24:29,145 You might ask, is that classifier 506 00:24:29,145 --> 00:24:31,730 going to work well in another population that might perhaps 507 00:24:31,730 --> 00:24:34,760 include people of different ethnicities? 508 00:24:34,760 --> 00:24:41,810 And so that has led to a concept which was recently published. 509 00:24:41,810 --> 00:24:44,900 This working draft that I'm showing the abstract from 510 00:24:44,900 --> 00:24:50,330 was just a few weeks ago called "Datasheets for data sets." 511 00:24:50,330 --> 00:24:52,280 And the goal here is to standardize 512 00:24:52,280 --> 00:24:54,830 the process of describing-- of eliciting 513 00:24:54,830 --> 00:24:58,130 the information about what is it about the data set that really 514 00:24:58,130 --> 00:25:03,630 played into your model? 515 00:25:03,630 --> 00:25:05,845 And so I'm going to walk you through very briefly 516 00:25:05,845 --> 00:25:07,220 just through a couple of elements 517 00:25:07,220 --> 00:25:13,440 of what an example data set for a datasheet might look like. 518 00:25:13,440 --> 00:25:14,930 This is too small for you to read, 519 00:25:14,930 --> 00:25:18,180 but I'll blow up one section in just a second. 520 00:25:18,180 --> 00:25:21,620 So this is a datasheet for a data 521 00:25:21,620 --> 00:25:25,700 set called Studying Face Recognition in an Unconstrained 522 00:25:25,700 --> 00:25:26,280 Environment. 523 00:25:26,280 --> 00:25:28,738 So it's for computer vision problem. 524 00:25:28,738 --> 00:25:30,780 There are going to be a number of questionnaires, 525 00:25:30,780 --> 00:25:33,590 which this paper that I point you to outlines. 526 00:25:33,590 --> 00:25:38,330 And you as the model developer go through that questionnaire 527 00:25:38,330 --> 00:25:41,630 and fill out the answers to it, so 528 00:25:41,630 --> 00:25:43,850 including things about motivation for the data 529 00:25:43,850 --> 00:25:46,920 set creation composition and so on. 530 00:25:46,920 --> 00:25:51,170 So in this particular instance, this data set called Labeled 531 00:25:51,170 --> 00:25:54,740 Faces in the Wild was created to provide images that study face 532 00:25:54,740 --> 00:25:57,470 recognition in an unconstrained [INAUDIBLE] settings, 533 00:25:57,470 --> 00:26:02,060 where image characteristics such as pose, elimination, 534 00:26:02,060 --> 00:26:05,450 resolution, focus cannot be controlled. 535 00:26:05,450 --> 00:26:10,130 So it's intended to be real-world settings. 536 00:26:10,130 --> 00:26:11,930 Now, one of the most interesting sections 537 00:26:11,930 --> 00:26:16,730 of this report that one should release with the data set 538 00:26:16,730 --> 00:26:20,060 has to do with how was the data preprocessed or cleaned? 539 00:26:20,060 --> 00:26:21,560 So, for example, for this data set, 540 00:26:21,560 --> 00:26:23,480 it walks through the following process. 541 00:26:23,480 --> 00:26:28,030 First, raw images were obtained from the data set, 542 00:26:28,030 --> 00:26:31,010 and it consisted of images and captions that 543 00:26:31,010 --> 00:26:35,090 were found together with that image in news articles 544 00:26:35,090 --> 00:26:37,160 or around the web. 545 00:26:37,160 --> 00:26:42,080 Then there was a face detector that was run on the data set. 546 00:26:42,080 --> 00:26:46,420 Here were the parameters of the face detector that were used. 547 00:26:46,420 --> 00:26:50,240 And then remember, the goal here is to study face detection. 548 00:26:50,240 --> 00:26:56,210 And so-- so one has to know, how were the-- 549 00:26:56,210 --> 00:27:00,620 how were the labels determined? 550 00:27:00,620 --> 00:27:03,110 And how would one, for example, eliminate if there 551 00:27:03,110 --> 00:27:05,070 was no face in this image? 552 00:27:05,070 --> 00:27:09,230 And so there they described how a face was detected and how 553 00:27:09,230 --> 00:27:11,960 a region was determined to not be a face in the case that it 554 00:27:11,960 --> 00:27:12,800 wasn't. 555 00:27:12,800 --> 00:27:16,568 And finally, it describes how duplicates were removed. 556 00:27:16,568 --> 00:27:18,110 And if you think back to the examples 557 00:27:18,110 --> 00:27:22,340 we had earlier in the semester from medical imaging, 558 00:27:22,340 --> 00:27:25,370 for example in pathology and radiology, 559 00:27:25,370 --> 00:27:28,670 similar data set constructions had to be done there. 560 00:27:28,670 --> 00:27:30,500 For example, one would go to the PAC System 561 00:27:30,500 --> 00:27:33,590 where radiology images are stored, one would-- 562 00:27:33,590 --> 00:27:37,850 one would decide which images are going to be pulled out, 563 00:27:37,850 --> 00:27:39,830 one would go to radiography reports 564 00:27:39,830 --> 00:27:42,895 to figure out how do we extract the relevant findings 565 00:27:42,895 --> 00:27:44,270 from that image, which would give 566 00:27:44,270 --> 00:27:48,410 the labels for that predictive-- for that learning task. 567 00:27:48,410 --> 00:27:52,490 And each step there will incur some bias and some-- 568 00:27:52,490 --> 00:27:55,580 which one then needs to describe carefully in order 569 00:27:55,580 --> 00:27:57,470 to understand what might the bias 570 00:27:57,470 --> 00:28:00,050 be of the learned classifier. 571 00:28:00,050 --> 00:28:03,680 So I won't go into more detail on this now, 572 00:28:03,680 --> 00:28:05,633 but this will also be one of the suggested 573 00:28:05,633 --> 00:28:06,800 readings for today's course. 574 00:28:06,800 --> 00:28:08,330 And it's a fast read. 575 00:28:08,330 --> 00:28:11,330 I encourage you to go through it to get some tuition for what 576 00:28:11,330 --> 00:28:13,910 are questions we might want to be asking about data sets 577 00:28:13,910 --> 00:28:14,600 that we create. 578 00:28:18,172 --> 00:28:19,630 And for the rest of this semester-- 579 00:28:19,630 --> 00:28:21,910 for the rest of the lecture today, I'm 580 00:28:21,910 --> 00:28:25,390 now going to move on to some more technical issues. 581 00:28:25,390 --> 00:28:29,530 So we have to do it. 582 00:28:29,530 --> 00:28:32,170 We're doing machine learning now. 583 00:28:32,170 --> 00:28:34,210 The populations might be different. 584 00:28:34,210 --> 00:28:35,510 What do we do about it? 585 00:28:35,510 --> 00:28:37,150 Can we change the learning algorithm 586 00:28:37,150 --> 00:28:40,420 in order to hope that your algorithm might transfer better 587 00:28:40,420 --> 00:28:41,470 to a new institution? 588 00:28:41,470 --> 00:28:44,350 Or if we get a little bit of data from that new institution, 589 00:28:44,350 --> 00:28:46,840 could we use that small amount of data 590 00:28:46,840 --> 00:28:50,350 from the new institution or a future time point in the future 591 00:28:50,350 --> 00:28:54,010 to retrain our model to do well in that slightly 592 00:28:54,010 --> 00:28:56,060 different distribution? 593 00:28:56,060 --> 00:28:59,030 So that's the whole field of transfer learning. 594 00:28:59,030 --> 00:29:02,770 So you have data drawn from one distribution on p of x and y, 595 00:29:02,770 --> 00:29:05,950 and maybe we have a little bit of data drawn from a different 596 00:29:05,950 --> 00:29:08,500 distribution q of x,y. 597 00:29:08,500 --> 00:29:11,260 And under the covariate shift assumption, 598 00:29:11,260 --> 00:29:23,650 I'm assuming that q(x,y) is equal to q of x times p of y 599 00:29:23,650 --> 00:29:26,740 given x, namely that the conditional distribution of y 600 00:29:26,740 --> 00:29:28,277 given x hasn't changed. 601 00:29:28,277 --> 00:29:29,860 The only thing that might have changed 602 00:29:29,860 --> 00:29:31,660 is your distribution over x. 603 00:29:31,660 --> 00:29:35,550 So that's what the covariate shift assumption would assume. 604 00:29:40,270 --> 00:29:42,850 So suppose that we have some small amount 605 00:29:42,850 --> 00:29:46,780 of data drawn from the new distribution q. 606 00:29:46,780 --> 00:29:48,670 How could we then use that in order 607 00:29:48,670 --> 00:29:52,720 to perhaps retrain our classifier to do well 608 00:29:52,720 --> 00:29:55,390 for that new institution? 609 00:29:55,390 --> 00:29:59,060 So I'll walk through four different approaches to do so. 610 00:29:59,060 --> 00:30:02,020 I'll start with linear models, which 611 00:30:02,020 --> 00:30:04,360 are the simplest to understand, and then I'll 612 00:30:04,360 --> 00:30:09,740 move on to deep models. 613 00:30:09,740 --> 00:30:12,280 The first approach to something that you've seen already 614 00:30:12,280 --> 00:30:14,740 several times in this course. 615 00:30:14,740 --> 00:30:17,770 We're going to think about transfer 616 00:30:17,770 --> 00:30:22,450 as a multi-task learning problem, where one of the tasks 617 00:30:22,450 --> 00:30:26,012 has much less data than the other task. 618 00:30:26,012 --> 00:30:27,970 So if you remember when we talked about disease 619 00:30:27,970 --> 00:30:31,750 progression modeling, I introduced 620 00:30:31,750 --> 00:30:34,990 this notion of regularizing the weight 621 00:30:34,990 --> 00:30:37,947 vectors so that they could be close to one another. 622 00:30:37,947 --> 00:30:40,030 At that time, we were talking about weight vectors 623 00:30:40,030 --> 00:30:42,030 predicting disease progression in different time 624 00:30:42,030 --> 00:30:43,030 points in the future. 625 00:30:43,030 --> 00:30:46,730 We could use exactly the same idea here, 626 00:30:46,730 --> 00:30:50,950 where you take your classifier, your linear classifier that 627 00:30:50,950 --> 00:30:53,200 was trained on a really large corpus, 628 00:30:53,200 --> 00:30:54,850 I'm going to call that-- 629 00:30:54,850 --> 00:30:57,460 I'm going to call the weights of that classifier w old, 630 00:30:57,460 --> 00:31:01,990 and then I'm going to solve a new optimization problem, which 631 00:31:01,990 --> 00:31:08,030 is minimizing over the weights w that minimizes some loss. 632 00:31:08,030 --> 00:31:11,560 So this is where your training-- your new training data come in. 633 00:31:22,690 --> 00:31:26,740 So I'm going to assume that the new training get D is 634 00:31:26,740 --> 00:31:29,140 drawn from the q distribution. 635 00:31:32,880 --> 00:31:38,360 And I'm going to add on a regularization that asks that w 636 00:31:38,360 --> 00:31:40,870 should stay close to w old. 637 00:31:44,390 --> 00:31:47,990 Now, if the amount of data you have-- if D, 638 00:31:47,990 --> 00:31:51,710 the data from that new institution, was very large, 639 00:31:51,710 --> 00:31:58,510 then you wouldn't need this at all because you would be able 640 00:31:58,510 --> 00:31:59,695 to just-- 641 00:31:59,695 --> 00:32:02,590 you would be able to ignore the classifier that you learned 642 00:32:02,590 --> 00:32:04,600 previously and just refit everything 643 00:32:04,600 --> 00:32:06,400 to that new institution's data. 644 00:32:06,400 --> 00:32:08,830 Where something like this is particularly valuable 645 00:32:08,830 --> 00:32:12,280 is if there was a small amount of data set shift, 646 00:32:12,280 --> 00:32:15,790 and you only have a very small amount of labeled data 647 00:32:15,790 --> 00:32:17,860 from that new institution, then this 648 00:32:17,860 --> 00:32:20,650 would allow you to change your weight 649 00:32:20,650 --> 00:32:22,280 vector just a little bit. 650 00:32:22,280 --> 00:32:24,340 So if this coefficient was very large, 651 00:32:24,340 --> 00:32:26,530 it would say that the new w can't 652 00:32:26,530 --> 00:32:28,745 be too far from the old w. 653 00:32:28,745 --> 00:32:31,300 So it'll allow you to shift things a little bit 654 00:32:31,300 --> 00:32:35,360 in order to do well on the small amount of data that you have. 655 00:32:35,360 --> 00:32:38,350 So, for example, if there is a feature which was previously 656 00:32:38,350 --> 00:32:40,360 predictive, but that feature is no longer 657 00:32:40,360 --> 00:32:42,450 present in the new data set, so, for example, 658 00:32:42,450 --> 00:32:45,290 it's all identically zero, then, of course, the new weight 659 00:32:45,290 --> 00:32:45,790 vect-- 660 00:32:45,790 --> 00:32:48,892 the new weight for that feature is going to be set to 0, 661 00:32:48,892 --> 00:32:50,350 and that weight you can think about 662 00:32:50,350 --> 00:32:53,346 as being redistributed to some of the other features. 663 00:32:53,346 --> 00:32:54,530 Does this makes sense? 664 00:32:54,530 --> 00:32:55,552 Any questions? 665 00:32:59,000 --> 00:33:02,200 So this is the simplest approach to transfer learning. 666 00:33:02,200 --> 00:33:04,510 And before you ever try anything more complicated, 667 00:33:04,510 --> 00:33:05,200 always try this. 668 00:33:12,074 --> 00:33:13,511 Uh, yep. 669 00:33:16,160 --> 00:33:25,130 So the second approach is also with a linear model, 670 00:33:25,130 --> 00:33:28,900 but here we're no longer going to assume that the features are 671 00:33:28,900 --> 00:33:30,830 still useful. 672 00:33:30,830 --> 00:33:35,440 So there might-- when you go from-- 673 00:33:35,440 --> 00:33:37,540 when you go from a-- 674 00:33:37,540 --> 00:33:41,190 your first institution, let's say, I'm GH on the left, 675 00:33:41,190 --> 00:33:42,940 you learn your model, and you can apply it 676 00:33:42,940 --> 00:33:46,660 to some new institution, let's say, UCSF on the right, 677 00:33:46,660 --> 00:33:49,960 it could be that there is some really big change 678 00:33:49,960 --> 00:33:53,380 in the feature set such that-- 679 00:33:53,380 --> 00:33:56,110 such that the original features are not at all 680 00:33:56,110 --> 00:33:59,680 useful for the new feature set. 681 00:33:59,680 --> 00:34:01,570 And a really extreme example of that 682 00:34:01,570 --> 00:34:04,283 might be the setting that I gave earlier when I said, 683 00:34:04,283 --> 00:34:06,700 your model's trained on English, and you're testing it out 684 00:34:06,700 --> 00:34:08,116 in Chinese. 685 00:34:08,116 --> 00:34:09,199 That would be an example-- 686 00:34:09,199 --> 00:34:10,741 if you use a bag of words model, that 687 00:34:10,741 --> 00:34:14,290 would be an example where your model, obviously, 688 00:34:14,290 --> 00:34:18,670 wouldn't generalize at all because your features are 689 00:34:18,670 --> 00:34:21,219 completely different. 690 00:34:21,219 --> 00:34:23,346 So what would you do in that setting? 691 00:34:23,346 --> 00:34:25,179 What's the simplest thing that you might do? 692 00:34:30,300 --> 00:34:33,510 So you're taking a text classifier learned in English, 693 00:34:33,510 --> 00:34:35,100 and you want to apply it in a setting 694 00:34:35,100 --> 00:34:36,840 where that language is Chinese. 695 00:34:36,840 --> 00:34:38,424 What would you do? 696 00:34:38,424 --> 00:34:39,424 AUDIENCE: Train on them. 697 00:34:39,424 --> 00:34:41,220 DAVID SONTAG: Translate, you said. 698 00:34:41,220 --> 00:34:42,645 And there was another answer. 699 00:34:42,645 --> 00:34:45,258 AUDIENCE: Or try train an RN. 700 00:34:45,258 --> 00:34:46,800 DAVID SONTAG: Train an RN to do what? 701 00:34:46,800 --> 00:34:48,075 AUDIENCE: To translate. 702 00:34:48,075 --> 00:34:49,770 DAVID SONTAG: Train an RN-- oh, OK. 703 00:34:49,770 --> 00:34:53,190 So assume that you have some ability 704 00:34:53,190 --> 00:34:56,610 to do machine translation, you translate from English to-- 705 00:34:56,610 --> 00:34:57,610 from Chinese to English. 706 00:34:57,610 --> 00:35:00,068 It has to be that direction because the original classifier 707 00:35:00,068 --> 00:35:01,140 was trained in English. 708 00:35:01,140 --> 00:35:04,770 And then your new function is the composition 709 00:35:04,770 --> 00:35:08,640 of the translation and the original function, right? 710 00:35:08,640 --> 00:35:11,070 And then you can imagine doing some fine 711 00:35:11,070 --> 00:35:14,100 tuning if you had a small amount of data. 712 00:35:14,100 --> 00:35:19,380 Now, the simplest translation function 713 00:35:19,380 --> 00:35:21,070 might be just use a dictionary. 714 00:35:21,070 --> 00:35:23,610 So you look up a word, and if that word 715 00:35:23,610 --> 00:35:25,680 has an analogy in another language, 716 00:35:25,680 --> 00:35:27,620 you say, OK, this is the translation. 717 00:35:27,620 --> 00:35:30,120 But there are always going to be some words in your language 718 00:35:30,120 --> 00:35:33,835 which don't have a very good translation. 719 00:35:33,835 --> 00:35:36,210 And so you might imagine that the simplest approach would 720 00:35:36,210 --> 00:35:38,970 be to translate, but then to just drop out 721 00:35:38,970 --> 00:35:43,170 words that don't have a good analog 722 00:35:43,170 --> 00:35:46,950 and force your classifier to work with, let's say, 723 00:35:46,950 --> 00:35:49,752 just the shared vocabulary. 724 00:35:49,752 --> 00:35:51,210 Everything we're talking about here 725 00:35:51,210 --> 00:35:54,340 is an example of a manually chosen decision. 726 00:35:54,340 --> 00:35:56,910 So we're going to manually choose a new representation 727 00:35:56,910 --> 00:36:01,740 for the data such that we have some amount of shared 728 00:36:01,740 --> 00:36:05,550 features between the source and target data sets. 729 00:36:08,320 --> 00:36:10,648 So let's talk about electronic health record 1 730 00:36:10,648 --> 00:36:11,940 and electronic health record 2. 731 00:36:11,940 --> 00:36:14,340 By the way, the slides that I'll be presenting here 732 00:36:14,340 --> 00:36:17,300 are from a paper published in KDD 733 00:36:17,300 --> 00:36:21,570 by Jan, Tristan, your instructor, Pete, 734 00:36:21,570 --> 00:36:23,580 and John Guttag. 735 00:36:23,580 --> 00:36:25,530 So you have to go two electronic health 736 00:36:25,530 --> 00:36:27,210 records, electronic health record 1, 737 00:36:27,210 --> 00:36:28,800 electronic health record 2. 738 00:36:28,800 --> 00:36:30,390 How can things change? 739 00:36:30,390 --> 00:36:36,900 Well, it could be that the same concept in electronic health 740 00:36:36,900 --> 00:36:41,762 record 1 might be mapped to a different encoding, 741 00:36:41,762 --> 00:36:43,470 so that's like an English-to-Spanish type 742 00:36:43,470 --> 00:36:47,220 translation, in electronic health record 2. 743 00:36:47,220 --> 00:36:48,810 Another example of a change might 744 00:36:48,810 --> 00:36:55,410 be to say that some concepts are removed, like maybe you 745 00:36:55,410 --> 00:36:58,260 have laboratory test results in electronic health record 1 746 00:36:58,260 --> 00:37:00,270 but not in electronic health record 2. 747 00:37:00,270 --> 00:37:03,120 So that's why you see an edge to nowhere. 748 00:37:03,120 --> 00:37:07,925 Another change might be there might be new concepts. 749 00:37:07,925 --> 00:37:10,050 So the new institution might have new types of data 750 00:37:10,050 --> 00:37:12,220 that the old institution didn't have. 751 00:37:12,220 --> 00:37:14,070 So what do you do in that setting? 752 00:37:14,070 --> 00:37:17,880 Well, one approach we would say, OK, we 753 00:37:17,880 --> 00:37:20,160 have some small amount of data from electronic health 754 00:37:20,160 --> 00:37:21,360 record 2. 755 00:37:21,360 --> 00:37:25,890 We could just train using that and throw away 756 00:37:25,890 --> 00:37:28,900 your original data from electronic health record 1. 757 00:37:28,900 --> 00:37:30,960 Now, of course, if you only had a small amount 758 00:37:30,960 --> 00:37:33,758 of data from the target to distribution, then 759 00:37:33,758 --> 00:37:36,300 that's going to be a very poor approach because you might not 760 00:37:36,300 --> 00:37:37,717 have enough data to actually learn 761 00:37:37,717 --> 00:37:39,990 a reasonable enough model. 762 00:37:39,990 --> 00:37:41,520 A second obvious approach would be, 763 00:37:41,520 --> 00:37:47,250 OK, we're going to just train on electronic health record 1 764 00:37:47,250 --> 00:37:48,180 and apply it. 765 00:37:48,180 --> 00:37:52,930 And for those concepts that aren't present anymore, 766 00:37:52,930 --> 00:37:54,330 so be it. 767 00:37:54,330 --> 00:37:56,028 Maybe things won't work very well. 768 00:37:56,028 --> 00:37:58,320 A third approach, which we were alluding to before when 769 00:37:58,320 --> 00:37:59,850 we talked about translation, would 770 00:37:59,850 --> 00:38:02,430 be to learn a model just in the intersection of the two 771 00:38:02,430 --> 00:38:03,780 features. 772 00:38:03,780 --> 00:38:06,750 And what this work does, as they say, 773 00:38:06,750 --> 00:38:09,300 we're going to manually redefine the feature set 774 00:38:09,300 --> 00:38:12,938 in order to try to find as much common ground as possible. 775 00:38:12,938 --> 00:38:14,730 And this is something which really involves 776 00:38:14,730 --> 00:38:17,250 a lot of domain knowledge. 777 00:38:17,250 --> 00:38:20,190 And I'm going to be using this as a point of contrast 778 00:38:20,190 --> 00:38:23,550 from what I'll be talking about in 10 or 15 minutes, where 779 00:38:23,550 --> 00:38:25,855 I talk about how one could do this without that domain 780 00:38:25,855 --> 00:38:27,480 knowledge that we're going to use here. 781 00:38:31,060 --> 00:38:33,640 So the setting that they looked at 782 00:38:33,640 --> 00:38:37,720 is one of predicting outcomes, such as in-hospital mortality 783 00:38:37,720 --> 00:38:40,330 or length of stay. 784 00:38:40,330 --> 00:38:43,840 The model which is going to be used as a bag-of-events model. 785 00:38:43,840 --> 00:38:47,770 So we will take a patient's longitudinal history up 786 00:38:47,770 --> 00:38:49,330 until the time of prediction. 787 00:38:49,330 --> 00:38:52,390 We'll look at different events that occurred. 788 00:38:52,390 --> 00:38:57,280 And this study was done using PhysioNet. 789 00:38:57,280 --> 00:39:01,120 And MIMIC, for example, events are encoded with some number, 790 00:39:01,120 --> 00:39:04,810 like 5814 might correspond to a CVP alarm, 791 00:39:04,810 --> 00:39:07,630 1046 might correspond to pain being present, 792 00:39:07,630 --> 00:39:12,620 25 might correspond to the drug heparin being given and so on. 793 00:39:12,620 --> 00:39:15,580 So we're going to create one feature for every event which 794 00:39:15,580 --> 00:39:16,480 has some number-- 795 00:39:16,480 --> 00:39:18,310 which is encoded with some number. 796 00:39:18,310 --> 00:39:21,250 And we'll just say 1 if that event 797 00:39:21,250 --> 00:39:22,830 has occurred, 0 otherwise. 798 00:39:22,830 --> 00:39:26,960 So that's the representation for a patient. 799 00:39:26,960 --> 00:39:30,350 Now, because when one goes though this new institution, 800 00:39:30,350 --> 00:39:34,790 EHR2, the way that events are encoded 801 00:39:34,790 --> 00:39:36,560 might be completely different. 802 00:39:36,560 --> 00:39:38,810 One won't be able to just use the original feature 803 00:39:38,810 --> 00:39:40,090 representation. 804 00:39:40,090 --> 00:39:42,020 And that's the English-to-Spanish example 805 00:39:42,020 --> 00:39:43,448 that I gave. 806 00:39:43,448 --> 00:39:44,990 But instead, what one could try to do 807 00:39:44,990 --> 00:39:48,770 is come up with a new feature set where that feature 808 00:39:48,770 --> 00:39:53,250 set could be derived from each of the different data sets. 809 00:39:53,250 --> 00:39:57,680 So, for example, since each one of the events in MIMIC 810 00:39:57,680 --> 00:40:00,140 has some text description that goes 811 00:40:00,140 --> 00:40:03,350 with it, event one corresponds to ischemic stroke, 812 00:40:03,350 --> 00:40:06,920 event 2, hemorrhagic stroke, and so on, 813 00:40:06,920 --> 00:40:08,420 one could attempt to map-- 814 00:40:08,420 --> 00:40:12,560 use that English description of the feature 815 00:40:12,560 --> 00:40:15,350 to come up with a way to map it into a common language. 816 00:40:15,350 --> 00:40:17,120 In this case, the common language 817 00:40:17,120 --> 00:40:20,990 is the UMLS, the United Medical Language 818 00:40:20,990 --> 00:40:23,700 System that Pete talked about a few lectures ago. 819 00:40:23,700 --> 00:40:26,750 So we're going to now say, OK, we have a much larger feature 820 00:40:26,750 --> 00:40:31,580 set where we've now encoded ischemic stroke 821 00:40:31,580 --> 00:40:34,700 as this concept, which is actually 822 00:40:34,700 --> 00:40:36,920 the same ischemic stroke, but also 823 00:40:36,920 --> 00:40:40,460 as this concept and that concept, 824 00:40:40,460 --> 00:40:43,850 which are more general versions of that original one. 825 00:40:43,850 --> 00:40:46,310 So this is just general stroke, and it 826 00:40:46,310 --> 00:40:49,100 could be multiple different types of strokes. 827 00:40:49,100 --> 00:40:53,510 And the hope is that even if in-- 828 00:40:53,510 --> 00:40:55,970 even if the model doesn't-- 829 00:40:55,970 --> 00:40:57,680 even if some of these more specific ones 830 00:40:57,680 --> 00:40:59,900 don't show up in the new institution's data, 831 00:40:59,900 --> 00:41:04,432 perhaps some of the more general concepts do show up there. 832 00:41:04,432 --> 00:41:05,890 And then what you're going to do is 833 00:41:05,890 --> 00:41:11,570 you're going to learn your model now on this expanded translated 834 00:41:11,570 --> 00:41:14,923 vocabulary, and then translate it. 835 00:41:14,923 --> 00:41:16,340 And at the new institution, you'll 836 00:41:16,340 --> 00:41:18,600 also be using that same common data model. 837 00:41:18,600 --> 00:41:20,810 And that way one hopes to have much more overlap 838 00:41:20,810 --> 00:41:23,540 in your feature set. 839 00:41:23,540 --> 00:41:27,770 And so to evaluate this, the authors 840 00:41:27,770 --> 00:41:32,210 looked at two different time points within MIMIC. 841 00:41:32,210 --> 00:41:36,350 One time point was when the Beth Israel Deaconess Medical 842 00:41:36,350 --> 00:41:39,320 Center was using electronic health record called CareView. 843 00:41:39,320 --> 00:41:41,655 And the second time point was when that hospital 844 00:41:41,655 --> 00:41:43,280 was using a different electronic health 845 00:41:43,280 --> 00:41:45,560 record called MetaVision. 846 00:41:45,560 --> 00:41:49,670 So this is an example actually of non-stationarity. 847 00:41:49,670 --> 00:41:54,200 Now because of them using two different electronic health 848 00:41:54,200 --> 00:41:55,790 records, the encodings were different. 849 00:41:55,790 --> 00:41:58,733 And that's why this problem arose. 850 00:41:58,733 --> 00:42:00,400 And so we're going to use this approach, 851 00:42:00,400 --> 00:42:02,270 and we're going to then learn a linear model 852 00:42:02,270 --> 00:42:06,410 on top of this new encoding that I just described. 853 00:42:06,410 --> 00:42:12,140 And we're going to compare the results by looking at how much 854 00:42:12,140 --> 00:42:16,790 performance was lost due to using this new encoding, 855 00:42:16,790 --> 00:42:19,110 and how well we generalize from one-- 856 00:42:19,110 --> 00:42:26,111 from one-- from the source task to the target task. 857 00:42:26,111 --> 00:42:28,170 And so here's the first question, 858 00:42:28,170 --> 00:42:32,330 which is, how much do we lose by using this new encoding? 859 00:42:32,330 --> 00:42:34,220 So as a comparison point for looking 860 00:42:34,220 --> 00:42:36,973 at predicting in-hospital mortality, we'll look at, 861 00:42:36,973 --> 00:42:38,390 what is the predictive performance 862 00:42:38,390 --> 00:42:42,920 if you're to just use an existing, very simple risk 863 00:42:42,920 --> 00:42:45,050 score called the SAPS score? 864 00:42:45,050 --> 00:42:48,260 And that's this red line where that y-axis here 865 00:42:48,260 --> 00:42:51,500 is the area under the ROC curve, and the x-axis 866 00:42:51,500 --> 00:42:53,480 is how much time in advance you're 867 00:42:53,480 --> 00:42:56,090 predicting, so the prediction gap. 868 00:42:56,090 --> 00:43:01,520 So using this very simple score, SAPS get somewhere between 0.75 869 00:43:01,520 --> 00:43:04,260 and 0.80, area under the ROC curve. 870 00:43:04,260 --> 00:43:08,900 But if you were to use all of the events data, which 871 00:43:08,900 --> 00:43:11,360 is much, much richer than what went into that simple SAPS 872 00:43:11,360 --> 00:43:16,310 score, you would get the purple curve, which is-- 873 00:43:16,310 --> 00:43:20,180 the purple curve, which is SAPS plus the event data, 874 00:43:20,180 --> 00:43:22,372 or the blue curve, which is just the events data. 875 00:43:22,372 --> 00:43:24,080 And you can see you can get substantially 876 00:43:24,080 --> 00:43:25,670 better predictive performance by using 877 00:43:25,670 --> 00:43:28,638 that much richer feature set. 878 00:43:28,638 --> 00:43:30,680 The SAPS score has the advantage that it's easier 879 00:43:30,680 --> 00:43:34,580 to generalize because it's so simple, those feature elements, 880 00:43:34,580 --> 00:43:38,600 one could trivially translate to any new EHR, either manually 881 00:43:38,600 --> 00:43:43,220 or automatically, and thus it'll always be a viable route. 882 00:43:43,220 --> 00:43:45,230 Whereas this blue curve, although it 883 00:43:45,230 --> 00:43:46,780 gets better predictive performance, 884 00:43:46,780 --> 00:43:49,490 you have to really worry about these generalization questions. 885 00:43:52,689 --> 00:43:56,030 And the same story happens in both of the source 886 00:43:56,030 --> 00:43:58,160 task and the target task. 887 00:43:58,160 --> 00:44:00,980 Now the second question to ask is, well, 888 00:44:00,980 --> 00:44:03,650 how much do you lose when you use the new representation 889 00:44:03,650 --> 00:44:05,520 of the data? 890 00:44:05,520 --> 00:44:09,290 And so here looking at, again, both of the two-- 891 00:44:09,290 --> 00:44:13,790 both EHRs, what we see first in red 892 00:44:13,790 --> 00:44:17,733 is the same red curvature-- is the same as the blue curvature 893 00:44:17,733 --> 00:44:18,650 on the previous slide. 894 00:44:18,650 --> 00:44:23,550 It's using SAPS plus the item IDs, so using all of the data. 895 00:44:23,550 --> 00:44:26,270 And then the blue curve here, which is a bit hard to see, 896 00:44:26,270 --> 00:44:29,150 but it's right there, it's substantially lower. 897 00:44:29,150 --> 00:44:31,360 So that's what happens if you now 898 00:44:31,360 --> 00:44:33,760 use this new representation. 899 00:44:33,760 --> 00:44:36,340 And you see that you do lose something 900 00:44:36,340 --> 00:44:39,940 by trying to find a common vocabulary. 901 00:44:39,940 --> 00:44:44,780 The performance does get hit a bit. 902 00:44:44,780 --> 00:44:46,570 But what's particularly interesting is 903 00:44:46,570 --> 00:44:52,450 when you attempt to generalize, you start to see a swap. 904 00:44:52,450 --> 00:44:54,970 So if we now-- 905 00:44:54,970 --> 00:44:59,860 so now the colors are going to be quite similar. 906 00:44:59,860 --> 00:45:02,930 Red here was at the very top before. 907 00:45:02,930 --> 00:45:07,480 So red is using the original representation of the data. 908 00:45:07,480 --> 00:45:10,600 Before it was at the very top. 909 00:45:10,600 --> 00:45:15,640 Shown here is the training error on this institution, CareView. 910 00:45:15,640 --> 00:45:17,938 You see, there's so much rich information 911 00:45:17,938 --> 00:45:19,480 in the original feature set that it's 912 00:45:19,480 --> 00:45:21,313 able to do very good predictive performance. 913 00:45:21,313 --> 00:45:24,370 But once you attempt to translate it, 914 00:45:24,370 --> 00:45:28,270 so you train on CareView, but you test on MetaVision, 915 00:45:28,270 --> 00:45:32,170 then the test performance shown here by this solid red line 916 00:45:32,170 --> 00:45:34,190 is actually the worst of all of the system. 917 00:45:34,190 --> 00:45:36,550 So there's a substantial drop in performance 918 00:45:36,550 --> 00:45:39,070 because not all of these features 919 00:45:39,070 --> 00:45:41,230 are present in the new EHR. 920 00:45:41,230 --> 00:45:44,680 On the other hand, when the translated version, 921 00:45:44,680 --> 00:45:49,540 despite the fact that it's a little bit worse when 922 00:45:49,540 --> 00:45:52,930 evaluated on the source, it generalizes much better. 923 00:45:52,930 --> 00:45:56,170 And so you see a significantly better performance 924 00:45:56,170 --> 00:45:59,320 that's shown by this blue curve here when you 925 00:45:59,320 --> 00:46:01,048 use this translated vocabulary. 926 00:46:01,048 --> 00:46:01,840 There's a question. 927 00:46:01,840 --> 00:46:04,384 AUDIENCE: So would you train with full features? 928 00:46:04,384 --> 00:46:08,430 So how do you apply [? with ?] them if the other [? full ?] 929 00:46:08,430 --> 00:46:10,810 features are-- you just [INAUDIBLE].. 930 00:46:10,810 --> 00:46:14,860 DAVID SONTAG: So, you assume that you have come up 931 00:46:14,860 --> 00:46:19,480 with a mapping from the features in both of the EHRs 932 00:46:19,480 --> 00:46:23,995 to this common feature vocabulary of QEs. 933 00:46:23,995 --> 00:46:26,620 And the way that this mapping is going to be done in this paper 934 00:46:26,620 --> 00:46:29,188 is based on the text of the-- 935 00:46:32,980 --> 00:46:34,480 of the events. 936 00:46:34,480 --> 00:46:36,700 So you take the text-based description of the event, 937 00:46:36,700 --> 00:46:38,533 and you come up with a deterministic mapping 938 00:46:38,533 --> 00:46:43,090 to this new UMLS-based representation. 939 00:46:43,090 --> 00:46:44,630 And then that's what's being used. 940 00:46:44,630 --> 00:46:46,075 There's no fine tuning being done 941 00:46:46,075 --> 00:46:47,200 in this particular example. 942 00:46:51,110 --> 00:46:56,530 So I consider this to be a very naive application of transfer. 943 00:46:56,530 --> 00:46:59,770 The results are exactly what you would expect the results to be. 944 00:46:59,770 --> 00:47:03,820 And, obviously, a lot of work had to go into doing this. 945 00:47:03,820 --> 00:47:06,565 And there's a bit of creativity in thinking that you should use 946 00:47:06,565 --> 00:47:08,440 the English-based description of the features 947 00:47:08,440 --> 00:47:10,023 to come up with the automatic mapping, 948 00:47:10,023 --> 00:47:13,250 but the story ends there. 949 00:47:13,250 --> 00:47:16,480 And so a question which all of you 950 00:47:16,480 --> 00:47:18,520 might have is, how could you try to do 951 00:47:18,520 --> 00:47:20,500 such an approach automatically? 952 00:47:20,500 --> 00:47:23,020 How could we automatically find representations-- new 953 00:47:23,020 --> 00:47:24,520 representations of the data that are 954 00:47:24,520 --> 00:47:26,470 likely to generalize from, let's say, 955 00:47:26,470 --> 00:47:29,970 a source distribution to a target distribution? 956 00:47:29,970 --> 00:47:31,630 And so to talk about that, we're going 957 00:47:31,630 --> 00:47:34,270 to now start thinking through representation 958 00:47:34,270 --> 00:47:37,060 learning-based approaches, of which deep models are 959 00:47:37,060 --> 00:47:39,730 particularly capable. 960 00:47:39,730 --> 00:47:43,930 So the simplest approach to try to do transfer learning 961 00:47:43,930 --> 00:47:47,860 in the context of, let's say, deep neural networks, 962 00:47:47,860 --> 00:47:52,330 would be to just chop off part of the network and reuse that-- 963 00:47:52,330 --> 00:47:56,810 some internal representation of the data in this new location. 964 00:47:56,810 --> 00:47:59,990 So the picture looks a little bit like this. 965 00:47:59,990 --> 00:48:02,600 So the data might feed in the bottom. 966 00:48:02,600 --> 00:48:04,600 There might be a number of convolutional layers, 967 00:48:04,600 --> 00:48:05,782 some fully connected layers. 968 00:48:05,782 --> 00:48:07,240 And what you decide to do is you're 969 00:48:07,240 --> 00:48:10,660 going to take this model that's trained in one institution, 970 00:48:10,660 --> 00:48:14,930 you chop it at some layer, it might be, for example, 971 00:48:14,930 --> 00:48:17,920 prior to the last fully connected layer, 972 00:48:17,920 --> 00:48:20,890 and then you're going to take that-- 973 00:48:20,890 --> 00:48:23,690 take the new representation of your data, 974 00:48:23,690 --> 00:48:25,210 now the representation of the data 975 00:48:25,210 --> 00:48:29,590 is what you would get out after doing some convolutions 976 00:48:29,590 --> 00:48:32,020 followed by a single fully connected layer, 977 00:48:32,020 --> 00:48:36,160 and then you're going to take your target distribution's 978 00:48:36,160 --> 00:48:38,740 data, which you might only have a small amount of, 979 00:48:38,740 --> 00:48:41,660 and you learn a simple model on top of that new representation. 980 00:48:41,660 --> 00:48:43,570 So, for example, you might learn a shallow classifier 981 00:48:43,570 --> 00:48:45,112 using a support vector machine on top 982 00:48:45,112 --> 00:48:46,270 of that new representation. 983 00:48:46,270 --> 00:48:50,160 Or you might add in some more-- 984 00:48:50,160 --> 00:48:52,660 a couple more layers of a deep neural network, and then fine 985 00:48:52,660 --> 00:48:54,220 tune the whole thing end to end. 986 00:48:54,220 --> 00:48:56,410 So all of these have been tried. 987 00:48:56,410 --> 00:49:00,050 And in some cases, one works better than another. 988 00:49:00,050 --> 00:49:05,590 And we saw already one example of this notion in this course. 989 00:49:05,590 --> 00:49:09,700 And that was when Adam Yala spoke in lecture 13 990 00:49:09,700 --> 00:49:14,440 about breast cancer and mammography, 991 00:49:14,440 --> 00:49:19,660 where in his approach he said that he had tried both 992 00:49:19,660 --> 00:49:25,030 taking a randomly initialized classifier 993 00:49:25,030 --> 00:49:28,030 and comparing that to what would happen if you initialized 994 00:49:28,030 --> 00:49:32,560 with a well-known ImageNet-based deep neural 995 00:49:32,560 --> 00:49:34,360 network for the problem. 996 00:49:34,360 --> 00:49:37,150 And he had a really interesting story that he gave. 997 00:49:37,150 --> 00:49:42,190 In his case, he had enough data that he actually 998 00:49:42,190 --> 00:49:47,080 didn't need to initialize using this pre-trained model 999 00:49:47,080 --> 00:49:48,250 from ImageNet. 1000 00:49:48,250 --> 00:49:52,060 If he had just done a random initialization, eventually-- 1001 00:49:52,060 --> 00:49:53,590 and this x-axis, I can't remember, 1002 00:49:53,590 --> 00:49:57,850 it might be hours of training or epochs, I don't remember, 1003 00:49:57,850 --> 00:49:58,600 it's time-- 1004 00:49:58,600 --> 00:50:00,160 eventually the right initialization 1005 00:50:00,160 --> 00:50:02,170 gets to a very similar performance. 1006 00:50:02,170 --> 00:50:04,540 But for his particular case, if you 1007 00:50:04,540 --> 00:50:08,740 were to do a initialization with ImageNet and then fine tune, 1008 00:50:08,740 --> 00:50:10,940 you get there much, much quicker. 1009 00:50:10,940 --> 00:50:13,090 And so it was for the computational reason 1010 00:50:13,090 --> 00:50:14,870 that he found it to be useful. 1011 00:50:14,870 --> 00:50:17,290 But in many other applications in medical imaging, 1012 00:50:17,290 --> 00:50:19,660 the same tricks become essential because you just 1013 00:50:19,660 --> 00:50:22,020 don't have enough data in the new test case. 1014 00:50:22,020 --> 00:50:25,660 And so one makes use of, for example, the filters 1015 00:50:25,660 --> 00:50:29,170 which one learns from an ImageNet's task, which 1016 00:50:29,170 --> 00:50:33,010 is dramatically different from the medical imaging problems, 1017 00:50:33,010 --> 00:50:34,870 and then using those same filters together 1018 00:50:34,870 --> 00:50:37,330 with a new top layer, set of top layers 1019 00:50:37,330 --> 00:50:41,265 in order to fine tune it for the problem that you care about. 1020 00:50:41,265 --> 00:50:42,640 So this would be the simplest way 1021 00:50:42,640 --> 00:50:46,990 to try to hope for a common representation for transfer 1022 00:50:46,990 --> 00:50:49,240 in a deep architecture. 1023 00:50:49,240 --> 00:50:52,480 But you might ask, how would you do the same sort of thing 1024 00:50:52,480 --> 00:50:56,200 with temporal data, not image data, maybe data 1025 00:50:56,200 --> 00:51:00,010 that's from language, or data from time series of health 1026 00:51:00,010 --> 00:51:01,113 insurance claims? 1027 00:51:01,113 --> 00:51:02,530 And for that you really want to be 1028 00:51:02,530 --> 00:51:05,420 thinking about recurrent neural networks. 1029 00:51:05,420 --> 00:51:08,050 So just to remind you, recurrent neural network 1030 00:51:08,050 --> 00:51:10,030 is a recurrent architecture where 1031 00:51:10,030 --> 00:51:11,852 you take as input some vector. 1032 00:51:11,852 --> 00:51:13,810 For example, if you're doing language modeling, 1033 00:51:13,810 --> 00:51:16,393 that vector might be encoding, just a one-hot encoding of what 1034 00:51:16,393 --> 00:51:17,810 is the word at that location. 1035 00:51:17,810 --> 00:51:20,027 So, for example, this vector might be all zeros, 1036 00:51:20,027 --> 00:51:21,610 except for the fourth dimension, which 1037 00:51:21,610 --> 00:51:24,970 is a 1, denoting that this word is the word, quote, "class." 1038 00:51:24,970 --> 00:51:28,780 And then it's fed into a recurrent unit, which 1039 00:51:28,780 --> 00:51:32,117 takes the previous hidden state, combined it 1040 00:51:32,117 --> 00:51:34,450 with the current input, and gets you a new hidden state. 1041 00:51:34,450 --> 00:51:37,990 And in this way, you read in-- you encode the full input. 1042 00:51:37,990 --> 00:51:39,580 And then you might predict-- 1043 00:51:39,580 --> 00:51:41,680 make a classification based on the hidden state 1044 00:51:41,680 --> 00:51:43,055 of the last time [? step. ?] That 1045 00:51:43,055 --> 00:51:44,720 would be a common approach. 1046 00:51:44,720 --> 00:51:47,800 And here would be a very simple example of a recurrent unit. 1047 00:51:47,800 --> 00:51:49,750 Here I'm using S to denote in a state. 1048 00:51:49,750 --> 00:51:52,780 Often you will see H used to denote the hidden state. 1049 00:51:52,780 --> 00:51:54,400 This is a particularly simple example, 1050 00:51:54,400 --> 00:51:56,560 where there's just a single non-linearity. 1051 00:51:56,560 --> 00:51:58,630 So you take your previous hidden state, 1052 00:51:58,630 --> 00:52:06,130 you hit it with some matrix Ws,s and you add that to the input 1053 00:52:06,130 --> 00:52:08,770 being hit by a different matrix. 1054 00:52:08,770 --> 00:52:11,500 You now have a combination of the input 1055 00:52:11,500 --> 00:52:12,833 plus the previous hidden state. 1056 00:52:12,833 --> 00:52:14,500 You apply non-linearity to that, and you 1057 00:52:14,500 --> 00:52:15,750 get your new hidden state out. 1058 00:52:15,750 --> 00:52:18,940 So that would be an example of a typical recurrent unit, 1059 00:52:18,940 --> 00:52:20,950 a very simple recurrent unit. 1060 00:52:20,950 --> 00:52:23,200 Now, the reason why I'm going through these details is 1061 00:52:23,200 --> 00:52:28,570 to point out that the dimension of that Ws,x matrix is 1062 00:52:28,570 --> 00:52:32,164 the dimension of the hidden state, so the dimension of s, 1063 00:52:32,164 --> 00:52:36,430 by the vocabulary size if you're using a one-hot encoding 1064 00:52:36,430 --> 00:52:38,060 of the input. 1065 00:52:38,060 --> 00:52:42,850 So if you have a huge vocabulary, that matrix, Ws,x, 1066 00:52:42,850 --> 00:52:45,250 is also going to be equally large. 1067 00:52:45,250 --> 00:52:47,110 And the challenge that that presents 1068 00:52:47,110 --> 00:52:52,360 is that it would lead to overfitting on rare words 1069 00:52:52,360 --> 00:52:54,560 very quickly. 1070 00:52:54,560 --> 00:52:57,220 And so that's a problem that could be addressed by instead 1071 00:52:57,220 --> 00:53:03,400 using a low-rank representation of that Ws,x matrix. 1072 00:53:03,400 --> 00:53:06,880 In particular, you could think about introducing 1073 00:53:06,880 --> 00:53:11,620 a lower dimensional bottleneck, which in this picture 1074 00:53:11,620 --> 00:53:17,740 I'm denoting as xt prime, which is your original xt 1075 00:53:17,740 --> 00:53:19,930 input, which is the one-hot encoding, 1076 00:53:19,930 --> 00:53:21,970 multiplied by a new matrix We. 1077 00:53:24,550 --> 00:53:28,360 And then your recurrent unit only takes 1078 00:53:28,360 --> 00:53:30,220 inputs of that hidden-- 1079 00:53:30,220 --> 00:53:34,840 of that xt prime's dimension, which 1080 00:53:34,840 --> 00:53:39,340 is k, which might be dramatically smaller than v. 1081 00:53:39,340 --> 00:53:41,290 And you can even think about each column 1082 00:53:41,290 --> 00:53:44,590 of that intermediate representation, We, 1083 00:53:44,590 --> 00:53:46,600 as a word embedding. 1084 00:53:46,600 --> 00:53:49,180 It's a way of-- 1085 00:53:49,180 --> 00:53:51,363 and this is something that Pete talked quite a bit 1086 00:53:51,363 --> 00:53:53,530 about when we were thinking about natural language-- 1087 00:53:53,530 --> 00:53:56,110 when we were talking about natural language processing. 1088 00:53:56,110 --> 00:53:58,690 And many of you would have heard about it 1089 00:53:58,690 --> 00:54:02,470 in the context of things like Word2Vec. 1090 00:54:02,470 --> 00:54:08,650 So if one wanted to take a setting, for example, 1091 00:54:08,650 --> 00:54:14,335 one institution's data where you had a huge amount of data, 1092 00:54:14,335 --> 00:54:16,690 learn every current neural network on that institution's 1093 00:54:16,690 --> 00:54:19,510 data, and then generalize it to a new institution, 1094 00:54:19,510 --> 00:54:22,630 one way of trying to do that, if you think about, 1095 00:54:22,630 --> 00:54:25,958 what is the thing that you chop, one answer might be, all you do 1096 00:54:25,958 --> 00:54:27,250 is you keep the word embedding. 1097 00:54:27,250 --> 00:54:28,875 So you might say, OK, I'm going to keep 1098 00:54:28,875 --> 00:54:32,980 the We's, I'm going to translate it back to my new institution. 1099 00:54:32,980 --> 00:54:35,830 But I'm going to let the recurrent unit parameters-- 1100 00:54:35,830 --> 00:54:37,490 the recurrent parameters, for example, 1101 00:54:37,490 --> 00:54:41,380 that Ws,s you might allow it to be relearned for each new 1102 00:54:41,380 --> 00:54:43,035 institution. 1103 00:54:43,035 --> 00:54:44,410 And so that might be one approach 1104 00:54:44,410 --> 00:54:46,720 of how to use the same idea that we 1105 00:54:46,720 --> 00:54:53,530 had from feed forward networks within a recurrent setting. 1106 00:54:53,530 --> 00:54:57,110 Now, all of this is very general. 1107 00:54:57,110 --> 00:54:59,890 And what I want to do next is to instantiate it 1108 00:54:59,890 --> 00:55:05,080 a bit in the context of health care. 1109 00:55:05,080 --> 00:55:09,610 So since the time that Pete presented 1110 00:55:09,610 --> 00:55:15,190 the extensions of Word2Vec such as BERT and ELMo, 1111 00:55:15,190 --> 00:55:16,120 and I'm not going to-- 1112 00:55:16,120 --> 00:55:17,537 I'm not going to go into them now, 1113 00:55:17,537 --> 00:55:20,290 but you can go back to Pete's lecture from a few weeks 1114 00:55:20,290 --> 00:55:22,975 ago to remind yourselves what those were, since the time 1115 00:55:22,975 --> 00:55:24,850 he presented that lecture, there are actually 1116 00:55:24,850 --> 00:55:26,650 three new papers that actually tried 1117 00:55:26,650 --> 00:55:30,370 to apply this in the health care context, one of which 1118 00:55:30,370 --> 00:55:32,680 was from MIT. 1119 00:55:32,680 --> 00:55:36,490 And so these papers all have the same sort of idea. 1120 00:55:36,490 --> 00:55:39,640 They're going to take some data set-- 1121 00:55:39,640 --> 00:55:43,480 and these papers all use MIMIC. 1122 00:55:43,480 --> 00:55:45,370 They're going to take that text data, 1123 00:55:45,370 --> 00:55:48,850 they're going to learn some word embeddings 1124 00:55:48,850 --> 00:55:50,500 or some low-dimensional representations 1125 00:55:50,500 --> 00:55:52,300 of all words in the vocabulary. 1126 00:55:52,300 --> 00:55:54,460 In this case, they're not learning 1127 00:55:54,460 --> 00:55:56,290 a static representation for each word. 1128 00:55:56,290 --> 00:55:59,140 Instead these BERT and ELMo approaches 1129 00:55:59,140 --> 00:56:00,640 are going to be learning-- well, you 1130 00:56:00,640 --> 00:56:02,330 can think of it as dynamic representations. 1131 00:56:02,330 --> 00:56:04,080 They're going to be a function of the word 1132 00:56:04,080 --> 00:56:06,713 and their context on the left and right-hand sides. 1133 00:56:06,713 --> 00:56:08,380 And then what they'll do is they'll then 1134 00:56:08,380 --> 00:56:10,930 take those representations and attempt to use them 1135 00:56:10,930 --> 00:56:13,120 for a completely new task. 1136 00:56:13,120 --> 00:56:17,210 Those new tasks might be on MIMIC data. 1137 00:56:17,210 --> 00:56:20,410 So, for example, these two tasks are classification problems 1138 00:56:20,410 --> 00:56:21,310 on MIMIC. 1139 00:56:21,310 --> 00:56:23,210 But they might also be on non-MIMIC data. 1140 00:56:23,210 --> 00:56:27,490 So these two tasks are from classification problems 1141 00:56:27,490 --> 00:56:30,830 on clinical text that didn't even come from MIMIC at all. 1142 00:56:30,830 --> 00:56:32,578 So it's really an example of translating 1143 00:56:32,578 --> 00:56:34,120 what you learned from one institution 1144 00:56:34,120 --> 00:56:35,650 to another institution. 1145 00:56:35,650 --> 00:56:37,450 These two data sets were super small. 1146 00:56:37,450 --> 00:56:40,475 Actually, all of these data sets were really, really small 1147 00:56:40,475 --> 00:56:42,100 compared to the original size of MIMIC. 1148 00:56:42,100 --> 00:56:44,725 So there might be some hope that one could learn something that 1149 00:56:44,725 --> 00:56:46,660 really improves generalization. 1150 00:56:46,660 --> 00:56:48,890 And indeed, that's what plays out. 1151 00:56:48,890 --> 00:56:53,450 So all these tasks are looking at a concept detection task. 1152 00:56:53,450 --> 00:56:59,240 Given a clinical note, identify the segments of text 1153 00:56:59,240 --> 00:57:01,280 within a note that refer to, for example, 1154 00:57:01,280 --> 00:57:04,280 a disorder, or a treatment, or something else, which 1155 00:57:04,280 --> 00:57:08,030 you then in a second stage might normalize to the UMLS. 1156 00:57:10,790 --> 00:57:13,940 So what's really striking about these results 1157 00:57:13,940 --> 00:57:18,590 is what happens when you go from the left to the right 1158 00:57:18,590 --> 00:57:20,280 column, which I'll explain in a second, 1159 00:57:20,280 --> 00:57:22,520 and what happens when you go top to bottom 1160 00:57:22,520 --> 00:57:24,810 across each one of these different tasks. 1161 00:57:24,810 --> 00:57:27,630 So the left column are the results. 1162 00:57:27,630 --> 00:57:33,230 And these results are an F score, the results, 1163 00:57:33,230 --> 00:57:39,365 if you were to use embeddings trained on a non-clinical data 1164 00:57:39,365 --> 00:57:42,260 set, or said definitely, not on MIMIC but on some other more 1165 00:57:42,260 --> 00:57:44,427 general data set. 1166 00:57:44,427 --> 00:57:46,010 The second column is what would happen 1167 00:57:46,010 --> 00:57:49,730 if you trained those embedding on a clinical data set, 1168 00:57:49,730 --> 00:57:51,440 in this case, MIMIC. 1169 00:57:51,440 --> 00:57:54,230 And you see pretty big improvements 1170 00:57:54,230 --> 00:57:58,550 from the general embeddings to the MIMIC-based embeddings. 1171 00:57:58,550 --> 00:58:01,040 What's even more striking is the improvements 1172 00:58:01,040 --> 00:58:04,190 that happen as you get better and better embeddings. 1173 00:58:04,190 --> 00:58:07,040 So the first row are the results if you 1174 00:58:07,040 --> 00:58:09,380 were to use just Word2Vec embeddings. 1175 00:58:09,380 --> 00:58:15,470 And so, for example, for the I2B2 Challenge in 2010, 1176 00:58:15,470 --> 00:58:20,720 you get 82.65 F score using Word2Vec embeddings. 1177 00:58:20,720 --> 00:58:23,300 And if you use a very large BERT embedding, 1178 00:58:23,300 --> 00:58:28,010 you get 90.25 F score-- 1179 00:58:28,010 --> 00:58:31,850 F measure, which is substantially higher. 1180 00:58:31,850 --> 00:58:34,200 And the same findings were found time and time again 1181 00:58:34,200 --> 00:58:36,810 across different tasks. 1182 00:58:36,810 --> 00:58:39,740 Now, what I find really striking about these results 1183 00:58:39,740 --> 00:58:43,160 is that I had tried many of these things a couple of years 1184 00:58:43,160 --> 00:58:46,040 ago, not using BERT or ELMo, but using Word2Vec, 1185 00:58:46,040 --> 00:58:48,320 and GloVe, and fastText. 1186 00:58:48,320 --> 00:58:52,550 And what I found is that using word embedding approaches 1187 00:58:52,550 --> 00:58:54,530 for these problems didn't-- 1188 00:58:54,530 --> 00:58:57,440 even if you threw that in as additional features on top 1189 00:58:57,440 --> 00:59:03,110 of other state-of-the-art approaches to this concept 1190 00:59:03,110 --> 00:59:06,470 extraction problem, it did not improve predictive performance 1191 00:59:06,470 --> 00:59:09,050 above the existing state of the art. 1192 00:59:09,050 --> 00:59:11,240 However, in this paper, here they 1193 00:59:11,240 --> 00:59:13,820 use the simplest possible algorithm. 1194 00:59:13,820 --> 00:59:15,710 They used a recurrent neural network 1195 00:59:15,710 --> 00:59:18,080 fed into a conditional random field 1196 00:59:18,080 --> 00:59:21,680 for the purpose of classifying each word into each 1197 00:59:21,680 --> 00:59:23,240 of these categories. 1198 00:59:23,240 --> 00:59:25,160 And the feature represent-- the features 1199 00:59:25,160 --> 00:59:28,320 that they used are just these embedding features. 1200 00:59:28,320 --> 00:59:30,370 So with just the Word2Vec embedding features, 1201 00:59:30,370 --> 00:59:31,370 the performance is crap. 1202 00:59:31,370 --> 00:59:33,740 You don't get anywhere close to the state of art. 1203 00:59:33,740 --> 00:59:37,070 But with the better embeddings, they actually obtain-- 1204 00:59:37,070 --> 00:59:39,440 actually, they improved on the state of the art 1205 00:59:39,440 --> 00:59:43,560 for every single one of these tasks. 1206 00:59:43,560 --> 00:59:46,010 And that is without any of the manual feature 1207 00:59:46,010 --> 00:59:48,110 engineering which we have been using 1208 00:59:48,110 --> 00:59:50,670 in the field for the last decade. 1209 00:59:50,670 --> 00:59:54,620 So I find this to be extremely promising. 1210 00:59:54,620 --> 00:59:59,090 Now you might ask, well, that is for one problem, which 1211 00:59:59,090 --> 01:00:04,700 is classification of concepts-- or identification of concepts. 1212 01:00:04,700 --> 01:00:06,900 What about for a predictive problem? 1213 01:00:06,900 --> 01:00:09,500 So a different paper also published-- 1214 01:00:09,500 --> 01:00:13,670 what month is it now, May-- so last month in April, 1215 01:00:13,670 --> 01:00:16,700 looked at a predicted problem of 30-day readmission prediction 1216 01:00:16,700 --> 01:00:18,100 using discharge summaries. 1217 01:00:18,100 --> 01:00:20,600 This also was valued on MIMIC. 1218 01:00:20,600 --> 01:00:23,960 And their evaluation looked at the area 1219 01:00:23,960 --> 01:00:26,610 under the ROC curve of two different approaches. 1220 01:00:26,610 --> 01:00:29,270 The first approach, which is using a bag-of-words model, 1221 01:00:29,270 --> 01:00:32,090 like what you did in your homework assignment, 1222 01:00:32,090 --> 01:00:35,360 and the second approach, which is the top row there, 1223 01:00:35,360 --> 01:00:40,640 which is using BERT embeddings, which they call Clinical BERT. 1224 01:00:40,640 --> 01:00:43,340 And this, again, is something which I had 1225 01:00:43,340 --> 01:00:44,760 tackled for quite a long time. 1226 01:00:44,760 --> 01:00:46,970 So I worked on these types of readmission problems. 1227 01:00:46,970 --> 01:00:48,887 And bag-of-words model is really hard to beat. 1228 01:00:48,887 --> 01:00:54,090 In fact, did any of you beat it in your homework assignment? 1229 01:00:54,090 --> 01:00:56,070 If you remember, there was an extra question, 1230 01:00:56,070 --> 01:00:57,860 which is, oh, well, maybe if we used 1231 01:00:57,860 --> 01:00:59,870 a deep learning-based approach for this problem, 1232 01:00:59,870 --> 01:01:01,495 maybe you could get better performance. 1233 01:01:01,495 --> 01:01:03,870 Did anyone get better performance? 1234 01:01:03,870 --> 01:01:04,723 No. 1235 01:01:04,723 --> 01:01:06,140 How many of you actually tried it? 1236 01:01:06,140 --> 01:01:08,790 Raise your hand. 1237 01:01:08,790 --> 01:01:11,280 OK, so one-- a couple of people who are afraid to 1238 01:01:11,280 --> 01:01:11,970 say, but yeah. 1239 01:01:11,970 --> 01:01:13,887 So a couple of people who tried, but not many. 1240 01:01:16,020 --> 01:01:19,387 But I think the reason why it's very challenging to do better 1241 01:01:19,387 --> 01:01:21,470 with, let's say, a recurrent neural network versus 1242 01:01:21,470 --> 01:01:25,500 a bag-of-words model is because there is-- 1243 01:01:25,500 --> 01:01:30,600 a lot of the subtlety in understanding the text 1244 01:01:30,600 --> 01:01:32,845 is in terms of understanding the context of the text. 1245 01:01:32,845 --> 01:01:35,220 And that's something that using these newer embeddings is 1246 01:01:35,220 --> 01:01:37,137 actually really good at because they can get-- 1247 01:01:37,137 --> 01:01:39,410 they could use the context of words 1248 01:01:39,410 --> 01:01:42,910 to better represent what each word actually means. 1249 01:01:42,910 --> 01:01:44,560 And they see substantial improvement 1250 01:01:44,560 --> 01:01:47,250 in performance using this approach. 1251 01:01:47,250 --> 01:01:48,790 What about for non-text data? 1252 01:01:48,790 --> 01:01:54,220 So you might ask when we have health insurance claims, 1253 01:01:54,220 --> 01:01:56,373 we have longitudinal data across time. 1254 01:01:56,373 --> 01:01:57,540 There's no language in this. 1255 01:01:57,540 --> 01:01:58,860 It's a time series data set. 1256 01:01:58,860 --> 01:02:02,050 You have ICD-9 codes at each point in time, 1257 01:02:02,050 --> 01:02:04,300 you have maybe lab test results, medication records. 1258 01:02:04,300 --> 01:02:06,300 And this is very similar to the market scan data 1259 01:02:06,300 --> 01:02:08,580 that you used in your homework assignment. 1260 01:02:08,580 --> 01:02:12,270 Could one learn embeddings for this type of data, which 1261 01:02:12,270 --> 01:02:15,220 is also useful for transfer? 1262 01:02:15,220 --> 01:02:20,760 So one goal might be to say, OK, let's take every ICD-9, ICD-10 1263 01:02:20,760 --> 01:02:23,890 code, every medication, every laboratory test result, 1264 01:02:23,890 --> 01:02:28,562 and embed those event types into some lower dimensional space. 1265 01:02:28,562 --> 01:02:30,270 And so here's an example of an embedding. 1266 01:02:30,270 --> 01:02:32,740 And you see how-- this is just a sketch, by the way-- 1267 01:02:32,740 --> 01:02:35,400 you see how you might hope that diagnosis 1268 01:02:35,400 --> 01:02:37,110 codes for autoimmune conditions might 1269 01:02:37,110 --> 01:02:39,420 be all near each other in some lower dimensional 1270 01:02:39,420 --> 01:02:42,660 space, diagnosis codes for medications 1271 01:02:42,660 --> 01:02:45,000 that treat some conditions should be near each other, 1272 01:02:45,000 --> 01:02:45,500 and so on. 1273 01:02:45,500 --> 01:02:47,970 So you might hope that such structure might be discovered 1274 01:02:47,970 --> 01:02:50,730 by an unsupervised learning algorithm that could then 1275 01:02:50,730 --> 01:02:53,585 be used within a transfer learning approach. 1276 01:02:53,585 --> 01:02:54,960 And indeed, that's what we found. 1277 01:02:54,960 --> 01:03:00,270 So I wrote a paper on this in 2015/16. 1278 01:03:00,270 --> 01:03:05,888 And here's one of the results from that paper. 1279 01:03:05,888 --> 01:03:07,680 So this is just a look at nearest neighbors 1280 01:03:07,680 --> 01:03:12,870 to give you some sense of whether the embedding's 1281 01:03:12,870 --> 01:03:14,795 actually capturing the structure of the data. 1282 01:03:14,795 --> 01:03:16,170 So we looked at nearest neighbors 1283 01:03:16,170 --> 01:03:22,930 of the diagnosis ICD-9 diagnosis code 710.0, which is lupus. 1284 01:03:22,930 --> 01:03:25,590 And what you find is that another diagnosis code, also 1285 01:03:25,590 --> 01:03:28,890 for lupus, is the first closest result, followed 1286 01:03:28,890 --> 01:03:31,930 by connective tissue disorder, or Sicca syndrome, 1287 01:03:31,930 --> 01:03:34,587 which is Sjogren's disease, Raynaud's syndrome, 1288 01:03:34,587 --> 01:03:35,920 and other autoimmune conditions. 1289 01:03:35,920 --> 01:03:37,380 So that makes a lot of sense. 1290 01:03:37,380 --> 01:03:39,210 You can also go across data types, 1291 01:03:39,210 --> 01:03:42,390 like ask, what is the nearest neighbor from this diagnosis 1292 01:03:42,390 --> 01:03:44,790 code to laboratory tests? 1293 01:03:44,790 --> 01:03:47,055 And since we've embedded lab tests and diagnosis 1294 01:03:47,055 --> 01:03:48,930 codes all in the same space, you can actually 1295 01:03:48,930 --> 01:03:49,860 get an answer to that. 1296 01:03:49,860 --> 01:03:52,465 And what you see is that these lab tests, which by the way 1297 01:03:52,465 --> 01:03:54,090 are exactly lab tests that are commonly 1298 01:03:54,090 --> 01:04:00,090 used to understand progression in this autoimmune condition, 1299 01:04:00,090 --> 01:04:01,890 are the closest neighbors. 1300 01:04:01,890 --> 01:04:07,110 Similarly, you can ask the same question about drugs and so on. 1301 01:04:07,110 --> 01:04:11,730 And by the way, we have made all of these embeddings publicly 1302 01:04:11,730 --> 01:04:14,970 available on my lab's GitHub. 1303 01:04:14,970 --> 01:04:17,040 And since the time that I wrote this paper, 1304 01:04:17,040 --> 01:04:18,930 there have been a number of other papers, 1305 01:04:18,930 --> 01:04:21,150 that I give citations to at the bottom here, 1306 01:04:21,150 --> 01:04:22,890 tackling a very similar problem. 1307 01:04:22,890 --> 01:04:27,780 This last one also put there embeddings publicly available, 1308 01:04:27,780 --> 01:04:31,890 and is much larger than the one that we had So these things, 1309 01:04:31,890 --> 01:04:34,020 I think, would also be very useful as one starts 1310 01:04:34,020 --> 01:04:37,860 to think about how one can transfer knowledge learned 1311 01:04:37,860 --> 01:04:40,062 on one institution to another institution 1312 01:04:40,062 --> 01:04:41,520 where you might have much less data 1313 01:04:41,520 --> 01:04:42,831 than that other institution. 1314 01:04:45,480 --> 01:04:48,810 So finally I want to return back to the question 1315 01:04:48,810 --> 01:04:50,820 that I raised in bullet two here, 1316 01:04:50,820 --> 01:04:52,290 where we looked at a linear model 1317 01:04:52,290 --> 01:04:54,180 with a manually chosen representation, 1318 01:04:54,180 --> 01:04:57,300 and ask, could we-- 1319 01:04:57,300 --> 01:05:01,770 instead of just naively chopping your deep neural network 1320 01:05:01,770 --> 01:05:04,680 at some layer and then fine tuning, 1321 01:05:04,680 --> 01:05:07,320 could one have learned a representation of your data 1322 01:05:07,320 --> 01:05:11,580 specifically for the purpose of encouraging good generalization 1323 01:05:11,580 --> 01:05:13,680 to a new institution? 1324 01:05:13,680 --> 01:05:17,220 And there has been some really exciting work 1325 01:05:17,220 --> 01:05:23,030 in this field that goes by the name of Unsupervised Domain 1326 01:05:23,030 --> 01:05:23,690 Adaptation. 1327 01:05:34,030 --> 01:05:36,510 So the setting that's considered here 1328 01:05:36,510 --> 01:05:38,490 is where you have data from-- you 1329 01:05:38,490 --> 01:05:42,190 have data from first some institution, which is x 1330 01:05:42,190 --> 01:05:44,850 comma y. 1331 01:05:44,850 --> 01:05:48,810 But then you want to do prediction 1332 01:05:48,810 --> 01:05:51,390 from a new institution where all you have access 1333 01:05:51,390 --> 01:05:55,260 to at training time is x. 1334 01:05:55,260 --> 01:05:57,510 So as opposed to the transfer settings 1335 01:05:57,510 --> 01:06:00,300 that I talked about earlier, now for this new institution, 1336 01:06:00,300 --> 01:06:03,002 you might have a ton of unlabeled data. 1337 01:06:03,002 --> 01:06:04,710 Whereas before I was talking about having 1338 01:06:04,710 --> 01:06:06,198 just a small amount of label data, 1339 01:06:06,198 --> 01:06:07,740 but I never talked of the possibility 1340 01:06:07,740 --> 01:06:09,950 of having a large amount of unlabeled data. 1341 01:06:09,950 --> 01:06:11,460 And so you might ask, how could you 1342 01:06:11,460 --> 01:06:13,890 use that large amount of unlabeled data 1343 01:06:13,890 --> 01:06:16,410 from that second institution in order 1344 01:06:16,410 --> 01:06:19,110 to learn representation that actually encourages 1345 01:06:19,110 --> 01:06:21,955 similarities from one solution to the other? 1346 01:06:21,955 --> 01:06:24,330 And that's exactly what these domain adversarial training 1347 01:06:24,330 --> 01:06:25,890 approaches will do. 1348 01:06:25,890 --> 01:06:28,170 What they do is they add a second term 1349 01:06:28,170 --> 01:06:29,740 to the last function. 1350 01:06:29,740 --> 01:06:31,590 So they're going to minimize-- 1351 01:06:31,590 --> 01:06:34,055 the intuition is you're going to minimize-- 1352 01:06:34,055 --> 01:06:35,680 you're going to try to learn parameters 1353 01:06:35,680 --> 01:06:45,150 that minimize your loss function evaluated on data set 1. 1354 01:06:45,150 --> 01:06:49,530 But intuitively, you're going to ask that there also 1355 01:06:49,530 --> 01:06:55,350 be a small distance, which I'll just note as d here, 1356 01:06:55,350 --> 01:07:00,105 between D1 and D2. 1357 01:07:00,105 --> 01:07:02,850 And so I'm being a little bit loose with notation here, 1358 01:07:02,850 --> 01:07:04,840 but when I calculate distance here, 1359 01:07:04,840 --> 01:07:07,467 I'm referring to distance in representation space. 1360 01:07:07,467 --> 01:07:09,300 So you might imagine taking the middle layer 1361 01:07:09,300 --> 01:07:10,740 of your deep neural network, so taking, 1362 01:07:10,740 --> 01:07:12,690 let's say, this layer, which we're going to call the feature 1363 01:07:12,690 --> 01:07:15,240 layer, or the representation layer, and you're going to say, 1364 01:07:15,240 --> 01:07:20,700 I want that my data under the first institution 1365 01:07:20,700 --> 01:07:22,740 should look very similar to the data 1366 01:07:22,740 --> 01:07:23,950 under the second institution. 1367 01:07:23,950 --> 01:07:26,490 So the first few layers of your deep neural network 1368 01:07:26,490 --> 01:07:29,310 are going to attempt to equalize the two data sets 1369 01:07:29,310 --> 01:07:34,140 so that they look similar to another, at least in x space. 1370 01:07:34,140 --> 01:07:36,480 And we're going to attempt to find representations 1371 01:07:36,480 --> 01:07:38,730 of your model that get good predictive performance 1372 01:07:38,730 --> 01:07:41,340 on the data set for which you actually have the labels 1373 01:07:41,340 --> 01:07:44,010 and for which the induced representations, let's 1374 01:07:44,010 --> 01:07:47,670 say, the middle layer look very similar across the two data 1375 01:07:47,670 --> 01:07:48,420 sets. 1376 01:07:48,420 --> 01:07:51,480 And one way to do that is just to try to predict for each-- 1377 01:07:51,480 --> 01:07:52,800 you now get a-- 1378 01:07:52,800 --> 01:07:54,570 for each data point, you might actually 1379 01:07:54,570 --> 01:07:56,760 say, well, which data set did it come from, 1380 01:07:56,760 --> 01:07:58,650 data set 1 or data set 2? 1381 01:07:58,650 --> 01:08:00,900 And what you want is that your model should not 1382 01:08:00,900 --> 01:08:03,100 be able to distinguish which data set it came from. 1383 01:08:03,100 --> 01:08:05,645 So that's what it says, gradient reverse layer 1384 01:08:05,645 --> 01:08:07,020 you want to be able to-- you want 1385 01:08:07,020 --> 01:08:09,943 to ensure that predicting which data set that data came from, 1386 01:08:09,943 --> 01:08:11,985 you want to perform badly on that loss functions. 1387 01:08:11,985 --> 01:08:14,939 It's like taking the minus of that loss. 1388 01:08:14,939 --> 01:08:17,189 And so we're not going to go into the details of that, 1389 01:08:17,189 --> 01:08:18,897 but I just wanted to give you a reference 1390 01:08:18,897 --> 01:08:20,472 to that approach in the bottom. 1391 01:08:20,472 --> 01:08:21,930 And what I want to do is just spend 1392 01:08:21,930 --> 01:08:24,720 one minute at the very end talking now about defenses 1393 01:08:24,720 --> 01:08:26,189 to adversarial attacks. 1394 01:08:26,189 --> 01:08:27,750 And conceptually this is very simple. 1395 01:08:27,750 --> 01:08:31,569 And that's why I can actually do it in one minute. 1396 01:08:31,569 --> 01:08:36,720 So we talked about how one could easily modify an image in order 1397 01:08:36,720 --> 01:08:40,566 to turn the prediction from, let's say, pig to airliner. 1398 01:08:40,566 --> 01:08:42,899 But how could we change your learning algorithm actually 1399 01:08:42,899 --> 01:08:44,550 to make sure that, despite the fact that you 1400 01:08:44,550 --> 01:08:47,130 do this perturbation, you still get the right prediction out, 1401 01:08:47,130 --> 01:08:48,359 pig? 1402 01:08:48,359 --> 01:08:50,609 Well, to think through that, we have to think through, 1403 01:08:50,609 --> 01:08:51,960 how do we do machine learning? 1404 01:08:51,960 --> 01:08:53,793 Well, a typical approach to machine learning 1405 01:08:53,793 --> 01:08:57,450 is to learn some parameters theta minimized 1406 01:08:57,450 --> 01:08:59,670 your empirical loss. 1407 01:08:59,670 --> 01:09:01,140 Often we use deep neural networks, 1408 01:09:01,140 --> 01:09:02,390 which look a little like this. 1409 01:09:02,390 --> 01:09:04,290 And we do gradient descent where we attempt 1410 01:09:04,290 --> 01:09:08,670 to minimize some loss surfaced, find some parameters theta have 1411 01:09:08,670 --> 01:09:11,890 as low loss as possible. 1412 01:09:11,890 --> 01:09:14,410 Now, when you think about an adversarial example 1413 01:09:14,410 --> 01:09:16,950 and where they come from, typically one 1414 01:09:16,950 --> 01:09:19,745 finds an adversarial example in the following way. 1415 01:09:19,745 --> 01:09:21,120 You take your same loss function, 1416 01:09:21,120 --> 01:09:23,939 now for specific input x, and you 1417 01:09:23,939 --> 01:09:26,260 try to find some perturbation delta 1418 01:09:26,260 --> 01:09:29,880 to x an additive perturbation, for example, such 1419 01:09:29,880 --> 01:09:34,109 that you increase the loss as much as possible with respect 1420 01:09:34,109 --> 01:09:36,540 to the correct label y. 1421 01:09:36,540 --> 01:09:38,790 And so if you've increased the loss with respect 1422 01:09:38,790 --> 01:09:40,672 to the correct label y, intuitively 1423 01:09:40,672 --> 01:09:42,630 then when you try to see, well, what should you 1424 01:09:42,630 --> 01:09:44,850 predict for this new perturbed input, 1425 01:09:44,850 --> 01:09:47,646 there's going to be a lower loss for some alternative label, 1426 01:09:47,646 --> 01:09:49,979 which is why the prediction-- the class that's predicted 1427 01:09:49,979 --> 01:09:51,540 actually changes. 1428 01:09:51,540 --> 01:09:54,690 So now one can try to find these adversarial examples using 1429 01:09:54,690 --> 01:09:57,090 the same type of gradient-based learning algorithms 1430 01:09:57,090 --> 01:10:01,040 that one uses for learning in the first place. 1431 01:10:01,040 --> 01:10:03,750 But what one can do is you can use a gradient descent 1432 01:10:03,750 --> 01:10:05,190 method now-- 1433 01:10:05,190 --> 01:10:07,980 instead of gradient descent, gradient ascent. 1434 01:10:07,980 --> 01:10:11,910 So you take this optimization problem for a given input x, 1435 01:10:11,910 --> 01:10:14,227 and you try to maximize that loss for that input x 1436 01:10:14,227 --> 01:10:15,810 with this vector delta, and you're now 1437 01:10:15,810 --> 01:10:18,300 doing gradient ascent. 1438 01:10:18,300 --> 01:10:20,280 And so what types of delta should you consider? 1439 01:10:20,280 --> 01:10:22,260 You can imagine small perturbations, 1440 01:10:22,260 --> 01:10:25,890 for example, delta that have very small maximum values. 1441 01:10:25,890 --> 01:10:28,270 That would be an example of an L-infinity norm. 1442 01:10:28,270 --> 01:10:30,400 Or you could say that the sum of the perturbations 1443 01:10:30,400 --> 01:10:33,160 across, let's say, all of the dimensions has to be small. 1444 01:10:33,160 --> 01:10:36,430 That would be corresponding to like an L1 or an L2 norm bound 1445 01:10:36,430 --> 01:10:38,377 on what delta should be. 1446 01:10:38,377 --> 01:10:40,210 So now we've got everything we need actually 1447 01:10:40,210 --> 01:10:42,880 to think about defenses to this type 1448 01:10:42,880 --> 01:10:46,460 of adversarial perturbation. 1449 01:10:46,460 --> 01:10:49,728 So instead of minimizing your typical empirical loss, what 1450 01:10:49,728 --> 01:10:51,520 we're going to do is we're going to attempt 1451 01:10:51,520 --> 01:10:55,510 to minimize an adversarial robust loss function. 1452 01:10:55,510 --> 01:10:57,100 What we'll do is we'll say, OK, we 1453 01:10:57,100 --> 01:11:01,660 want to be sure that no matter what the perturbation is 1454 01:11:01,660 --> 01:11:05,500 that one adds the input, the true label y 1455 01:11:05,500 --> 01:11:07,680 still has low loss. 1456 01:11:07,680 --> 01:11:10,840 So you want to find parameters theta which 1457 01:11:10,840 --> 01:11:12,820 minimize this new quantity. 1458 01:11:12,820 --> 01:11:14,980 So I'm saying that we should still 1459 01:11:14,980 --> 01:11:21,375 do well even for the worst-case adversarial perturbation. 1460 01:11:21,375 --> 01:11:23,500 And so now this would be the following new learning 1461 01:11:23,500 --> 01:11:26,110 objective, where we're going to minimize over 1462 01:11:26,110 --> 01:11:28,580 theta with respect to the maximum of our delta. 1463 01:11:28,580 --> 01:11:31,780 And you have to restrict the family that these perturbations 1464 01:11:31,780 --> 01:11:34,087 could live in, so if that delta that you're 1465 01:11:34,087 --> 01:11:35,920 maximizing with respect to is the empty set, 1466 01:11:35,920 --> 01:11:38,170 you get back the original learning problem. 1467 01:11:38,170 --> 01:11:41,470 If you let it be, let's say, all L-infinity 1468 01:11:41,470 --> 01:11:46,437 bounded perturbations of maximum size of 0.01, 1469 01:11:46,437 --> 01:11:48,520 then you're saying we're going to allow for a very 1470 01:11:48,520 --> 01:11:49,930 small amount of perturbations. 1471 01:11:49,930 --> 01:11:51,388 And the learning algorithm is going 1472 01:11:51,388 --> 01:11:53,800 to find parameters theta such that for every input, even 1473 01:11:53,800 --> 01:11:57,340 with a small perturbation to it, adversarially chosen, 1474 01:11:57,340 --> 01:11:59,590 you still get good predictive performance. 1475 01:11:59,590 --> 01:12:02,470 And this is now a new optimization problem 1476 01:12:02,470 --> 01:12:04,300 that one can solve. 1477 01:12:04,300 --> 01:12:06,040 And we've now reduced the problem 1478 01:12:06,040 --> 01:12:09,070 of finding an adversarial robust model to a new optimization 1479 01:12:09,070 --> 01:12:09,940 problem. 1480 01:12:09,940 --> 01:12:11,410 And what the field has been doing 1481 01:12:11,410 --> 01:12:13,810 in the last couple of years is coming up 1482 01:12:13,810 --> 01:12:15,970 with new optimization approaches to try 1483 01:12:15,970 --> 01:12:18,880 to solve those problems fast. 1484 01:12:18,880 --> 01:12:22,690 So, for example, this paper published an ICML in 2018 1485 01:12:22,690 --> 01:12:24,940 by Zico Kolter and his student-- 1486 01:12:24,940 --> 01:12:28,792 Zico just visited MIT a few weeks ago-- 1487 01:12:28,792 --> 01:12:30,250 what it did is it said, we're going 1488 01:12:30,250 --> 01:12:34,300 to use a convex relaxation to the rectified linear unit, 1489 01:12:34,300 --> 01:12:37,028 which is used in many deep neural network architectures. 1490 01:12:37,028 --> 01:12:39,070 And what it's going to do it's then going to say, 1491 01:12:39,070 --> 01:12:41,440 OK, we're going to think about how 1492 01:12:41,440 --> 01:12:45,790 a small perturbation to the input 1493 01:12:45,790 --> 01:12:49,000 would be propagated in terms of getting how much that could 1494 01:12:49,000 --> 01:12:50,300 actually change the output. 1495 01:12:50,300 --> 01:12:52,630 And if one could be bound at every layer 1496 01:12:52,630 --> 01:12:55,420 by layer how much a small perturbation affects 1497 01:12:55,420 --> 01:12:57,250 the output of that layer, then one 1498 01:12:57,250 --> 01:12:59,350 could propagate from the very bottom 1499 01:12:59,350 --> 01:13:01,300 all the way to the loss function of the top 1500 01:13:01,300 --> 01:13:05,260 to try to bound how much the loss function itself changes. 1501 01:13:05,260 --> 01:13:08,780 And a picture of what you would expect out is as follows. 1502 01:13:08,780 --> 01:13:11,770 On the left-hand side here, you have a data point, 1503 01:13:11,770 --> 01:13:14,800 red and blue, and the decision boundary 1504 01:13:14,800 --> 01:13:19,420 that's learned if you didn't do this robust learning algorithm. 1505 01:13:19,420 --> 01:13:21,700 On the right, you have now-- 1506 01:13:21,700 --> 01:13:25,030 you'll notice a small square around each data point. 1507 01:13:25,030 --> 01:13:28,600 That corresponds to a maximum perturbation 1508 01:13:28,600 --> 01:13:29,793 of some limited amount. 1509 01:13:29,793 --> 01:13:31,960 And now you notice how the decision boundary doesn't 1510 01:13:31,960 --> 01:13:33,520 cross any one of those squares. 1511 01:13:33,520 --> 01:13:35,950 And that's what would be found by this learning algorithm. 1512 01:13:35,950 --> 01:13:38,200 Interestingly, one can look at the filters that 1513 01:13:38,200 --> 01:13:40,483 are learned by convolutional neural network using 1514 01:13:40,483 --> 01:13:41,650 this new learning algorithm. 1515 01:13:41,650 --> 01:13:44,980 And you find that they're much more sparse. 1516 01:13:44,980 --> 01:13:48,340 And so this is a very fast moving field. 1517 01:13:48,340 --> 01:13:52,930 Every time a new adversarial attack-- 1518 01:13:52,930 --> 01:13:55,660 every time a new adversarial defense mechanism comes up, 1519 01:13:55,660 --> 01:13:57,243 someone comes up with a different type 1520 01:13:57,243 --> 01:13:58,930 of attack, which breaks it. 1521 01:13:58,930 --> 01:14:01,730 And usually that's from one of two reasons. 1522 01:14:01,730 --> 01:14:05,883 One, because the defense mechanism isn't provable, 1523 01:14:05,883 --> 01:14:08,300 and so one could try to come up with a theorem which says, 1524 01:14:08,300 --> 01:14:10,390 OK, as long as you don't perturbate 1525 01:14:10,390 --> 01:14:14,710 more than some amount, these are the results you should expect. 1526 01:14:14,710 --> 01:14:16,960 The other flip of the coin is, even if you come up 1527 01:14:16,960 --> 01:14:18,668 with some provable guarantee, there might 1528 01:14:18,668 --> 01:14:20,180 be other types of attacks. 1529 01:14:20,180 --> 01:14:22,180 So, for example, you might imagine a rotation 1530 01:14:22,180 --> 01:14:25,480 to the input instead of an L-infinity bounded norm 1531 01:14:25,480 --> 01:14:26,530 that you add to it. 1532 01:14:26,530 --> 01:14:29,200 And so for every new type of attack model, 1533 01:14:29,200 --> 01:14:31,750 you have to think through new defense mechanisms. 1534 01:14:31,750 --> 01:14:34,930 And so you should expect to see some iteration in the space. 1535 01:14:34,930 --> 01:14:38,200 And there's a website called robust-ml.org, 1536 01:14:38,200 --> 01:14:41,643 where many of these attacks and defenses are being published 1537 01:14:41,643 --> 01:14:43,810 to allow for the academic community to make progress 1538 01:14:43,810 --> 01:14:44,440 here. 1539 01:14:44,440 --> 01:14:47,260 And with that, I'll finish today's lecture.