Input Convex Neural Networks with Flux.jl

This tutorial shows how to embed an input convex neural network (ICNN) model from Flux.jl into JuMP.

Required packages

This tutorial requires the following packages:

using JuMP
import Flux
import HiGHS
import MathOptAI
import Plots

Building the ICNN

The following custom layer can be used to build ICNNs. This layer has two forward methods. One that takes a single input and the other takes a Tuple. They both return the result of the forward pass as well as the original input.

struct InputConvex{T,F}
    weight_x::Matrix{T}
    weight_z::Matrix{T}
    bias::Vector{T}
    σ::F
end

Flux.@layer(InputConvex, trainable = (weight_x, weight_z, bias))

function InputConvex(
    ((in_z, in_x), out)::Pair{Tuple{Int,Int},Int},
    σ = identity;
    init = Flux.glorot_uniform,
)
    return InputConvex(init(out, in_x), init(out, in_z), init(out), σ)
end

function (c::InputConvex)(x::AbstractVector)
    return c.σ.(c.weight_x * x .+ c.bias), x
end

function (c::InputConvex)((z, x)::Tuple)
    return c.σ.(Flux.softplus.(c.weight_z) * z .+ c.weight_x * x .+ c.bias), x
end

function Base.show(io::IO, l::InputConvex)
    m, n = size(l.weight_x)
    print(io, "InputConvex((", size(l.weight_z, 2), ", $m) => $n")
    if l.σ != identity
        print(io, ", ", l.σ)
    end
    if l.bias == false
        print(io, "; bias=false")
    end
    print(io, ")")
    return
end

Here's an example:

layer = InputConvex((8, 8) => 2, Flux.relu)
InputConvex((8, 2) => 8, relu)  # 34 parameters
layer(rand(8))
([0.9203263775571092, 0.0], [0.8734036421124615, 0.4946274884450549, 0.6640874408688472, 0.4151062705961195, 0.09655691085597329, 0.2992489153990011, 0.6054244855029871, 0.6316504168847467])

Next, we define a custom Chain to build the ICNN.

struct InputConvexChain{T<:Flux.Chain}
    chain::T
end

InputConvexChain(layers...) = InputConvexChain(Flux.Chain(layers))

(model::InputConvexChain)(x) = model.chain(x)

function Base.show(io::IO, l::InputConvexChain)
    println(io, "InputConvexChain(")
    println.(io, "\t", l.chain)
    println(io, ")")
    return
end

Here's an example:

chain = InputConvexChain(
    InputConvex((8, 8) => 2, Flux.relu),
    InputConvex((2, 8) => 1, Flux.relu),
)
InputConvexChain(
	InputConvex((8, 2) => 8, relu)
	InputConvex((2, 1) => 8, relu)
)
chain(rand(8))
([0.0955121110107846], [0.3713511968949684, 0.9406057773177869, 0.4229177031800182, 0.1287728840390393, 0.5360862285574326, 0.2786557434582626, 0.16796047324478225, 0.8623354688619428])

Building the Predictor

We need to implement build_predictor and add_predictor for InputConvexChain in order to be able to embed this network into JuMP.

struct InputConvexChainPredictor <: MathOptAI.AbstractPredictor
    p::MathOptAI.Pipeline
end

function MathOptAI.build_predictor(
    predictor::InputConvexChain;
    config::Dict = Dict{Any,Any}(),
    kwargs...,
)
    layer1 = first(predictor.chain)
    p = MathOptAI.Pipeline(
        MathOptAI.Affine(layer1.weight_x, layer1.bias),
        MathOptAI.build_predictor(layer1.σ; config),
    )
    for layer in predictor.chain[2:end]
        weights = hcat(Flux.softplus(layer.weight_z), layer.weight_x)
        push!(p.layers, MathOptAI.Affine(weights, layer.bias))
        push!(p.layers, MathOptAI.build_predictor(layer.σ; config))
    end
    return InputConvexChainPredictor(p)
end

function MathOptAI.add_predictor(
    model::JuMP.AbstractModel,
    predictor::InputConvexChainPredictor,
    x::Vector;
    kwargs...,
)
    layers = predictor.p.layers
    z, inner = MathOptAI.add_predictor(model, first(layers), x)
    formulation = MathOptAI.PipelineFormulation(predictor, Any[inner])
    for layer in layers[2:end]
        z, inner = if layer isa MathOptAI.Affine
            MathOptAI.add_predictor(model, layer, [z; x])
        else
            MathOptAI.add_predictor(model, layer, z)
        end
        push!(formulation.layers, inner)
    end
    return z, formulation
