diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 45a8e0f63..bfe8a9f9a 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -1,13 +1,14 @@ module SciMLBaseZygoteExt using Zygote -using Zygote: @adjoint, pullback -import Zygote: literal_getproperty +using Zygote: @adjoint, pullback, @_adjoint_keepthunks, _project, pair +import Zygote: literal_getproperty, literal_getfield import ChainRulesCore using SciMLBase using SciMLBase: ODESolution, remake, ODEFunction, getobserved, build_solution, EnsembleSolution, - NonlinearSolution, AbstractTimeseriesSolution + NonlinearSolution, AbstractTimeseriesSolution, + ODEProblem using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, observed, parameter_values, state_values, current_time using RecursiveArrayTools @@ -299,4 +300,15 @@ end ∇responsible_map(__context__, f, args...) end +@_adjoint_keepthunks function Zygote.literal_getfield(x::SciMLBase.AbstractSciMLProblem, ::Val{f}) where f + val = getfield(x, f) + function back(Δ) + Zygote.accum_param(__context__, val, Δ) === nothing && return + dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) + (_project(x, dx), nothing) + end + Zygote.unwrap(val), back +end +Zygote.accum(::Tuple{}, ::NamedTuple{}) = () + end diff --git a/src/initialization.jl b/src/initialization.jl index ef056cd3e..e98a0e829 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -258,41 +258,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata.update_initializeprob!(initprob, valp) end end - - if is_trivial_initialization(initdata) - nlsol = initprob - success = true - else - nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) - if nlsolve_alg === nothing && state_values(initprob) !== nothing - throw(OverrideInitMissingAlgorithm()) - end - if alg.abstol !== nothing - _abstol = alg.abstol - elseif abstol !== nothing - _abstol = abstol - else - throw(OverrideInitNoTolerance(:abstol)) - end - if alg.reltol !== nothing - _reltol = alg.reltol - elseif reltol !== nothing - _reltol = reltol - else - throw(OverrideInitNoTolerance(:reltol)) - end - nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...) - - success = if initprob isa NonlinearLeastSquaresProblem - # Do not accept StalledSuccess as a solution - # A good local minima is not a success - resid = nlsol.resid - normresid = norm(resid) - SciMLBase.successful_retcode(nlsol) && normresid <= abstol - else - SciMLBase.successful_retcode(nlsol) - end - end + nlsol, success = solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg, kwargs...) if initdata.initializeprobmap !== nothing u02 = initdata.initializeprobmap(nlsol) @@ -301,14 +267,50 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, p2 = initdata.initializeprobpmap(valp, nlsol) end - # specifically needs to be written this way for Zygote - # See https://github.com/SciML/ModelingToolkit.jl/pull/3585#issuecomment-2883919162 u03 = isnothing(initdata.initializeprobmap) ? u0 : u02 p3 = isnothing(initdata.initializeprobpmap) ? p : p2 - return u03, p3, success end +function solve_initialization(initdata::OverrideInitData{<:AbstractNonlinearProblem{Nothing}}, initprob, alg; kwargs...) + nlsol = initprob + success = true + return nlsol, success +end + +function solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg, kwargs...) + nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) + if nlsolve_alg === nothing && state_values(initprob) !== nothing + throw(OverrideInitMissingAlgorithm()) + end + if alg.abstol !== nothing + _abstol = alg.abstol + elseif abstol !== nothing + _abstol = abstol + else + throw(OverrideInitNoTolerance(:abstol)) + end + if alg.reltol !== nothing + _reltol = alg.reltol + elseif reltol !== nothing + _reltol = reltol + else + throw(OverrideInitNoTolerance(:reltol)) + end + nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...) + + success = if initprob isa NonlinearLeastSquaresProblem + # Do not accept StalledSuccess as a solution + # A good local minima is not a success + resid = nlsol.resid + normresid = norm(resid) + SciMLBase.successful_retcode(nlsol) && normresid <= abstol + else + SciMLBase.successful_retcode(nlsol) + end + return nlsol, success +end + """ $(TYPEDSIGNATURES) @@ -319,21 +321,6 @@ function get_initial_values(prob, integrator, f, ::NoInit, iip; kwargs...) return state_values(integrator), parameter_values(integrator), true end -is_trivial_initialization(::Nothing) = true - -function is_trivial_initialization(initdata::OverrideInitData) - !(initdata.initializeprob isa NonlinearLeastSquaresProblem) && - state_values(initdata.initializeprob) === nothing -end - -function is_trivial_initialization(f::AbstractSciMLFunction) - has_initialization_data(f) && is_trivial_initialization(f.initialization_data) -end - -function is_trivial_initialization(prob::AbstractSciMLProblem) - is_trivial_initialization(prob.f) -end - @enum DETERMINED_STATUS OVERDETERMINED FULLY_DETERMINED UNDERDETERMINED function initialization_status(prob::AbstractSciMLProblem) diff --git a/src/remake.jl b/src/remake.jl index 45098ccdd..ccd70e383 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -1207,7 +1207,7 @@ function maybe_eager_initialize_problem(prob::AbstractSciMLProblem, initializati lazy_initialization = !is_trivial_initialization(initialization_data) end if initialization_data !== nothing && !lazy_initialization && - (!is_time_dependent(prob) || current_time(prob) !== nothing) + (!is_time_dependent(prob) || current_time(prob) !== nothing) u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) if u0 !== nothing && eltype(u0) == Any && isempty(u0) @@ -1229,3 +1229,20 @@ function remake(thing::AbstractEnsembleProblem; kwargs...) en_kwargs = [k for k in kwargs if first(k) ∈ fieldnames(T)] T(remake(thing.prob; setdiff(kwargs, en_kwargs)...); en_kwargs...) end + +is_trivial_initialization(::Nothing) = true + +is_trivial_initialization(::OverrideInitData{<:NonlinearLeastSquaresProblem}) = false + +is_trivial_initialization(::OverrideInitData{<:NonlinearProblem{Nothing}}) = true + +is_trivial_initialization(::OverrideInitData) = false + +function is_trivial_initialization(f::AbstractSciMLFunction) + # has_initialization_data(f) && is_trivial_initialization(f.initialization_data) + is_trivial_initialization(f.initialization_data) +end + +function is_trivial_initialization(prob::AbstractSciMLProblem) + is_trivial_initialization(prob.f) +end \ No newline at end of file