WEBVTT

00:00:00.000 --> 00:00:01.952
[SQUEAKING]

00:00:01.952 --> 00:00:03.904
[RUSTLING]

00:00:03.904 --> 00:00:04.880
[CLICKING]

00:00:15.128 --> 00:00:18.070
DAVID SONTAG: OK, so
then today's lecture

00:00:18.070 --> 00:00:21.640
is going to be about data
set shifts, specifically

00:00:21.640 --> 00:00:25.120
how one can be robust
to data set shift.

00:00:25.120 --> 00:00:27.220
Now, this is the topic
that we've been alluding to

00:00:27.220 --> 00:00:30.250
throughout the semester.

00:00:30.250 --> 00:00:33.402
And the setting that I want
you to be thinking about

00:00:33.402 --> 00:00:33.985
is as follows.

00:00:37.150 --> 00:00:40.180
You're a data scientist
working at, let's say,

00:00:40.180 --> 00:00:44.740
Mass General
Hospital, and you've

00:00:44.740 --> 00:00:47.530
been very careful in setting
up your machine learning task

00:00:47.530 --> 00:00:51.100
to make sure that the
data is well specified,

00:00:51.100 --> 00:00:54.280
the labels that you're trying
to predict are well specified.

00:00:54.280 --> 00:00:55.480
You train on a valid--

00:00:55.480 --> 00:00:58.808
you train on your training data,
you test it on a held-out set,

00:00:58.808 --> 00:01:00.475
you see that the model
generalizes well,

00:01:00.475 --> 00:01:02.350
you do chart review to
make sure what you're

00:01:02.350 --> 00:01:05.440
predicting is actually what
you think you're predicting,

00:01:05.440 --> 00:01:08.320
and you even do prospective
deployment where you then

00:01:08.320 --> 00:01:10.242
let your machine
learning algorithm drive

00:01:10.242 --> 00:01:11.950
some clinical decision
support, and you'd

00:01:11.950 --> 00:01:14.860
see things are working great.

00:01:14.860 --> 00:01:17.020
Now what?

00:01:17.020 --> 00:01:21.590
What happens after this stage
when you go to deployment?

00:01:21.590 --> 00:01:25.000
What happens when
your same model

00:01:25.000 --> 00:01:27.040
is going to be used
not just tomorrow

00:01:27.040 --> 00:01:30.220
but also next week, the
following week, the next year?

00:01:30.220 --> 00:01:32.620
What happens if your model,
which is working well

00:01:32.620 --> 00:01:36.610
at this one hospital, then
wants to-- then there's

00:01:36.610 --> 00:01:38.545
another institution,
say, maybe Brigham

00:01:38.545 --> 00:01:41.440
and Women's Hospital,
or maybe UCSF,

00:01:41.440 --> 00:01:43.540
or some rural hospital
in the United States

00:01:43.540 --> 00:01:46.000
wants to use the
same model, will it

00:01:46.000 --> 00:01:49.390
keep working in this "short
term to the future" time period

00:01:49.390 --> 00:01:50.685
or in a new institution?

00:01:50.685 --> 00:01:52.810
That's the question which
we're going to be talking

00:01:52.810 --> 00:01:54.010
about in today's lecture.

00:01:54.010 --> 00:01:55.810
And we'll be talking
about how one

00:01:55.810 --> 00:01:59.990
can deal with data set shift
of two different varieties.

00:01:59.990 --> 00:02:03.510
The first variety is adversarial
perturbations to data,

00:02:03.510 --> 00:02:06.280
and the second variety is
data due to-- the data that

00:02:06.280 --> 00:02:09.250
changes for natural reasons.

00:02:09.250 --> 00:02:11.500
Now, the reason why
it's not at all obvious

00:02:11.500 --> 00:02:13.600
that your machine learning
algorithm should still

00:02:13.600 --> 00:02:16.480
work in the setting is because
the number one assumption

00:02:16.480 --> 00:02:18.310
we make when we do
machine learning

00:02:18.310 --> 00:02:20.680
is that your training
distribution, your training

00:02:20.680 --> 00:02:24.593
data, is drawn from the same
distribution as your test data.

00:02:24.593 --> 00:02:27.010
So if you now go to a setting
where your data distribution

00:02:27.010 --> 00:02:33.190
has changed, even if you've
computed your accuracy using

00:02:33.190 --> 00:02:35.807
your held-out data
and it looks good,

00:02:35.807 --> 00:02:37.390
there's no reason
that should continue

00:02:37.390 --> 00:02:40.030
to look good in this new
setting, where the data

00:02:40.030 --> 00:02:42.580
distribution has changed.

00:02:42.580 --> 00:02:45.305
A simple example of what it
means for a data distribution

00:02:45.305 --> 00:02:46.555
to change might be as follows.

00:02:51.510 --> 00:02:57.620
Suppose that we
have as input data,

00:02:57.620 --> 00:03:03.980
and we're trying to
predict some label, which

00:03:03.980 --> 00:03:12.060
maybe meant something like,
why if a patient has--

00:03:12.060 --> 00:03:14.880
or will be newly diagnosed
with type 2 diabetes,

00:03:14.880 --> 00:03:18.650
and this is an
example which we--

00:03:18.650 --> 00:03:23.180
which we talked about when we
introduce risk stratification,

00:03:23.180 --> 00:03:27.470
you learn a model
to predict y from x.

00:03:27.470 --> 00:03:30.080
And now suppose you go
to a new institution

00:03:30.080 --> 00:03:33.560
where their definition of
what type 2 diabetes means

00:03:33.560 --> 00:03:35.486
has changed.

00:03:35.486 --> 00:03:41.870
For example, maybe they don't
actually have type 2 diabetes

00:03:41.870 --> 00:03:45.800
coded in their data, maybe
they only have diabetes

00:03:45.800 --> 00:03:49.010
coded in their data,
which is lumping together

00:03:49.010 --> 00:03:51.320
both type 1 and type
2 diabetes, type 1

00:03:51.320 --> 00:03:56.330
being what's usually
juvenile diabetes

00:03:56.330 --> 00:03:59.790
and is actually a very distinct
disease from type 2 diabetes.

00:03:59.790 --> 00:04:02.450
So now the notion of what
diabetes is is different.

00:04:02.450 --> 00:04:04.495
Maybe the use case is
also slightly different.

00:04:04.495 --> 00:04:05.870
And there's no
reason, obviously,

00:04:05.870 --> 00:04:08.480
that your model, which was used
to predict type 2 diabetes,

00:04:08.480 --> 00:04:10.820
would work for that new label.

00:04:10.820 --> 00:04:12.710
Now, this is an example
of a very type--

00:04:12.710 --> 00:04:16.730
of a type of data set
shift which is perhaps

00:04:16.730 --> 00:04:19.730
for you obvious nothing
should work in the setting

00:04:19.730 --> 00:04:29.540
because here the distribution
of P of y given x changes,

00:04:29.540 --> 00:04:32.840
meaning even if you have
the same individual,

00:04:32.840 --> 00:04:35.900
your distribution P(y)
given x in, let's say,

00:04:35.900 --> 00:04:40.325
the distribution P(0) and the
distribution P of y given x

00:04:40.325 --> 00:04:42.700
and P(1), where this is, let's
say, one institution, this

00:04:42.700 --> 00:04:45.110
is another, these now are
two different distributions

00:04:45.110 --> 00:04:47.220
if the meaning of the
label has changed.

00:04:47.220 --> 00:04:50.720
So for the same person, there
might be different distribution

00:04:50.720 --> 00:04:52.800
over what y is.

00:04:52.800 --> 00:04:54.545
So this is one
type of data shift.

00:04:54.545 --> 00:04:55.920
And a very different
type of data

00:04:55.920 --> 00:04:59.313
set shift is where we assume
that these two are equal.

00:04:59.313 --> 00:05:00.980
And so that would,
for example, rule out

00:05:00.980 --> 00:05:03.150
this type of data set shift.

00:05:03.150 --> 00:05:11.240
But rather what changes is P of
x from location 1 to location--

00:05:11.240 --> 00:05:14.270
to location 2.

00:05:14.270 --> 00:05:18.010
And this is the type of data set
shift which will be focused on

00:05:18.010 --> 00:05:18.860
in today's lecture.

00:05:18.860 --> 00:05:21.200
It goes by the name
of covariate shift.

00:05:27.260 --> 00:05:31.190
And let's look at two
different examples of that.

00:05:31.190 --> 00:05:34.730
The first example would be of
an adversarial perturbation.

00:05:34.730 --> 00:05:39.080
And so we've-- you've all seen
the use of convolutional neural

00:05:39.080 --> 00:05:41.540
networks for image
classification problems.

00:05:41.540 --> 00:05:44.360
This is just one illustration
of such an architecture.

00:05:44.360 --> 00:05:45.890
And with such an
architecture, one

00:05:45.890 --> 00:05:48.140
could then attempt to do all
sorts of different object

00:05:48.140 --> 00:05:50.900
classification or image
classification tasks.

00:05:50.900 --> 00:05:54.260
You could take as input
this picture of a dog, which

00:05:54.260 --> 00:05:57.980
is clearly a dog.

00:05:57.980 --> 00:06:01.220
And you could modify
it just a little bit.

00:06:01.220 --> 00:06:05.258
Just add in a very
small amount of noise.

00:06:05.258 --> 00:06:06.800
What I'm going to
do is now I'm going

00:06:06.800 --> 00:06:11.880
to create a new image which
is that original image.

00:06:11.880 --> 00:06:13.700
Now with every single
pixel, I'm going

00:06:13.700 --> 00:06:17.840
to add a very small epsilon in
the direction of that noise.

00:06:17.840 --> 00:06:20.760
And what you get out
is this new image,

00:06:20.760 --> 00:06:22.718
which you could stare at
however long you want,

00:06:22.718 --> 00:06:24.718
you're not going to able
to tell the difference.

00:06:24.718 --> 00:06:26.420
Basically to the
human eye, these two

00:06:26.420 --> 00:06:29.240
look exactly identical.

00:06:29.240 --> 00:06:33.680
Except when you take your
machine learning classifier,

00:06:33.680 --> 00:06:36.980
which is trained on
original unperturbed data,

00:06:36.980 --> 00:06:39.330
and now apply it
to this new image,

00:06:39.330 --> 00:06:40.580
it's classified as an ostrich.

00:06:43.490 --> 00:06:46.760
And this observation
was published

00:06:46.760 --> 00:06:49.940
in a paper in 2014 called
"Intriguing properties

00:06:49.940 --> 00:06:52.280
of neural networks."

00:06:52.280 --> 00:06:58.255
And it really kickstarted
a huge surge of interest

00:06:58.255 --> 00:06:59.630
in the machine
learning community

00:06:59.630 --> 00:07:04.350
on adversarial perturbations
to machine learning.

00:07:04.350 --> 00:07:07.730
So asking questions, if you
were to perturb inputs just

00:07:07.730 --> 00:07:09.680
a little bit, how
does that change

00:07:09.680 --> 00:07:11.090
your classifier's output?

00:07:11.090 --> 00:07:15.140
And could that be used to attack
machine learning algorithms?

00:07:15.140 --> 00:07:18.032
And how can one
defend against it?

00:07:18.032 --> 00:07:19.740
By the way, as an
aside, this is actually

00:07:19.740 --> 00:07:22.580
a very old area of research.

00:07:22.580 --> 00:07:26.005
And even back in the land
of linear classifiers,

00:07:26.005 --> 00:07:27.380
these questions
had been studied.

00:07:27.380 --> 00:07:30.870
Although I won't get
into it in this course.

00:07:30.870 --> 00:07:32.960
So this is a type of data
set shift in the sense

00:07:32.960 --> 00:07:36.920
that what we want is that this
should still be classified

00:07:36.920 --> 00:07:40.190
as an ostrich-- as a dog.

00:07:40.190 --> 00:07:42.210
So the actual label
hasn't changed.

00:07:42.210 --> 00:07:44.720
We would like this distribution
over the labels, given

00:07:44.720 --> 00:07:46.850
the perturbed into it,
to be slightly different,

00:07:46.850 --> 00:07:49.670
except that now the
distribution of inputs

00:07:49.670 --> 00:07:51.620
is a little bit
different because we're

00:07:51.620 --> 00:07:55.537
allowing for some noise to be
added to each of the inputs.

00:07:55.537 --> 00:07:57.620
And in this case, the noise
actually isn't random,

00:07:57.620 --> 00:07:58.310
it's adversarial.

00:07:58.310 --> 00:07:59.935
And towards the end
of today's lecture,

00:07:59.935 --> 00:08:02.540
I'll give you an example
of how one can actually

00:08:02.540 --> 00:08:04.550
generate the
adversarial image, which

00:08:04.550 --> 00:08:06.680
can change the classifier.

00:08:06.680 --> 00:08:08.780
Now, the reason
why we should care

00:08:08.780 --> 00:08:11.420
about these types of
things in this course

00:08:11.420 --> 00:08:14.580
are because I expect
that this type of data

00:08:14.580 --> 00:08:18.020
set shift, which is not at
all natural, it's adversarial,

00:08:18.020 --> 00:08:21.950
is also going to start
showing up in both computer

00:08:21.950 --> 00:08:26.570
vision and non-computer vision
problems in the medical domain.

00:08:26.570 --> 00:08:32.585
There was a nice paper by Sam
Finlayson, Andy Beam, and Isaac

00:08:32.585 --> 00:08:37.039
Kohane recently, which
presented several different case

00:08:37.039 --> 00:08:40.370
studies of where these
problems could really

00:08:40.370 --> 00:08:42.658
arise in health care.

00:08:42.658 --> 00:08:44.450
So, for example, here
what we're looking at

00:08:44.450 --> 00:08:45.920
is an image
classification problem

00:08:45.920 --> 00:08:47.780
arising from dermatology.

00:08:47.780 --> 00:08:52.130
You're given as input an image.

00:08:52.130 --> 00:08:57.500
For example, you would like
that this image be classified

00:08:57.500 --> 00:09:01.490
as an individual having
a particular type of skin

00:09:01.490 --> 00:09:05.300
disorder, a nevus, and
this other image, melanoma.

00:09:05.300 --> 00:09:09.230
And what one can see is that
with a small perturbation

00:09:09.230 --> 00:09:13.730
of the input, one
can completely swap

00:09:13.730 --> 00:09:17.080
the label that would be assigned
to it from one to the other.

00:09:19.460 --> 00:09:20.960
And in this paper,
which we're going

00:09:20.960 --> 00:09:24.290
to post as optional
readings for today's course,

00:09:24.290 --> 00:09:27.650
they talk about how one
could maliciously use

00:09:27.650 --> 00:09:31.320
these algorithms for benefit.

00:09:31.320 --> 00:09:36.890
So, for example, imagine that
a health insurance company now

00:09:36.890 --> 00:09:44.330
decides in order to reimburse
for an expensive biopsy

00:09:44.330 --> 00:09:50.600
of a patient's skin,
a clinician or a nurse

00:09:50.600 --> 00:09:55.970
must first take a
picture of the disorder

00:09:55.970 --> 00:10:00.470
and submit that picture
together with the bill

00:10:00.470 --> 00:10:02.450
for the procedure.

00:10:02.450 --> 00:10:04.970
And imagine now that the
insurance company were

00:10:04.970 --> 00:10:07.895
to have a machine
learning algorithm be

00:10:07.895 --> 00:10:12.200
an automatic check, was this
procedure actually reasonable

00:10:12.200 --> 00:10:15.110
for this condition?

00:10:15.110 --> 00:10:19.760
And if it isn't, it
might be flagged.

