In [1]:
using Optim
using AutoGrad

# API input

In [2]:
function f!(t, y, dy)
    # for ODE solver
    dy[1] = t+998.0*y[1] + 1998.0*y[2]
    dy[2] = -999.0*y[1] - 1999.0*y[2]
end

f! (generic function with 1 method)

In [3]:
t = [0:0.1:2;]
t = reshape(t, 1, :) # shape = [n_input,n_sample]

1×21 Array{Float64,2}:
 0.0  0.1  0.2  0.3  0.4  0.5  0.6  0.7  …  1.4  1.5  1.6  1.7  1.8  1.9  2.0

In [4]:
y0_list = [2.0, -1.0]

2-element Array{Float64,1}:
  2.0
 -1.0

# Preprocessing

In [5]:
t0 = t[1]
Nvar = length(y0_list)

2

In [6]:
function f_list(t, y)
    # For NN
    N = length(y)
    dy = Vector{}(N) # empty list
    f!(t, y, dy)
    
    return dy
end

f_list (generic function with 1 method)

# Building blocks

In [7]:
function init_weights(;n_in=1, n_hidden=10, n_out=1)
    W1 = randn(n_hidden, n_in) # for left multiply W1*x
    b1 = zeros(n_hidden)
    W2 = randn(n_out, n_hidden)
    b2 = zeros(n_out)
    params = [W1, b1, W2, b2]
    return params
end

function sizes_and_length(params)
    sizes = map(size, params)
    lengths = map(length, params)
    total_l = sum(lengths)

    return sizes, lengths, total_l
end

sizes_and_length (generic function with 1 method)

In [8]:
sizes_and_length(init_weights())

(Tuple{Int64,Vararg{Int64,N} where N}[(10, 1), (10,), (1, 10), (1,)], [10, 10, 10, 1], 31)

In [9]:
function predict(params, t, y0, t0; act=tanh)
    
    W1, b1, W2, b2 = params
    
    # normal NN calculation
    a = act.(W1*t .+ b1)
    out = W2*a .+ b2
    
    # force intial condition
    y = y0 .+ (t .- t0) .* out
    
    return y
end

predict (generic function with 1 method)

In [10]:
predict_sum(params, t, y0, t0) = sum(predict(params, t, y0, t0))
predict_dt = grad(predict_sum, 2)

(::gradfun) (generic function with 1 method)

# NN model type


## Basic info

In [11]:
mutable struct NNModel
    Nvar     # number of ODE equations
    f_list   # f(t, y) on the right hand side
    y0_list  # initial conditions
    t        # training points
    t0       # initial t, i.e. t[1]
    n_hidden # number of hidden units

    params_list    # a list of NN parameters
    sizes   # size of NN parameters, for flatten operation
    lengths # length of NN parameters, for flatten operation
    total_l  # total length of NN parameters, for flatten operation
    all_flattened_params # flattenned parameters, for Optim.jl
end

In [12]:
function pre_init_nn(f!, t, y0_list; n_hidden=10)
    Nvar = length(y0_list)
    t0 = t[1]
    
    function f_list(t, y)
        dy = Vector{}(length(y))
        f!(t, y, dy)
        return dy
    end
    
    return NNModel(Nvar, f_list, y0_list, t, t0, n_hidden, 
                   nothing, nothing, nothing, nothing, nothing)
end

function show(nn::NNModel)
    println("Neural ODE Solver")
    println("Number of equations:       ", nn.Nvar)
    println("Initial condition y0:      ", nn.y0_list)
    println("Numnber of hidden units:   ", nn.n_hidden)
    println("Number of training points: ", length(nn.t))
end

show (generic function with 1 method)

In [13]:
nn = pre_init_nn(f!, t, y0_list)

NNModel(2, f_list, [2.0, -1.0], [0.0 0.1 … 1.9 2.0], 0.0, 10, nothing, nothing, nothing, nothing, nothing)

In [14]:
show(nn)

Neural ODE Solver
Number of equations:       2
Initial condition y0:      [2.0, -1.0]
Numnber of hidden units:   10
Number of training points: 21


## Add weights information

In [15]:
flat_opt = p->collect(Iterators.flatten(p)) # flatten operation for a single NN params

