Skip to content

Commit

Permalink
Merge pull request #88 from cesmix-mit/ooc_learn
Browse files Browse the repository at this point in the history
Ooc learn
  • Loading branch information
emmanuellujan authored Oct 21, 2024
2 parents acb6d22 + 60a20c5 commit 4c7e70d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PotentialLearning"
uuid = "82b0a93c-c2e3-44bc-a418-f0f89b0ae5c2"
authors = ["CESMIX Team"]
version = "0.2.6"
version = "0.2.7"

[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
Expand Down
96 changes: 92 additions & 4 deletions src/Learning/linear-learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ Fit energies and forces using weighted least squares.
function learn!(
lp::CovariateLinearProblem,
ws::Vector,
int::Bool
int::Bool;
λ::Real=0.0
)
@views B_train = reduce(hcat, lp.B)'
@views dB_train = reduce(hcat, lp.dB)'
Expand All @@ -249,11 +250,11 @@ function learn!(

βs = Vector{Float64}()
try
βs = (A'*Q*A) \ (A'*Q*b)
βs = (A'*Q*A + λ*I) \ (A'*Q*b)
catch e
println(e)
println("Linear system will be solved using pinv.")
βs = pinv(A'*Q*A)*(A'*Q*b)
βs = pinv(A'*Q*A + λ*I)*(A'*Q*b)
end

# Update lp.
Expand All @@ -263,7 +264,7 @@ function learn!(
else
lp.β .= βs
end

end


Expand All @@ -282,4 +283,91 @@ function learn!(
return learn!(lp, ws, int)
end

function assemble_matrices(lp, ws)
@views B_train = reduce(hcat, lp.B)'
@views dB_train = reduce(hcat, lp.dB)'
@views e_train = lp.e
@views f_train = reduce(vcat, lp.f)

@views A = [B_train; dB_train]
@views b = [e_train; f_train]

W = Diagonal([ws[1] * ones(length(e_train));
ws[2] * ones(length(f_train))])

A, W, b
end

function ooc_learn!(
lb::InteratomicPotentials.LinearBasisPotential,
ds_train::PotentialLearning.DataSet;
ws = [30.0,1.0],
symmetrize::Bool = true,
λ::Union{Real,Nothing} = 0.01,
reg_style::Symbol = :default,
AtWA = nothing,
AtWb = nothing
)

basis_size = length(lb.basis)

if isnothing(AtWA) || isnothing(AtWb)
AtWA = zeros(basis_size,basis_size)
AtWb = zeros(basis_size)

W = zeros(1,1)

configs = get_system.(ds_train)

for config in ds_train
ref_energy = get_values(get_energy(config))
ref_forces = reduce(vcat,get_values(get_forces(config)))

sys = get_system(config)
global_descrs = reshape(sum(compute_local_descriptors(sys,lb.basis)),:,1)'
force_descrs = stack(reduce(vcat,compute_force_descriptors(sys,lb.basis)))'

A = [global_descrs; force_descrs]
b = [ref_energy; ref_forces]
if size(W)[1] != size(A)[1]
W = Diagonal( [ws[1]*ones(length(ref_energy));
ws[2]*ones(length(ref_forces))])
end

AtWA .+= A'*W*A
AtWb .+= A'*W*b
end
else
AtWA = deepcopy(AtWA)
AtWb = deepcopy(AtWb)
end

if symmetrize
AtWA = Symmetric(AtWA)
end

if !isnothing(λ)
if reg_style == :default
reg_matrix = λ*Diagonal(ones(size(AtWA)[1]))
AtWA += reg_matrix
end

if reg_style == :scale_thresh || reg_style == :scale
for i in 1:size(AtWA,1)
reg_elem = AtWA[i,i]*(1+λ)
if reg_style == :scale_thresh
reg_elem = max(reg_elem,λ)
end#
AtWA[i,i] = reg_elem
end
end

end

β = AtWA \ AtWb
println("condition number of AtWA: $(cond(AtWA))")

lb.β .= β

AtWA, AtWb
end

0 comments on commit 4c7e70d

Please sign in to comment.