Classification with LSTM

Sandya Subramanian

HST (Health Sciences and Technology), 2nd year Ph.D.

Setup

In [1]:
using CSV, Plots, Images
using Knet, AutoGrad
using Knet: sigm_dot, tanh_dot
using StatsBase
WARNING: Method definition ==(Base.Nullable{S}, Base.Nullable{T}) in module Base at nullable.jl:238 overwritten in module NullableArrays at /home/nsrl/juliapro/JuliaPro-0.6.1.1/JuliaPro/pkgs-0.6.1.1/v0.6/NullableArrays/src/operators.jl:99.
In [2]:
Knet.gpu(0)
Out[2]:
0
In [3]:
include("findpeaks.jl")
include("pan_tompkins.jl")
include("run_pan_tompkins.jl")
include("get_breaks.jl")
include("get_dataset.jl")
include("get_comb_dataset.jl")
include("get_traintest.jl")
Out[3]:
get_traintest (generic function with 1 method)

LSTM

After seeing that classification was really successful with just MLP, I decided to try LSTM just to see if it made a difference to have the time dependencies already encoded. Since ECG is a very regular signal in time, one would expect that temporal dependencies really matter.

In [21]:
#Task 1: Create dictionary
function createVocabulary(mode)
    vocab = Dict{Int64,Int64}()
    # Depends on if binary or multiclass
    if mode == "binary"
        for i = 1:2
            vocab[i] = i
        end
    else
        for i = 1:8
            vocab[i] = i
        end
    end
    return vocab
end

#LSTM Network Function
function lstm(weight,bias,hidden,cell,input)
    gates   = hcat(input,hidden) * weight .+ bias
    hsize   = size(hidden,2) #size of layer
    forget  = sigm_dot(gates[:,1:hsize]) #Assumed size H+I * 4H
    ingate  = sigm_dot(gates[:,1+hsize:2hsize])
    outgate = sigm_dot(gates[:,1+2hsize:3hsize])
    change  = tanh_dot(gates[:,1+3hsize:end])
    cell    = cell .* forget + ingate .* change
    hidden  = outgate .* tanh_dot(cell)
    return (hidden,cell)
end

#Task 2: Create Initial Weights
function initweights(lenhidden, lenvocab, embed, arraytype)
    init(d...) = xavier(d...)
    bias(d...) = zeros(d...)
    tmp_model = Vector{Array{Float32}}(2*length(lenhidden)+2) #2*num layers + 2
    X = embed #size of x input (721)
    for k = 1:length(lenhidden)
        tmp_model[2k-1] = init(X+lenhidden[k],4*lenhidden[k]) #Each element of the Vector is a matrix itself
        tmp_model[2k] = bias(1,4*lenhidden[k])
        #Combine all the weights
        #Biases separately
        X = lenhidden[k] #Replace previous layer size
    end
    #INPUT
    #tmp_model[end-2] = init(lenvocab,embed) #Size of vocab by size of input x?
    #OUTPUT - ok as is
    tmp_model[end-1] = init(lenhidden[end],lenvocab) #Size of last layer by number of words - for output?
    tmp_model[end] = bias(1,lenvocab) #Zero vector for each word?
    
    if arraytype == "KnetArray"
        model = Vector{KnetArray{Float32}}(2*length(lenhidden)+2)
        for k = 1:length(tmp_model)
            model[k] = KnetArray{Float32}(tmp_model[k]);
        end
    else
        model = copy(tmp_model);
    end
    return model
end

#Task 3: Create Initial State
function initstate(model, batch)
    #Check what type to use
    if typeof(model[1]) == Array{Float64,2}
        arraytype = "Array"
    else 
        arraytype = "KnetArray"
    end
    nlayers = div(length(model)-2,2)
    state = Vector{Any}(2*nlayers)
    for k = 1:nlayers
        #Get length of layer
        H = div(size(model[2k],2),4)
        state[2k-1] = state[2k] = zeros(batch,H)
        #cell and hidden for each layer both vectors length H
    end
    
    if arraytype == "KnetArray"
        for k = 1:length(state)
            state[k] = KnetArray{Float32}(state[k]);
        end
    end
    return state
end

