In [None]:
# This notebook contains code shown during the presentation.

using Images

#=
We use the following constants for discretizing the Lab color space. The a and b ranges were 
obtained from https://pdfs.semanticscholar.org/cfe3/bcb885d922d169e156796266af41d1865fff.pdf
=#
const A_MIN = -166;
const A_MAX = 141;
const A_RANGE = A_MAX - A_MIN;
const B_MIN = -132;
const B_MAX = 147;
const B_RANGE = B_MAX - B_MIN;
const NUM_BINS = 18; # total number of ab bins is actually NUM_BINS * NUM_BINS

L = 50
[Lab{Float32}(L, i, j) for i=linspace(A_MIN, A_MAX, NUM_BINS), j=linspace(B_MIN, B_MAX, NUM_BINS)]

In [None]:
# functions for converting between Lab color space and bin number
function lab2bin(color::Lab)::Int32
    a_bin = floor((color.a - A_MIN) / A_RANGE * NUM_BINS)
    a_bin = min(a_bin, NUM_BINS - 1)  # in case color.a = A_MAX
    b_bin = floor((color.b - B_MIN) / B_RANGE * NUM_BINS)
    b_bin = min(b_bin, NUM_BINS - 1)  # in case color.b = B_MAX
    return NUM_BINS * a_bin + b_bin + 1
end

function bin2lab(bin::Int32, color::Lab):Lab
    bin_index = bin - 1
    a_index = bin_index / NUM_BINS
    a_val = A_MIN + A_RANGE * (a_index / NUM_BINS)
    b_index = bin_index % NUM_BINS
    b_val = B_MIN + B_RANGE * (b_index / NUM_BINS)
    return Lab(color.l, a_val, b_val)
end

In [None]:
img = load("data/images/train/a/abbey/00000001.jpg")
info("img is of size ", size(img), " and type ", typeof(img))

# convert RGB4 to Lab element-wise
ground_truth = convert.(Lab, img)

# convert Lab to bin numbers
lab = lab2bin.(ground_truth)

# convert bin numbers back to discretized Lab
lab = bin2lab.(lab, ground_truth)

In [None]:
#=
# code for counting the number of occurrences of each color bin

using StatsBase

imgs_trn_lab = [convert.(Lab, x) for x in imgs_trn]
imgs_trn_bin = [lab2bin.(x) for x in imgs_trn_lab]
imgs_trn_bin_counts = countmap.(imgs_trn_bin)

function merge_bin_counts(a,b)
    return merge(+, a, b)
end

bin_counts = reduce(merge_bin_counts, imgs_trn_bin_counts)
save("bin_counts.jld", "bin_counts", bin_counts)
=#

In [None]:
function weights(;atype=KnetArray{Float32})
    w = Any[
        xavier(Float32, 3, 3, 1, 64), randn(Float32, 1, 1, 64, 1), #Conv 1.1
        xavier(Float32, 3, 3, 64, 64), randn(Float32, 1, 1, 64, 1), #Conv 1.2

        xavier(Float32, 3, 3, 64, 128), randn(Float32, 1, 1, 128, 1), #Conv 2.1
        xavier(Float32, 3, 3, 128, 128), randn(Float32, 1, 1, 128, 1), #Conv 2.2

        xavier(Float32, 3, 3, 128, 256), randn(Float32, 1, 1, 256, 1), #Conv 3.1
        xavier(Float32, 3, 3, 256, 256), randn(Float32, 1, 1, 256, 1), #Conv 3.2
        xavier(Float32, 3, 3, 256, 256), randn(Float32, 1, 1, 256, 1), #Conv 3.3

        xavier(Float32, 3, 3, 256, 512), randn(Float32, 1, 1, 512, 1), #Conv 4.1
        xavier(Float32, 3, 3, 512, 512), randn(Float32, 1, 1, 512, 1), #Conv 4.2
        xavier(Float32, 3, 3, 512, 512), randn(Float32, 1, 1, 512, 1), #Conv 4.3

        xavier(Float32, 3, 3, 512, 512), randn(Float32, 1, 1, 512, 1), #Conv 5.1
        xavier(Float32, 3, 3, 512, 512), randn(Float32, 1, 1, 512, 1), #Conv 5.2
        xavier(Float32, 3, 3, 512, 512), randn(Float32, 1, 1, 512, 1), #Conv 5.3

        xavier(Float32, 3, 3, 512, 512), randn(Float32, 1, 1, 512, 1), #Conv 6.1
        xavier(Float32, 3, 3, 512, 512), randn(Float32, 1, 1, 512, 1), #Conv 6.2
        xavier(Float32, 3, 3, 512, 512), randn(Float32, 1, 1, 512, 1), #Conv 6.3

        xavier(Float32, 3, 3, 512, 256), randn(Float32, 1, 1, 256, 1), #Conv 7.1
        xavier(Float32, 3, 3, 256, 256), randn(Float32, 1, 1, 256, 1), #Conv 7.2
        xavier(Float32, 3, 3, 256, 256), randn(Float32, 1, 1, 256, 1), #Conv 7.3
        
        xavier(Float32, 4, 4, 256, 256), randn(Float32, 1, 1, 256, 1), # deconv layer

        xavier(Float32, 3, 3, 256, 128), randn(Float32, 1, 1, 128, 1), #Conv 8.1
        xavier(Float32, 3, 3, 128, 128), randn(Float32, 1, 1, 128, 1), #Conv 8.2
        xavier(Float32, 3, 3, 128, 128), randn(Float32, 1, 1, 128, 1), #Conv 8.3

        xavier(Float32, 4, 4, 128, 128), randn(Float32, 1, 1, 128, 1), # deconv layer

        xavier(Float32, 1, 1, 128, 324), randn(Float32, 1, 1, 324, 1), # 1x1 conv layer
    ]
    return map(a->convert(atype,a), w)
end

In [None]:
function predict(w,x)
    x = relu.(conv4(w[1],x; padding=1) .+ w[2]) #Conv 1.1
    x = relu.(conv4(w[3],x; padding=1,stride=2) .+ w[4]) #Conv 1.2

    x = relu.(conv4(w[5],x; padding=1) .+ w[6]) #Conv 2.1
    x = relu.(conv4(w[7],x; padding=1,stride=2) .+ w[8]) #Conv 2.2

    x = relu.(conv4(w[9],x; padding=1) .+ w[10]) #Conv 3.1
    x = relu.(conv4(w[11],x; padding=1) .+ w[12]) #Conv 3.2
    x = relu.(conv4(w[13],x; padding=1,stride=2) .+ w[14]) #Conv 3.3

    x = relu.(conv4(w[15],x; padding=1) .+ w[16]) #Conv 4.1
    x = relu.(conv4(w[17],x; padding=1) .+ w[18]) #Conv 4.2
    x = relu.(conv4(w[19],x; padding=1) .+ w[20]) #Conv 4.3

    x = relu.(conv4(w[21],x; padding=1) .+ w[22]) #Conv 5.1
    x = relu.(conv4(w[23],x; padding=1) .+ w[24]) #Conv 5.2
    x = relu.(conv4(w[25],x; padding=1) .+ w[26]) #Conv 5.3

    x = relu.(conv4(w[27],x; padding=1) .+ w[28]) #Conv 6.1
    x = relu.(conv4(w[29],x; padding=1) .+ w[30]) #Conv 6.2
    x = relu.(conv4(w[31],x; padding=1) .+ w[32]) #Conv 6.3

    x = relu.(conv4(w[33],x; padding=1) .+ w[34]) #Conv 7.1
    x = relu.(conv4(w[35],x; padding=1) .+ w[36]) #Conv 7.2
    x = relu.(conv4(w[37],x; padding=1) .+ w[38]) #Conv 7.3

    x = relu.(deconv4(w[39],x; padding=1, stride=2) .+ w[40]) # upsample by 2
    
    x = relu.(conv4(w[41],x; padding=1) .+ w[42]) #Conv 8.1
    x = relu.(conv4(w[43],x; padding=1) .+ w[44]) #Conv 8.2
    x = relu.(conv4(w[45],x; padding=1) .+ w[46]) #Conv 8.3

    x = relu.(deconv4(w[47],x; stride=4) .+ w[48]) # upsample by 4

    x = sigm.(conv4(w[49],x) .+ w[50]) # 1x1 conv layer to get probability distribution

    return x
end

In [None]:
function select_prob(p, truth::Array{Int32}, i, j, k)
    return p[truth[i,j,1,k]]
end

# initialized with the negative log of the frequency of each color (code not shown here)
color_weights = Dict{Int32,Float32}()

function color_weight(c::Int32)
    return color_weights[c]
end

function crb_loss(x, truth)
    temp = [select_prob(x[i,j,:,k], truth, i, j, k) for i=1:size(x,1), j=1:size(x,2), k=1:size(x,4)]
    temp = reshape(temp, (size(temp,1), size(temp,2), 1, size(temp,3)))
    return -sum(map(color_weight, truth) .* log.(temp))
end