Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using flatten/unflatten for Automatic Differentiation #27

Open
paschermayr opened this issue Aug 13, 2021 · 4 comments
Open

Using flatten/unflatten for Automatic Differentiation #27

paschermayr opened this issue Aug 13, 2021 · 4 comments

Comments

@paschermayr
Copy link

Hi there,

Really nice package!

I was wondering if one can adjust the flatten/unflatten functions, such that unflatten is also working inside a closure for using Automatic Differentation.

At the moment, it seems that the type constraints cannot handle Duals. MWE:

using ParameterHandling, Distributions, ReverseDiff, ForwardDiff

# Get sample data and parameter
val = ( μ = 1., σ = 2. )
data = randn(100)

# Write down a logdensity with parameter and data as arguments
function log_density(val, data)
        return sum( Distributions.logpdf(Distributions.Normal(val.μ, val.σ), data[iter] ) for iter in eachindex(data) )
end
log_density(val, data)

# Closure for transforming θ Vector to NamedTuple
function get_log_target(val, data)
        _, unflatten = ParameterHandling.flatten(val)
        function log_target(θ::AbstractVector{T}) where T
                return log_density( unflatten(θ), data)
        end
end

# Check if it is working
lt = get_log_target(val, data)
theta = [.1, .2]
lt(theta)

# Method Error for Dual numbers
ForwardDiff.gradient(lt, theta) # MethodError: no method matching (::ParameterHandling.var"#unflatten_to_NamedTuple#15...
ReverseDiff.gradient(lt, theta) # MethodError: no method matching (::ParameterHandling.var"#unflatten_to_NamedTuple#15"{Float64, NamedTuple{(:μ, :σ)

Zygote was actually working for this toy example, but breaks when used with more complex examples (and was not very fast).

@willtebbutt
Copy link
Member

willtebbutt commented Aug 13, 2021

Hmm interesting, this is definitely a probem we should address.

IIRC the type constraints were introduced to make it possible to ensure that a particular precision was used uniformly throughout the parameter vector. I wonder whether it's necessary to enforce in the unflatten flatten, as well as flatten? @rofinn any thoughts?

For example, removing the element type constraint here would probably do it, but would still ensure that all values are converted to the appropriate precision in flatten (when there presumably aren't any Duals floating around). Would this be sufficient @paschermayr ?

Zygote was actually working for this toy example, but breaks when used with more complex examples (and was not very fast).

We've not focused too much on performance because in the use-cases we've considered so far, the parameter handling stuff has negligible cost relative to the other computations. Do you know if the place in which Zygote was taking a while was the parameter handling stuff, or your other code?

@paschermayr
Copy link
Author

paschermayr commented Aug 13, 2021

@willtebbutt I removed the type constraint in the line you suggested, but the problem persists. I also tried to remove most of the other type constraints (without success). At the moment, the whole framework builds on the fact that constrain and unconstrain type stay the same, so you can infer the original type from the flattened vector. Do you remember the version before that constraint was introduced?

Zygote time for the log_target function increases 3 fold, while for the PrameterHandling line its 5-fold, but the time spent for PrameterHandling is negligible. I think it is just Zygote that is a bit slower than other AD frameworks at the moment.

using BenchmarkTools
@btime Zygote.gradient($lt, $theta) # 3.455 ms (34676 allocations: 1.51 MiB)
@btime lt($theta)                   # 1.020 μs (3 allocations: 208 bytes)

function get_log_target2(data)
        function log_target2(θ::AbstractVector{T}) where T
            dist = Distributions.Normal(θ[1], θ[2] )
                return sum( Distributions.logpdf(dist, data[iter] ) for iter in eachindex(data) )
        end
end
lt2 = get_log_target2(data)
theta = [ .1, .2]
lt2(theta)
@btime Zygote.gradient($lt2, $theta) # 2.923 ms (29167 allocations: 1.28 MiB)
@btime lt2($theta)                   # 927.273 ns (1 allocation: 16 bytes)

Edit: I went back up to version .2.1, and made it working for ForwardDiff and ReverseDiff by changing Vector{X} to AbstractVector{X} and changing NamedTuple to

function flatten(x::NamedTuple)
    x_vec, unflatten = flatten(values(x))
    function unflatten_to_NamedTuple(v::AbstractVector{<:Real})
        v_vec_vec = unflatten(v)
        #return typeof(x)(v_vec_vec)
        return NamedTuple{ keys(x) }(v_vec_vec)
    end
    return x_vec, unflatten_to_NamedTuple
end

Edit 2: While working with the older version, I noticed a slight different behavior for discrete values. In the .2.1 version, all Integers are not flattened, while in the new version, Integers are not flattened, but Vectors of Integers are flattened and transformed to Float64s. MWE:

using ParameterHandling
val = ( a = 1., b = 2, c = [3., 4.], d = [5, 6] )
ParameterHandling.flatten(val) # Vector{Float64} with 5 elements -> [1., 3., 4., 5., 6.]

@willtebbutt
Copy link
Member

Sorry for the slow response. Hectic week.

While working with the older version, I noticed a slight different behavior for discrete values. In the .2.1 version, all Integers are not flattened, while in the new version, Integers are not flattened, but Vectors of Integers are flattened and transformed to Float64s.

Oooo this is looks like a bug. Could you open a separate issue about it please?

Edit: I went back up to version .2.1, and made it working for ForwardDiff and ReverseDiff by changing Vector{X} to AbstractVector{X} and changing NamedTuple to

Interesting. It makes sense to me that this would fix it. I wonder if there's a half-way house that doesn't loose us the niceness of being able to say "give me the Float32 representation" in flatten? Something like

function flatten(::Type{T}, x::NamedTuple) where {T<:Real}
    x_vec, unflatten = flatten(T, values(x))
    function unflatten_to_NamedTuple(v::AbstractVector{<:Real})
        v_vec_vec = unflatten(v)
        #return typeof(x)(v_vec_vec)
        return NamedTuple{ keys(x) }(v_vec_vec)
    end
    return x_vec, unflatten_to_NamedTuple
end

I'm assuming here that you don't need to AD through flatten, just unflatten. Is that correct?

I think to make it work with the current version you'd probably need to also modify the flatten for stuff like Real. Something like

function flatten(::Type{T}, x::R) where {T<:Real, R<:Real}
    v = T[x]
    unflatten_to_Real(v::AbstractVector{<:Real}) = only(v)
    return v, unflatten_to_Real
end

I'd definitely be open to a PR which made these kinds of changes if you would be interested in making them. I'm pretty confident it would just be a case of working through src/flatten.jl and relaxing the type-constraints inside the "unflatten" closures.

I don't think this would be breaking, because provided that everything continues to be flattened to the correct precision, it should automatically be unflattened to the correct precision too.

@paschermayr
Copy link
Author

paschermayr commented Aug 23, 2021

Oooo this is looks like a bug. Could you open a separate issue about it please?

Yes, just submitted the issue.

Interesting. It makes sense to me that this would fix it. I wonder if there's a half-way house that doesn't loose us the niceness of being able to say "give me the Float32 representation" in flatten?

I think to make it work with the current version you'd probably need to also modify the flatten for stuff like Real.

I was not able to both preserve the flatten type and make it AutoDiff friendly. I might try to define 2 functions to return (unflatten and unflatten_AutoDiffFriendly) to choose.

I'm assuming here that you don't need to AD through flatten, just unflatten. Is that correct?

Yes, exactly. From the JuliaCon talk, I assume it is the other way around for problems where ParameterHandling.jl is currently used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants