-
-
Notifications
You must be signed in to change notification settings - Fork 114
Accumulation for ODEProblem #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
236fb65
d50778b
584f8c3
53cc142
4aaae10
a03bb02
e3b0108
c4178b1
c85dfa2
296f43e
0c249ab
40ba20b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,14 @@ | ||
| module SciMLBaseZygoteExt | ||
|
|
||
| using Zygote | ||
| using Zygote: @adjoint, pullback | ||
| import Zygote: literal_getproperty | ||
| using Zygote: @adjoint, pullback, @_adjoint_keepthunks, _project, pair | ||
| import Zygote: literal_getproperty, literal_getfield | ||
| import ChainRulesCore | ||
| using SciMLBase | ||
| using SciMLBase: ODESolution, remake, ODEFunction, | ||
| getobserved, build_solution, EnsembleSolution, | ||
| NonlinearSolution, AbstractTimeseriesSolution | ||
| NonlinearSolution, AbstractTimeseriesSolution, | ||
| ODEProblem | ||
| using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, | ||
| observed, parameter_values, state_values, current_time | ||
| using RecursiveArrayTools | ||
|
|
@@ -299,4 +300,23 @@ end | |
| ∇responsible_map(__context__, f, args...) | ||
| end | ||
|
|
||
| @_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f | ||
| val = getfield(x, f) | ||
| function back(Δ) | ||
| # error() | ||
| Zygote.accum_param(__context__, val, Δ) === nothing && return | ||
| if isimmutable(x) | ||
| error() | ||
| dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) | ||
| (_project(x, dx), nothing) | ||
| else | ||
| dx = Zygote.grad_mut(__context__, x) | ||
| dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...) | ||
| return (dx,nothing) | ||
| end | ||
| end | ||
| Zygote.unwrap(val), back | ||
| end | ||
| Zygote.accum(::Tuple{}, ::NamedTuple{}) = () | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this type piracy?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is 😅
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah... FluxML/Zygote.jl#1574 is definitely the safer bet. Why are problem types |
||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -258,50 +258,58 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, | |
| initdata.update_initializeprob!(initprob, valp) | ||
| end | ||
| end | ||
| nlsol, success = solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg ) | ||
|
|
||
| if is_trivial_initialization(initdata) | ||
| nlsol = initdata | ||
| success = true | ||
| else | ||
| nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) | ||
| if nlsolve_alg === nothing && state_values(initprob) !== nothing | ||
| throw(OverrideInitMissingAlgorithm()) | ||
| end | ||
| if alg.abstol !== nothing | ||
| _abstol = alg.abstol | ||
| elseif abstol !== nothing | ||
| _abstol = abstol | ||
| else | ||
| throw(OverrideInitNoTolerance(:abstol)) | ||
| end | ||
| if alg.reltol !== nothing | ||
| _reltol = alg.reltol | ||
| elseif reltol !== nothing | ||
| _reltol = reltol | ||
| else | ||
| throw(OverrideInitNoTolerance(:reltol)) | ||
| end | ||
| nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...) | ||
|
|
||
| success = if initprob isa NonlinearLeastSquaresProblem | ||
| # Do not accept StalledSuccess as a solution | ||
| # A good local minima is not a success | ||
| resid = nlsol.resid | ||
| normresid = norm(resid) | ||
| SciMLBase.successful_retcode(nlsol) && normresid <= abstol | ||
| else | ||
| SciMLBase.successful_retcode(nlsol) | ||
| end | ||
| end | ||
|
|
||
| nlsol2 = prob.f.initialization_data.initializeprob | ||
| if initdata.initializeprobmap !== nothing | ||
| u0 = initdata.initializeprobmap(choose_branch(nlsol)) | ||
| u02 = initdata.initializeprobmap(nlsol2) | ||
|
||
| end | ||
| if initdata.initializeprobpmap !== nothing | ||
| p = initdata.initializeprobpmap(valp, choose_branch(nlsol)) | ||
| p2 = initdata.initializeprobpmap(valp, nlsol) | ||
| end | ||
|
|
||
| return u0, p, success | ||
| u03 = isnothing(initdata.initializeprobmap) ? u0 : u02 | ||
| p3 = isnothing(initdata.initializeprobpmap) ? p : p2 | ||
| return u03, p3, success | ||
| end | ||
|
|
||
| function solve_initialization(initdata::OverrideInitData{<:AbstractNonlinearProblem{Nothing}}, initprob, alg; kwargs...) | ||
| nlsol = initprob | ||
| success = true | ||
| return nlsol, success | ||
| end | ||
|
|
||
| function solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg) | ||
| nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) | ||
| if nlsolve_alg === nothing && state_values(initprob) !== nothing | ||
| throw(OverrideInitMissingAlgorithm()) | ||
| end | ||
| if alg.abstol !== nothing | ||
| _abstol = alg.abstol | ||
| elseif abstol !== nothing | ||
| _abstol = abstol | ||
| else | ||
| throw(OverrideInitNoTolerance(:abstol)) | ||
| end | ||
| if alg.reltol !== nothing | ||
| _reltol = alg.reltol | ||
| elseif reltol !== nothing | ||
| _reltol = reltol | ||
| else | ||
| throw(OverrideInitNoTolerance(:reltol)) | ||
| end | ||
| nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...) | ||
|
|
||
| success = if initprob isa NonlinearLeastSquaresProblem | ||
| # Do not accept StalledSuccess as a solution | ||
| # A good local minima is not a success | ||
| resid = nlsol.resid | ||
| normresid = norm(resid) | ||
| SciMLBase.successful_retcode(nlsol) && normresid <= abstol | ||
| else | ||
| SciMLBase.successful_retcode(nlsol) | ||
| end | ||
| return nlsol, success | ||
| end | ||
|
|
||
| """ | ||
|
|
@@ -314,21 +322,6 @@ function get_initial_values(prob, integrator, f, ::NoInit, iip; kwargs...) | |
| return state_values(integrator), parameter_values(integrator), true | ||
| end | ||
|
|
||
| is_trivial_initialization(::Nothing) = true | ||
|
|
||
| function is_trivial_initialization(initdata::OverrideInitData) | ||
| !(initdata.initializeprob isa NonlinearLeastSquaresProblem) && | ||
| state_values(initdata.initializeprob) === nothing | ||
| end | ||
|
|
||
| function is_trivial_initialization(f::AbstractSciMLFunction) | ||
| has_initialization_data(f) && is_trivial_initialization(f.initialization_data) | ||
| end | ||
|
|
||
| function is_trivial_initialization(prob::AbstractSciMLProblem) | ||
| is_trivial_initialization(prob.f) | ||
| end | ||
|
|
||
| @enum DETERMINED_STATUS OVERDETERMINED FULLY_DETERMINED UNDERDETERMINED | ||
|
|
||
| function initialization_status(prob::AbstractSciMLProblem) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of this block if the first line is
error()?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in a03bb02