Input Convex Neural Networks with Flux.jl
This tutorial shows how to embed an input convex neural network (ICNN) model from Flux.jl into JuMP.
Required packages
This tutorial requires the following packages:
using JuMP
import Flux
import HiGHS
import MathOptAI
import PlotsBuilding the ICNN
The following custom layer can be used to build ICNNs. This layer has two forward methods. One that takes a single input and the other takes a Tuple. They both return the result of the forward pass as well as the original input.
struct InputConvex{T,F}
weight_x::Matrix{T}
weight_z::Matrix{T}
bias::Vector{T}
σ::F
end
Flux.@layer(InputConvex, trainable = (weight_x, weight_z, bias))
function InputConvex(
((in_z, in_x), out)::Pair{Tuple{Int,Int},Int},
σ = identity;
init = Flux.glorot_uniform,
)
return InputConvex(init(out, in_x), init(out, in_z), init(out), σ)
end
function (c::InputConvex)(x::AbstractVector)
return c.σ.(c.weight_x * x .+ c.bias), x
end
function (c::InputConvex)((z, x)::Tuple)
return c.σ.(Flux.softplus.(c.weight_z) * z .+ c.weight_x * x .+ c.bias), x
end
function Base.show(io::IO, l::InputConvex)
m, n = size(l.weight_x)
print(io, "InputConvex((", size(l.weight_z, 2), ", $m) => $n")
if l.σ != identity
print(io, ", ", l.σ)
end
if l.bias == false
print(io, "; bias=false")
end
print(io, ")")
return
endHere's an example:
layer = InputConvex((8, 8) => 2, Flux.relu)InputConvex((8, 2) => 8, relu) # 34 parameterslayer(rand(8))([0.9203263775571092, 0.0], [0.8734036421124615, 0.4946274884450549, 0.6640874408688472, 0.4151062705961195, 0.09655691085597329, 0.2992489153990011, 0.6054244855029871, 0.6316504168847467])Next, we define a custom Chain to build the ICNN.
struct InputConvexChain{T<:Flux.Chain}
chain::T
end
InputConvexChain(layers...) = InputConvexChain(Flux.Chain(layers))
(model::InputConvexChain)(x) = model.chain(x)
function Base.show(io::IO, l::InputConvexChain)
println(io, "InputConvexChain(")
println.(io, "\t", l.chain)
println(io, ")")
return
endHere's an example:
chain = InputConvexChain(
InputConvex((8, 8) => 2, Flux.relu),
InputConvex((2, 8) => 1, Flux.relu),
)InputConvexChain(
InputConvex((8, 2) => 8, relu)
InputConvex((2, 1) => 8, relu)
)
chain(rand(8))([0.0955121110107846], [0.3713511968949684, 0.9406057773177869, 0.4229177031800182, 0.1287728840390393, 0.5360862285574326, 0.2786557434582626, 0.16796047324478225, 0.8623354688619428])Building the Predictor
We need to implement build_predictor and add_predictor for InputConvexChain in order to be able to embed this network into JuMP.
struct InputConvexChainPredictor <: MathOptAI.AbstractPredictor
p::MathOptAI.Pipeline
end
function MathOptAI.build_predictor(
predictor::InputConvexChain;
config::Dict = Dict{Any,Any}(),
kwargs...,
)
layer1 = first(predictor.chain)
p = MathOptAI.Pipeline(
MathOptAI.Affine(layer1.weight_x, layer1.bias),
MathOptAI.build_predictor(layer1.σ; config),
)
for layer in predictor.chain[2:end]
weights = hcat(Flux.softplus(layer.weight_z), layer.weight_x)
push!(p.layers, MathOptAI.Affine(weights, layer.bias))
push!(p.layers, MathOptAI.build_predictor(layer.σ; config))
end
return InputConvexChainPredictor(p)
end
function MathOptAI.add_predictor(
model::JuMP.AbstractModel,
predictor::InputConvexChainPredictor,
x::Vector;
kwargs...,
)
layers = predictor.p.layers
z, inner = MathOptAI.add_predictor(model, first(layers), x)
formulation = MathOptAI.PipelineFormulation(predictor, Any[inner])
for layer in layers[2:end]
z, inner = if layer isa MathOptAI.Affine
MathOptAI.add_predictor(model, layer, [z; x])
else
MathOptAI.add_predictor(model, layer, z)
end
push!(formulation.layers, inner)
end
return z, formulation
endWith that, we are now ready to embed these networks into JuMP.
Embed ICNN into JuMP
Let us build a small ICNN first.
predictor = InputConvexChain(
InputConvex((8, 8) => 2, Flux.relu),
InputConvex((2, 8) => 1, Flux.relu),
)InputConvexChain(
InputConvex((8, 2) => 8, relu)
InputConvex((2, 1) => 8, relu)
)
We can embed predictor into a JuMP model now.
model = Model()
@variable(model, x[1:8])
z, formulation = MathOptAI.add_predictor(
model,
predictor,
x;
config = Dict(Flux.relu => MathOptAI.ReLUSOS1),
);z1-element Vector{JuMP.VariableRef}:
moai_ReLU[1]formulationAffine(A, b) [input: 8, output: 2]
├ variables [2]
│ ├ moai_Affine[1]
│ └ moai_Affine[2]
└ constraints [2]
├ 0.1318863034248352 x[1] + 0.4718753397464752 x[2] + 0.3021894693374634 x[3] + 0.273065984249115 x[4] + 0.7072346806526184 x[5] - 0.3679642975330353 x[6] + 0.31136834621429443 x[7] - 0.20329298079013824 x[8] - moai_Affine[1] = 0.7275775074958801
└ 0.08028249442577362 x[1] + 0.6980839371681213 x[2] - 0.5868808627128601 x[3] + 0.5454084277153015 x[4] + 0.35285863280296326 x[5] + 0.49926745891571045 x[6] + 0.5192384123802185 x[7] + 0.2657865285873413 x[8] - moai_Affine[2] = 0.8161755800247192
MathOptAI.ReLUSOS1()
├ variables [4]
│ ├ moai_ReLU[1]
│ ├ moai_ReLU[2]
│ ├ moai_z[1]
│ └ moai_z[2]
└ constraints [8]
├ moai_ReLU[1] ≥ 0
├ moai_z[1] ≥ 0
├ moai_Affine[1] - moai_ReLU[1] + moai_z[1] = 0
├ [moai_ReLU[1], moai_z[1]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0])
├ moai_ReLU[2] ≥ 0
├ moai_z[2] ≥ 0
├ moai_Affine[2] - moai_ReLU[2] + moai_z[2] = 0
└ [moai_ReLU[2], moai_z[2]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0])
Affine(A, b) [input: 10, output: 1]
├ variables [1]
│ └ moai_Affine[1]
└ constraints [1]
└ -0.35285982489585876 x[1] - 0.6026825904846191 x[2] - 0.45887860655784607 x[3] - 0.03224626183509827 x[4] - 0.7383406162261963 x[5] + 0.4853387773036957 x[6] - 0.1503416746854782 x[7] - 0.37024572491645813 x[8] + 1.546250343322754 moai_ReLU[1] + 1.4259251356124878 moai_ReLU[2] - moai_Affine[1] = -1.3824591636657715
MathOptAI.ReLUSOS1()
├ variables [2]
│ ├ moai_ReLU[1]
│ └ moai_z[1]
└ constraints [4]
├ moai_ReLU[1] ≥ 0
├ moai_z[1] ≥ 0
├ moai_Affine[1] - moai_ReLU[1] + moai_z[1] = 0
└ [moai_ReLU[1], moai_z[1]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0])
Epigraph formulations
The nice thing about ICNNs is that we can formulate their epigraph and avoid adding binary variables to the model. For that, we can use ReLUEpigraph.
chain = InputConvexChain(
InputConvex((1, 1) => 3, Flux.relu),
InputConvex((3, 1) => 1, Flux.relu),
)
model = Model(HiGHS.Optimizer)
set_silent(model)
@variable(model, x[1:1])
config = Dict(Flux.relu => MathOptAI.ReLUEpigraph)
y, _ = MathOptAI.add_predictor(model, chain, x; config)
@objective(model, Min, only(y))
x_value, y_value = -20:20, Float64[]
for xi in x_value
fix(x[1], xi)
optimize!(model)
assert_is_solved_and_feasible(model)
push!(y_value, objective_value(model))
end
Plots.plot(x_value, y_value)As expected, the value of y is convex with respect to x.
This page was generated using Literate.jl.