Solve the [Lotka Volterra](https://en.wikipedia.org/wiki/Lotka–Volterra_equations) problem that the original [NeuralNetDiffEq.jl](https://julialang.org/blog/2017/10/gsoc-NeuralNetDiffEq) **failed completely**.

In [1]:
using Plots
using DifferentialEquations
include("NN_solver.jl")

predict (generic function with 2 methods)

In [2]:
default(size = (400, 300), linewidth=4, markersize=5, 
        markerstrokewidth=0)

# ODE definition

In [3]:
function lotka_volterra!(t, u, du)
    du[1] = 1.5 .* u[1] - 1.0 .* u[1].*u[2]
    du[2] = -3 .* u[2] + u[1].*u[2]
end

y0_list = [1.0, 1.0]
tspan = (0.0,5.0)

(0.0, 5.0)

# Solve by DifferentialEquations.jl as reference

In [4]:
prob = ODEProblem(lotka_volterra!, y0_list, tspan)
sol = solve(prob, saveat=0.05, reltol=1e-6, abstol=1e-6);

In [5]:
plot(sol.t, sol[1,:], label="y1", legend = :topleft)
plot!(sol.t, sol[2,:], label="y2")
ylims!(0, 7)
title!("Lotka-Volterra reference solution")
xlabel!("t")

In [6]:
savefig("reference.svg")

# My ANN solver

## Initialize ANN

In [7]:
t = collect(linspace(tspan[1], tspan[2], 51))
t = reshape(t, 1, :) # training points

1×51 Array{Float64,2}:
 0.0  0.1  0.2  0.3  0.4  0.5  0.6  0.7  …  4.4  4.5  4.6  4.7  4.8  4.9  5.0

In [8]:
nn = init_nn(lotka_volterra!, t, y0_list, n_hidden = 20);
show(nn) # print basic info

Neural ODE Solver
Number of equations:       2
Initial condition y0:      [1.0, 1.0]
Numnber of hidden units:   20
Number of training points: 51


**Note: The subsequent training can be sensitive to initialization. If it cannot converge, may need to re-initialize the weights here and train again.**

In [9]:
srand(1234) # set random seeds for reproducibility
reset_weights!(nn)

In [10]:
function quickplot()
    y_pred_list, _ = predict(nn)
    plot(sol.t, sol[1,:], label="y1 true", legend = :topleft)
    plot!(sol.t, sol[2,:], label="y2 true")

    plot!(nn.t[:], y_pred_list[1][:], label="y1 NN", lw=0, marker=:circle)
    plot!(nn.t[:], y_pred_list[2][:], label="y2 NN", lw=0, marker=:circle)
    ylims!(0, 7)
    xlabel!("t")
end

quickplot()

## First try a first-order optimizer

The basic Gradient Descent method converges very slowly:

In [11]:
srand(1234) # reset random seeds
reset_weights!(nn)

In [12]:
@time res_GD = train!(nn, GradientDescent(); iterations=1000)

Any[(:iterations, 1000)]




 86.158986 seconds (71.93 M allocations: 57.234 GiB, 11.76% gc time)


Results of Optimization Algorithm
 * Algorithm: Gradient Descent
 * Starting Point: [0.8673472019512456,-0.9017438158568171, ...]
 * Minimizer: [0.8698494291925695,-0.9038568423154282, ...]
 * Minimum: 3.080401e-01
 * Iterations: 1000
 * Convergence: false
   * |x - x'| < 1.0e-32: false 
     |x - x'| = 6.22e-05 
   * |f(x) - f(x')| / |f(x)| < 1.0e-32: false
     |f(x) - f(x')| / |f(x)| = 1.40e-04 
   * |g(x)| < 1.0e-08: false 
     |g(x)| = 1.80e-01 
   * stopped by an increasing objective: false
   * Reached Maximum Number of Iterations: true
 * Objective Calls: 2505
 * Gradient Calls: 2505

In [13]:
plot(Optim.f_trace(res_GD), yscale = :log10)
ylims!(1e-2, 1e3)

In [14]:
quickplot()

In [18]:
writedlm("weights_GD.txt", Optim.x_trace(res_GD))
writedlm("loss_GD.txt", Optim.f_trace(res_GD))

## Then use a second-order optimizer

In [19]:
srand(1234) # set random seeds for reproducibility
reset_weights!(nn)

In [20]:
@time res_BFGS = train!(nn, BFGS(); iterations=1000)

Any[(:iterations, 1000)]
 70.970442 seconds (61.32 M allocations: 55.676 GiB, 14.17% gc time)


Results of Optimization Algorithm
 * Algorithm: BFGS
 * Starting Point: [0.8673472019512456,-0.9017438158568171, ...]
 * Minimizer: [3.5269674134417857,-1.7068049859059131, ...]
 * Minimum: 1.231741e-04
 * Iterations: 1000
 * Convergence: false
   * |x - x'| < 1.0e-32: false 
     |x - x'| = 3.08e-02 
   * |f(x) - f(x')| / |f(x)| < 1.0e-32: false
     |f(x) - f(x')| / |f(x)| = 1.62e-02 
   * |g(x)| < 1.0e-08: false 
     |g(x)| = 1.43e-02 
   * stopped by an increasing objective: false
   * Reached Maximum Number of Iterations: true
 * Objective Calls: 2447
 * Gradient Calls: 2447

In [21]:
plot(Optim.f_trace(res_BFGS), yscale = :log10)
ylims!(1e-5, 1e3)

In [22]:
quickplot()

In [23]:
writedlm("weights_BFGS.txt", Optim.x_trace(res_BFGS))
writedlm("loss_BFGS.txt", Optim.f_trace(res_BFGS))