In [1]:
using DifferentialEquations
using Plots
include("NN_solver.jl")

predict (generic function with 2 methods)

In [2]:
default(size = (400, 300), linewidth=3, markersize=5, 
        markerstrokewidth=0)

In [3]:
function lotka_volterra!(t, u, du)
    du[1] = 1.5 .* u[1] - 1.0 .* u[1].*u[2]
    du[2] = -3 .* u[2] + u[1].*u[2]
end

y0_list = [1.0, 1.0]
tspan = (0.0,5.0)
t = collect(linspace(tspan[1], tspan[2], 51))
t = reshape(t, 1, :) # training points

1×51 Array{Float64,2}:
 0.0  0.1  0.2  0.3  0.4  0.5  0.6  0.7  …  4.4  4.5  4.6  4.7  4.8  4.9  5.0

In [4]:
prob = ODEProblem(lotka_volterra!, y0_list, tspan)
sol = solve(prob, saveat=0.05, reltol=1e-6, abstol=1e-6);

In [5]:
nn = init_nn(lotka_volterra!, t, y0_list, n_hidden = 20);
show(nn) # print basic info

Neural ODE Solver
Number of equations:       2
Initial condition y0:      [1.0, 1.0]
Numnber of hidden units:   20
Number of training points: 51


In [6]:
p_BFGS = readdlm("weights_BFGS.txt")
p_GD = readdlm("weights_GD.txt")

l_BFGS = readdlm("loss_BFGS.txt")
l_GD = readdlm("loss_GD.txt")

size(p_BFGS), size(l_BFGS)

((1001, 122), (1001, 1))

In [17]:
function quickplot(t, p)
    nn.params_list = get_unflat(p[t,:], nn)
    y_pred_list,_ = predict(nn)
    y_pred_list, _ = predict(nn)
    
    plot(nn.t[:], y_pred_list[1][:], label="y1 NN", lw=0, marker=:circle, legend = :topleft)
    plot!(nn.t[:], y_pred_list[2][:], label="y2 NN", lw=0, marker=:circle)
    
    plot!(sol.t, sol[1,:], label="y1 true")
    plot!(sol.t, sol[2,:], label="y2 true")


    ylims!(0, 8)
    xlabel!("t")
end

quickplot (generic function with 1 method)

In [18]:
quickplot(350, p_BFGS)
title!(@sprintf("BFGS; iter=%d; loss=%.2e",150,l_BFGS[150]))

## Time-series

In [19]:
for i=1:1001
    if (i%50 == 0) print(i," ") end
    quickplot(i, p_BFGS)
    title!(@sprintf("BFGS; iter=%d; loss=%.2e", i-1, l_BFGS[i]))
    savefig("./figures/BFGS_"*lpad(i,3,0)*".pdf")
end

50 100 150 200 250 300 350 400 450 500 550 600 650 700 750 800 850 900 950 1000 

In [20]:
for i=1:1001
    if (i%50 == 0) print(i," ") end
    quickplot(i, p_GD)
    title!(@sprintf("GD; iter=%d; loss=%.2e", i-1, l_GD[i]))
    savefig("./figures/GD_"*lpad(i,3,0)*".pdf")
end

50 100 150 200 250 300 350 400 450 500 550 600 650 700 750 800 850 900 950 1000 