00:10:19.760 --> 00:10:25.660
Now, a malicious user could
perturb the input such

00:10:25.660 --> 00:10:28.630
that it would, despite the
patient having perhaps even

00:10:28.630 --> 00:10:32.560
completely normal-looking
skin, could nonetheless

00:10:32.560 --> 00:10:34.930
be classified by a
machine learning algorithm

00:10:34.930 --> 00:10:37.300
as being abnormal in
some way, and thus

00:10:37.300 --> 00:10:41.080
perhaps could get reimbursed
by that procedure.

00:10:41.080 --> 00:10:44.560
Now, obviously
this is an example

00:10:44.560 --> 00:10:47.320
of a nefarious setting
where we would then

00:10:47.320 --> 00:10:51.220
hope that such an
individual would be caught

00:10:51.220 --> 00:10:53.793
by the police, sent to jail.

00:10:53.793 --> 00:10:55.960
But nonetheless, what we
would like to be able to do

00:10:55.960 --> 00:10:58.930
is build checks and balances
into the system such

00:10:58.930 --> 00:11:01.780
that that couldn't even
happen because to a human

00:11:01.780 --> 00:11:05.860
it's obvious that you
shouldn't be able to trick--

00:11:05.860 --> 00:11:09.022
trick anyone with such a
very minor perturbation.

00:11:09.022 --> 00:11:10.480
So how do you build
algorithms that

00:11:10.480 --> 00:11:12.440
could also be not
tricked as easily

00:11:12.440 --> 00:11:13.690
as humans wouldn't be tracked?

00:11:13.690 --> 00:11:15.160
AUDIENCE: Can I ask a question

00:11:15.160 --> 00:11:15.460
DAVID SONTAG: Yeah.

00:11:15.460 --> 00:11:17.140
AUDIENCE: For any
of these samples,

00:11:17.140 --> 00:11:20.785
did the attacker need
access to the network?

00:11:20.785 --> 00:11:22.202
Is there a way to
[? attack it? ?]

00:11:22.202 --> 00:11:24.660
DAVID SONTAG: So the question
is whether the attacker needs

00:11:24.660 --> 00:11:26.950
to know something about
the function that's

00:11:26.950 --> 00:11:29.620
being used for classifying.

00:11:29.620 --> 00:11:33.190
There are examples of both what
are called white box and black

00:11:33.190 --> 00:11:38.860
box attacks, where in one
setting you have access

00:11:38.860 --> 00:11:43.340
to the function and
other settings you don't.

00:11:43.340 --> 00:11:45.550
And so both have been
studied in the literature,

00:11:45.550 --> 00:11:47.470
and there are results
showing that one

00:11:47.470 --> 00:11:49.630
can attack in either setting.

00:11:49.630 --> 00:11:51.970
Sometimes you might need
to know a little bit more.

00:11:51.970 --> 00:11:53.512
Like, for example,
sometimes you need

00:11:53.512 --> 00:11:55.960
to have the ability to query
the function a certain number

00:11:55.960 --> 00:11:56.698
of times.

00:11:56.698 --> 00:11:58.990
So even if you don't know
exactly what the function is,

00:11:58.990 --> 00:12:01.450
like you don't know the
weights of the neural network,

00:12:01.450 --> 00:12:04.630
as long as you can query
it sufficiently many times,

00:12:04.630 --> 00:12:07.685
you'll be able to construct
adversarial examples.

00:12:07.685 --> 00:12:08.810
That would be one approach.

00:12:08.810 --> 00:12:10.060
Another approach
would be, oh, maybe we

00:12:10.060 --> 00:12:11.050
don't know the function,
but we know something

00:12:11.050 --> 00:12:12.650
about the training data.

00:12:12.650 --> 00:12:16.060
So there are ways to go about
doing this even if you don't

00:12:16.060 --> 00:12:17.870
perfectly know the function.

00:12:17.870 --> 00:12:19.162
Does that answer your question?

00:12:21.950 --> 00:12:25.092
So what about a
natural perturbation?

00:12:25.092 --> 00:12:26.800
So this figure just
pulled from lecture 5

00:12:26.800 --> 00:12:28.360
when we talked about
non-stationarity

00:12:28.360 --> 00:12:30.550
in the context of risk
stratification, that's

00:12:30.550 --> 00:12:34.450
just to remind you here the
x-axis is time, that y-axis is

00:12:34.450 --> 00:12:37.000
different types of
laboratory test results

00:12:37.000 --> 00:12:41.755
that might be ordered,
and the color denotes

00:12:41.755 --> 00:12:44.020
how many of those
laboratory tests

00:12:44.020 --> 00:12:47.630
were ordered in a certain
population at a point in time.

00:12:47.630 --> 00:12:50.980
So what we would expect to
see if the data was stationary

00:12:50.980 --> 00:12:54.010
is that every row would
be a homogeneous color.

00:12:54.010 --> 00:12:56.440
But instead what we see is
that there are points in time,

00:12:56.440 --> 00:12:58.780
for example, a few month
integrals over here,

00:12:58.780 --> 00:13:02.890
when suddenly it looks like, for
some of the laboratory tests,

00:13:02.890 --> 00:13:05.800
they were never performed.

00:13:05.800 --> 00:13:08.590
That's most likely
due to a data problem,

00:13:08.590 --> 00:13:11.800
or perhaps the feed of data from
that laboratory test provider

00:13:11.800 --> 00:13:14.147
got lost, there were
some systems problem.

00:13:14.147 --> 00:13:15.980
But they're also going
to be settings where,

00:13:15.980 --> 00:13:18.160
for example, a
laboratory test is never

00:13:18.160 --> 00:13:19.660
used until it's suddenly used.

00:13:19.660 --> 00:13:21.100
And that might be because
it's a new test that

00:13:21.100 --> 00:13:23.170
was just invented or
approved for reimbursement

00:13:23.170 --> 00:13:24.560
at that point in time.

00:13:24.560 --> 00:13:26.410
So this is an example
of non-stationarity.

00:13:26.410 --> 00:13:28.540
And, of course, this
could also result

00:13:28.540 --> 00:13:31.130
in changes in your
data distribution,

00:13:31.130 --> 00:13:35.380
such as what I described
over there, over time.

00:13:35.380 --> 00:13:37.090
And the third example
is when you then

00:13:37.090 --> 00:13:39.400
go across institutions,
wherein, of course,

00:13:39.400 --> 00:13:41.650
both the language that might
be used-- you might think

00:13:41.650 --> 00:13:43.420
of a hospital in
the United States

00:13:43.420 --> 00:13:46.120
versus a hospital in China, the
clinical notes will be written

00:13:46.120 --> 00:13:48.560
in completely different
languages, that'll

00:13:48.560 --> 00:13:49.810
would be an extreme case.

00:13:49.810 --> 00:13:53.080
And a less extreme case might
be two different hospitals

00:13:53.080 --> 00:13:56.450
in Boston where the
acronyms or the shorthand

00:13:56.450 --> 00:13:58.600
they use for some clinical
terms might actually

00:13:58.600 --> 00:14:03.140
be different because
of local practices.

00:14:03.140 --> 00:14:04.490
So, what do we do?

00:14:04.490 --> 00:14:05.380
This is all a setup.

00:14:05.380 --> 00:14:07.810
And for the rest of the
lecture, what I'll talk about

00:14:07.810 --> 00:14:10.810
is first, very
briefly, how one can

00:14:10.810 --> 00:14:15.110
build in population-level checks
for has something changed.

00:14:15.110 --> 00:14:17.930
And then the bulk
of today's lecture,

00:14:17.930 --> 00:14:20.620
we'll be talking about how
to develop transfer learning

00:14:20.620 --> 00:14:23.847
algorithms and how one
could think about defenses

00:14:23.847 --> 00:14:24.805
to adversarial attacks.

00:14:29.580 --> 00:14:33.690
So before I show you that
first slide for bullet one,

00:14:33.690 --> 00:14:35.497
I want to have a
bit of discussion.

00:14:38.910 --> 00:14:41.160
You've suddenly done that
thing of learning machine

00:14:41.160 --> 00:14:43.480
learning algorithm
in your institution,

00:14:43.480 --> 00:14:47.610
and you want to know,
will this algorithm

00:14:47.610 --> 00:14:50.820
work at some other institution?

00:14:50.820 --> 00:14:54.870
You pick up the phone, you
call up your collaborating data

00:14:54.870 --> 00:14:56.850
scientists at another
institution, what

00:14:56.850 --> 00:14:58.600
are the questions that
you should ask them

00:14:58.600 --> 00:15:00.120
when we're trying to
understand, will your algorithm

00:15:00.120 --> 00:15:00.912
work there as well?

00:15:07.860 --> 00:15:08.410
Yeah.

00:15:08.410 --> 00:15:10.680
AUDIENCE: What kind of
lab test information

00:15:10.680 --> 00:15:13.270
they collect [INAUDIBLE].

00:15:13.270 --> 00:15:14.940
DAVID SONTAG: So
what type of data

00:15:14.940 --> 00:15:17.130
do they have on their
patients, and do they

00:15:17.130 --> 00:15:20.460
have similar data types
or features available

00:15:20.460 --> 00:15:22.910
for their patient population?

00:15:22.910 --> 00:15:25.160
Other ideas, someone who
hasn't spoken in the last two

00:15:25.160 --> 00:15:29.262
lectures, maybe someone
in the far back there,

00:15:29.262 --> 00:15:30.720
people who have
their computer out.

00:15:30.720 --> 00:15:32.303
Maybe you with your
hand in your mouth

00:15:32.303 --> 00:15:34.250
right there, yeah, you
with your glasses on.

00:15:34.250 --> 00:15:34.750
Ideas.

00:15:34.750 --> 00:15:35.417
[STUDENT LAUGHS]

00:15:35.417 --> 00:15:37.590
AUDIENCE: Sorry, can
you repeat the question?

00:15:37.590 --> 00:15:40.350
DAVID SONTAG: You want me
to repeat the question?

00:15:40.350 --> 00:15:42.690
The question was as follows.

00:15:42.690 --> 00:15:45.930
You learn your machine learning
algorithm at some institution,

00:15:45.930 --> 00:15:48.672
and you want to apply it
now in a new institution.

00:15:48.672 --> 00:15:50.880
What questions should you
ask of that new institution

00:15:50.880 --> 00:15:52.500
to try to assess whether your
algorithm will generalize

00:15:52.500 --> 00:15:54.137
in that new institution?

00:15:54.137 --> 00:15:56.950
AUDIENCE: I guess it depends on
your problem you're looking at,

00:15:56.950 --> 00:15:58.880
like whether you're
trying to learn

00:15:58.880 --> 00:16:00.748
possible differences
in your population,

00:16:00.748 --> 00:16:04.887
if you're requiring data with
particular [INAUDIBLE] use.

00:16:04.887 --> 00:16:06.470
So I'd envision it
that you'd want to,

00:16:06.470 --> 00:16:08.790
like are your machines
calibrated [INAUDIBLE]??

00:16:08.790 --> 00:16:11.070
Do they use techniques
to acquire the data?

00:16:11.070 --> 00:16:12.070
DAVID SONTAG: All right.

00:16:12.070 --> 00:16:14.740
So let's break down each of
the answers that you gave.

00:16:14.740 --> 00:16:16.420
The first answer
that you gave was,

00:16:16.420 --> 00:16:18.100
are there differences
in the population?

00:16:20.783 --> 00:16:22.450
What would be an exa--
someone else now,

00:16:22.450 --> 00:16:24.742
what are we an example of a
difference in a population?

00:16:28.660 --> 00:16:29.160
Yep.

00:16:29.160 --> 00:16:30.660
AUDIENCE: Age
distribution You might

00:16:30.660 --> 00:16:32.310
have younger people
in maybe Boston

00:16:32.310 --> 00:16:33.600
versus like a
Massachusetts [INAUDIBLE]..

00:16:33.600 --> 00:16:35.517
DAVID SONTAG: So you
might have younger people

00:16:35.517 --> 00:16:38.650
in Boston versus
older people who

00:16:38.650 --> 00:16:40.680
are in Central Massachusetts.

00:16:40.680 --> 00:16:42.900
How might a change
in age distribution

00:16:42.900 --> 00:16:47.562
affect your ability of your
algorithms to generalize?

00:16:47.562 --> 00:16:48.062
Yep.

00:16:48.062 --> 00:16:49.770
AUDIENCE: [? Possibly ?]
health patterns,

00:16:49.770 --> 00:16:52.270
where young people are very
different from [INAUDIBLE] who

00:16:52.270 --> 00:16:54.740
have some diseases that
are clearly more prevalent

00:16:54.740 --> 00:16:56.150
in populations that are
older [? than you. ?]

00:16:56.150 --> 00:16:57.150
DAVID SONTAG: Thank you.

00:16:57.150 --> 00:17:01.030
So sometimes we might expect a
different just set of diseases

00:17:01.030 --> 00:17:03.700
to occur for a younger
population versus an older

00:17:03.700 --> 00:17:04.700
population.

00:17:04.700 --> 00:17:07.810
So I type 2 diabetes,
hypertension,

00:17:07.810 --> 00:17:13.000
these are diseases that are
often diagnosed when patients--

00:17:13.000 --> 00:17:17.185
when individuals are
40s, 50s, and older.

00:17:17.185 --> 00:17:19.300
If you have people
who are in their 20s,

00:17:19.300 --> 00:17:21.190
you don't typically
see those diseases

00:17:21.190 --> 00:17:23.210
in a younger population.

00:17:23.210 --> 00:17:27.460
And so what that means is
if your model, for example,

00:17:27.460 --> 00:17:32.710
was trained on a population
of very young individuals,

00:17:32.710 --> 00:17:36.730
then it might not be able
to-- and suppose you're doing

00:17:36.730 --> 00:17:40.000
something like
predicting future cost,

00:17:40.000 --> 00:17:42.250
so something which is not
directly tied to the disease

00:17:42.250 --> 00:17:45.950
itself, the features that
are predictive of future cost

00:17:45.950 --> 00:17:49.120
in a very young population
might be very different from

00:17:49.120 --> 00:17:50.680
features--

00:17:50.680 --> 00:17:53.230
for predictors of cost in a
much older population because

00:17:53.230 --> 00:17:56.635
of the differences in conditions
that those individuals have.

00:17:56.635 --> 00:17:58.570
Now the second
answer that was given

00:17:58.570 --> 00:18:01.660
had to do with calibration
of instruments.

00:18:01.660 --> 00:18:03.370
Can you elaborate
a bit about that?

00:18:03.370 --> 00:18:03.995
AUDIENCE: Yeah.

00:18:03.995 --> 00:18:07.060
So I was thinking [? clearly ?]
in the colonoscopy space.

00:18:07.060 --> 00:18:09.902
But if you're collecting--
so in that space,

00:18:09.902 --> 00:18:11.360
you're collecting
videos of colons.

00:18:11.360 --> 00:18:14.020
And so you can
have machines that

00:18:14.020 --> 00:18:15.823
are calibrated very
differently, let's say

00:18:15.823 --> 00:18:18.470
different light exposure,
different camera settings.

00:18:18.470 --> 00:18:21.450
But you also have that
the GIs and physicians

00:18:21.450 --> 00:18:23.980
have different techniques as
to how they explore the colon.

00:18:23.980 --> 00:18:26.560
So the video data itself is
going to be very different.

00:18:26.560 --> 00:18:28.540
DAVID SONTAG: So the
example that was given

00:18:28.540 --> 00:18:31.262
was of colonoscopies
and data that might

00:18:31.262 --> 00:18:32.470
be collected as part of that.

