Skip to content

Commit

Permalink
Merge pull request #199 from chelseas/fix_binary
Browse files Browse the repository at this point in the history
Adding constraints so that binary variables reflect correct activation status
  • Loading branch information
tomerarnon authored Nov 16, 2021
2 parents 0d9be34 + 92d7128 commit 237f092
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/optimization/utils/constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ function encode_layer!(::AbstractLinearProgram, model::Model, layer::Layer{Id},
nothing
end

# All ReLU layers pass through this
function encode_layer!(LP::AbstractLinearProgram, model::Model, layer::Layer{ReLU}, args...)
encode_relu.(LP, model, args...)
nothing
Expand All @@ -103,8 +104,15 @@ function encode_layer!(SLP::SlackLP, model::Model, layer::Layer{Id}, ẑᵢ, z
return nothing
end

# need to fix δᵢⱼ for BoundedMixedIntegerLP and possibly other types
function encode_layer!(::BoundedMixedIntegerLP, model::Model, layer::Layer{Id}, ẑᵢ, zᵢ, δᵢ, args...)
@constraint(model, zᵢ .== ẑᵢ)
@constraint(model, δᵢ .== 1)
return nothing
end

function encode_ij(LP, model, i, j)
# where is this function used? Needs documentation.
L = model[:network].layers[i]
params = model_params(LP, model, i)
if L.activation isa Id
Expand Down Expand Up @@ -133,8 +141,10 @@ end
function encode_relu(::BoundedMixedIntegerLP, model, ẑᵢⱼ, zᵢⱼ, δᵢⱼ, l̂ᵢⱼ, ûᵢⱼ)
if l̂ᵢⱼ >= 0.0
@constraint(model, zᵢⱼ == ẑᵢⱼ)
@constraint(model, δᵢⱼ == 1)
elseif ûᵢⱼ <= 0.0
@constraint(model, zᵢⱼ == 0.0)
@constraint(model, δᵢⱼ == 0)
else
@constraints(model, begin
zᵢⱼ >= 0.0
Expand Down Expand Up @@ -164,11 +174,11 @@ function encode_relu(::TriangularRelaxedLP, model, ẑᵢⱼ, zᵢⱼ, l̂ᵢⱼ
end

function encode_relu(::LinearRelaxedLP, model, ẑᵢⱼ, zᵢⱼ, δᵢⱼ)
@constraint(model, zᵢⱼ == (δᵢⱼ ? ẑᵢⱼ : 0.0))
@constraint(model, zᵢⱼ == (δᵢⱼ ? ẑᵢⱼ : 0.0)) # in LinearRelaxedLP δᵢⱼ is a constant not a variable
end

function encode_relu(::StandardLP, model, ẑᵢⱼ, zᵢⱼ, δᵢⱼ)
if δᵢⱼ
if δᵢⱼ # in StandardLP δᵢⱼ is a constant, not a variable
@constraint(model, ẑᵢⱼ >= 0.0)
@constraint(model, zᵢⱼ == ẑᵢⱼ)
else
Expand Down

0 comments on commit 237f092

Please sign in to comment.