#Task 4: Create Predict function
function predict(model, state, input; pdrop=0)
    input = KnetArray(input');
    nlayers = div(length(model)-2,2)
    newstate = similar(state)
    for k = 1:nlayers
        #Run through by selecting the right elements corresponding to that layer
        #state[2k-1] -> hidden
        #state[2k] -> cell
        #model[2k-1] -> weights
        #model[2k] -> bias
        hidden, cell = lstm(model[2k-1],model[2k],state[2k-1],state[2k],input)
        newstate[2k-1] = hidden
        newstate[2k] = cell
        input = hidden
    end
    output = input * model[end-1] .+ model[end] #To get outputs
    return output,newstate
end

function accuracy_LSTM(model,state,data,p; average=true)
    sum = cnt = 0
    for (x,y) in data
        sum += accuracy(p(model,state,x)[1]',y)
        cnt += 1
    end
    average ? sum / cnt : sum
end

#Task 5: Create Loss Function
function loss(model, state, x_batch, y; newstate=nothing, pdrop=0)
    preds,newstate = predict(model,state,x_batch)
    if newstate != nothing
        copy!(newstate, map(AutoGrad.getval,state))
    end
    return nll(preds',y)
end

#Task 6: Create Train Function
function train(model, dtrn, optim, lossgradient, options; pdrop=0)
    state = initstate(model, options["batchsize"])
    for (x,y) in dtrn
        #temp = model[1]
        grads = lossgradient(model, state, x, y; newstate=state, pdrop=pdrop)
        update!(model, grads, optim)
        #println(sum(model[1]-temp))
    end
    return model
end
Out[21]:
train (generic function with 1 method)
In [7]:
function prep_dataset(subj_list,fs)
    abn_dataset, full_dataset, abn_truth_cats, full_bin_cats, label_key = get_comb_dataset(subj_list,fs)
    println(countmap(abn_truth_cats))
    println(countmap(full_bin_cats))
    xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = get_traintest(abn_dataset,full_dataset,abn_truth_cats,full_bin_cats,0.1)
    
    return xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin 
end
Out[7]:
prep_dataset (generic function with 1 method)
In [8]:
function setup(numepochs,lenhidden,mode,arraytype)
  #Options
  lenhidden    = lenhidden;          # Sizes of one or more LSTM layers.

  options1 = Dict{String,Any}()
  options1["togenerate"]   = Int64(4000)            # If non-zero generate given number of characters.
  options1["epochs"]       = Int64(numepochs)      # Number of epochs for training.
  options1["embed"]        = Int64(721)            # Size of the embedding vector.
  options1["batchsize"]    = Int64(50)            # Number of sequences to train on in parallel
 # options1["seqlength"]    = Int64(721)             # Maximum number of steps to unroll the network for bptt. Initial epochs will use the epoch number as bptt length for faster convergence.
  options1["seed"]         = -1            # Random number seed. -1 or 0 is no fixed seed
  options1["lr"]           = 0.001           # Initial learning rate
  options1["gclip"]        = 3.0           # Value to clip the gradient norm at.
  options1["dpout"]        = 0.0            # Dropout probability.  

  #options1["seed"] > 0 && srand(options1["seed"])

  # Prep dataset
  #xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = prep_dataset([207,212,203,209,201],360)
  xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = prep_dataset([207, 212, 203, 209, 201, 202, 205, 208, 210, 213, 220, 221, 222, 230, 111, 112, 113, 114, 115, 116, 117, 118, 119, 121, 122, 123, 124, 100, 101, 103, 104, 106, 108, 109, 232, 233, 234],360)
  signal = [xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin]

  vocab = createVocabulary(mode);
  info("$(length(vocab)) unique values.") # The output should be 75 unique chars for input.txt

  #Now, let's compute the accuracy of a random model
  model = initweights(lenhidden, length(vocab), options1["embed"], arraytype);
    
  return options1, lenhidden, signal, vocab, model
end

function getbatches(options1,signal,mode)
  #Now we are ready. First, let's see the intial loss.
  #Feed input matrix based on mode
  if mode == "binary"
    dtrn =  minibatch(signal[7], signal[8], options1["batchsize"]);
    dtst =  minibatch(signal[5], signal[6], options1["batchsize"]);
  else
    dtrn =  minibatch(signal[3], signal[4], options1["batchsize"]);
    dtst =  minibatch(signal[1], signal[2], options1["batchsize"]);
  end
  return dtrn,dtst
end

function setuptrain(options1, dtrn, dtst, model)
  # Knet magic
  lossgradient = grad(loss,1);

  # Print the loss of train and test sets with random model.
  firststate = initstate(model, options1["batchsize"])
  println((:epoch,0,:train,accuracy_LSTM(model,firststate,dtrn,predict),:test,accuracy_LSTM(model,firststate,dtst,predict)))

  #Below is the training part of RNN (with Adam)
  optim = map(x->Adam(lr=options1["lr"], gclip=options1["gclip"]), model);
  return lossgradient, optim
end

function trainloop(options1, vocab, dtrn, dtst, lossgradient, optim, model)
  # MAIN LOOP
  firststate = initstate(model, options1["batchsize"])
  for epoch=1:options1["epochs"]
      @time train(model, dtrn, optim, lossgradient, options1; pdrop=options1["dpout"])
      # Calculate and print the losses after each epoch
      println((:epoch,epoch,:train,accuracy_LSTM(model,firststate,dtrn,predict),:test,accuracy_LSTM(model,firststate,dtst,predict)))  
    end
  return model
end
Out[8]:
trainloop (generic function with 1 method)
In [9]:
function main(numepochs,lenhidden,mode,arraytype)
    options1, lenhidden, signal, vocab, model = setup(numepochs,lenhidden,mode,arraytype)
    dtrn,dtst = getbatches(options1, signal, mode)
    lossgradient, optim = setuptrain(options1, dtrn, dtst, model)
    model = trainloop(options1, vocab, dtrn, dtst, lossgradient, optim, model)
end
Out[9]:
main (generic function with 1 method)

Binary Classification

Interestingly enough, I was able to achieve the same high accuracy but with far fewer layers. Perhaps the inherent structure of the LSTM simplifies the number of layers necessary.

In [16]:
main(100,[128],"binary","KnetArray")
Dict(7=>4332,4=>1380,2=>3096,3=>3730,5=>4332,8=>1404,6=>1770,1=>4348)
Dict(2=>24392,1=>19264)
INFO: 2 unique values.
(:epoch, 0, :train, 0.5580891719745236, :test, 0.5645977011494253)
  1.179427 seconds (1.05 M allocations: 256.373 MiB, 3.07% gc time)
(:epoch, 1, :train, 0.895694267515923, :test, 0.8979310344827585)
  1.120322 seconds (1.06 M allocations: 256.546 MiB, 3.44% gc time)
(:epoch, 2, :train, 0.9309554140127395, :test, 0.9289655172413791)
  1.122957 seconds (1.06 M allocations: 256.517 MiB, 3.42% gc time)
(:epoch, 3, :train, 0.9430573248407673, :test, 0.9402298850574707)
  1.122412 seconds (1.06 M allocations: 256.489 MiB, 3.42% gc time)
(:epoch, 4, :train, 0.9470573248407692, :test, 0.945057471264367)
  1.121317 seconds (1.05 M allocations: 256.426 MiB, 3.16% gc time)
(:epoch, 5, :train, 0.9539872611465025, :test, 0.9514942528735626)
  1.117789 seconds (1.06 M allocations: 256.538 MiB, 3.48% gc time)
(:epoch, 6, :train, 0.9576305732484135, :test, 0.9554022988505741)
  1.118980 seconds (1.06 M allocations: 256.509 MiB, 3.43% gc time)
(:epoch, 7, :train, 0.9600254777070121, :test, 0.9565517241379305)
  1.115834 seconds (1.06 M allocations: 256.447 MiB, 3.24% gc time)
(:epoch, 8, :train, 0.9621656050955476, :test, 0.9583908045977007)
  1.116610 seconds (1.05 M allocations: 256.418 MiB, 3.17% gc time)
(:epoch, 9, :train, 0.9632101910828085, :test, 0.9616091954022986)
  1.235988 seconds (1.06 M allocations: 256.530 MiB, 12.84% gc time)
(:epoch, 10, :train, 0.9637961783439555, :test, 0.9622988505747123)
  1.169724 seconds (1.06 M allocations: 256.501 MiB, 3.28% gc time)
(:epoch, 11, :train, 0.9645605095541466, :test, 0.9625287356321836)
  1.281197 seconds (1.06 M allocations: 256.439 MiB, 2.83% gc time)
(:epoch, 12, :train, 0.9659108280254847, :test, 0.9629885057471261)
  1.392455 seconds (1.06 M allocations: 256.551 MiB, 2.84% gc time)
(:epoch, 13, :train, 0.9674140127388607, :test, 0.9627586206896548)
  1.113776 seconds (1.06 M allocations: 256.522 MiB, 3.44% gc time)
(:epoch, 14, :train, 0.9686878980891793, :test, 0.9634482758620687)
  1.244877 seconds (1.06 M allocations: 256.494 MiB, 13.05% gc time)
(:epoch, 15, :train, 0.9690700636942746, :test, 0.9632183908045971)
  1.148228 seconds (1.06 M allocations: 256.431 MiB, 3.17% gc time)
(:epoch, 16, :train, 0.9704458598726191, :test, 0.9652873563218386)
  1.112645 seconds (1.06 M allocations: 256.543 MiB, 3.47% gc time)
(:epoch, 17, :train, 0.9701910828025555, :test, 0.9643678160919535)
  1.113293 seconds (1.06 M allocations: 256.514 MiB, 3.45% gc time)
(:epoch, 18, :train, 0.97289171974523, :test, 0.9680459770114941)
  1.113641 seconds (1.06 M allocations: 256.485 MiB, 3.42% gc time)
(:epoch, 19, :train, 0.9737834394904529, :test, 0.9678160919540229)
  1.132586 seconds (1.05 M allocations: 256.423 MiB, 3.11% gc time)
(:epoch, 20, :train, 0.9751082802547842, :test, 0.9682758620689651)
  1.116786 seconds (1.06 M allocations: 256.535 MiB, 3.46% gc time)
(:epoch, 21, :train, 0.9751337579617902, :test, 0.9685057471264367)
  1.244251 seconds (1.06 M allocations: 256.506 MiB, 12.80% gc time)
(:epoch, 22, :train, 0.9757452229299435, :test, 0.9680459770114942)
  1.106218 seconds (1.06 M allocations: 256.444 MiB, 3.22% gc time)
(:epoch, 23, :train, 0.9760000000000068, :test, 0.9673563218390804)
  1.125400 seconds (1.06 M allocations: 256.555 MiB, 3.45% gc time)
(:epoch, 24, :train, 0.9767388535031917, :test, 0.9691954022988507)
  1.249622 seconds (1.06 M allocations: 256.527 MiB, 13.18% gc time)
(:epoch, 25, :train, 0.9765605095541466, :test, 0.9691954022988505)
  1.135282 seconds (1.06 M allocations: 256.498 MiB, 3.41% gc time)
(:epoch, 26, :train, 0.9767388535031916, :test, 0.9694252873563218)
  1.125558 seconds (1.06 M allocations: 256.436 MiB, 3.20% gc time)
(:epoch, 27, :train, 0.9770191082802611, :test, 0.9698850574712645)
  1.121484 seconds (1.06 M allocations: 256.547 MiB, 3.45% gc time)
(:epoch, 28, :train, 0.9769426751592418, :test, 0.9701149425287355)
  1.356009 seconds (1.06 M allocations: 256.519 MiB, 13.19% gc time)
(:epoch, 29, :train, 0.9766624203821719, :test, 0.9698850574712643)
  1.177600 seconds (1.06 M allocations: 256.490 MiB, 3.20% gc time)
(:epoch, 30, :train, 0.9762547770700699, :test, 0.9687356321839079)
  1.184564 seconds (1.05 M allocations: 256.428 MiB, 3.01% gc time)
(:epoch, 31, :train, 0.9746242038216624, :test, 0.9682758620689654)
  1.153683 seconds (1.06 M allocations: 256.540 MiB, 3.33% gc time)
(:epoch, 32, :train, 0.9739363057324902, :test, 0.9673563218390803)
  1.113711 seconds (1.06 M allocations: 256.511 MiB, 3.43% gc time)
(:epoch, 33, :train, 0.9734522292993691, :test, 0.9673563218390803)
  1.112886 seconds (1.06 M allocations: 256.482 MiB, 3.43% gc time)
(:epoch, 34, :train, 0.9739108280254842, :test, 0.9666666666666662)
  1.212631 seconds (1.05 M allocations: 256.420 MiB, 3.01% gc time)
(:epoch, 35, :train, 0.9700127388535101, :test, 0.9650574712643677)
  1.440063 seconds (1.06 M allocations: 256.532 MiB, 13.98% gc time)
(:epoch, 36, :train, 0.97365605095542, :test, 0.9675862068965513)
  1.134302 seconds (1.06 M allocations: 256.503 MiB, 3.40% gc time)
(:epoch, 37, :train, 0.9726878980891784, :test, 0.9680459770114938)
  1.138720 seconds (1.06 M allocations: 256.441 MiB, 3.16% gc time)
(:epoch, 38, :train, 0.9743949044586047, :test, 0.9655172413793098)
  1.152571 seconds (1.06 M allocations: 256.552 MiB, 3.33% gc time)
(:epoch, 39, :train, 0.9716178343949113, :test, 0.9636781609195397)
  1.136366 seconds (1.06 M allocations: 256.524 MiB, 3.36% gc time)
(:epoch, 40, :train, 0.9748280254777133, :test, 0.9666666666666662)
  1.119668 seconds (1.06 M allocations: 256.495 MiB, 3.43% gc time)
(:epoch, 41, :train, 0.975770700636949, :test, 0.9673563218390799)
  1.105196 seconds (1.06 M allocations: 256.433 MiB, 3.16% gc time)
(:epoch, 42, :train, 0.9767643312101973, :test, 0.9678160919540222)
  1.102391 seconds (1.06 M allocations: 256.544 MiB, 3.48% gc time)
(:epoch, 43, :train, 0.9761783439490513, :test, 0.9673563218390803)
  1.104241 seconds (1.06 M allocations: 256.516 MiB, 3.43% gc time)
(:epoch, 44, :train, 0.9780127388535086, :test, 0.96919540229885)
  1.222507 seconds (1.06 M allocations: 256.488 MiB, 12.95% gc time)
(:epoch, 45, :train, 0.9781656050955467, :test, 0.9680459770114936)
  1.101484 seconds (1.05 M allocations: 256.425 MiB, 3.17% gc time)
(:epoch, 46, :train, 0.9781656050955472, :test, 0.9678160919540227)
  1.104310 seconds (1.06 M allocations: 256.536 MiB, 3.43% gc time)
(:epoch, 47, :train, 0.9803312101910885, :test, 0.9726436781609191)
  1.102162 seconds (1.06 M allocations: 256.508 MiB, 3.43% gc time)
(:epoch, 48, :train, 0.9805605095541461, :test, 0.9703448275862068)
  1.105866 seconds (1.06 M allocations: 256.446 MiB, 3.20% gc time)
(:epoch, 49, :train, 0.9810700636942732, :test, 0.9708045977011489)
  1.102455 seconds (1.06 M allocations: 256.557 MiB, 3.47% gc time)
(:epoch, 50, :train, 0.9813503184713434, :test, 0.9735632183908043)
  1.110510 seconds (1.06 M allocations: 256.529 MiB, 3.43% gc time)
(:epoch, 51, :train, 0.9819363057324889, :test, 0.9726436781609189)
  1.217508 seconds (1.06 M allocations: 256.500 MiB, 13.01% gc time)
(:epoch, 52, :train, 0.9809936305732544, :test, 0.9719540229885051)
  1.099577 seconds (1.06 M allocations: 256.438 MiB, 3.19% gc time)
(:epoch, 53, :train, 0.9802802547770758, :test, 0.9689655172413789)
  1.099410 seconds (1.06 M allocations: 256.549 MiB, 3.47% gc time)
(:epoch, 54, :train, 0.981961783439496, :test, 0.9728735632183905)
  1.243922 seconds (1.06 M allocations: 256.521 MiB, 12.93% gc time)
(:epoch, 55, :train, 0.9834904458598778, :test, 0.9698850574712643)
  1.124619 seconds (1.06 M allocations: 256.492 MiB, 3.45% gc time)
(:epoch, 56, :train, 0.9827770700636996, :test, 0.9719540229885053)
  1.106666 seconds (1.05 M allocations: 256.430 MiB, 3.22% gc time)
(:epoch, 57, :train, 0.9824968152866292, :test, 0.9721839080459767)
  1.111991 seconds (1.06 M allocations: 256.541 MiB, 3.49% gc time)
(:epoch, 58, :train, 0.9815796178344002, :test, 0.9696551724137924)
  1.237680 seconds (1.06 M allocations: 256.513 MiB, 12.97% gc time)
(:epoch, 59, :train, 0.9831082802547821, :test, 0.9703448275862069)
  1.110172 seconds (1.06 M allocations: 256.484 MiB, 3.39% gc time)
(:epoch, 60, :train, 0.9827261146496868, :test, 0.969655172413793)
  1.119681 seconds (1.05 M allocations: 256.422 MiB, 3.16% gc time)
(:epoch, 61, :train, 0.982624203821662, :test, 0.9698850574712643)
  1.235869 seconds (1.06 M allocations: 256.533 MiB, 12.93% gc time)
(:epoch, 62, :train, 0.9848917197452277, :test, 0.9721839080459769)
  1.109420 seconds (1.06 M allocations: 256.505 MiB, 3.46% gc time)
(:epoch, 63, :train, 0.9840254777070113, :test, 0.9724137931034482)
  1.104185 seconds (1.06 M allocations: 256.442 MiB, 3.23% gc time)
(:epoch, 64, :train, 0.9824968152866294, :test, 0.9682758620689654)
  1.109334 seconds (1.06 M allocations: 256.554 MiB, 3.48% gc time)
(:epoch, 65, :train, 0.9842292993630618, :test, 0.9712643678160918)
  1.232618 seconds (1.06 M allocations: 256.526 MiB, 12.97% gc time)
(:epoch, 66, :train, 0.9828280254777122, :test, 0.9698850574712643)
  1.104491 seconds (1.06 M allocations: 256.496 MiB, 3.45% gc time)
(:epoch, 67, :train, 0.9853757961783484, :test, 0.9717241379310341)
  1.106468 seconds (1.06 M allocations: 256.435 MiB, 3.22% gc time)
(:epoch, 68, :train, 0.9845605095541449, :test, 0.9712643678160922)
  1.107807 seconds (1.06 M allocations: 256.546 MiB, 3.49% gc time)
(:epoch, 69, :train, 0.9835923566879026, :test, 0.9705747126436781)
  1.113861 seconds (1.06 M allocations: 256.517 MiB, 3.46% gc time)
(:epoch, 70, :train, 0.9853503184713425, :test, 0.9728735632183907)
  1.113985 seconds (1.06 M allocations: 256.489 MiB, 3.49% gc time)
(:epoch, 71, :train, 0.9844840764331263, :test, 0.9728735632183907)
  1.264111 seconds (1.05 M allocations: 256.426 MiB, 2.96% gc time)
(:epoch, 72, :train, 0.9857070063694313, :test, 0.973103448275862)
  1.111501 seconds (1.06 M allocations: 256.538 MiB, 3.48% gc time)
(:epoch, 73, :train, 0.9859872611465015, :test, 0.9719540229885056)
  1.111924 seconds (1.06 M allocations: 256.509 MiB, 3.47% gc time)
(:epoch, 74, :train, 0.9845605095541451, :test, 0.970574712643678)
  1.172707 seconds (1.06 M allocations: 256.447 MiB, 3.12% gc time)
(:epoch, 75, :train, 0.9849426751592404, :test, 0.9717241379310344)
  1.112037 seconds (1.05 M allocations: 256.418 MiB, 3.22% gc time)
(:epoch, 76, :train, 0.9855796178343996, :test, 0.9710344827586206)
  1.242315 seconds (1.06 M allocations: 256.531 MiB, 12.95% gc time)
(:epoch, 77, :train, 0.9849426751592405, :test, 0.9714942528735628)
  1.119232 seconds (1.06 M allocations: 256.501 MiB, 3.43% gc time)
(:epoch, 78, :train, 0.9842547770700687, :test, 0.9726436781609191)
  1.105979 seconds (1.06 M allocations: 256.439 MiB, 3.26% gc time)
(:epoch, 79, :train, 0.9837707006369478, :test, 0.9717241379310342)
  1.120450 seconds (1.06 M allocations: 256.551 MiB, 3.46% gc time)
(:epoch, 80, :train, 0.9857579617834443, :test, 0.9721839080459767)
  1.147029 seconds (1.06 M allocations: 256.522 MiB, 3.42% gc time)
(:epoch, 81, :train, 0.9864203821656095, :test, 0.9728735632183907)
  1.245702 seconds (1.06 M allocations: 256.494 MiB, 12.92% gc time)
(:epoch, 82, :train, 0.9832611464968204, :test, 0.9701149425287352)
  1.107292 seconds (1.06 M allocations: 256.431 MiB, 3.21% gc time)
(:epoch, 83, :train, 0.9839745222929986, :test, 0.9712643678160915)
  1.112074 seconds (1.06 M allocations: 256.543 MiB, 3.50% gc time)
(:epoch, 84, :train, 0.9848152866242076, :test, 0.9710344827586203)
  1.117421 seconds (1.06 M allocations: 256.514 MiB, 3.47% gc time)
(:epoch, 85, :train, 0.9854267515923613, :test, 0.973563218390804)
  1.110333 seconds (1.06 M allocations: 256.485 MiB, 3.44% gc time)
(:epoch, 86, :train, 0.9865477707006414, :test, 0.9735632183908042)
  1.109744 seconds (1.05 M allocations: 256.423 MiB, 3.23% gc time)
(:epoch, 87, :train, 0.9867770700636983, :test, 0.9731034482758618)
  1.125551 seconds (1.06 M allocations: 256.535 MiB, 3.47% gc time)
(:epoch, 88, :train, 0.9872611464968192, :test, 0.973333333333333)
  1.320754 seconds (1.06 M allocations: 256.506 MiB, 12.36% gc time)
(:epoch, 89, :train, 0.9880000000000041, :test, 0.9724137931034477)
  1.106054 seconds (1.06 M allocations: 256.444 MiB, 3.20% gc time)
(:epoch, 90, :train, 0.9829808917197501, :test, 0.9694252873563217)
  1.108282 seconds (1.06 M allocations: 256.555 MiB, 3.50% gc time)
(:epoch, 91, :train, 0.9867770700636987, :test, 0.973333333333333)
  1.235227 seconds (1.06 M allocations: 256.527 MiB, 12.92% gc time)
(:epoch, 92, :train, 0.9859108280254826, :test, 0.9728735632183907)
  1.104389 seconds (1.06 M allocations: 256.498 MiB, 3.48% gc time)
(:epoch, 93, :train, 0.988152866242042, :test, 0.9744827586206892)
  1.182455 seconds (1.06 M allocations: 256.436 MiB, 3.06% gc time)
(:epoch, 94, :train, 0.9858343949044631, :test, 0.9728735632183905)
  1.104176 seconds (1.06 M allocations: 256.547 MiB, 3.53% gc time)
(:epoch, 95, :train, 0.9879490445859918, :test, 0.9721839080459765)
  1.233747 seconds (1.06 M allocations: 256.519 MiB, 13.06% gc time)
(:epoch, 96, :train, 0.9865732484076473, :test, 0.972413793103448)
  1.174078 seconds (1.06 M allocations: 256.490 MiB, 3.28% gc time)
(:epoch, 97, :train, 0.9880764331210234, :test, 0.975172413793103)
  1.108525 seconds (1.05 M allocations: 256.428 MiB, 3.22% gc time)
(:epoch, 98, :train, 0.9876433121019146, :test, 0.9728735632183905)
  1.110418 seconds (1.06 M allocations: 256.540 MiB, 3.49% gc time)
(:epoch, 99, :train, 0.9873375796178385, :test, 0.9719540229885056)
  1.194032 seconds (1.06 M allocations: 256.511 MiB, 3.27% gc time)
(:epoch, 100, :train, 0.9867006369426796, :test, 0.9737931034482754)
Out[16]:
4-element Array{Knet.KnetArray{Float32,N} where N,1}:
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010218200000, 1738752, 0, nothing), (849, 512))
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x000001024ec0e800, 2048, 0, nothing), (1, 512))     
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207060000, 1024, 0, nothing), (128, 2))     
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x00000102070bb800, 8, 0, nothing), (1, 2))          