00:18:35.088 --> 00:18:37.630
And the data that could be--
the data that could be collected

00:18:37.630 --> 00:18:39.505
could be different for
two different reasons.

00:18:39.505 --> 00:18:43.967
One, because the-- because
the actual instruments

00:18:43.967 --> 00:18:46.300
that are collecting the data,
for example, imaging data,

00:18:46.300 --> 00:18:47.680
might be calibrated a
little bit differently.

00:18:47.680 --> 00:18:50.013
And a second reason might be
because the procedures that

00:18:50.013 --> 00:18:53.200
are used to perform that
diagnostic test might be

00:18:53.200 --> 00:18:54.680
different in each institution.

00:18:54.680 --> 00:18:57.800
Each one will result in slightly
different biases to the data,

00:18:57.800 --> 00:18:59.920
and it's not clear that
an algorithm trained

00:18:59.920 --> 00:19:02.230
on one type of procedure
or one type of instrument

00:19:02.230 --> 00:19:04.580
would generalize to another.

00:19:04.580 --> 00:19:06.540
So these are all great examples.

00:19:06.540 --> 00:19:11.620
And so when one reads a paper
from the clinical community

00:19:11.620 --> 00:19:16.600
on developing a new risk
stratification tool, what

00:19:16.600 --> 00:19:19.960
you will always
see in this paper

00:19:19.960 --> 00:19:23.050
is what's known as "Table 1."

00:19:23.050 --> 00:19:25.360
Table 1 looks a
little bit like this.

00:19:25.360 --> 00:19:27.550
Here I pulled one
of my own papers

00:19:27.550 --> 00:19:29.975
that was published in
JAMA Cardiology for 2016

00:19:29.975 --> 00:19:32.350
where we looked at how to try
to find patients with heart

00:19:32.350 --> 00:19:34.900
failure who are hospitalized.

00:19:34.900 --> 00:19:37.510
And I'm just going to walk
through what this table is.

00:19:37.510 --> 00:19:40.030
So this table is
describing the population

00:19:40.030 --> 00:19:42.730
that was used in the study.

00:19:42.730 --> 00:19:46.150
At the very top, it says these
are characteristics of 47,000

00:19:46.150 --> 00:19:47.980
hospitalized patients.

00:19:47.980 --> 00:19:53.480
Then what we've done is,
using our domain knowledge,

00:19:53.480 --> 00:19:55.870
we know that this is a
heart failure population,

00:19:55.870 --> 00:19:58.240
and we know that there are
a number of different axes

00:19:58.240 --> 00:20:01.420
that differentiate patients
who are hospitalized

00:20:01.420 --> 00:20:02.960
that have heart failure.

00:20:02.960 --> 00:20:07.180
And so we enumerate over
many of the features

00:20:07.180 --> 00:20:10.840
that we think are critical to
characterizing the population,

00:20:10.840 --> 00:20:12.880
and we give
descriptive statistics

00:20:12.880 --> 00:20:14.620
on each one of those features.

00:20:14.620 --> 00:20:19.530
You always start with things
like age, gender, and race.

00:20:19.530 --> 00:20:22.990
And so here, for example, the
average age was 61 years old,

00:20:22.990 --> 00:20:32.080
this was, by the way, NYU
Medical School, 50.8% female,

00:20:32.080 --> 00:20:37.480
11.2% Black, African-American,
17.6% of individuals

00:20:37.480 --> 00:20:41.290
were on Medicaid, which
was a state-provided health

00:20:41.290 --> 00:20:46.960
insurance for either disabled
or lower-income individuals.

00:20:46.960 --> 00:20:51.310
And then we looked at quantities
like what types of medications

00:20:51.310 --> 00:20:52.480
were patients on.

00:20:52.480 --> 00:20:57.970
41% of-- 42% of
inpatient patients

00:20:57.970 --> 00:20:59.770
were on something
called beta blockers.

00:20:59.770 --> 00:21:03.970
31.6% of outpatients
were on beta blockers.

00:21:03.970 --> 00:21:09.080
We then looked at things
like laboratory test results.

00:21:09.080 --> 00:21:12.400
So one can look at the
average creatinine values,

00:21:12.400 --> 00:21:16.900
the average sodium values
of this patient population.

00:21:16.900 --> 00:21:18.670
And in this way,
it described what

00:21:18.670 --> 00:21:21.173
is the population
that's being studied.

00:21:21.173 --> 00:21:22.840
Then when you go to
the new institution,

00:21:22.840 --> 00:21:26.440
that new institution receives
not just the algorithm,

00:21:26.440 --> 00:21:29.170
but they also
receive this Table 1

00:21:29.170 --> 00:21:31.900
that describes a population in
which the algorithm was learned

00:21:31.900 --> 00:21:32.890
on.

00:21:32.890 --> 00:21:36.040
And they could use that together
with some domain knowledge

00:21:36.040 --> 00:21:38.745
to think through questions like
what we were eliciting-- what

00:21:38.745 --> 00:21:40.793
I elicited from you
in our discussion

00:21:40.793 --> 00:21:42.460
so that we could
think, is it actually--

00:21:42.460 --> 00:21:44.260
does it make sense that
this model will generalize

00:21:44.260 --> 00:21:45.400
to this new institution?

00:21:45.400 --> 00:21:47.920
Are the reasons
why it might not?

00:21:47.920 --> 00:21:49.840
And you could do that
even before doing

00:21:49.840 --> 00:21:54.710
any prospective evaluation
on the new population.

00:21:54.710 --> 00:21:58.540
So almost all of you should
have something like Table 1

00:21:58.540 --> 00:22:02.380
in your project
write-ups because that's

00:22:02.380 --> 00:22:06.520
an important part of any study
in this field is describing,

00:22:06.520 --> 00:22:09.430
what is the population that
you're doing your study on?

00:22:09.430 --> 00:22:10.493
You agree with me, Pete?

00:22:10.493 --> 00:22:11.410
PETER SZOLOVITS: Yeah.

00:22:11.410 --> 00:22:16.300
I would just at that Table 1,
if you're doing a case control

00:22:16.300 --> 00:22:19.420
study, you will have
two columns that

00:22:19.420 --> 00:22:24.910
show the distributions
in the two populations,

00:22:24.910 --> 00:22:28.840
and then a p-value of how
likely those differences are

00:22:28.840 --> 00:22:30.440
to be significant.

00:22:30.440 --> 00:22:33.630
And if you leave that out, you
can't get your paper published.

00:22:33.630 --> 00:22:35.740
DAVID SONTAG: I'll just
repeat Pete's answer

00:22:35.740 --> 00:22:37.570
for the recording.

00:22:37.570 --> 00:22:44.300
If you are-- this table is
for a predictive problem.

00:22:44.300 --> 00:22:48.100
But if you're thinking about a
causal inference type problem,

00:22:48.100 --> 00:22:52.250
where there's a notion of
different intervention groups,

00:22:52.250 --> 00:22:55.450
then you'd be expected to
report the same sorts of things,

00:22:55.450 --> 00:22:57.200
but for both the
case population,

00:22:57.200 --> 00:22:58.750
the people who received,
let's say, treatment one,

00:22:58.750 --> 00:23:00.250
and the control
population of people

00:23:00.250 --> 00:23:02.107
who receive treatment zero.

00:23:02.107 --> 00:23:03.940
And then you would be
looking at differences

00:23:03.940 --> 00:23:06.910
between those populations as
well at the individual feature

00:23:06.910 --> 00:23:13.080
level as part of the descriptive
statistics for that study.

00:23:13.080 --> 00:23:16.614
Now, this-- yeah.

00:23:16.614 --> 00:23:19.030
AUDIENCE: Is this to
identify [? individually ?]

00:23:19.030 --> 00:23:20.710
[? between ?] those peoples?

00:23:20.710 --> 00:23:24.225
[INAUDIBLE] institutions to do
like t-tests on those tables--

00:23:24.225 --> 00:23:25.975
DAVID SONTAG: To see
if they're different?

00:23:25.975 --> 00:23:27.730
No, so they're always
going to be different.

00:23:27.730 --> 00:23:29.105
You go to a new
institution, it's

00:23:29.105 --> 00:23:31.340
always going to look different.

00:23:31.340 --> 00:23:34.600
And so just looking to see
how something changed is not--

00:23:34.600 --> 00:23:37.750
the answer's always
going to be yes.

00:23:37.750 --> 00:23:42.190
But it enables a conversation
to think through, OK, this,

00:23:42.190 --> 00:23:43.330
and then you might look--

00:23:43.330 --> 00:23:44.440
you might use some
of the techniques

00:23:44.440 --> 00:23:46.982
that Pete's going to talk about
next week on interpretability

00:23:46.982 --> 00:23:49.330
to understand, well, what
is the model actually using.

00:23:49.330 --> 00:23:51.520
Then you might
ask, oh, OK, well,

00:23:51.520 --> 00:23:53.020
the model is using
this thing, which

00:23:53.020 --> 00:23:55.480
makes sense in this population
but might not make sense

00:23:55.480 --> 00:23:56.530
in another population.

00:23:56.530 --> 00:23:58.210
And it's these two
things together

00:23:58.210 --> 00:23:59.693
that make the conversation.

00:24:04.130 --> 00:24:07.550
Now, this question
has really come

00:24:07.550 --> 00:24:13.820
to the forefront in recent
years in close connection

00:24:13.820 --> 00:24:16.865
to the topic that Pete
discussed last week on fairness

00:24:16.865 --> 00:24:18.350
in machine learning.

00:24:18.350 --> 00:24:20.495
Because you might ask
if a classifier is built

00:24:20.495 --> 00:24:22.370
in some population, is
it going to generalize

00:24:22.370 --> 00:24:24.870
to another population if that
population that has learned on

00:24:24.870 --> 00:24:26.540
was very biased, for
example, it might

00:24:26.540 --> 00:24:27.770
have been all white people.

00:24:27.770 --> 00:24:29.145
You might ask, is
that classifier

00:24:29.145 --> 00:24:31.730
going to work well in another
population that might perhaps

00:24:31.730 --> 00:24:34.760
include people of
different ethnicities?

00:24:34.760 --> 00:24:41.810
And so that has led to a concept
which was recently published.

00:24:41.810 --> 00:24:44.900
This working draft that I'm
showing the abstract from

00:24:44.900 --> 00:24:50.330
was just a few weeks ago called
"Datasheets for data sets."

00:24:50.330 --> 00:24:52.280
And the goal here
is to standardize

00:24:52.280 --> 00:24:54.830
the process of
describing-- of eliciting

00:24:54.830 --> 00:24:58.130
the information about what is it
about the data set that really

00:24:58.130 --> 00:25:03.630
played into your model?

00:25:03.630 --> 00:25:05.845
And so I'm going to walk
you through very briefly

00:25:05.845 --> 00:25:07.220
just through a
couple of elements

00:25:07.220 --> 00:25:13.440
of what an example data set for
a datasheet might look like.

00:25:13.440 --> 00:25:14.930
This is too small
for you to read,

00:25:14.930 --> 00:25:18.180
but I'll blow up one
section in just a second.

00:25:18.180 --> 00:25:21.620
So this is a
datasheet for a data

00:25:21.620 --> 00:25:25.700
set called Studying Face
Recognition in an Unconstrained

00:25:25.700 --> 00:25:26.280
Environment.

00:25:26.280 --> 00:25:28.738
So it's for computer
vision problem.

00:25:28.738 --> 00:25:30.780
There are going to be a
number of questionnaires,

00:25:30.780 --> 00:25:33.590
which this paper that I
point you to outlines.

00:25:33.590 --> 00:25:38.330
And you as the model developer
go through that questionnaire

00:25:38.330 --> 00:25:41.630
and fill out the
answers to it, so

00:25:41.630 --> 00:25:43.850
including things about
motivation for the data

00:25:43.850 --> 00:25:46.920
set creation
composition and so on.

00:25:46.920 --> 00:25:51.170
So in this particular instance,
this data set called Labeled

00:25:51.170 --> 00:25:54.740
Faces in the Wild was created to
provide images that study face

00:25:54.740 --> 00:25:57.470
recognition in an unconstrained
[INAUDIBLE] settings,

00:25:57.470 --> 00:26:02.060
where image characteristics
such as pose, elimination,

00:26:02.060 --> 00:26:05.450
resolution, focus
cannot be controlled.

00:26:05.450 --> 00:26:10.130
So it's intended to be
real-world settings.

00:26:10.130 --> 00:26:11.930
Now, one of the most
interesting sections

00:26:11.930 --> 00:26:16.730
of this report that one should
release with the data set

00:26:16.730 --> 00:26:20.060
has to do with how was the
data preprocessed or cleaned?

00:26:20.060 --> 00:26:21.560
So, for example,
for this data set,

00:26:21.560 --> 00:26:23.480
it walks through the
following process.

00:26:23.480 --> 00:26:28.030
First, raw images were
obtained from the data set,

00:26:28.030 --> 00:26:31.010
and it consisted of
images and captions that

00:26:31.010 --> 00:26:35.090
were found together with
that image in news articles

00:26:35.090 --> 00:26:37.160
or around the web.

00:26:37.160 --> 00:26:42.080
Then there was a face detector
that was run on the data set.

00:26:42.080 --> 00:26:46.420
Here were the parameters of the
face detector that were used.

00:26:46.420 --> 00:26:50.240
And then remember, the goal
here is to study face detection.

00:26:50.240 --> 00:26:56.210
And so-- so one has to
know, how were the--

00:26:56.210 --> 00:27:00.620
how were the labels determined?

00:27:00.620 --> 00:27:03.110
And how would one, for
example, eliminate if there

00:27:03.110 --> 00:27:05.070
was no face in this image?

00:27:05.070 --> 00:27:09.230
And so there they described
how a face was detected and how

00:27:09.230 --> 00:27:11.960
a region was determined to not
be a face in the case that it

00:27:11.960 --> 00:27:12.800
wasn't.

00:27:12.800 --> 00:27:16.568
And finally, it describes
how duplicates were removed.

00:27:16.568 --> 00:27:18.110
And if you think
back to the examples

00:27:18.110 --> 00:27:22.340
we had earlier in the
semester from medical imaging,

00:27:22.340 --> 00:27:25.370
for example in
pathology and radiology,

00:27:25.370 --> 00:27:28.670
similar data set constructions
had to be done there.

00:27:28.670 --> 00:27:30.500
For example, one would
go to the PAC System

00:27:30.500 --> 00:27:33.590
where radiology images
are stored, one would--

00:27:33.590 --> 00:27:37.850
one would decide which images
are going to be pulled out,

00:27:37.850 --> 00:27:39.830
one would go to
radiography reports

00:27:39.830 --> 00:27:42.895
to figure out how do we
extract the relevant findings

00:27:42.895 --> 00:27:44.270
from that image,
which would give

00:27:44.270 --> 00:27:48.410
the labels for that predictive--
for that learning task.

00:27:48.410 --> 00:27:52.490
And each step there will
incur some bias and some--

00:27:52.490 --> 00:27:55.580
which one then needs to
describe carefully in order

00:27:55.580 --> 00:27:57.470
to understand what
might the bias

00:27:57.470 --> 00:28:00.050
be of the learned classifier.

00:28:00.050 --> 00:28:03.680
So I won't go into more
detail on this now,

00:28:03.680 --> 00:28:05.633
but this will also be
one of the suggested

00:28:05.633 --> 00:28:06.800
readings for today's course.

00:28:06.800 --> 00:28:08.330
And it's a fast read.

00:28:08.330 --> 00:28:11.330
I encourage you to go through
it to get some tuition for what

