Skip to content

Commit

Permalink
swap beta models to make 3 compartment the main diameter model
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinggong committed Jul 9, 2024
1 parent 8444c6f commit 64a2619
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 4 deletions.
117 changes: 113 additions & 4 deletions src/estimators_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Noisemodel(Microstructure.logp_rician, 0.02, (0.005, 0.1), Normal{Float64}(μ=0.
Base.@kwdef struct Noisemodel
logpdf::Function = logp_gauss
sigma_start::Float64 = 0.01
sigma_range::Tuple{Float64,Float64} = (0.005, 0.1)
sigma_range::Tuple{Float64,Float64} = (0.001, 0.1)
proposal::Distribution = Normal(0, 0.005)
end

Expand Down Expand Up @@ -123,7 +123,7 @@ function draw_samples!(
end

"""
draw_samples(sampler::Sampler, noise::Noisemodel = Noisemodel())
draw_samples(sampler::Sampler, noise::Noisemodel , container::String)
Generate pertubations used in MCMC for tissue parameters and sigma using the proposals
"""
Expand Down Expand Up @@ -167,6 +167,49 @@ function draw_samples(sampler::Sampler, noise::Noisemodel, container::String)
return pertubations
end

"""
draw_samples(sampler::Sampler, container::String)
Generate pertubations used in MCMC for tissue parameters
"""
function draw_samples(sampler::Sampler, container::String)

if container == "vec"

pertubations = [
Vector{Any}(undef, sampler.nsamples) for i in 1:length(sampler.params)
]

@inbounds for (i, para) in enumerate(sampler.params)
pertubation = rand(sampler.proposal[i], sampler.nsamples)
if pertubation isa Vector
pertubations[i] = pertubation
else
pertubations[i] = [vec(pertubation[:, i]) for i in 1:(sampler.nsamples)]
end
end

elseif container == "dict"

pertubations = Dict()
@inbounds for (i, para) in enumerate(sampler.params)
# pertubation could be a vector or a matrix from multi-variant proposal
pertubation = rand(sampler.proposal[i], sampler.nsamples)
if pertubation isa Vector
push!(pertubations, para => pertubation)
else
push!(
pertubations, para => [vec(pertubation[:, i]) for i in 1:(sampler.nsamples)]
) # for vector fracs
end
end
else
error("use vec or dict")
end

return pertubations
end

"""
Define a subsampler sampling a subset of parameters in the sampler
using index vector for keeping parameters
Expand Down Expand Up @@ -201,8 +244,8 @@ function Sampler(model::BiophysicalModel)
Normal(0, 0.25e-6),
Normal(0, 0.025e-9),
Normal(0, 0.05),
MvNormal([0.0025 0 0; 0 0.0001 0; 0 0 0.0001]),
) #; equal to (Normal(0,0.05),Normal(0,0.01),Normal(0,0.01)) for fracs
MvNormal([0.0025 0;0 0.0001]), # 3-compartment model with 2 free fraction parameters
) #; equal to (Normal(0,0.05),Normal(0,0.01)) for fracs
nsamples = 70000
burnin = 20000
# setup sampler and noise model
Expand Down Expand Up @@ -481,6 +524,72 @@ function mcmc!(
return nothing
end

## testing: mcmc with given noise level sigma
function mcmc!(
estimates::BiophysicalModel,
meas::Vector{Float64},
protocol::Protocol,
sampler::Sampler,
sigma::Float64,
rng::Int64=1,
)
Random.seed!(rng)

# create chain and pertubations
chain = create_chain(sampler, "dict")
pertubations = draw_samples(sampler, "dict")

# get logp_start from the start model and sigma
logp_start = noise.logpdf(meas, model_signals(estimates, protocol), sigma)

@inbounds for i in 1:(sampler.nsamples::Int)

# get current pertubation
pertubation = Tuple(para => pertubations[para][i] for para in sampler.params)

# get the next sample location and check if it is within prior ranges
outliers = increment!(estimates, pertubation, sampler.prior_range)

if iszero(outliers)

# update linked parameters in model
update!(estimates, sampler.paralinks)

# update logp
logp_next = noise.logpdf(meas, model_signals(estimates, protocol), sigma)

# acception ratio
if rand(Float64) < min(1, exp(logp_next - logp_start))
move = 1
logp_start = copy(logp_next)
else
move = 0
# move estimates back to previous location
decrement!(estimates, pertubation)
update!(estimates, sampler.paralinks)
end
else
move = 0
# move next back to current location
decrement!(estimates, pertubation)
end

record_chain!(chain, estimates, sampler.params, i, move, sigma, logp_start)
end

#update model object as the mean values of selected samples
update!(
estimates,
Tuple(
para => mean(chain[para][(sampler.burnin):(sampler.thinning):end]) for
para in sampler.params
),
)
update!(estimates, sampler.paralinks)

return chain
end

function record_chain!(
chain::Dict{String,Vector{Any}},
estimates::BiophysicalModel,
Expand Down
3 changes: 3 additions & 0 deletions src/estimators_nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ end
NetworkArg(model, protocol,params,paralinks,tissuetype,sigma,noise_type,dropoutp=0.2)
Use the inputs related to biophysical models to determine network architecture and number of training samples
return a full defined NetworkArg struct
Reference for adjusting the number of training samples:
Shwartz-Ziv, R., Goldblum, M., Bansal, A., Bruss, C.B., LeCun, Y., & Wilson, A.G. (2024). Just How Flexible are Neural Networks in Practice?
"""
function NetworkArg(
model::BiophysicalModel,
Expand Down

0 comments on commit 64a2619

Please sign in to comment.