ECG Classification with MLP

Sandya Subramanian

HST (Health Sciences and Technology), 2nd year Ph.D.

Data Exploration

First to import the data....

In [1]:
#Pkg.add("CSV")
#Pkg.add("Plots")
#Pkg.add("TimeSeries")
#Pkg.add("Images")
#Pkg.add("Plotly")
using CSV, Plots, Images #, TimeSeries
using Knet, AutoGrad
using Knet: sigm_dot, tanh_dot
using StatsBase
WARNING: Method definition ==(Base.Nullable{S}, Base.Nullable{T}) in module Base at nullable.jl:238 overwritten in module NullableArrays at /home/nsrl/juliapro/JuliaPro-0.6.1.1/JuliaPro/pkgs-0.6.1.1/v0.6/NullableArrays/src/operators.jl:99.
In [2]:
Knet.gpu(0)
Out[2]:
0

And get all the helper code I wrote...

In [3]:
include("findpeaks.jl")
include("pan_tompkins.jl")
include("run_pan_tompkins.jl")
include("get_breaks.jl")
include("get_dataset.jl")
include("get_comb_dataset.jl")
include("get_traintest.jl")
Out[3]:
get_traintest (generic function with 1 method)

MLP for Beat Classification

I decided to start off with something simple. I know cardiologsts can classify beats pretty easily by eye, and most student with a few hours of training can at least tell if something is abnormal at the least, so it can't be THAT hard of a problem. However, I wasn't sure if it would be easy enough for a non-deep learning network. I used MLP, and played around with the number of layers, starting at 1 and increasing as I saw improving accuracy. In the end, I used 4 layers.

In [4]:
function predict(w,x)
    for i=1:2:length(w)
        #Check x type
        if isa(x,Array)
            x = w[i]*KnetArray(mat(x)) .+ w[i+1]
        else 
            x = w[i]*mat(x) .+ w[i+1]
        end
        if i<length(w)-1
            x = relu.(x) # max(0,x)
        end
    end
    return x
end

loss(w,x,ygold) = nll(predict(w,x),ygold)

lossgradient = grad(loss)

function train(w, dtrn; lr=.5, epochs=10)
    for epoch=1:epochs
        for (x,y) in dtrn
            g = lossgradient(w, x, y)
            update!(w,g;lr=lr)
        end
    end
    return w
end

function weights(h...; atype=Array{Float32}, winit=0.1, mode="binary")
    w = Any[]
    x = Int(2*360+1)
    if mode == "binary"
        sizes = [h..., 2]
    else 
        sizes = [h..., 8]
    end
    for y in sizes
        push!(w, convert(atype, winit*randn(y,x)))
        push!(w, convert(atype, zeros(y, 1)))
        x = y
    end
    return w
end
Out[4]:
weights (generic function with 1 method)
In [5]:
function prep_dataset(subj_list,fs)
    abn_dataset, full_dataset, abn_truth_cats, full_bin_cats, label_key = get_comb_dataset(subj_list,fs)
    println(countmap(abn_truth_cats))
    println(countmap(full_bin_cats))
    xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = get_traintest(abn_dataset,full_dataset,abn_truth_cats,full_bin_cats,0.1)
    
    return xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin 
end
Out[5]:
prep_dataset (generic function with 1 method)
In [6]:
function main(numepochs,mode,arraytype)
    
    args = Dict{String,Any}()
    args["mode"] = mode;     #binary or multiclass
    args["seed"] = -1            #random number seed: use a nonnegative int for repeatable results
    args["batchsize"] = 50      #minibatch size   
    args["epochs"] = numepochs          #number of epochs for training
    args["hidden"] = [128, 256, 128, 64]    #sizes of hidden layers
    args["lr"] = 0.1             #learning rate
    args["winit"] = 0.1          #w initialized with winit*randn()
    args["fast"] = false         #skip loss printing for faster run
    if arraytype == "Array"
        args["atype"] = "Array{Float32}"
    else
        args["atype"] = "KnetArray{Float32}"
    end
    args["gcheck"] = 0           #check N random gradients per parameter
    
    if !args["fast"]
        println("opts=",[(k,v) for (k,v) in args]...)
    end
    args["seed"] > 0 && srand(args["seed"])
    atype = eval(parse(args["atype"]))
    w = weights(args["hidden"]...; atype=atype, winit=args["winit"], mode=args["mode"])
    
    #xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = prep_dataset([207,212,203,209,201],360)
    xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = prep_dataset([207, 212, 203, 209, 201, 202, 205, 208, 210, 213, 220, 221, 222, 230, 111, 112, 113, 114, 115, 116, 117, 118, 119, 121, 122, 123, 124, 100, 101, 103, 104, 106, 108, 109, 232, 233, 234],360)
    
    if args["mode"] == "multiclass"
        println("Multi-class classification")
        dtrn = minibatch(xtrn_mc, ytrn_mc, args["batchsize"])
        dtst = minibatch(xtst_mc, ytst_mc, args["batchsize"])
    else
        println("Binary classification")
        dtrn = minibatch(xtrn_bin, ytrn_bin, args["batchsize"])
        dtst = minibatch(xtst_bin, ytst_bin, args["batchsize"])
    end
    report(epoch)=println((:epoch,epoch,:trn,accuracy(w,dtrn,predict),:tst,accuracy(w,dtst,predict)))
    if args["fast"]
        (train(w, dtrn; lr=args["lr"], epochs=args["epochs"]); gpu()>=0 && Knet.cudaDeviceSynchronize())
    else
        report(0)
        @time for epoch=1:args["epochs"]
            train(w, dtrn; lr=args["lr"], epochs=1)
            report(epoch)
            if args["gcheck"] > 0
                gradcheck(loss, w, first(dtrn)...; gcheck=args["gcheck"], verbose=true)
            end
        end
    end
    return w, dtrn, dtst
end
Out[6]:
main (generic function with 1 method)

Multi-class classification