Multi-class classification

Here I needed 3 layers, but still achieved the same high accuracy as with MLP. Perhaps the time dependencies between different types of arrhyhtmias are less straightforward, which makes sense.

In [17]:
main(100,[128,256,64],"multiclass","KnetArray")
Dict(7=>4332,4=>1380,2=>3096,3=>3730,5=>4332,8=>1404,6=>1770,1=>4348)
Dict(2=>24392,1=>19264)
INFO: 8 unique values.
(:epoch, 0, :train, 0.16756264236902055, :test, 0.16458333333333333)
  1.284043 seconds (1.35 M allocations: 172.399 MiB, 2.45% gc time)
(:epoch, 1, :train, 0.7764009111617307, :test, 0.7870833333333334)
  1.238067 seconds (1.36 M allocations: 172.573 MiB, 2.57% gc time)
(:epoch, 2, :train, 0.8395899772209562, :test, 0.8391666666666667)
  1.238087 seconds (1.35 M allocations: 172.485 MiB, 2.50% gc time)
(:epoch, 3, :train, 0.8577676537585408, :test, 0.8541666666666666)
  1.238547 seconds (1.36 M allocations: 172.557 MiB, 2.56% gc time)
(:epoch, 4, :train, 0.8883826879271054, :test, 0.8774999999999998)
  1.229925 seconds (1.34 M allocations: 172.362 MiB, 2.30% gc time)
