In this notebook we demonstrate how to implement a *very simple* GAN network in less than 50 lines of code!

In [1]:
using Knet, Plots
gr(); Knet.gpu(0); atype = KnetArray{Float32};

In [2]:
xtrn = [1 2; -0.1 0.5]' * randn(2, 1000) .+ [1, 2];
ytrn = ones(UInt8, 1, 1000);
dtrn = minibatch(xtrn, ytrn, 4, xtype=atype, shuffle=true);

In [3]:
function initweights(d, hidden)
    model = Vector{Any}(2 * length(hidden))
    X = d
    for k = 1:length(hidden)
        H = hidden[k]
        model[2k - 1] = xavier(H, X) 
        model[2k]     = zeros(H, 1)
        X = H
    end
    return model
end

initweights (generic function with 1 method)

In [4]:
function predict(w, x)
    x = mat(x)
    for i=1:2:length(w) - 2
        x = tanh.(w[i] * x .+ w[i+1])
    end
    return w[end - 1]*x .+ w[end]
end

predict (generic function with 1 method)

In [5]:
function loss(w, x, y; wd=0)
    if wd == 0; return nll(predict(w, x), y); else; return nll(predict(wd, predict(w, x)), y) end
end
lossgradient  = grad(loss)

(::gradfun) (generic function with 1 method)

In [6]:
function train(w, x, y, optim; o...)
    g = lossgradient(w, x, y; o...); update!(w, g, optim); return w
end

train (generic function with 1 method)

In [7]:
wg = map(atype, initweights(2, [2]));
wd = map(atype, initweights(2, [5, 3, 2]));
optimg = optimizers(wg, Adam;  lr=0.001)
optimd = optimizers(wd, Adam;  lr=0.0025)

for epoch = 1:10
    for (x, y) in dtrn
        wd = train(wd, hcat(x, predict(wg, atype(randn(size(x))))), hcat(y, Array{UInt8}(2ones(size(y)))), optimd)  
        wg = train(wg, x, y, optimg; wd=wd)  
    end  
    xfake = Array(predict(wg, atype(randn(2, 100))))
    scatter(xtrn[1, :], xtrn[2, :], label=:true_data)
    display(scatter!(xfake[1, :], xfake[2, :],  label=:synthetic_data, size=(400,300)))
end