00:28:11.330 --> 00:28:13.910
are questions we might want
to be asking about data sets

00:28:13.910 --> 00:28:14.600
that we create.

00:28:18.172 --> 00:28:19.630
And for the rest
of this semester--

00:28:19.630 --> 00:28:21.910
for the rest of the
lecture today, I'm

00:28:21.910 --> 00:28:25.390
now going to move on to
some more technical issues.

00:28:25.390 --> 00:28:29.530
So we have to do it.

00:28:29.530 --> 00:28:32.170
We're doing machine
learning now.

00:28:32.170 --> 00:28:34.210
The populations
might be different.

00:28:34.210 --> 00:28:35.510
What do we do about it?

00:28:35.510 --> 00:28:37.150
Can we change the
learning algorithm

00:28:37.150 --> 00:28:40.420
in order to hope that your
algorithm might transfer better

00:28:40.420 --> 00:28:41.470
to a new institution?

00:28:41.470 --> 00:28:44.350
Or if we get a little bit of
data from that new institution,

00:28:44.350 --> 00:28:46.840
could we use that
small amount of data

00:28:46.840 --> 00:28:50.350
from the new institution or a
future time point in the future

00:28:50.350 --> 00:28:54.010
to retrain our model to
do well in that slightly

00:28:54.010 --> 00:28:56.060
different distribution?

00:28:56.060 --> 00:28:59.030
So that's the whole field
of transfer learning.

00:28:59.030 --> 00:29:02.770
So you have data drawn from one
distribution on p of x and y,

00:29:02.770 --> 00:29:05.950
and maybe we have a little bit
of data drawn from a different

00:29:05.950 --> 00:29:08.500
distribution q of x,y.

00:29:08.500 --> 00:29:11.260
And under the covariate
shift assumption,

00:29:11.260 --> 00:29:23.650
I'm assuming that q(x,y) is
equal to q of x times p of y

00:29:23.650 --> 00:29:26.740
given x, namely that the
conditional distribution of y

00:29:26.740 --> 00:29:28.277
given x hasn't changed.

00:29:28.277 --> 00:29:29.860
The only thing that
might have changed

00:29:29.860 --> 00:29:31.660
is your distribution over x.

00:29:31.660 --> 00:29:35.550
So that's what the covariate
shift assumption would assume.

00:29:40.270 --> 00:29:42.850
So suppose that we
have some small amount

00:29:42.850 --> 00:29:46.780
of data drawn from the
new distribution q.

00:29:46.780 --> 00:29:48.670
How could we then
use that in order

00:29:48.670 --> 00:29:52.720
to perhaps retrain our
classifier to do well

00:29:52.720 --> 00:29:55.390
for that new institution?

00:29:55.390 --> 00:29:59.060
So I'll walk through four
different approaches to do so.

00:29:59.060 --> 00:30:02.020
I'll start with
linear models, which

00:30:02.020 --> 00:30:04.360
are the simplest to
understand, and then I'll

00:30:04.360 --> 00:30:09.740
move on to deep models.

00:30:09.740 --> 00:30:12.280
The first approach to something
that you've seen already

00:30:12.280 --> 00:30:14.740
several times in this course.

00:30:14.740 --> 00:30:17.770
We're going to
think about transfer

00:30:17.770 --> 00:30:22.450
as a multi-task learning
problem, where one of the tasks

00:30:22.450 --> 00:30:26.012
has much less data
than the other task.

00:30:26.012 --> 00:30:27.970
So if you remember when
we talked about disease

00:30:27.970 --> 00:30:31.750
progression modeling,
I introduced

00:30:31.750 --> 00:30:34.990
this notion of
regularizing the weight

00:30:34.990 --> 00:30:37.947
vectors so that they could
be close to one another.

00:30:37.947 --> 00:30:40.030
At that time, we were
talking about weight vectors

00:30:40.030 --> 00:30:42.030
predicting disease
progression in different time

00:30:42.030 --> 00:30:43.030
points in the future.

00:30:43.030 --> 00:30:46.730
We could use exactly
the same idea here,

00:30:46.730 --> 00:30:50.950
where you take your classifier,
your linear classifier that

00:30:50.950 --> 00:30:53.200
was trained on a
really large corpus,

00:30:53.200 --> 00:30:54.850
I'm going to call that--

00:30:54.850 --> 00:30:57.460
I'm going to call the weights
of that classifier w old,

00:30:57.460 --> 00:31:01.990
and then I'm going to solve a
new optimization problem, which

00:31:01.990 --> 00:31:08.030
is minimizing over the weights
w that minimizes some loss.

00:31:08.030 --> 00:31:11.560
So this is where your training--
your new training data come in.

00:31:22.690 --> 00:31:26.740
So I'm going to assume that
the new training get D is

00:31:26.740 --> 00:31:29.140
drawn from the q distribution.

00:31:32.880 --> 00:31:38.360
And I'm going to add on a
regularization that asks that w

00:31:38.360 --> 00:31:40.870
should stay close to w old.

00:31:44.390 --> 00:31:47.990
Now, if the amount of
data you have-- if D,

00:31:47.990 --> 00:31:51.710
the data from that new
institution, was very large,

00:31:51.710 --> 00:31:58.510
then you wouldn't need this at
all because you would be able

00:31:58.510 --> 00:31:59.695
to just--

00:31:59.695 --> 00:32:02.590
you would be able to ignore
the classifier that you learned

00:32:02.590 --> 00:32:04.600
previously and just
refit everything

00:32:04.600 --> 00:32:06.400
to that new institution's data.

00:32:06.400 --> 00:32:08.830
Where something like this
is particularly valuable

00:32:08.830 --> 00:32:12.280
is if there was a small
amount of data set shift,

00:32:12.280 --> 00:32:15.790
and you only have a very
small amount of labeled data

00:32:15.790 --> 00:32:17.860
from that new
institution, then this

00:32:17.860 --> 00:32:20.650
would allow you to
change your weight

00:32:20.650 --> 00:32:22.280
vector just a little bit.

00:32:22.280 --> 00:32:24.340
So if this coefficient
was very large,

00:32:24.340 --> 00:32:26.530
it would say that
the new w can't

00:32:26.530 --> 00:32:28.745
be too far from the old w.

00:32:28.745 --> 00:32:31.300
So it'll allow you to
shift things a little bit

00:32:31.300 --> 00:32:35.360
in order to do well on the small
amount of data that you have.

00:32:35.360 --> 00:32:38.350
So, for example, if there is
a feature which was previously

00:32:38.350 --> 00:32:40.360
predictive, but that
feature is no longer

00:32:40.360 --> 00:32:42.450
present in the new data
set, so, for example,

00:32:42.450 --> 00:32:45.290
it's all identically zero,
then, of course, the new weight

00:32:45.290 --> 00:32:45.790
vect--

00:32:45.790 --> 00:32:48.892
the new weight for that feature
is going to be set to 0,

00:32:48.892 --> 00:32:50.350
and that weight
you can think about

00:32:50.350 --> 00:32:53.346
as being redistributed to
some of the other features.

00:32:53.346 --> 00:32:54.530
Does this makes sense?

00:32:54.530 --> 00:32:55.552
Any questions?

00:32:59.000 --> 00:33:02.200
So this is the simplest
approach to transfer learning.

00:33:02.200 --> 00:33:04.510
And before you ever try
anything more complicated,

00:33:04.510 --> 00:33:05.200
always try this.

00:33:12.074 --> 00:33:13.511
Uh, yep.

00:33:16.160 --> 00:33:25.130
So the second approach is
also with a linear model,

00:33:25.130 --> 00:33:28.900
but here we're no longer going
to assume that the features are

00:33:28.900 --> 00:33:30.830
still useful.

00:33:30.830 --> 00:33:35.440
So there might--
when you go from--

00:33:35.440 --> 00:33:37.540
when you go from a--

00:33:37.540 --> 00:33:41.190
your first institution, let's
say, I'm GH on the left,

00:33:41.190 --> 00:33:42.940
you learn your model,
and you can apply it

00:33:42.940 --> 00:33:46.660
to some new institution,
let's say, UCSF on the right,

00:33:46.660 --> 00:33:49.960
it could be that there
is some really big change

00:33:49.960 --> 00:33:53.380
in the feature set such that--

00:33:53.380 --> 00:33:56.110
such that the original
features are not at all

00:33:56.110 --> 00:33:59.680
useful for the new feature set.

00:33:59.680 --> 00:34:01.570
And a really extreme
example of that

00:34:01.570 --> 00:34:04.283
might be the setting that
I gave earlier when I said,

00:34:04.283 --> 00:34:06.700
your model's trained on English,
and you're testing it out

00:34:06.700 --> 00:34:08.116
in Chinese.

00:34:08.116 --> 00:34:09.199
That would be an example--

00:34:09.199 --> 00:34:10.741
if you use a bag of
words model, that

00:34:10.741 --> 00:34:14.290
would be an example where
your model, obviously,

00:34:14.290 --> 00:34:18.670
wouldn't generalize at all
because your features are

00:34:18.670 --> 00:34:21.219
completely different.

00:34:21.219 --> 00:34:23.346
So what would you
do in that setting?

00:34:23.346 --> 00:34:25.179
What's the simplest
thing that you might do?

00:34:30.300 --> 00:34:33.510
So you're taking a text
classifier learned in English,

00:34:33.510 --> 00:34:35.100
and you want to
apply it in a setting

00:34:35.100 --> 00:34:36.840
where that language is Chinese.

00:34:36.840 --> 00:34:38.424
What would you do?

00:34:38.424 --> 00:34:39.424
AUDIENCE: Train on them.

00:34:39.424 --> 00:34:41.220
DAVID SONTAG:
Translate, you said.

00:34:41.220 --> 00:34:42.645
And there was another answer.

00:34:42.645 --> 00:34:45.258
AUDIENCE: Or try train an RN.

00:34:45.258 --> 00:34:46.800
DAVID SONTAG: Train
an RN to do what?

00:34:46.800 --> 00:34:48.075
AUDIENCE: To translate.

00:34:48.075 --> 00:34:49.770
DAVID SONTAG: Train
an RN-- oh, OK.

00:34:49.770 --> 00:34:53.190
So assume that you
have some ability

00:34:53.190 --> 00:34:56.610
to do machine translation, you
translate from English to--

00:34:56.610 --> 00:34:57.610
from Chinese to English.

00:34:57.610 --> 00:35:00.068
It has to be that direction
because the original classifier

00:35:00.068 --> 00:35:01.140
was trained in English.

00:35:01.140 --> 00:35:04.770
And then your new function
is the composition

00:35:04.770 --> 00:35:08.640
of the translation and the
original function, right?

00:35:08.640 --> 00:35:11.070
And then you can
imagine doing some fine

00:35:11.070 --> 00:35:14.100
tuning if you had a
small amount of data.

00:35:14.100 --> 00:35:19.380
Now, the simplest
translation function

00:35:19.380 --> 00:35:21.070
might be just use a dictionary.

00:35:21.070 --> 00:35:23.610
So you look up a
word, and if that word

00:35:23.610 --> 00:35:25.680
has an analogy in
another language,

00:35:25.680 --> 00:35:27.620
you say, OK, this
is the translation.

00:35:27.620 --> 00:35:30.120
But there are always going to
be some words in your language

00:35:30.120 --> 00:35:33.835
which don't have a
very good translation.

00:35:33.835 --> 00:35:36.210
And so you might imagine that
the simplest approach would

00:35:36.210 --> 00:35:38.970
be to translate, but
then to just drop out

00:35:38.970 --> 00:35:43.170
words that don't
have a good analog

00:35:43.170 --> 00:35:46.950
and force your classifier
to work with, let's say,

00:35:46.950 --> 00:35:49.752
just the shared vocabulary.

00:35:49.752 --> 00:35:51.210
Everything we're
talking about here

00:35:51.210 --> 00:35:54.340
is an example of a
manually chosen decision.

00:35:54.340 --> 00:35:56.910
So we're going to manually
choose a new representation

00:35:56.910 --> 00:36:01.740
for the data such that we
have some amount of shared

00:36:01.740 --> 00:36:05.550
features between the source
and target data sets.

00:36:08.320 --> 00:36:10.648
So let's talk about
electronic health record 1

00:36:10.648 --> 00:36:11.940
and electronic health record 2.

00:36:11.940 --> 00:36:14.340
By the way, the slides that
I'll be presenting here

00:36:14.340 --> 00:36:17.300
are from a paper
published in KDD

00:36:17.300 --> 00:36:21.570
by Jan, Tristan, your
instructor, Pete,

00:36:21.570 --> 00:36:23.580
and John Guttag.

00:36:23.580 --> 00:36:25.530
So you have to go
two electronic health

00:36:25.530 --> 00:36:27.210
records, electronic
health record 1,

00:36:27.210 --> 00:36:28.800
electronic health record 2.

00:36:28.800 --> 00:36:30.390
How can things change?

00:36:30.390 --> 00:36:36.900
Well, it could be that the same
concept in electronic health

00:36:36.900 --> 00:36:41.762
record 1 might be mapped
to a different encoding,

00:36:41.762 --> 00:36:43.470
so that's like an
English-to-Spanish type

00:36:43.470 --> 00:36:47.220
translation, in electronic
health record 2.

00:36:47.220 --> 00:36:48.810
Another example
of a change might

00:36:48.810 --> 00:36:55.410
be to say that some concepts
are removed, like maybe you

00:36:55.410 --> 00:36:58.260
have laboratory test results
in electronic health record 1

00:36:58.260 --> 00:37:00.270
but not in electronic
health record 2.

00:37:00.270 --> 00:37:03.120
So that's why you see
an edge to nowhere.

00:37:03.120 --> 00:37:07.925
Another change might be
there might be new concepts.

00:37:07.925 --> 00:37:10.050
So the new institution
might have new types of data

00:37:10.050 --> 00:37:12.220
that the old
institution didn't have.

00:37:12.220 --> 00:37:14.070
So what do you do
in that setting?

00:37:14.070 --> 00:37:17.880
Well, one approach
we would say, OK, we

00:37:17.880 --> 00:37:20.160
have some small amount of
data from electronic health

00:37:20.160 --> 00:37:21.360
record 2.

00:37:21.360 --> 00:37:25.890
We could just train
using that and throw away

00:37:25.890 --> 00:37:28.900
your original data from
electronic health record 1.

00:37:28.900 --> 00:37:30.960
Now, of course, if you
only had a small amount

00:37:30.960 --> 00:37:33.758
of data from the target
to distribution, then

00:37:33.758 --> 00:37:36.300
that's going to be a very poor
approach because you might not

00:37:36.300 --> 00:37:37.717
have enough data
to actually learn

00:37:37.717 --> 00:37:39.990
a reasonable enough model.

00:37:39.990 --> 00:37:41.520
A second obvious
approach would be,

00:37:41.520 --> 00:37:47.250
OK, we're going to just train
on electronic health record 1

00:37:47.250 --> 00:37:48.180
and apply it.

00:37:48.180 --> 00:37:52.930
And for those concepts that
aren't present anymore,

00:37:52.930 --> 00:37:54.330
so be it.

00:37:54.330 --> 00:37:56.028
Maybe things won't
work very well.

00:37:56.028 --> 00:37:58.320
A third approach, which we
were alluding to before when

00:37:58.320 --> 00:37:59.850
we talked about
translation, would

00:37:59.850 --> 00:38:02.430
be to learn a model just in
the intersection of the two

00:38:02.430 --> 00:38:03.780
features.

00:38:03.780 --> 00:38:06.750
And what this work
does, as they say,

00:38:06.750 --> 00:38:09.300
we're going to manually
redefine the feature set

