2  Chapter 4: Under the Hood: Training a Digit Classifier

Download MNIST dataset, and sepearate into training, validation and test sets:

threes = sort(readdir(joinpath(data_path, "train", "3")))
sevens = sort(readdir(joinpath(data_path, "train", "7")))

threes
6131-element Vector{String}:
 "10001.png"
 "10012.png"
 "10032.png"
 "10035.png"
 "10043.png"
 "10053.png"
 "10075.png"
 "1008.png"
 "10092.png"
 "10094.png"
 "10098.png"
 "10100.png"
 "10117.png"
 ⋮
 "9905.png"
 "9914.png"
 "9916.png"
 "9917.png"
 "993.png"
 "9933.png"
 "9954.png"
 "9960.png"
 "9975.png"
 "9978.png"
 "999.png"
 "9992.png"
im3_path = joinpath(data_path, "train", "3", threes[1])
im3 = load(joinpath(im3_path))
im3_array = convert(Array{Int}, im3 * 255)
im3_array[4:10, 4:10]
7×7 Matrix{Int64}:
 0  0    0    0    0    0    0
 0  0    0    0    0    0    0
 0  0    0    0    0    0   29
 0  0    0    0   48  166  224
 0  0   93  244  249  253  187
 0  0  107  253  253  230   48
 0  0    3   20   20   15    0
seven_tensors = [load(joinpath(data_path, "train", "7", seven)) for seven in sevens]
three_tensors = [load(joinpath(data_path, "train", "3", three)) for three in threes]
size(seven_tensors), size(three_tensors)
((6265,), (6131,))
three_tensors[1]
size(three_tensors[1])
size(three_tensors)
(6131,)
stacked_sevens = MLUtils.stack(seven_tensors)
stacked_threes = MLUtils.stack(three_tensors)
size(stacked_sevens), size(stacked_threes)
((28, 28, 6265), (28, 28, 6131))

Alternatively one can create seven_tensors and three_tensors directly from MLUDatasets:

dataset = MLDatasets.MNIST(:train)

stacked_sevens = dataset.features[:, :, dataset.targets.==7]
stacked_threes = dataset.features[:, :, dataset.targets.==3]

size(stacked_sevens), size(stacked_threes)
((28, 28, 6265), (28, 28, 6131))
# length(size(stacked_threes))
ndims(stacked_threes)
3
### need to transpose the dimensions

stacked_sevens = permutedims(stacked_sevens, [2, 1, 3])
stacked_threes = permutedims(stacked_threes, [2, 1, 3])

convert(Array{Gray}, hcat(stacked_sevens[:, :, 1], stacked_threes[:, :, 1]))
mean3 = mean(stacked_threes, dims=3)
mean3 = mean3[:, :, 1]

convert(Array{Gray}, mean3)
mean7 = mean(stacked_sevens, dims=3)
mean7 = mean7[:, :, 1]

convert(Array{Gray}, mean7)
a_3 = stacked_threes[:, :, 1]
dist_3_abs = mean(abs.(a_3 .- mean3))
dist_3_sqr = sqrt(mean((a_3 .- mean3) .^ 2))
dist_3_abs, dist_3_sqr
(0.12612005f0, 0.23508778f0)
Flux.Losses.mae(a_3, mean3), sqrt(Flux.Losses.mse(a_3, mean3))
(0.12612005f0, 0.23508778f0)

2.1 Computing Metrics Using Broadcasting

The rule of broadcasting in Julia is different from Python? In Python, first align all dimensions to the right, then broadcast. In Julia, first align all dimensions to the left, then broadcast. So in python [1000, 28, 28] - [28, 28] is allowed, but in Julia, we need [28, 28, 1000] - [28, 28]. Use permutedims to change the order of dimensions.

valid_threes = sort(readdir(joinpath(data_path, "test", "3")))
valid_3_tens = MLUtils.stack([load(joinpath(data_path, "test", "3", img)) for img in valid_threes])

valid_sevens = sort(readdir(joinpath(data_path, "test", "7")))
valid_7_tens = MLUtils.stack([load(joinpath(data_path, "test", "7", img)) for img in valid_sevens])

# valid_3_tens = permutedims(valid_3_tens, [3, 1, 2])

size(valid_3_tens), size(valid_7_tens)
((28, 28, 1010), (28, 28, 1028))
function mnist_distance(a, b)
    mm = mean(Float32.(abs.(a .- b)), dims=(1, 2))
    return dropdims(mm, dims=(1, 2))
end

mnist_distance(a_3, mean3)[1]
0.12612005f0
valid_3_dist = mnist_distance(valid_3_tens, mean3)
(size(valid_3_dist), valid_3_dist)
((1010,), Float32[0.12803426, 0.16230617, 0.12421783, 0.14686641, 0.120200165, 0.118782386, 0.16995831, 0.12664512, 0.13367367, 0.11829293  …  0.12947088, 0.1525875, 0.1556588, 0.13255748, 0.13654083, 0.17447978, 0.12789737, 0.15078008, 0.12630658, 0.12596397])
size(valid_3_tens .- mean3)
(28, 28, 1010)
is_3(x) = mnist_distance(x, mean3) .< mnist_distance(x, mean7)
is_3 (generic function with 1 method)
is_3(a_3)
is_3(valid_3_tens[:, :, 1:10])

is_3(valid_7_tens[:, :, 1:10])
10-element BitVector:
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
accuracy_3s = mean(is_3(valid_3_tens))
accuracy_7s = mean(1 .- is_3(valid_7_tens))

accuracy_3s, accuracy_7s
(0.9168316831683169, 0.9854085603112841)

2.2 Calculating Gradients

2.2.1 Using Flux.jl

Taking gradients in Flux.jl is as simple as calling gradient on a function. For example, to take the gradient of f(x) = x^2 at x = 2, we can do the following:

f(x) = x^2
df(x) = gradient(f, x)[1]
df(2)
4.0

Below we implement and visualise gradient descent from scratch in Julia.

xmax = 10
n = 100
plt = plot(
    range(-xmax, xmax, length=n), f;
    label="f(x)", lw=5, xlim=1.5 .* [-xmax, xmax],
    xlab="Parameter", ylab="Loss", legend=false
)

nsteps = 10
lrs = [0.05, 0.3, 0.975, 1.025]
descend(x; lr=0.1) = x - lr * df(x)
x = [-0.75xmax]
x = repeat(x, length(lrs), 1)                             # repeat x for each learning rate
plts = [deepcopy(plt) for i in 1:length(lrs)]           # repeat plt for each learning rate
anim = @animate for j in 1:nsteps
    global x = hcat(x, zeros(size(x, 1)))                # add column of zeros to x
    for (i, lr) in enumerate(lrs)
        _plt = plot(plts[i], title="lr = $lr", ylims=(0, f(xmax)), legend=false)
        scatter!([x[i, j]], [f(x[i, j])]; label=nothing, ms=5, c=:red)    # plot current point
        x[i, j+1] = descend(x[i, j]; lr=lr)                               # descend
        Δx = x[i, j+1] - x[i, j]
        Δy = f(x[i, j+1]) - f(x[i, j])
        quiver!([x[i, j]], [f(x[i, j])], quiver=([Δx], [0]), c=:red)          # horizontal arrow
        quiver!([x[i, j+1]], [f(x[i, j])], quiver=([0], [Δy]), c=:red)        # vertical arrow
        plts[i] = _plt
    end
    plot(
        plts..., legend=false,
        plot_title="Step $j", margin=5mm,
        dpi=300,
    )
end
gif(anim, joinpath(www_path, "c4_gd.gif"), fps=0.5)

Figure 2.1: Gradient descent for different learning rates

2.3 An End-to-End SGD Example

## is time a good variable name?
time = collect(range(start=0, stop=19))

speed = @. $rand(20) + 0.75 * (time - 9.5)^2 + 1

scatter(time, speed, legend=false, xlabel="time", ylabel="speed")
function f(t, params)
    a, b, c = params
    return @. a * (t - b)^2 + c
end

function mse(preds, targets)
    return sum((preds .- targets) .^ 2) / length(preds)
end
mse (generic function with 1 method)
function show_preds(preds)
    scatter(time, speed)
    scatter!(time, preds, color="red")
end
show_preds (generic function with 1 method)
params = rand(3)
preds = f(time, params)

show_preds(preds)
loss = mse(preds, speed)
8296.82143644548
dloss(params) = gradient(params -> mse(f(time, params), speed), params)

grad = dloss(params)[1]

lr = 1e-5
params = params .- lr .* grad

preds = f(time, params)
mse(preds, speed)

show_preds(preds)
## params will be updated in place
function apply_step!(params; lr=1e-5, prn=true)
    grad = dloss(params)[1]
    params .-= lr * grad ## inplace update
    preds = f(time, params)
    loss = mse(preds, speed)
    if prn
        println(loss)
        println(grad)
        println(params)
    end
    return preds
end
apply_step! (generic function with 1 method)
params = rand(3)
plts = []

for i in range(1, 4)
    push!(plts, show_preds(apply_step!(params; lr=0.0001, prn=false)))
end

plot(
    plts..., legend=false,
    plot_title="First four steps", margin=5mm,
    dpi=300,
)
params = rand(3)
preds = f(time, params)

plts = []
push!(plts, show_preds(preds))

lr = 0.0001  ## how to adjust learning rate? takes a lot of time to learn
for i in range(0, 60000)
    apply_step!(params, prn=false)
end

preds = apply_step!(params, prn=true);
push!(plts, show_preds(preds))

plot(
    plts..., legend=false,
    plot_title="After 60000 steps", margin=5mm,
    dpi=300,
)
10.90886132109632
[-0.0735022149927751, -0.0003901579811111944, 4.383430999602089]
[0.6672579211758306, 9.493901611682551, 6.352134871734056]

2.4 The MNIST Loss Function

train_x = cat(stacked_threes, stacked_sevens, dims=3) |> x -> reshape(x, 28 * 28, :) |> transpose;
train_y = vcat(repeat([1], size(stacked_threes)[3]), repeat([0], size(stacked_sevens)[3]));

size(train_x), size(train_y)
((12396, 784), (12396,))
dset = [(train_x[i, :], train_y[i]) for i in range(1, size(train_x)[1])]
x, y = dset[1]
size(dset), size(x), y
((12396,), (784,), 1)
valid_x = cat(valid_3_tens, valid_7_tens, dims=3) |> x -> reshape(x, 28 * 28, :) |> transpose;
valid_y = vcat(repeat([1], size(valid_3_tens)[3]), repeat([0], size(valid_7_tens)[3]));
valid_dset = zip(eachrow(valid_x), valid_y);

size(valid_x), size(valid_y), size(valid_dset)
((2038, 784), (2038,), (2038,))
init_params(size; std=1.0) = randn(size) * std

weights = init_params((28 * 28, 1))

bias = init_params(1)

size(weights), size(bias)
((784, 1), (1,))
train_x = convert(Array{Float32}, train_x)

train_x[1:1, :] * weights .+ bias
1×1 Matrix{Float64}:
 -3.501327342352872

Pytorch tensor provides a tag to indicate if gradient is to be computed. This is not needed in Flux? To get gradient, just use gradient function in Flux

gradient(weights -> sum(train_x[1:1, :] * weights), weights)
([0.0; 0.0; … ; 0.0; 0.0;;],)
linear1(xb) = xb * weights .+ bias
preds = linear1(train_x)
12396×1 Matrix{Float64}:
  -3.501327342352871
  -7.803989646297762
   0.6762347534331052
  -3.023752660788845
  -9.516878292448476
  -7.033978158406221
  -4.121724962898755
   7.818492504419868
  -2.4782311140758995
   3.779727107749914
  -3.408244480726173
  -7.060928958728613
  -2.1788888600992573
   ⋮
   9.819611612539923
   0.8538666355803137
   2.3249911800591567
  -3.8884379846078874
  -4.1295855377706445
  -0.14468002821284043
   1.5411110361966878
  -6.228736530914673
  -2.1027676354594456
 -11.00265384543037
  -0.504937167730418
   4.139620206286243
