# A sensitivity-enabled ODE solver for Julia DiffEqs

## Sensitivity analysis

Given a model, sensitivity analysis (SA) quantifies the relationship between
model parameters and model outputs (not only the solutions.)

It is used in desing optimization, parameter estimation, optimial controls, or
experiemental design.

## Problem

Assume that we have the following initial-value problem for an ordinary
differential equation system:

$$
\begin{eqnarray}
\dot{y} = f(t, y, p) \\
y(t_{0})=y_{0}(p)
\end{eqnarray}
$$

where $y\in\mathbb{R}$, $p\in\mathbb{R}^{N_{p}}$, and $N_{p}$ is the number of
parameters.

## Forward Sensitivity Analysis

The goal is to compute the sensitivity of the solution and it's the simpler to
perform. Talking the derivative of the model with respect to the parameters:

$$
\begin{eqnarray}
\frac{\partial \dot{y}}{\partial p_{i}} =
\frac{\partial f}{\partial y}\frac{\partial y}{\partial p_{i}} +
\frac{\partial f}{\partial p_{i}} \\
\frac{\partial y(t_{0})}{\partial p_{i}} =
\frac{\partial y_{0}(p))}{\partial p_{i}}
\end{eqnarray}
$$

By defining $s{i}\equiv\frac{\partial y}{\partial p_{i}}$, we have

$$
\begin{eqnarray}
\dot{s_{i}} =
\frac{\partial f}{\partial y}s_{i} +
\frac{\partial f}{\partial p_{i}} \\
s_{i}(t_{0}) = \frac{\partial y_{0}(p))}{\partial p_{i}}
\end{eqnarray}
$$

Then the original ODE problem gets extended to a larger problem to include
differential equations for the sensitivity. The number of equations grows from
$N$ to $N N_{q}$, where $q<p$ is the number of parameters with respect
to which we want to calculate the sensitivity.

## Adjoint sensitivity analysis

Oten we are interested not in the sensitivity of the solution, but in the
sensitivity of a functional of the solution. In this case, it may
computationally easier to perform an adjoint sensitivity analysis. Suppose
that we want to calculate the sensitivity of the following functional

$$
G(p) = \int^{t_{f}}_{t_{0}} dt g(t, y, p)
$$

The sensitivity is

$$
\frac{\partial G}{\partial p_{i}} = \int^{t_{f}}_{t_{0}} dt
\left(\frac{\partial g}{\partial p_{i}} + \frac{\partial g}{\partial y}s_{i}\right)
$$

Note that $G(p)$ does not depend explicitily on the solution $y(t)$.

To avoid having to calculate $s_{i}$, we want to substitute for another
quantiy simpler to calculate. Using the differential equation for the
sensitivities

$$
\begin{eqnarray}
\dot{s_{i}} =
\frac{\partial f}{\partial y}s_{i} +
\frac{\partial f}{\partial p_{i}} \\
\end{eqnarray}
$$

Multiplying by a smooth function $\lambda(t)$ and integrating by parts

$$
\begin{eqnarray}
\lambda\frac{\partial f}{\partial p_{i}} =
\lambda\dot{s_{i}} - \frac{\partial f}{\partial y}s_{i} = \\
\frac{d}{dt}\left(\lambda s_{i}\right) - \dot{\lambda}s_{i}
- \frac{\partial f}{\partial y}s_{i}
\end{eqnarray}
$$

Now we *define* $\lambda(t)$ such as satisfies

$$
\dot{\lambda}=-\frac{\partial f}{\partial y}\lambda-\frac{\partial g}{\partial y}
$$

and therefore

$$
\lambda\frac{\partial f}{\partial p_{i}} -\frac{d}{dt}\left(\lambda s_{i}\right)
= \frac{\partial g}{\partial y}s_{i}
$$
and finally

$$
\frac{\partial G}{\partial p_{i}} = \int^{t_{f}}_{t_{0}} dt
\left(\frac{\partial g}{\partial p_{i}} + \lambda\frac{\partial f}{\partial p_{i}}\right)
+\lambda(t_{0})s_{i}(t_{0})
$$
where we have chosen $\lambda(t_{f})=0$ as the "initial" condition for
$\lambda(t)$.

Note that we have to solve one additional equation to calculate $\lambda(t)$.
We can calculate the Jacobians with respect to the parameters,
$\frac{\partial f}{\partial p_{i}}$ and 
$\frac{\partial g}{\partial p_{i}}$, using automatic differentiation.

## Problem Setup

In [1]:
using DifferentialEquations, ReverseDiff, ForwardDiff



In [2]:
struct ODESensProblem{F,F0,G,pType,tType}
    f::F
    f0::F0
    g::G
    p0::pType
    tspan::Tuple{tType,tType}
end

## Forward Sensitivities

In [3]:
struct ODEForwardSensFunction{F}
    f::F
    numparams::Int
    numindvar::Int
end

function ODEForwardSensFunction(f,u0,p0)
    numparams = length(p0)
    numindvar = length(u0)
    ODEForwardSensFunction(f,numparams,numindvar)
end

function (S::ODEForwardSensFunction)(t,u,p)
    du = Vector{eltype(u)}(S.numindvar*(S.numparams+1))
    y = @view u[1:S.numindvar]
    du[1:S.numindvar] = S.f(t,y,p)
    Jf = ReverseDiff.jacobian(S.f,([t],y,p)) # Calculate the Jacobian into Jf
    for i in 1:S.numparams
        Si = @view u[S.numindvar*i+1:S.numindvar*(i+1)]
        du[S.numindvar*i+1:S.numindvar*(i+1)] = 
            Jf[2]*Si+Jf[3][:,i]
    end
    return du
end

In [4]:
function ForwardSens(sp::ODESensProblem)
    # solve the ODE problem and sensitivity ODE problem at the same time
    u0      = sp.f0(sp.p0)
    Jf0     = ForwardDiff.jacobian(sp.f0,sp.p0)
    u0_sens = hcat(u0,Jf0)
    u0_sens = reshape(u0_sens,(length(u0_sens)))
    f_sens  = ODEForwardSensFunction(sp.f,u0,sp.p0)
    fp_sens = ParameterizedFunction(f_sens,sp.p0)
    ODEProb = ODEProblem(fp_sens,u0_sens,sp.tspan)    
    ODESol  = solve(ODEProb)
    
    # compute the final sensitivities
    uf      = @view ODESol.u[end][1:length(u0)]
    Uf      = reshape((@view ODESol.u[end][length(u0)+1:end]),(length(u0),length(sp.p0))) 
    Jg      = ReverseDiff.jacobian(sp.g,([sp.tspan[2]],uf,sp.p0))
    dgdp    = Jg[2]*Uf + Jg[3]
    
    return dgdp
end 

ForwardSens (generic function with 1 method)

## Adjoint Sensitivities

In [5]:
struct ODEAdjointSensFunction{F,ODESol}
    f::F
    numparams::Int
    numindvar::Int
    numdepvar::Int
    ODESolution::ODESol
end

function ODEAdjointSensFunction(f,p0,u0,g0,ODESolution)
    numparams = length(p0)
    numindvar = length(u0)
    numdepvar = length(g0)
    ODEAdjointSensFunction(f,numparams,numindvar,numdepvar,ODESolution)
end

function (S::ODEAdjointSensFunction)(t,u,p)
    du = similar(u)
    λ  = reshape((@view u[1:S.numdepvar*S.numindvar]),(S.numindvar,S.numdepvar))
    y  = S.ODESolution(t)
    Jf = ReverseDiff.jacobian(S.f,([t],y,p))
    dλ = -Jf[2]'*λ
    dλint = -λ'*Jf[3]
    du = [reshape(dλ,(length(dλ)));reshape(dλint,(length(dλint)))]
    return du