00:38:09.300 --> 00:38:12.938
in order to try to find as
much common ground as possible.

00:38:12.938 --> 00:38:14.730
And this is something
which really involves

00:38:14.730 --> 00:38:17.250
a lot of domain knowledge.

00:38:17.250 --> 00:38:20.190
And I'm going to be using
this as a point of contrast

00:38:20.190 --> 00:38:23.550
from what I'll be talking about
in 10 or 15 minutes, where

00:38:23.550 --> 00:38:25.855
I talk about how one could
do this without that domain

00:38:25.855 --> 00:38:27.480
knowledge that we're
going to use here.

00:38:31.060 --> 00:38:33.640
So the setting
that they looked at

00:38:33.640 --> 00:38:37.720
is one of predicting outcomes,
such as in-hospital mortality

00:38:37.720 --> 00:38:40.330
or length of stay.

00:38:40.330 --> 00:38:43.840
The model which is going to be
used as a bag-of-events model.

00:38:43.840 --> 00:38:47.770
So we will take a patient's
longitudinal history up

00:38:47.770 --> 00:38:49.330
until the time of prediction.

00:38:49.330 --> 00:38:52.390
We'll look at different
events that occurred.

00:38:52.390 --> 00:38:57.280
And this study was
done using PhysioNet.

00:38:57.280 --> 00:39:01.120
And MIMIC, for example, events
are encoded with some number,

00:39:01.120 --> 00:39:04.810
like 5814 might
correspond to a CVP alarm,

00:39:04.810 --> 00:39:07.630
1046 might correspond
to pain being present,

00:39:07.630 --> 00:39:12.620
25 might correspond to the drug
heparin being given and so on.

00:39:12.620 --> 00:39:15.580
So we're going to create one
feature for every event which

00:39:15.580 --> 00:39:16.480
has some number--

00:39:16.480 --> 00:39:18.310
which is encoded
with some number.

00:39:18.310 --> 00:39:21.250
And we'll just say
1 if that event

00:39:21.250 --> 00:39:22.830
has occurred, 0 otherwise.

00:39:22.830 --> 00:39:26.960
So that's the representation
for a patient.

00:39:26.960 --> 00:39:30.350
Now, because when one goes
though this new institution,

00:39:30.350 --> 00:39:34.790
EHR2, the way that
events are encoded

00:39:34.790 --> 00:39:36.560
might be completely different.

00:39:36.560 --> 00:39:38.810
One won't be able to just
use the original feature

00:39:38.810 --> 00:39:40.090
representation.

00:39:40.090 --> 00:39:42.020
And that's the
English-to-Spanish example

00:39:42.020 --> 00:39:43.448
that I gave.

00:39:43.448 --> 00:39:44.990
But instead, what
one could try to do

00:39:44.990 --> 00:39:48.770
is come up with a new feature
set where that feature

00:39:48.770 --> 00:39:53.250
set could be derived from each
of the different data sets.

00:39:53.250 --> 00:39:57.680
So, for example, since each
one of the events in MIMIC

00:39:57.680 --> 00:40:00.140
has some text
description that goes

00:40:00.140 --> 00:40:03.350
with it, event one corresponds
to ischemic stroke,

00:40:03.350 --> 00:40:06.920
event 2, hemorrhagic
stroke, and so on,

00:40:06.920 --> 00:40:08.420
one could attempt to map--

00:40:08.420 --> 00:40:12.560
use that English
description of the feature

00:40:12.560 --> 00:40:15.350
to come up with a way to map
it into a common language.

00:40:15.350 --> 00:40:17.120
In this case, the
common language

00:40:17.120 --> 00:40:20.990
is the UMLS, the
United Medical Language

00:40:20.990 --> 00:40:23.700
System that Pete talked
about a few lectures ago.

00:40:23.700 --> 00:40:26.750
So we're going to now say, OK,
we have a much larger feature

00:40:26.750 --> 00:40:31.580
set where we've now
encoded ischemic stroke

00:40:31.580 --> 00:40:34.700
as this concept,
which is actually

00:40:34.700 --> 00:40:36.920
the same ischemic
stroke, but also

00:40:36.920 --> 00:40:40.460
as this concept
and that concept,

00:40:40.460 --> 00:40:43.850
which are more general
versions of that original one.

00:40:43.850 --> 00:40:46.310
So this is just
general stroke, and it

00:40:46.310 --> 00:40:49.100
could be multiple
different types of strokes.

00:40:49.100 --> 00:40:53.510
And the hope is
that even if in--

00:40:53.510 --> 00:40:55.970
even if the model doesn't--

00:40:55.970 --> 00:40:57.680
even if some of these
more specific ones

00:40:57.680 --> 00:40:59.900
don't show up in the
new institution's data,

00:40:59.900 --> 00:41:04.432
perhaps some of the more general
concepts do show up there.

00:41:04.432 --> 00:41:05.890
And then what you're
going to do is

00:41:05.890 --> 00:41:11.570
you're going to learn your model
now on this expanded translated

00:41:11.570 --> 00:41:14.923
vocabulary, and
then translate it.

00:41:14.923 --> 00:41:16.340
And at the new
institution, you'll

00:41:16.340 --> 00:41:18.600
also be using that
same common data model.

00:41:18.600 --> 00:41:20.810
And that way one hopes
to have much more overlap

00:41:20.810 --> 00:41:23.540
in your feature set.

00:41:23.540 --> 00:41:27.770
And so to evaluate
this, the authors

00:41:27.770 --> 00:41:32.210
looked at two different
time points within MIMIC.

00:41:32.210 --> 00:41:36.350
One time point was when the
Beth Israel Deaconess Medical

00:41:36.350 --> 00:41:39.320
Center was using electronic
health record called CareView.

00:41:39.320 --> 00:41:41.655
And the second time point
was when that hospital

00:41:41.655 --> 00:41:43.280
was using a different
electronic health

00:41:43.280 --> 00:41:45.560
record called MetaVision.

00:41:45.560 --> 00:41:49.670
So this is an example
actually of non-stationarity.

00:41:49.670 --> 00:41:54.200
Now because of them using two
different electronic health

00:41:54.200 --> 00:41:55.790
records, the encodings
were different.

00:41:55.790 --> 00:41:58.733
And that's why
this problem arose.

00:41:58.733 --> 00:42:00.400
And so we're going
to use this approach,

00:42:00.400 --> 00:42:02.270
and we're going to then
learn a linear model

00:42:02.270 --> 00:42:06.410
on top of this new encoding
that I just described.

00:42:06.410 --> 00:42:12.140
And we're going to compare the
results by looking at how much

00:42:12.140 --> 00:42:16.790
performance was lost due
to using this new encoding,

00:42:16.790 --> 00:42:19.110
and how well we
generalize from one--

00:42:19.110 --> 00:42:26.111
from one-- from the source
task to the target task.

00:42:26.111 --> 00:42:28.170
And so here's the
first question,

00:42:28.170 --> 00:42:32.330
which is, how much do we lose
by using this new encoding?

00:42:32.330 --> 00:42:34.220
So as a comparison
point for looking

00:42:34.220 --> 00:42:36.973
at predicting in-hospital
mortality, we'll look at,

00:42:36.973 --> 00:42:38.390
what is the
predictive performance

00:42:38.390 --> 00:42:42.920
if you're to just use an
existing, very simple risk

00:42:42.920 --> 00:42:45.050
score called the SAPS score?

00:42:45.050 --> 00:42:48.260
And that's this red line
where that y-axis here

00:42:48.260 --> 00:42:51.500
is the area under the
ROC curve, and the x-axis

00:42:51.500 --> 00:42:53.480
is how much time
in advance you're

00:42:53.480 --> 00:42:56.090
predicting, so the
prediction gap.

00:42:56.090 --> 00:43:01.520
So using this very simple score,
SAPS get somewhere between 0.75

00:43:01.520 --> 00:43:04.260
and 0.80, area
under the ROC curve.

00:43:04.260 --> 00:43:08.900
But if you were to use all
of the events data, which

00:43:08.900 --> 00:43:11.360
is much, much richer than what
went into that simple SAPS

00:43:11.360 --> 00:43:16.310
score, you would get the
purple curve, which is--

00:43:16.310 --> 00:43:20.180
the purple curve, which is
SAPS plus the event data,

00:43:20.180 --> 00:43:22.372
or the blue curve, which
is just the events data.

00:43:22.372 --> 00:43:24.080
And you can see you
can get substantially

00:43:24.080 --> 00:43:25.670
better predictive
performance by using

00:43:25.670 --> 00:43:28.638
that much richer feature set.

00:43:28.638 --> 00:43:30.680
The SAPS score has the
advantage that it's easier

00:43:30.680 --> 00:43:34.580
to generalize because it's so
simple, those feature elements,

00:43:34.580 --> 00:43:38.600
one could trivially translate
to any new EHR, either manually

00:43:38.600 --> 00:43:43.220
or automatically, and thus
it'll always be a viable route.

00:43:43.220 --> 00:43:45.230
Whereas this blue
curve, although it

00:43:45.230 --> 00:43:46.780
gets better predictive
performance,

00:43:46.780 --> 00:43:49.490
you have to really worry about
these generalization questions.

00:43:52.689 --> 00:43:56.030
And the same story happens
in both of the source

00:43:56.030 --> 00:43:58.160
task and the target task.

00:43:58.160 --> 00:44:00.980
Now the second question
to ask is, well,

00:44:00.980 --> 00:44:03.650
how much do you lose when you
use the new representation

00:44:03.650 --> 00:44:05.520
of the data?

00:44:05.520 --> 00:44:09.290
And so here looking at,
again, both of the two--

00:44:09.290 --> 00:44:13.790
both EHRs, what we
see first in red

00:44:13.790 --> 00:44:17.733
is the same red curvature-- is
the same as the blue curvature

00:44:17.733 --> 00:44:18.650
on the previous slide.

00:44:18.650 --> 00:44:23.550
It's using SAPS plus the item
IDs, so using all of the data.

00:44:23.550 --> 00:44:26.270
And then the blue curve here,
which is a bit hard to see,

00:44:26.270 --> 00:44:29.150
but it's right there,
it's substantially lower.

00:44:29.150 --> 00:44:31.360
So that's what
happens if you now

00:44:31.360 --> 00:44:33.760
use this new representation.

00:44:33.760 --> 00:44:36.340
And you see that you
do lose something

00:44:36.340 --> 00:44:39.940
by trying to find a
common vocabulary.

00:44:39.940 --> 00:44:44.780
The performance
does get hit a bit.

00:44:44.780 --> 00:44:46.570
But what's particularly
interesting is

00:44:46.570 --> 00:44:52.450
when you attempt to generalize,
you start to see a swap.

00:44:52.450 --> 00:44:54.970
So if we now--

00:44:54.970 --> 00:44:59.860
so now the colors are
going to be quite similar.

00:44:59.860 --> 00:45:02.930
Red here was at the
very top before.

00:45:02.930 --> 00:45:07.480
So red is using the original
representation of the data.

00:45:07.480 --> 00:45:10.600
Before it was at the very top.

00:45:10.600 --> 00:45:15.640
Shown here is the training error
on this institution, CareView.

00:45:15.640 --> 00:45:17.938
You see, there's so
much rich information

00:45:17.938 --> 00:45:19.480
in the original
feature set that it's

00:45:19.480 --> 00:45:21.313
able to do very good
predictive performance.

00:45:21.313 --> 00:45:24.370
But once you attempt
to translate it,

00:45:24.370 --> 00:45:28.270
so you train on CareView,
but you test on MetaVision,

00:45:28.270 --> 00:45:32.170
then the test performance shown
here by this solid red line

00:45:32.170 --> 00:45:34.190
is actually the worst
of all of the system.

00:45:34.190 --> 00:45:36.550
So there's a substantial
drop in performance

00:45:36.550 --> 00:45:39.070
because not all
of these features

00:45:39.070 --> 00:45:41.230
are present in the new EHR.

00:45:41.230 --> 00:45:44.680
On the other hand, when
the translated version,

00:45:44.680 --> 00:45:49.540
despite the fact that it's
a little bit worse when

00:45:49.540 --> 00:45:52.930
evaluated on the source,
it generalizes much better.

00:45:52.930 --> 00:45:56.170
And so you see a significantly
better performance

00:45:56.170 --> 00:45:59.320
that's shown by this
blue curve here when you

00:45:59.320 --> 00:46:01.048
use this translated vocabulary.

00:46:01.048 --> 00:46:01.840
There's a question.

00:46:01.840 --> 00:46:04.384
AUDIENCE: So would you
train with full features?

00:46:04.384 --> 00:46:08.430
So how do you apply [? with ?]
them if the other [? full ?]

00:46:08.430 --> 00:46:10.810
features are-- you
just [INAUDIBLE]..

00:46:10.810 --> 00:46:14.860
DAVID SONTAG: So, you
assume that you have come up

00:46:14.860 --> 00:46:19.480
with a mapping from the
features in both of the EHRs

00:46:19.480 --> 00:46:23.995
to this common feature
vocabulary of QEs.

00:46:23.995 --> 00:46:26.620
And the way that this mapping is
going to be done in this paper

00:46:26.620 --> 00:46:29.188
is based on the text of the--

00:46:32.980 --> 00:46:34.480
of the events.

00:46:34.480 --> 00:46:36.700
So you take the text-based
description of the event,

00:46:36.700 --> 00:46:38.533
and you come up with a
deterministic mapping

00:46:38.533 --> 00:46:43.090
to this new UMLS-based
representation.

00:46:43.090 --> 00:46:44.630
And then that's
what's being used.

00:46:44.630 --> 00:46:46.075
There's no fine
tuning being done

00:46:46.075 --> 00:46:47.200
in this particular example.

00:46:51.110 --> 00:46:56.530
So I consider this to be a very
naive application of transfer.

00:46:56.530 --> 00:46:59.770
The results are exactly what you
would expect the results to be.

00:46:59.770 --> 00:47:03.820
And, obviously, a lot of work
had to go into doing this.

00:47:03.820 --> 00:47:06.565
And there's a bit of creativity
in thinking that you should use

00:47:06.565 --> 00:47:08.440
the English-based
description of the features

00:47:08.440 --> 00:47:10.023
to come up with the
automatic mapping,

00:47:10.023 --> 00:47:13.250
but the story ends there.

00:47:13.250 --> 00:47:16.480
And so a question
which all of you

00:47:16.480 --> 00:47:18.520
might have is, how
could you try to do

00:47:18.520 --> 00:47:20.500
such an approach automatically?

00:47:20.500 --> 00:47:23.020
How could we automatically
find representations-- new

00:47:23.020 --> 00:47:24.520
representations of
the data that are

00:47:24.520 --> 00:47:26.470
likely to generalize
from, let's say,

00:47:26.470 --> 00:47:29.970
a source distribution to
a target distribution?

00:47:29.970 --> 00:47:31.630
And so to talk about
that, we're going

00:47:31.630 --> 00:47:34.270
to now start thinking
through representation

00:47:34.270 --> 00:47:37.060
learning-based approaches,
of which deep models are

00:47:37.060 --> 00:47:39.730
particularly capable.

00:47:39.730 --> 00:47:43.930
So the simplest approach to
try to do transfer learning

00:47:43.930 --> 00:47:47.860
in the context of, let's
say, deep neural networks,

00:47:47.860 --> 00:47:52.330
would be to just chop off part
of the network and reuse that--

00:47:52.330 --> 00:47:56.810
some internal representation of
the data in this new location.

00:47:56.810 --> 00:47:59.990
So the picture looks a
little bit like this.