corrects = (preds .> 0.0) .=== Bool.(train_y)

mean(corrects)
0.6069699903194579
weights[1] *= 1.0001

preds = linear1(train_x)
mean((preds .> 0.0) .== Bool.(train_y))
0.6069699903194579
trgts = [1, 0, 1]
prds = [0.9, 0.4, 0.2]

mnist_loss(predictions, targets) = mean(t === 1 ? 1 - p : p for (p, t) in zip(predictions, targets))

mnist_loss(prds, trgts), mnist_loss([0.9, 0.4, 0.8], trgts)
(0.43333333333333335, 0.2333333333333333)
sigmoid(x) = 1 / (1 + exp(-x))

print(sigmoid.(rand(10)))

plot(range(-5, 5, length=100), sigmoid)
[0.6410273776045075, 0.6105076166250384, 0.5783704340117538, 0.6426148235103056, 0.6852254892841372, 0.5984580626393032, 0.6717542924899141, 0.517005739780998, 0.5656790579343766, 0.6485622582089181]
function mnist_loss(predictions, targets)
    predictions = sigmoid.(predictions)
    return mean([t === 1 ? 1 - p : p for (p, t) in zip(predictions, targets)])
end
mnist_loss (generic function with 1 method)

2.5 SGD and Mini-Batches

coll = range(1, 15)

dl = DataLoader((coll), batchsize=5, shuffle=true)

collect(dl)
3-element Vector{Vector{Int64}}:
 [2, 4, 8, 1, 10]
 [12, 14, 7, 15, 3]
 [6, 9, 11, 5, 13]
lowercase_alphabets = 'a':'z' ## [Char(i) for i in 97:122]

ds = [ (i, v) for (i, v) in enumerate(lowercase_alphabets)]

dl = DataLoader(ds, batchsize=5, shuffle=true)
collect(dl)
6-element Vector{Vector{Tuple{Int64, Char}}}:
 [(13, 'm'), (24, 'x'), (2, 'b'), (4, 'd'), (26, 'z')]
 [(20, 't'), (25, 'y'), (12, 'l'), (22, 'v'), (18, 'r')]
 [(15, 'o'), (7, 'g'), (5, 'e'), (3, 'c'), (11, 'k')]
 [(23, 'w'), (16, 'p'), (1, 'a'), (6, 'f'), (21, 'u')]
 [(14, 'n'), (8, 'h'), (9, 'i'), (19, 's'), (10, 'j')]
 [(17, 'q')]

Does dataloader work with files and directories?

weights = init_params((28*28,1))
bias = init_params(1)
size(weights), size(bias)
((784, 1), (1,))
function reformat_dl(d1) 
    xb = MLUtils.stack([x for (x, y) in d1], dims=1)
    yb = MLUtils.stack([[y] for (x, y) in d1], dims=1)
    return xb, yb
end

dl = DataLoader(dset, batchsize=256, shuffle=true)

d1 = first(dl)
length(d1)

xb, yb = reformat_dl(d1)

size(xb), size(yb)
((256, 784), (256, 1))
valid_x = convert(Array{Float32}, valid_x)

valid_dset = [(valid_x[i, :], valid_y[i]) for i in range(1, size(valid_x)[1])]

valid_dl = DataLoader(valid_dset, batchsize=256, shuffle=true)
8-element DataLoader(::Vector{Tuple{Vector{Float32}, Int64}}, shuffle=true, batchsize=256)
  with first element:
  256-element Vector{Tuple{Vector{Float32}, Int64}}
batch = train_x[1:4, :]
size(batch)

preds = linear1(batch)

loss = mnist_loss(preds, train_y[1:4])

## redefine linear layer to include weights and bias as parameters

linear1(xb, weights, bias) = xb * weights .+ bias
preds = linear1(batch, weights, bias)

