# ECG Classification with MLP

****Sandya Subramanian****

****HST (Health Sciences and Technology), 2nd year Ph.D.****

# Data Exploration

First to import the data....

In [1]:
#Pkg.add("CSV")
#Pkg.add("Plots")
#Pkg.add("TimeSeries")
#Pkg.add("Images")
#Pkg.add("Plotly")
using CSV, Plots, Images #, TimeSeries
using Knet, AutoGrad
using Knet: sigm_dot, tanh_dot
using StatsBase



In [2]:
Knet.gpu(0)

0

And get all the helper code I wrote...

In [3]:
include("findpeaks.jl")
include("pan_tompkins.jl")
include("run_pan_tompkins.jl")
include("get_breaks.jl")
include("get_dataset.jl")
include("get_comb_dataset.jl")
include("get_traintest.jl")

get_traintest (generic function with 1 method)

# MLP for Beat Classification

I decided to start off with something simple. I know cardiologsts can classify beats pretty easily by eye, and most student with a few hours of training can at least tell if something is abnormal at the least, so it can't be THAT hard of a problem. However, I wasn't sure if it would be easy enough for a non-deep learning network. I used MLP, and played around with the number of layers, starting at 1 and increasing as I saw improving accuracy. In the end, I used 4 layers. 

In [4]:
function predict(w,x)
    for i=1:2:length(w)
        #Check x type
        if isa(x,Array)
            x = w[i]*KnetArray(mat(x)) .+ w[i+1]
        else 
            x = w[i]*mat(x) .+ w[i+1]
        end
        if i<length(w)-1
            x = relu.(x) # max(0,x)
        end
    end
    return x
end

loss(w,x,ygold) = nll(predict(w,x),ygold)

lossgradient = grad(loss)

function train(w, dtrn; lr=.5, epochs=10)
    for epoch=1:epochs
        for (x,y) in dtrn
            g = lossgradient(w, x, y)
            update!(w,g;lr=lr)
        end
    end
    return w
end

function weights(h...; atype=Array{Float32}, winit=0.1, mode="binary")
    w = Any[]
    x = Int(2*360+1)
    if mode == "binary"
        sizes = [h..., 2]
    else 
        sizes = [h..., 8]
    end
    for y in sizes
        push!(w, convert(atype, winit*randn(y,x)))
        push!(w, convert(atype, zeros(y, 1)))
        x = y
    end
    return w
end

weights (generic function with 1 method)

In [5]:
function prep_dataset(subj_list,fs)
    abn_dataset, full_dataset, abn_truth_cats, full_bin_cats, label_key = get_comb_dataset(subj_list,fs)
    println(countmap(abn_truth_cats))
    println(countmap(full_bin_cats))
    xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = get_traintest(abn_dataset,full_dataset,abn_truth_cats,full_bin_cats,0.1)
    
    return xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin 
end

prep_dataset (generic function with 1 method)

In [6]:
function main(numepochs,mode,arraytype)
    
    args = Dict{String,Any}()
    args["mode"] = mode;     #binary or multiclass
    args["seed"] = -1            #random number seed: use a nonnegative int for repeatable results
    args["batchsize"] = 50      #minibatch size   
    args["epochs"] = numepochs          #number of epochs for training
    args["hidden"] = [128, 256, 128, 64]    #sizes of hidden layers
    args["lr"] = 0.1             #learning rate
    args["winit"] = 0.1          #w initialized with winit*randn()
    args["fast"] = false         #skip loss printing for faster run
    if arraytype == "Array"
        args["atype"] = "Array{Float32}"
    else
        args["atype"] = "KnetArray{Float32}"
    end
    args["gcheck"] = 0           #check N random gradients per parameter
    
    if !args["fast"]
        println("opts=",[(k,v) for (k,v) in args]...)
    end
    args["seed"] > 0 && srand(args["seed"])
    atype = eval(parse(args["atype"]))
    w = weights(args["hidden"]...; atype=atype, winit=args["winit"], mode=args["mode"])
    
    #xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = prep_dataset([207,212,203,209,201],360)
    xtst_mc, ytst_mc, xtrn_mc, ytrn_mc, xtst_bin, ytst_bin, xtrn_bin, ytrn_bin = prep_dataset([207, 212, 203, 209, 201, 202, 205, 208, 210, 213, 220, 221, 222, 230, 111, 112, 113, 114, 115, 116, 117, 118, 119, 121, 122, 123, 124, 100, 101, 103, 104, 106, 108, 109, 232, 233, 234],360)
    
    if args["mode"] == "multiclass"
        println("Multi-class classification")
        dtrn = minibatch(xtrn_mc, ytrn_mc, args["batchsize"])
        dtst = minibatch(xtst_mc, ytst_mc, args["batchsize"])
    else
        println("Binary classification")
        dtrn = minibatch(xtrn_bin, ytrn_bin, args["batchsize"])
        dtst = minibatch(xtst_bin, ytst_bin, args["batchsize"])
    end
    report(epoch)=println((:epoch,epoch,:trn,accuracy(w,dtrn,predict),:tst,accuracy(w,dtst,predict)))
    if args["fast"]
        (train(w, dtrn; lr=args["lr"], epochs=args["epochs"]); gpu()>=0 && Knet.cudaDeviceSynchronize())
    else
        report(0)
        @time for epoch=1:args["epochs"]
            train(w, dtrn; lr=args["lr"], epochs=1)
            report(epoch)
            if args["gcheck"] > 0
                gradcheck(loss, w, first(dtrn)...; gcheck=args["gcheck"], verbose=true)
            end
        end
    end
    return w, dtrn, dtst