In [7]:
w, dtrn, dtst = main(50,"multiclass","KnetArray")
opts=("epochs", 50)("hidden", [128, 256, 128, 64])("lr", 0.1)("atype", "KnetArray{Float32}")("winit", 0.1)("mode", "multiclass")("fast", false)("seed", -1)("gcheck", 0)("batchsize", 50)
Dict(7=>4332,4=>1380,2=>3096,3=>3730,5=>4332,8=>1404,6=>1770,1=>4348)
Dict(2=>24392,1=>19264)
Multi-class classification
(:epoch, 0, :trn, 0.1784510250569476, :tst, 0.17083333333333334)
(:epoch, 1, :trn, 0.8286560364464692, :tst, 0.8270833333333333)
(:epoch, 2, :trn, 0.8806833712984055, :tst, 0.875)
(:epoch, 3, :trn, 0.9011845102505694, :tst, 0.89)
(:epoch, 4, :trn, 0.9140318906605922, :tst, 0.8983333333333333)
(:epoch, 5, :trn, 0.9210478359908884, :tst, 0.9075)
(:epoch, 6, :trn, 0.9227334851936219, :tst, 0.90625)
(:epoch, 7, :trn, 0.9345785876993167, :tst, 0.9145833333333333)
(:epoch, 8, :trn, 0.9402733485193622, :tst, 0.9170833333333334)
(:epoch, 9, :trn, 0.9486560364464692, :tst, 0.9291666666666667)
(:epoch, 10, :trn, 0.9462414578587699, :tst, 0.9270833333333334)
(:epoch, 11, :trn, 0.9537129840546698, :tst, 0.9345833333333333)
(:epoch, 12, :trn, 0.9608656036446469, :tst, 0.94625)
(:epoch, 13, :trn, 0.9631890660592255, :tst, 0.94125)
(:epoch, 14, :trn, 0.9617312072892938, :tst, 0.9466666666666667)
(:epoch, 15, :trn, 0.9681093394077449, :tst, 0.9508333333333333)
(:epoch, 16, :trn, 0.9701138952164009, :tst, 0.9525)
(:epoch, 17, :trn, 0.9760820045558086, :tst, 0.95875)
(:epoch, 18, :trn, 0.9803189066059226, :tst, 0.9658333333333333)
(:epoch, 19, :trn, 0.9800911161731207, :tst, 0.9645833333333333)
(:epoch, 20, :trn, 0.9834168564920274, :tst, 0.9695833333333334)
(:epoch, 21, :trn, 0.9818223234624146, :tst, 0.96625)
(:epoch, 22, :trn, 0.9776765375854214, :tst, 0.9608333333333333)
(:epoch, 23, :trn, 0.9774031890660593, :tst, 0.9625)
(:epoch, 24, :trn, 0.984510250569476, :tst, 0.96625)
(:epoch, 25, :trn, 0.981002277904328, :tst, 0.9641666666666666)
(:epoch, 26, :trn, 0.9792710706150342, :tst, 0.9633333333333334)
(:epoch, 27, :trn, 0.9230068337129841, :tst, 0.9070833333333334)
(:epoch, 28, :trn, 0.9547152619589977, :tst, 0.93375)
(:epoch, 29, :trn, 0.9629157175398634, :tst, 0.9445833333333333)
(:epoch, 30, :trn, 0.9626423690205012, :tst, 0.9429166666666666)
(:epoch, 31, :trn, 0.9777220956719818, :tst, 0.9641666666666666)
(:epoch, 32, :trn, 0.98874715261959, :tst, 0.9716666666666667)
(:epoch, 33, :trn, 0.9864236902050114, :tst, 0.9695833333333334)
(:epoch, 34, :trn, 0.9881548974943052, :tst, 0.9754166666666667)
(:epoch, 35, :trn, 0.9848291571753987, :tst, 0.9704166666666667)
(:epoch, 36, :trn, 0.9851936218678815, :tst, 0.9733333333333334)
(:epoch, 37, :trn, 0.9876537585421412, :tst, 0.9716666666666667)
(:epoch, 38, :trn, 0.9821867881548975, :tst, 0.9670833333333333)
(:epoch, 39, :trn, 0.9943963553530751, :tst, 0.97875)
(:epoch, 40, :trn, 0.9944874715261959, :tst, 0.9820833333333333)
(:epoch, 41, :trn, 0.988382687927107, :tst, 0.9779166666666667)
(:epoch, 42, :trn, 0.9875170842824601, :tst, 0.9754166666666667)
(:epoch, 43, :trn, 0.9906605922551253, :tst, 0.9795833333333334)
(:epoch, 44, :trn, 0.9903416856492028, :tst, 0.9770833333333333)
(:epoch, 45, :trn, 0.9964464692482916, :tst, 0.985)
(:epoch, 46, :trn, 0.9907061503416856, :tst, 0.98)
(:epoch, 47, :trn, 0.996628701594533, :tst, 0.9870833333333333)
(:epoch, 48, :trn, 0.986378132118451, :tst, 0.9725)
(:epoch, 49, :trn, 0.9886560364464693, :tst, 0.9770833333333333)
(:epoch, 50, :trn, 0.9938952164009112, :tst, 0.9854166666666667)
 22.757157 seconds (30.84 M allocations: 7.609 GiB, 4.31% gc time)
