WEBVTT

00:00:00.500 --> 00:00:01.956
[SQUEAKING]

00:00:01.956 --> 00:00:04.401
[RUSTLING]

00:00:04.401 --> 00:00:05.868
[CLICKING]

00:00:21.482 --> 00:00:22.940
DAVID SONTAG: So
today's lecture is

00:00:22.940 --> 00:00:25.520
going to continue on
the lecture that you

00:00:25.520 --> 00:00:27.950
saw on Tuesday, which
was introducing you

00:00:27.950 --> 00:00:29.300
to causal inference.

00:00:29.300 --> 00:00:33.410
So the causal inference
setting, which

00:00:33.410 --> 00:00:35.960
we're studying in this course,
is a really simplistic one

00:00:35.960 --> 00:00:37.490
from a causal
graphs perspective.

00:00:37.490 --> 00:00:41.030
There are three sets of
variables of interest--

00:00:41.030 --> 00:00:43.490
everything you know about an
individual or patient, which

00:00:43.490 --> 00:00:50.180
we're calling x over here;
and intervention or action--

00:00:50.180 --> 00:00:52.160
which for today's
lecture, we're going

00:00:52.160 --> 00:00:56.510
to suppose that it's either 0
or 1, so a binary intervention.

00:00:56.510 --> 00:00:58.940
You either take it or don't--

00:00:58.940 --> 00:01:01.280
and an outcome y.

00:01:01.280 --> 00:01:04.040
And what makes this
problem of understanding

00:01:04.040 --> 00:01:08.840
the impact of the intervention
on the outcome challenging

00:01:08.840 --> 00:01:10.700
is that we have to
make that inference

00:01:10.700 --> 00:01:14.757
from observational data, where
we don't have the ability--

00:01:14.757 --> 00:01:16.340
at least not in
medicine, we typically

00:01:16.340 --> 00:01:20.750
don't have the ability to
make active interventions.

00:01:20.750 --> 00:01:23.930
And the goal of what we will
be discussing in this course

00:01:23.930 --> 00:01:26.570
is about how to take
data that was collected

00:01:26.570 --> 00:01:29.600
from a practice of medicine
where actions or interventions

00:01:29.600 --> 00:01:32.510
were taken, and then use
that to infer something

00:01:32.510 --> 00:01:34.088
about the causal effect.

00:01:34.088 --> 00:01:36.380
And obviously, there are also
randomized control trials

00:01:36.380 --> 00:01:39.830
where one intentionally
does randomize,

00:01:39.830 --> 00:01:41.927
but the focus of
today's lecture is

00:01:41.927 --> 00:01:44.510
going to be using observational
data, or ready collected data,

00:01:44.510 --> 00:01:47.580
to try to make
these conclusions.

00:01:47.580 --> 00:01:51.260
So we introduced the language of
potential outcomes on Tuesday.

00:01:51.260 --> 00:01:54.200
Potential outcomes is the
mathematical framework

00:01:54.200 --> 00:01:56.330
for trying to answer
these questions.

00:01:56.330 --> 00:01:59.030
Then with that definition
of potential outcomes,

00:01:59.030 --> 00:02:02.060
we can define the conditional
average treatment effect,

00:02:02.060 --> 00:02:07.280
which is the difference
between Y1 and Y0

00:02:07.280 --> 00:02:09.860
for the individual Xi.

00:02:09.860 --> 00:02:13.130
So you'll notice here
that I have patients,

00:02:13.130 --> 00:02:16.010
so treating the
potential outcome

00:02:16.010 --> 00:02:18.620
as a random variable
in case there

00:02:18.620 --> 00:02:19.880
might be some stochasticity.

00:02:19.880 --> 00:02:23.000
So sometimes, maybe if you were
to give someone a treatment,

00:02:23.000 --> 00:02:25.500
it works, and
sometimes it doesn't.

00:02:25.500 --> 00:02:28.850
So that's what the
expectation is accounting for.

00:02:28.850 --> 00:02:30.350
Any questions before I move on?

00:02:34.450 --> 00:02:37.275
So with respect
to this definition

00:02:37.275 --> 00:02:39.150
of conditional average
treatment effect, then

00:02:39.150 --> 00:02:41.275
you could ask, well, what
would happen in aggregate

00:02:41.275 --> 00:02:43.290
for the population?

00:02:43.290 --> 00:02:44.990
And you can compute
that by taking

00:02:44.990 --> 00:02:48.230
the average of the conditional
average treatment effect

00:02:48.230 --> 00:02:50.210
over all of the individuals.

00:02:50.210 --> 00:02:53.810
So that's just this expectation
with respect to, now, p of x.

00:02:53.810 --> 00:02:57.290
Now, critically, this
distribution, p of x, you

00:02:57.290 --> 00:03:00.830
should think about as the
distribution of everyone

00:03:00.830 --> 00:03:04.020
that exists in your data.

00:03:04.020 --> 00:03:06.200
So some of those individuals
might have received

00:03:06.200 --> 00:03:07.610
treatment 1 in the past.

00:03:07.610 --> 00:03:10.498
Some of them might have
received treatment 0.

00:03:10.498 --> 00:03:13.040
But when we ask this question
about average treatment effect,

00:03:13.040 --> 00:03:15.833
we're asking, for both of
those populations, what

00:03:15.833 --> 00:03:17.000
would have been the effect--

00:03:17.000 --> 00:03:20.050
what would have been the
difference about [INAUDIBLE]

00:03:20.050 --> 00:03:26.080
they received treatment 1 minus
had they received treatment 0?

00:03:26.080 --> 00:03:29.050
Now, I wanted to
take this opportunity

00:03:29.050 --> 00:03:32.410
to start thinking a little
bit bigger picture about how

00:03:32.410 --> 00:03:38.500
causal inference can be
important in a variety

00:03:38.500 --> 00:03:42.640
of societal
questions, and so I'd

00:03:42.640 --> 00:03:46.060
like to now spend just a
couple of minutes thinking

00:03:46.060 --> 00:03:48.820
with you about what some
causal questions might

00:03:48.820 --> 00:03:51.580
be that we urgently need to
answer about the COVID-19

00:03:51.580 --> 00:03:52.660
pandemic.

00:03:52.660 --> 00:03:55.060
And as you try to think
through these questions,

00:03:55.060 --> 00:03:57.850
I want you to have this
causal graph in mind.

00:03:57.850 --> 00:04:01.240
So there is the
general population.

00:04:01.240 --> 00:04:04.720
There is some action
that you want to perform,

00:04:04.720 --> 00:04:09.130
and the whole notion
of causal inferences

00:04:09.130 --> 00:04:13.520
assessing the effective action
on some outcome of interest.

00:04:13.520 --> 00:04:16.570
So in trying to give
the answer to my--

00:04:16.570 --> 00:04:18.250
various answers to
my questions of what

00:04:18.250 --> 00:04:20.250
are some causal inference
questions of relevance

00:04:20.250 --> 00:04:22.800
to the current
pandemic, I want you

00:04:22.800 --> 00:04:27.550
to try to frame your answers in
terms of these Xs, Ts, and Ys.

00:04:27.550 --> 00:04:30.850
It's also, obviously,
very hard to answer

00:04:30.850 --> 00:04:33.190
using the types of techniques
that we will be discussing

00:04:33.190 --> 00:04:36.520
in this course, and partly
because the techniques that I'm

00:04:36.520 --> 00:04:40.270
focusing on are very much
data driven techniques.

00:04:40.270 --> 00:04:44.890
That said, the general framework
that I've introduced on Tuesday

00:04:44.890 --> 00:04:50.110
for covariate adjustment
of, come up with a model

00:04:50.110 --> 00:04:52.850
and use that model
to make a prediction,

00:04:52.850 --> 00:04:57.358
and the assumptions that
underlie that in terms of,

00:04:57.358 --> 00:04:59.650
well, where's that model
coming from, if you're fitting

00:04:59.650 --> 00:05:03.400
the parameters from data, having
to have common support in order

00:05:03.400 --> 00:05:09.250
to be able to have any trust
in the downstream conclusions.

00:05:09.250 --> 00:05:11.530
Those underlying assumptions
and the general premises

00:05:11.530 --> 00:05:13.247
will still hold,
but here, obviously,

00:05:13.247 --> 00:05:15.330
when it comes to something
like social distancing,

00:05:15.330 --> 00:05:18.520
they're complicated
network effects.

00:05:18.520 --> 00:05:21.220
And so whereas up
until now, we've

00:05:21.220 --> 00:05:24.270
been making the assumption
of what was called SUTVA--

00:05:24.270 --> 00:05:29.170
it was a assumption that I
probably didn't even talk about

00:05:29.170 --> 00:05:30.850
in Tuesday's lecture.

00:05:30.850 --> 00:05:34.780
But intuitively, what
the SUTVA assumption says

00:05:34.780 --> 00:05:36.940
is that each of your
training examples

00:05:36.940 --> 00:05:39.040
are independent of each other.

00:05:39.040 --> 00:05:41.140
And that might make sense
when you think about,

00:05:41.140 --> 00:05:44.122
give a patient a
medication or not,

00:05:44.122 --> 00:05:45.580
but it certainly
doesn't make sense

00:05:45.580 --> 00:05:48.520
when you think about social
distancing type measures,

00:05:48.520 --> 00:05:50.650
where if some people
social distance,

00:05:50.650 --> 00:05:53.440
but other people don't,
it has obviously a very

00:05:53.440 --> 00:05:56.480
different impact on society.

00:05:56.480 --> 00:05:59.267
So one needs a different
class of models

00:05:59.267 --> 00:06:00.850
to try to think about
that, which have

00:06:00.850 --> 00:06:03.850
to relax that SUTVA assumption.

00:06:03.850 --> 00:06:07.540
So those were all really
good answers to my question,

00:06:07.540 --> 00:06:12.120
and in some sense, now--

00:06:12.120 --> 00:06:14.260
so there's the
epidemiological type questions

00:06:14.260 --> 00:06:16.330
that we last spoke about.

00:06:16.330 --> 00:06:18.070
But the first few set
of questions about,

00:06:18.070 --> 00:06:23.830
really, how does one treat
patients who have COVID

00:06:23.830 --> 00:06:26.110
are the types of questions
that only now we can really

00:06:26.110 --> 00:06:27.655
start to answer
now, unfortunately,

00:06:27.655 --> 00:06:30.030
because we're starting to get
a lot of data in the United

00:06:30.030 --> 00:06:31.960
States and internationally.

00:06:31.960 --> 00:06:35.020
And so for example, my own
personal research group,

00:06:35.020 --> 00:06:36.430
we're starting to
really scale up

00:06:36.430 --> 00:06:38.970
our research on these
types of questions.

00:06:38.970 --> 00:06:43.030
Now, one very simplified
example that I

00:06:43.030 --> 00:06:46.630
wanted to give of how a causal
inference lens can be useful

00:06:46.630 --> 00:06:50.470
here is by trying to
understand case fatality rates.

00:06:50.470 --> 00:06:53.260
So for example, in
Italy, it was reported

00:06:53.260 --> 00:06:59.890
that 4.3% of individuals
who had this condition

00:06:59.890 --> 00:07:02.260
passed away,
whereas in China, it

00:07:02.260 --> 00:07:04.870
was reported that 2.3%
of individuals who

00:07:04.870 --> 00:07:07.660
had this condition passed away.

00:07:07.660 --> 00:07:10.960
Now, you might ask, based
on just those two numbers,

00:07:10.960 --> 00:07:12.500
is something
different about China?

00:07:12.500 --> 00:07:16.720
For example, might
it be that the way

00:07:16.720 --> 00:07:22.180
that COVID is being managed in
China is better than in Italy?

00:07:22.180 --> 00:07:26.380
You might also wonder if
the strain of the disease

00:07:26.380 --> 00:07:31.290
might be different
between China and Italy?

00:07:31.290 --> 00:07:37.930
So perhaps there were some
mutations since it left Wuhan.

00:07:37.930 --> 00:07:40.700
But if you dig a
little bit deeper,

00:07:40.700 --> 00:07:45.895
you see that, if you plot case
fatality rates by age group,

00:07:45.895 --> 00:07:47.770
you get this plot that
I'm showing over here.

00:07:47.770 --> 00:07:49.353
And you see that if
you compare Italy,