(:epoch, 5, :train, 0.9061958997722074, :test, 0.8995833333333332)
  1.229643 seconds (1.35 M allocations: 172.528 MiB, 2.59% gc time)
(:epoch, 6, :train, 0.9079726651480616, :test, 0.9004166666666666)
  1.229925 seconds (1.36 M allocations: 172.607 MiB, 2.62% gc time)
(:epoch, 7, :train, 0.9131662870159434, :test, 0.9041666666666668)
  1.240510 seconds (1.35 M allocations: 172.500 MiB, 2.55% gc time)
(:epoch, 8, :train, 0.9234168564920263, :test, 0.9174999999999999)
  1.239722 seconds (1.36 M allocations: 172.576 MiB, 2.61% gc time)
(:epoch, 9, :train, 0.927471526195898, :test, 0.9187500000000003)
  1.235794 seconds (1.35 M allocations: 172.382 MiB, 2.31% gc time)
(:epoch, 10, :train, 0.928747152619588, :test, 0.9187500000000003)
  1.265289 seconds (1.36 M allocations: 172.547 MiB, 2.57% gc time)
(:epoch, 11, :train, 0.9296127562642343, :test, 0.9191666666666666)
  1.239153 seconds (1.36 M allocations: 172.629 MiB, 2.62% gc time)
