Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inexact Krylov for linear response #1027

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
201 changes: 201 additions & 0 deletions examples/inexact_Krylov.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# test inexact GMRES for linear response
using ASEconvert
using DFTK
using JLD2
using LinearMaps
using LinearAlgebra
using Printf
using Dates
using Random
using ForwardDiff

disable_threading()

println("------ Setting up model ... ------")
repeat = 2
mixing = KerkerMixing()
tol = 1e-12
Ecut =15
kgrid =(1, 3, 3)

system = ase.build.bulk("Al", cubic=true) * pytuple((repeat, 1, 1))
system = pyconvert(AbstractSystem, system)
system = attach_psp(system; Al="hgh/pbe/Al-q3")
model = model_PBE(system, temperature=0.001, symmetries=false)
basis = PlaneWaveBasis(model; Ecut=Ecut, kgrid=kgrid)
println(show(stdout, MIME("text/plain"), basis))
println("------ Running SCF ... ------")
DFTK.reset_timer!(DFTK.timer)
scfres = self_consistent_field(basis; tol=tol, mixing=mixing)
println(DFTK.timer)

println("------ Computing rhs ... ------")
ρ, ψ, ham, basis, occupation, εF, eigenvalues = scfres.ρ, scfres.ψ, scfres.ham, scfres.basis, scfres.occupation, scfres.εF, scfres.eigenvalues
num_kpoints = length(basis.kpoints)
positions = model.positions
lattice = model.lattice
atoms = model.atoms
R = [zeros(3) for pos in positions]
Random.seed!(1234)
for iR in 1:length(R)
R[iR] = -ones(3) + 2 * rand(3)
end
function V1(ε)
T = typeof(ε)
pos = positions + ε * R
modelV = Model(Matrix{T}(lattice), atoms, pos; model_name="potential",
terms=[DFTK.AtomicLocal(), DFTK.AtomicNonlocal()], symmetries=false)
basisV = PlaneWaveBasis(modelV; Ecut, kgrid)
jambon = Hamiltonian(basisV)
DFTK.total_local_potential(jambon)
end
δV = ForwardDiff.derivative(V1, 0.0)
println("||δV|| = ", norm(δV))
flush(stdout)
DFTK.reset_timer!(DFTK.timer)
δρ0 = apply_χ0(ham, ψ, occupation, εF, eigenvalues, δV; tol=1e-16)
println(DFTK.timer)
println("||δρ0|| = ", norm(δρ0))

# setup's for running inexact GMRES for linear response
adaptive=true # other options: "D10", "D100", "D10_n" for fixed CG tolerances
CG_tol_scale_choice="hdmd" # other options: "agr", "hdmd", "grt", with increasingly tighter CG tolerances (i.e., gives more accurate results)
precon = false
restart = 20
maxiter = 100
tol = 1e-9

apply_χ0_info = DFTK.get_apply_χ0_info(ham, ψ, occupation, εF, eigenvalues; CG_tol_type= (adaptive == true) ? CG_tol_scale_choice : "plain")
CG_tol_scale = apply_χ0_info.CG_tol_scale
Nocc_ks = [length(CG_tol_scale[ik]) for ik in 1:num_kpoints]
Nocc = sum(Nocc_ks)

# The actual computations are only several lines of code
# most of the code here is for printing, debugging, and saving intermediate results
# maybe they should be wrapped in chi0.jl and use a verbose flag to control printing
normδV_all = Float64[]
tol_sternheimer_all = Float64[]
CG_niters_all = Vector{Vector{Int64}}[]
CG_xs_all = Vector{Vector{Any}}[]

inds = [0, 1, 1] # i, ik, n

function sternheimer_callback(CG_niters, CG_xs)
function callback(info)
if inds[3] > Nocc_ks[inds[2]]
inds[3] = 1
inds[2] += 1
end
CG_niters[inds[2]][inds[3]] = info.res.n_iter
push!(CG_xs[inds[2]], info.res.x)
inds[3] += 1
end
end

