In [1]:
using Optim
using AutoGrad
using Plots
using DifferentialEquations

In [2]:
default(size = (300, 200)) # plot size

# Define ODE

In [3]:
function f!(t, y, dy)
    # for ODE solver
    dy[1] = t+998.0*y[1] + 1998.0*y[2]
    dy[2] = -999.0*y[1] - 1999.0*y[2]
end

f! (generic function with 1 method)

In [4]:
function f_list(t, y)
    # For NN
    N = length(y)
    dy = Vector{}(N) # empty list
    f!(t, y, dy)
    
    return dy
end

f_list (generic function with 1 method)

In [5]:
test_t = [1.0 2 3 4]
test_y = [test_t, test_t*2]
test_t

1Ã—4 Array{Float64,2}:
 1.0  2.0  3.0  4.0

In [6]:
pred_f = f_list(test_t, test_y)

2-element Array{Any,1}:
 [4995.0 9990.0 14985.0 19980.0]    
 [-4997.0 -9994.0 -14991.0 -19988.0]

In [7]:
for f in pred_f
    println(f, size(f))
end

[4995.0 9990.0 14985.0 19980.0](1, 4)
[-4997.0 -9994.0 -14991.0 -19988.0](1, 4)


# Solve by DifferentialEquations.jl

In [8]:
y0_list = [2.0, -1.0]
tspan = (0.0, 2.0) # has to be tuple

(0.0, 2.0)

In [9]:
prob = ODEProblem(f!, y0_list, tspan)
sol = solve(prob);

In [10]:
plot(sol.t, sol[1,:])
plot!(sol.t, sol[2,:])

# Build NN

In [11]:
function init_weights(;n_in=1, n_hidden=10, n_out=1)
    W1 = randn(n_hidden, n_in) # for left multiply W1*x
    b1 = zeros(n_hidden)
    W2 = randn(n_out, n_hidden)
    b2 = zeros(n_out)
    params = [W1, b1, W2, b2]
    return params
end

params1 = init_weights()
params2 = init_weights()
params_list = [params1, params2]
sizes = map(size, params1)


4-element Array{Tuple{Int64,Vararg{Int64,N} where N},1}:
 (10, 1)
 (10,)  
 (1, 10)
 (1,)   

In [12]:
t = collect(linspace(tspan[1], tspan[2], 21))
t = reshape(t, 1, :) # shape = [n_input,n_sample]

1Ã—21 Array{Float64,2}:
 0.0  0.1  0.2  0.3  0.4  0.5  0.6  0.7  â€¦  1.4  1.5  1.6  1.7  1.8  1.9  2.0

In [13]:
function predict(params, t, y0, t0; act=tanh)
    
    W1, b1, W2, b2 = params
    
    # normal NN calculation
    a = act.(W1*t .+ b1)
    y = W2*a .+ b2
    
    # force intial condition
    phi = y0 .+ (t .- t0) .* y
    
    return phi
end

predict (generic function with 1 method)

In [14]:
y_pred = predict(params1, t, y0_list[1], tspan[1])

1Ã—21 Array{Float64,2}:
 2.0  1.98694  1.9499  1.8945  1.82791  â€¦  1.29006  1.28479  1.28349  1.28591

In [15]:
plot(t[:], y_pred[:])

## Grad NN w.r.t to t

In [16]:
predict_sum(params, t, y0, t0) = sum(predict(params, t, y0, t0))
predict_dt = grad(predict_sum, 2)

(::gradfun) (generic function with 1 method)

In [17]:
dydt_pred = predict_dt(params1, t, y0_list[1], tspan[1])

1Ã—21 Array{Float64,2}:
 0.0  -0.257522  -0.473576  -0.622108  â€¦  -0.0324464  0.00599613  0.0419105

In [18]:
plot(t[:], dydt_pred[:])

# Loss function (the hardest part)

In [19]:
function loss_func(params_list, t, y0_list, t0)
    
    Nvar = length(y0_list)
    y_pred_list = []
    dydt_pred_list = []
    
    for i = 1:Nvar
        params = params_list[i]
        y0 = y0_list[i]
        y_pred =  predict(params, t, y0, t0)
        dydt_pred = predict_dt(params, t, y0, t0)
        
        push!(y_pred_list, y_pred)
        push!(dydt_pred_list, dydt_pred)
    end
        
    f_pred_list = f_list(t, y_pred_list)
    
    loss_total = 0.0
    for i = 1:Nvar
        f_pred = f_pred_list[i]
        dydt_pred = dydt_pred_list[i]
        loss = mean(abs2, dydt_pred - f_pred)
        loss_total += loss
        
        #println(loss, size(f_pred), size(dydt_pred))
    end
    
    return loss_total

end

loss_func (generic function with 1 method)

In [20]:
loss_func(params_list, t, y0_list, tspan[1])

# Flatten parameters

In [21]:
function unflatten(params_flat)
    params = []
    i1 = 1
    for s in sizes # sizes is defined outside of the function
        l = reduce(*, s) # size -> length
        i2 = i1+l
        #p = reshape(params_flat[i1:i2-1], s)
        p = reshape(view(params_flat,i1:i2-1), s)
        push!(params, p)
        i1 = i2
    end 
    return params
end

unflatten (generic function with 1 method)

In [22]:
Nvar = 2
l_total = 0

for s in sizes # sizes is defined outside of the function
    l = reduce(*, s)
    l_total += l
end
l_total, Nvar

(31, 2)

In [23]:
params_flat1 = collect(Iterators.flatten(params1))
params_flat2 = collect(Iterators.flatten(params2))
all_params_flat = [params_flat1; params_flat2];

In [24]:
all_params_reshape = reshape(all_params_flat, l_total, Nvar)
all_params_reshape[:,1];

In [25]:
function unflatten_all(all_params_flat)
    params_list = []
    all_params_reshape = reshape(all_params_flat, l_total, Nvar)
    for i in 1:Nvar
        params = unflatten(all_params_reshape[:,i])
        push!(params_list, params)
    end
    return params_list
end

unflatten_all (generic function with 1 method)

In [26]:
unflatten_all(all_params_flat)

2-element Array{Any,1}:
 Any[[-0.0478109; -0.98554; â€¦ ; 1.02648; 0.385927], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-0.0738517 1.38207 â€¦ 1.86756 0.833887], [0.0]]  
 Any[[-0.353488; -0.52281; â€¦ ; 0.430679; 0.946903], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.716115 -0.649736 â€¦ 0.276327 -0.511665], [0.0]]

In [27]:
# use global t for now
function loss_wrap(all_params_flat)
    params_list = unflatten_all(all_params_flat)
    return  loss_func(params_list, t, y0_list, tspan[1])
end

loss_wrap (generic function with 1 method)

In [28]:
loss_wrap(all_params_flat)

# Optim.jl

In [29]:
od = OnceDifferentiable(loss_wrap, all_params_flat; autodiff =:forward);
typeof(od)

NLSolversBase.OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1},Val{false}}

In [30]:
option = Optim.Options(iterations = 500, show_trace=true, show_every=50)

Optim.Options{Float64,Void}(1.0e-32, 1.0e-32, 1.0e-8, 0, 0, 0, false, 500, false, true, false, 50, nothing, NaN)

In [31]:
@time opt = optimize(od, all_params_flat, BFGS(), option)



Iter     Function value   Gradient norm 
     0     5.776944e+05     4.419892e+06
    50     3.051153e-01     7.188397e+02
   100     2.366292e-02     3.281797e+01
   150     6.717213e-03     9.314632e+00
   200     5.216245e-03     5.184082e+00
   250     3.115397e-03     1.432958e+01
   300     1.750149e-03     1.606706e+01
   350     1.445951e-03     5.806129e+00
   400     1.236017e-03     1.024962e+00
   450     1.043306e-03     1.598668e+01
   500     8.370421e-04     5.661081e-01
 15.680441 seconds (22.61 M allocations: 4.578 GiB, 6.04% gc time)


Results of Optimization Algorithm
 * Algorithm: BFGS
 * Starting Point: [-0.047810877287854596,-0.9855401422759466, ...]
 * Minimizer: [0.19058455546287376,-0.7789795141559013, ...]
 * Minimum: 8.370421e-04
 * Iterations: 500
 * Convergence: false
   * |x - x'| < 1.0e-32: false 
     |x - x'| = 2.85e-03 
   * |f(x) - f(x')| / |f(x)| < 1.0e-32: false
     |f(x) - f(x')| / |f(x)| = 1.44e-03 
   * |g(x)| < 1.0e-08: false 
     |g(x)| = 5.66e-01 
   * stopped by an increasing objective: false
   * Reached Maximum Number of Iterations: true
 * Objective Calls: 1503
 * Gradient Calls: 1503

In [32]:
loss_wrap(opt.minimizer)

In [33]:
opt_params_list = unflatten_all(opt.minimizer)

2-element Array{Any,1}:
 Any[[0.190585; -0.77898; â€¦ ; 0.743805; 0.322602], [-0.477079, -0.129497, -0.360529, -0.2488, 0.450666, -0.209943, 0.175051, -0.199491, -0.0535283, 0.201541], [0.0700667 1.56625 â€¦ 2.43266 1.28518], [-1.42449]]       
 Any[[-0.218437; -0.696553; â€¦ ; 0.630682; 1.28474], [0.0208482, 0.0828659, 0.39399, -0.248002, 0.195414, 0.056064, 0.0879515, -0.269222, -0.383348, 0.0811448], [0.802731 -0.937241 â€¦ -0.00209753 -1.17578], [0.441589]]

In [34]:
y1_pred = predict(opt_params_list[1], t, y0_list[1], tspan[1])

1Ã—21 Array{Float64,2}:
 2.0  1.81623  1.66632  1.55067  â€¦  2.14731  2.2722  2.40203  2.53609

In [35]:
plot(sol.t, sol[1,:], label="y1 true")
plot!(t[:], y1_pred[:], label="y1 predict", lw=0, marker=:circle, markerstrokewidth = 0)