(:epoch, 12, :train, 0.9336674259681078, :test, 0.92)
  1.233210 seconds (1.35 M allocations: 172.517 MiB, 2.58% gc time)
(:epoch, 13, :train, 0.9280637813211828, :test, 0.92)
  1.227946 seconds (1.36 M allocations: 172.598 MiB, 2.63% gc time)
(:epoch, 14, :train, 0.9359908883826862, :test, 0.9195833333333338)
  1.240616 seconds (1.35 M allocations: 172.490 MiB, 2.55% gc time)
(:epoch, 15, :train, 0.9524829157175384, :test, 0.9420833333333337)
  1.239443 seconds (1.36 M allocations: 172.567 MiB, 2.59% gc time)
(:epoch, 16, :train, 0.943872437357629, :test, 0.9329166666666668)
  1.235514 seconds (1.34 M allocations: 172.371 MiB, 2.31% gc time)
(:epoch, 17, :train, 0.9496127562642354, :test, 0.9395833333333333)
  1.359527 seconds (1.36 M allocations: 172.540 MiB, 11.37% gc time)
(:epoch, 18, :train, 0.9613667425968098, :test, 0.945)
  1.338635 seconds (1.36 M allocations: 172.618 MiB, 2.47% gc time)
(:epoch, 19, :train, 0.9584965831435079, :test, 0.945)
  1.299630 seconds (1.35 M allocations: 172.506 MiB, 2.44% gc time)
