-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using flatten/unflatten for Automatic Differentiation #27
Comments
Hmm interesting, this is definitely a probem we should address. IIRC the type constraints were introduced to make it possible to ensure that a particular precision was used uniformly throughout the parameter vector. I wonder whether it's necessary to enforce in the For example, removing the element type constraint here would probably do it, but would still ensure that all values are converted to the appropriate precision in
We've not focused too much on performance because in the use-cases we've considered so far, the parameter handling stuff has negligible cost relative to the other computations. Do you know if the place in which Zygote was taking a while was the parameter handling stuff, or your other code? |
@willtebbutt I removed the type constraint in the line you suggested, but the problem persists. I also tried to remove most of the other type constraints (without success). At the moment, the whole framework builds on the fact that constrain and unconstrain type stay the same, so you can infer the original type from the flattened vector. Do you remember the version before that constraint was introduced? Zygote time for the log_target function increases 3 fold, while for the PrameterHandling line its 5-fold, but the time spent for PrameterHandling is negligible. I think it is just Zygote that is a bit slower than other AD frameworks at the moment.
Edit: I went back up to version .2.1, and made it working for ForwardDiff and ReverseDiff by changing Vector{X} to AbstractVector{X} and changing NamedTuple to
Edit 2: While working with the older version, I noticed a slight different behavior for discrete values. In the .2.1 version, all Integers are not flattened, while in the new version, Integers are not flattened, but Vectors of Integers are flattened and transformed to Float64s. MWE:
|
Sorry for the slow response. Hectic week.
Oooo this is looks like a bug. Could you open a separate issue about it please?
Interesting. It makes sense to me that this would fix it. I wonder if there's a half-way house that doesn't loose us the niceness of being able to say "give me the Float32 representation" in flatten? Something like function flatten(::Type{T}, x::NamedTuple) where {T<:Real}
x_vec, unflatten = flatten(T, values(x))
function unflatten_to_NamedTuple(v::AbstractVector{<:Real})
v_vec_vec = unflatten(v)
#return typeof(x)(v_vec_vec)
return NamedTuple{ keys(x) }(v_vec_vec)
end
return x_vec, unflatten_to_NamedTuple
end I'm assuming here that you don't need to AD through I think to make it work with the current version you'd probably need to also modify the function flatten(::Type{T}, x::R) where {T<:Real, R<:Real}
v = T[x]
unflatten_to_Real(v::AbstractVector{<:Real}) = only(v)
return v, unflatten_to_Real
end I'd definitely be open to a PR which made these kinds of changes if you would be interested in making them. I'm pretty confident it would just be a case of working through I don't think this would be breaking, because provided that everything continues to be |
Yes, just submitted the issue.
I was not able to both preserve the flatten type and make it AutoDiff friendly. I might try to define 2 functions to return (unflatten and unflatten_AutoDiffFriendly) to choose.
Yes, exactly. From the JuliaCon talk, I assume it is the other way around for problems where ParameterHandling.jl is currently used. |
Hi there,
Really nice package!
I was wondering if one can adjust the flatten/unflatten functions, such that unflatten is also working inside a closure for using Automatic Differentation.
At the moment, it seems that the type constraints cannot handle Duals. MWE:
Zygote was actually working for this toy example, but breaks when used with more complex examples (and was not very fast).
The text was updated successfully, but these errors were encountered: