# File: drc.py
# 
# drc stands for "Dose-Response Curve".
# 
# Copyright 2014 by Zak Fallows
# CC-BY license, please feel free to reuse this code, just give me credit.
# 
# This file is heavily modeled after the following:
#     ~/Documents/SP236/Dose response curves/matplotlib/dose_response.py


#============================ How to Use This File ============================#
# 
# # First, you must install Python 2.7, Matplotlib, and IPython.
# 
# # Second, cd into whatever directory holds drc.py:
# 
# mac> cd Dose_Response_Curves
# 
# # Third, start Pylab:
# 
# mac> ipython --pylab
# 
# py>
# 
# # Fourth, import drc:
# 
# py> import drc
# 
# # Finally, you can draw graphs:
# 
# py> drc.draw_test01()
# 
# # A Tk window should pop-up with a linear graph for morphine, complete with 
# # crosshairs at the EC50 (which is 10 mg according to me).
# 
# # Draw another graph:
# 
# py> drc.draw_ocw_small(2)
# 
# # A second Tk window should pop-up.


from matplotlib.pyplot import *
from matplotlib import rcParams
import numpy as np


# Set some graph appearance parameters:
rcParams['lines.linewidth'] = 3
rcParams['lines.color'] = 'k'
rcParams['xtick.major.width'] = 3
rcParams['ytick.major.width'] = 3
rcParams['xtick.labelsize'] = 'large'
rcParams['ytick.labelsize'] = 'large'


def response(dose, k, g=1.0, ie=100.0):
    """Given a dose and parameters, calculate the response
    
    This function is pure pharmacodynamics, no Matplotlib, no graphing.
    
    Arguments:
        dose    Float
        k       Float, EC50, IC50, or Km (Michaelis-Menten constant)
        g       Float, gamma, Hill coefficient, the shape parameter
        ie      Float, intrinsic efficacy, or Vmax in the case of enzymes
    
    """
    
    if dose == 0.0:
        return 0.0
    
    return ie * 1.0 / (1.0 + ((k/dose)**g))


def make_x_data(semilog, min_dose=None, max_dose=None, steps=200):
    """Output a list of floats, each float representing a dose in mg"""
    
    if semilog == True:
        # This is a semilog plot
        
        if min_dose == None:
            min_dose = 0.01
        if max_dose == None:
            max_dose = 10000.0
        
        if min_dose <= 0.0:
            raise Exception(
                "When using a semilog graph, min_dose must be strictly "
                "greater than zero. It was %s." % min_dose)
        
        total_ratio = (1.0 * max_dose) / min_dose
        total_log = np.log(total_ratio)
        step_ratio = np.exp(total_log / steps)  # Should NOT be steps + 1
        x_data = [min_dose]
        curr_dose = min_dose
        while curr_dose <= max_dose:
            curr_dose *= step_ratio
            x_data.append(curr_dose)
        return x_data
    
    elif semilog == False:
        if min_dose == None:
            min_dose = 0.0
        if max_dose == None:
            max_dose = 300.0
        
        step_size = (1.0 * max_dose - min_dose) / (steps + 1)
        x_data = [min_dose]
        curr_dose = min_dose
        while curr_dose <= max_dose:
            curr_dose += step_size
            x_data.append(curr_dose)
        return x_data
    
    else:
        raise Exception("semilog must be either True or False, but instead "
                        "it was %s." % semilog)


def make_y_data(x_data, k, g=1.0, ie=100.0):
    """Output a list of floats, usually in the range [0.0, 100.0]"""
    
    y_data = []
    for x_val in x_data:
        y_data.append(response(x_val, k, g=g, ie=ie))
    return y_data


def set_size(width, height, dpi=200):
    """Set the size of the figure (graph)
    
    Arguments:
        width   Float, in inches
        height  Float, in inches
        dpi     Int, dots per inch
    
    Note that the 'width' and 'height' DO affect the resolution (i.e. the size in pixels) of the output image, but 'dpi' does NOT affect the resolution of the output image. The 'dpi' value does affect the size of the graph on screen when using Tk and Pylab.
    
    """
    
    figure(figsize=(width, height), dpi=dpi)


def set_axis_limits(xmin, xmax, ymax=100.0):
    """This prevents unnecessary whitespace"""
    
    xlim(xmin=xmin, xmax=xmax)
    ylim(ymax=ymax)


