-
Notifications
You must be signed in to change notification settings - Fork 36
Make run_ad
return both primal and gradient time
#1002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: breaking
Are you sure you want to change the base?
Conversation
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), | ||
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), | ||
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), | ||
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), | ||
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), | ||
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), | ||
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), | ||
("Multivariate 1k", multivariate1k, :typed, :mooncake, true), | ||
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), | ||
("Multivariate 10k", multivariate10k, :typed, :mooncake, true), | ||
("Dynamic", Models.dynamic(), :typed, :mooncake, true), | ||
("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), | ||
("LDA", lda_instance, :typed, :reversediff, true), | ||
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff), | ||
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff), | ||
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff), | ||
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff), | ||
("Smorgasbord", smorgasbord_instance, :typed, :reversediff), | ||
("Smorgasbord", smorgasbord_instance, :typed, :mooncake), | ||
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake), | ||
("Multivariate 1k", multivariate1k, :typed, :mooncake), | ||
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake), | ||
("Multivariate 10k", multivariate10k, :typed, :mooncake), | ||
("Dynamic", Models.dynamic(), :typed, :mooncake), | ||
("Submodel", Models.parent(randn(rng)), :typed, :mooncake), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR also removes the option to set varinfo to be linked or unlinked. Going forward everything is linked. In practice there really aren't any cases where AD is run with unlinked varinfo (indeed running with unlinked is a recipe for bugs when you have constraints like in Dirichlet or LKJCholesky, see e.g. TuringLang/ADTests#7) so I don't think that we should do it here.
DynamicPPL.jl documentation for PR #1002 is available at: |
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note Chairmarks is already a strong dep of DPPL
retvals = model(rng) | ||
vns = [VarName{k}() for k in keys(retvals)] | ||
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) | ||
vi = DynamicPPL.typed_varinfo(rng, model) | ||
vals = DynamicPPL.values_as(vi, Dict) | ||
SimpleVarInfo{Float64}(vals) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old code only works if the model explicitly returns a NamedTuple with only plain Symbols (which they happen to do). I am aware that this pattern is peppered all over the code base (e.g. with demo models too) but I think we should try to avoid having magic return values and instead rely on functionality that is designed to work on all models.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## breaking #1002 +/- ##
============================================
- Coverage 81.91% 81.87% -0.05%
============================================
Files 38 38
Lines 4025 4027 +2
============================================
Hits 3297 3297
- Misses 728 730 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
run_ad
would calculate both and then return only the ratio, which seemed a bit wasteful.With this change we also get to use
run_ad
in the CI benchmarking workflow.