(:epoch, 20, :train, 0.9538041002277903, :test, 0.9408333333333334)
  1.309519 seconds (1.36 M allocations: 172.588 MiB, 2.48% gc time)
(:epoch, 21, :train, 0.9629157175398626, :test, 0.9549999999999997)
  1.239411 seconds (1.35 M allocations: 172.480 MiB, 2.53% gc time)
(:epoch, 22, :train, 0.9651480637813216, :test, 0.9483333333333334)
  1.288231 seconds (1.36 M allocations: 172.556 MiB, 2.48% gc time)
(:epoch, 23, :train, 0.9713439635535307, :test, 0.9562499999999999)
  1.503316 seconds (1.34 M allocations: 172.360 MiB, 1.93% gc time)
(:epoch, 24, :train, 0.9727107061503422, :test, 0.9558333333333332)
  1.473055 seconds (1.35 M allocations: 172.527 MiB, 2.30% gc time)
(:epoch, 25, :train, 0.9626423690205, :test, 0.9462499999999997)
  1.237378 seconds (1.36 M allocations: 172.607 MiB, 2.62% gc time)
(:epoch, 26, :train, 0.9775854214123012, :test, 0.9637499999999998)
  1.259629 seconds (1.35 M allocations: 172.500 MiB, 2.54% gc time)
(:epoch, 27, :train, 0.9725284738041005, :test, 0.9616666666666666)
  1.284289 seconds (1.36 M allocations: 172.576 MiB, 2.49% gc time)
(:epoch, 28, :train, 0.9700683371298404, :test, 0.9587499999999999)
  1.297709 seconds (1.35 M allocations: 172.381 MiB, 2.22% gc time)
(:epoch, 29, :train, 0.9784054669703887, :test, 0.9641666666666667)
  1.321732 seconds (1.36 M allocations: 172.547 MiB, 2.42% gc time)
(:epoch, 30, :train, 0.9781321184510255, :test, 0.9658333333333333)
  1.241976 seconds (1.36 M allocations: 172.628 MiB, 2.64% gc time)
(:epoch, 31, :train, 0.9789066059225521, :test, 0.9633333333333333)
  1.235823 seconds (1.35 M allocations: 172.516 MiB, 2.61% gc time)
(:epoch, 32, :train, 0.97243735763098, :test, 0.9591666666666665)
  1.237106 seconds (1.36 M allocations: 172.597 MiB, 2.70% gc time)
(:epoch, 33, :train, 0.9705694760820053, :test, 0.9575000000000004)
  1.252345 seconds (1.35 M allocations: 172.490 MiB, 2.53% gc time)
(:epoch, 34, :train, 0.9830523917995457, :test, 0.9662499999999999)
  1.239125 seconds (1.36 M allocations: 172.565 MiB, 2.60% gc time)
(:epoch, 35, :train, 0.9846924829157186, :test, 0.9704166666666666)
  1.234364 seconds (1.34 M allocations: 172.371 MiB, 2.31% gc time)
(:epoch, 36, :train, 0.9774943052391805, :test, 0.9620833333333333)
  1.358643 seconds (1.36 M allocations: 172.541 MiB, 11.34% gc time)
