Skip to content

Commit

Permalink
Merge pull request #91 from cesmix-mit/ooc_learn
Browse files Browse the repository at this point in the history
Ooc learn
  • Loading branch information
swyant authored Nov 21, 2024
2 parents 25e48fd + 378f542 commit 807ff3c
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions src/Learning/linear-learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ function ooc_learn!(
λ::Union{Real,Nothing} = 0.01,
reg_style::Symbol = :default,
AtWA = nothing,
AtWb = nothing
AtWb = nothing,
pbar = true,
eweight_normalized = :squared # :squared, :standard or nothing
)

basis_size = length(lb.basis)
Expand All @@ -317,21 +319,37 @@ function ooc_learn!(

W = zeros(1,1)

configs = get_system.(ds_train)
if pbar
iter = ProgressBar(ds_train)
else
iter = ds_train
end

for config in ds_train
for config in iter
ref_energy = get_values(get_energy(config))
ref_forces = reduce(vcat,get_values(get_forces(config)))

sys = get_system(config)
natoms = length(sys)
global_descrs = reshape(sum(compute_local_descriptors(sys,lb.basis)),:,1)'
force_descrs = stack(reduce(vcat,compute_force_descriptors(sys,lb.basis)))'

A = [global_descrs; force_descrs]
b = [ref_energy; ref_forces]
if size(W)[1] != size(A)[1]
W = Diagonal( [ws[1]*ones(length(ref_energy));
ws[2]*ones(length(ref_forces))])

if isnothing(eweight_normalized)
we_norm = 1.0
elseif eweight_normalized == :standard
we_norm = 1/natoms
elseif eweight_normalized == :squared
we_norm = 1/natoms^2
else
error("eweight_normalized can only be nothing, :standard, or :squared")
end

W = Diagonal( [we_norm*ws[1];
ws[2]*ones(length(ref_forces))] )
end

AtWA .+= A'*W*A
Expand Down

0 comments on commit 807ff3c

Please sign in to comment.