diff --git a/src/PlaneWaveBasis.jl b/src/PlaneWaveBasis.jl index 8b322c8002..f9b325e26b 100644 --- a/src/PlaneWaveBasis.jl +++ b/src/PlaneWaveBasis.jl @@ -400,7 +400,8 @@ Returns nothing if outside the range of valid wave vectors. end end -function index_G_vectors(basis::PlaneWaveBasis, G::AbstractVector{<:Integer}) +# @inline is necessary here for the inner function to be inlined as well +@inline function index_G_vectors(basis::PlaneWaveBasis, G::AbstractVector{<:Integer}) index_G_vectors(basis.fft_size, G) end diff --git a/src/common/cis2pi.jl b/src/common/cis2pi.jl index a98596ad9a..178ddb21bc 100644 --- a/src/common/cis2pi.jl +++ b/src/common/cis2pi.jl @@ -1,5 +1,14 @@ """Function to compute exp(2π i x)""" cis2pi(x) = cispi(2x) +function cis2pi(x::T) where {T <: AbstractFloat} + # Special case when 2x is an integer, as exp(n*π*i) = +- 1. Saves expensive + # exponential evaluations when called repeatedly + if isinteger(2x) + return isodd(2x) ? -one(complex(T)) : one(complex(T)) + else + return cispi(2x) + end +end """Function to compute sin(2π x)""" sin2pi(x) = sinpi(2x) diff --git a/test/symmetry_issues.jl b/test/symmetry_issues.jl index e39926a733..9cdca2aa27 100644 --- a/test/symmetry_issues.jl +++ b/test/symmetry_issues.jl @@ -1,7 +1,7 @@ # This file collects examples, where we had issues with symmetries (symmetry determination, # k-point reduction, etc.) which are now resolved. Should make sure we don't reintroduce bugs. -@testitem "Symmetry issues" begin +@testitem "Symmetry issues" setup=[TestCases] begin using DFTK using DFTK: spglib_dataset using Unitful @@ -24,4 +24,51 @@ @test length(symmetries) == 48 @test spglib_dataset(system).spacegroup_number == 225 end + + @testset "Inlining" begin + # Test that the index_G_vectors function is properly inlined by comparing timing + # with a locally defined function known not to be inlined. Issue initially tackled + # in PR https://github.com/JuliaMolSim/DFTK.jl/pull/1025 + function index_G_vectors_slow(basis, G::AbstractVector{<:Integer}) + start = .- cld.(basis.fft_size .- 1, 2) + stop = fld.(basis.fft_size .- 1, 2) + lengths = stop .- start .+ 1 + + # FFTs store wavevectors as [0 1 2 3 -2 -1] (example for N=5) + function G_to_index(length, G) + G >= 0 && return 1 + G + return 1 + length + G + end + if all(start .<= G .<= stop) + CartesianIndex(Tuple(G_to_index.(lengths, G))) + else + nothing # Outside range of valid indices + end + end + + # This is a bare-bone version of the accumulate_over_symmetries() function, only + # keeping calls to the index_G_vectors() function for which we test inlining + function G_vectors_calls(basis, test_func) + for symop in basis.symmetries + invS = Mat3{Int}(inv(symop.S)) + for (ig, G) in enumerate(DFTK.G_vectors_generator(basis.fft_size)) + igired = test_func(basis, invS * G) + end + end + end + + silicon = TestCases.silicon + Si = ElementPsp(silicon.atnum, load_psp("hgh/lda/si-q4")) + atoms = [Si, Si] + model = model_DFT(silicon.lattice, atoms, silicon.positions; + functionals=[:lda_x, :lda_c_vwn]) + Ecut = 32 + kgrid = [1, 1, 1] + basis = PlaneWaveBasis(model; Ecut, kgrid) + + actual_alloc = @allocated G_vectors_calls(basis, DFTK.index_G_vectors) + slow_alloc = @allocated G_vectors_calls(basis, index_G_vectors_slow) + @test slow_alloc > actual_alloc + end end +