Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ function _alg_autodiff(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
alg.autodiff
end

# Type-stable helper functions using dispatch instead of runtime conditionals
_unwrap_autodiff(::Val{true}) = AutoForwardDiff()
_unwrap_autodiff(::Val{false}) = AutoFiniteDiff()
_unwrap_autodiff(x) = x

function alg_autodiff(alg)
autodiff = _alg_autodiff(alg)

if autodiff == Val(true)
return AutoForwardDiff()
elseif autodiff == Val(false)
return AutoFiniteDiff()
else
return autodiff
end
Comment on lines -22 to -28
Copy link
Member

@oscardssmith oscardssmith Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively, this could be

    if autodiff isa Val{true}
        return AutoForwardDiff()
    elseif autodiff isa Val{false}
        return AutoFiniteDiff()
    else
        return autodiff
    end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any true advantage of using branching over multiple dispatch for distinguishing executable code by types? I know this is very widely used in the SciML ecosystem, but this if-else by type design pattern causes a lot of trouble when trying to extend codes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do splits to multiple dispatch, it's just a code structure thing.

return _unwrap_autodiff(autodiff)
end

Base.@pure function determine_chunksize(u, alg::SciMLBase.DEAlgorithm)
Expand Down
168 changes: 111 additions & 57 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,55 +238,101 @@ function jacobian!(J::AbstractMatrix{<:Number}, f::F, x::AbstractArray{<:Number}
nothing
end

# Type-stable dispatch based on field types for build_jac_config
# First dispatch on f.jac type
function build_jac_config(alg, f::F1, uf::F2, du1, uprev,
u, tmp, du2) where {F1, F2}
haslinsolve = hasfield(typeof(alg), :linsolve)

if !SciMLBase.has_jac(f) &&
(!SciMLBase.has_Wfact_t(f)) &&
((concrete_jac(alg) === nothing && (!haslinsolve || (haslinsolve &&
(alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))))) ||
(concrete_jac(alg) !== nothing && concrete_jac(alg)))
jac_prototype = f.jac_prototype

if is_sparse_csc(jac_prototype)
if f.mass_matrix isa UniformScaling
idxs = diagind(jac_prototype)
@. @view(jac_prototype[idxs]) = 1
else
idxs = findall(!iszero, f.mass_matrix)
@. @view(jac_prototype[idxs]) = @view(f.mass_matrix[idxs])
end
end
_build_jac_config_jac(alg, f, uf, du1, uprev, u, tmp, du2, f.jac)
end

# f.jac is provided - no need to compute jac_config
@inline function _build_jac_config_jac(alg, f::F1, uf::F2, du1, uprev,
u, tmp, du2, jac) where {F1, F2}
(nothing, nothing)
end

# f.jac is nothing - check Wfact_t next
@inline function _build_jac_config_jac(alg, f::F1, uf::F2, du1, uprev,
u, tmp, du2, ::Nothing) where {F1, F2}
_build_jac_config_wfact(alg, f, uf, du1, uprev, u, tmp, du2, f.Wfact_t)
end

# f.Wfact_t is provided - no need to compute jac_config
@inline function _build_jac_config_wfact(alg, f::F1, uf::F2, du1, uprev,
u, tmp, du2, wfact_t) where {F1, F2}
(nothing, nothing)
end

# f.Wfact_t is nothing - dispatch on algorithm's concrete_jac and linsolve for type stability
@inline function _build_jac_config_wfact(alg, f::F1, uf::F2, du1, uprev,
u, tmp, du2, ::Nothing) where {F1, F2}
_build_jac_config_alg(alg, f, uf, du1, uprev, u, concrete_jac(alg), alg.linsolve)
end

# concrete_jac is nothing and linsolve is nothing -> need config (default factorization)
@inline function _build_jac_config_alg(alg, f::F1, uf::F2, du1, uprev, u,
::Nothing, ::Nothing) where {F1, F2}
_compute_jac_config(alg, f, uf, du1, uprev, u)
end

# concrete_jac is nothing and linsolve is provided -> check if needs concrete A
@inline function _build_jac_config_alg(alg, f::F1, uf::F2, du1, uprev, u,
::Nothing, linsolve) where {F1, F2}
if LinearSolve.needs_concrete_A(linsolve)
_compute_jac_config(alg, f, uf, du1, uprev, u)
else
(nothing, nothing)
end
end

# concrete_jac is true -> need config
@inline function _build_jac_config_alg(alg, f::F1, uf::F2, du1, uprev, u,
cj::Bool, linsolve) where {F1, F2}
if cj
_compute_jac_config(alg, f, uf, du1, uprev, u)
else
(nothing, nothing)
end
end

autodiff_alg = gpu_safe_autodiff(alg_autodiff(alg), u)
dense = autodiff_alg isa AutoSparse ? ADTypes.dense_ad(autodiff_alg) : autodiff_alg
# Actually compute the jacobian config
@inline function _compute_jac_config(alg, f::F1, uf::F2, du1, uprev, u) where {F1, F2}
jac_prototype = f.jac_prototype

if dense isa AutoFiniteDiff
dir_forward = @set dense.dir = 1
dir_reverse = @set dense.dir = -1
if is_sparse_csc(jac_prototype)
if f.mass_matrix isa UniformScaling
idxs = diagind(jac_prototype)
@. @view(jac_prototype[idxs]) = 1
else
idxs = findall(!iszero, f.mass_matrix)
@. @view(jac_prototype[idxs]) = @view(f.mass_matrix[idxs])
end
end

if autodiff_alg isa AutoSparse
autodiff_alg_forward = @set autodiff_alg.dense_ad = dir_forward
autodiff_alg_reverse = @set autodiff_alg.dense_ad = dir_reverse
else
autodiff_alg_forward = dir_forward
autodiff_alg_reverse = dir_reverse
end
autodiff_alg = gpu_safe_autodiff(alg_autodiff(alg), u)
dense = autodiff_alg isa AutoSparse ? ADTypes.dense_ad(autodiff_alg) : autodiff_alg

jac_config_forward = DI.prepare_jacobian(
uf, du1, autodiff_alg_forward, u, strict = Val(false))
jac_config_reverse = DI.prepare_jacobian(
uf, du1, autodiff_alg_reverse, u, strict = Val(false))
if dense isa AutoFiniteDiff
dir_forward = @set dense.dir = 1
dir_reverse = @set dense.dir = -1

jac_config = (jac_config_forward, jac_config_reverse)
if autodiff_alg isa AutoSparse
autodiff_alg_forward = @set autodiff_alg.dense_ad = dir_forward
autodiff_alg_reverse = @set autodiff_alg.dense_ad = dir_reverse
else
jac_config1 = DI.prepare_jacobian(uf, du1, autodiff_alg, u, strict = Val(false))
jac_config = (jac_config1, jac_config1)
autodiff_alg_forward = dir_forward
autodiff_alg_reverse = dir_reverse
end

jac_config_forward = DI.prepare_jacobian(
uf, du1, autodiff_alg_forward, u, strict = Val(false))
jac_config_reverse = DI.prepare_jacobian(
uf, du1, autodiff_alg_reverse, u, strict = Val(false))

jac_config = (jac_config_forward, jac_config_reverse)
else
jac_config = (nothing, nothing)
jac_config1 = DI.prepare_jacobian(uf, du1, autodiff_alg, u, strict = Val(false))
jac_config = (jac_config1, jac_config1)
end

jac_config
Expand Down Expand Up @@ -365,32 +411,40 @@ end
# Fallback for other AD backends
gpu_safe_autodiff(backend, u) = backend

# Type-stable dispatch based on field type for build_grad_config
# When f.tgrad is Nothing, we need to compute the config
function build_grad_config(alg, f::F1, tf::F2, du1, t) where {F1, F2}
if !SciMLBase.has_tgrad(f)
ad = ADTypes.dense_ad(alg_autodiff(alg))
_build_grad_config_dispatch(alg, f, tf, du1, t, f.tgrad)
end

# Apply GPU-safe wrapping for AutoForwardDiff when dealing with GPU arrays
ad = gpu_safe_autodiff(ad, du1)
# f.tgrad is nothing - need to compute grad_config
@inline function _build_grad_config_dispatch(alg, f::F1, tf::F2, du1, t, ::Nothing) where {F1, F2}
ad = ADTypes.dense_ad(alg_autodiff(alg))

if ad isa AutoFiniteDiff
dir_true = @set ad.dir = 1
dir_false = @set ad.dir = -1
# Apply GPU-safe wrapping for AutoForwardDiff when dealing with GPU arrays
ad = gpu_safe_autodiff(ad, du1)

grad_config_true = DI.prepare_derivative(tf, du1, dir_true, t)
grad_config_false = DI.prepare_derivative(tf, du1, dir_false, t)
if ad isa AutoFiniteDiff
dir_true = @set ad.dir = 1
dir_false = @set ad.dir = -1

grad_config = (grad_config_true, grad_config_false)
elseif ad isa AutoForwardDiff
grad_config1 = DI.prepare_derivative(tf, du1, ad, convert(eltype(du1), t))
grad_config = (grad_config1, grad_config1)
else
grad_config1 = DI.prepare_derivative(tf, du1, ad, t)
grad_config = (grad_config1, grad_config1)
end
return grad_config
grad_config_true = DI.prepare_derivative(tf, du1, dir_true, t)
grad_config_false = DI.prepare_derivative(tf, du1, dir_false, t)

grad_config = (grad_config_true, grad_config_false)
elseif ad isa AutoForwardDiff
grad_config1 = DI.prepare_derivative(tf, du1, ad, convert(eltype(du1), t))
grad_config = (grad_config1, grad_config1)
else
return (nothing, nothing)
grad_config1 = DI.prepare_derivative(tf, du1, ad, t)
grad_config = (grad_config1, grad_config1)
end
return grad_config
end

# f.tgrad is provided - return nothing tuple
@inline function _build_grad_config_dispatch(alg, f::F1, tf::F2, du1, t, tgrad) where {F1, F2}
(nothing, nothing)
end

function sparsity_colorvec(f::F, x) where F
Expand Down
18 changes: 15 additions & 3 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,19 @@ function alg_cache(
alg_autodiff(alg), size(tab.H, 1))
end

# Function barrier helper for type stability - Julia specializes on concrete types
@inline function _make_rosenbrock_cache(
u, uprev, dense, du, du1, du2, dtC, dtd, ks, fsalfirst, fsallast,
dT, J::JType, W::WType, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve::F, jac_config::JCType, grad_config::GCType, reltol, alg,
step_limiter!, stage_limiter!, interp_order) where {JType, WType, F, JCType, GCType}
RosenbrockCache(
u, uprev, dense, du, du1, du2, dtC, dtd, ks, fsalfirst, fsallast,
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg,
step_limiter!, stage_limiter!, interp_order)
end

function alg_cache(
alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr, Rodas6P},
u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -822,9 +835,8 @@ function alg_cache(
Pl=Pl, Pr=Pr,
assumptions=LinearSolve.OperatorAssumptions(true))


# Return the cache struct with vectors
RosenbrockCache(
# Use function barrier to ensure type stability in cache construction
_make_rosenbrock_cache(
u, uprev, dense, du, du1, du2, dtC, dtd, ks, fsalfirst, fsallast,
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg,
Expand Down
Loading