pack(δρ) = vec(δρ)
unpack(δρ) = reshape(δρ, size(ρ))

function operators_a(tol_sternheimer)
function eps_fun(δρ)
δρ = unpack(δρ)
δV = apply_kernel(basis, δρ; ρ)

inds[1] += 1
inds[2:3] = [1, 1]
push!(normδV_all, norm(DFTK.symmetrize_ρ(basis, δV)))
if adaptive == true && CG_tol_scale_choice == "grt"
tol_sternheimer = tol_sternheimer ./ (2*normδV_all[end])
end
push!(tol_sternheimer_all, tol_sternheimer)
CG_niters = [zeros(Int64, Nocc_ks[i]) for i in 1:num_kpoints]
CG_xs = [ [] for _ in 1:num_kpoints ]

println("---- τ_CG's used for each Sternheimer equation (row) of each k-point (column) ----")
τ_CG_table = [max.(0.5*eps(Float64), tol_sternheimer ./ CG_tol_scale[ik]) for ik in 1:num_kpoints]
@printf("| %-7s ", "k-point")
for n in 1:maximum(Nocc_ks)
@printf("| %-8d ", n)
end
@printf("|\n")
for (k, row) in enumerate(τ_CG_table)
@printf("| %-7d ", k)
for τ in row[1:end]
@printf("| %-8.2e ", τ)
end
@printf("|\n")
end
@printf("| %-10s | %-10s | %-10s | %-10s |\n", "τ_i", "min τ_CG", "mean τ_CG", "max τ_CG")
@printf("| %-10.3e | %-10.3e | %-10.3e | %-10.3e |\n\n", tol_sternheimer, minimum(reduce(vcat, τ_CG_table)), exp(sum(log.(reduce(vcat, τ_CG_table)))/Nocc), maximum(reduce(vcat, τ_CG_table)))
flush(stdout)

t1 = Dates.now()
χ0δV = apply_χ0(ham, ψ, occupation, εF, eigenvalues, δV; tol=tol_sternheimer, callback=sternheimer_callback(CG_niters,CG_xs), apply_χ0_info=apply_χ0_info)
t2 = Dates.now()

push!(CG_niters_all, CG_niters)
push!(CG_xs_all, CG_xs)
println("no. CG iters for each Sternheimer equation (row) of each k-point (column):")
@printf("| %-7s ", "k-point")
for n in 1:maximum(Nocc_ks)
@printf("| %-3d ", n)
end
@printf("|\n")
for (k, row) in enumerate(CG_niters)
@printf("| %-7d ", k)
for niters in row[1:end]
@printf("| %-3d ", niters)
end
@printf("|\n")
end
@printf("| %-10s | %-10s | %-10s | %-10s | %-10s |\n", "min", "mean", "max", "sum", "total")
@printf("| %-10d | %-10.3f | %-10d | %-10d | %-10d |\n\n", minimum(reduce(vcat, CG_niters)), sum(reduce(vcat, CG_niters)) / Nocc, maximum(reduce(vcat, CG_niters)), sum(reduce(vcat, CG_niters)), sum(sum.(sum(CG_niters_all))))
println("χ0Time = ", canonicalize(t2 - t1), ", time now: ", Dates.format(t2, "yyyy-mm-dd HH:MM:SS"))
flush(stdout)
#println("ave CG iters = ", sum(reduce(vcat, CG_niters)) / Nocc)

if precon
return pack(DFTK.mix_density(mixing, basis, δρ - χ0δV))
else
return pack(δρ - χ0δV)
end
end
return LinearMap(eps_fun, prod(size(δρ0)))
end