(:epoch, 37, :train, 0.9838724373576317, :test, 0.9691666666666664)
  1.240362 seconds (1.36 M allocations: 172.617 MiB, 2.60% gc time)
(:epoch, 38, :train, 0.9882004555808672, :test, 0.9708333333333333)
  1.236872 seconds (1.35 M allocations: 172.506 MiB, 2.58% gc time)
(:epoch, 39, :train, 0.9830979498861043, :test, 0.9708333333333331)
  1.238240 seconds (1.36 M allocations: 172.587 MiB, 2.59% gc time)
(:epoch, 40, :train, 0.987517084282461, :test, 0.9716666666666663)
  1.241821 seconds (1.35 M allocations: 172.480 MiB, 2.53% gc time)
(:epoch, 41, :train, 0.9821867881548986, :test, 0.9691666666666663)
  1.237626 seconds (1.36 M allocations: 172.556 MiB, 2.63% gc time)
(:epoch, 42, :train, 0.9745330296127566, :test, 0.9562500000000002)
  1.237098 seconds (1.34 M allocations: 172.361 MiB, 2.28% gc time)
(:epoch, 43, :train, 0.9883826879271079, :test, 0.9724999999999997)
  1.238110 seconds (1.35 M allocations: 172.525 MiB, 2.60% gc time)
(:epoch, 44, :train, 0.9829157175398643, :test, 0.9691666666666664)
  1.234479 seconds (1.36 M allocations: 172.608 MiB, 2.66% gc time)
(:epoch, 45, :train, 0.9781321184510262, :test, 0.9641666666666665)
  1.240334 seconds (1.35 M allocations: 172.500 MiB, 2.54% gc time)
(:epoch, 46, :train, 0.9825056947608202, :test, 0.9670833333333332)
  1.271670 seconds (1.36 M allocations: 172.575 MiB, 2.56% gc time)
(:epoch, 47, :train, 0.9892482915717543, :test, 0.9758333333333334)
  1.238359 seconds (1.35 M allocations: 172.381 MiB, 2.32% gc time)
(:epoch, 48, :train, 0.9895216400911169, :test, 0.9762499999999994)
  1.235717 seconds (1.36 M allocations: 172.546 MiB, 2.64% gc time)
(:epoch, 49, :train, 0.9878815489749438, :test, 0.9725)
  1.242105 seconds (1.36 M allocations: 172.628 MiB, 2.62% gc time)
(:epoch, 50, :train, 0.9848291571753993, :test, 0.9720833333333333)
  1.246834 seconds (1.35 M allocations: 172.515 MiB, 2.69% gc time)
(:epoch, 51, :train, 0.983234624145787, :test, 0.964583333333333)
  1.448165 seconds (1.36 M allocations: 172.597 MiB, 2.31% gc time)
(:epoch, 52, :train, 0.9840091116173122, :test, 0.9649999999999999)
  1.423729 seconds (1.35 M allocations: 172.491 MiB, 2.30% gc time)
(:epoch, 53, :train, 0.9844646924829172, :test, 0.9720833333333331)
  1.320251 seconds (1.36 M allocations: 172.565 MiB, 2.45% gc time)
(:epoch, 54, :train, 0.9852847380410026, :test, 0.9679166666666666)
  1.317594 seconds (1.34 M allocations: 172.370 MiB, 2.22% gc time)
(:epoch, 55, :train, 0.993530751708429, :test, 0.9770833333333332)
  1.360741 seconds (1.36 M allocations: 172.540 MiB, 11.32% gc time)
(:epoch, 56, :train, 0.9841913439635545, :test, 0.9708333333333333)
  1.236329 seconds (1.36 M allocations: 172.617 MiB, 2.62% gc time)
(:epoch, 57, :train, 0.990159453302962, :test, 0.9754166666666665)
  1.356357 seconds (1.35 M allocations: 172.505 MiB, 2.37% gc time)
(:epoch, 58, :train, 0.9917084282460147, :test, 0.9774999999999999)
  1.240657 seconds (1.36 M allocations: 172.587 MiB, 2.67% gc time)
(:epoch, 59, :train, 0.9754897494305245, :test, 0.9587500000000002)
  1.267707 seconds (1.35 M allocations: 172.480 MiB, 2.48% gc time)
(:epoch, 60, :train, 0.988519362186789, :test, 0.9737499999999999)
  1.328506 seconds (1.36 M allocations: 172.556 MiB, 2.41% gc time)
(:epoch, 61, :train, 0.9734851936218673, :test, 0.9545833333333333)
  1.226200 seconds (1.34 M allocations: 172.360 MiB, 2.30% gc time)
(:epoch, 62, :train, 0.9907061503416862, :test, 0.9820833333333332)
  1.233775 seconds (1.35 M allocations: 172.526 MiB, 2.57% gc time)
(:epoch, 63, :train, 0.9889749430523925, :test, 0.9741666666666666)
  1.233276 seconds (1.36 M allocations: 172.606 MiB, 2.61% gc time)
(:epoch, 64, :train, 0.9892027334851944, :test, 0.9779166666666667)
  1.232770 seconds (1.35 M allocations: 172.499 MiB, 2.56% gc time)
(:epoch, 65, :train, 0.9866970387243739, :test, 0.9724999999999998)
  1.232299 seconds (1.36 M allocations: 172.576 MiB, 2.65% gc time)
(:epoch, 66, :train, 0.9911161731207295, :test, 0.9783333333333334)
  1.232592 seconds (1.35 M allocations: 172.380 MiB, 2.33% gc time)
(:epoch, 67, :train, 0.9912528473804111, :test, 0.977083333333333)
  1.263242 seconds (1.36 M allocations: 172.546 MiB, 2.63% gc time)
(:epoch, 68, :train, 0.9955808656036453, :test, 0.9841666666666665)
  1.255372 seconds (1.36 M allocations: 172.627 MiB, 2.61% gc time)
(:epoch, 69, :train, 0.9878359908883834, :test, 0.9779166666666668)
  1.237327 seconds (1.35 M allocations: 172.516 MiB, 2.61% gc time)
(:epoch, 70, :train, 0.992892938496584, :test, 0.9841666666666664)
  1.234784 seconds (1.36 M allocations: 172.597 MiB, 2.64% gc time)