end

main (generic function with 1 method)

**Multi-class classification**

In [7]:
w, dtrn, dtst = main(50,"multiclass","KnetArray")

opts=("epochs", 50)("hidden", [128, 256, 128, 64])("lr", 0.1)("atype", "KnetArray{Float32}")("winit", 0.1)("mode", "multiclass")("fast", false)("seed", -1)("gcheck", 0)("batchsize", 50)
Dict(7=>4332,4=>1380,2=>3096,3=>3730,5=>4332,8=>1404,6=>1770,1=>4348)
Dict(2=>24392,1=>19264)
Multi-class classification
(:epoch, 0, :trn, 0.1784510250569476, :tst, 0.17083333333333334)
(:epoch, 1, :trn, 0.8286560364464692, :tst, 0.8270833333333333)
(:epoch, 2, :trn, 0.8806833712984055, :tst, 0.875)
(:epoch, 3, :trn, 0.9011845102505694, :tst, 0.89)
(:epoch, 4, :trn, 0.9140318906605922, :tst, 0.8983333333333333)
(:epoch, 5, :trn, 0.9210478359908884, :tst, 0.9075)
(:epoch, 6, :trn, 0.9227334851936219, :tst, 0.90625)
(:epoch, 7, :trn, 0.9345785876993167, :tst, 0.9145833333333333)
(:epoch, 8, :trn, 0.9402733485193622, :tst, 0.9170833333333334)
(:epoch, 9, :trn, 0.9486560364464692, :tst, 0.9291666666666667)
(:epoch, 10, :trn, 0.9462414578587699, :tst, 0.9270833333333334)
(:epoch, 11, :trn, 0.9537129840546698

(Any[Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206a00000, 369152, 0, nothing), (128, 721)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00000, 512, 0, nothing), (128, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207000000, 131072, 0, nothing), (256, 128)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00200, 1024, 0, nothing), (256, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207020000, 131072, 0, nothing), (128, 256)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00600, 512, 0, nothing), (128, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207200000, 32768, 0, nothing), (64, 128)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00800, 256, 0, nothing), (64, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010207400000, 2048, 0, nothing), (8, 64)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e00a00, 32, 0, nothing),

Whoohoo! The accuracy is over 98.5% on the test set after just 50 epochs!

**Binary classification**

In [8]:
w, dtrn, dtst = main(50,"binary","KnetArray")

opts=("epochs", 50)("hidden", [128, 256, 128, 64])("lr", 0.1)("atype", "KnetArray{Float32}")("winit", 0.1)("mode", "binary")("fast", false)("seed", -1)("gcheck", 0)("batchsize", 50)
Dict(7=>4332,4=>1380,2=>3096,3=>3730,5=>4332,8=>1404,6=>1770,1=>4348)
Dict(2=>24392,1=>19264)
Binary classification
(:epoch, 0, :trn, 0.44028025477707006, :tst, 0.45011494252873563)
(:epoch, 1, :trn, 0.8762038216560509, :tst, 0.8726436781609196)
(:epoch, 2, :trn, 0.9302420382165605, :tst, 0.9303448275862068)
(:epoch, 3, :trn, 0.950624203821656, :tst, 0.9478160919540229)
(:epoch, 4, :trn, 0.9539617834394905, :tst, 0.9510344827586207)
(:epoch, 5, :trn, 0.957171974522293, :tst, 0.9533333333333334)
(:epoch, 6, :trn, 0.9614012738853503, :tst, 0.9537931034482758)
(:epoch, 7, :trn, 0.9671337579617835, :tst, 0.9629885057471265)
(:epoch, 8, :trn, 0.9683566878980892, :tst, 0.9634482758620689)
(:epoch, 9, :trn, 0.9693248407643312, :tst, 0.964367816091954)
(:epoch, 10, :trn, 0.9621146496815287, :tst, 0.953103448275862)

(Any[Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206b0e600, 369152, 0, nothing), (128, 721)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e91000, 512, 0, nothing), (128, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010210880000, 131072, 0, nothing), (256, 128)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e8fc00, 1024, 0, nothing), (256, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010210820000, 131072, 0, nothing), (128, 256)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e90000, 512, 0, nothing), (128, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010210e0c800, 32768, 0, nothing), (64, 128)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e92400, 256, 0, nothing), (64, 1)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206e5aa00, 512, 0, nothing), (2, 64)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000010206ea7800, 8, 0, nothing), (

Here as well, the accuracy is over 97% after 50 epochs. Both results are highly successful and promising.