println("----- running GMRES: tol=", tol, ", restart=", restart, ", adaptive=", adaptive, ", CG_tol_scale_choice=", CG_tol_scale_choice, " -----")
Pδρ0 = δρ0
if precon
Pδρ0 = DFTK.mix_density(mixing, basis, δρ0)
end
println("||Pδρ0|| = ", norm(Pδρ0))
# if the first argument of DFTK.gmres is a function, then each iteration the effective matrix is changed (here, due to inexactly computed mat-vec products)
# if the first argument of DFTK.gmres is a matrix, then the matrix is fixed
DFTK.reset_timer!(DFTK.timer)
if adaptive == "D10"
results_a = DFTK.gmres(operators_a(tol / 10), pack(Pδρ0); restart=restart, tol=tol, verbose=1, debug=true, maxiter=maxiter)
elseif adaptive == "D10_n"
results_a = DFTK.gmres(operators_a(tol / 10 / norm(Pδρ0)), pack(Pδρ0); restart=restart, tol=tol, verbose=1, debug=true, maxiter=maxiter)
elseif adaptive == "D100"
results_a = DFTK.gmres(operators_a(tol / 100), pack(Pδρ0); restart=restart, tol=tol, verbose=1, debug=true, maxiter=maxiter)
elseif adaptive == true
results_a = DFTK.gmres(operators_a, pack(Pδρ0); restart=restart, tol=tol, verbose=1, debug=true, maxiter=maxiter)
else
error("Invalid adaptive choice")
end
println(DFTK.timer)


# define the "exact" application of the dielectric adjoint operator
# by using very tight CG tolerances
# this is also how `eps_fun` should have looked like after removing all the printing and debugging code...
function eps_fun_exact(δρ)
normδρ = norm(δρ)
δρ = δρ ./ normδρ
δρ = unpack(δρ)
δV = apply_kernel(basis, δρ; ρ)

χ0δV = apply_χ0(ham, ψ, occupation, εF, eigenvalues, δV; tol=1e-16)
pack(δρ - χ0δV) .* normδρ
end

println("true residual = ", norm(eps_fun_exact(results_a.x[:, end]) - pack(δρ0)), "\n")
1 change: 1 addition & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ include("postprocess/dos.jl")
export compute_χ0
export apply_χ0
include("response/cg.jl")
include("response/gmres.jl")
include("response/chi0.jl")
include("response/hessian.jl")
export compute_current
Expand Down
113 changes: 104 additions & 9 deletions src/response/chi0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@
callback=identity, cg_callback=identity,
ψk_extra=zeros_like(ψk, size(ψk, 1), 0), εk_extra=zeros(0),
Hψk_extra=zeros_like(ψk, size(ψk, 1), 0), tol=1e-9)
# do not use tol smaller than eps(T)/2
tol = max(0.5*eps(eltype(ε)), tol)

basis = Hk.basis
kpoint = Hk.kpoint

