Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ITensors] Improve type stability of svdMPO and qn_svdMPO #1183

Merged
14 changes: 9 additions & 5 deletions src/physics/autompo/opsum_to_mpo.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
function svdMPO(os::OpSum{C}, sites; kwargs...)::MPO where {C}
function svdMPO(ValType::Type{<:Number}, os::OpSum{C}, sites; kwargs...)::MPO where {C}
terasakisatoshi marked this conversation as resolved.
Show resolved Hide resolved
mindim::Int = get(kwargs, :mindim, 1)
maxdim::Int = get(kwargs, :maxdim, 10000)
cutoff::Float64 = get(kwargs, :cutoff, 1E-15)

N = length(sites)

ValType = determineValType(terms(os))

Vs = [Matrix{ValType}(undef, 1, 1) for n in 1:N]
Vs = Matrix{ValType}[Matrix{ValType}(undef, 1, 1) for n in 1:N]
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
terasakisatoshi marked this conversation as resolved.
Show resolved Hide resolved
tempMPO = [MatElem{Scaled{C,Prod{Op}}}[] for n in 1:N]

function crosses_bond(t::Scaled{C,Prod{Op}}, n::Int) where {C}
Expand Down Expand Up @@ -117,7 +115,7 @@ function svdMPO(os::OpSum{C}, sites; kwargs...)::MPO where {C}
end

#
# Special handling of starting and
# Special handling of starting and
# ending identity operators:
#
idM = zeros(ValType, dim(ll), dim(rl))
Expand All @@ -138,3 +136,9 @@ function svdMPO(os::OpSum{C}, sites; kwargs...)::MPO where {C}

return H
end #svdMPO

function svdMPO(os::OpSum{C}, sites; kwargs...)::MPO where {C}
# Function barrier to improve type stability
ValType = determineValType(terms(os))
return svdMPO(ValType, os, sites; kwargs...)
end
12 changes: 8 additions & 4 deletions src/physics/autompo/opsum_to_mpo_qn.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
function qn_svdMPO(os::OpSum{C}, sites; kwargs...)::MPO where {C}
function qn_svdMPO(ValType::Type{<:Number}, os::OpSum{C}, sites; kwargs...)::MPO where {C}
terasakisatoshi marked this conversation as resolved.
Show resolved Hide resolved
mindim::Int = get(kwargs, :mindim, 1)
maxdim::Int = get(kwargs, :maxdim, typemax(Int))
cutoff::Float64 = get(kwargs, :cutoff, 1E-15)

N = length(sites)

ValType = determineValType(terms(os))

Vs = [Dict{QN,Matrix{ValType}}() for n in 1:(N + 1)]
Vs = Dict{QN,Matrix{ValType}}[Dict{QN,Matrix{ValType}}() for n in 1:(N + 1)]
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
terasakisatoshi marked this conversation as resolved.
Show resolved Hide resolved
sparse_MPO = [QNMatElem{Scaled{C,Prod{Op}}}[] for n in 1:N]

function crosses_bond(t::Scaled{C,Prod{Op}}, n::Int)
Expand Down Expand Up @@ -251,3 +249,9 @@ function qn_svdMPO(os::OpSum{C}, sites; kwargs...)::MPO where {C}

return H
end #qn_svdMPO

function qn_svdMPO(os::OpSum{C}, sites; kwargs...)::MPO where {C}
# Function barrier to improve type stability
ValType = determineValType(terms(os))
return qn_svdMPO(ValType, os, sites; kwargs...)
end