(:epoch, 71, :train, 0.9898405466970404, :test, 0.9779166666666667)
  1.233364 seconds (1.35 M allocations: 172.489 MiB, 2.55% gc time)
(:epoch, 72, :train, 0.9851025056947621, :test, 0.9766666666666662)
  1.233416 seconds (1.36 M allocations: 172.566 MiB, 2.60% gc time)
(:epoch, 73, :train, 0.9928929384965837, :test, 0.9758333333333332)
  1.222778 seconds (1.34 M allocations: 172.370 MiB, 2.35% gc time)
(:epoch, 74, :train, 0.9959453302961283, :test, 0.9820833333333332)
  1.346282 seconds (1.36 M allocations: 172.539 MiB, 11.41% gc time)
(:epoch, 75, :train, 0.9872437357630991, :test, 0.9733333333333332)
  1.230048 seconds (1.36 M allocations: 172.618 MiB, 2.62% gc time)
(:epoch, 76, :train, 0.9867881548974952, :test, 0.9720833333333331)
  1.233303 seconds (1.35 M allocations: 172.505 MiB, 2.59% gc time)
(:epoch, 77, :train, 0.9933940774487474, :test, 0.9758333333333331)
  1.229738 seconds (1.36 M allocations: 172.586 MiB, 2.60% gc time)
(:epoch, 78, :train, 0.9903416856492039, :test, 0.9737500000000002)
  1.224086 seconds (1.35 M allocations: 172.480 MiB, 2.54% gc time)
(:epoch, 79, :train, 0.9926195899772214, :test, 0.97625)
  1.229309 seconds (1.36 M allocations: 172.555 MiB, 2.60% gc time)
(:epoch, 80, :train, 0.993257403189067, :test, 0.9829166666666665)
  1.221175 seconds (1.34 M allocations: 172.362 MiB, 2.31% gc time)
(:epoch, 81, :train, 0.9903416856492034, :test, 0.9812499999999996)
  1.231572 seconds (1.35 M allocations: 172.526 MiB, 2.59% gc time)
(:epoch, 82, :train, 0.9962642369020502, :test, 0.9829166666666665)
  1.227229 seconds (1.36 M allocations: 172.606 MiB, 2.65% gc time)
(:epoch, 83, :train, 0.9896583143507982, :test, 0.9762499999999998)
  1.228169 seconds (1.35 M allocations: 172.499 MiB, 2.56% gc time)
(:epoch, 84, :train, 0.9953530751708433, :test, 0.9824999999999998)
  1.233592 seconds (1.36 M allocations: 172.575 MiB, 2.60% gc time)
(:epoch, 85, :train, 0.9944419134396364, :test, 0.98)
  1.226147 seconds (1.35 M allocations: 172.381 MiB, 2.33% gc time)
(:epoch, 86, :train, 0.9800000000000004, :test, 0.9662499999999999)
  1.226082 seconds (1.36 M allocations: 172.546 MiB, 2.62% gc time)
(:epoch, 87, :train, 0.9940774487471531, :test, 0.9829166666666663)
  1.231788 seconds (1.36 M allocations: 172.627 MiB, 2.63% gc time)
(:epoch, 88, :train, 0.9810022779043283, :test, 0.9749999999999996)
  1.251932 seconds (1.35 M allocations: 172.515 MiB, 2.55% gc time)
(:epoch, 89, :train, 0.9849658314350804, :test, 0.9729166666666663)
  1.232861 seconds (1.36 M allocations: 172.597 MiB, 2.63% gc time)
(:epoch, 90, :train, 0.9909339407744876, :test, 0.9795833333333334)
  1.230214 seconds (1.35 M allocations: 172.489 MiB, 2.55% gc time)
(:epoch, 91, :train, 0.9917995444191354, :test, 0.9795833333333331)
  1.233392 seconds (1.36 M allocations: 172.566 MiB, 2.61% gc time)
(:epoch, 92, :train, 0.9933029612756272, :test, 0.9791666666666665)
  1.226757 seconds (1.34 M allocations: 172.371 MiB, 2.32% gc time)
(:epoch, 93, :train, 0.9969476082004561, :test, 0.9879166666666666)
  1.482199 seconds (1.36 M allocations: 172.539 MiB, 10.45% gc time)
(:epoch, 94, :train, 0.9956719817767659, :test, 0.9845833333333331)
  1.400889 seconds (1.36 M allocations: 172.618 MiB, 2.35% gc time)
(:epoch, 95, :train, 0.9752619589977221, :test, 0.9620833333333335)
  1.297955 seconds (1.35 M allocations: 172.505 MiB, 2.47% gc time)
(:epoch, 96, :train, 0.9959908883826883, :test, 0.984583333333333)
  1.380960 seconds (1.36 M allocations: 172.587 MiB, 2.44% gc time)
(:epoch, 97, :train, 0.995170842824602, :test, 0.9841666666666664)
  1.261671 seconds (1.35 M allocations: 172.479 MiB, 2.49% gc time)
(:epoch, 98, :train, 0.9937129840546703, :test, 0.9820833333333333)
  1.375738 seconds (1.36 M allocations: 172.556 MiB, 2.36% gc time)
(:epoch, 99, :train, 0.9970842824601373, :test, 0.9862500000000002)
  1.298061 seconds (1.34 M allocations: 172.360 MiB, 2.21% gc time)
(:epoch, 100, :train, 0.9913895216400923, :test, 0.9804166666666666)
Out[17]:
8-element Array{Knet.KnetArray{Float32,N} where N,1}:
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010243200000, 1738752, 0, nothing), (849, 512)) 
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e33800, 2048, 0, nothing), (1, 512))      
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010274e00000, 1572864, 0, nothing), (384, 1024))
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x000001024ed75800, 4096, 0, nothing), (1, 1024))     
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010256580000, 327680, 0, nothing), (320, 256))  
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x00000102071a0c00, 1024, 0, nothing), (1, 256))      
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206f91000, 2048, 0, nothing), (64, 8))       
 Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x000001020715c800, 32, 0, nothing), (1, 8))