From 5c887ce05e64d6304c0268b40695c9ddb024dbc1 Mon Sep 17 00:00:00 2001 From: Bobby Date: Tue, 19 Nov 2024 23:58:02 -0500 Subject: [PATCH 1/3] implement truncation_error keyword arg for truncate! --- src/abstractmps.jl | 12 ++++++++++-- test/base/test_mps.jl | 10 ++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/abstractmps.jl b/src/abstractmps.jl index 4670be5..1580efe 100644 --- a/src/abstractmps.jl +++ b/src/abstractmps.jl @@ -1672,13 +1672,18 @@ provided as keyword arguments. Keyword arguments: * `site_range`=1:N - only truncate the MPS bonds between these sites +* `truncation_error` - if provided, will store the truncation error from all SVDs performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_error = Ref{Float64}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_error[] = 0.0`). """ function truncate!(M::AbstractMPS; alg="frobenius", kwargs...) return truncate!(Algorithm(alg), M; kwargs...) end function truncate!( - ::Algorithm"frobenius", M::AbstractMPS; site_range=1:length(M), kwargs... + ::Algorithm"frobenius", + M::AbstractMPS; + site_range=1:length(M), + truncation_error=nothing, + kwargs..., ) N = length(M) @@ -1690,7 +1695,10 @@ function truncate!( for j in reverse((first(site_range) + 1):last(site_range)) rinds = uniqueinds(M[j], M[j - 1]) ltags = tags(commonind(M[j], M[j - 1])) - U, S, V = svd(M[j], rinds; lefttags=ltags, kwargs...) + U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...) + if !isnothing(truncation_error) + truncation_error[] += spec.truncerr + end M[j] = U M[j - 1] *= (S * V) setrightlim!(M, j) diff --git a/test/base/test_mps.jl b/test/base/test_mps.jl index 65eed0e..ff75ef1 100644 --- a/test/base/test_mps.jl +++ b/test/base/test_mps.jl @@ -755,6 +755,16 @@ end truncate!(M; site_range=3:7, maxdim=2) @test linkdims(M) == [2, 4, 2, 2, 2, 2, 8, 4, 2] end + + @testset "truncate! with truncation_error" begin + M = basicRandomMPS(10; dim=10) + truncation_error = Ref{Float64}() + truncation_error[] = 0.0 + truncate!(M, maxdim=3, cutoff=1E-3, truncation_error=truncation_error) + @test truncation_error[] > 0.0 + end + + end @testset "Other MPS methods" begin From 00f1764fa6ea21b7824ba66121a05fc9037d6725 Mon Sep 17 00:00:00 2001 From: Bobby Date: Wed, 1 Jan 2025 20:56:28 -0500 Subject: [PATCH 2/3] Refactor truncate! to return truncation error from each bond This commit refactors the original implemenation. `truncate!` now expects the user to pass a pointer to a vector of floats with as many elements as there are bonds in the MPS. It will then store the truncation error of each bond in the vector. The corresponding test was updated. The package was re-tested with the same results described in the PR (75165 passing and 33 broken for total of 75198 tests). The docstring of `truncate!` was updated to reflect the new behavior. --- src/abstractmps.jl | 10 +++++----- test/base/test_mps.jl | 12 +++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/abstractmps.jl b/src/abstractmps.jl index 1580efe..318e473 100644 --- a/src/abstractmps.jl +++ b/src/abstractmps.jl @@ -1672,7 +1672,7 @@ provided as keyword arguments. Keyword arguments: * `site_range`=1:N - only truncate the MPS bonds between these sites -* `truncation_error` - if provided, will store the truncation error from all SVDs performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_error = Ref{Float64}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_error[] = 0.0`). +* `truncation_errors` - if provided, will store the truncation error from each SVD performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_errors = Ref{Vector{Float64}}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_errors[] = zeros(nbonds)`). """ function truncate!(M::AbstractMPS; alg="frobenius", kwargs...) return truncate!(Algorithm(alg), M; kwargs...) @@ -1682,7 +1682,7 @@ function truncate!( ::Algorithm"frobenius", M::AbstractMPS; site_range=1:length(M), - truncation_error=nothing, + truncation_errors=nothing, kwargs..., ) N = length(M) @@ -1692,12 +1692,12 @@ function truncate!( orthogonalize!(M, last(site_range)) # Perform truncations in a right-to-left sweep - for j in reverse((first(site_range) + 1):last(site_range)) + for (i,j) in enumerate(reverse((first(site_range) + 1):last(site_range))) rinds = uniqueinds(M[j], M[j - 1]) ltags = tags(commonind(M[j], M[j - 1])) U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...) - if !isnothing(truncation_error) - truncation_error[] += spec.truncerr + if !isnothing(truncation_errors) + truncation_errors[][i] = spec.truncerr end M[j] = U M[j - 1] *= (S * V) diff --git a/test/base/test_mps.jl b/test/base/test_mps.jl index ff75ef1..87c1452 100644 --- a/test/base/test_mps.jl +++ b/test/base/test_mps.jl @@ -756,12 +756,14 @@ end @test linkdims(M) == [2, 4, 2, 2, 2, 2, 8, 4, 2] end - @testset "truncate! with truncation_error" begin + @testset "truncate! with truncation_errors" begin + N = 10 + nbonds = N - 1 M = basicRandomMPS(10; dim=10) - truncation_error = Ref{Float64}() - truncation_error[] = 0.0 - truncate!(M, maxdim=3, cutoff=1E-3, truncation_error=truncation_error) - @test truncation_error[] > 0.0 + truncation_errors = Ref{Vector{Float64}}() + truncation_errors[] = fill(-1.0, nbonds) # set to something other than zero for test. + truncate!(M, maxdim=3, cutoff=1E-3, truncation_errors=truncation_errors) + @test all(truncation_errors[] .>= 0.0) end From 4a298393cf3f721515061429385d4469834ccaf8 Mon Sep 17 00:00:00 2001 From: NuclearPowerNerd <58567518+NuclearPowerNerd@users.noreply.github.com> Date: Mon, 13 Jan 2025 14:08:01 -0500 Subject: [PATCH 3/3] Apply suggestions from code review Change new kwarg name to `truncation_errors!` from `truncation_errors` Remove usage of `enumerate` Update docstring Co-authored-by: Matt Fishman --- src/abstractmps.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/abstractmps.jl b/src/abstractmps.jl index 318e473..6e12600 100644 --- a/src/abstractmps.jl +++ b/src/abstractmps.jl @@ -1672,7 +1672,7 @@ provided as keyword arguments. Keyword arguments: * `site_range`=1:N - only truncate the MPS bonds between these sites -* `truncation_errors` - if provided, will store the truncation error from each SVD performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_errors = Ref{Vector{Float64}}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_errors[] = zeros(nbonds)`). +* `truncation_errors!` - if provided, will store the truncation error from each SVD performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_errors! = Ref{Vector{Float64}}()`, which will be overwritten in the function. """ function truncate!(M::AbstractMPS; alg="frobenius", kwargs...) return truncate!(Algorithm(alg), M; kwargs...) @@ -1682,7 +1682,7 @@ function truncate!( ::Algorithm"frobenius", M::AbstractMPS; site_range=1:length(M), - truncation_errors=nothing, + (truncation_errors!)=nothing, kwargs..., ) N = length(M) @@ -1692,7 +1692,9 @@ function truncate!( orthogonalize!(M, last(site_range)) # Perform truncations in a right-to-left sweep - for (i,j) in enumerate(reverse((first(site_range) + 1):last(site_range))) + js = reverse((first(site_range) + 1):last(site_range)) + for i in eachindex(js) + j = js[i] rinds = uniqueinds(M[j], M[j - 1]) ltags = tags(commonind(M[j], M[j - 1])) U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...)