From 236fb6500bd1dc178ba5820037e4f9b5256fe545 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 18 May 2025 20:28:16 +0530 Subject: [PATCH 01/10] fix: respect OOP `update_initializeprob!` --- src/initialization.jl | 83 +++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 4d25633f6..3e029950d 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -259,40 +259,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, end end - if is_trivial_initialization(initdata) - nlsol = initdata - success = true - else - nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) - if nlsolve_alg === nothing && state_values(initprob) !== nothing - throw(OverrideInitMissingAlgorithm()) - end - if alg.abstol !== nothing - _abstol = alg.abstol - elseif abstol !== nothing - _abstol = abstol - else - throw(OverrideInitNoTolerance(:abstol)) - end - if alg.reltol !== nothing - _reltol = alg.reltol - elseif reltol !== nothing - _reltol = reltol - else - throw(OverrideInitNoTolerance(:reltol)) - end - nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...) - - success = if initprob isa NonlinearLeastSquaresProblem - # Do not accept StalledSuccess as a solution - # A good local minima is not a success - resid = nlsol.resid - normresid = norm(resid) - SciMLBase.successful_retcode(nlsol) && normresid <= abstol - else - SciMLBase.successful_retcode(nlsol) - end - end + nlsol, success = solve_initialization(initdata, initprob, alg, Val{is_trivial_initialization(initdata)}(); abstol, reltol, nlsolve_alg) if initdata.initializeprobmap !== nothing u0 = initdata.initializeprobmap(choose_branch(nlsol)) @@ -304,6 +271,45 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, return u0, p, success end +function solve_initialization(initdata, initprob, alg, ::Val{true}; kwargs...) + nlsol = @set initdata.initializeprob = initprob + success = true + return nlsol, success +end + +function solve_initialization(initdata, initprob, alg, ::Val{false}; reltol, abstol, nlsolve_alg) + nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) + if nlsolve_alg === nothing && state_values(initprob) !== nothing + throw(OverrideInitMissingAlgorithm()) + end + if alg.abstol !== nothing + _abstol = alg.abstol + elseif abstol !== nothing + _abstol = abstol + else + throw(OverrideInitNoTolerance(:abstol)) + end + if alg.reltol !== nothing + _reltol = alg.reltol + elseif reltol !== nothing + _reltol = reltol + else + throw(OverrideInitNoTolerance(:reltol)) + end + nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...) + + success = if initprob isa NonlinearLeastSquaresProblem + # Do not accept StalledSuccess as a solution + # A good local minima is not a success + resid = nlsol.resid + normresid = norm(resid) + SciMLBase.successful_retcode(nlsol) && normresid <= abstol + else + SciMLBase.successful_retcode(nlsol) + end + return nlsol, success +end + """ $(TYPEDSIGNATURES) @@ -316,10 +322,11 @@ end is_trivial_initialization(::Nothing) = true -function is_trivial_initialization(initdata::OverrideInitData) - !(initdata.initializeprob isa NonlinearLeastSquaresProblem) && - state_values(initdata.initializeprob) === nothing -end +is_trivial_initialization(::OverrideInitData{<:NonlinearLeastSquaresProblem}) = false + +is_trivial_initialization(::OverrideInitData{<:NonlinearProblem{Nothing}}) = true + +is_trivial_initialization(::OverrideInitData) = false function is_trivial_initialization(f::AbstractSciMLFunction) has_initialization_data(f) && is_trivial_initialization(f.initialization_data) From d50778beebed28482ecedad7b8424f1d006f2a51 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 23 May 2025 02:49:50 +0530 Subject: [PATCH 02/10] fix: aliasing and mutating object accumulation --- src/initialization.jl | 34 ++++++++++------------------------ src/remake.jl | 17 +++++++++++++++++ src/utils.jl | 3 --- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 3e029950d..30e624d0a 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -258,26 +258,28 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata.update_initializeprob!(initprob, valp) end end + nlsol, success = solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg ) - nlsol, success = solve_initialization(initdata, initprob, alg, Val{is_trivial_initialization(initdata)}(); abstol, reltol, nlsolve_alg) - + nlsol2 = prob.f.initialization_data.initializeprob if initdata.initializeprobmap !== nothing - u0 = initdata.initializeprobmap(choose_branch(nlsol)) + u02 = initdata.initializeprobmap(nlsol2) end if initdata.initializeprobpmap !== nothing - p = initdata.initializeprobpmap(valp, choose_branch(nlsol)) + p2 = initdata.initializeprobpmap(valp, nlsol) end - return u0, p, success + u03 = isnothing(initdata.initializeprobmap) ? u0 : u02 + p3 = isnothing(initdata.initializeprobpmap) ? p : p2 + return u03, p3, success end -function solve_initialization(initdata, initprob, alg, ::Val{true}; kwargs...) - nlsol = @set initdata.initializeprob = initprob +function solve_initialization(initdata::OverrideInitData{<:AbstractNonlinearProblem{Nothing}}, initprob, alg; kwargs...) + nlsol = initprob success = true return nlsol, success end -function solve_initialization(initdata, initprob, alg, ::Val{false}; reltol, abstol, nlsolve_alg) +function solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg) nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) if nlsolve_alg === nothing && state_values(initprob) !== nothing throw(OverrideInitMissingAlgorithm()) @@ -320,22 +322,6 @@ function get_initial_values(prob, integrator, f, ::NoInit, iip; kwargs...) return state_values(integrator), parameter_values(integrator), true end -is_trivial_initialization(::Nothing) = true - -is_trivial_initialization(::OverrideInitData{<:NonlinearLeastSquaresProblem}) = false - -is_trivial_initialization(::OverrideInitData{<:NonlinearProblem{Nothing}}) = true - -is_trivial_initialization(::OverrideInitData) = false - -function is_trivial_initialization(f::AbstractSciMLFunction) - has_initialization_data(f) && is_trivial_initialization(f.initialization_data) -end - -function is_trivial_initialization(prob::AbstractSciMLProblem) - is_trivial_initialization(prob.f) -end - @enum DETERMINED_STATUS OVERDETERMINED FULLY_DETERMINED UNDERDETERMINED function initialization_status(prob::AbstractSciMLProblem) diff --git a/src/remake.jl b/src/remake.jl index 45098ccdd..5a1082423 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -1229,3 +1229,20 @@ function remake(thing::AbstractEnsembleProblem; kwargs...) en_kwargs = [k for k in kwargs if first(k) ∈ fieldnames(T)] T(remake(thing.prob; setdiff(kwargs, en_kwargs)...); en_kwargs...) end + +is_trivial_initialization(::Nothing) = true + +is_trivial_initialization(::OverrideInitData{<:NonlinearLeastSquaresProblem}) = false + +is_trivial_initialization(::OverrideInitData{<:NonlinearProblem{Nothing}}) = true + +is_trivial_initialization(::OverrideInitData) = false + +function is_trivial_initialization(f::AbstractSciMLFunction) + # has_initialization_data(f) && is_trivial_initialization(f.initialization_data) + is_trivial_initialization(f.initialization_data) +end + +function is_trivial_initialization(prob::AbstractSciMLProblem) + is_trivial_initialization(prob.f) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 5720be1cd..ecded5af1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -550,6 +550,3 @@ Strips a SciMLSolution object and its interpolation of their functions to better function strip_solution(sol::AbstractSciMLSolution) sol end - -choose_branch(x::OverrideInitData) = x.initializeprob -choose_branch(sol::AbstractSciMLSolution) = sol \ No newline at end of file From 584f8c3dcdd4966534ab74bf0f72233b7fb14556 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Sat, 24 May 2025 04:19:54 +0530 Subject: [PATCH 03/10] chore: return deref'd tangent for getproperty(ODEProblem) --- ext/SciMLBaseZygoteExt.jl | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 45a8e0f63..ad92a011d 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -1,13 +1,14 @@ module SciMLBaseZygoteExt using Zygote -using Zygote: @adjoint, pullback -import Zygote: literal_getproperty +using Zygote: @adjoint, pullback, @_adjoint_keepthunks, _project, pair +import Zygote: literal_getproperty, literal_getfield import ChainRulesCore using SciMLBase using SciMLBase: ODESolution, remake, ODEFunction, getobserved, build_solution, EnsembleSolution, - NonlinearSolution, AbstractTimeseriesSolution + NonlinearSolution, AbstractTimeseriesSolution, + ODEProblem using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, observed, parameter_values, state_values, current_time using RecursiveArrayTools @@ -299,4 +300,20 @@ end ∇responsible_map(__context__, f, args...) end +@_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f + val = getfield(x, f) + function back(Δ) + Zygote.accum_param(__context__, val, Δ) === nothing && return + if isimmutable(x) + dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) + (_project(x, dx), nothing) + else + dx = Zygote.grad_mut(__context__, x) + dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...) + return (dx[],nothing) + end + end + Zygote.unwrap(val), back +end + end From 53cc1422be521d138fb12b3691f8b7052eea1d9e Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 May 2025 17:19:01 +0530 Subject: [PATCH 04/10] chore: add accum patch --- ext/SciMLBaseZygoteExt.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index ad92a011d..6ce8eba47 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -303,17 +303,20 @@ end @_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f val = getfield(x, f) function back(Δ) + # error() Zygote.accum_param(__context__, val, Δ) === nothing && return if isimmutable(x) + error() dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) (_project(x, dx), nothing) else dx = Zygote.grad_mut(__context__, x) dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...) - return (dx[],nothing) + return (dx,nothing) end end Zygote.unwrap(val), back end +Zygote.accum(::Tuple{}, ::NamedTuple{}) = () end From a03bb02cc059c88aa27f04aafe6fa00893888585 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 May 2025 17:30:52 +0530 Subject: [PATCH 05/10] chore: rm call to error --- ext/SciMLBaseZygoteExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 6ce8eba47..5091c6712 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -303,7 +303,6 @@ end @_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f val = getfield(x, f) function back(Δ) - # error() Zygote.accum_param(__context__, val, Δ) === nothing && return if isimmutable(x) error() From c4178b1cb5eb3d44de5bd27df9af314cbe2f9c9c Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 May 2025 19:05:19 +0530 Subject: [PATCH 06/10] chore: treat prob as immutable --- Project.toml | 2 +- ext/SciMLBaseChainRulesCoreExt.jl | 1 + ext/SciMLBaseZygoteExt.jl | 4 ++-- src/SciMLBase.jl | 2 ++ src/initialization.jl | 3 +-- src/remake.jl | 25 +++++++++++++++++-------- 6 files changed, 24 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index a1b6ffe74..dae152d64 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" @@ -41,7 +42,6 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index f1be9c162..d8dc88931 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -58,6 +58,7 @@ end function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...) function ODEProblemAdjoint(ȳ) + @show "some con" (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type) end diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 5091c6712..be6e1aa18 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -300,7 +300,7 @@ end ∇responsible_map(__context__, f, args...) end -@_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f +@_adjoint_keepthunks function Zygote.literal_getfield(x::SciMLBase.AbstractSciMLProblem, ::Val{f}) where f val = getfield(x, f) function back(Δ) Zygote.accum_param(__context__, val, Δ) === nothing && return @@ -311,7 +311,7 @@ end else dx = Zygote.grad_mut(__context__, x) dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...) - return (dx,nothing) + return (dx[],nothing) end end Zygote.unwrap(val), back diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 3111e7cf6..0ec73b797 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -44,6 +44,8 @@ import SciMLOperators: @reexport using SciMLOperators +using Zygote + function __solve end function __init end diff --git a/src/initialization.jl b/src/initialization.jl index dff2e72cf..9dc75b2f6 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -260,9 +260,8 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, end nlsol, success = solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg ) - nlsol2 = prob.f.initialization_data.initializeprob if initdata.initializeprobmap !== nothing - u02 = initdata.initializeprobmap(nlsol2) + u02 = initdata.initializeprobmap(nlsol) end if initdata.initializeprobpmap !== nothing p2 = initdata.initializeprobpmap(valp, nlsol) diff --git a/src/remake.jl b/src/remake.jl index 5a1082423..63377681c 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -703,6 +703,8 @@ function remake(prob::NonlinearProblem; if problem_type === missing problem_type = prob.problem_type end + # error() + # @show f prob = if kwargs === missing NonlinearProblem{isinplace(prob)}(f = f, u0 = newu0, p = newp, @@ -1206,18 +1208,25 @@ function maybe_eager_initialize_problem(prob::AbstractSciMLProblem, initializati if lazy_initialization === nothing lazy_initialization = !is_trivial_initialization(initialization_data) end - if initialization_data !== nothing && !lazy_initialization && - (!is_time_dependent(prob) || current_time(prob) !== nothing) + cond = initialization_data !== nothing && !lazy_initialization && + (!is_time_dependent(prob) || current_time(prob) !== nothing) + @show cond + if cond + # @show "in maybe_eager_initialize_problem" u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) - if u0 !== nothing && eltype(u0) == Any && isempty(u0) - u0 = nothing - end + # if u0 !== nothing && eltype(u0) == Any && isempty(u0) + # u0 = nothing + # end else - u0 = state_values(prob) - p = parameter_values(prob) + u02 = state_values(prob) + p2 = parameter_values(prob) end - return u0, p + # @show p + + u03 = cond ? u0 : u02 + p3 = cond ? p : p2 + return u03, p3 end function remake(thing::AbstractJumpProblem; kwargs...) From c85dfa2913ba43eba69b2883edcbfa9c970fb746 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 May 2025 19:12:05 +0530 Subject: [PATCH 07/10] chore: gix commit history mixup --- ext/SciMLBaseChainRulesCoreExt.jl | 1 - ext/SciMLBaseZygoteExt.jl | 1 - src/SciMLBase.jl | 2 -- src/remake.jl | 26 +++++++++----------------- 4 files changed, 9 insertions(+), 21 deletions(-) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index d8dc88931..f1be9c162 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -58,7 +58,6 @@ end function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...) function ODEProblemAdjoint(ȳ) - @show "some con" (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type) end diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index be6e1aa18..a536693bd 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -305,7 +305,6 @@ end function back(Δ) Zygote.accum_param(__context__, val, Δ) === nothing && return if isimmutable(x) - error() dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) (_project(x, dx), nothing) else diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 0ec73b797..3111e7cf6 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -44,8 +44,6 @@ import SciMLOperators: @reexport using SciMLOperators -using Zygote - function __solve end function __init end diff --git a/src/remake.jl b/src/remake.jl index 63377681c..377c915db 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -262,6 +262,7 @@ function remake(prob::ODEProblem; f = missing, else ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...) end + @show typeof(prob) u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) @reset prob.u0 = u0 @@ -703,8 +704,6 @@ function remake(prob::NonlinearProblem; if problem_type === missing problem_type = prob.problem_type end - # error() - # @show f prob = if kwargs === missing NonlinearProblem{isinplace(prob)}(f = f, u0 = newu0, p = newp, @@ -1208,25 +1207,18 @@ function maybe_eager_initialize_problem(prob::AbstractSciMLProblem, initializati if lazy_initialization === nothing lazy_initialization = !is_trivial_initialization(initialization_data) end - cond = initialization_data !== nothing && !lazy_initialization && - (!is_time_dependent(prob) || current_time(prob) !== nothing) - @show cond - if cond - # @show "in maybe_eager_initialize_problem" + if initialization_data !== nothing && !lazy_initialization && + (!is_time_dependent(prob) || current_time(prob) !== nothing) u0, p, _ = get_initial_values( prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) - # if u0 !== nothing && eltype(u0) == Any && isempty(u0) - # u0 = nothing - # end + if u0 !== nothing && eltype(u0) == Any && isempty(u0) + u0 = nothing + end else - u02 = state_values(prob) - p2 = parameter_values(prob) + u0 = state_values(prob) + p = parameter_values(prob) end - # @show p - - u03 = cond ? u0 : u02 - p3 = cond ? p : p2 - return u03, p3 + return u0, p end function remake(thing::AbstractJumpProblem; kwargs...) From 296f43eada8a28f0ef1e4d0699afe3e1515837e2 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 May 2025 19:13:35 +0530 Subject: [PATCH 08/10] chore: rm unused code --- Project.toml | 2 +- src/remake.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index dae152d64..a1b6ffe74 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,6 @@ SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" @@ -42,6 +41,7 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" diff --git a/src/remake.jl b/src/remake.jl index 377c915db..ccd70e383 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -262,7 +262,6 @@ function remake(prob::ODEProblem; f = missing, else ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...) end - @show typeof(prob) u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) @reset prob.u0 = u0 From 0c249abef5fb5920aeabb7abb6c64fe08566173a Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 May 2025 20:12:07 +0530 Subject: [PATCH 09/10] chore: pass kwargs to solve_initialization --- src/initialization.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 9dc75b2f6..e98a0e829 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -258,7 +258,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata.update_initializeprob!(initprob, valp) end end - nlsol, success = solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg ) + nlsol, success = solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg, kwargs...) if initdata.initializeprobmap !== nothing u02 = initdata.initializeprobmap(nlsol) @@ -278,7 +278,7 @@ function solve_initialization(initdata::OverrideInitData{<:AbstractNonlinearProb return nlsol, success end -function solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg) +function solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg, kwargs...) nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) if nlsolve_alg === nothing && state_values(initprob) !== nothing throw(OverrideInitMissingAlgorithm()) From 40ba20b7557756bdf46172a2519ce8a899629624 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 May 2025 21:09:15 +0530 Subject: [PATCH 10/10] chore: pretend prob is immutable --- ext/SciMLBaseZygoteExt.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index a536693bd..bfe8a9f9a 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -304,14 +304,8 @@ end val = getfield(x, f) function back(Δ) Zygote.accum_param(__context__, val, Δ) === nothing && return - if isimmutable(x) - dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) - (_project(x, dx), nothing) - else - dx = Zygote.grad_mut(__context__, x) - dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...) - return (dx[],nothing) - end + dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) + (_project(x, dx), nothing) end Zygote.unwrap(val), back end