Out[7]:
(Any[Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206a00000, 369152, 0, nothing), (128, 721)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00000, 512, 0, nothing), (128, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207000000, 131072, 0, nothing), (256, 128)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00200, 1024, 0, nothing), (256, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207020000, 131072, 0, nothing), (128, 256)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00600, 512, 0, nothing), (128, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207200000, 32768, 0, nothing), (64, 128)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00800, 256, 0, nothing), (64, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207400000, 2048, 0, nothing), (8, 64)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00a00, 32, 0, nothing), (8, 1))], Knet.MB(Float32[0.0386972 0.061168 … 0.0340182 0.0449909; 0.0165626 0.0791549 … 0.0380538 0.0458972; … ; 0.0305965 0.0161014 … 0.0137556 0.5456; 0.0240328 0.0478658 … 0.0297357 0.530711], [4 5 … 3 3], 50, 21953, false, 1:21953, [721, 50], [50], Array{Float32,2}, Array{Int64,1}), Knet.MB(Float32[0.0814601 0.0 … 0.0 0.0188742; 0.0376155 0.0 … 0.0 0.0418536; … ; 0.0592601 0.0 … 0.0 0.00364112; 0.0445953 0.0 … 0.0 0.0185715], [7 2 … 2 8], 50, 2439, false, 1:2439, [721, 50], [50], Array{Float32,2}, Array{Int64,1}))

Whoohoo! The accuracy is over 98.5% on the test set after just 50 epochs!

Binary classification