00:07:49.353 --> 00:07:53.800
which is the orange, to
China, which is blue, now

00:07:53.800 --> 00:07:58.440
stratified by age range, you
see that for every single age

00:07:58.440 --> 00:08:04.620
range, the percentage
of deaths is lower

00:08:04.620 --> 00:08:07.260
in Italy than in
China, which would

00:08:07.260 --> 00:08:10.445
seem to be a contradiction
with what we saw--

00:08:10.445 --> 00:08:11.820
with the aggregate
numbers, where

00:08:11.820 --> 00:08:15.630
we see that the case
fatality rate in Italy

00:08:15.630 --> 00:08:17.730
is higher than in China.

00:08:17.730 --> 00:08:23.760
And so the reason why this can
happen has to do with the fact

00:08:23.760 --> 00:08:26.320
that the populations
are very different.

00:08:26.320 --> 00:08:29.700
And by the way, this
paradox goes by the name

00:08:29.700 --> 00:08:33.059
of Simpson's paradox.

00:08:33.059 --> 00:08:34.980
So if you dig a bit
deeper, you see then

00:08:34.980 --> 00:08:36.840
that, if you're
to look at, well,

00:08:36.840 --> 00:08:39.539
what is the distribution of
individuals in China and Italy

00:08:39.539 --> 00:08:45.480
that have been
reported to have COVID,

00:08:45.480 --> 00:08:49.110
you see that, in Italy,
it's much more highly

00:08:49.110 --> 00:08:53.730
weighted towards
these older ages.

00:08:53.730 --> 00:09:00.130
And if you then combine that
with the total number of cases,

00:09:00.130 --> 00:09:03.240
you get you get to
these discrepancies,

00:09:03.240 --> 00:09:06.100
so it now fully explains
these two numbers

00:09:06.100 --> 00:09:07.640
and the plot that you see.

00:09:07.640 --> 00:09:10.140
Now if we're to try to think
about this a bit more formally,

00:09:10.140 --> 00:09:11.760
we would try to
formalize it in terms

00:09:11.760 --> 00:09:14.680
of following causal graph.

00:09:14.680 --> 00:09:19.690
And so here, we have the
same notions of X, T, and Y,

00:09:19.690 --> 00:09:23.110
where X is the age
of an individual who

00:09:23.110 --> 00:09:26.230
has been diagnosed with COVID.

00:09:26.230 --> 00:09:30.700
T is now country, so we're going
to think about the intervention

00:09:30.700 --> 00:09:35.720
here as transporting
ourselves from China to Italy,

00:09:35.720 --> 00:09:39.130
so thinking about changing
the environment altogether.

00:09:39.130 --> 00:09:42.480
And Y is the outcome on
an individual level basis.

00:09:42.480 --> 00:09:44.140
And so the formal
question that one

00:09:44.140 --> 00:09:47.470
might want to ask is about
a causal impact of changing

00:09:47.470 --> 00:09:51.130
the country on the outcome Y.

00:09:51.130 --> 00:09:54.245
Now, for this particular
causal question,

00:09:54.245 --> 00:09:56.620
this causal graph that I'm
drawing here is the wrong one,

00:09:56.620 --> 00:09:59.470
and in fact, the right
causal graph probably

00:09:59.470 --> 00:10:04.390
has an edge that
goes from T to X.

00:10:04.390 --> 00:10:09.700
In particular, the distribution
of individuals in the country

00:10:09.700 --> 00:10:11.350
is obviously a function
of the country,

00:10:11.350 --> 00:10:14.403
not the other way around.

00:10:14.403 --> 00:10:15.820
But despite the
fact that there is

00:10:15.820 --> 00:10:17.637
that difference
in directionality,

00:10:17.637 --> 00:10:19.720
all of the techniques that
we've been teaching you

00:10:19.720 --> 00:10:21.730
in this course are still
applicable for trying

00:10:21.730 --> 00:10:25.960
to ask a causal question about
the impact of intervening

00:10:25.960 --> 00:10:34.370
on a country, and that's
really because, in some sense,

00:10:34.370 --> 00:10:37.180
these two distributions,
at an observational level,

00:10:37.180 --> 00:10:39.167
are equivalent.

00:10:39.167 --> 00:10:41.750
And if you want to dig a little
bit deeper into this example--

00:10:41.750 --> 00:10:44.800
and I want to stress this is
just for educational purposes.

00:10:47.350 --> 00:10:50.560
Don't read anything
into these numbers--

00:10:50.560 --> 00:10:54.810
I would go to this Colab
notebook after the course.

00:10:54.810 --> 00:10:57.360
So all of this was
just a little bit

00:10:57.360 --> 00:11:01.860
of set up to help frame where
causal inference shows up

00:11:01.860 --> 00:11:04.110
and some things that
we've been thinking

00:11:04.110 --> 00:11:07.080
and really very worried and
stressed about ourselves

00:11:07.080 --> 00:11:09.580
personally recently.

00:11:09.580 --> 00:11:12.570
And I want to now shift
gears to starting to get back

00:11:12.570 --> 00:11:15.300
to the course material,
and in particular, I

00:11:15.300 --> 00:11:18.270
want to start today's
more theoretical parts

00:11:18.270 --> 00:11:20.670
of the lectures by returning
to covariate adjustment,

00:11:20.670 --> 00:11:22.610
which we ended on and Tuesday.

00:11:22.610 --> 00:11:25.650
In covariate adjustment, one--

00:11:25.650 --> 00:11:30.180
we'll use a machine learning
approach to learn some model,

00:11:30.180 --> 00:11:34.200
which I'll call F. So you could
imagine a black box machine

00:11:34.200 --> 00:11:38.730
learning algorithm, which
takes as input both X and T.

00:11:38.730 --> 00:11:42.270
So X are your covariates of
the individual that are going

00:11:42.270 --> 00:11:45.570
to receive the treatment, and
T is that treatment decision,

00:11:45.570 --> 00:11:49.320
which for today's lecture, you
can just assume is binary 01,

00:11:49.320 --> 00:11:53.590
and uses those together now
to predict the outcome Y.

00:11:53.590 --> 00:11:57.380
Now, what we showed on Tuesday
was that, under ignorability,

00:11:57.380 --> 00:11:59.510
where ignorability,
remember, was

00:11:59.510 --> 00:12:01.910
the assumption of no
hitting confounding,

00:12:01.910 --> 00:12:05.000
then the conditional
average treatment effect

00:12:05.000 --> 00:12:09.260
could be defined as
just a difference--

00:12:09.260 --> 00:12:15.110
could be could be computed
as the expectation of Y1

00:12:15.110 --> 00:12:17.540
now conditioned on
T equals 1, so this

00:12:17.540 --> 00:12:19.190
is the piece that
I've added in here,

00:12:19.190 --> 00:12:24.830
and minus the expectation of Y0
now conditioned on T equal 0.

00:12:24.830 --> 00:12:27.380
And it's that conditioning
which is really important,

00:12:27.380 --> 00:12:33.080
because that's what enables you
to estimate Y1 from data where

00:12:33.080 --> 00:12:36.020
treatment 1 was observed,
whereas you never

00:12:36.020 --> 00:12:42.030
get to observe Y1 in data when
treatment 0 was performed.

00:12:42.030 --> 00:12:46.670
So we have this formula, and
after fitting that model F,

00:12:46.670 --> 00:12:49.250
one could then use it
to try to estimate CATE

00:12:49.250 --> 00:12:52.190
by just taking that
learned function,

00:12:52.190 --> 00:12:59.420
plugging in the number 1
for the treatment variable

00:12:59.420 --> 00:13:02.330
in order to get your
estimate of this expectation,

00:13:02.330 --> 00:13:06.380
and then plugging in the number
0 for the treatment variable

00:13:06.380 --> 00:13:09.932
when you want to get your
estimate of this expectation.

00:13:09.932 --> 00:13:11.390
Taking the difference
between those

00:13:11.390 --> 00:13:13.480
then gives you your estimate
of the conditional average

00:13:13.480 --> 00:13:14.210
treatment effect.

00:13:17.220 --> 00:13:21.490
So that's the approach, and what
we didn't talk about so much

00:13:21.490 --> 00:13:26.740
was the modeling choices of what
should your function class be.

00:13:26.740 --> 00:13:30.280
So this is going to turn
out to be really important,

00:13:30.280 --> 00:13:33.100
and really, the punchline
of the next several slides

00:13:33.100 --> 00:13:35.950
is going to be a major
difference in philosophy

00:13:35.950 --> 00:13:38.800
between machine
learning and statistics,

00:13:38.800 --> 00:13:43.670
and between prediction
and causal inference.

00:13:43.670 --> 00:13:47.440
So let's now consider the
following simple model, where

00:13:47.440 --> 00:13:53.500
I'm going to assume that the
ground truth in the real world

00:13:53.500 --> 00:13:58.120
has that the potential outcome
YT of X, where T, again

00:13:58.120 --> 00:14:04.840
is the treatment, is equal
to some simple linear model

00:14:04.840 --> 00:14:08.770
involving the covariates
X and the treatments

00:14:08.770 --> 00:14:13.840
T, the treatment T. So in
this very simple setting,

00:14:13.840 --> 00:14:16.900
I'm going to assume that we
just have a single feature

00:14:16.900 --> 00:14:19.850
or covariate for the
individual, which is there age.

00:14:19.850 --> 00:14:22.810
I'm going to assume
that this model doesn't

00:14:22.810 --> 00:14:25.660
have any terms with an
interaction between X and T,

00:14:25.660 --> 00:14:30.500
so it's fully linear in X and T.

00:14:30.500 --> 00:14:36.890
So this is an assumption about
the true potential outcomes,

00:14:36.890 --> 00:14:39.980
and what we'll do over
the next couple of slides

00:14:39.980 --> 00:14:44.210
is think about what would happen
if you now modeled Y of T,

00:14:44.210 --> 00:14:46.970
so modeling it with
some function F, where

00:14:46.970 --> 00:14:50.060
F was, let's say, a linear
function versus a nonlinear

00:14:50.060 --> 00:14:55.112
function, if F took this
form or a different form.

00:14:55.112 --> 00:14:56.570
And by the way,
I'm going to assume

00:14:56.570 --> 00:15:00.560
that the noise here,
epsilon t, can be arbitrary,

00:15:00.560 --> 00:15:04.480
but that it has 0 mean.

00:15:04.480 --> 00:15:06.960
So let's get started by
trying to estimate what

00:15:06.960 --> 00:15:09.810
the true CATE is, or
Conditional Average Treatment

00:15:09.810 --> 00:15:14.550
Effect, for this
potential outcome model.

00:15:14.550 --> 00:15:16.000
Well, just by
definition, the CATE

00:15:16.000 --> 00:15:19.050
is the expectation
of Y1 minus Y0.

00:15:19.050 --> 00:15:22.200
We're going to
take this formula,

00:15:22.200 --> 00:15:28.770
and we're going to plug it
in for the first term using

00:15:28.770 --> 00:15:32.640
T equals 1, and that's why
you get this term over here

00:15:32.640 --> 00:15:34.590
with gamma.

00:15:34.590 --> 00:15:37.633
And the gamma is because,
again, T is equal to 1.

00:15:37.633 --> 00:15:39.300
We're also going to
take this, and we're

00:15:39.300 --> 00:15:42.900
going to plug it in for,
now, this term over here,

00:15:42.900 --> 00:15:44.700
where T is equal to 0.

00:15:44.700 --> 00:15:48.600
And when T is equal to 0, then
the gamma term just disappears,

00:15:48.600 --> 00:15:53.370
and so you just get
beta X plus epsilon 0.

00:15:53.370 --> 00:15:58.620
So all I've done so far
is plug in the Y1 and Y0

00:15:58.620 --> 00:16:03.180
according to the assumed form,
but notice now that there's

00:16:03.180 --> 00:16:05.760
some terms that cancel
out-- in particular,

00:16:05.760 --> 00:16:08.310
the beta X term over
here cancels out

00:16:08.310 --> 00:16:10.840
with a beta X term over here.

00:16:10.840 --> 00:16:17.730
And because epsilon 1 has a
0 mean, and epsilon 0 also

00:16:17.730 --> 00:16:19.200
has a 0 mean.

00:16:19.200 --> 00:16:22.860
The only thing left
is that gamma term,

00:16:22.860 --> 00:16:25.630
and expectation of a constant's
obviously that constant.

00:16:25.630 --> 00:16:27.510
And so what we
conclude from this

00:16:27.510 --> 00:16:30.510
is that the CATE value is gamma.

00:16:33.020 --> 00:16:35.670
Now, the average
treatment effect,

00:16:35.670 --> 00:16:38.450
which is the average of
CATE over all individuals X,

00:16:38.450 --> 00:16:41.430
will then also be
gamma, obviously.

00:16:41.430 --> 00:16:43.430
So we've done something
pretty interesting here.

00:16:43.430 --> 00:16:45.560
We've started from
the assumption

00:16:45.560 --> 00:16:48.980
that the true potential
outcome model is linear,

00:16:48.980 --> 00:16:51.920
and what we concluded is that
the average treatment effect is

00:16:51.920 --> 00:16:56.210
precisely the coefficient
of the treatment

00:16:56.210 --> 00:16:58.580
variable in this linear model.

00:17:01.330 --> 00:17:04.770
So what that means is
that, if what you're

00:17:04.770 --> 00:17:08.190
interested in is
causal inference,

00:17:08.190 --> 00:17:10.980
and suppose that we
were lucky enough

00:17:10.980 --> 00:17:12.930
to know that the true
model were linear,

00:17:12.930 --> 00:17:15.810
and so we attempted to fit some
function F, which had precisely

00:17:15.810 --> 00:17:22.619
the same form, we get some
beta hats and some gamma hats

00:17:22.619 --> 00:17:26.040
out from the learning
algorithm, all we need to do

00:17:26.040 --> 00:17:28.650
is look at that
gamma hat in order

00:17:28.650 --> 00:17:31.470
to conclude something about
the average treatment effect.

00:17:31.470 --> 00:17:33.870
No need to do this
complicated thing of plugging

00:17:33.870 --> 00:17:36.950
in to estimate CATEs.

00:17:36.950 --> 00:17:39.270
And again, the reason it's
such a trivial conclusion

00:17:39.270 --> 00:17:42.520
is because of our
assumption of linearity.

00:17:42.520 --> 00:17:45.760
Now, what that also
means is that, if you

00:17:45.760 --> 00:17:48.820
have errors in
learning-- in particular,

00:17:48.820 --> 00:17:52.040
suppose, for example, that
you are estimating your gamma

00:17:52.040 --> 00:17:54.400
hat wrongly, then
that means you're also

00:17:54.400 --> 00:17:57.100
going to be getting
wrong your estimates

00:17:57.100 --> 00:18:00.085
of your conditional and
average treatment effects.

00:18:02.958 --> 00:18:05.000
There's a question here,
which I was lucky enough

00:18:05.000 --> 00:18:07.280
to see, that says, what
does gamma represent

00:18:07.280 --> 00:18:09.530
in terms of the medication?

00:18:09.530 --> 00:18:11.820
Thank you for that question.

00:18:11.820 --> 00:18:18.340
So gamma is--
literally speaking,

00:18:18.340 --> 00:18:21.730
gamma tells you the conditional
average treatment effect,

00:18:21.730 --> 00:18:26.850
meaning if you were to give
the treatment versus not

00:18:26.850 --> 00:18:29.100
giving the treatment, how
that affects the outcome.

00:18:29.100 --> 00:18:31.230
Think about the outcome
of interest being

00:18:31.230 --> 00:18:32.940
the patient's blood
pressure, there

00:18:32.940 --> 00:18:36.120
being potential confounding
factor of the patient's age,

00:18:36.120 --> 00:18:40.140
and T being one of two different
blood pressure measurements.

00:18:40.140 --> 00:18:43.830
If gamma is positive,
then it means

00:18:43.830 --> 00:18:46.740
that treatment 1 is more--

00:18:46.740 --> 00:18:49.770
treatment 1 increases the
patient's blood pressure

00:18:49.770 --> 00:18:51.540
relative to treatment 0.

00:18:51.540 --> 00:18:53.670
And if gamma is
negative, it means

00:18:53.670 --> 00:18:56.760
that treatment 1 decreases
the patient's blood pressure

00:18:56.760 --> 00:18:58.140
relative to treatment 0.

00:19:04.120 --> 00:19:06.140
So in machine learning--

00:19:06.140 --> 00:19:08.800
oh, sorry, there's another chat.

00:19:08.800 --> 00:19:10.240
Thank you, good.

00:19:10.240 --> 00:19:14.680
So in machine learning, I
typically tell my students,

00:19:14.680 --> 00:19:18.330
don't attempt to interpret
your coefficient.

00:19:18.330 --> 00:19:20.360
At least, don't
interpret them too much.

00:19:20.360 --> 00:19:22.430
Don't put too much
weight into them,

00:19:22.430 --> 00:19:24.460
and that's because,
when you're learning

00:19:24.460 --> 00:19:26.140
very high dimensional
models, there

00:19:26.140 --> 00:19:29.220
can be a lot of redundancy
between your features.

00:19:29.220 --> 00:19:30.960
But when you talk
to statisticians,

00:19:30.960 --> 00:19:32.532
often they pay really
close attention

00:19:32.532 --> 00:19:33.990
to their coefficients,
and they try

00:19:33.990 --> 00:19:36.330
to interpret those coefficients
often with the causal lens.

00:19:36.330 --> 00:19:38.220
And when I first got
started in this field,

00:19:38.220 --> 00:19:40.470
I couldn't understand why
are they paying attention

00:19:40.470 --> 00:19:41.785
to those coefficients so much?

00:19:41.785 --> 00:19:44.160
Why are they coming up with
these causal hypotheses based

00:19:44.160 --> 00:19:45.577
on which coefficients
are positive

00:19:45.577 --> 00:19:46.770
and which are the negative?

00:19:46.770 --> 00:19:48.930
And this is the answer.

00:19:48.930 --> 00:19:52.890
It really comes down
to an interpretation

00:19:52.890 --> 00:19:54.900
of the prediction
problem in terms

00:19:54.900 --> 00:19:57.840
of the feature of
relevance being

00:19:57.840 --> 00:20:01.650
a treatment, that treatment
being linear with respect

00:20:01.650 --> 00:20:04.590
to the potential
outcome, and then

00:20:04.590 --> 00:20:06.780
looking at the coefficient
of the treatment

00:20:06.780 --> 00:20:09.300
as telling you something
about the average treatment

00:20:09.300 --> 00:20:12.570
effect of that
intervention or treatment.

00:20:12.570 --> 00:20:16.200
Moreover, that also tells us
why it's often very important

00:20:16.200 --> 00:20:21.780
to look at confidence intervals,
so one might want to know,

00:20:21.780 --> 00:20:28.350
we have some small data set, we
get some estimate of gamma hat,

00:20:28.350 --> 00:20:31.740
but what if you had
a different data set?

00:20:31.740 --> 00:20:36.510
So what happens if you had a
new sample of 100 data points?

00:20:36.510 --> 00:20:38.227
How would your estimated
gamma hat vary?

00:20:38.227 --> 00:20:40.060
And so you might be
interested, for example,

00:20:40.060 --> 00:20:42.518
in confidence intervals, like
a 95% confident interval that

00:20:42.518 --> 00:20:50.820
says that gamma hat is
between, let's say, 1

00:20:50.820 --> 00:20:58.170
and, let's say maybe, 0.5
with probability 0.95.

00:20:58.170 --> 00:21:00.840
That'll be an example
of a confidence

00:21:00.840 --> 00:21:02.820
interval around gamma hat.

00:21:02.820 --> 00:21:04.377
And such a confidence
interval then

00:21:04.377 --> 00:21:06.210
gives you confidence--
a confidence interval

00:21:06.210 --> 00:21:07.440
around the coefficients,
then gives you

00:21:07.440 --> 00:21:09.870
confidence intervals around
the average treatment

00:21:09.870 --> 00:21:13.040
effect via this analysis.

00:21:13.040 --> 00:21:16.520
So the second
observation is what

00:21:16.520 --> 00:21:18.290
happens if the true
model isn't linear,

00:21:18.290 --> 00:21:20.510
but we hadn't realized
that as a modeler,

00:21:20.510 --> 00:21:24.975
and we had just assumed that,
well, the linear model's

00:21:24.975 --> 00:21:25.850
probably good enough?

00:21:25.850 --> 00:21:28.370
And maybe even, the linear model
gets pretty good prediction

00:21:28.370 --> 00:21:30.100
performance?

00:21:30.100 --> 00:21:33.690
Well, let's look at the
extreme example of this.

00:21:33.690 --> 00:21:37.110
Let's now assume that the
true data generating process,

00:21:37.110 --> 00:21:40.680
instead of being just
beta X plus gamma T,

00:21:40.680 --> 00:21:45.733
we're going to add in now a new
term, delta times X squared.

00:21:45.733 --> 00:21:50.508
Now, this is the
most naive extension

00:21:50.508 --> 00:21:52.050
of the original
linear model that you

00:21:52.050 --> 00:21:55.710
could imagine, because I'm not
even adding any interaction

00:21:55.710 --> 00:22:00.750
terms like 10 times XT.

00:22:00.750 --> 00:22:03.600
So no interaction
terms involving

00:22:03.600 --> 00:22:05.720
treatment and covariate.

00:22:05.720 --> 00:22:07.650
Treatment is still--
the potential outcome

00:22:07.650 --> 00:22:09.570
is still linear in treatment.

00:22:09.570 --> 00:22:11.670
We're just adding a
single nonlinear term

00:22:11.670 --> 00:22:15.500
involving one of the features.

00:22:15.500 --> 00:22:17.570
Now, if you compute
the average treatment

00:22:17.570 --> 00:22:19.340
effect via the same
analysis we did

00:22:19.340 --> 00:22:24.190
before, you'll again find that
our treatment effect is gamma.

00:22:24.190 --> 00:22:26.740
Let's suppose now
that we hadn't known

00:22:26.740 --> 00:22:29.480
that there was that delta
X squared term in there,

00:22:29.480 --> 00:22:33.670
and we hypothesized that the
potential outcome was given

00:22:33.670 --> 00:22:36.310
to you by this linear
model involving X and T.

00:22:36.310 --> 00:22:39.190
And I'm going to use Y
hat to denote that that's

00:22:39.190 --> 00:22:42.580
going to be the function family
that we're going to be fitting.

00:22:42.580 --> 00:22:45.590
So we now fit that
beta hat in gamma hat,

00:22:45.590 --> 00:22:49.300
and if you had infinite data
drawn from this true generating

00:22:49.300 --> 00:22:52.750
process, which is, again,
unknown, what one can show

00:22:52.750 --> 00:22:55.840
is that the gamma hat
that you would estimate

00:22:55.840 --> 00:22:59.890
using any reasonable estimator,
like a least squared estimator,

00:22:59.890 --> 00:23:04.600
is actually equal to
gamma, the true ATE value,

00:23:04.600 --> 00:23:08.370
plus delta times this term.

00:23:08.370 --> 00:23:14.500
And notice that this term does
not depend on beta or gamma.

00:23:14.500 --> 00:23:18.290
What this means is,
depending on delta,

00:23:18.290 --> 00:23:21.070
your gamma hat could be
made arbitrarily large

00:23:21.070 --> 00:23:22.400
or arbitrarily small.

00:23:22.400 --> 00:23:25.480
So for example, if
delta is very large,

00:23:25.480 --> 00:23:28.090
gamma hat might
become positive when

00:23:28.090 --> 00:23:29.620
gamma might have been negative.

00:23:29.620 --> 00:23:32.560
And so your conclusions about
the average treatment effect

00:23:32.560 --> 00:23:36.910
could be completely wrong,
and this should scare you.

00:23:36.910 --> 00:23:41.380
This is the thing which makes
using covariate adjustments so

00:23:41.380 --> 00:23:43.030
dangerous, which
is that if you're

00:23:43.030 --> 00:23:46.300
making the wrong assumptions
about the true potential

00:23:46.300 --> 00:23:51.740
outcomes, you could get
very, very wrong conclusions.

00:23:51.740 --> 00:23:55.960
So because of
that, one typically

00:23:55.960 --> 00:23:57.640
wants to live in
a world where you

00:23:57.640 --> 00:24:00.473
don't have to make many
assumptions about the form,

00:24:00.473 --> 00:24:02.890
so that you could try to fit
the data as well as possible.

00:24:02.890 --> 00:24:05.740
So here, you see that there
is this nonlinear term.

00:24:05.740 --> 00:24:09.670
Well, obviously, if you had
used some nonlinear modeling

00:24:09.670 --> 00:24:12.280
algorithm, like a neural network
or maybe a random forest,

00:24:12.280 --> 00:24:15.070
then it would have the potential
to fix that nonlinear function,

00:24:15.070 --> 00:24:19.570
and then maybe we wouldn't
get caught in this same trap.

00:24:19.570 --> 00:24:21.820
And there are a variety of
machine learning algorithms

00:24:21.820 --> 00:24:24.550
that have been applied to
causal inference, everything

00:24:24.550 --> 00:24:28.060
from random forests and
Bayesian additive regression

00:24:28.060 --> 00:24:30.760
trees to algorithms
like Gaussian processes

00:24:30.760 --> 00:24:32.090
and deep neural networks.

00:24:32.090 --> 00:24:35.100
I'll just briefly
highlight the last two.

00:24:35.100 --> 00:24:37.170
So Gaussian processes
are very often

00:24:37.170 --> 00:24:41.910
used to model continuous
valued potential outcomes,

00:24:41.910 --> 00:24:44.310
and there are a couple of ways
in which they can be done.

00:24:44.310 --> 00:24:46.680
So for example,
one class of models

00:24:46.680 --> 00:24:52.650
might treat Y1 and Y0 as two
separate Gaussian processes

00:24:52.650 --> 00:24:56.580
and fit those to the data.

00:24:56.580 --> 00:24:58.950
A different approach,
shown on the right here,

00:24:58.950 --> 00:25:09.540
would be to treat T as
an additional covariate,

00:25:09.540 --> 00:25:13.320
so now you have X and
T as your features

00:25:13.320 --> 00:25:19.090
and fit a Gaussian process
for that joint model.

00:25:19.090 --> 00:25:20.650
When it comes to
neural networks,

00:25:20.650 --> 00:25:23.680
neural networks had been used
in causal inference going back

00:25:23.680 --> 00:25:31.270
about 20, 30 years, but really
started catching on a few years

00:25:31.270 --> 00:25:35.810
ago with a paper that
I wrote in my group

00:25:35.810 --> 00:25:37.700
as being one of
the earliest papers

00:25:37.700 --> 00:25:40.220
from this recent generation
of using neural networks

00:25:40.220 --> 00:25:42.060
for causal inference.

00:25:42.060 --> 00:25:46.100
And one of the things that we
found to work very effectively

00:25:46.100 --> 00:25:52.110
is to use a joint model for
predicting the causal effect,

00:25:52.110 --> 00:25:56.330
so we're going to be
learning a model that takes--

00:25:56.330 --> 00:26:05.090
an F that takes, as input,
X and T and has to predict

00:26:05.090 --> 00:26:08.720
Y. And the advantage
of that is that it's

00:26:08.720 --> 00:26:12.810
going to allow us to share
parameters across your T

00:26:12.810 --> 00:26:15.110
equals 1 and T equals 0 samples.

00:26:15.110 --> 00:26:18.770
But rather than
feeding in X and T

00:26:18.770 --> 00:26:21.020
in your first layer of
your neural network,

00:26:21.020 --> 00:26:25.377
we're only going to feed
in X in the initial layer

00:26:25.377 --> 00:26:26.960
of the neural network,
and we're going

00:26:26.960 --> 00:26:28.850
to learn a shared
representation, which

00:26:28.850 --> 00:26:30.590
is going to be used
for both predicting

00:26:30.590 --> 00:26:33.020
T equals 0 and T equals 1.

00:26:33.020 --> 00:26:38.570
And then for predicting
when T is equal to 0,

00:26:38.570 --> 00:26:43.730
we use a different head
from predicting T equals 1.

00:26:43.730 --> 00:26:48.500
So F0 is a function that
concatenates these shared

00:26:48.500 --> 00:26:50.750
layers with several
new layers used

00:26:50.750 --> 00:26:53.710
to predict for when
T is equal to 0

00:26:53.710 --> 00:26:55.190
and same analogously for 1.

00:26:55.190 --> 00:26:59.270
And we found that architecture
worked substantially better

00:26:59.270 --> 00:27:00.782
than the naive
architectures when

00:27:00.782 --> 00:27:02.990
doing causal inference on
several different benchmark

00:27:02.990 --> 00:27:03.490
data sets.

00:27:06.170 --> 00:27:08.980
Now, the last thing
I want to talk about

00:27:08.980 --> 00:27:11.340
for covariate
adjustment, before I

00:27:11.340 --> 00:27:13.940
move on to a new
set of techniques,

00:27:13.940 --> 00:27:17.030
is a method called
matching, that

00:27:17.030 --> 00:27:21.260
is intuitively very pleasing.

00:27:21.260 --> 00:27:24.530
It's a very-- would seem to
be a really natural approach

00:27:24.530 --> 00:27:27.680
to do causal inference,
and at first glance,

00:27:27.680 --> 00:27:31.160
may look like it has nothing
to do with covariate adjustment

00:27:31.160 --> 00:27:32.693
technique.

00:27:32.693 --> 00:27:34.860
What I'll do now is I'm
going to first introduce you

00:27:34.860 --> 00:27:36.650
to the matching
technique, and then I

00:27:36.650 --> 00:27:38.550
will show you that it
actually is precisely

00:27:38.550 --> 00:27:41.010
identical to
covariate adjustment

00:27:41.010 --> 00:27:42.600
with a particular
assumption of what

00:27:42.600 --> 00:27:44.550
the functional family for F is.

00:27:44.550 --> 00:27:47.370
So not Gaussian processes,
not deep neural networks,

00:27:47.370 --> 00:27:49.150
but it'll be something else.

00:27:49.150 --> 00:27:52.110
So before I get into
that, what is matching as

00:27:52.110 --> 00:27:54.390
a technique for
causal inference?

00:27:54.390 --> 00:27:56.070
Well, the key idea
of matching is

00:27:56.070 --> 00:28:00.870
to use each
individual's twin to try

00:28:00.870 --> 00:28:03.630
to get some intuition about what
their potential outcome might

00:28:03.630 --> 00:28:04.320
have been?

00:28:04.320 --> 00:28:08.520
So I created these
slides a few years ago

00:28:08.520 --> 00:28:10.500
when President
Obama was in office,

00:28:10.500 --> 00:28:14.520
and you might imagine this
is the actual President

00:28:14.520 --> 00:28:16.500
Obama who did go to law school.

00:28:16.500 --> 00:28:22.260
And you might imagine who might
have been that other president?

00:28:22.260 --> 00:28:24.480
What President Obama
have been like had he not

00:28:24.480 --> 00:28:28.512
gone to law school, but let's
say, gone to business school?

00:28:28.512 --> 00:28:30.970
So if you can now imagine trying
to find, in your data set,

00:28:30.970 --> 00:28:35.330
someone else who looks
just like Barack Obama,

00:28:35.330 --> 00:28:38.010
but who, instead of
going to law school,

00:28:38.010 --> 00:28:40.010
went to business school,
and then you would then

00:28:40.010 --> 00:28:41.360
ask the following question.

00:28:41.360 --> 00:28:45.260
For example, would
this individual

00:28:45.260 --> 00:28:47.840
have gone on to
become president had

00:28:47.840 --> 00:28:50.510
he gone to law school versus
had he gone to business school?

00:28:50.510 --> 00:28:52.070
If you find someone
else who's just

00:28:52.070 --> 00:28:55.100
like Barack Obama who went to
business school, look to see

00:28:55.100 --> 00:28:57.350
did that person become
president eventually,

00:28:57.350 --> 00:29:00.880
that would in essence give
you that counterfactual.

00:29:00.880 --> 00:29:02.620
Obviously, this is
a contrived example

00:29:02.620 --> 00:29:07.090
because you would never get
the sample size to see that.

00:29:07.090 --> 00:29:10.440
So that's the general idea,
and now, I'll show it to you

00:29:10.440 --> 00:29:12.010
in a picture.

00:29:12.010 --> 00:29:16.530
So here now, we have to
covariates or features--

00:29:16.530 --> 00:29:21.630
a patient's age and their
Charleson comorbidity index.

00:29:21.630 --> 00:29:25.350
This is some measure
of how many--

00:29:25.350 --> 00:29:27.272
what types of conditions
or comorbidities

00:29:27.272 --> 00:29:28.230
the patient might have.

00:29:28.230 --> 00:29:32.410
Do they have diabetes, do they
have hypertension, and so on?

00:29:32.410 --> 00:29:34.930
And notably, what
I'm not showing

00:29:34.930 --> 00:29:38.710
you here is the
outcome Y. All I'm

00:29:38.710 --> 00:29:40.960
showing you are the
original data points

00:29:40.960 --> 00:29:43.150
and what treatment
did they receive.

00:29:43.150 --> 00:29:46.750
So blue are the individuals who
received the control treatment,

00:29:46.750 --> 00:29:49.405
or T equals 0, and red
are the individuals

00:29:49.405 --> 00:29:53.260
who received treatment 1.

00:29:53.260 --> 00:29:55.480
So you can imagine trying
to find nearest neighbors.

00:29:55.480 --> 00:29:57.400
For example, the
nearest neighbor

00:29:57.400 --> 00:30:02.050
to this data point over here
is this blue point over here,

00:30:02.050 --> 00:30:06.790
and so if you wanted to
know, well, what we observed,

00:30:06.790 --> 00:30:15.370
some Y1, for this individual,
we observed some Y0

00:30:15.370 --> 00:30:17.460
for this individual.

00:30:17.460 --> 00:30:19.360
And if you wanted
to know, well, what

00:30:19.360 --> 00:30:22.690
would have happened
to this individual

00:30:22.690 --> 00:30:26.085
if they had received treatment
0 instead of treatment 1,

00:30:26.085 --> 00:30:27.460
well, you could
just look at what

00:30:27.460 --> 00:30:29.470
happened to this
blue point and say,

00:30:29.470 --> 00:30:31.553
that's what would have
happened to this red point,

00:30:31.553 --> 00:30:34.360
because they're very
close to each other.

00:30:34.360 --> 00:30:35.968
Any questions
about what matching

00:30:35.968 --> 00:30:37.510
would do before I
define it formally?

00:30:50.572 --> 00:30:55.090
Here, I'll-- yeah,
good, one question.

00:30:55.090 --> 00:30:57.760
What happens if the nearest
neighbor is extremely far away?

00:30:57.760 --> 00:30:59.180
That's a great question.

00:30:59.180 --> 00:31:06.640
So you can imagine that you have
one red data point over here

00:31:06.640 --> 00:31:10.013
and no blue data points nearby.

00:31:10.013 --> 00:31:11.930
The matching approach
wouldn't work very well.

00:31:11.930 --> 00:31:13.750
So this data point,
the nearest neighbor,

00:31:13.750 --> 00:31:17.740
is this blue point over here,
which intuitively, is very far

00:31:17.740 --> 00:31:19.030
from this red point.

00:31:19.030 --> 00:31:23.650
And so if we were to estimate
this red point's counterfactual

00:31:23.650 --> 00:31:25.690
using that blue point,
we're likely to get

00:31:25.690 --> 00:31:27.460
a very bad estimate,
and in fact, that

00:31:27.460 --> 00:31:29.590
is going to be one of the
challenges of matching

00:31:29.590 --> 00:31:30.550
based approaches.

00:31:30.550 --> 00:31:33.610
It's going to work really well
in a high dimensional setting

00:31:33.610 --> 00:31:37.690
where you can imagine--
sorry, in a large--

00:31:37.690 --> 00:31:40.270
it's going to work very well in
a large sample setting, where

00:31:40.270 --> 00:31:44.710
you can hope that you're likely
to observe a counterfactual

00:31:44.710 --> 00:31:45.970
for every individual.

00:31:45.970 --> 00:31:48.407
And it won't work well you
have very limited data,

00:31:48.407 --> 00:31:49.990
and of course, all
this is going to be

00:31:49.990 --> 00:31:53.638
subject to the assumption
of common support.

00:31:53.638 --> 00:31:55.180
So one question's
about how does that

00:31:55.180 --> 00:31:56.472
translate into high dimensions?

00:31:56.472 --> 00:31:58.120
The short answer--
not very well.

00:31:58.120 --> 00:32:01.200
We'll get back to
that in a moment.

00:32:01.200 --> 00:32:05.200
Can a single data point
appear in multiple matchings?

00:32:05.200 --> 00:32:10.840
Yes, and I will define, in
just a moment, how and why.

00:32:10.840 --> 00:32:14.160
It won't be a strict matching.

00:32:14.160 --> 00:32:15.870
Are we trying to
find a counterfactual

00:32:15.870 --> 00:32:19.740
for each treated observation,
or one for each control

00:32:19.740 --> 00:32:20.490
observation?

00:32:20.490 --> 00:32:22.860
I'll answer that
in just a second.

00:32:22.860 --> 00:32:25.050
And finally, is it common
for medical data sets

00:32:25.050 --> 00:32:26.610
to find such matching pairs?

00:32:26.610 --> 00:32:29.430
I'm going to reinterpret
that question as saying,

00:32:29.430 --> 00:32:32.190
is this technique used
often in medicine?

00:32:32.190 --> 00:32:34.140
And the answer
is, yes, it's used

00:32:34.140 --> 00:32:38.340
all the time in clinical
research despite the fact

00:32:38.340 --> 00:32:43.102
that bio statisticians,
for quite a few years now,

00:32:43.102 --> 00:32:45.060
have been trying to argue
that folks should not

00:32:45.060 --> 00:32:48.820
use this technique for
reasons that you see shortly.

00:32:48.820 --> 00:32:50.310
So it's widely used.

00:32:50.310 --> 00:32:52.830
It's very intuitive, which
is why I'm teaching it.

00:32:52.830 --> 00:32:55.170
And it's going to fit into
a very general framework,

00:32:55.170 --> 00:32:56.880
as you'll see in just a
moment, which I'll give you

00:32:56.880 --> 00:32:58.530
the natural solution
for the problems

00:32:58.530 --> 00:33:00.030
that I'm going to raise.

00:33:00.030 --> 00:33:02.010
So moving on, and
then I'll return

00:33:02.010 --> 00:33:05.340
to any remaining questions.

00:33:05.340 --> 00:33:12.320
So here, I'll define one way of
doing counterfactual inference

00:33:12.320 --> 00:33:14.517
using matching, and it's
going to start, of course,

00:33:14.517 --> 00:33:16.100
by assuming that we
have some distance

00:33:16.100 --> 00:33:19.440
metric d between individuals.

00:33:19.440 --> 00:33:22.830
Then we're going to say,
for each individual i,

00:33:22.830 --> 00:33:28.830
let's let j of i be the
other individual j, obviously

00:33:28.830 --> 00:33:35.580
different from i, who is closest
to i, but critically, closest

00:33:35.580 --> 00:33:37.620
but has a different treatment.

00:33:37.620 --> 00:33:43.950
So where Ti is different
from Tj, and again,

00:33:43.950 --> 00:33:49.100
I'm assuming binary,
so Tj is either 0 or 1.

00:33:49.100 --> 00:33:51.080
With that definition
then, we're going

00:33:51.080 --> 00:33:57.560
to define our estimate of the
conditional average treatment

00:33:57.560 --> 00:34:01.790
effect for an individual is
whatever their actual observed

00:34:01.790 --> 00:34:04.700
outcome was.

00:34:04.700 --> 00:34:06.680
This, I'm going to give
for an individual that

00:34:06.680 --> 00:34:10.504
actually received treatment 1,
so it's Y1, and the reason--

00:34:10.504 --> 00:34:17.810
it's Yi minus the imputed
counterfactual corresponding

00:34:17.810 --> 00:34:19.550
to T is equal to 0.

00:34:19.550 --> 00:34:22.040
And the way we get that
computed counterfactual

00:34:22.040 --> 00:34:25.550
is by trying to find
that nearest neighbor who

00:34:25.550 --> 00:34:28.130
received treatment 0
instead of treatment 1

00:34:28.130 --> 00:34:30.440
and looking at their Y.

00:34:30.440 --> 00:34:33.679
Analogously, if T is
equal to 0, then we're

00:34:33.679 --> 00:34:37.340
going to use the
observed Yi, now

00:34:37.340 --> 00:34:41.480
over here instead of over there
because it corresponds to Y0.

00:34:41.480 --> 00:34:45.560
And where we need to impute Y1--

00:34:45.560 --> 00:34:47.690
capital Y1, potential
outcome Y1--

00:34:47.690 --> 00:34:51.320
we're going to use the observed
outcome from the nearest

00:34:51.320 --> 00:34:55.940
neighbor of individual i who
received treatment 1 instead

00:34:55.940 --> 00:34:57.900
of 0.

00:34:57.900 --> 00:35:01.850
So this, mathematically, is what
I mean by our matching based

00:35:01.850 --> 00:35:05.330
estimator, and this
also should answer

00:35:05.330 --> 00:35:07.680
one of the questions which
was raised, which is,

00:35:07.680 --> 00:35:09.380
do you really need
to have it matching,

00:35:09.380 --> 00:35:13.953
or could a data point be matched
to multiple other data points?

00:35:13.953 --> 00:35:16.370
And indeed, here, you see the
answer to that last question

00:35:16.370 --> 00:35:19.650
is yes, because you could have
a setting where, for example,

00:35:19.650 --> 00:35:22.970
there are two red points here.

00:35:22.970 --> 00:35:25.220
And I can't draw
blue, but I'll just

00:35:25.220 --> 00:35:28.160
use a square for what I
would have drawn as blue.

00:35:28.160 --> 00:35:30.350
And then everything
else very far away,

00:35:30.350 --> 00:35:32.300
and for both of
these red points,

00:35:32.300 --> 00:35:35.530
this blue point is
the closest neighbor.

00:35:35.530 --> 00:35:40.070
So both of the counterfactual
estimates for these two points

00:35:40.070 --> 00:35:42.230
would be using the
same blue point,

00:35:42.230 --> 00:35:44.990
so that's the answer
to that question.

00:35:44.990 --> 00:35:47.300
Now, I'm just going to
rewrite this in a little bit

00:35:47.300 --> 00:35:48.290
more convenient form.

00:35:48.290 --> 00:35:51.020
So I'll take this
formula, shown over here,

00:35:51.020 --> 00:35:56.360
and you can rewrite
that as Yi minus Yji,

00:35:56.360 --> 00:35:58.010
but you have to flip
the sign depending

00:35:58.010 --> 00:36:00.080
on whether Ti is
equal to 1 or 0,

00:36:00.080 --> 00:36:02.600
and so that's what this
term is going to do.

00:36:02.600 --> 00:36:06.110
If Ti is equal to 1,
then this evaluates to 1.

00:36:06.110 --> 00:36:08.650
If Ti is equal to 0, this
evaluates to minus 1.

00:36:08.650 --> 00:36:10.220
Flips the sign.

00:36:10.220 --> 00:36:13.640
So now that we have
the definition of CATE,

00:36:13.640 --> 00:36:17.690
we can now easily estimate
the average treatment effect

00:36:17.690 --> 00:36:20.390
by just averaging these CATEs
over all of the individuals

00:36:20.390 --> 00:36:22.880
in your data set.

00:36:22.880 --> 00:36:26.210
So this is now the
definition of how

00:36:26.210 --> 00:36:29.390
to do one nearest
neighbor matching.

00:36:29.390 --> 00:36:30.170
Any questions?

00:36:36.390 --> 00:36:39.770
So one question is, do we ever
use the metric d to weight

00:36:39.770 --> 00:36:42.530
how much we would, quote,
unquote, "trust" the matching?

00:36:45.402 --> 00:36:46.360
That's a good question.

00:36:46.360 --> 00:36:53.190
So what Hannah's
asking is, what happens

00:36:53.190 --> 00:36:55.740
if you have, for example,
very many nearest

00:36:55.740 --> 00:36:58.380
neighbors, or
analogously, what happens

00:36:58.380 --> 00:37:00.510
if you have some
nearest neighbors that

00:37:00.510 --> 00:37:02.700
are really close, some
that are really far?

00:37:02.700 --> 00:37:06.390
You might imagine trying to
weight your nearest neighbors

00:37:06.390 --> 00:37:09.540
by the distance
from the data point,

00:37:09.540 --> 00:37:12.318
and you could imagine
even doing that--

00:37:12.318 --> 00:37:14.610
you can even imagine coming
up with an estimator, which

00:37:14.610 --> 00:37:17.460
might discount certain data
points if they don't have

00:37:17.460 --> 00:37:19.610
nearest neighbors
near them at all

00:37:19.610 --> 00:37:21.300
by the corresponding
weighting factor.

00:37:21.300 --> 00:37:22.860
Yes, that's a good idea.

00:37:22.860 --> 00:37:25.200
Yes, you can come up with
a consistent estimator

00:37:25.200 --> 00:37:28.620
of the average treatment
effect through such an idea.

00:37:28.620 --> 00:37:31.650
There are probably a few
hundred papers written about it,

00:37:31.650 --> 00:37:34.030
and that's all I
have to say about it.

00:37:34.030 --> 00:37:37.470
So there's lots of variants
of this, and they all end

00:37:37.470 --> 00:37:40.500
up having the same
theoretical justification

00:37:40.500 --> 00:37:43.240
that I'm about to give
in the next slide.

00:37:43.240 --> 00:37:46.980
So one of the
advantages of matching

00:37:46.980 --> 00:37:48.960
is that you get some
interpretability.

00:37:48.960 --> 00:37:51.150
So if I was to ask
you, well, what's

00:37:51.150 --> 00:37:54.630
the reason why you tell
me that this treatment is

00:37:54.630 --> 00:37:56.370
going to work for John?

00:37:56.370 --> 00:37:57.900
Well, someone can respond--

00:37:57.900 --> 00:38:02.100
well, I used this
technique, and I

00:38:02.100 --> 00:38:09.030
found that the nearest
neighbor to John was Anna.

00:38:09.030 --> 00:38:11.430
And Anna took this other
treatment from John,

00:38:11.430 --> 00:38:12.990
and this is what
happened for Anna.

00:38:12.990 --> 00:38:15.990
And that's why I
conjecture that, for John,

00:38:15.990 --> 00:38:18.870
the difference between
Y1 and Y0 is as follows.

00:38:18.870 --> 00:38:21.000
And so then, that
can be criticized.

00:38:21.000 --> 00:38:23.820
So for example, a clinician
who has some domain expert,

00:38:23.820 --> 00:38:27.940
can look at Anna, look at John,
and say, oh, wait a second,

00:38:27.940 --> 00:38:30.720
these two individuals are really
different from one another.

00:38:30.720 --> 00:38:32.640
Let's say the
treatment, for example,

00:38:32.640 --> 00:38:35.820
had to do with something
which was gender specific.

00:38:39.487 --> 00:38:41.820
Comparing two individuals
which are of different genders

00:38:41.820 --> 00:38:44.070
are obviously not going to
be comparable to one other,

00:38:44.070 --> 00:38:46.020
and so then the
domain expert would

00:38:46.020 --> 00:38:48.720
be able to reject that
conclusion and say,

00:38:48.720 --> 00:38:50.760
nuh-uh, I don't trust
any of these statistics.

00:38:50.760 --> 00:38:52.710
Go back to the drawing board.

00:38:52.710 --> 00:38:58.570
And so type of interpretability
is very attractive.

00:38:58.570 --> 00:39:01.660
The second aspect of this,
which is very attractive

00:39:01.660 --> 00:39:03.790
is that it's a
non-parametric method,

00:39:03.790 --> 00:39:07.270
non-parametric in the same
way that neural networks

00:39:07.270 --> 00:39:08.890
or random forest
are non-parametric.

00:39:08.890 --> 00:39:13.930
So this does not rely
on any strong assumption

00:39:13.930 --> 00:39:18.702
about the parametric form
of the potential outcomes.

00:39:18.702 --> 00:39:20.160
On the other hand,
this approach is

00:39:20.160 --> 00:39:22.560
very reliant on the
underlying metric.

00:39:22.560 --> 00:39:25.110
If your distance function
is a poor distance function,

00:39:25.110 --> 00:39:27.960
then it's going to
give poor results.

00:39:27.960 --> 00:39:31.440
And moreover, it could
be very much misled

00:39:31.440 --> 00:39:34.920
by features that don't
affect the outcome, which

00:39:34.920 --> 00:39:37.980
is not necessarily a
property that we want.

00:39:37.980 --> 00:39:41.010
Now, here's that final slide
that makes the connection.

00:39:41.010 --> 00:39:46.130
Matching is equivalent
to covariate adjustment.

00:39:46.130 --> 00:39:47.410
It's exactly the same.

00:39:47.410 --> 00:39:50.790
It's an instantiation
of covariate adjustment

00:39:50.790 --> 00:39:53.670
with a particular
functional family for F.

00:39:53.670 --> 00:39:55.500
So rather than assuming
that your function

00:39:55.500 --> 00:39:58.290
F, that black box, is a linear
function or a neural network

00:39:58.290 --> 00:40:02.310
or a random forester or a
Bayesian regression tree,

00:40:02.310 --> 00:40:04.080
we're going to
assume that function

00:40:04.080 --> 00:40:07.100
takes the form of a nearest
neighbor classifier.

00:40:07.100 --> 00:40:12.630
In particular, we'll
say that Y hat of 1,

00:40:12.630 --> 00:40:16.830
the function for predicting
the potential outcome Y hat 1,

00:40:16.830 --> 00:40:23.160
is given to you by finding the
nearest neighbor of the data

00:40:23.160 --> 00:40:28.350
point X according to the
data set of individuals

00:40:28.350 --> 00:40:34.600
that received treatment 1,
and same thing for Y hat 0.

00:40:34.600 --> 00:40:39.430
And so that then allows
us to actually prove

00:40:39.430 --> 00:40:43.220
some properties of matching.

00:40:43.220 --> 00:40:46.600
So for example, if
you remember from--

00:40:46.600 --> 00:40:49.870
I think I mentioned
in Tuesday's lecture

00:40:49.870 --> 00:40:52.870
that this covariate
adjustment approach,

00:40:52.870 --> 00:40:57.520
under the assumptions of overlap
and under the assumptions

00:40:57.520 --> 00:41:04.240
of no hidden confounding,
and that your function

00:41:04.240 --> 00:41:07.240
family for potential outcome
is sufficiently rich that you

00:41:07.240 --> 00:41:10.570
can actually fit the
underlying model,

00:41:10.570 --> 00:41:12.340
then you're going to
get correct estimates

00:41:12.340 --> 00:41:16.700
of your conditional
average treatment effect.

00:41:16.700 --> 00:41:26.450
Now, one can show that a nearest
neighbor algorithm is not,

00:41:26.450 --> 00:41:28.080
generally, a
consistent algorithm.

00:41:28.080 --> 00:41:29.538
And what that means
is that, if you

00:41:29.538 --> 00:41:32.570
have a small number
of samples, you're

00:41:32.570 --> 00:41:34.520
going to be getting
biased estimate.

00:41:34.520 --> 00:41:38.870
Your function F might, in
general, be a biased estimate.

00:41:38.870 --> 00:41:41.300
Now, we can conclude
from that, that if we

00:41:41.300 --> 00:41:43.390
were to use one nearest
neighbor matching

00:41:43.390 --> 00:41:46.910
for inferring average treatment
effect, that in general,

00:41:46.910 --> 00:41:48.740
it could give us
a biased estimate

00:41:48.740 --> 00:41:50.690
of the average treatment effect.

00:41:50.690 --> 00:41:55.120
However, in the limit
of infinite data,

00:41:55.120 --> 00:41:56.930
one nearest neighbor
algorithms are

00:41:56.930 --> 00:42:01.850
guaranteed to be able to fit
the underlying function family.

00:42:01.850 --> 00:42:04.360
That is to say,
that bias goes to 0

00:42:04.360 --> 00:42:06.530
in the limit of a
large amount of data,

00:42:06.530 --> 00:42:09.980
and thus, we can immediately
draw from that literature

00:42:09.980 --> 00:42:12.080
and causal inference--
sorry, from that literature

00:42:12.080 --> 00:42:14.890
and machine learning to
obtain theoretical results

00:42:14.890 --> 00:42:18.900
for matching for
causal inference.

00:42:18.900 --> 00:42:20.910
And so that's all I want
to say about matching

00:42:20.910 --> 00:42:24.120
and its connection to
covariate adjustment.

00:42:24.120 --> 00:42:27.300
And really, the
punchline is, think

00:42:27.300 --> 00:42:30.060
about matching just as another
type of covariate adjustment,

00:42:30.060 --> 00:42:33.390
one which uses a nearest
neighbor function family,

00:42:33.390 --> 00:42:37.020
and thus should be compared
to other approaches

00:42:37.020 --> 00:42:43.110
to covariate adjustments, such
as, for example, using machine

00:42:43.110 --> 00:42:47.630
learning algorithms that are
designed to be interpretable.

00:42:47.630 --> 00:42:49.880
So the last part
of this lecture is

00:42:49.880 --> 00:42:56.930
going to be introducing a
second approach for inferring

00:42:56.930 --> 00:43:00.020
average treatment effect that
is known as the propensity score

00:43:00.020 --> 00:43:03.420
method, and this is
going to be a real shift.

00:43:03.420 --> 00:43:05.250
It's going to be a
different estimator

00:43:05.250 --> 00:43:08.610
from the covariate adjustment.

00:43:08.610 --> 00:43:11.568
So as I mentioned, it's going
to be used for estimating

00:43:11.568 --> 00:43:12.610
average treatment effect.

00:43:12.610 --> 00:43:14.257
In problem set 4,
you're going to see

00:43:14.257 --> 00:43:16.090
how you can use the
same sorts of techniques

00:43:16.090 --> 00:43:18.060
I'll tell you about
now for also estimating

00:43:18.060 --> 00:43:19.890
conditional average
treatment effect,

00:43:19.890 --> 00:43:23.090
but that won't be obvious
just from today's lecture.

00:43:23.090 --> 00:43:27.315
So the key intuition for
propensity score method

00:43:27.315 --> 00:43:29.440
is to think back to what
would have happened if you

00:43:29.440 --> 00:43:31.030
had a randomized control trial.

00:43:31.030 --> 00:43:33.430
In a randomized control
trial, again, you

00:43:33.430 --> 00:43:37.480
get choice over what treatment
to give each individual,

00:43:37.480 --> 00:43:39.875
so you might imagine
flipping a coin.

00:43:39.875 --> 00:43:41.500
If it's heads, giving
them treatment 1.

00:43:41.500 --> 00:43:43.980
If it's tails, giving
them treatment 0.

00:43:43.980 --> 00:43:46.320
So given data from a
randomized control trial,

00:43:46.320 --> 00:43:49.230
then there's a really
simple estimator shown here

00:43:49.230 --> 00:43:51.060
for the average
treatment effect.

00:43:51.060 --> 00:43:56.460
You just sum up the values
of Y for the individuals

00:43:56.460 --> 00:43:59.093
that receive treatment
1, divided by n1,

00:43:59.093 --> 00:44:00.510
which is the number
of individuals

00:44:00.510 --> 00:44:01.593
that received treatment 1.

00:44:01.593 --> 00:44:04.240
So this is the average
outcome for all people who

00:44:04.240 --> 00:44:06.390
got treatment 1, and
you just subtract

00:44:06.390 --> 00:44:09.300
from that the average outcome
for all individuals who

00:44:09.300 --> 00:44:11.040
received treatment 0.

00:44:11.040 --> 00:44:14.820
And that can be easily shown
to be an unbiased estimator

00:44:14.820 --> 00:44:16.650
of the average
treatment effect had

00:44:16.650 --> 00:44:19.910
your data come from a
randomized controlled trial.

00:44:19.910 --> 00:44:22.070
So the key idea of a
propensity score method

00:44:22.070 --> 00:44:25.550
is to turn an observational
study into something

00:44:25.550 --> 00:44:29.810
that looks like a randomized
control trial via re-weighting

00:44:29.810 --> 00:44:31.590
of the data points.

00:44:31.590 --> 00:44:34.130
So here's the picture I
want you to have in mind.

00:44:34.130 --> 00:44:37.700
Again, here, I am not
showing you outcomes.

00:44:37.700 --> 00:44:40.340
I'm just showing
you the features X--

00:44:40.340 --> 00:44:41.840
that's what the
data points are--

00:44:41.840 --> 00:44:45.830
and the treatments that
were given to them, the Ts.

00:44:45.830 --> 00:44:47.750
And the Ts, in this
case, are being

00:44:47.750 --> 00:44:51.920
denoted by the color of the
dots, so red is T equals 1.

00:44:51.920 --> 00:44:53.730
Blue is T equals 0.

00:44:53.730 --> 00:44:56.700
And my apologies in advance
for anyone who's color blind.

00:44:56.700 --> 00:45:00.980
So the key challenge
when working

00:45:00.980 --> 00:45:03.500
with observational
study is that there

00:45:03.500 --> 00:45:07.610
might be a bias in terms of who
receives treatment 0 versus who

00:45:07.610 --> 00:45:09.230
receives treatment 1.

00:45:09.230 --> 00:45:11.190
If this was a randomized
control trial,

00:45:11.190 --> 00:45:13.700
then you would expect to
see the reds and the blues

00:45:13.700 --> 00:45:16.380
all intermixed equally
with one another,

00:45:16.380 --> 00:45:18.900
but as you can see
here, in this data set,

00:45:18.900 --> 00:45:20.360
there are very many
more people who

00:45:20.360 --> 00:45:23.990
received-- very more young
people who received treatment

00:45:23.990 --> 00:45:26.210
0 than received treatment 1.

00:45:26.210 --> 00:45:29.900
Said differently, if you look
at the distribution over X

00:45:29.900 --> 00:45:32.330
conditioned on T
equals 0 in the data,

00:45:32.330 --> 00:45:34.550
it's different from
the distribution

00:45:34.550 --> 00:45:39.780
over X conditioned on the
people who receive treatment 1.

00:45:39.780 --> 00:45:42.285
So what the propensity score
method is going to do is

00:45:42.285 --> 00:45:44.730
it's going to recognize that
there is a difference between

00:45:44.730 --> 00:45:48.210
these two distributions, and
it's going to re-weight data

00:45:48.210 --> 00:45:52.150
points so that, in aggregate, it
looks like, in any one region--

00:45:52.150 --> 00:45:54.830
so for example, imagine
looking at this region--

00:45:54.830 --> 00:45:56.590
that there's roughly
the same number

00:45:56.590 --> 00:45:59.320
of red and blue data points.

00:45:59.320 --> 00:46:02.718
Where if you think about blowing
up this red data point-- here,

00:46:02.718 --> 00:46:04.260
I've made it very
big-- you can think

00:46:04.260 --> 00:46:06.410
about it being many,
many red data points

00:46:06.410 --> 00:46:08.230
of the corresponding weight.

00:46:08.230 --> 00:46:13.440
You look over here, see
again roughly the same amount

00:46:13.440 --> 00:46:15.550
of red and blue mass as well.

00:46:15.550 --> 00:46:20.202
So if we can find some way
to increase or decrease

00:46:20.202 --> 00:46:22.410
the weight associated with
each data point such that,

00:46:22.410 --> 00:46:26.460
now, it looks like the
two distributions, those

00:46:26.460 --> 00:46:28.970
who received treatment 1 and
those who received treatment 0,

00:46:28.970 --> 00:46:30.930
look like they came
from-- look like now they

00:46:30.930 --> 00:46:33.180
have the same
weighted distribution,

00:46:33.180 --> 00:46:34.680
then we're going
to be in business.

00:46:34.680 --> 00:46:36.597
So we're going to search
for those weights, w,

00:46:36.597 --> 00:46:39.522
that have that property.

00:46:39.522 --> 00:46:40.980
So to do that, we
need to introduce

00:46:40.980 --> 00:46:43.920
one new concept, which is
known as the propensity score.

00:46:43.920 --> 00:46:47.700
The propensity score is given
to you by the probability

00:46:47.700 --> 00:46:51.930
that T equals 1
given X. Here, again,

00:46:51.930 --> 00:46:53.490
we're going to use
machine learning.

00:46:53.490 --> 00:46:54.948
Whereas in covariate
adjustment, we

00:46:54.948 --> 00:47:01.170
used machine learning to predict
Y conditioned on X comma T--

00:47:01.170 --> 00:47:03.170
that's what covariate
adjustment did--

00:47:03.170 --> 00:47:06.900
here, we're going to be
ignoring Y altogether.

00:47:06.900 --> 00:47:09.000
We're just going
to take X's input,

00:47:09.000 --> 00:47:10.980
and we're going to
be predicting T.

00:47:10.980 --> 00:47:13.620
So you can imagine using
logistic regression, given

00:47:13.620 --> 00:47:16.230
your covariates, to predict
which treatment any given data

00:47:16.230 --> 00:47:17.190
point came from.

00:47:17.190 --> 00:47:19.590
Here, you're using the
full data set, of course,

00:47:19.590 --> 00:47:21.420
to make that
prediction, so we're

00:47:21.420 --> 00:47:25.140
looking at both data
points where T equals 1

00:47:25.140 --> 00:47:26.000
and T equals 0.

00:47:26.000 --> 00:47:28.578
T is your label for this.

00:47:28.578 --> 00:47:30.120
Then what we're
going to do is given,

00:47:30.120 --> 00:47:33.390
that learned propensity score--
so we take your data set.

00:47:33.390 --> 00:47:35.610
You, first, learn
the propensity score.

00:47:35.610 --> 00:47:38.190
Then we're going to re-weight
the data points according

00:47:38.190 --> 00:47:40.640
to the inverse of
the propensity score.

00:47:40.640 --> 00:47:43.570
And you might ask,
this looks familiar.

00:47:43.570 --> 00:47:46.410
This whole notion of
re-weighting data points,

00:47:46.410 --> 00:47:50.100
this whole notion of trying
to figure out which, quote,

00:47:50.100 --> 00:47:52.155
unquote, "data set" a
data point came from,

00:47:52.155 --> 00:47:54.780
the data set of individuals who
receive treatment 1 or the data

00:47:54.780 --> 00:47:57.420
set of individuals who
receive treatment 0--

00:47:57.420 --> 00:47:58.620
that sounds really familiar.

00:47:58.620 --> 00:48:01.380
And it's because it's exactly
what you saw in lecture 10,

00:48:01.380 --> 00:48:03.140
when we talked about
data set shift.

00:48:03.140 --> 00:48:05.310
In fact, this whole
entire method,

00:48:05.310 --> 00:48:08.310
as you'll develop
in problem set 4,

00:48:08.310 --> 00:48:14.350
is a special case of learning
under data set shift.

00:48:14.350 --> 00:48:17.850
So here, now, is the
propensity score algorithm.

00:48:17.850 --> 00:48:23.080
We take our data set, which
have samples of X, T, and Y

00:48:23.080 --> 00:48:27.390
where Y, of course, tells
you the potential outcome

00:48:27.390 --> 00:48:30.413
corresponding to
the treatment T.

00:48:30.413 --> 00:48:32.330
We're going to use any
machine learning method

00:48:32.330 --> 00:48:36.150
in order to estimate
this model that

00:48:36.150 --> 00:48:38.850
can give you a probability
of treatment given X.

00:48:38.850 --> 00:48:42.727
Now, critically, we need
a probability for this.

00:48:42.727 --> 00:48:44.310
We're not trying to
do classification.

00:48:44.310 --> 00:48:46.727
We need an actual probability,
and so if you remember back

00:48:46.727 --> 00:48:50.670
to previous lectures where
we spoke about calibration,

00:48:50.670 --> 00:48:53.820
about the ability to accurately
predict probabilities,

00:48:53.820 --> 00:48:55.930
that is going to be
really important here.

00:48:55.930 --> 00:48:58.430
And so for example, if you were
to use a deep neural network

00:48:58.430 --> 00:49:02.310
in order to estimate
the propensity scores,

00:49:02.310 --> 00:49:06.120
deep networks are well known
to not be well calibrated.

00:49:06.120 --> 00:49:09.030
And so one would have to use
one of a number of new methods

00:49:09.030 --> 00:49:10.410
that have been
recently developed

00:49:10.410 --> 00:49:12.850
to make the outputs of
deep learning calibrated

00:49:12.850 --> 00:49:15.730
in order to use this
type of technique.

00:49:15.730 --> 00:49:17.700
So after finishing
step 1, now that you

00:49:17.700 --> 00:49:20.460
have a model that can allow
you to estimate the propensity

00:49:20.460 --> 00:49:23.130
score for every
data point X, we now

00:49:23.130 --> 00:49:27.270
can take those and estimate
your average treatment effect

00:49:27.270 --> 00:49:29.450
with the following formula.

00:49:29.450 --> 00:49:34.790
It's 1 over n of the sum
over the data points, where

00:49:34.790 --> 00:49:39.080
the data points corresponding
to the treatment 1 of Yi--

00:49:39.080 --> 00:49:41.570
that part is
identical to before.

00:49:41.570 --> 00:49:43.970
But what you see now is
we're going to divide it

00:49:43.970 --> 00:49:46.610
by the propensity score,
and so this denominator,

00:49:46.610 --> 00:49:48.440
that's the new piece here.

00:49:48.440 --> 00:49:50.360
That's the inverse of
the propensity score

00:49:50.360 --> 00:49:53.940
is precisely the weighting that
we were referring to earlier,

00:49:53.940 --> 00:49:57.560
and the same thing happens
over here for Ti equals 0.

00:49:57.560 --> 00:50:01.700
Now, let's try to get some
intuition about this formula,

00:50:01.700 --> 00:50:03.410
and I like trying
to get intuition

00:50:03.410 --> 00:50:05.842
by looking at a special case.

00:50:05.842 --> 00:50:08.300
So the simplest special case
that we might be familiar with

00:50:08.300 --> 00:50:10.850
is that of a randomized
control trial, where

00:50:10.850 --> 00:50:14.090
because you're flipping a coin,
and each data point either

00:50:14.090 --> 00:50:16.940
gets treatment 0 or treatment
1, then the propensity

00:50:16.940 --> 00:50:21.120
score is precisely,
deterministically equal to 5.

00:50:21.120 --> 00:50:23.270
So let's take this now.

00:50:23.270 --> 00:50:25.490
No machine learning done here.

00:50:25.490 --> 00:50:28.250
Let's just plug it in to see if
we get back the formula that I

00:50:28.250 --> 00:50:30.500
showed you earlier
for the estimate

00:50:30.500 --> 00:50:33.190
of the average treatment effect
in a randomized control trial.

00:50:33.190 --> 00:50:35.330
So we plug that in over there.

00:50:35.330 --> 00:50:41.840
This now becomes 0.5, and
plug that in over here.

00:50:41.840 --> 00:50:45.785
This also becomes 0.5.

00:50:45.785 --> 00:50:47.660
And then what we're
going to do is we're just

00:50:47.660 --> 00:50:49.250
going to take that 0.5.

00:50:49.250 --> 00:50:52.430
We're going to bring that out,
and this is going to become a 2

00:50:52.430 --> 00:50:56.030
over here, and
same, a 2 over here.

00:50:56.030 --> 00:50:58.520
And you get to the
following formula, which

00:50:58.520 --> 00:51:01.220
is-- if you were to compare to
the formula from a few slides

00:51:01.220 --> 00:51:04.790
ago, it's almost identical,
except that a few slides

00:51:04.790 --> 00:51:14.560
ago over here, I had 1 over n1,
and over here, I had 1 over n0.

00:51:17.260 --> 00:51:20.032
Now, these two are two different
estimators for the same thing,

00:51:20.032 --> 00:51:22.240
and the reason why you can
say they're the same thing

00:51:22.240 --> 00:51:25.527
is that, in a randomized
control trial,

00:51:25.527 --> 00:51:27.610
the number of individuals
that receive treatment 1

00:51:27.610 --> 00:51:29.447
is, on average, n over 2.

00:51:29.447 --> 00:51:31.780
Similarly, the number of
individuals receiving treatment

00:51:31.780 --> 00:51:33.610
0 are, on average, n over 2.

00:51:33.610 --> 00:51:36.850
So if you were to--

00:51:36.850 --> 00:51:39.850
that n over 2 cancels
out with this 2 over n

00:51:39.850 --> 00:51:41.980
is what gets you a
correct estimator.

00:51:41.980 --> 00:51:44.590
So this is a slightly
different estimator,

00:51:44.590 --> 00:51:49.000
but nearly identical to the
one that I showed you earlier,

00:51:49.000 --> 00:51:52.990
and by this argument, is
a consistent estimator

00:51:52.990 --> 00:51:57.460
of the average treatment effect
in a randomized control trial.

00:51:57.460 --> 00:52:01.840
So any questions before I try
to derive this formula for you?

00:52:20.180 --> 00:52:24.350
So one student asks,
so the propensity score

00:52:24.350 --> 00:52:26.120
is the, quote,
unquote, "bias" of how

00:52:26.120 --> 00:52:32.870
likely people are assigned
to T equals 1 or T equals 0?

00:52:32.870 --> 00:52:34.850
Yes, that's exactly right.

00:52:34.850 --> 00:52:43.640
So if you were to imagine
taking an individual where

00:52:43.640 --> 00:52:45.620
this probability
for that individual

00:52:45.620 --> 00:52:48.890
is, let's say,
very close to 1, it

00:52:48.890 --> 00:52:51.840
means that there are very
few other people in the data

00:52:51.840 --> 00:52:54.470
set who receive treatment 1.

00:52:54.470 --> 00:53:01.820
They're a red data point in
a sea of blue data points.

00:53:01.820 --> 00:53:04.340
And by dividing by
that, we're going

00:53:04.340 --> 00:53:08.270
to be trying to remove that
bias, and that's exactly right.

00:53:08.270 --> 00:53:09.440
Thank you for that question.

00:53:09.440 --> 00:53:10.784
Are there other questions?

00:53:20.560 --> 00:53:24.668
I really appreciate the
questions via the chat window,

00:53:24.668 --> 00:53:25.210
so thank you.

00:53:28.660 --> 00:53:31.780
So let's now try to
derive this formula.

00:53:34.480 --> 00:53:37.497
Recall the definition of
average treatment effect,

00:53:37.497 --> 00:53:39.580
and for those who are
paying very close attention,

00:53:39.580 --> 00:53:42.945
you might notice that I removed
the expectation over Y1.

00:53:42.945 --> 00:53:45.070
And for this derivation
that I'm going to give you,

00:53:45.070 --> 00:53:46.172
I'm going to suppose--

00:53:46.172 --> 00:53:48.380
I'm going to assume that a
potential outcomes are all

00:53:48.380 --> 00:53:51.130
deterministic because it
makes the math easier,

00:53:51.130 --> 00:53:54.280
but is without
loss of generality.

00:53:54.280 --> 00:53:55.990
So the average
treatment effect is

00:53:55.990 --> 00:53:59.020
the expectation, with
respect to all individuals,

00:53:59.020 --> 00:54:01.680
of the potential outcome
Y1 minus the expectation

00:54:01.680 --> 00:54:04.330
with respect to all individuals
of the potential outcome Y0.

00:54:04.330 --> 00:54:09.460
So this term over here is going
to be our estimate of that,

00:54:09.460 --> 00:54:12.310
and this term over here is
going to be our estimate

00:54:12.310 --> 00:54:14.550
of this expectation.

00:54:14.550 --> 00:54:18.480
So naively, if you were to
just take the observed data,

00:54:18.480 --> 00:54:20.100
it would allow you to compute--

00:54:20.100 --> 00:54:23.670
if you, for example, just
averaged the values of Y

00:54:23.670 --> 00:54:25.635
for the individual who
received treatment 1,

00:54:25.635 --> 00:54:27.510
that would give you this
expectation that I'm

00:54:27.510 --> 00:54:29.070
showing on the bottom here.

00:54:29.070 --> 00:54:31.830
I want you to compare
that to the one that's

00:54:31.830 --> 00:54:33.845
actually needed in the
average treatment effect.

00:54:33.845 --> 00:54:35.970
Whereas over here, it's an
expectation with respect

00:54:35.970 --> 00:54:39.610
to individuals that received
treatment 1, up here,

00:54:39.610 --> 00:54:43.080
this was an expectation with
respect to all individuals.

00:54:43.080 --> 00:54:44.580
But the thing inside
the expectation

00:54:44.580 --> 00:54:46.830
is exactly identical,
and that's the key point

00:54:46.830 --> 00:54:48.420
that we're going
to work with, which

00:54:48.420 --> 00:54:50.460
is that we want an
expectation with respect

00:54:50.460 --> 00:54:54.690
to a different distribution than
the one that we actually have.

00:54:54.690 --> 00:54:56.730
And again, this
should ring bells,

00:54:56.730 --> 00:54:58.680
because this sounds
very, very familiar

00:54:58.680 --> 00:55:00.750
to the data set shift
story that we talked

00:55:00.750 --> 00:55:03.070
about a few lectures ago.

00:55:03.070 --> 00:55:07.910
So I'm going to show you how
to derive an estimator for just

00:55:07.910 --> 00:55:09.980
this first term, and the
second term is obviously

00:55:09.980 --> 00:55:11.710
going to be identical.

00:55:11.710 --> 00:55:14.060
So let's start out
with the following.

00:55:14.060 --> 00:55:18.950
We know that p of X
given T times p of T

00:55:18.950 --> 00:55:23.690
is equal to p of X
times p of T given X.

00:55:23.690 --> 00:55:28.790
So what I've just done here
is use two different formulas

00:55:28.790 --> 00:55:36.340
for a joint distribution,
and then I've

00:55:36.340 --> 00:55:38.390
divided by p of T
given X in order

00:55:38.390 --> 00:55:41.175
to get the formula that I
showed you a second ago.

00:55:41.175 --> 00:55:42.800
I'm not going to
attempt to erase that.

00:55:42.800 --> 00:55:44.900
I'll leave it up there.

00:55:44.900 --> 00:55:48.520
So the next thing we're going
to do is we're going to say,

00:55:48.520 --> 00:55:52.000
if we were to compute an
expectation with respect

00:55:52.000 --> 00:55:57.730
to p of X given T equals 1,
and if we were to now take

00:55:57.730 --> 00:56:02.800
the value that we observe, Y1,
which we can get observations

00:56:02.800 --> 00:56:04.960
for all the individuals
who received treatment 1,

00:56:04.960 --> 00:56:10.300
and if we were to re-weight
this observation by this ratio,

00:56:10.300 --> 00:56:16.780
where remember,
this ratio showed up

00:56:16.780 --> 00:56:21.640
in the previous bullet point,
then what I'm going to show you

00:56:21.640 --> 00:56:25.570
in just a moment is that
this is equal to the quantity

00:56:25.570 --> 00:56:28.670
that we actually wanted.

00:56:28.670 --> 00:56:30.200
Well, why is that?

00:56:30.200 --> 00:56:36.500
Well, if you expand
this expectation,

00:56:36.500 --> 00:56:38.930
this expectation is an
integral with respect

00:56:38.930 --> 00:56:49.610
to p of X conditioned
on T equals 1 times

00:56:49.610 --> 00:56:55.130
the thing inside the brackets,
and because we know that p of--

00:56:55.130 --> 00:56:57.730
because we know from up here
that p of X conditioned on T

00:56:57.730 --> 00:57:01.510
equals 1 times p of T
equals 1 divided by p of T

00:57:01.510 --> 00:57:05.090
equals 1 conditioned on
X is equal to p of X,

00:57:05.090 --> 00:57:06.590
this whole thing
is just going to be

00:57:06.590 --> 00:57:14.750
equal to an integral of p of
X times Y1, which is precisely

00:57:14.750 --> 00:57:17.180
the definition of
expectation that we want.

00:57:17.180 --> 00:57:19.310
So this was a very
simple derivation

00:57:19.310 --> 00:57:21.320
to show you that the
re-weighting gets you

00:57:21.320 --> 00:57:22.700
what you need.

00:57:22.700 --> 00:57:25.700
Now, we can estimate this
expectation empirically

00:57:25.700 --> 00:57:28.160
as follows, the
estimate that we're

00:57:28.160 --> 00:57:29.780
going to now sum
over all data points

00:57:29.780 --> 00:57:31.143
that received treatment 1.

00:57:31.143 --> 00:57:32.810
We're going to take
an average, so we're

00:57:32.810 --> 00:57:34.580
dividing by the
number of data points

00:57:34.580 --> 00:57:36.020
that received treatment 1.

00:57:36.020 --> 00:57:38.390
For p of T equals
1, we're just going

00:57:38.390 --> 00:57:41.852
to use the empirical estimate
of how many individuals received

00:57:41.852 --> 00:57:43.310
treatment 1 in the
data set divided

00:57:43.310 --> 00:57:45.435
by the total number of
individuals in the data set.

00:57:45.435 --> 00:57:46.780
That's n1 divided by n.

00:57:46.780 --> 00:57:49.700
And for the denominator, p of
T equals 1 conditioned on X,

00:57:49.700 --> 00:57:52.340
we just plug in, now, the
propensity score, which

00:57:52.340 --> 00:57:53.740
we had previously estimated.

00:57:53.740 --> 00:57:54.990
And we're done.

00:57:54.990 --> 00:57:57.380
And so that, now,
is our estimate

00:57:57.380 --> 00:58:00.630
for the first term in the
average treatment effect,

00:58:00.630 --> 00:58:02.630
and you can do that now
loosely for Ti equals 0.

00:58:02.630 --> 00:58:04.010
And I've shown
you the full proof

00:58:04.010 --> 00:58:09.330
of why this is an unbiased
estimator for average treatment

00:58:09.330 --> 00:58:11.030
effect.

00:58:11.030 --> 00:58:14.180
So I'm going to be concluding
now, in the next few minutes.

00:58:14.180 --> 00:58:17.060
First, I just wanted to
comment on what we just saw.

00:58:17.060 --> 00:58:22.100
So we saw a different way to
estimate the average treatment

00:58:22.100 --> 00:58:25.760
effect, which only required
estimating the propensity

00:58:25.760 --> 00:58:26.360
score.

00:58:26.360 --> 00:58:29.750
In particular, we never
had to use a model

00:58:29.750 --> 00:58:32.510
to predict Y in this
approach for estimating

00:58:32.510 --> 00:58:35.560
the average treatment
effect, and that's

00:58:35.560 --> 00:58:37.550
a good thing and a bad thing.

00:58:37.550 --> 00:58:40.480
It's a good thing
because, if you

00:58:40.480 --> 00:58:44.350
had errors in
estimating your model y,

00:58:44.350 --> 00:58:46.970
as I showed you in the very
beginning of today's lecture,

00:58:46.970 --> 00:58:48.638
that could have
a very big impact

00:58:48.638 --> 00:58:50.680
on your estimate of the
average treatment effect.

00:58:50.680 --> 00:58:52.790
And so that doesn't
show up here.

00:58:52.790 --> 00:58:56.090
On the other hand, this
has its own disadvantages.

00:58:56.090 --> 00:59:00.620
So for example, the propensity
score is going to be really,

00:59:00.620 --> 00:59:03.527
really affected by
lack of overlap,

00:59:03.527 --> 00:59:05.110
because when you
have lack of overlap,

00:59:05.110 --> 00:59:07.360
it means there's some data
points where the propensity

00:59:07.360 --> 00:59:09.870
score is very close to
0 or very close to 1.

00:59:09.870 --> 00:59:12.160
And that really leads
to very large variance

00:59:12.160 --> 00:59:14.120
in your estimators.

00:59:14.120 --> 00:59:16.270
And a very common
trick which is used

00:59:16.270 --> 00:59:17.650
to try to address
that concern is

00:59:17.650 --> 00:59:20.140
known as clipping, where you
simply clip the propensity

00:59:20.140 --> 00:59:22.810
scores so that they're always
bounding away from 0 and 1.

00:59:22.810 --> 00:59:24.760
But that's really
just a heuristic,

00:59:24.760 --> 00:59:28.270
and it can, of course, then
lead to biased estimates

00:59:28.270 --> 00:59:30.430
of the average treatment effect.

00:59:30.430 --> 00:59:33.900
So there's a whole family of
causal inference algorithms

00:59:33.900 --> 00:59:37.920
that attempt to use ideas
from both covariate adjustment

00:59:37.920 --> 00:59:40.770
and inverse
propensity weighting.

00:59:40.770 --> 00:59:42.420
For example, there's
a method called

00:59:42.420 --> 00:59:44.490
doubly robust
estimators, and we'll

00:59:44.490 --> 00:59:48.810
try to provide a citation for
those estimators in the Scribe

00:59:48.810 --> 00:59:49.610
notes.

00:59:49.610 --> 00:59:51.127
And these doubly
robust estimators

00:59:51.127 --> 00:59:53.460
are a different family of
estimators that actually bring

00:59:53.460 --> 00:59:55.410
in both of these
techniques together,

00:59:55.410 --> 00:59:56.980
and they have a
really nice property,

00:59:56.980 --> 00:59:59.190
which is that if either
one of them fail,

00:59:59.190 --> 01:00:03.700
you still get valid estimates
of average treatment effect.

01:00:03.700 --> 01:00:08.470
I'm going to skip this and just
jump to the summary now, which

01:00:08.470 --> 01:00:11.320
is that we've presented
two different approaches

01:00:11.320 --> 01:00:13.450
for causal inference
from observational data--

01:00:13.450 --> 01:00:17.690
covariate adjustment and
propensity score based methods.

01:00:17.690 --> 01:00:20.080
And both of these,
I need to stress,

01:00:20.080 --> 01:00:23.470
are only going to
give you valid results

01:00:23.470 --> 01:00:25.248
under the assumptions
we outlined

01:00:25.248 --> 01:00:26.290
in the previous lecture--

01:00:26.290 --> 01:00:29.410
for example, that your
causal graph is correct;

01:00:29.410 --> 01:00:32.650
critically, that there's no
unobserved confounding; and

01:00:32.650 --> 01:00:37.720
second, that you have overlap
between your two treatment

01:00:37.720 --> 01:00:38.980
classes.

01:00:38.980 --> 01:00:46.840
And third, if you're using
a non-parametric regression

01:00:46.840 --> 01:00:49.810
approach, overlap is
extremely important,

01:00:49.810 --> 01:00:52.990
because without
overlap, your model's

01:00:52.990 --> 01:00:56.050
undefined in regions of space.

01:00:56.050 --> 01:00:59.110
And thus, as a result,
you have no way

01:00:59.110 --> 01:01:03.490
of verifying if your
extrapolations are correct,

01:01:03.490 --> 01:01:07.780
and so one has to use trust
in the model, which is not

01:01:07.780 --> 01:01:09.670
something we really like.

01:01:09.670 --> 01:01:14.110
And in propensity score methods,
overlap is very important

01:01:14.110 --> 01:01:16.850
because if you don't have that,
you get inverse propensity

01:01:16.850 --> 01:01:19.330
scores that are either--

01:01:19.330 --> 01:01:23.180
which are infinite and lead
to extremely high variance

01:01:23.180 --> 01:01:25.060
estimators.

01:01:25.060 --> 01:01:28.120
So in the end of this
slide, which are already

01:01:28.120 --> 01:01:30.400
posted online, I
include some references

01:01:30.400 --> 01:01:33.130
that I strongly encourage
folks to follow up on.

01:01:33.130 --> 01:01:35.460
First references to two
recent workshops that

01:01:35.460 --> 01:01:37.960
have been held in the machine
learning community so that you

01:01:37.960 --> 01:01:41.170
can get a sense of what the
latest and greatest in terms

01:01:41.170 --> 01:01:43.810
of research in
causal inference are,

01:01:43.810 --> 01:01:46.030
two different books
on causal inference

01:01:46.030 --> 01:01:48.508
that you can download
for free from MIT,

01:01:48.508 --> 01:01:50.050
and finally, some
papers that I think

01:01:50.050 --> 01:01:52.092
are really interesting,
particularly of interest,

01:01:52.092 --> 01:01:54.100
potentially, to course projects.

01:01:54.100 --> 01:01:57.130
So we are at time now.

01:01:57.130 --> 01:01:59.860
I will hang around for a
few minutes after lecture,

01:01:59.860 --> 01:02:02.230
as I would normally.

01:02:02.230 --> 01:02:06.270
But I'm going to stop the
recording of the lecture.