curr_gr = gradient(weights, bias) do weights, bias
    preds = linear1(batch, weights, bias)
    mnist_loss(preds, train_y[1:4])
end
([0.0; 0.0; … ; 0.0; 0.0;;], [-0.00019200529295718371])
# using dictionary to store parameters

params = Dict("weights" => weights, "bias" => bias)

linear1(xb, params) = xb * params["weights"] .+ params["bias"]

curr_gr = gradient(params) do params
    preds = linear1(batch, params)
    mnist_loss(preds, train_y[1:4])
end
(Dict{Any, Any}("weights" => [0.0; 0.0; … ; 0.0; 0.0;;], "bias" => [-0.00019200529295718371]),)
lr = 1e-4
function calc_grad(xb, yb, model, weights, bias)
    preds = model(xb, weights, bias)
    loss = mnist_loss(preds, yb)
    curr_gr = gradient(weights, bias) do weights, bias
        preds = model(xb, weights, bias)
        mnist_loss(preds, yb)
    end
end
calc_grad (generic function with 1 method)

Using params dictionary.

function calc_grad(xb, yb, model, params)
    preds = model(xb, params)
    loss = mnist_loss(preds, yb)
    curr_gr = gradient(params) do params
        preds = model(xb, params)
        mnist_loss(preds, yb)
    end
end
calc_grad (generic function with 2 methods)
curr_grad = calc_grad(batch, train_y[1:4], linear1, weights, bias)
dict_grad = calc_grad(batch, train_y[1:4], linear1, params)[1]
## weights.grad.mean(),bias.grad

mean(curr_grad[1]), mean(curr_grad[2])
mean(dict_grad["weights"]), mean(dict_grad["bias"])
(-3.7275223954992516e-5, -0.00019200529295718371)
function train_epoch(model, lr, params)
    for dd in dl
        xb, yb = reformat_dl(dd)
        grad = calc_grad(xb, yb, model, params)[1]
        for k in keys(params)
            params[k] .-= grad[k] * lr
            ## no need to zero_grad? in Pytorch, p.grad.zero_()
        end
    end
end

train_epoch(linear1, lr, params)
(preds .> 0.0) == Bool.(train_y[1:4])
false
function batch_accuracy(xb, yb)
    preds = sigmoid.(xb)
    correct = (preds .> 0.5) .== yb
    return mean(correct)
end

batch_accuracy(linear1(batch, params), train_y[1:4])
1.0
function validate_epoch(model)
    accs = zeros(length(valid_dl))
    i = 1
    for dd in valid_dl
        xb, yb = reformat_dl(dd)
        accs[i] = batch_accuracy(model(xb, params), yb)
        i = i + 1
    end
    return round(mean(accs), digits=4)
end

function train_accuracy(model)
    accs = zeros(length(dl))
    i = 1
    for dd in dl
        xb, yb = reformat_dl(dd)
        accs[i] = batch_accuracy(model(xb, params), yb)
        i = i + 1
    end
    return round(mean(accs), digits=4)
end
train_accuracy (generic function with 1 method)
lr = 1

weights = init_params((28 * 28, 1))
bias = init_params(1)

params = Dict("weights" => weights, "bias" => bias)

train_epoch(linear1, lr, params)

validate_epoch(linear1)
0.9275
for i in range(1, 20)
    train_epoch(linear1, lr, params)
    println((i, validate_epoch(linear1), train_accuracy(linear1)))
end
(1, 0.9499, 0.9512)
(2, 0.9583, 0.9591)
(3, 0.9612, 0.9639)
(4, 0.9632, 0.9657)
(5, 0.9638, 0.9678)
(6, 0.9656, 0.9694)
(7, 0.9656, 0.9705)
(8, 0.9681, 0.9716)
(9, 0.9686, 0.9722)
(10, 0.9706, 0.9731)
(11, 0.9711, 0.9737)
(12, 0.973, 0.9745)
(13, 0.9735, 0.9756)
(14, 0.9735, 0.9762)
(15, 0.974, 0.9771)
(16, 0.9741, 0.9774)
(17, 0.9744, 0.9776)
(18, 0.9749, 0.9789)
(19, 0.9749, 0.9787)
(20, 0.9749, 0.9794)