end

With that, we are now ready to embed these networks into JuMP.

Embed ICNN into JuMP

Let us build a small ICNN first.

predictor = InputConvexChain(
    InputConvex((8, 8) => 2, Flux.relu),
    InputConvex((2, 8) => 1, Flux.relu),
)
InputConvexChain(
	InputConvex((8, 2) => 8, relu)
	InputConvex((2, 1) => 8, relu)
)

We can embed predictor into a JuMP model now.

model = Model()
@variable(model, x[1:8])
z, formulation = MathOptAI.add_predictor(
    model,
    predictor,
    x;
    config = Dict(Flux.relu => MathOptAI.ReLUSOS1),
);
z
1-element Vector{JuMP.VariableRef}:
 moai_ReLU[1]
formulation
Affine(A, b) [input: 8, output: 2]
├ variables [2]
│ ├ moai_Affine[1]
│ └ moai_Affine[2]
└ constraints [2]
  ├ 0.1318863034248352 x[1] + 0.4718753397464752 x[2] + 0.3021894693374634 x[3] + 0.273065984249115 x[4] + 0.7072346806526184 x[5] - 0.3679642975330353 x[6] + 0.31136834621429443 x[7] - 0.20329298079013824 x[8] - moai_Affine[1] = 0.7275775074958801
  └ 0.08028249442577362 x[1] + 0.6980839371681213 x[2] - 0.5868808627128601 x[3] + 0.5454084277153015 x[4] + 0.35285863280296326 x[5] + 0.49926745891571045 x[6] + 0.5192384123802185 x[7] + 0.2657865285873413 x[8] - moai_Affine[2] = 0.8161755800247192
MathOptAI.ReLUSOS1()
├ variables [4]
│ ├ moai_ReLU[1]
│ ├ moai_ReLU[2]
│ ├ moai_z[1]
│ └ moai_z[2]
└ constraints [8]
  ├ moai_ReLU[1] ≥ 0
  ├ moai_z[1] ≥ 0
  ├ moai_Affine[1] - moai_ReLU[1] + moai_z[1] = 0
  ├ [moai_ReLU[1], moai_z[1]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0])
  ├ moai_ReLU[2] ≥ 0
  ├ moai_z[2] ≥ 0
  ├ moai_Affine[2] - moai_ReLU[2] + moai_z[2] = 0
  └ [moai_ReLU[2], moai_z[2]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0])
Affine(A, b) [input: 10, output: 1]
├ variables [1]
│ └ moai_Affine[1]
└ constraints [1]
  └ -0.35285982489585876 x[1] - 0.6026825904846191 x[2] - 0.45887860655784607 x[3] - 0.03224626183509827 x[4] - 0.7383406162261963 x[5] + 0.4853387773036957 x[6] - 0.1503416746854782 x[7] - 0.37024572491645813 x[8] + 1.546250343322754 moai_ReLU[1] + 1.4259251356124878 moai_ReLU[2] - moai_Affine[1] = -1.3824591636657715
MathOptAI.ReLUSOS1()
├ variables [2]
│ ├ moai_ReLU[1]
│ └ moai_z[1]
└ constraints [4]
  ├ moai_ReLU[1] ≥ 0
  ├ moai_z[1] ≥ 0
  ├ moai_Affine[1] - moai_ReLU[1] + moai_z[1] = 0
  └ [moai_ReLU[1], moai_z[1]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0])

Epigraph formulations

The nice thing about ICNNs is that we can formulate their epigraph and avoid adding binary variables to the model. For that, we can use ReLUEpigraph.

chain = InputConvexChain(
    InputConvex((1, 1) => 3, Flux.relu),
    InputConvex((3, 1) => 1, Flux.relu),
)
model = Model(HiGHS.Optimizer)
set_silent(model)
@variable(model, x[1:1])
config = Dict(Flux.relu => MathOptAI.ReLUEpigraph)
y, _ = MathOptAI.add_predictor(model, chain, x; config)
@objective(model, Min, only(y))
x_value, y_value = -20:20, Float64[]
for xi in x_value
    fix(x[1], xi)
    optimize!(model)
    assert_is_solved_and_feasible(model)
    push!(y_value, objective_value(model))
end
Plots.plot(x_value, y_value)
Example block output

As expected, the value of y is convex with respect to x.


This page was generated using Literate.jl.