Skip to content

Commit

Permalink
Projected density of states (#1002)
Browse files Browse the repository at this point in the history
Co-authored-by: xquan818 <[email protected]>
  • Loading branch information
xuequan818 and xquan818 authored Oct 17, 2024
1 parent 5b5e214 commit d8a5454
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 44 deletions.
31 changes: 31 additions & 0 deletions examples/dos.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# # Densities of states (DOS)
# In this example, we'll plot the DOS, local DOS, and projected DOS of Silicon.
# DOS computation only supports finite temperature.

using DFTK
using Unitful
using Plots
using LazyArtifacts

## Define the geometry and pseudopotential
a = 10.26 # Silicon lattice constant in Bohr
lattice = a / 2 * [[0 1 1.0];
[1 0 1.0];
[1 1 0.0]]
Si = ElementPsp(:Si; psp=load_psp(artifact"pd_nc_sr_lda_standard_0.4.1_upf", "Si.upf"))
atoms = [Si, Si]
positions = [ones(3) / 8, -ones(3) / 8]

## Run SCF
model = model_LDA(lattice, atoms, positions; temperature=5e-3)
basis = PlaneWaveBasis(model; Ecut=15, kgrid=[4, 4, 4], symmetries_respect_rgrid=true)
scfres = self_consistent_field(basis, tol=1e-8)

## Plot the DOS
plot_dos(scfres)

## Plot the local DOS along one direction
plot_ldos(scfres; n_points=100, ldos_xyz=[:, 10, 10])

## Plot the projected DOS
plot_pdos(scfres; εrange=(-0.3, 0.5))
86 changes: 85 additions & 1 deletion ext/DFTKPlotsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DFTKPlotsExt
using Brillouin: KPath
using DFTK
using DFTK: is_metal, data_for_plotting, spin_components, default_band_εrange
import DFTK: plot_dos, plot_bandstructure
import DFTK: plot_dos, plot_bandstructure, plot_ldos, plot_pdos
using Plots
using Unitful
using UnitfulAtomic
Expand Down Expand Up @@ -108,4 +108,88 @@ function plot_dos(basis, eigenvalues; εF=nothing, unit=u"hartree",
end
plot_dos(scfres; kwargs...) = plot_dos(scfres.basis, scfres.eigenvalues; scfres.εF, kwargs...)

function plot_ldos(basis, eigenvalues, ψ; εF=nothing, unit=u"hartree",
temperature=basis.model.temperature,
smearing=basis.model.smearing,
εrange=default_band_εrange(eigenvalues; εF),
n_points=1000, ldos_xyz=[:, 1, 1], kwargs...)
eshift = something(εF, 0.0)
εs = range(austrip.(εrange)..., length=n_points)

# Constant to convert from AU to the desired unit
to_unit = ustrip(auconvert(unit, 1.0))

# LDε has three dimensions (x, y, z)
# map on a single axis to plot the variation with εs
LDεs = dropdims.(compute_ldos.(εs, Ref(basis), Ref(eigenvalues), Ref(ψ); smearing, temperature); dims=4)
LDεs_slice = similar(LDεs[1], n_points, length(LDεs[1][ldos_xyz...]))
for (i, LDε) in enumerate(LDεs)
LDεs_slice[i, :] = LDε[ldos_xyz...]
end
p = heatmap(1:size(LDεs_slice, 2), (εs .- eshift) .* to_unit, LDεs_slice; kwargs...)
if !isnothing(εF)
Plots.hline!(p, [0.0], label="εF", color=:green, lw=1.5)
end

if isnothing(εF)
Plots.ylabel!(p, "eigenvalues ($(string(unit)))")
else
Plots.ylabel!(p, "eigenvalues -ε_F ($(string(unit)))")
end
p
end
plot_ldos(scfres; kwargs...) = plot_ldos(scfres.basis, scfres.eigenvalues, scfres.ψ; scfres.εF, kwargs...)

function plot_pdos(basis, eigenvalues, ψ, i, l,
psp, position, el::Symbol;
εF=nothing, unit=u"hartree",
temperature=basis.model.temperature,
smearing=basis.model.smearing,
εrange=default_band_εrange(eigenvalues; εF),
n_points=1000, p=nothing, kwargs...)
eshift = something(εF, 0.0)
εs = range(austrip.(εrange)..., length=n_points)

# Constant to convert from AU to the desired unit
to_unit = ustrip(auconvert(unit, 1.0))

# Calculate the projections of the atom with given i and l,
# and sum all angular momentums m=-l:l
pdos = dropdims(sum(compute_pdos(εs, basis, eigenvalues, ψ, i, l,
psp, position; temperature, smearing), dims=2); dims=2)
label = String(el) * "-" * psp.pswfc_labels[l+1][i]

# Plot pdos
p = something(p, Plots.plot(; kwargs...))
Plots.plot!(p, (εs .- eshift) .* to_unit, pdos; label)

p
end

function plot_pdos(scfres; kwargs...)
# Plot DOS
p = plot_dos(scfres; scfres.εF, kwargs...)

# TODO do the symmetrization instead of unfolding
scfres_unfold = DFTK.unfold_bz(scfres)
basis = scfres_unfold.basis
psp_groups = [group for group in basis.model.atom_groups
if basis.model.atoms[first(group)] isa ElementPsp]

# Plot PDOS for the first atom of each atom group
for group in psp_groups
psp = basis.model.atoms[first(group)].psp
position = basis.model.positions[first(group)]
el = basis.model.atoms[first(group)].symbol
for l = 0:psp.lmax
for i = 1:DFTK.count_n_pswfc_radial(psp, l)
plot_pdos(basis, scfres_unfold.eigenvalues, scfres_unfold.ψ,
i, l, psp, position, el; scfres.εF, p, kwargs...)
end
end
end

p
end

end
3 changes: 3 additions & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ export compute_stresses_cart
include("postprocess/stresses.jl")
export compute_dos
export compute_ldos
export compute_pdos
export plot_dos
export plot_ldos
export plot_pdos
include("postprocess/dos.jl")
export compute_χ0
export apply_χ0
Expand Down
70 changes: 70 additions & 0 deletions src/postprocess/dos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#
# LDOS (local density of states)
# LD(ε) = sum_n f_n' |ψn|^2 = sum_n δ(ε - ε_n) |ψn|^2
#
# PD(ε) = sum_n f_n' |<ψn,ϕ>|^2
# ϕ = ∑_R ϕilm(r-pos-R) is the periodized atomic wavefunction, obtained from the pseudopotential
# This is computed for a given (i,l) (eg i=2,l=2 for the 3p) and summed over all m

"""
Total density of states at energy ε
Expand Down Expand Up @@ -63,7 +67,73 @@ function compute_ldos(scfres::NamedTuple; ε=scfres.εF, kwargs...)
compute_ldos(ε, scfres.basis, scfres.eigenvalues, scfres.ψ; kwargs...)
end

"""
Projected density of states at energy ε for an atom with given i and l.
"""
function compute_pdos(ε, basis, eigenvalues, ψ, i, l, psp, position;
smearing=basis.model.smearing,
temperature=basis.model.temperature)
if (temperature == 0) || smearing isa Smearing.None
error("compute_pdos only supports finite temperature")
end
filled_occ = filled_occupation(basis.model)

projs = compute_pdos_projs(basis, eigenvalues, ψ, i, l, psp, position)

D = zeros(typeof(ε[1]), length(ε), 2l+1)
for (i, iε) in enumerate(ε)
for (ik, projk) in enumerate(projs)
@views for (iband, εnk) in enumerate(eigenvalues[ik])
enred = (εnk - iε) / temperature
D[i, :] .-= (filled_occ .* basis.kweights[ik] .* projk[iband, :]
./ temperature
.* Smearing.occupation_derivative(smearing, enred))
end
end
end
D = mpi_sum(D, basis.comm_kpts)
end

function compute_pdos(scfres::NamedTuple, iatom, i, l; ε=scfres.εF, kwargs...)
psp = scfres.basis.model.atoms[iatom].psp
position = scfres.basis.model.positions[iatom]
# TODO do the symmetrization instead of unfolding
scfres = unfold_bz(scfres)
compute_pdos(ε, scfres.basis, scfres.eigenvalues, scfres.ψ, i, l, psp, position; kwargs...)
end

# Build atomic orbitals projections projs[ik][iband,m] = |<ψnk, ϕ>|^2 for a single atom.
# TODO optimization ? accept a range of positions, in case we want to compute the PDOS
# for all atoms of the same type (and reuse the computation of the atomic orbitals)
function compute_pdos_projs(basis, eigenvalues, ψ, i, l, psp::NormConservingPsp, position)
# Precompute the form factors on all kpoints at once so we can better use the cache (memory-compute tradeoff).
# Revisit this (pass the cache around explicitly) if RAM becomes an issue.
G_plus_k_all = [Gplusk_vectors(basis, basis.kpoints[ik])
for ik = 1:length(basis.kpoints)]
G_plus_k_all_cart = [map(recip_vector_red_to_cart(basis.model), gpk)
for gpk in G_plus_k_all]

# Build form factors of pseudo-wavefunctions centered at 0.
fun(p) = eval_psp_pswfc_fourier(psp, i, l, p)
# form_factors_all[ik][p,m]
form_factors_all = build_form_factors(fun, l, G_plus_k_all_cart)

projs = Vector{Matrix}(undef, length(basis.kpoints))
for (ik, ψk) in enumerate(ψ)
structure_factor = [cis2pi(-dot(position, p)) for p in G_plus_k_all[ik]]
# TODO orthogonalize pseudo-atomic wave functions?
proj_vectors = structure_factor .* form_factors_all[ik] ./ sqrt(basis.model.unit_cell_volume)
projs[ik] = abs2.(ψk' * proj_vectors) # contract on p -> projs[ik][iband,m]
end

projs
end

"""
Plot the density of states over a reasonable range. Requires to load `Plots.jl` beforehand.
"""
function plot_dos end

function plot_ldos end

function plot_pdos end
39 changes: 27 additions & 12 deletions src/pseudo/NormConservingPsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,24 @@ abstract type NormConservingPsp end
# description::String # Descriptive string

#### Methods:
# charge_ionic(psp::NormConservingPsp)
# has_valence_density(psp:NormConservingPsp)
# has_core_density(psp:NormConservingPsp)
# eval_psp_projector_real(psp::NormConservingPsp, i, l, r::Real)
# eval_psp_projector_fourier(psp::NormConservingPsp, i, l, p::Real)
# eval_psp_local_real(psp::NormConservingPsp, r::Real)
# eval_psp_local_fourier(psp::NormConservingPsp, p::Real)
# eval_psp_energy_correction(T::Type, psp::NormConservingPsp, n_electrons::Integer)
# charge_ionic(psp)
# has_valence_density(psp)
# has_core_density(psp)
# eval_psp_projector_real(psp, i, l, r::Real)
# eval_psp_projector_fourier(psp, i, l, p::Real)
# eval_psp_local_real(psp, r::Real)
# eval_psp_local_fourier(psp, p::Real)
# eval_psp_energy_correction(T::Type, psp, n_electrons::Integer)

#### Optional methods:
# eval_psp_density_valence_real(psp::NormConservingPsp, r::Real)
# eval_psp_density_valence_fourier(psp::NormConservingPsp, p::Real)
# eval_psp_density_core_real(psp::NormConservingPsp, r::Real)
# eval_psp_density_core_fourier(psp::NormConservingPsp, p::Real)
# eval_psp_density_valence_real(psp, r::Real)
# eval_psp_density_valence_fourier(psp, p::Real)
# eval_psp_density_core_real(psp, r::Real)
# eval_psp_density_core_fourier(psp, p::Real)
# eval_psp_pswfc_real(psp, i::Int, l::Int, p::Real)
# eval_psp_pswfc_fourier(psp, i::Int, l::Int, p::Real)
# count_n_pswfc(psp, l::Integer)
# count_n_pswfc_radial(psp, l::Integer)

"""
eval_psp_projector_real(psp, i, l, r)
Expand Down Expand Up @@ -203,3 +207,14 @@ function count_n_proj(psps, psp_positions)
sum(count_n_proj(psp) * length(positions)
for (psp, positions) in zip(psps, psp_positions))
end

count_n_pswfc_radial(psp::NormConservingPsp, l) = error("Pseudopotential $psp does not implement atomic wavefunctions.")

function count_n_pswfc_radial(psp::NormConservingPsp)
sum(l -> count_n_pswfc_radial(psp, l), 0:psp.lmax; init=0)::Int
end

count_n_pswfc(psp::NormConservingPsp, l) = count_n_pswfc_radial(psp, l) * (2l + 1)
function count_n_pswfc(psp::NormConservingPsp)
sum(l -> count_n_pswfc(psp, l), 0:psp.lmax; init=0)::Int
end
2 changes: 2 additions & 0 deletions src/pseudo/PspUpf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ function eval_psp_projector_fourier(psp::PspUpf, i, l, p::T)::T where {T<:Real}
hankel(rgrid, r2_proj, l, p)
end

count_n_pswfc_radial(psp::PspUpf, l) = length(psp.r2_pswfcs[l+1])

function eval_psp_pswfc_real(psp::PspUpf, i, l, r::T)::T where {T<:Real}
psp.r2_pswfcs_interp[l+1][i](r) / r^2
end
Expand Down
79 changes: 48 additions & 31 deletions src/terms/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ end
# We compute the forces from the irreductible BZ; they are symmetrized later.
G_plus_k = Gplusk_vectors(basis, kpt)
G_plus_k_cart = to_cpu(Gplusk_vectors_cart(basis, kpt))
form_factors = build_form_factors(element.psp, G_plus_k_cart)
form_factors = build_projector_form_factors(element.psp, G_plus_k_cart)
for idx in group
r = model.positions[idx]
structure_factors = [cis2pi(-dot(p, r)) for p in G_plus_k]
Expand Down Expand Up @@ -173,7 +173,7 @@ function build_projection_vectors(basis::PlaneWaveBasis{T}, kpt::Kpoint,
for (psp, positions) in zip(psps, psp_positions)
# Compute position-independent form factors
G_plus_k_cart = to_cpu(Gplusk_vectors_cart(basis, kpt))
form_factors = build_form_factors(psp, G_plus_k_cart)
form_factors = build_projector_form_factors(psp, G_plus_k_cart)

# Combine with structure factors
for r in positions
Expand All @@ -193,47 +193,64 @@ function build_projection_vectors(basis::PlaneWaveBasis{T}, kpt::Kpoint,
end

"""
Build form factors (Fourier transforms of projectors) for an atom centered at 0.
Build form factors (Fourier transforms of projectors) for all orbitals of an atom centered at 0.
"""
function build_form_factors(psp, G_plus_k::AbstractVector{Vec3{TT}}) where {TT}
function build_projector_form_factors(psp::NormConservingPsp,
G_plus_k::AbstractVector{Vec3{TT}}) where {TT}
G_plus_ks = [G_plus_k]

n_proj = count_n_proj(psp)
form_factors = zeros(Complex{TT}, length(G_plus_k), n_proj)
for l = 0:psp.lmax,
n_proj_l = count_n_proj_radial(psp, l)
offset = sum(x -> count_n_proj(psp, x), 0:l-1; init=0) .+
n_proj_l .* (collect(1:2l+1) .- 1) # offset about m for a given l
for i = 1:n_proj_l
proj_li(p) = eval_psp_projector_fourier(psp, i, l, p)
form_factors_li = build_form_factors(proj_li, l, G_plus_ks)
@views form_factors[:, offset.+i] = form_factors_li[1]
end
end

form_factors
end

"""
Build Fourier transform factors of an atomic function centered at 0 for a given l.
"""
function build_form_factors(fun::Function, l::Int,
G_plus_ks::AbstractVector{<:AbstractVector{Vec3{TT}}}) where {TT}
# TODO this function can be generally useful, should refactor to a separate file eventually
T = real(TT)

# Pre-compute the radial parts of the non-local projectors at unique |p| to speed up
# Pre-compute the radial parts of the non-local atomic functions at unique |p| to speed up
# the form factor calculation (by a lot). Using a hash map gives O(1) lookup.

# Maximum number of projectors over angular momenta so that form factors
# for a given `p` can be stored in an `nproj x (lmax + 1)` matrix.
n_proj_max = maximum(l -> count_n_proj_radial(psp, l), 0:psp.lmax; init=0)

radials = IdDict{T,Matrix{T}}() # IdDict for Dual compatibility
for p in G_plus_k
p_norm = norm(p)
if !haskey(radials, p_norm)
radials_p = Matrix{T}(undef, n_proj_max, psp.lmax + 1)
for l = 0:psp.lmax, iproj_l = 1:count_n_proj_radial(psp, l)
# TODO This might be faster if we do this in batches of l
# (i.e. make the inner loop run over k-points and G_plus_k)
# and did recursion over l to compute the spherical bessels
radials_p[iproj_l, l+1] = eval_psp_projector_fourier(psp, iproj_l, l, p_norm)
radials = IdDict{T,T}() # IdDict for Dual compatibility
for G_plus_k in G_plus_ks
for p in G_plus_k
p_norm = norm(p)
if !haskey(radials, p_norm)
radials_p = fun(p_norm)
radials[p_norm] = radials_p
end
radials[p_norm] = radials_p
end
end

form_factors = Matrix{Complex{T}}(undef, length(G_plus_k), count_n_proj(psp))
for (ip, p) in enumerate(G_plus_k)
radials_p = radials[norm(p)]
count = 1
for l = 0:psp.lmax, m = -l:l
# see "Fourier transforms of centered functions" in the docs for the formula
angular = (-im)^l * ylm_real(l, m, p)
for iproj_l = 1:count_n_proj_radial(psp, l)
form_factors[ip, count] = radials_p[iproj_l, l+1] * angular
count += 1
form_factors = Vector{Matrix{Complex{T}}}(undef, length(G_plus_ks))
for (ik, G_plus_k) in enumerate(G_plus_ks)
form_factors_ik = Matrix{Complex{T}}(undef, length(G_plus_k), 2l + 1)
for (ip, p) in enumerate(G_plus_k)
radials_p = radials[norm(p)]
for m = -l:l
# see "Fourier transforms of centered functions" in the docs for the formula
angular = (-im)^l * ylm_real(l, m, p)
form_factors_ik[ip, m+l+1] = radials_p * angular
end
end
@assert count == count_n_proj(psp) + 1
form_factors[ik] = form_factors_ik
end

form_factors
end

Expand Down

0 comments on commit d8a5454

Please sign in to comment.