00:47:59.990 --> 00:48:02.600
So the data might
feed in the bottom.

00:48:02.600 --> 00:48:04.600
There might be a number
of convolutional layers,

00:48:04.600 --> 00:48:05.782
some fully connected layers.

00:48:05.782 --> 00:48:07.240
And what you decide
to do is you're

00:48:07.240 --> 00:48:10.660
going to take this model that's
trained in one institution,

00:48:10.660 --> 00:48:14.930
you chop it at some layer,
it might be, for example,

00:48:14.930 --> 00:48:17.920
prior to the last
fully connected layer,

00:48:17.920 --> 00:48:20.890
and then you're
going to take that--

00:48:20.890 --> 00:48:23.690
take the new representation
of your data,

00:48:23.690 --> 00:48:25.210
now the representation
of the data

00:48:25.210 --> 00:48:29.590
is what you would get out
after doing some convolutions

00:48:29.590 --> 00:48:32.020
followed by a single
fully connected layer,

00:48:32.020 --> 00:48:36.160
and then you're going to take
your target distribution's

00:48:36.160 --> 00:48:38.740
data, which you might only
have a small amount of,

00:48:38.740 --> 00:48:41.660
and you learn a simple model on
top of that new representation.

00:48:41.660 --> 00:48:43.570
So, for example, you might
learn a shallow classifier

00:48:43.570 --> 00:48:45.112
using a support
vector machine on top

00:48:45.112 --> 00:48:46.270
of that new representation.

00:48:46.270 --> 00:48:50.160
Or you might add in some more--

00:48:50.160 --> 00:48:52.660
a couple more layers of a deep
neural network, and then fine

00:48:52.660 --> 00:48:54.220
tune the whole thing end to end.

00:48:54.220 --> 00:48:56.410
So all of these have been tried.

00:48:56.410 --> 00:49:00.050
And in some cases, one
works better than another.

00:49:00.050 --> 00:49:05.590
And we saw already one example
of this notion in this course.

00:49:05.590 --> 00:49:09.700
And that was when Adam
Yala spoke in lecture 13

00:49:09.700 --> 00:49:14.440
about breast cancer
and mammography,

00:49:14.440 --> 00:49:19.660
where in his approach he
said that he had tried both

00:49:19.660 --> 00:49:25.030
taking a randomly
initialized classifier

00:49:25.030 --> 00:49:28.030
and comparing that to what
would happen if you initialized

00:49:28.030 --> 00:49:32.560
with a well-known
ImageNet-based deep neural

00:49:32.560 --> 00:49:34.360
network for the problem.

00:49:34.360 --> 00:49:37.150
And he had a really
interesting story that he gave.

00:49:37.150 --> 00:49:42.190
In his case, he had enough
data that he actually

00:49:42.190 --> 00:49:47.080
didn't need to initialize
using this pre-trained model

00:49:47.080 --> 00:49:48.250
from ImageNet.

00:49:48.250 --> 00:49:52.060
If he had just done a random
initialization, eventually--

00:49:52.060 --> 00:49:53.590
and this x-axis,
I can't remember,

00:49:53.590 --> 00:49:57.850
it might be hours of training
or epochs, I don't remember,

00:49:57.850 --> 00:49:58.600
it's time--

00:49:58.600 --> 00:50:00.160
eventually the
right initialization

00:50:00.160 --> 00:50:02.170
gets to a very
similar performance.

00:50:02.170 --> 00:50:04.540
But for his particular
case, if you

00:50:04.540 --> 00:50:08.740
were to do a initialization with
ImageNet and then fine tune,

00:50:08.740 --> 00:50:10.940
you get there
much, much quicker.

00:50:10.940 --> 00:50:13.090
And so it was for the
computational reason

00:50:13.090 --> 00:50:14.870
that he found it to be useful.

00:50:14.870 --> 00:50:17.290
But in many other applications
in medical imaging,

00:50:17.290 --> 00:50:19.660
the same tricks become
essential because you just

00:50:19.660 --> 00:50:22.020
don't have enough data
in the new test case.

00:50:22.020 --> 00:50:25.660
And so one makes use of,
for example, the filters

00:50:25.660 --> 00:50:29.170
which one learns from an
ImageNet's task, which

00:50:29.170 --> 00:50:33.010
is dramatically different from
the medical imaging problems,

00:50:33.010 --> 00:50:34.870
and then using those
same filters together

00:50:34.870 --> 00:50:37.330
with a new top layer,
set of top layers

00:50:37.330 --> 00:50:41.265
in order to fine tune it for
the problem that you care about.

00:50:41.265 --> 00:50:42.640
So this would be
the simplest way

00:50:42.640 --> 00:50:46.990
to try to hope for a common
representation for transfer

00:50:46.990 --> 00:50:49.240
in a deep architecture.

00:50:49.240 --> 00:50:52.480
But you might ask, how would
you do the same sort of thing

00:50:52.480 --> 00:50:56.200
with temporal data, not
image data, maybe data

00:50:56.200 --> 00:51:00.010
that's from language, or data
from time series of health

00:51:00.010 --> 00:51:01.113
insurance claims?

00:51:01.113 --> 00:51:02.530
And for that you
really want to be

00:51:02.530 --> 00:51:05.420
thinking about recurrent
neural networks.

00:51:05.420 --> 00:51:08.050
So just to remind you,
recurrent neural network

00:51:08.050 --> 00:51:10.030
is a recurrent
architecture where

00:51:10.030 --> 00:51:11.852
you take as input some vector.

00:51:11.852 --> 00:51:13.810
For example, if you're
doing language modeling,

00:51:13.810 --> 00:51:16.393
that vector might be encoding,
just a one-hot encoding of what

00:51:16.393 --> 00:51:17.810
is the word at that location.

00:51:17.810 --> 00:51:20.027
So, for example, this
vector might be all zeros,

00:51:20.027 --> 00:51:21.610
except for the fourth
dimension, which

00:51:21.610 --> 00:51:24.970
is a 1, denoting that this word
is the word, quote, "class."

00:51:24.970 --> 00:51:28.780
And then it's fed into
a recurrent unit, which

00:51:28.780 --> 00:51:32.117
takes the previous
hidden state, combined it

00:51:32.117 --> 00:51:34.450
with the current input, and
gets you a new hidden state.

00:51:34.450 --> 00:51:37.990
And in this way, you read in--
you encode the full input.

00:51:37.990 --> 00:51:39.580
And then you might predict--

00:51:39.580 --> 00:51:41.680
make a classification
based on the hidden state

00:51:41.680 --> 00:51:43.055
of the last time
[? step. ?] That

00:51:43.055 --> 00:51:44.720
would be a common approach.

00:51:44.720 --> 00:51:47.800
And here would be a very simple
example of a recurrent unit.

00:51:47.800 --> 00:51:49.750
Here I'm using S to
denote in a state.

00:51:49.750 --> 00:51:52.780
Often you will see H used
to denote the hidden state.

00:51:52.780 --> 00:51:54.400
This is a particularly
simple example,

00:51:54.400 --> 00:51:56.560
where there's just a
single non-linearity.

00:51:56.560 --> 00:51:58.630
So you take your
previous hidden state,

00:51:58.630 --> 00:52:06.130
you hit it with some matrix Ws,s
and you add that to the input

00:52:06.130 --> 00:52:08.770
being hit by a different matrix.

00:52:08.770 --> 00:52:11.500
You now have a
combination of the input

00:52:11.500 --> 00:52:12.833
plus the previous hidden state.

00:52:12.833 --> 00:52:14.500
You apply non-linearity
to that, and you

00:52:14.500 --> 00:52:15.750
get your new hidden state out.

00:52:15.750 --> 00:52:18.940
So that would be an example
of a typical recurrent unit,

00:52:18.940 --> 00:52:20.950
a very simple recurrent unit.

00:52:20.950 --> 00:52:23.200
Now, the reason why I'm going
through these details is

00:52:23.200 --> 00:52:28.570
to point out that the dimension
of that Ws,x matrix is

00:52:28.570 --> 00:52:32.164
the dimension of the hidden
state, so the dimension of s,

00:52:32.164 --> 00:52:36.430
by the vocabulary size if
you're using a one-hot encoding

00:52:36.430 --> 00:52:38.060
of the input.

00:52:38.060 --> 00:52:42.850
So if you have a huge
vocabulary, that matrix, Ws,x,

00:52:42.850 --> 00:52:45.250
is also going to
be equally large.

00:52:45.250 --> 00:52:47.110
And the challenge
that that presents

00:52:47.110 --> 00:52:52.360
is that it would lead to
overfitting on rare words

00:52:52.360 --> 00:52:54.560
very quickly.

00:52:54.560 --> 00:52:57.220
And so that's a problem that
could be addressed by instead

00:52:57.220 --> 00:53:03.400
using a low-rank representation
of that Ws,x matrix.

00:53:03.400 --> 00:53:06.880
In particular, you could
think about introducing

00:53:06.880 --> 00:53:11.620
a lower dimensional bottleneck,
which in this picture

00:53:11.620 --> 00:53:17.740
I'm denoting as xt prime,
which is your original xt

00:53:17.740 --> 00:53:19.930
input, which is the
one-hot encoding,

00:53:19.930 --> 00:53:21.970
multiplied by a new matrix We.

00:53:24.550 --> 00:53:28.360
And then your recurrent
unit only takes

00:53:28.360 --> 00:53:30.220
inputs of that hidden--

00:53:30.220 --> 00:53:34.840
of that xt prime's
dimension, which

00:53:34.840 --> 00:53:39.340
is k, which might be
dramatically smaller than v.

00:53:39.340 --> 00:53:41.290
And you can even think
about each column

00:53:41.290 --> 00:53:44.590
of that intermediate
representation, We,

00:53:44.590 --> 00:53:46.600
as a word embedding.

00:53:46.600 --> 00:53:49.180
It's a way of--

00:53:49.180 --> 00:53:51.363
and this is something that
Pete talked quite a bit

00:53:51.363 --> 00:53:53.530
about when we were thinking
about natural language--

00:53:53.530 --> 00:53:56.110
when we were talking about
natural language processing.

00:53:56.110 --> 00:53:58.690
And many of you would
have heard about it

00:53:58.690 --> 00:54:02.470
in the context of
things like Word2Vec.

00:54:02.470 --> 00:54:08.650
So if one wanted to take
a setting, for example,

00:54:08.650 --> 00:54:14.335
one institution's data where
you had a huge amount of data,

00:54:14.335 --> 00:54:16.690
learn every current neural
network on that institution's

00:54:16.690 --> 00:54:19.510
data, and then generalize
it to a new institution,

00:54:19.510 --> 00:54:22.630
one way of trying to do
that, if you think about,

00:54:22.630 --> 00:54:25.958
what is the thing that you chop,
one answer might be, all you do

00:54:25.958 --> 00:54:27.250
is you keep the word embedding.

00:54:27.250 --> 00:54:28.875
So you might say,
OK, I'm going to keep

00:54:28.875 --> 00:54:32.980
the We's, I'm going to translate
it back to my new institution.

00:54:32.980 --> 00:54:35.830
But I'm going to let the
recurrent unit parameters--

00:54:35.830 --> 00:54:37.490
the recurrent
parameters, for example,

00:54:37.490 --> 00:54:41.380
that Ws,s you might allow it
to be relearned for each new

00:54:41.380 --> 00:54:43.035
institution.

00:54:43.035 --> 00:54:44.410
And so that might
be one approach

00:54:44.410 --> 00:54:46.720
of how to use the
same idea that we

00:54:46.720 --> 00:54:53.530
had from feed forward networks
within a recurrent setting.

00:54:53.530 --> 00:54:57.110
Now, all of this
is very general.

00:54:57.110 --> 00:54:59.890
And what I want to do
next is to instantiate it

00:54:59.890 --> 00:55:05.080
a bit in the context
of health care.

00:55:05.080 --> 00:55:09.610
So since the time
that Pete presented

00:55:09.610 --> 00:55:15.190
the extensions of Word2Vec
such as BERT and ELMo,

00:55:15.190 --> 00:55:16.120
and I'm not going to--

00:55:16.120 --> 00:55:17.537
I'm not going to
go into them now,

00:55:17.537 --> 00:55:20.290
but you can go back to Pete's
lecture from a few weeks

00:55:20.290 --> 00:55:22.975
ago to remind yourselves what
those were, since the time

00:55:22.975 --> 00:55:24.850
he presented that lecture,
there are actually

00:55:24.850 --> 00:55:26.650
three new papers
that actually tried

00:55:26.650 --> 00:55:30.370
to apply this in the health
care context, one of which

00:55:30.370 --> 00:55:32.680
was from MIT.

00:55:32.680 --> 00:55:36.490
And so these papers all
have the same sort of idea.

00:55:36.490 --> 00:55:39.640
They're going to
take some data set--

00:55:39.640 --> 00:55:43.480
and these papers all use MIMIC.

00:55:43.480 --> 00:55:45.370
They're going to
take that text data,

00:55:45.370 --> 00:55:48.850
they're going to learn
some word embeddings

00:55:48.850 --> 00:55:50.500
or some low-dimensional
representations

00:55:50.500 --> 00:55:52.300
of all words in the vocabulary.

00:55:52.300 --> 00:55:54.460
In this case,
they're not learning

00:55:54.460 --> 00:55:56.290
a static representation
for each word.

00:55:56.290 --> 00:55:59.140
Instead these BERT
and ELMo approaches

00:55:59.140 --> 00:56:00.640
are going to be
learning-- well, you

00:56:00.640 --> 00:56:02.330
can think of it as
dynamic representations.

00:56:02.330 --> 00:56:04.080
They're going to be a
function of the word

00:56:04.080 --> 00:56:06.713
and their context on the
left and right-hand sides.

00:56:06.713 --> 00:56:08.380
And then what they'll
do is they'll then

00:56:08.380 --> 00:56:10.930
take those representations
and attempt to use them

00:56:10.930 --> 00:56:13.120
for a completely new task.

00:56:13.120 --> 00:56:17.210
Those new tasks might
be on MIMIC data.

00:56:17.210 --> 00:56:20.410
So, for example, these two tasks
are classification problems

00:56:20.410 --> 00:56:21.310
on MIMIC.

00:56:21.310 --> 00:56:23.210
But they might also
be on non-MIMIC data.

00:56:23.210 --> 00:56:27.490
So these two tasks are from
classification problems

00:56:27.490 --> 00:56:30.830
on clinical text that didn't
even come from MIMIC at all.

00:56:30.830 --> 00:56:32.578
So it's really an
example of translating

00:56:32.578 --> 00:56:34.120
what you learned
from one institution

00:56:34.120 --> 00:56:35.650
to another institution.

00:56:35.650 --> 00:56:37.450
These two data sets
were super small.

00:56:37.450 --> 00:56:40.475
Actually, all of these data
sets were really, really small

00:56:40.475 --> 00:56:42.100
compared to the
original size of MIMIC.

00:56:42.100 --> 00:56:44.725
So there might be some hope that
one could learn something that

00:56:44.725 --> 00:56:46.660
really improves generalization.

00:56:46.660 --> 00:56:48.890
And indeed, that's
what plays out.

00:56:48.890 --> 00:56:53.450
So all these tasks are looking
at a concept detection task.

00:56:53.450 --> 00:56:59.240
Given a clinical note,
identify the segments of text

00:56:59.240 --> 00:57:01.280
within a note that
refer to, for example,

00:57:01.280 --> 00:57:04.280
a disorder, or a treatment,
or something else, which

00:57:04.280 --> 00:57:08.030
you then in a second stage
might normalize to the UMLS.