2.6 Creating an Optimizer

A Flux based implementation

model = Chain(
    Dense(28 * 28 => 1),
    Flux.sigmoid  ## or σ
)

optim = Flux.setup(Flux.Adam(1.0), model)

losses = []

for epoch in 1:20
    for dd in dl
        xb, yb = reformat_dl(dd)
        loss, grads = Flux.withgradient(model) do m
            # Evaluate model and loss inside gradient context:
            y_hat = m(xb')
            Flux.binarycrossentropy(y_hat, yb')  # mnist_loss(y_hat', yb)
        end
        Flux.update!(optim, model, grads[1])
        push!(losses, loss)  # logging, outside gradient context
    end
end

optim # parameters, momenta and output have all changed

xb, yb = reformat_dl(first(valid_dl))

out2 = model(xb')  # first row is prob. of true, second row p(false)

mean((out2[1, :] .> 0.5) .== yb)
0.9921875

Show examples of predicting seven and three.

xb, yb = reformat_dl(collect(valid_dl)[end])

seven_examples = rand(findall(y -> y == 0, yb[:]), 9)

convert(Array{Gray}, mosaic(map(i -> reshape(xb[i, :], 28, 28), seven_examples), ncol=3))

[b > 0.5 ? "three" : "seven" for b in model(xb[seven_examples, :]')]
1×9 Matrix{String}:
 "three"  "seven"  "seven"  "seven"  …  "seven"  "seven"  "seven"  "seven"
three_examples = rand(findall(y -> y == 1, yb[:]), 9)
convert(Array{Gray}, mosaic(map(i -> reshape(xb[i, :], 28, 28), three_examples), ncol=3))
[b > 0.5 ? "three" : "seven" for b in model(xb[three_examples, :]')]
1×9 Matrix{String}:
 "three"  "three"  "three"  "three"  …  "three"  "three"  "three"  "three"

2.7 Adding a Nonlinearity

function simple_net1(xb)
    res = xb * w1 .+ b1'
    res[res.<0] .= 0
    res = res * w2 .+ b2
    return res
end

w1 = init_params((28 * 28, 30))
b1 = init_params(30)
w2 = init_params((30, 1))
b2 = init_params(1)

simple_net1(train_x[1:4, :])
4×1 Matrix{Float64}:
 -126.8325425611682
  -91.884731723257
 -106.8205913364877
 -138.07959754797585
plot(range(-5, 5), Flux.relu)
simple_net_flux = Chain(
    Flux.Dense(28 * 28, 30),
    Flux.relu,
    Flux.Dense(30, 1)
)

Flux.params(simple_net_flux[1])[1] .= w1'
Flux.params(simple_net_flux[1])[2] .= b1

Flux.params(simple_net_flux[3])[1] .= w2'
Flux.params(simple_net_flux[3])[2] .= b2

simple_net_flux(train_x[1:4, :]')
1×4 Matrix{Float32}:
 -126.833  -91.8847  -106.821  -138.08

2.8 Training a Digit Classifier

The MNIST dataset can be loaded in Julia as follows:

# Data
X, y = MLDatasets.MNIST(:train)[:]
y_enc = Flux.onehotbatch(y, 0:9)
Xtest, ytest = MLDatasets.MNIST(:test)[:]
ytest_enc = onehotbatch(ytest, 0:9)
mosaic(map(i -> convert2image(MNIST, X[:, :, i]), rand(1:60000, 100)), ncol=10)

We can preprocess the data as follows:

i_train, i_val = [], []
for (k, v) in group_indices(y)
    _i_train, _i_val = splitobs(v, at=0.7)
    push!(i_train, _i_train...)
    push!(i_val, _i_val...)
end
Xtrain, ytrain = X[:, :, i_train], y_enc[:, i_train]
Xval, yval = X[:, :, i_val], y_enc[:, i_val]
([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 0 … 0 0; 0 0 … 1 1; … ; 0 0 … 0 0; 0 0 … 0 0])

Next, we define a data loader:

batchsize = 128
train_set = DataLoader((Xtrain, ytrain), batchsize=batchsize, shuffle=true)
val_set = DataLoader((Xval, yval), batchsize=batchsize)
141-element DataLoader(::Tuple{Array{Float32, 3}, OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=128)
  with first element:
  (28×28×128 Array{Float32, 3}, 10×128 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

We can now define a model, based on how we preprocessed the data:

model = Chain(
    Flux.flatten,
    Dense(28^2, 32, relu),
    Dense(32, 10),
    softmax
)
Chain(
  Flux.flatten,
  Dense(784 => 32, relu),               # 25_120 parameters
  Dense(32 => 10),                      # 330 parameters
  NNlib.softmax,
)                   # Total: 4 arrays, 25_450 parameters, 99.664 KiB.

Finally, what’s left to do is to define a loss function and an optimiser:

loss(y_hat, y) = Flux.Losses.crossentropy(y_hat, y)
opt_state = Flux.setup(Adam(), model)

Before we start training, we define some helper functions:

# Callbacks:
function accuracy(model, data::DataLoader)
    acc = 0
    for (x, y) in data
        acc += sum(onecold(model(x)) .== onecold(y)) / size(y, 2)
    end
    return acc / length(data)
end

function avg_loss(model, data::DataLoader)
    _loss = 0
    for (x, y) in data
        _loss += loss(model(x), y)[1]
    end
    return _loss / length(data)
end

As a very last step, we set up our training logs:

# Final setup:
nepochs = 100
acc_train, acc_val = accuracy(model, train_set), accuracy(model, val_set)
loss_train, loss_val = avg_loss(model, train_set), avg_loss(model, val_set)

log = DataFrame(
    epoch=0,
    acc_train=acc_train,
    acc_val=acc_val,
    loss_train=loss_train,
    loss_val=loss_val
)

Below we finally train our model:

# Training loop:
for epoch in 1:nepochs

    for (i, data) in enumerate(train_set)

        # Extract data:
        input, label = data

        # Compute loss and gradient:
        val, grads = Flux.withgradient(model) do m
            result = m(input)
            loss(result, label)
        end

        # Detect loss of Inf or NaN. Print a warning, and then skip update!
        if !isfinite(val)
            @warn "loss is $val on item $i" epoch
            continue
        end

        Flux.update!(opt_state, model, grads[1])

    end

    # Monitor progress:
    acc_train, acc_val = accuracy(model, train_set), accuracy(model, val_set)
    loss_train, loss_val = avg_loss(model, train_set), avg_loss(model, val_set)
    results = Dict(
        :epoch => epoch,
        :acc_train => acc_train,
        :acc_val => acc_val,
        :loss_train => loss_train,
        :loss_val => loss_val
    )
    push!(log, results)

    # Print progress:
    vals = Matrix(results_df[2:end,[:loss_train,:loss_val]])
    plt = UnicodePlots.lineplot(1:epoch, vals; 
        name=["Train","Validation"], title="Loss in epoch $epoch", xlim=(1,nepochs))
    UnicodePlots.display(plt)

end

Figure 2.2 shows the training and validation loss and accuracy over epochs. The model is overfitting, as the validation loss increases after bottoming out at around epoch 20.

output = DataFrame(log)
output = output[2:end, :]

anim = @animate for epoch in 1:maximum(output.epoch)
    p_loss = plot(output[1:epoch, :epoch], Matrix(output[1:epoch, [:loss_train, :loss_val]]),
        label=["Train" "Validation"], title="Loss", legend=:topleft)
    p_acc = plot(output[1:epoch, :epoch], Matrix(output[1:epoch, [:acc_train, :acc_val]]),
        label=["Train" "Validation"], title="Accuracy", legend=:topleft)
    plot(p_loss, p_acc, layout=(1, 2), dpi=300, margin=5mm, size=(800, 400))
end
gif(anim, joinpath(www_path, "c4_mnist.gif"), fps=5)

Figure 2.2: Training and validation loss and accuracy