Skip to content

Commit

Permalink
fix Aqua's reported piracies and method ambiguities (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
Omar-Elrefaei authored Nov 21, 2024
1 parent 5664440 commit dd7d140
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
13 changes: 7 additions & 6 deletions src/QSymbolicsBase/basic_ops_homogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ arguments(x::SScaled) = [x.coeff,x.obj]
operation(x::SScaled) = *
head(x::SScaled) = :*
children(x::SScaled) = [:*,x.coeff,x.obj]
function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj}
function Base.:(*)(c::U, x::Symbolic{T}) where {U<:Union{Number, Symbolic{<:Number}},T<:QObj}
if (isa(c, Number) && iszero(c)) || iszero(x)
SZero{T}()
elseif _isone(c)
Expand All @@ -40,9 +40,9 @@ function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj}
SScaled{T}(c, x)
end
end
Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = c*x
Base.:(*)(x::Symbolic{T}, c::Number) where {T<:QObj} = c*x
Base.:(*)(x::Symbolic{T}, y::Symbolic{S}) where {T<:QObj,S<:QObj} = throw(ArgumentError("multiplication between $(typeof(x)) and $(typeof(y)) is not defined; maybe you are looking for a tensor product `tensor`"))
Base.:(/)(x::Symbolic{T}, c) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x
Base.:(/)(x::Symbolic{T}, c::Number) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x
basis(x::SScaled) = basis(x.obj)

const SScaledKet = SScaled{AbstractKet}
Expand Down Expand Up @@ -94,13 +94,13 @@ arguments(x::SAdd) = x._arguments_precomputed
operation(x::SAdd) = +
head(x::SAdd) = :+
children(x::SAdd) = [:+; x._arguments_precomputed]
function Base.:(+)(xs::Vararg{Symbolic{T},N}) where {T<:QObj,N}
function Base.:(+)(x::Symbolic{T}, xs::Vararg{Symbolic{T}, N}) where {T<:QObj, N}
xs = (x, xs...)
xs = collect(xs)
f = first(xs)
nonzero_terms = filter!(x->!iszero(x),xs)
isempty(nonzero_terms) ? f : SAdd{T}(countmap_flatten(nonzero_terms, SScaled{T}))
end
Base.:(+)(xs::Vararg{Symbolic{<:QObj},0}) = 0 # to avoid undefined type parameters issue in the above method
basis(x::SAdd) = basis(first(x.dict).first)

const SAddBra = SAdd{AbstractBra}
Expand Down Expand Up @@ -137,7 +137,8 @@ arguments(x::SMulOperator) = x.terms
operation(x::SMulOperator) = *
head(x::SMulOperator) = :*
children(x::SMulOperator) = [:*;x.terms]
function Base.:(*)(xs::Symbolic{AbstractOperator}...)
function Base.:(*)(x::Symbolic{AbstractOperator}, xs::Vararg{Symbolic{AbstractOperator}, N}) where {N}
xs = (x, xs...)
zero_ind = findfirst(x->iszero(x), xs)
if isnothing(zero_ind)
if any(x->!(samebases(basis(x),basis(first(xs)))),xs)
Expand Down
2 changes: 2 additions & 0 deletions src/QSymbolicsBase/basic_superops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ kraus(xs::Symbolic{AbstractOperator}...) = KrausRepr(collect(xs))
basis(x::KrausRepr) = basis(first(x.krausops))
Base.:(*)(sop::KrausRepr, op::Symbolic{AbstractOperator}) = (+)((i*op*dagger(i) for i in sop.krausops)...)
Base.:(*)(sop::KrausRepr, k::Symbolic{AbstractKet}) = (+)((i*SProjector(k)*dagger(i) for i in sop.krausops)...)
Base.:(*)(sop::KrausRepr, k::SZeroOperator) = SZeroOperator()
Base.:(*)(sop::KrausRepr, k::SZeroKet) = SZeroOperator()
Base.show(io::IO, x::KrausRepr) = print(io, "𝒦("*join([symbollabel(i) for i in x.krausops], ",")*")")

##
Expand Down
43 changes: 39 additions & 4 deletions test/test_aqua.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
@testitem "Aqua" tags=[:aqua] begin
using Aqua
Aqua.test_all(QuantumSymbolics,
ambiguities=(;broken=true),
piracies=(;broken=true),
)

# Add any new types needed to QObj, or here if QObj if not appropriate.
# Add types from elsewhere in the ecosystem here or preferably to QObj
own_types = [Base.uniontypes(QObj)...,]
own_types_union = Union{SymQObj,}

Aqua.test_all(QuantumSymbolics, piracies=(;treat_as_own=own_types))

function normalize_arguments(method)
args = Base.unwrap_unionall(method.sig).types[2:end]
normalized_args = []
# handle few edge cases specific to our analysis
for arg in args
# mutation and order of if-conditions is intedtional here
if (arg isa UnionAll) && (arg.body <: Type) arg = arg.body.parameters[1] end
if (arg isa Core.TypeofVararg) arg = arg.T end
if (arg isa TypeVar) arg = arg.ub end
push!(normalized_args, arg)
end
return normalized_args
end

# Custom type-piracy detection, to catch uses of QuantumInterface types without a Symbolic
filtered_piracies = filter(Aqua.Piracy.hunt(QuantumSymbolics)) do m
!any(normalize_arguments(m) .<: own_types_union)
end

aqua_piracies = Aqua.Piracy.hunt(QuantumSymbolics, treat_as_own=own_types)
internally_detected_piracies = setdiff(filtered_piracies, aqua_piracies)
if !isempty(internally_detected_piracies)
printstyled(
stderr,
"Internally flagged possible type-piracy:\n";
color = Base.warn_color()
)
show(stderr, MIME"text/plain"(), internally_detected_piracies)
println(stderr, "\n")
end
@test isempty(internally_detected_piracies)
end

0 comments on commit dd7d140

Please sign in to comment.