diff --git a/src/atoms/IndexAtom.jl b/src/atoms/IndexAtom.jl index 49162e000..b83914929 100644 --- a/src/atoms/IndexAtom.jl +++ b/src/atoms/IndexAtom.jl @@ -109,20 +109,22 @@ function Base.getindex(x::AbstractExpr, rows::AbstractVector{<:Real}, col::Real) return getindex(x, rows, col:col) end -function Base.getindex( - x::AbstractExpr, - I::Union{AbstractMatrix{Bool},<:BitMatrix}, -) +function Base.getindex(x::AbstractExpr, I::AbstractMatrix{Bool}) return [xi for (xi, ii) in zip(x, I) if ii] end -function Base.getindex( - x::AbstractExpr, - I::Union{<:AbstractVector{Bool},<:BitVector}, -) +function Base.getindex(x::AbstractExpr, I::AbstractVector{Bool}) return [xi for (xi, ii) in zip(x, I) if ii] end +function Base.getindex(x::AbstractExpr, ind::BitVector) + return getindex(x, findall(ind)) +end + +function Base.getindex(x::AbstractExpr, ind::BitMatrix) + return getindex(x, LinearIndices(ind)[ind]) +end + # All rows and columns Base.getindex(x::AbstractExpr, ::Colon, ::Colon) = x diff --git a/src/atoms/VcatAtom.jl b/src/atoms/VcatAtom.jl index 7b97f79f9..adfe022d7 100644 --- a/src/atoms/VcatAtom.jl +++ b/src/atoms/VcatAtom.jl @@ -119,6 +119,10 @@ function Base.getindex(x::VcatAtom, inds::AbstractVector{<:Real}) return IndexAtom(remaining, inds) end +function Base.getindex(x::VcatAtom, inds::BitVector) + return getindex(x, findall(inds)) +end + function Base.getindex(x::VcatAtom, inds::AbstractVector{Bool}) - return getindex(x, first.(filter!(last, collect(enumerate(inds))))) + return getindex(x, convert(BitVector, inds)) end diff --git a/test/test_atoms.jl b/test/test_atoms.jl index 6e38b06a0..1243f8a3f 100644 --- a/test/test_atoms.jl +++ b/test/test_atoms.jl @@ -535,22 +535,25 @@ function test_IndexAtom() Convex.set_value!(x, [1 3; 2 4]) @test Convex.evaluate.(z) == [1, 2, 4] # Base.getindex(x::AbstractExpr, I::BitVector) - y = BitVector([true, false, true]) - x = Variable(3) - z = x[y] - @test string(z) == string([x[1], x[3]]) - @test z isa Vector{Convex.IndexAtom} - @test length(z) == 2 - Convex.set_value!(x, [1, 2, 3]) - @test Convex.evaluate.(z) == [1, 3] + target = """ + variables: x1, x2, x3 + minobjective: [1.0 * x1, 1.0 * x3] + """ + _test_atom(target) do context + x = Variable(3) + y = BitVector([true, false, true]) + return x[y] + end # Base.getindex(x::AbstractExpr, I::BitMatrix) - y = BitMatrix([true false; true true]) - x = Variable(2, 2) - z = x[y] - @test z isa Vector{Convex.IndexAtom} - @test length(z) == 3 - Convex.set_value!(x, [1 3; 2 4]) - @test Convex.evaluate.(z) == [1, 2, 4] + target = """ + variables: x1, x2, x3, x4 + minobjective: [1.0 * x1, 1.0 * x2, 1.0 * x4] + """ + _test_atom(target) do context + x = Variable(2, 2) + y = BitMatrix([true false; true true]) + return x[y] + end return end