00:57:10.790 --> 00:57:13.940
So what's really striking
about these results

00:57:13.940 --> 00:57:18.590
is what happens when you go
from the left to the right

00:57:18.590 --> 00:57:20.280
column, which I'll
explain in a second,

00:57:20.280 --> 00:57:22.520
and what happens when
you go top to bottom

00:57:22.520 --> 00:57:24.810
across each one of
these different tasks.

00:57:24.810 --> 00:57:27.630
So the left column
are the results.

00:57:27.630 --> 00:57:33.230
And these results are
an F score, the results,

00:57:33.230 --> 00:57:39.365
if you were to use embeddings
trained on a non-clinical data

00:57:39.365 --> 00:57:42.260
set, or said definitely, not
on MIMIC but on some other more

00:57:42.260 --> 00:57:44.427
general data set.

00:57:44.427 --> 00:57:46.010
The second column
is what would happen

00:57:46.010 --> 00:57:49.730
if you trained those embedding
on a clinical data set,

00:57:49.730 --> 00:57:51.440
in this case, MIMIC.

00:57:51.440 --> 00:57:54.230
And you see pretty
big improvements

00:57:54.230 --> 00:57:58.550
from the general embeddings
to the MIMIC-based embeddings.

00:57:58.550 --> 00:58:01.040
What's even more striking
is the improvements

00:58:01.040 --> 00:58:04.190
that happen as you get
better and better embeddings.

00:58:04.190 --> 00:58:07.040
So the first row are
the results if you

00:58:07.040 --> 00:58:09.380
were to use just
Word2Vec embeddings.

00:58:09.380 --> 00:58:15.470
And so, for example, for
the I2B2 Challenge in 2010,

00:58:15.470 --> 00:58:20.720
you get 82.65 F score
using Word2Vec embeddings.

00:58:20.720 --> 00:58:23.300
And if you use a very
large BERT embedding,

00:58:23.300 --> 00:58:28.010
you get 90.25 F score--

00:58:28.010 --> 00:58:31.850
F measure, which is
substantially higher.

00:58:31.850 --> 00:58:34.200
And the same findings were
found time and time again

00:58:34.200 --> 00:58:36.810
across different tasks.

00:58:36.810 --> 00:58:39.740
Now, what I find really
striking about these results

00:58:39.740 --> 00:58:43.160
is that I had tried many of
these things a couple of years

00:58:43.160 --> 00:58:46.040
ago, not using BERT or
ELMo, but using Word2Vec,

00:58:46.040 --> 00:58:48.320
and GloVe, and fastText.

00:58:48.320 --> 00:58:52.550
And what I found is that using
word embedding approaches

00:58:52.550 --> 00:58:54.530
for these problems didn't--

00:58:54.530 --> 00:58:57.440
even if you threw that in as
additional features on top

00:58:57.440 --> 00:59:03.110
of other state-of-the-art
approaches to this concept

00:59:03.110 --> 00:59:06.470
extraction problem, it did not
improve predictive performance

00:59:06.470 --> 00:59:09.050
above the existing
state of the art.

00:59:09.050 --> 00:59:11.240
However, in this
paper, here they

00:59:11.240 --> 00:59:13.820
use the simplest
possible algorithm.

00:59:13.820 --> 00:59:15.710
They used a recurrent
neural network

00:59:15.710 --> 00:59:18.080
fed into a conditional
random field

00:59:18.080 --> 00:59:21.680
for the purpose of classifying
each word into each

00:59:21.680 --> 00:59:23.240
of these categories.

00:59:23.240 --> 00:59:25.160
And the feature
represent-- the features

00:59:25.160 --> 00:59:28.320
that they used are just
these embedding features.

00:59:28.320 --> 00:59:30.370
So with just the Word2Vec
embedding features,

00:59:30.370 --> 00:59:31.370
the performance is crap.

00:59:31.370 --> 00:59:33.740
You don't get anywhere
close to the state of art.

00:59:33.740 --> 00:59:37.070
But with the better embeddings,
they actually obtain--

00:59:37.070 --> 00:59:39.440
actually, they improved
on the state of the art

00:59:39.440 --> 00:59:43.560
for every single
one of these tasks.

00:59:43.560 --> 00:59:46.010
And that is without any
of the manual feature

00:59:46.010 --> 00:59:48.110
engineering which
we have been using

00:59:48.110 --> 00:59:50.670
in the field for
the last decade.

00:59:50.670 --> 00:59:54.620
So I find this to be
extremely promising.

00:59:54.620 --> 00:59:59.090
Now you might ask, well, that
is for one problem, which

00:59:59.090 --> 01:00:04.700
is classification of concepts--
or identification of concepts.

01:00:04.700 --> 01:00:06.900
What about for a
predictive problem?

01:00:06.900 --> 01:00:09.500
So a different paper
also published--

01:00:09.500 --> 01:00:13.670
what month is it now, May--
so last month in April,

01:00:13.670 --> 01:00:16.700
looked at a predicted problem
of 30-day readmission prediction

01:00:16.700 --> 01:00:18.100
using discharge summaries.

01:00:18.100 --> 01:00:20.600
This also was valued on MIMIC.

01:00:20.600 --> 01:00:23.960
And their evaluation
looked at the area

01:00:23.960 --> 01:00:26.610
under the ROC curve of
two different approaches.

01:00:26.610 --> 01:00:29.270
The first approach, which is
using a bag-of-words model,

01:00:29.270 --> 01:00:32.090
like what you did in
your homework assignment,

01:00:32.090 --> 01:00:35.360
and the second approach,
which is the top row there,

01:00:35.360 --> 01:00:40.640
which is using BERT embeddings,
which they call Clinical BERT.

01:00:40.640 --> 01:00:43.340
And this, again, is
something which I had

01:00:43.340 --> 01:00:44.760
tackled for quite a long time.

01:00:44.760 --> 01:00:46.970
So I worked on these types
of readmission problems.

01:00:46.970 --> 01:00:48.887
And bag-of-words model
is really hard to beat.

01:00:48.887 --> 01:00:54.090
In fact, did any of you beat
it in your homework assignment?

01:00:54.090 --> 01:00:56.070
If you remember, there
was an extra question,

01:00:56.070 --> 01:00:57.860
which is, oh, well,
maybe if we used

01:00:57.860 --> 01:00:59.870
a deep learning-based
approach for this problem,

01:00:59.870 --> 01:01:01.495
maybe you could get
better performance.

01:01:01.495 --> 01:01:03.870
Did anyone get
better performance?

01:01:03.870 --> 01:01:04.723
No.

01:01:04.723 --> 01:01:06.140
How many of you
actually tried it?

01:01:06.140 --> 01:01:08.790
Raise your hand.

01:01:08.790 --> 01:01:11.280
OK, so one-- a couple of
people who are afraid to

01:01:11.280 --> 01:01:11.970
say, but yeah.

01:01:11.970 --> 01:01:13.887
So a couple of people
who tried, but not many.

01:01:16.020 --> 01:01:19.387
But I think the reason why it's
very challenging to do better

01:01:19.387 --> 01:01:21.470
with, let's say, a recurrent
neural network versus

01:01:21.470 --> 01:01:25.500
a bag-of-words model
is because there is--

01:01:25.500 --> 01:01:30.600
a lot of the subtlety in
understanding the text

01:01:30.600 --> 01:01:32.845
is in terms of understanding
the context of the text.

01:01:32.845 --> 01:01:35.220
And that's something that
using these newer embeddings is

01:01:35.220 --> 01:01:37.137
actually really good at
because they can get--

01:01:37.137 --> 01:01:39.410
they could use the
context of words

01:01:39.410 --> 01:01:42.910
to better represent what
each word actually means.

01:01:42.910 --> 01:01:44.560
And they see
substantial improvement

01:01:44.560 --> 01:01:47.250
in performance
using this approach.

01:01:47.250 --> 01:01:48.790
What about for non-text data?

01:01:48.790 --> 01:01:54.220
So you might ask when we
have health insurance claims,

01:01:54.220 --> 01:01:56.373
we have longitudinal
data across time.

01:01:56.373 --> 01:01:57.540
There's no language in this.

01:01:57.540 --> 01:01:58.860
It's a time series data set.

01:01:58.860 --> 01:02:02.050
You have ICD-9 codes
at each point in time,

01:02:02.050 --> 01:02:04.300
you have maybe lab test
results, medication records.

01:02:04.300 --> 01:02:06.300
And this is very similar
to the market scan data

01:02:06.300 --> 01:02:08.580
that you used in your
homework assignment.

01:02:08.580 --> 01:02:12.270
Could one learn embeddings
for this type of data, which

01:02:12.270 --> 01:02:15.220
is also useful for transfer?

01:02:15.220 --> 01:02:20.760
So one goal might be to say, OK,
let's take every ICD-9, ICD-10

01:02:20.760 --> 01:02:23.890
code, every medication,
every laboratory test result,

01:02:23.890 --> 01:02:28.562
and embed those event types into
some lower dimensional space.

01:02:28.562 --> 01:02:30.270
And so here's an
example of an embedding.

01:02:30.270 --> 01:02:32.740
And you see how-- this is
just a sketch, by the way--

01:02:32.740 --> 01:02:35.400
you see how you might
hope that diagnosis

01:02:35.400 --> 01:02:37.110
codes for autoimmune
conditions might

01:02:37.110 --> 01:02:39.420
be all near each other
in some lower dimensional

01:02:39.420 --> 01:02:42.660
space, diagnosis
codes for medications

01:02:42.660 --> 01:02:45.000
that treat some conditions
should be near each other,

01:02:45.000 --> 01:02:45.500
and so on.

01:02:45.500 --> 01:02:47.970
So you might hope that such
structure might be discovered

01:02:47.970 --> 01:02:50.730
by an unsupervised learning
algorithm that could then

01:02:50.730 --> 01:02:53.585
be used within a transfer
learning approach.

01:02:53.585 --> 01:02:54.960
And indeed, that's
what we found.

01:02:54.960 --> 01:03:00.270
So I wrote a paper
on this in 2015/16.

01:03:00.270 --> 01:03:05.888
And here's one of the
results from that paper.

01:03:05.888 --> 01:03:07.680
So this is just a look
at nearest neighbors

01:03:07.680 --> 01:03:12.870
to give you some sense of
whether the embedding's

01:03:12.870 --> 01:03:14.795
actually capturing the
structure of the data.

01:03:14.795 --> 01:03:16.170
So we looked at
nearest neighbors

01:03:16.170 --> 01:03:22.930
of the diagnosis ICD-9 diagnosis
code 710.0, which is lupus.

01:03:22.930 --> 01:03:25.590
And what you find is that
another diagnosis code, also

01:03:25.590 --> 01:03:28.890
for lupus, is the first
closest result, followed

01:03:28.890 --> 01:03:31.930
by connective tissue
disorder, or Sicca syndrome,

01:03:31.930 --> 01:03:34.587
which is Sjogren's disease,
Raynaud's syndrome,

01:03:34.587 --> 01:03:35.920
and other autoimmune conditions.

01:03:35.920 --> 01:03:37.380
So that makes a lot of sense.

01:03:37.380 --> 01:03:39.210
You can also go
across data types,

01:03:39.210 --> 01:03:42.390
like ask, what is the nearest
neighbor from this diagnosis

01:03:42.390 --> 01:03:44.790
code to laboratory tests?

01:03:44.790 --> 01:03:47.055
And since we've embedded
lab tests and diagnosis

01:03:47.055 --> 01:03:48.930
codes all in the same
space, you can actually

01:03:48.930 --> 01:03:49.860
get an answer to that.

01:03:49.860 --> 01:03:52.465
And what you see is that these
lab tests, which by the way

01:03:52.465 --> 01:03:54.090
are exactly lab tests
that are commonly

01:03:54.090 --> 01:04:00.090
used to understand progression
in this autoimmune condition,

01:04:00.090 --> 01:04:01.890
are the closest neighbors.

01:04:01.890 --> 01:04:07.110
Similarly, you can ask the same
question about drugs and so on.

01:04:07.110 --> 01:04:11.730
And by the way, we have made
all of these embeddings publicly

01:04:11.730 --> 01:04:14.970
available on my lab's GitHub.

01:04:14.970 --> 01:04:17.040
And since the time that
I wrote this paper,

01:04:17.040 --> 01:04:18.930
there have been a
number of other papers,

01:04:18.930 --> 01:04:21.150
that I give citations
to at the bottom here,

01:04:21.150 --> 01:04:22.890
tackling a very similar problem.

01:04:22.890 --> 01:04:27.780
This last one also put there
embeddings publicly available,

01:04:27.780 --> 01:04:31.890
and is much larger than the one
that we had So these things,

01:04:31.890 --> 01:04:34.020
I think, would also be
very useful as one starts

01:04:34.020 --> 01:04:37.860
to think about how one can
transfer knowledge learned

01:04:37.860 --> 01:04:40.062
on one institution to
another institution

01:04:40.062 --> 01:04:41.520
where you might
have much less data

01:04:41.520 --> 01:04:42.831
than that other institution.

01:04:45.480 --> 01:04:48.810
So finally I want to
return back to the question

01:04:48.810 --> 01:04:50.820
that I raised in
bullet two here,

01:04:50.820 --> 01:04:52.290
where we looked
at a linear model

01:04:52.290 --> 01:04:54.180
with a manually
chosen representation,

01:04:54.180 --> 01:04:57.300
and ask, could we--

01:04:57.300 --> 01:05:01.770
instead of just naively chopping
your deep neural network

01:05:01.770 --> 01:05:04.680
at some layer and
then fine tuning,

01:05:04.680 --> 01:05:07.320
could one have learned a
representation of your data

01:05:07.320 --> 01:05:11.580
specifically for the purpose of
encouraging good generalization

01:05:11.580 --> 01:05:13.680
to a new institution?

01:05:13.680 --> 01:05:17.220
And there has been some
really exciting work

01:05:17.220 --> 01:05:23.030
in this field that goes by the
name of Unsupervised Domain

01:05:23.030 --> 01:05:23.690
Adaptation.

01:05:34.030 --> 01:05:36.510
So the setting that's
considered here

01:05:36.510 --> 01:05:38.490
is where you have
data from-- you

01:05:38.490 --> 01:05:42.190
have data from first some
institution, which is x

01:05:42.190 --> 01:05:44.850
comma y.

01:05:44.850 --> 01:05:48.810
But then you want
to do prediction

01:05:48.810 --> 01:05:51.390
from a new institution
where all you have access

01:05:51.390 --> 01:05:55.260
to at training time is x.

01:05:55.260 --> 01:05:57.510
So as opposed to the
transfer settings

01:05:57.510 --> 01:06:00.300
that I talked about earlier,
now for this new institution,

01:06:00.300 --> 01:06:03.002
you might have a ton
of unlabeled data.

01:06:03.002 --> 01:06:04.710
Whereas before I was
talking about having

01:06:04.710 --> 01:06:06.198
just a small amount
of label data,

01:06:06.198 --> 01:06:07.740
but I never talked
of the possibility

01:06:07.740 --> 01:06:09.950
of having a large amount
of unlabeled data.

01:06:09.950 --> 01:06:11.460
And so you might
ask, how could you

01:06:11.460 --> 01:06:13.890
use that large amount
of unlabeled data

01:06:13.890 --> 01:06:16.410
from that second
institution in order

01:06:16.410 --> 01:06:19.110
to learn representation
that actually encourages

01:06:19.110 --> 01:06:21.955
similarities from one
solution to the other?

01:06:21.955 --> 01:06:24.330
And that's exactly what these
domain adversarial training

01:06:24.330 --> 01:06:25.890
approaches will do.

01:06:25.890 --> 01:06:28.170
What they do is they
add a second term

01:06:28.170 --> 01:06:29.740
to the last function.

01:06:29.740 --> 01:06:31.590
So they're going to minimize--

01:06:31.590 --> 01:06:34.055
the intuition is you're
going to minimize--

01:06:34.055 --> 01:06:35.680
you're going to try
to learn parameters

01:06:35.680 --> 01:06:45.150
that minimize your loss function
evaluated on data set 1.

01:06:45.150 --> 01:06:49.530
But intuitively, you're
going to ask that there also

01:06:49.530 --> 01:06:55.350
be a small distance, which
I'll just note as d here,

01:06:55.350 --> 01:07:00.105
between D1 and D2.

01:07:00.105 --> 01:07:02.850
And so I'm being a little
bit loose with notation here,

01:07:02.850 --> 01:07:04.840
but when I calculate
distance here,

01:07:04.840 --> 01:07:07.467
I'm referring to distance
in representation space.

01:07:07.467 --> 01:07:09.300
So you might imagine
taking the middle layer

01:07:09.300 --> 01:07:10.740
of your deep neural
network, so taking,

01:07:10.740 --> 01:07:12.690
let's say, this layer, which
we're going to call the feature

01:07:12.690 --> 01:07:15.240
layer, or the representation
layer, and you're going to say,

01:07:15.240 --> 01:07:20.700
I want that my data under
the first institution

01:07:20.700 --> 01:07:22.740
should look very
similar to the data

01:07:22.740 --> 01:07:23.950
under the second institution.

01:07:23.950 --> 01:07:26.490
So the first few layers of
your deep neural network

01:07:26.490 --> 01:07:29.310
are going to attempt to
equalize the two data sets

01:07:29.310 --> 01:07:34.140
so that they look similar to
another, at least in x space.

01:07:34.140 --> 01:07:36.480
And we're going to attempt
to find representations

01:07:36.480 --> 01:07:38.730
of your model that get
good predictive performance

01:07:38.730 --> 01:07:41.340
on the data set for which
you actually have the labels

01:07:41.340 --> 01:07:44.010
and for which the induced
representations, let's

01:07:44.010 --> 01:07:47.670
say, the middle layer look very
similar across the two data

01:07:47.670 --> 01:07:48.420
sets.

01:07:48.420 --> 01:07:51.480
And one way to do that is just
to try to predict for each--

01:07:51.480 --> 01:07:52.800
you now get a--

01:07:52.800 --> 01:07:54.570
for each data point,
you might actually

01:07:54.570 --> 01:07:56.760
say, well, which data
set did it come from,

01:07:56.760 --> 01:07:58.650
data set 1 or data set 2?

01:07:58.650 --> 01:08:00.900
And what you want is that
your model should not

01:08:00.900 --> 01:08:03.100
be able to distinguish
which data set it came from.

01:08:03.100 --> 01:08:05.645
So that's what it says,
gradient reverse layer

01:08:05.645 --> 01:08:07.020
you want to be
able to-- you want

01:08:07.020 --> 01:08:09.943
to ensure that predicting which
data set that data came from,

01:08:09.943 --> 01:08:11.985
you want to perform badly
on that loss functions.

01:08:11.985 --> 01:08:14.939
It's like taking the
minus of that loss.

01:08:14.939 --> 01:08:17.189
And so we're not going to
go into the details of that,

01:08:17.189 --> 01:08:18.897
but I just wanted to
give you a reference

01:08:18.897 --> 01:08:20.472
to that approach in the bottom.

01:08:20.472 --> 01:08:21.930
And what I want to
do is just spend

01:08:21.930 --> 01:08:24.720
one minute at the very end
talking now about defenses

01:08:24.720 --> 01:08:26.189
to adversarial attacks.

01:08:26.189 --> 01:08:27.750
And conceptually
this is very simple.

01:08:27.750 --> 01:08:31.569
And that's why I can
actually do it in one minute.

01:08:31.569 --> 01:08:36.720
So we talked about how one could
easily modify an image in order

01:08:36.720 --> 01:08:40.566
to turn the prediction from,
let's say, pig to airliner.

01:08:40.566 --> 01:08:42.899
But how could we change your
learning algorithm actually

01:08:42.899 --> 01:08:44.550
to make sure that,
despite the fact that you

01:08:44.550 --> 01:08:47.130
do this perturbation, you still
get the right prediction out,

01:08:47.130 --> 01:08:48.359
pig?

01:08:48.359 --> 01:08:50.609
Well, to think through that,
we have to think through,

01:08:50.609 --> 01:08:51.960
how do we do machine learning?

01:08:51.960 --> 01:08:53.793
Well, a typical approach
to machine learning

01:08:53.793 --> 01:08:57.450
is to learn some
parameters theta minimized

01:08:57.450 --> 01:08:59.670
your empirical loss.

01:08:59.670 --> 01:09:01.140
Often we use deep
neural networks,

01:09:01.140 --> 01:09:02.390
which look a little like this.

01:09:02.390 --> 01:09:04.290
And we do gradient
descent where we attempt

01:09:04.290 --> 01:09:08.670
to minimize some loss surfaced,
find some parameters theta have

01:09:08.670 --> 01:09:11.890
as low loss as possible.

01:09:11.890 --> 01:09:14.410
Now, when you think about
an adversarial example

01:09:14.410 --> 01:09:16.950
and where they come
from, typically one

01:09:16.950 --> 01:09:19.745
finds an adversarial example
in the following way.

01:09:19.745 --> 01:09:21.120
You take your same
loss function,

01:09:21.120 --> 01:09:23.939
now for specific
input x, and you

01:09:23.939 --> 01:09:26.260
try to find some
perturbation delta

01:09:26.260 --> 01:09:29.880
to x an additive perturbation,
for example, such

01:09:29.880 --> 01:09:34.109
that you increase the loss as
much as possible with respect

01:09:34.109 --> 01:09:36.540
to the correct label y.

01:09:36.540 --> 01:09:38.790
And so if you've increased
the loss with respect

01:09:38.790 --> 01:09:40.672
to the correct
label y, intuitively

01:09:40.672 --> 01:09:42.630
then when you try to see,
well, what should you

01:09:42.630 --> 01:09:44.850
predict for this
new perturbed input,

01:09:44.850 --> 01:09:47.646
there's going to be a lower
loss for some alternative label,

01:09:47.646 --> 01:09:49.979
which is why the prediction--
the class that's predicted

01:09:49.979 --> 01:09:51.540
actually changes.

01:09:51.540 --> 01:09:54.690
So now one can try to find
these adversarial examples using

01:09:54.690 --> 01:09:57.090
the same type of gradient-based
learning algorithms

01:09:57.090 --> 01:10:01.040
that one uses for learning
in the first place.

01:10:01.040 --> 01:10:03.750
But what one can do is you
can use a gradient descent

01:10:03.750 --> 01:10:05.190
method now--

01:10:05.190 --> 01:10:07.980
instead of gradient
descent, gradient ascent.

01:10:07.980 --> 01:10:11.910
So you take this optimization
problem for a given input x,

01:10:11.910 --> 01:10:14.227
and you try to maximize
that loss for that input x

01:10:14.227 --> 01:10:15.810
with this vector
delta, and you're now

01:10:15.810 --> 01:10:18.300
doing gradient ascent.

01:10:18.300 --> 01:10:20.280
And so what types of
delta should you consider?

01:10:20.280 --> 01:10:22.260
You can imagine
small perturbations,

01:10:22.260 --> 01:10:25.890
for example, delta that have
very small maximum values.

01:10:25.890 --> 01:10:28.270
That would be an example
of an L-infinity norm.

01:10:28.270 --> 01:10:30.400
Or you could say that the
sum of the perturbations

01:10:30.400 --> 01:10:33.160
across, let's say, all of the
dimensions has to be small.

01:10:33.160 --> 01:10:36.430
That would be corresponding to
like an L1 or an L2 norm bound

01:10:36.430 --> 01:10:38.377
on what delta should be.

01:10:38.377 --> 01:10:40.210
So now we've got
everything we need actually

01:10:40.210 --> 01:10:42.880
to think about
defenses to this type

01:10:42.880 --> 01:10:46.460
of adversarial perturbation.

01:10:46.460 --> 01:10:49.728
So instead of minimizing your
typical empirical loss, what

01:10:49.728 --> 01:10:51.520
we're going to do is
we're going to attempt

01:10:51.520 --> 01:10:55.510
to minimize an adversarial
robust loss function.

01:10:55.510 --> 01:10:57.100
What we'll do is
we'll say, OK, we

01:10:57.100 --> 01:11:01.660
want to be sure that no matter
what the perturbation is

01:11:01.660 --> 01:11:05.500
that one adds the
input, the true label y

01:11:05.500 --> 01:11:07.680
still has low loss.

01:11:07.680 --> 01:11:10.840
So you want to find
parameters theta which

01:11:10.840 --> 01:11:12.820
minimize this new quantity.

01:11:12.820 --> 01:11:14.980
So I'm saying that
we should still

01:11:14.980 --> 01:11:21.375
do well even for the worst-case
adversarial perturbation.

01:11:21.375 --> 01:11:23.500
And so now this would be
the following new learning

01:11:23.500 --> 01:11:26.110
objective, where we're
going to minimize over

01:11:26.110 --> 01:11:28.580
theta with respect to
the maximum of our delta.

01:11:28.580 --> 01:11:31.780
And you have to restrict the
family that these perturbations

01:11:31.780 --> 01:11:34.087
could live in, so if
that delta that you're

01:11:34.087 --> 01:11:35.920
maximizing with respect
to is the empty set,

01:11:35.920 --> 01:11:38.170
you get back the original
learning problem.

01:11:38.170 --> 01:11:41.470
If you let it be, let's
say, all L-infinity

01:11:41.470 --> 01:11:46.437
bounded perturbations
of maximum size of 0.01,

01:11:46.437 --> 01:11:48.520
then you're saying we're
going to allow for a very

01:11:48.520 --> 01:11:49.930
small amount of perturbations.

01:11:49.930 --> 01:11:51.388
And the learning
algorithm is going

01:11:51.388 --> 01:11:53.800
to find parameters theta such
that for every input, even

01:11:53.800 --> 01:11:57.340
with a small perturbation
to it, adversarially chosen,

01:11:57.340 --> 01:11:59.590
you still get good
predictive performance.

01:11:59.590 --> 01:12:02.470
And this is now a new
optimization problem

01:12:02.470 --> 01:12:04.300
that one can solve.

01:12:04.300 --> 01:12:06.040
And we've now
reduced the problem

01:12:06.040 --> 01:12:09.070
of finding an adversarial robust
model to a new optimization

01:12:09.070 --> 01:12:09.940
problem.

01:12:09.940 --> 01:12:11.410
And what the field
has been doing

01:12:11.410 --> 01:12:13.810
in the last couple
of years is coming up

01:12:13.810 --> 01:12:15.970
with new optimization
approaches to try

01:12:15.970 --> 01:12:18.880
to solve those problems fast.

01:12:18.880 --> 01:12:22.690
So, for example, this paper
published an ICML in 2018

01:12:22.690 --> 01:12:24.940
by Zico Kolter and his student--

01:12:24.940 --> 01:12:28.792
Zico just visited
MIT a few weeks ago--

01:12:28.792 --> 01:12:30.250
what it did is it
said, we're going

01:12:30.250 --> 01:12:34.300
to use a convex relaxation
to the rectified linear unit,

01:12:34.300 --> 01:12:37.028
which is used in many deep
neural network architectures.

01:12:37.028 --> 01:12:39.070
And what it's going to do
it's then going to say,

01:12:39.070 --> 01:12:41.440
OK, we're going
to think about how

01:12:41.440 --> 01:12:45.790
a small perturbation
to the input

01:12:45.790 --> 01:12:49.000
would be propagated in terms
of getting how much that could

01:12:49.000 --> 01:12:50.300
actually change the output.

01:12:50.300 --> 01:12:52.630
And if one could be
bound at every layer

01:12:52.630 --> 01:12:55.420
by layer how much a small
perturbation affects

01:12:55.420 --> 01:12:57.250
the output of that
layer, then one

01:12:57.250 --> 01:12:59.350
could propagate
from the very bottom

01:12:59.350 --> 01:13:01.300
all the way to the loss
function of the top

01:13:01.300 --> 01:13:05.260
to try to bound how much the
loss function itself changes.

01:13:05.260 --> 01:13:08.780
And a picture of what you
would expect out is as follows.

01:13:08.780 --> 01:13:11.770
On the left-hand side here,
you have a data point,

01:13:11.770 --> 01:13:14.800
red and blue, and
the decision boundary

01:13:14.800 --> 01:13:19.420
that's learned if you didn't do
this robust learning algorithm.

01:13:19.420 --> 01:13:21.700
On the right, you have now--

01:13:21.700 --> 01:13:25.030
you'll notice a small square
around each data point.

01:13:25.030 --> 01:13:28.600
That corresponds to a
maximum perturbation

01:13:28.600 --> 01:13:29.793
of some limited amount.

01:13:29.793 --> 01:13:31.960
And now you notice how the
decision boundary doesn't

01:13:31.960 --> 01:13:33.520
cross any one of those squares.

01:13:33.520 --> 01:13:35.950
And that's what would be found
by this learning algorithm.

01:13:35.950 --> 01:13:38.200
Interestingly, one can
look at the filters that

01:13:38.200 --> 01:13:40.483
are learned by convolutional
neural network using

01:13:40.483 --> 01:13:41.650
this new learning algorithm.

01:13:41.650 --> 01:13:44.980
And you find that
they're much more sparse.

01:13:44.980 --> 01:13:48.340
And so this is a very
fast moving field.

01:13:48.340 --> 01:13:52.930
Every time a new
adversarial attack--

01:13:52.930 --> 01:13:55.660
every time a new adversarial
defense mechanism comes up,

01:13:55.660 --> 01:13:57.243
someone comes up
with a different type

01:13:57.243 --> 01:13:58.930
of attack, which breaks it.

01:13:58.930 --> 01:14:01.730
And usually that's from
one of two reasons.

01:14:01.730 --> 01:14:05.883
One, because the defense
mechanism isn't provable,

01:14:05.883 --> 01:14:08.300
and so one could try to come
up with a theorem which says,

01:14:08.300 --> 01:14:10.390
OK, as long as you
don't perturbate

01:14:10.390 --> 01:14:14.710
more than some amount, these are
the results you should expect.

01:14:14.710 --> 01:14:16.960
The other flip of the coin
is, even if you come up

01:14:16.960 --> 01:14:18.668
with some provable
guarantee, there might

01:14:18.668 --> 01:14:20.180
be other types of attacks.

01:14:20.180 --> 01:14:22.180
So, for example, you
might imagine a rotation

01:14:22.180 --> 01:14:25.480
to the input instead of
an L-infinity bounded norm

01:14:25.480 --> 01:14:26.530
that you add to it.

01:14:26.530 --> 01:14:29.200
And so for every new
type of attack model,

01:14:29.200 --> 01:14:31.750
you have to think through
new defense mechanisms.

01:14:31.750 --> 01:14:34.930
And so you should expect to see
some iteration in the space.

01:14:34.930 --> 01:14:38.200
And there's a website
called robust-ml.org,

01:14:38.200 --> 01:14:41.643
where many of these attacks and
defenses are being published

01:14:41.643 --> 01:14:43.810
to allow for the academic
community to make progress

01:14:43.810 --> 01:14:44.440
here.

01:14:44.440 --> 01:14:47.260
And with that, I'll
finish today's lecture.