In working on my bachelor project, I have implemented a type CyclicOptimiser. It acts as a drop-in replacement for regular optimisers, by adding a new method to Flux.update!
.
I can see that I have reproduced a lot of the work in here. So I simply wanted to offer up my implementation, and let you decide if I have had any ideas that might add value to this project.
using Flux
using Flux.Optimise: AbstractOptimiser
using Base.Iterators: Stateful, Cycle
import UnicodePlots
import Flux.update! # Learning rate is updated each time Flux.update! is called, allowing seemless drop-in-replacement of normal optimisers with cyclic optimisers.
import Base: show
round3(x) = round(x, sigdigits=3)
function optimiser_to_string(opt::AbstractOptimiser)
fldnms = fieldnames(typeof(opt))
fields = getfield.([opt], fldnms)
fieldtypes = typeof.(fields)
output = string(typeof(opt)) * "("
for i in eachindex(fields)
if fieldtypes[i] <: IdDict
output *= "..., "
else
fldnms[i] == :eta ? (output *= "$(fields[i]|>round3), ") : (output *= "$(fields[i]), ")
end
end
output = output[begin:end-2] * ")"
return output
end
"""
struct CycleMagnitude
len::Int
magfac::Float64
end
A type to be used as a functor with the purpose of
calculating a magnitude that is changed discretely
by a factor `magfac` each time `len` cycles are completed.
"""
struct CycleMagnitude
len::Int
magfac::Float64
end
"""
(cyc::CycleMagnitude)(x) = cyc.magfac ^ (x÷cyc.len)
Compute a magnitude that is multiplied by `cyc.magfac`
every time the input increases by cyc.len.
The input is intended to be the `taken` field of a
Cycle(Stateful(my_collection)).
Note that for the actual calculation, the learning rate
needs to be shifted so that the smallest value in the
cycle is 0 before scaling, and shifted back up after scaling.
"""
(cyc::CycleMagnitude)(x) = cyc.magfac ^ (x÷cyc.len)
abstract type AbstractCycler end
struct TriangleCycler <: AbstractCycler
cycle::Stateful{Cycle{A}} where {A<:AbstractVector}
end
show(io::IO, cyc::AbstractCycler) = println(io, "Cycler with values $(cyc.cycle.itr.xs).\nCycled $(cyc.cycle.taken) times")
cycle!(cycler::AbstractCycler) = popfirst!(cycler.cycle)
"""
TriangleCycler(lower, upper, len)
Construct a TriangleCycler containing a set
of `len` values values that goes from `lower`
up to `upper` and back down again. Plotted against
its index, the returned set looks like
a triangle with 2 equal legs.
If the `len` is odd, the first and last point will
be the same, causing repetition when cycled.
"""
function TriangleCycler(lower, upper, len)
if len == 1 # Special case to avoid the error from range(a_number, another_number != a_number, length=1)
cycle = [(lower+upper)/2]
elseif iseven(len)
cycle = vcat(range(lower, upper; length=len÷2+1), reverse(range(lower, upper; length=len÷2+1))[begin+1:end-1])
else
cycle = vcat(range(lower, upper; length=len÷2+1), reverse(range(lower, upper; length=len÷2+1))[begin+1:end])
end
return TriangleCycler(cycle |> Cycle |> Stateful)
end
show(io::IO, tricy::TriangleCycler) = println(io, "TriangleCycler from $(minimum(tricy.cycle.itr.xs)|>round3) to $(maximum(tricy.cycle.itr.xs)|>round3) of cycle-length $(length(tricy.cycle.itr.xs))")
function check_optimiser(opt::AbstractOptimiser)
hasfield(typeof(opt), :eta) || "Tried to construct a CyclicOptimiser with $(opt), which has no field eta (e.g. no learningrate parameter)." |> error
opt isa DataType && "Tried to construct a CyclicOptimiser with an optimiser type (e.g. `Descent`). Try to use a concrete optimiser instead (e.g. `Descent()`)"|>error
return nothing
end
"""
struct CyclicOptimiser{T} <: AbstractOptimiser where {T<:AbstractOptimiser}
current_optimiser::T
learningrate::AbstractCycler
cycle_magnitude::CycleMagnitude
end
"""
struct CyclicOptimiser{T} <: AbstractOptimiser where {T<:AbstractOptimiser}
current_optimiser::T
learningrate::AbstractCycler
cycle_magnitude::CycleMagnitude
function CyclicOptimiser(opt, learningrate::AbstractCycler, cycmag::CycleMagnitude)
check_optimiser(opt)
@assert length(learningrate.cycle.itr.xs) == cycmag.len "Length og learningrate cycle does not match the length of the internal CycleMagnitude."
return new{typeof(opt)}(opt, learningrate, cycmag)
end
end
"""
CyclicOptimiser(opt::AbstractOptimiser, lower, upper, len; cycler::AbstractCycler=TriangleCycler, magfac=1)
Construct a CyclicOptimiser. The optimiser whose learning rate is cycled is
`opt`, the first positional argument. `lower`, `upper` and `len` are passed on
to `cycler`, constructing an `AbstractCycler` and defaulting to TriangleCycler.
A final keyword argument `magfac` sets the magnitude-controlling factor that
is applied after a full cycle is completed. So if `magfac` is set to 0.5, then
the span of the cycle is halved each cycle. The lower limit is pinned,
so `magfac` only effects the upper limit, to ensure that the learningrate
decreases each cycle (assuming magfac ≤ 1, which is checked for).
"""
function CyclicOptimiser(opt::AbstractOptimiser, lower, upper, len; cycler=TriangleCycler, magfac=1)
check_optimiser(opt)
return CyclicOptimiser(opt, cycler(lower, upper, len), CycleMagnitude(len, magfac))
end
function plot(cycopt::CyclicOptimiser, n_cycles=3)
xs = 1:cycopt.cycle_magnitude.len*n_cycles
cycopt = deepcopy(cycopt)
Iterators.reset!(cycopt.learningrate.cycle)
ys = [cycle!(cycopt.learningrate) for _ in eachindex(xs)] .* cycopt.cycle_magnitude.(xs)
return UnicodePlots.scatterplot(xs, ys, xlabel="Iteration", ylabel="Learningrate",
title="Learningrate for $n_cycles cycles")
end
function show(io::IO, cycopt::CyclicOptimiser)
print(io,
"""
CyclicOptimiser with following properties:
Current optimiser = $(cycopt.current_optimiser|>optimiser_to_string)
Learningrate = $(typeof(cycopt.learningrate)) from $(cycopt.learningrate.cycle.itr.xs|>minimum|>round3) to $(cycopt.learningrate.cycle.itr.xs|>maximum|>round3)
Cyclelength = $(cycopt.cycle_magnitude.len). Magfac = $(cycopt.cycle_magnitude.magfac)""")
end
function cycle!(co::CyclicOptimiser)
A = co.cycle_magnitude(co.learningrate.cycle.taken)
lower_bound = co.learningrate.cycle.itr.xs |> minimum
co.current_optimiser.eta = A * (cycle!(co.learningrate) - lower_bound) + lower_bound
return co.current_optimiser
end
Flux.update!(cycopt::CyclicOptimiser, xs::Params, gs) = Flux.update!(cycle!(cycopt), xs::Params, gs)