In [None]:
# This notebook contains code for loading in the weights of the model and 
# checking the output for a single validation image.

using JLD, Images, Knet, ImageView

# load weights of model
d = load("net.jld")
weights = d["weights"]
epochs = d["epochs"]

info("loading weights of epoch ", epochs)

In [None]:
const IMG_DIM = 128;

const A_MIN = -166;
const A_MAX = 141;
const A_RANGE = A_MAX - A_MIN;
const B_MIN = -132;
const B_MAX = 147;
const B_RANGE = B_MAX - B_MIN;
const NUM_BINS = 18;

function lab2l(color::Lab)::Float32
    return color.l
end

function bin2lab(bin::Int64, color::Lab):Lab
    bin_index = bin - 1
    a_index = bin_index / NUM_BINS
    a_val = A_MIN + A_RANGE * (a_index / NUM_BINS)
    b_index = bin_index % NUM_BINS
    b_val = B_MIN + B_RANGE * (b_index / NUM_BINS)
    return Lab(color.l, a_val, b_val)
end

function get_color_distributions(x_lab)
    x = lab2l.(x_lab)
    x = reshape(x, (IMG_DIM, IMG_DIM, 1, 1))
    return predict(weights, x)
end

function predict(w,x)
    x = relu.(conv4(w[1],x; padding=1) .+ w[2]) #Conv 1.1
    x = relu.(conv4(w[3],x; padding=1,stride=2) .+ w[4]) #Conv 1.2

    x = relu.(conv4(w[5],x; padding=1) .+ w[6]) #Conv 2.1
    x = relu.(conv4(w[7],x; padding=1,stride=2) .+ w[8]) #Conv 2.2

    x = relu.(conv4(w[9],x; padding=1) .+ w[10]) #Conv 3.1
    x = relu.(conv4(w[11],x; padding=1) .+ w[12]) #Conv 3.2
    x = relu.(conv4(w[13],x; padding=1,stride=2) .+ w[14]) #Conv 3.3

    x = relu.(conv4(w[15],x; padding=1) .+ w[16]) #Conv 4.1
    x = relu.(conv4(w[17],x; padding=1) .+ w[18]) #Conv 4.2
    x = relu.(conv4(w[19],x; padding=1) .+ w[20]) #Conv 4.3

    x = relu.(conv4(w[21],x; padding=1) .+ w[22]) #Conv 5.1
    x = relu.(conv4(w[23],x; padding=1) .+ w[24]) #Conv 5.2
    x = relu.(conv4(w[25],x; padding=1) .+ w[26]) #Conv 5.3

    x = relu.(conv4(w[27],x; padding=1) .+ w[28]) #Conv 6.1
    x = relu.(conv4(w[29],x; padding=1) .+ w[30]) #Conv 6.2
    x = relu.(conv4(w[31],x; padding=1) .+ w[32]) #Conv 6.3

    x = relu.(conv4(w[33],x; padding=1) .+ w[34]) #Conv 7.1
    x = relu.(conv4(w[35],x; padding=1) .+ w[36]) #Conv 7.2
    x = relu.(conv4(w[37],x; padding=1) .+ w[38]) #Conv 7.3
    
    x = relu.(deconv4(w[39],x; padding=1, stride=2) .+ w[40]) # upsample by 2
    
    x = relu.(conv4(w[41],x; padding=1) .+ w[42]) #Conv 8.1
    x = relu.(conv4(w[43],x; padding=1) .+ w[44]) #Conv 8.2
    x = relu.(conv4(w[45],x; padding=1) .+ w[46]) #Conv 8.3
    
    x = relu.(deconv4(w[47],x; stride=4) .+ w[48]) # upsample by 4
    
    x = sigm.(conv4(w[49],x) .+ w[50]) # 1x1 conv layer to get probability distribution
    
    return x
end

In [None]:
# load image and feed through network
x = load("data/images/val/00000001.jpg")
x_lab = convert.(Lab, x)
x = get_color_distributions(x_lab)  # x = randn(IMG_DIM, IMG_DIM, 324, 1)

# take highest probability bin
x = [indmax(x[i,j,:,k]) for i=1:size(x,1), j=1:size(x,2), k=1:size(x,4)]

# convert back to RGB4 to display
predicted = bin2lab.(x, x_lab)
imshow(convert.(RGB4, predicted))