def crosshairs(x_val, y_val, xmin=0.0, linestyle='dashed', color='k', 
               alpha=1.0):
    """Draw 'crosshairs' at a particular point
    
    There will be a vertical line up to the point and a horizontal line over 
    to the point from the left.
    
    """
    
    vlines(x=x_val, ymin=0,    ymax=y_val, linestyle=linestyle, 
           color=color, alpha=alpha)
    hlines(y=y_val, xmin=xmin, xmax=x_val, linestyle=linestyle, 
           color=color, alpha=alpha)


def draw_test01():
    """Draw a simple linear plot for morphine"""
    
    k = 10.0    # An EC50 of 10 mg works well for morphine.
    # Use default g
    # Use default ie
    min_dose = 0.0
    max_dose = 100.0
    steps = 300
    
    x_data = make_x_data(False, min_dose, max_dose, steps)
    y_data = make_y_data(x_data, k)
    set_size(4.0, 4.0)
    set_axis_limits(min_dose, max_dose)
    line, = plot(x_data, y_data, 'k-')
    # Note the comma in the line above, "line,". The comma is there because 
    # in this case, plot() returns a Python list of length 1. We don't really 
    # want a list, we want the single list element, which is a Line2D object. 
    # Putting in that comma makes Python unpack the list so that "line" will be 
    # a Line2D object and not a Python list of length 1.
    
    # Draw crosshairs at the EC50:
    y_ec50 = response(k, k)
    crosshairs(k, y_ec50, xmin=min_dose)


def draw_test02():
    """Draw a simple semilog plot for morphine"""
    
    k = 10.0    # An EC50 of 10 mg works well for morphine.
    # Use default g
    # Use default ie
    min_dose = 0.1
    max_dose = 1000.0
    steps = 300
    
    x_data = make_x_data(True, min_dose, max_dose, steps)
    y_data = make_y_data(x_data, k)
    set_size(4.0, 4.0)
    set_axis_limits(min_dose, max_dose)
    line, = semilogx(x_data, y_data, 'k-')
    
    # Draw crosshairs at the EC50:
    y_ec50 = response(k, k)
    crosshairs(k, y_ec50, xmin=min_dose)
    
    # Make x tick marks:
    xticks([0.1, 1.0, 10.0, 100.0, 1000.0],
           ["0.1", "1", "10", "100", "1,000"])


def draw_test03():
    """Linear plot with custom dashes"""
    
    k = 10.0    # An EC50 of 10 mg works well for morphine.
    # Use default g
    # Use default ie
    min_dose = 0.0
    max_dose = 100.0
    steps = 300
    
    x_data = make_x_data(False, min_dose, max_dose, steps)
    y_data = make_y_data(x_data, k)
    set_size(4.0, 4.0)
    set_axis_limits(min_dose, max_dose)
    line, = plot(x_data, y_data, 'k--')
    
    # Set the width of the dashes:
    line.set_dashes((8, 2, 4, 6))
    # The argument to set_dashes() must be a sequence of numbers. The sequence 
    # must be an even length, i.e. 2, 4, or 6, but NOT 3. The sequence above 
    # means:
    # "8-wide dark patch, 2-wide gap, 4-wide dark patch, 6-wide gap"
    # 
    # WARNING: You CANNOT use .set_dashes() on an hlines or vlines, it will 
    # cause a terrible crash, not just an exception. You can only use this 
    # method on the output from plot() and semilogx() and so on.


def horizontal_line(y, xmin, xmax, semilog, linestyle='dashed', color='k', 
                    alpha=1.0, dashes=(None, None)):
    """Draw a horizontal line, possibly dashed, using plot() and not hlines()
    
    This function is necessary because hlines() cannot accept custom dash widths via .set_dashes().
    
    """
    
    if semilog:
        line, = semilogx([xmin, xmax], [y, y], linestyle=linestyle, 
                         color=color, alpha=alpha)
    else:
        line, = plot([xmin, xmax], [y, y], linestyle=linestyle, color=color, 
                     alpha=alpha)
    line.set_dashes(dashes)


def draw_test04():
    """Linear plot for morphine with horizontal line"""
    
    k = 10.0    # An EC50 of 10 mg works well for morphine.
    # Use default g
    # Use default ie
    min_dose = 0.0
    max_dose = 100.0
    steps = 300
    
    x_data = make_x_data(False, min_dose, max_dose, steps)
    y_data = make_y_data(x_data, k)
    set_size(4.0, 4.0)
    set_axis_limits(min_dose, max_dose)
    line, = plot(x_data, y_data, 'k-')
    
    # Add the horizontal lines:
    horizontal_line(40.0, min_dose, max_dose, False, linestyle='dashed', 
                    color='k', alpha=0.5, dashes=(7, 3))
    horizontal_line(60.0, 20.0, 70.0, False, linestyle='dashed')
    # Note that dashes=(None, None) results in a solid line.


def draw_ocw_small(mode=1):
    """Draw the Course Home graphic for OCW
    
    OCW's Course Home must have an image that is exactly 320px wide and at 
    least 240px tall.
    
    Requirements:
        .jpg
        320 px wide exactly
        At least 240 px tall
    
    mode == 1:
        Just draw the 3 drug curves.
    
    mode == 2:
        Also draw the isoeffective lines.
    
    mode == 3:
        Draw 3 drug curves and just 3 isoeffective lines.
    
    NOTE: We must draw the isoeffective lines BEFORE the drug curves, that way 
    the isoeffective lines will be behind the drug curves (z-index).
    
    """
    
    min_dose = 0.01
    max_dose = 200.0
    steps = 300
    
    x_data = make_x_data(True, min_dose, max_dose, steps)
    ## The following makes it big, you can take a screen shot:
    set_size(4.0, 2.5, 240)
    ## The following makes it roughly final size:
    # set_size(4.0, 2.5, 100)
    set_axis_limits(min_dose, max_dose)
    
    if mode == 2:
        # Draw the 5 isoeffective lines:
        
        ld = min_dose
        hd = max_dose
        
        hln = horizontal_line
        
        # From bottom (low dose) to top:
        
        # One Vicodin tablet (normal size, i.e. 5/500):
        one_vicodin = response(1.7, 10.0)
        # The code above uses the fact that 5 mg of hydrocodone (oral) is 
        # isoeffective with 1.7 mg of IV morphine, and IV morphine has a k 
        # value of 10.0 mg.
        hln(one_vicodin, ld, hd, True, color='r', dashes=(7, 7))
        
        # In the hospital after a car accident:
        hln(50, ld, hd, True, color='b', dashes=(10, 5, 3, 5))
        
        # Powerful euphoria:
        hln(200.0/3, ld, hd, True, color='y', dashes=(7, 7))
        
        # Total apnea (fatal overdose):
        hln(80, ld, hd, True, color='r', dashes=(10, 5, 3, 5))
        
        # Surgical-depth anesthesia:
        surgical = response(60.0, 10.0)
        # The line above uses the fact that 60 mg of IV morphine (k=10mg) is a 
        # reasonable dose for opioid-only anesthesia.
        hln(surgical, ld, hd, True, color='b', dashes=(7, 7))
    
    elif mode == 3:
        # Draw just 3 isoeffective lines:
        
        ld = min_dose
        hd = max_dose
        
        hln = horizontal_line
        
        # From bottom (low dose) to top:
        
        # One Vicodin tablet (normal size, i.e. 5/500):
        one_vicodin = response(1.7, 10.0)
        hln(one_vicodin, ld, hd, True, color='y', dashes=(7, 7))
        
        # Standard dose (10 mg IV morphine):
        hln(50, ld, hd, True, color='b', dashes=(10, 5, 3, 5))
        
        # Total apnea (fatal overdose):
        hln(80, ld, hd, True, color='r', dashes=(7, 7))
    
    # Draw the 3 drug curves:
    
    # Fentanyl:
    fentanyl_data = make_y_data(x_data, k=0.1)
    line1, = semilogx(x_data, fentanyl_data, 'r-')
    
    # Buprenorphine:
    bupe_data = make_y_data(x_data, k=0.12, ie=70.0)
    line2, = semilogx(x_data, bupe_data, 'b-')
    
    # Heroin:
    heroin_data = make_y_data(x_data, k=7.0)
    line3, = semilogx(x_data, heroin_data, 'k-')
    
    # Make x tick marks:
    xticks([0.1, 1.0, 10.0, 100.0],
           ["0.1", "1", "10", "100"])
    
    # Make y tick marks (none):
    yticks([], [])


def draw_ocw_lines(mode=1):
    """Draw lines to be cut up and reused for captions
    
    mode == 1:
        Draw 3 solid lines representing the drugs
    
    mode == 2:
        Draw 5 dashed lines representing the isoeffective lines
    
    """
    
    min_dose = 0.01
    max_dose = 200.0
    steps = 300
    
    x_data = make_x_data(True, min_dose, max_dose, steps)
    set_size(4.0, 2.5, 240)
    set_axis_limits(min_dose, max_dose)
    ylim(ymin=0.0, ymax=100.0)
    
    ld = min_dose
    hd = max_dose
    
    if mode == 1:
        # Fentanyl-colored line:
        hlines(y=90, xmin=ld, xmax=hd, colors='r')
    
        # Buprenorphine-colored line:
        hlines(y=70, xmin=ld, xmax=hd, colors='b')
    
        # Heroin-colored line:
        hlines(y=50, xmin=ld, xmax=hd, colors='k')
    
    elif mode == 2:
        hln = horizontal_line
        
        # From bottom (low dose) to top:
        
        # One Vicodin:
        hln(10, ld, hd, True, color='r', dashes=(7, 7))
        
        # In the hospital after a car accident:
        hln(30, ld, hd, True, color='b', dashes=(10, 5, 3, 5))
        
        # Powerful euphoria:
        hln(50, ld, hd, True, color='y', dashes=(7, 7))
        
        # Total apnea (fatal overdose):
        hln(70, ld, hd, True, color='r', dashes=(10, 5, 3, 5))
        
        # Surgical-depth anesthesia:
        hln(90, ld, hd, True, color='b', dashes=(7, 7))
    
    else:
        raise Exception("Invalid mode argument, it was %s." % mode)


def draw_empty_box():
    """Draw an empty box for the legends"""
    
    min_dose = 0.01
    max_dose = 200.0
    
    # Legend for figure 1 (just 3 lines tall):
    # set_size(4.0, 1.2, 240)
    # Legend for figure 2 (5 lines tall):
    set_size(4.0, 2.0, 240)
    set_axis_limits(min_dose, max_dose)
    ylim(ymin=0.0, ymax=100.0)
    
    x_data = [300.0, 310.0]
    y_data = [50.0, 60.0]
    line, = plot(x_data, y_data, 'w-')
    
    # Make the x and y ticks (none and none):
    xticks([], [])
    yticks([], [])


def draw_morphine(semilog=True):
    """Draw a plot for morphine with crosshairs at 5 points"""
    
    if semilog:
        min_dose = 0.05
        max_dose = 2000.0
    else:
        min_dose = 0.0
        max_dose = 80.0
    steps = 300
    
    x_data = make_x_data(semilog, min_dose, max_dose, steps)
    y_data = make_y_data(x_data, k=10.0)
    
    set_size(7.0, 5.0, 100)
    set_axis_limits(min_dose, max_dose)
    
    # Draw the plot:
    if semilog:
        line1, = semilogx(x_data, y_data, 'k-')
    else:
        line1, = plot(x_data, y_data, 'k-')
    
    # Draw the crosshairs (from bottom to top):
    
    def morphine_crosshairs(dose):
        effect = response(dose, 10.0)
        crosshairs(dose, effect, min_dose, alpha=0.42)
        print "Dose:   %4.1f mg" % dose
        print "Effect: %.1f" % effect
        return effect
    
    # One Vicodin tablet (5/500):
    e1 = morphine_crosshairs(1.7)
    # Dose:    1.7 mg
    # Effect: 14.5
    
    # Powerful hospital pain relief:
    e2 = morphine_crosshairs(10.0)
    # Dose:   10.0 mg
    # Effect: 50.0
    
    # Strong euphoria:
    e3 = morphine_crosshairs(20.0)
    # Dose:   20.0 mg
    # Effect: 66.7
    
    # Total apnea (fatal overdose):
    e4 = morphine_crosshairs(40.0)
    # Dose:   40.0 mg
    # Effect: 80.0
    
    # Surgical-depth anesthesia:
    e5 = morphine_crosshairs(60.0)
    # Dose:   60.0 mg
    # Effect: 85.7
    
    # Set titles:
    
    fig = line1.figure
    fig.suptitle("Morphine Dose-Response Curve", fontsize=16)
    # xlabel("Dose (mg)", fontsize=14)
    # ylabel("Response", fontsize=14)
    
    if semilog:
        xticks([0.1, 1.0, 10.0, 100.0, 1000.0],
               ["0.1", "1", "10", "100", "1,000"])
    else:
        pass
    
    yticks([e1, e2, e3, e4, e5],
           ["A", "B", "C", "D", "E"])

#================================ Do at Import ================================#

# Make Figure 1, just 3 drug curves:
# draw_ocw_small(1)

# Make Figure 2, 3 drug curves plus 5 isoeffective lines:
# draw_ocw_small(2)

# Make lines for the 3 drugs, to be cut and used in the legend:
# draw_ocw_lines(1)

# Make the 5 isoeffective lines, to be cut and used in the legend:
# draw_ocw_lines(2)

# Make the empty box, to be used in the legend:
# draw_empty_box()

# Draw 3 drug curves with just 3 isoeffective lines:
# draw_ocw_small(3)

# Draw a plot for morphine with crosshairs at 5 points:
# draw_morphine(False)
# draw_morphine(True)