(::#9) (generic function with 1 method)

In [16]:
function reset_weights!(nn::NNModel)
    # update weights
    nn.params_list = [init_weights(n_hidden = nn.n_hidden) 
                      for _ in 1:nn.Nvar]
    nn.sizes, nn.lengths, nn.total_l = sizes_and_length(nn.params_list[1])
    
    # update flattened weights
    nn.all_flattened_params = vcat(map(flat_opt, nn.params_list)...)
    return nothing
end

reset_weights! (generic function with 1 method)

In [17]:
reset_weights!(nn)

In [18]:
size(nn.all_flattened_params)

(62,)

## Put together

In [19]:
function init_nn(f!, t, y0_list; n_hidden=10)
    nn = pre_init_nn(f!, t, y0_list)
    reset_weights!(nn)
    return nn
end

init_nn (generic function with 1 method)

In [20]:
nn = init_nn(f!, t, y0_list);
show(nn)

Neural ODE Solver
Number of equations:       2
Initial condition y0:      [2.0, -1.0]
Numnber of hidden units:   10
Number of training points: 21


In [21]:
size(nn.all_flattened_params)

(62,)

# Unflatten functions (not operate on NN directly)

## Single NN

In [22]:
function unflatten(params_flat, sizes, lengths)
    params = []
    i1 = 1
    for j in 1:length(sizes)
        s = sizes[j]
        l = lengths[j]
        i2 = i1+l
        #p = reshape(params_flat[i1:i2-1], s)
        p = reshape(view(params_flat,i1:i2-1), s)
        push!(params, p)
        i1 = i2
    end 
    return params
end

unflatten (generic function with 1 method)

In [23]:
params_test = init_weights()
sizes_test, lengths_test = sizes_and_length(params_test)

(Tuple{Int64,Vararg{Int64,N} where N}[(10, 1), (10,), (1, 10), (1,)], [10, 10, 10, 1], 31)

In [24]:
params_flat_test = collect(Iterators.flatten(params_test));

In [25]:
unflatten(params_flat_test, sizes_test, lengths_test) == params_test

true

## Multilple NN

In [26]:
function unflatten_all(all_flattened_params, Nvar, total_l, sizes, lengths)
    params_list = []
    all_params_reshape = reshape(all_flattened_params, total_l, Nvar)
    for i in 1:Nvar
        params = unflatten(all_params_reshape[:,i], sizes, lengths)
        push!(params_list, params)
    end
    return params_list
end

unflatten_all (generic function with 1 method)

In [27]:
params_recover = unflatten_all(nn.all_flattened_params, nn.Nvar, nn.total_l, 
                               nn.sizes, nn.lengths)

2-element Array{Any,1}:
 Any[[0.575811; -0.0853388; … ; 0.0205407; -0.558314], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.10233 -2.17488 … 1.07334 0.0343746], [0.0]]
 Any[[1.52646; -0.174612; … ; -0.284359; 0.442893], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-1.65078 -1.48589 … -0.891213 -1.14649], [0.0]] 

In [28]:
params_recover == nn.params_list

true

# Loss function

In [29]:
function loss_func(params_list, nn::NNModel)
    # need to expose params_list to optim.jl
    
    # shortcut
    Nvar = nn.Nvar
    f_list = nn.f_list
    t = nn.t
    
    y_pred_list = []
    dydt_pred_list = []
    
    for i = 1:Nvar
        params = params_list[i]
        y0 = y0_list[i]
        y_pred =  predict(params, t, y0, t0)
        dydt_pred = predict_dt(params, t, y0, t0)
        
        push!(y_pred_list, y_pred)
        push!(dydt_pred_list, dydt_pred)
    end
        
    f_pred_list = f_list(t, y_pred_list)
    
    loss_total = 0.0
    for i = 1:Nvar
        f_pred = f_pred_list[i]
        dydt_pred = dydt_pred_list[i]
        loss = mean(abs2, dydt_pred - f_pred)
        loss_total += loss
    end
    
    return loss_total

end

loss_func (generic function with 1 method)

In [30]:
# reset_weights!(nn)
loss_func(nn.params_list, nn)

2.1932848766712155e6

# Training

In [31]:
function get_unflat(all_flattened_params, nn::NNModel)
    # don't modify nn itself
    return unflatten_all(all_flattened_params, 
        nn.Nvar, nn.total_l, nn.sizes, nn.lengths)
end

get_unflat (generic function with 1 method)

In [32]:
get_unflat(nn.all_flattened_params, nn) == nn.params_list

true

In [33]:
function train!(nn::NNModel; method=BFGS(), iterations=500, show_every=50)

    # for optim.jl
    function loss_wrap(all_flattened_params)
        params_list = get_unflat(all_flattened_params, nn)
        return loss_func(params_list, nn)
    end
    
    # configuration optimazation options
    od = OnceDifferentiable(loss_wrap, nn.all_flattened_params; autodiff =:forward);
    option = Optim.Options(iterations=iterations, show_trace=true, show_every=show_every)
    
    # training
    opt = optimize(od, nn.all_flattened_params, method, option)
    
    # update weights
    nn.all_flattened_params = opt.minimizer # flattened weights
    nn.params_list = get_unflat(opt.minimizer, nn) # original weights
    
end

train! (generic function with 1 method)

In [34]:
reset_weights!(nn)
@time train!(nn)



Iter     Function value   Gradient norm 
     0     6.154519e+05     8.318763e+06
    50     4.128740e+00     5.727815e+02
   100     6.610481e-01     4.820127e+02
   150     2.147895e-01     3.749863e+01
   200     2.814772e-02     1.009211e+02
   250     4.915779e-03     1.420258e+01
   300     2.795057e-03     2.081087e+01
   350     1.886847e-03     1.677503e+01
   400     8.645246e-04     2.625430e+00
   450     3.416373e-04     2.016820e+00
   500     1.870863e-04     1.338760e+00
 22.623613 seconds (30.86 M allocations: 4.998 GiB, 3.36% gc time)


2-element Array{Any,1}:
 Any[[-0.462421; -0.0663361; … ; -0.727093; 1.00037], [-0.0254224, -1.71344, -0.18786, 1.22524, 0.0269862, 0.408243, 0.333453, -0.0153271, 0.0902273, -1.66902], [-2.14424 2.48715 … 0.334147 0.133116], [-1.48216]]
 Any[[-1.55256; -0.377941; … ; 1.03184; -1.23321], [0.676434, 0.0066661, -0.68498, -0.224742, -0.297432, -0.665359, 0.639513, -1.05162, -1.16149, -2.63728], [-1.74321 1.28149 … 0.170286 0.603559], [-0.361049]]   

# Prediction

In [35]:
function predict(nn::NNModel; t=nothing)
    
    if t == nothing 
        t = nn.t # predict on training points by default
    end
    
    # shortcut
    Nvar = nn.Nvar
    f_list = nn.f_list
    params_list = nn.params_list
    
    y_pred_list = []
    dydt_pred_list = []
    
    for i = 1:Nvar
        params = params_list[i]
        y0 = y0_list[i]
        y_pred =  predict(params, t, y0, t0)
        dydt_pred = predict_dt(params, t, y0, t0)
        
        push!(y_pred_list, y_pred)
        push!(dydt_pred_list, dydt_pred)
    end
    
    return y_pred_list, dydt_pred_list
end

predict (generic function with 2 methods)

In [36]:
y_pred_list, dydt_pred_list = predict(nn)

(Any[[2.0 1.81951 … 2.40151 2.54], [-1.0 -0.909705 … -1.19982 -1.26902]], Any[[-2.00303 -1.61563 … 1.36681 1.40126], [1.00105 0.808584 … -0.683033 -0.699888]])

In [37]:
using Plots
default(size = (300, 200)) # plot size

In [38]:
plot(nn.t[:], y_pred_list[1][:])

# Predict new points

In [39]:
t_test = [0.0:0.01:2;]
t_test = reshape(t_test, 1, :) # shape = [n_input,n_sample]

1×201 Array{Float64,2}:
 0.0  0.01  0.02  0.03  0.04  0.05  …  1.95  1.96  1.97  1.98  1.99  2.0

In [40]:
y_test, _ = predict(nn, t=t_test)

(Any[[2.0 1.98018 … 2.526 2.54], [-1.0 -0.990091 … -1.26202 -1.26902]], Any[[-2.00303 -1.96183 … 1.39828 1.40126], [1.00105 0.980799 … -0.698455 -0.699888]])

In [41]:
plot(t_test[:], y_test[1][:])