Expand Down Expand Up @@ -293,8 +296,8 @@
For phonon, `δHψ[ik]` is ``δH·ψ_{k-q}``, expressed in `basis.kpoints[ik]`.
"""
function compute_δψ!(δψ, basis::PlaneWaveBasis{T}, H, ψ, εF, ε, δHψ, ε_minus_q=ε;
ψ_extra=[zeros_like(ψk, size(ψk,1), 0) for ψk in ψ],
q=zero(Vec3{T}), kwargs_sternheimer...) where {T}
ψ_extra=[zeros_like(ψk, size(ψk,1), 0) for ψk in ψ], q=zero(Vec3{T}),
CG_tol_scale=nothing, kwargs_sternheimer...) where {T}
# We solve the Sternheimer equation
# (H_k - ε_{n,k-q}) δψ_{n,k} = - (1 - P_{k}) δHψ_{n, k-q},
# where P_{k} is the projector on ψ_{k} and with the conventions:
Expand All @@ -307,6 +310,10 @@
smearing = model.smearing
filled_occ = filled_occupation(model)

flag = !isnothing(CG_tol_scale)
if flag
tol_sternheimer = kwargs_sternheimer[:tol]
end
# Compute δψnk band per band
for ik = 1:length(ψ)
Hk = H[ik]
Expand All @@ -333,7 +340,10 @@
δψk[:, n] .+= ψk[:, m] .* αmn .* dot(ψk[:, m], δHψ[ik][:, n])
end

# Sternheimer contribution
# Sternheimer contribution with adaptive CG tolerance
if flag
kwargs_sternheimer = merge(kwargs_sternheimer, Dict(:tol => tol_sternheimer / CG_tol_scale[ik][n]))
end
δψk[:, n] .+= sternheimer_solver(Hk, ψk, εk_minus_q[n], δHψ[ik][:, n]; ψk_extra,
εk_extra, Hψk_extra, kwargs_sternheimer...)
end
Expand Down Expand Up @@ -390,12 +400,91 @@
(; δψ, δoccupation, δεF)
end

function get_apply_χ0_info(ham, ψ, occupation, εF::T, eigenvalues;
occupation_threshold=default_occupation_threshold(T),
q=zero(Vec3{eltype(ham.basis)}),CG_tol_type="hdmd") where {T}

CG_tol_type = lowercase(string(CG_tol_type))

basis = ham.basis
num_kpoints = length(basis.kpoints)
k_to_k_minus_q = k_to_kpq_permutation(basis, -q)

mask_occ = map(occk -> findall(occnk -> abs(occnk) ≥ occupation_threshold, occk), occupation)
mask_extra = map(occk -> findall(occnk -> abs(occnk) < occupation_threshold, occk), occupation)

ψ_occ = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_occ)]
ψ_extra = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_extra)]
ε_occ = [eigenvalues[ik][maskk] for (ik, maskk) in enumerate(mask_occ)]

ε_minus_q_occ = [eigenvalues[k_to_k_minus_q[ik]][mask_occ[k_to_k_minus_q[ik]]]
for ik = 1:num_kpoints]

Nocc_ks = [length(ε_occ[ik]) for ik in 1:num_kpoints]
Nocc = sum(Nocc_ks)

# compute CG_tol_scale
fn_occ = [occupation[ik][maskk] for (ik, maskk) in enumerate(mask_occ)]
if CG_tol_type == "hdmd"
CG_tol_scale = [fn_occ[ik] * basis.kweights[ik] for ik in 1:num_kpoints] * Nocc * sqrt(prod(basis.fft_size)) / basis.model.unit_cell_volume
elseif CG_tol_type == "grt"
kcoef = zeros(num_kpoints)
for k in 1:num_kpoints
accum = zeros(basis.fft_size)
for n in 1:Nocc_ks[k]
accum += (abs2.(real.(ifft(basis, basis.kpoints[k], ψ[k][:, n]))))
end
kcoef[k] = sqrt(maximum(accum)) * basis.kweights[k]
end

Check warning on line 438 in src/response/chi0.jl

View check run for this annotation

Codecov / codecov/patch

src/response/chi0.jl#L430-L438

Added lines #L430 - L438 were not covered by tests

CG_tol_scale = [fn_occ[ik] * kcoef[ik] for ik in 1:num_kpoints] * sqrt(Nocc) * sqrt(prod(basis.fft_size)) / sqrt(basis.model.unit_cell_volume)

Check warning on line 440 in src/response/chi0.jl

View check run for this annotation

Codecov / codecov/patch

src/response/chi0.jl#L440

Added line #L440 was not covered by tests
else
CG_tol_scale = [[1.0 for _ in 1:Nocc_ks[ik]] for ik in 1:num_kpoints]
if !occursin(CG_tol_type, "agrplain1.0")
@warn("CG_tol_type is not recognized, set CG_tol_scale to 1.0 for all bands")

Check warning on line 444 in src/response/chi0.jl

View check run for this annotation

Codecov / codecov/patch

src/response/chi0.jl#L442-L444

Added lines #L442 - L444 were not covered by tests
end
end

(; k_to_k_minus_q, mask_occ, ψ_occ, ψ_extra, ε_occ, ε_minus_q_occ, CG_tol_scale)
end

@views @timing function apply_χ0_4P(ham, occupation, εF, δHψ, apply_χ0_info::NamedTuple;

Check warning on line 451 in src/response/chi0.jl

View check run for this annotation

Codecov / codecov/patch

src/response/chi0.jl#L451

Added line #L451 was not covered by tests
q=zero(Vec3{eltype(ham.basis)}), kwargs_sternheimer...)

basis = ham.basis
mask_occ = apply_χ0_info.mask_occ
k_to_k_minus_q = apply_χ0_info.k_to_k_minus_q
ψ_occ = apply_χ0_info.ψ_occ
ε_occ = apply_χ0_info.ε_occ

δHψ_minus_q_occ = [δHψ[ik][:, mask_occ[k_to_k_minus_q[ik]]] for ik = 1:length(basis.kpoints)]

δoccupation = zero.(occupation)
if iszero(q)
δocc_occ = [δoccupation[ik][maskk] for (ik, maskk) in enumerate(mask_occ)]
(; δεF) = compute_δocc!(δocc_occ, basis, ψ_occ, εF, ε_occ, δHψ_minus_q_occ)
else
# When δH is not periodic, δH ψnk is a Bloch wave at k+q and ψnk at k,
# so that δεnk = <ψnk|δH|ψnk> = 0 and there is no occupation shift
δεF = zero(εF)

Check warning on line 469 in src/response/chi0.jl

View check run for this annotation

Codecov / codecov/patch

src/response/chi0.jl#L469

Added line #L469 was not covered by tests
end

# Then we compute δψ (again in-place into a zero-padded array) with elements of
# `basis.kpoints` that are equivalent to `k+q`.
δψ = zero.(δHψ)
δψ_occ = [δψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_occ[k_to_k_minus_q])]
compute_δψ!(δψ_occ, ham.basis, ham.blocks, ψ_occ, εF, ε_occ, δHψ_minus_q_occ, apply_χ0_info.ε_minus_q_occ;
apply_χ0_info.ψ_extra, q, apply_χ0_info.CG_tol_scale, kwargs_sternheimer...)

(; δψ, δoccupation, δεF)
end

"""
Get the density variation δρ corresponding to a potential variation δV.
"""
function apply_χ0(ham, ψ, occupation, εF::T, eigenvalues, δV::AbstractArray{TδV};
occupation_threshold=default_occupation_threshold(TδV),
q=zero(Vec3{eltype(ham.basis)}), kwargs_sternheimer...) where {T, TδV}
occupation_threshold=default_occupation_threshold(TδV), q=zero(Vec3{eltype(ham.basis)}),
apply_χ0_info=nothing, kwargs_sternheimer...) where {T, TδV}

basis = ham.basis

Expand All @@ -414,15 +503,21 @@
# For phonon calculations, assemble
# δHψ_k = δV_{q} · ψ_{k-q}.
δHψ = multiply_ψ_by_blochwave(basis, ψ, δV, q)
(; δψ, δoccupation) = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHψ;
occupation_threshold, q, kwargs_sternheimer...)
if isnothing(apply_χ0_info)
(; δψ, δoccupation) = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHψ;
occupation_threshold, q, kwargs_sternheimer...)
else
(; δψ, δoccupation) = apply_χ0_4P(ham, occupation, εF, δHψ, apply_χ0_info;
q, kwargs_sternheimer...)
end

δρ = compute_δρ(basis, ψ, δψ, occupation, δoccupation; occupation_threshold, q)

δρ * normδV
end

function apply_χ0(scfres, δV; kwargs_sternheimer...)
function apply_χ0(scfres, δV; apply_χ0_info=nothing, kwargs_sternheimer...)
apply_χ0(scfres.ham, scfres.ψ, scfres.occupation, scfres.εF, scfres.eigenvalues, δV;
scfres.occupation_threshold, kwargs_sternheimer...)
occupation_threshold=scfres.occupation_threshold,
apply_χ0_info=apply_χ0_info, kwargs_sternheimer...)
end
Loading
Loading