end

In [6]:
function AdjointSens(sp::ODESensProblem)
    # first solve the ODE problem
    u0      = sp.f0(sp.p0)
    fp      = ParameterizedFunction(sp.f,sp.p0)
    ODEProb = ODEProblem(fp,u0,sp.tspan)
    ODESol  = solve(ODEProb)
    
    # next solve the adjoint sensitivities
    Jg      = ReverseDiff.jacobian(sp.g,([ODESol.t[end]],ODESol.u[end],sp.p0))
    #λT      = [Jg[2]';Jg[3]']
    #λT      = reshape(λT,(length(λT)))
    λT      = [reshape(Jg[2]',length(Jg[2]));reshape(Jg[3],length(Jg[3]))]
    Tspan   = (sp.tspan[2],sp.tspan[1])
    f_adj   = ODEAdjointSensFunction(sp.f,sp.p0,u0,sp.g(sp.tspan[2],u0,sp.p0),ODESol)
    fp_adj  = ParameterizedFunction(f_adj,sp.p0)
    AdjProb = ODEProblem(fp_adj,λT,Tspan)
    AdjSol  = solve(AdjProb)
    
    # finally, compute the final sensitivities
    Jf0     = ForwardDiff.jacobian(sp.f0,sp.p0)
    λ0      = reshape(AdjSol.u[end][1:length(Jg[2])],size(Jg[2]'))
    λint    = reshape(AdjSol.u[end][length(Jg[2])+1:end],size(Jg[3]))
    dgdp    = λint .+ λ0'*Jf0
    
    return dgdp
end 

AdjointSens (generic function with 1 method)

## Examples

### Linear Example

$$
\begin{eqnarray}
\dot{u} = 
\begin{bmatrix}
    -p_1 & 0 \\
    p_1 & p_2
\end{bmatrix}
u \\
u(0) = \begin{bmatrix} 1 & 1 \end{bmatrix}^T \\
g(t,u,p)= ||u||^2 \\
p = \begin{bmatrix} 1 & 2 \end{bmatrix}^T
\end{eqnarray}
$$

In [7]:
f_linear(t,u,p)=[-p[1]*u[1];p[1]*u[1]-p[2]*u[2]]
f0(p)=[1.0;1.0]
g(t,u,p)=[u[1]*u[1]+u[2]*u[2]]

u0 = [1.0;1.0]
p0_linear = [1.0;2.0]
tspan = (0.0,1.0)

sp_linear = ODESensProblem(f_linear,f0,g,p0_linear,tspan)

ODESensProblem{#f_linear,#f0,#g,Array{Float64,1},Float64}(f_linear, f0, g, [1.0, 2.0], (0.0, 1.0))

In [8]:
ForwardSens(sp_linear)

1×2 Array{Float64,2}:
 -0.199149  -0.171096

In [9]:
AdjointSens(sp_linear)

1×2 Array{Float64,2}:
 -0.199149  -0.171096

In [10]:
# Check Answer using Analytical Solution
function linear_sol(t,u0,p)
    ut = [u0[1]*exp.(-p[1]*t);
          -u0[1]*p[1]/(p[1]-p[2])*exp.(-p[1]*t)+
          u0[1]*p[1]/(p[1]-p[2])*exp.(-p[2]*t)+
          u0[2]*exp.(-p[2]*t)]
    
    g= [ut[1]*ut[1]+ut[2]*ut[2]]
end
ReverseDiff.jacobian(linear_sol,([1.0],u0,p0_linear))[3]

1×2 Array{Float64,2}:
 -0.199148  -0.171096

In [11]:
# Also check using symbolic differentiation methods
pf_linear = @ode_def_nohes LinTest begin
    dx = -a*x
    dy =  a*x - b*y
    end a=>1 b=>2
prob_linear = ODELocalSensitivityProblem(pf_linear,u0,tspan)
sol_linear = solve(prob_linear);
solend=sol_linear[end]
2*solend[1]*solend[3:2:5]'+2*solend[2]*solend[4:2:6]'

1×2 RowVector{Float64,Array{Float64,1}}:
 -0.199149  -0.171096

In [12]:
# Time comparison
function test1()
    f_linear(t,u,p)=[-p[1]*u[1];p[1]*u[1]-p[2]*u[2]]
    f0(p)=[1.0;1.0]
    g(t,u,p)=[u[1]*u[1]+u[2]*u[2]]

    u0 = [1.0;1.0]
    p0_linear = [1.0;2.0]
    tspan = (0.0,1.0)

    sp_linear = ODESensProblem(f_linear,f0,g,p0_linear,tspan)
    ForwardSens(sp_linear)
end

function test2()
    f_linear(t,u,p)=[-p[1]*u[1];p[1]*u[1]-p[2]*u[2]]
    f0(p)=[1.0;1.0]
    g(t,u,p)=[u[1]*u[1]+u[2]*u[2]]

    u0 = [1.0;1.0]
    p0_linear = [1.0;2.0]
    tspan = (0.0,1.0)

    sp_linear = ODESensProblem(f_linear,f0,g,p0_linear,tspan)
    AdjointSens(sp_linear)
end

pf_linear = @ode_def_nohes LinTest begin
        dx = -a*x
        dy =  a*x - b*y
    end a=>1 b=>2

function test3()
    u0 = [1.0;1.0]
    tspan = (0.0,1.0)
    prob_linear = ODELocalSensitivityProblem(pf_linear,u0,tspan)
    sol_linear = solve(prob_linear);
    solend=sol_linear[end]
    2*solend[1]*solend[3:2:5]'+2*solend[2]*solend[4:2:6]'
end

test3 (generic function with 1 method)

In [13]:
@time test1()
@time test2()
@time test3()

  6.809529 seconds (845.39 k allocations: 117.158 MiB, 0.73% gc time)
 14.176764 seconds (1.72 M allocations: 236.757 MiB, 0.62% gc time)
  5.024735 seconds (729.04 k allocations: 110.375 MiB, 1.15% gc time)


1×2 RowVector{Float64,Array{Float64,1}}:
 -0.199149  -0.171096

In [14]:
@time test1()
@time test2()
@time test3()

  0.010698 seconds (6.70 k allocations: 380.917 KiB)
  0.019767 seconds (9.70 k allocations: 578.584 KiB)
  0.002419 seconds (1.35 k allocations: 91.854 KiB)


1×2 RowVector{Float64,Array{Float64,1}}:
 -0.199149  -0.171096

### Nonlinear Example (Lotka-Volterra)
$$
\begin{eqnarray}
\dot{u} = 
\begin{bmatrix}
    p_1u_1-p_2u_1u_2 \\
    -p_3u_2+p_4u_1u_2
\end{bmatrix} \\
u(0) = \begin{bmatrix} 1 & 1 \end{bmatrix}^T \\
g(t,u,p)= ||u||^2 \\
p = \begin{bmatrix} 1.5 & 1.0 & 3.0 & 1.0 \end{bmatrix}^T
\end{eqnarray}
$$

In [15]:
f_nonlinear(t,u,p)=[p[1]*u[1]-p[2]*u[1]*u[2];-p[3]*u[2]+p[4]*u[1]*u[2]]
f0(p)=[1.0;1.0]
g(t,u,p)=[u[1]*u[1]+u[2]*u[2]]

u0 = [1.0;1.0]
p0_nonlinear = [1.5;1.0;3.0;1.0]
tspan = (0.0,5.0)

sp_nonlinear = ODESensProblem(f_nonlinear,f0,g,p0_nonlinear,tspan)

ODESensProblem{#f_nonlinear,#f0,#g,Array{Float64,1},Float64}(f_nonlinear, f0, g, [1.5, 1.0, 3.0, 1.0], (0.0, 5.0))

In [16]:
ForwardSens(sp_nonlinear)

1×4 Array{Float64,2}:
 156.039  -21.2792  48.4524  -21.8133

In [17]:
AdjointSens(sp_nonlinear)

1×4 Array{Float64,2}:
 155.669  -21.1979  48.4611  -22.085

In [18]:
# Compare to using symbolic differentiation package
pf_nonlinear = @ode_def_nohes LVSens begin
    dx = a*x - b*x*y
    dy = -c*y + d*x*y
    end a=>1.5 b=>1 c=>3 d=>1
prob_nonlinear = ODELocalSensitivityProblem(pf_nonlinear,u0,tspan)
sol_nonlinear = solve(prob_nonlinear);
solnl_end=sol_nonlinear[end]
2*solnl_end[1]*solnl_end[3:2:9]'+2*solnl_end[2]*solnl_end[4:2:10]'

1×4 RowVector{Float64,Array{Float64,1}}:
 156.039  -21.2792  48.4524  -21.8133

In [19]:
# Timing results
function test4()
    f_nonlinear(t,u,p)=[p[1]*u[1]-p[2]*u[1]*u[2];-p[3]*u[2]+p[4]*u[1]*u[2]]
    f0(p)=[1.0;1.0]
    g(t,u,p)=[u[1]*u[1]+u[2]*u[2]]

    u0 = [1.0;1.0]
    p0_nonlinear = [1.5;1.0;3.0;1.0]
    tspan = (0.0,5.0)

    sp_nonlinear = ODESensProblem(f_nonlinear,f0,g,p0_nonlinear,tspan)
    ForwardSens(sp_nonlinear)
end

function test5()
    f_nonlinear(t,u,p)=[p[1]*u[1]-p[2]*u[1]*u[2];-p[3]*u[2]+p[4]*u[1]*u[2]]
    f0(p)=[1.0;1.0]
    g(t,u,p)=[u[1]*u[1]+u[2]*u[2]]

    u0 = [1.0;1.0]
    p0_nonlinear = [1.5;1.0;3.0;1.0]
    tspan = (0.0,5.0)

    sp_nonlinear = ODESensProblem(f_nonlinear,f0,g,p0_nonlinear,tspan)
    AdjointSens(sp_nonlinear)
end

    pf_nonlinear = @ode_def_nohes LVSens begin
        dx = a*x - b*x*y
        dy = -c*y + d*x*y
    end a=>1.5 b=>1 c=>3 d=>1

function test6()
    u0 = [1.0;1.0]
    tspan = (0.0,5.0)
    prob_nonlinear = ODELocalSensitivityProblem(pf_nonlinear,u0,tspan)
    sol_nonlinear = solve(prob_nonlinear);
    solnl_end=sol_nonlinear[end]
    2*solnl_end[1]*solnl_end[3:2:9]'+2*solnl_end[2]*solnl_end[4:2:10]'
end

test6 (generic function with 1 method)

In [20]:
@time test4()
@time test5()
@time test6()

  6.978251 seconds (877.00 k allocations: 118.773 MiB, 0.73% gc time)
 14.314572 seconds (1.75 M allocations: 238.326 MiB, 0.81% gc time)
  4.843474 seconds (698.59 k allocations: 108.463 MiB, 0.86% gc time)


1×4 RowVector{Float64,Array{Float64,1}}:
 156.039  -21.2792  48.4524  -21.8133

In [21]:
@time test4()
@time test5()
@time test6()

  0.036617 seconds (25.96 k allocations: 1.390 MiB)
  0.028101 seconds (28.42 k allocations: 1.578 MiB)
  0.004719 seconds (4.29 k allocations: 289.026 KiB)


1×4 RowVector{Float64,Array{Float64,1}}:
 156.039  -21.2792  48.4524  -21.8133