In [8]:
w, dtrn, dtst = main(50,"binary","KnetArray")
opts=("epochs", 50)("hidden", [128, 256, 128, 64])("lr", 0.1)("atype", "KnetArray{Float32}")("winit", 0.1)("mode", "binary")("fast", false)("seed", -1)("gcheck", 0)("batchsize", 50)
Dict(7=>4332,4=>1380,2=>3096,3=>3730,5=>4332,8=>1404,6=>1770,1=>4348)
Dict(2=>24392,1=>19264)
Binary classification
(:epoch, 0, :trn, 0.44028025477707006, :tst, 0.45011494252873563)
(:epoch, 1, :trn, 0.8762038216560509, :tst, 0.8726436781609196)
(:epoch, 2, :trn, 0.9302420382165605, :tst, 0.9303448275862068)
(:epoch, 3, :trn, 0.950624203821656, :tst, 0.9478160919540229)
(:epoch, 4, :trn, 0.9539617834394905, :tst, 0.9510344827586207)
(:epoch, 5, :trn, 0.957171974522293, :tst, 0.9533333333333334)
(:epoch, 6, :trn, 0.9614012738853503, :tst, 0.9537931034482758)
(:epoch, 7, :trn, 0.9671337579617835, :tst, 0.9629885057471265)
(:epoch, 8, :trn, 0.9683566878980892, :tst, 0.9634482758620689)
(:epoch, 9, :trn, 0.9693248407643312, :tst, 0.964367816091954)
(:epoch, 10, :trn, 0.9621146496815287, :tst, 0.953103448275862)
(:epoch, 11, :trn, 0.9703949044585988, :tst, 0.9625287356321839)
(:epoch, 12, :trn, 0.965656050955414, :tst, 0.9597701149425287)
(:epoch, 13, :trn, 0.9715414012738853, :tst, 0.9648275862068966)
(:epoch, 14, :trn, 0.9761528662420382, :tst, 0.9705747126436781)
(:epoch, 15, :trn, 0.9675923566878981, :tst, 0.9588505747126437)
(:epoch, 16, :trn, 0.9799745222929936, :tst, 0.9735632183908046)
(:epoch, 17, :trn, 0.9681273885350319, :tst, 0.9627586206896551)
(:epoch, 18, :trn, 0.9790828025477707, :tst, 0.9735632183908046)
(:epoch, 19, :trn, 0.9808917197452229, :tst, 0.9733333333333334)
(:epoch, 20, :trn, 0.9727898089171975, :tst, 0.964367816091954)
(:epoch, 21, :trn, 0.9801019108280254, :tst, 0.9763218390804598)
(:epoch, 22, :trn, 0.9822165605095542, :tst, 0.9726436781609196)
(:epoch, 23, :trn, 0.9749554140127389, :tst, 0.9664367816091954)
(:epoch, 24, :trn, 0.9692738853503184, :tst, 0.9602298850574713)
(:epoch, 25, :trn, 0.9699108280254777, :tst, 0.9613793103448276)
(:epoch, 26, :trn, 0.9846624203821656, :tst, 0.975632183908046)
(:epoch, 27, :trn, 0.984, :tst, 0.9760919540229885)
(:epoch, 28, :trn, 0.9795668789808917, :tst, 0.9724137931034482)
(:epoch, 29, :trn, 0.9667770700636943, :tst, 0.9581609195402299)
(:epoch, 30, :trn, 0.9826751592356688, :tst, 0.9749425287356321)
(:epoch, 31, :trn, 0.9827006369426752, :tst, 0.9742528735632184)
(:epoch, 32, :trn, 0.9822165605095542, :tst, 0.9728735632183908)
(:epoch, 33, :trn, 0.9817070063694268, :tst, 0.9714942528735632)
(:epoch, 34, :trn, 0.9837961783439491, :tst, 0.9751724137931035)
(:epoch, 35, :trn, 0.9829554140127389, :tst, 0.975632183908046)
(:epoch, 36, :trn, 0.982828025477707, :tst, 0.9763218390804598)
(:epoch, 37, :trn, 0.98028025477707, :tst, 0.9726436781609196)
(:epoch, 38, :trn, 0.9847898089171975, :tst, 0.9754022988505747)
(:epoch, 39, :trn, 0.9801783439490446, :tst, 0.9696551724137931)
(:epoch, 40, :trn, 0.9830063694267516, :tst, 0.9749425287356321)
(:epoch, 41, :trn, 0.9830318471337579, :tst, 0.9737931034482759)
(:epoch, 42, :trn, 0.9834904458598727, :tst, 0.9747126436781609)
(:epoch, 43, :trn, 0.9846624203821656, :tst, 0.9751724137931035)
(:epoch, 44, :trn, 0.983515923566879, :tst, 0.9740229885057471)
(:epoch, 45, :trn, 0.9821401273885351, :tst, 0.972183908045977)
(:epoch, 46, :trn, 0.9846624203821656, :tst, 0.9747126436781609)
(:epoch, 47, :trn, 0.9618853503184713, :tst, 0.9583908045977011)
(:epoch, 48, :trn, 0.9814012738853504, :tst, 0.9767816091954022)
(:epoch, 49, :trn, 0.9758726114649682, :tst, 0.968735632183908)
(:epoch, 50, :trn, 0.9806369426751592, :tst, 0.9710344827586207)
 35.223313 seconds (52.83 M allocations: 13.473 GiB, 4.77% gc time)
Out[8]:
(Any[Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206b0e600, 369152, 0, nothing), (128, 721)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e91000, 512, 0, nothing), (128, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010210880000, 131072, 0, nothing), (256, 128)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e8fc00, 1024, 0, nothing), (256, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010210820000, 131072, 0, nothing), (128, 256)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e90000, 512, 0, nothing), (128, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010210e0c800, 32768, 0, nothing), (64, 128)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e92400, 256, 0, nothing), (64, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e5aa00, 512, 0, nothing), (2, 64)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206ea7800, 8, 0, nothing), (2, 1))], Knet.MB(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [1 1 … 2 1], 50, 39290, false, 1:39290, [721, 50], [50], Array{Float32,2}, Array{Int64,1}), Knet.MB(Float32[0.0 0.0755302 … 0.00194841 0.00219122; 0.0 0.048578 … 0.0241835 0.026606; … ; 0.0 0.0890895 … 0.0188818 0.0354896; 0.0 0.113184 … 0.0135281 0.0215813], [1 2 … 2 2], 50, 4366, false, 1:4366, [721, 50], [50], Array{Float32,2}, Array{Int64,1}))

Here as well, the accuracy is over 97% after 50 epochs. Both results are highly successful and promising.