diff --git a/src/ExaModels.jl b/src/ExaModels.jl index 61ac4c857..bf283443e 100644 --- a/src/ExaModels.jl +++ b/src/ExaModels.jl @@ -78,6 +78,10 @@ export ExaModel, @add_con, @add_con!, set_parameter!, + objective, + constraint, + constraint!, + subexpr, solution, multipliers, multipliers_L, diff --git a/src/gradient.jl b/src/gradient.jl index 3bf806cf4..631703697 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -8,22 +8,26 @@ Performs dense gradient evaluation via the reverse pass on the computation (sub) - `y`: result vector - `adj`: adjoint propagated up to the current node """ -@inline function drpass(d::D, y, adj) where {D<:Union{Real,AdjointNull}} +@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:Union{Real,AdjointNull}} nothing end -@inline function drpass(d::D, y, adj) where {D<:AdjointNode1} - offset = drpass(d.inner, y, adj * d.y) +@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNode1} + drpass(e, e_starts, e_cnts, d.inner, y, adj * d.y) nothing end -@inline function drpass(d::D, y, adj) where {D<:AdjointNode2} - offset = drpass(d.inner1, y, adj * d.y1) - offset = drpass(d.inner2, y, adj * d.y2) +@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNode2} + drpass(e, e_starts, e_cnts, d.inner1, y, adj * d.y1) + drpass(e, e_starts, e_cnts, d.inner2, y, adj * d.y2) nothing end -@inline function drpass(d::D, y, adj) where {D<:AdjointNodeVar} +@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNodeVar} @inbounds y[d.i] += adj nothing end +@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNodeExpr} + y[d.i] += e[e_starts[d.i][2]] + nothing +end """ gradient!(y, f, x, adj) @@ -36,6 +40,18 @@ Performs dense gradient evalution - `x`: variable vector - `adj`: initial adjoint """ +function gradient!(isexp, e, e_starts, e_cnts, y, f, x, θ, adj) + @simd for k in eachindex(f.itr) + @inbounds gradient!(isexp, e, e_starts, e_cnts, y, f.f, x, θ, f.itr[k], adj) + end + return y +end +function gradient!(isexp, e, e_starts, e_cnts, y, f, x, θ, p, adj) + graph = f(p, AdjointNodeSource(x, nothing), θ) + drpass(e, e_starts, e_cnts, graph, y, adj) + return y +end +# Simple (no-expression-cache) wrappers used by _grad! in nlp.jl function gradient!(y, f, x, θ, adj) @simd for k in eachindex(f.itr) @inbounds gradient!(y, f.f, x, θ, f.itr[k], adj) @@ -43,8 +59,8 @@ function gradient!(y, f, x, θ, adj) return y end function gradient!(y, f, x, θ, p, adj) - graph = f(p, AdjointNodeSource(x), θ) - drpass(graph, y, adj) + graph = f(p, AdjointNodeSource(x, nothing), θ) + drpass(nothing, nothing, nothing, graph, y, adj) return y end @@ -167,15 +183,15 @@ Performs sparse gradient evalution - `x`: variable vector - `adj`: initial adjoint """ -function sgradient!(y, f, x, θ, adj) +function sgradient!(y, f, x, θ, adj, isexp) @simd for k in eachindex(f.itr) @inbounds sgradient!(y, f.f, f.itr[k], x, θ, f.itr.comp1, offset1(f, k), adj) end return y end -function sgradient!(y, f, p, x, θ, comp, o1, adj) - graph = f(p, AdjointNodeSource(x), θ) +function sgradient!(y, f, p, x, θ, comp, o1, adj, isexp) + graph = f(p, AdjointNodeSource(x, nothing), θ) grpass(graph, comp, y, o1, 0, adj) return y end diff --git a/src/graph.jl b/src/graph.jl index d38fa433a..eb6c34431 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -138,6 +138,9 @@ the index is itself a node). struct Var{I} <: AbstractNode i::I end +struct Exp{I} <: AbstractNode + i::I +end struct ParameterSource <: AbstractNode end struct ParameterNode{I} <: AbstractNode @@ -216,14 +219,12 @@ end @inline Base.getindex(v::DataIndexed{I, n}, i) where {I, n} = DataIndexed(v, i) @inline Base.indexed_iterate(v::DataIndexed{I, n}, idx, start = 1) where {I, n} = (DataIndexed(v, idx), idx + 1) - @inline Base.getindex(n::VarSource, i) = Var(i) @inline Base.getindex(::ParameterSource, i) = ParameterNode(i) @inline Node1(f::F, inner::I) where {F,I} = Node1{F,I}(inner) @inline Node2(f::F, inner1::I1, inner2::I2) where {F,I1,I2} = Node2{F,I1,I2}(inner1, inner2) - struct Identity end @inline (v::Var{I})(i, x, θ) where {I<:AbstractNode} = @inbounds x[v.i(i, x, θ)] @@ -231,6 +232,9 @@ struct Identity end @inline (v::Var{I})(i::Identity, x, θ) where {I <: AbstractNode} = x[v] @inline (v::Var{I})(i::Identity, x, θ) where {I <: Real} = x[v] +@inline (e::Exp{I})(i, x, θ) where {I<:AbstractNode} = @inbounds x[e.i(i, x, θ)] +@inline (e::Exp{I})(i, x, θ) where {I} = @inbounds x[e.i] + @inline (v::ParameterNode{I})(i, x, θ) where {I<:AbstractNode} = @inbounds θ[v.i(i, x, θ)] @inline (v::ParameterNode{I})(::Any, x, θ) where {I} = @inbounds θ[v.i] @inline (v::ParameterNode{I})(::Identity, x, θ) where {I<:AbstractNode} = @inbounds θ[v.i] @@ -294,6 +298,11 @@ struct AdjointNodeVar{I,T} <: AbstractAdjointNode x::T end +struct AdjointNodeExpr{I,T} <: AbstractAdjointNode + i::I + x::T +end + """ AdjointNodeSource{VT} @@ -304,8 +313,9 @@ primal value. # Fields - `inner::VT`: primal variable vector (or `nothing` for a zero-valued seed) """ -struct AdjointNodeSource{VT} +struct AdjointNodeSource{VT,OE} inner::VT + offset_exps::OE end @inline AdjointNode1(f::F, x::T, y, inner::I) where {F,T,I} = @@ -318,7 +328,6 @@ end @inline Base.getindex(x::I, i) where {I<:AdjointNodeSource} = @inbounds AdjointNodeVar(i, x.inner[i]) - """ SecondAdjointNode1{F, T, I} <: AbstractSecondAdjointNode @@ -379,8 +388,13 @@ struct SecondAdjointNodeVar{I,T} <: AbstractSecondAdjointNode x::T end +struct SecondAdjointNodeExpr{I,T} <: AbstractSecondAdjointNode + i::I + x::T +end + """ - SecondAdjointNodeSource{VT} + SecondAdjointNodeSource{VT,VTI} Factory for [`SecondAdjointNodeVar`](@ref) leaves. Indexing with `i` returns `SecondAdjointNodeVar(i, inner[i])`, seeding the Hessian-pass tree. @@ -388,8 +402,9 @@ Factory for [`SecondAdjointNodeVar`](@ref) leaves. Indexing with `i` returns # Fields - `inner::VT`: primal variable vector (or `nothing` for a zero-valued seed) """ -struct SecondAdjointNodeSource{VT} +struct SecondAdjointNodeSource{VT,OE} inner::VT + offset_exps::OE end @inline SecondAdjointNode1(f::F, x::T, y, h, inner::I) where {F,T,I} = diff --git a/src/hessian.jl b/src/hessian.jl index 64f92a07f..3a8821614 100644 --- a/src/hessian.jl +++ b/src/hessian.jl @@ -1,9 +1,13 @@ """ - hdrpass(t1::T1, t2::T2, comp, y1, y2, o2, cnt, adj) + hdrpass(e, e_starts, e_cnts, t1::T1, t2::T2, comp, y1, y2, o2, cnt, adj) Performs sparse hessian evaluation (`(df1/dx)(df2/dx)'` portion) via the reverse pass on the computation (sub)graph formed by second-order forward pass # Arguments: +- `e`: expression Jacobian values +- `e_starts`: expression start indices +- `e_cnts`: expression counts +- `isexp`: expression indicator vector - `t1`: second-order computation (sub)graph regarding f1 - `t2`: second-order computation (sub)graph regarding f2 - `comp`: a `Compressor`, which helps map counter to sparse vector index @@ -14,6 +18,9 @@ Performs sparse hessian evaluation (`(df1/dx)(df2/dx)'` portion) via the reverse - `adj`: second adjoint propagated up to the current node """ @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -23,10 +30,13 @@ Performs sparse hessian evaluation (`(df1/dx)(df2/dx)'` portion) via the reverse cnt, adj, ) where {T1<:SecondAdjointNode1,T2<:SecondAdjointNode1} - cnt = hdrpass(t1.inner, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y * t2.y) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y * t2.y) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::SecondAdjointNode1, t2::SecondAdjointNode1, comp::Nothing, @@ -36,12 +46,15 @@ end cnt, adj, ) # despecialized - cnt = hdrpass(t1.inner, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y * t2.y) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y * t2.y) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -50,11 +63,14 @@ end o2, cnt, adj, -) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNode1} - cnt = hdrpass(t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y) +) where {T1<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr},T2<:SecondAdjointNode1} + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::SecondAdjointNodeVar, t2::SecondAdjointNode1, comp::Nothing, @@ -64,12 +80,30 @@ end cnt, adj, ) # despecialized - cnt = hdrpass(t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y) + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y) + cnt +end +function hdrpass( + e, + e_starts, + e_cnts, + t1::SecondAdjointNodeExpr, + t2::SecondAdjointNode1, + comp::Nothing, + y1, + y2, + o2, + cnt, + adj, +) # despecialized + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y) cnt end - @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -78,11 +112,14 @@ end o2, cnt, adj, -) where {T1<:SecondAdjointNode1,T2<:SecondAdjointNodeVar} - cnt = hdrpass(t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y) +) where {T1<:SecondAdjointNode1,T2<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr}} + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::SecondAdjointNode1, t2::SecondAdjointNodeVar, comp::Nothing, @@ -92,12 +129,31 @@ end cnt, adj, ) # despecialized - cnt = hdrpass(t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y) + cnt +end +function hdrpass( + e, + e_starts, + e_cnts, + t1::SecondAdjointNode1, + t2::SecondAdjointNodeExpr, + comp::Nothing, + y1, + y2, + o2, + cnt, + adj, +) # despecialized + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -107,13 +163,16 @@ end cnt, adj, ) where {T1<:SecondAdjointNode2,T2<:SecondAdjointNode2} - cnt = hdrpass(t1.inner1, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y1) - cnt = hdrpass(t1.inner1, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y2) - cnt = hdrpass(t1.inner2, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y1) - cnt = hdrpass(t1.inner2, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y2) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y2) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y2) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::SecondAdjointNode2, t2::SecondAdjointNode2, comp::Nothing, @@ -123,15 +182,18 @@ end cnt, adj, ) # despecialized - cnt = hdrpass(t1.inner1, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y1) - cnt = hdrpass(t1.inner1, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y2) - cnt = hdrpass(t1.inner2, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y1) - cnt = hdrpass(t1.inner2, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y2) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y2) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y2) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -141,11 +203,14 @@ end cnt, adj, ) where {T1<:SecondAdjointNode1,T2<:SecondAdjointNode2} - cnt = hdrpass(t1.inner, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y * t2.y1) - cnt = hdrpass(t1.inner, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y * t2.y2) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y * t2.y2) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::SecondAdjointNode1, t2::SecondAdjointNode2, comp::Nothing, @@ -155,12 +220,15 @@ end cnt, adj, ) # despecialized - cnt = hdrpass(t1.inner, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y * t2.y1) - cnt = hdrpass(t1.inner, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y * t2.y2) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2.inner1, comp, y1, y2, o2, cnt, adj * t1.y * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2.inner2, comp, y1, y2, o2, cnt, adj * t1.y * t2.y2) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -170,11 +238,14 @@ end cnt, adj, ) where {T1<:SecondAdjointNode2,T2<:SecondAdjointNode1} - cnt = hdrpass(t1.inner1, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y) - cnt = hdrpass(t1.inner2, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y) cnt end function hdrpass( + e, + e_starts, + e_cnts, t1::SecondAdjointNode2, t2::SecondAdjointNode1, comp::Nothing, @@ -184,12 +255,15 @@ function hdrpass( cnt, adj, ) # despecialized - cnt = hdrpass(t1.inner1, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y) - cnt = hdrpass(t1.inner2, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y1 * t2.y) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2.inner, comp, y1, y2, o2, cnt, adj * t1.y2 * t2.y) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -198,12 +272,15 @@ end o2, cnt, adj, -) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNode2} - cnt = hdrpass(t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1) - cnt = hdrpass(t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2) +) where {T1<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr},T2<:SecondAdjointNode2} + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::SecondAdjointNodeVar, t2::SecondAdjointNode2, comp::Nothing, @@ -213,12 +290,32 @@ end cnt, adj, ) # despecialized - cnt = hdrpass(t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1) - cnt = hdrpass(t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2) + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2) + cnt +end +function hdrpass( + e, + e_starts, + e_cnts, + t1::SecondAdjointNodeExpr, + t2::SecondAdjointNode2, + comp::Nothing, + y1, + y2, + o2, + cnt, + adj, +) # despecialized + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -227,12 +324,15 @@ end o2, cnt, adj, -) where {T1<:SecondAdjointNode2,T2<:SecondAdjointNodeVar} - cnt = hdrpass(t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1) - cnt = hdrpass(t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2) +) where {T1<:SecondAdjointNode2,T2<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr}} + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2) cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::SecondAdjointNode2, t2::SecondAdjointNodeVar, comp::Nothing, @@ -242,13 +342,32 @@ end cnt, adj, ) # despecialized - cnt = hdrpass(t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1) - cnt = hdrpass(t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2) + cnt +end +function hdrpass( + e, + e_starts, + e_cnts, + t1::SecondAdjointNode2, + t2::SecondAdjointNodeExpr, + comp::Nothing, + y1, + y2, + o2, + cnt, + adj, +) # despecialized + cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1) + cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2) cnt end - @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -260,35 +379,100 @@ end ) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeVar} i, j = t1.i, t2.i @inbounds if i == j - y1[o2+comp(cnt+=1)] += 2 * adj + y1[o2+comp(cnt += 1)] += 2 * adj else - y1[o2+comp(cnt+=1)] += adj + y1[o2+comp(cnt += 1)] += adj end cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, - comp::Nothing, + comp, y1, y2, o2, - cnt::Tuple{<:Tuple,<:Tuple}, + cnt, adj, -) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeVar} - pair = (t1.i, t2.i) - mapping, uniques = cnt - idx = _hpass_find_pair(pair, uniques, 1) - if idx === 0 - return ((mapping..., length(uniques) + 1), (uniques..., pair)) - else - return ((mapping..., idx), uniques) +) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeVar} + (cnt_start, e_start) = e_starts[t1.i] + len = e_cnts[cnt_start] + cnt += 1 + for i in 1:len + @inbounds y1[o2+comp(cnt)] += e[e_start+i-1] * adj + cnt += e_cnts[cnt_start+i] + end + return cnt +end + +@inline function hdrpass( + e, + e_starts, + e_cnts, + t1::T1, + t2::T2, + comp, + y1, + y2, + o2, + cnt, + adj, +) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeExpr} + (cnt_start, e_start) = e_starts[t2.i] + len = e_cnts[cnt_start] + cnt += 1 + for i in 1:len + @inbounds y1[o2+comp(cnt)] += e[e_start+i-1] * adj + cnt += e_cnts[cnt_start+i] + end + return cnt +end + +@inline function hdrpass( + e, + e_starts, + e_cnts, + t1::T1, + t2::T2, + comp, + y1, + y2, + o2, + cnt, + adj, +) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeExpr} + (cnt_start1, e_start1) = e_starts[t1.i] + len1 = e_cnts[cnt_start1] + (cnt_start2, e_start2) = e_starts[t2.i] + len2 = e_cnts[cnt_start2] + + cnt += 1 + for i in 1:len1 + val1 = e[e_start1+i-1] + for j in 1:len2 + val2 = e[e_start2+j-1] + ind = o2 + comp(cnt) + @inbounds if t1.i == t2.i && i == j + y1[ind] += 2 * val1 * val2 * adj + else + y1[ind] += val1 * val2 * adj + end + cnt += e_cnts[cnt_start2+j] + end + cnt += e_cnts[cnt_start1+i] end + return cnt end @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -320,21 +504,34 @@ end @inline hdrpass(::SecondAdjointNull, ::SecondAdjointNull, comp, y1, y2, o2, cnt, adj) = cnt """ - hrpass(t::D, comp, y1, y2, o2, cnt, adj, adj2) + hrpass(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t::D, comp, y1, y2, o2, cnt, adj, adj2) Performs sparse hessian evaluation (`d²f/dx²` portion) via the reverse pass on the computation (sub)graph formed by second-order forward pass # Arguments: +- `e`: expression Jacobian values (e1) +- `e_starts`: expression Jacobian start indices +- `e_cnts`: expression Jacobian counts +- `e2`: expression Hessian values +- `e2_starts`: expression Hessian start indices +- `e2_cnts`: expression Hessian counts +- `isexp`: expression indicator vector - `comp`: a `Compressor`, which helps map counter to sparse vector index - `y1`: result vector #1 - `y2`: result vector #2 (only used when evaluating sparsity) - `o2`: index offset - `cnt`: counter - `adj`: first adjoint propagated up to the current node -- `adj`: second adjoint propagated up to the current node +- `adj2`: second adjoint propagated up to the current node """ @inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -346,7 +543,14 @@ Performs sparse hessian evaluation (`d²f/dx²` portion) via the reverse pass on ) where {D<:Union{SecondAdjointNull,Real}} cnt end + @inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -355,12 +559,83 @@ end cnt, adj, adj2, -) where {D<:SecondAdjointNode1} +) where {D<:SecondAdjointNodeExpr} + (cnt_start2, e_start2) = e2_starts[t.i] + len2 = e2_cnts[cnt_start2] + cnt += 1 + for i in 1:len2 + @inbounds y1[o2+comp(cnt)] += adj * e2[e_start2+i-1] + cnt += e2_cnts[cnt_start2+i] + end + return cnt +end + +@inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, + t::D, + comp, + y1::V, + y2::V, + o2, + cnt, + adj, + adj2, +) where {D<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}} + (cnt_start2, e_start2) = e2_starts[t.i] + len2 = e2_cnts[cnt_start2] + cnt += 1 + for i in 1:len2 + ind = o2 + comp(cnt) + val = e2[e_start2+i-1] + r = unpack_row(val) + c = unpack_col(val) + if y1 === y2 + if r != 0 || c != 0 + @inbounds y1[ind] = pack_indices(r, c) + end + else + if r != 0 || c != 0 + @inbounds y1[ind] = r + @inbounds y2[ind] = c + end + end + cnt += e2_cnts[cnt_start2+i] + end + return cnt +end - cnt = hrpass(t.inner, comp, y1, y2, o2, cnt, adj * t.y, adj2 * (t.y)^2 + adj * t.h) +@inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, + t::D, + comp, + y1, + y2, + o2, + cnt, + adj, + adj2, +) where {D<:SecondAdjointNode1} + cnt = hrpass(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, adj * t.y, adj2 * (t.y)^2 + adj * t.h) cnt end + @inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -370,19 +645,23 @@ end adj, adj2, ) where {D<:SecondAdjointNode2} - adj2y1y2 = adj2 * t.y1 * t.y2 adjh12 = adj * t.h12 - cnt = hrpass(t.inner1, comp, y1, y2, o2, cnt, adj * t.y1, adj2 * (t.y1)^2 + adj * t.h11) - cnt = hrpass(t.inner2, comp, y1, y2, o2, cnt, adj * t.y2, adj2 * (t.y2)^2 + adj * t.h22) - cnt = hdrpass(t.inner1, t.inner2, comp, y1, y2, o2, cnt, adj2y1y2 + adjh12) + cnt = hrpass(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner1, comp, y1, y2, o2, cnt, adj * t.y1, adj2 * (t.y1)^2 + adj * t.h11) + cnt = hrpass(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner2, comp, y1, y2, o2, cnt, adj * t.y2, adj2 * (t.y2)^2 + adj * t.h22) + cnt = hdrpass(e, e_starts, e_cnts, t.inner1, t.inner2, comp, y1, y2, o2, cnt, adj2y1y2 + adjh12) cnt end @inline hrpass0(args...) = hrpass(args...) - @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -392,10 +671,17 @@ end adj, adj2, ) where {N<:Union{FirstFixed{typeof(*)},SecondFixed{typeof(*)}},D<:SecondAdjointNode1{N}} - cnt = hrpass0(t.inner, comp, y1, y2, o2, cnt, adj * t.y, adj2 * (t.y)^2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, adj * t.y, adj2 * (t.y)^2) cnt end + @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -405,10 +691,17 @@ end adj, adj2, ) where {N<:Union{FirstFixed{typeof(+)},SecondFixed{typeof(+)}},D<:SecondAdjointNode1{N}} - cnt = hrpass0(t.inner, comp, y1, y2, o2, cnt, adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, adj, adj2) cnt end + @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -418,10 +711,17 @@ end adj, adj2, ) where {D<:SecondAdjointNode1{FirstFixed{typeof(-)}}} - cnt = hrpass0(t.inner, comp, y1, y2, o2, cnt, -adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, -adj, adj2) cnt end + @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -431,11 +731,17 @@ end adj, adj2, ) where {D<:SecondAdjointNode1{SecondFixed{typeof(-)}}} - cnt = hrpass0(t.inner, comp, y1, y2, o2, cnt, adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, adj, adj2) cnt end @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -445,10 +751,17 @@ end adj, adj2, ) where {D<:SecondAdjointNode1{typeof(+)}} - cnt = hrpass0(t.inner, comp, y1, y2, o2, cnt, adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, adj, adj2) cnt end + @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -458,11 +771,17 @@ end adj, adj2, ) where {D<:SecondAdjointNode1{typeof(-)}} - cnt = hrpass0(t.inner, comp, y1, y2, o2, cnt, -adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, -adj, adj2) cnt end @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -472,12 +791,18 @@ end adj, adj2, ) where {D<:SecondAdjointNode2{typeof(+)}} - cnt = hrpass0(t.inner1, comp, y1, y2, o2, cnt, adj, adj2) - cnt = hrpass0(t.inner2, comp, y1, y2, o2, cnt, adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner1, comp, y1, y2, o2, cnt, adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner2, comp, y1, y2, o2, cnt, adj, adj2) cnt end @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::D, comp, y1, @@ -487,11 +812,18 @@ end adj, adj2, ) where {D<:SecondAdjointNode2{typeof(-)}} - cnt = hrpass0(t.inner1, comp, y1, y2, o2, cnt, adj, adj2) - cnt = hrpass0(t.inner2, comp, y1, y2, o2, cnt, -adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner1, comp, y1, y2, o2, cnt, adj, adj2) + cnt = hrpass0(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner2, comp, y1, y2, o2, cnt, -adj, adj2) cnt end + @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::T, comp, y1, @@ -503,7 +835,14 @@ end ) where {T<:SecondAdjointNodeVar} cnt end + @inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::T, comp::Nothing, y1, @@ -516,23 +855,49 @@ end cnt end +@inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, + t::T, + comp, + y1, + y2, + o2, + cnt, + adj, + adj2, +) where {T<:SecondAdjointNodeExpr} + (cnt_start2, e_start2) = e2_starts[t.i] + len2 = e2_cnts[cnt_start2] + cnt += 1 + for i in 1:len2 + @inbounds y1[o2+comp(cnt)] += adj * e2[e_start2+i-1] + cnt += e2_cnts[cnt_start2+i] + end + + + return cnt +end @inline function hdrpass( - t1::SecondAdjointNodeVar, - t2::SecondAdjointNodeVar, + e, + e_starts, + e_cnts, + t1::T1, + t2::T2, comp::Nothing, y1, y2, o2, cnt::Vector, adj, -) +) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeVar} push!(cnt, (t1.i, t2.i)) - return cnt -end -function hrpass(t::SecondAdjointNodeVar, comp::Nothing, y1, y2, o2, cnt::Vector, adj, adj2) - push!(cnt, (t.i, t.i)) - return cnt + cnt end # Tuple-based sparsity detection for Hessian: cnt = (mapping_acc::Tuple, unique_acc::Tuple) @@ -543,7 +908,56 @@ end return _hpass_find_pair(x, Base.tail(t), i + 1) end +@inline function hdrpass( + e, + e_starts, + e_cnts, + t1::T1, + t2::T2, + comp::Nothing, + y1, + y2, + o2, + cnt::Tuple{<:Tuple,<:Tuple}, + adj, +) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeVar} + pair = (t1.i, t2.i) + mapping, uniques = cnt + idx = _hpass_find_pair(pair, uniques, 1) + if idx === 0 + return ((mapping..., length(uniques) + 1), (uniques..., pair)) + else + return ((mapping..., idx), uniques) + end +end + +@inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, + t::T, + comp::Nothing, + y1, + y2, + o2, + cnt::Vector, + adj, + adj2, +) where {T<:SecondAdjointNodeVar} + push!(cnt, (t.i, t.i)) + cnt +end + @inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::T, comp::Nothing, y1, @@ -564,6 +978,12 @@ end end @inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::T, comp, y1::Tuple{V1,V2}, @@ -577,7 +997,14 @@ end @inbounds y[t.i] += adj2 * v[t.i] return (cnt += 1) end + @inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::T, comp, y1, @@ -587,10 +1014,17 @@ end adj, adj2, ) where {T<:SecondAdjointNodeVar} - @inbounds y1[o2+comp(cnt+=1)] += adj2 + @inbounds y1[o2+comp(cnt += 1)] += adj2 cnt end + @inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::T, comp, y1::V, @@ -601,11 +1035,26 @@ end adj2, ) where {T<:SecondAdjointNodeVar,I<:Integer,V<:AbstractVector{I}} ind = o2 + comp(cnt += 1) - @inbounds y1[ind] = t.i - @inbounds y2[ind] = t.i + if y1 === y2 + if t.i != 0 + @inbounds y1[ind] = pack_indices(t.i, t.i) + end + else + if t.i != 0 + @inbounds y1[ind] = t.i + @inbounds y2[ind] = t.i + end + end cnt end + @inline function hrpass( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, t::T, comp, y1::V, @@ -619,7 +1068,15 @@ end @inbounds y1[ind] = ((t.i, t.i), ind) cnt end + +@inline pack_indices(i, j) = (UInt64(i) << 32) | UInt64(j) +@inline unpack_row(v) = Int(v >> 32) +@inline unpack_col(v) = Int(v & 0xFFFFFFFF) + @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -631,16 +1088,33 @@ end ) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeVar,I<:Integer,V<:AbstractVector{I}} i, j = t1.i, t2.i ind = o2 + comp(cnt += 1) - @inbounds if i >= j - y1[ind] = i - y2[ind] = j + + if y1 === y2 + if i != 0 || j != 0 + @inbounds if i >= j + y1[ind] = pack_indices(i, j) + else + y1[ind] = pack_indices(j, i) + end + end else - y1[ind] = j - y2[ind] = i + if i != 0 || j != 0 + @inbounds if i >= j + y1[ind] = i + y2[ind] = j + else + y1[ind] = j + y2[ind] = i + end + end end cnt end + @inline function hdrpass( + e, + e_starts, + e_cnts, t1::T1, t2::T2, comp, @@ -665,19 +1139,250 @@ end cnt end +@inline function hrpass0( + e, + e_starts, + e_cnts, + e2, + e2_starts, + e2_cnts, + t::T, + comp, + y1::V, + y2::V, + o2, + cnt, + adj, + adj2, +) where {T<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}} + (cnt_start2, e_start2) = e2_starts[t.i] + len2 = e2_cnts[cnt_start2] + cnt += 1 + for i in 1:len2 + ind = o2 + comp(cnt) + val = e2[e_start2+i-1] + r = unpack_row(val) + c = unpack_col(val) + if y1 === y2 + if r != 0 || c != 0 + @inbounds y1[ind] = pack_indices(r, c) + end + else + if r != 0 || c != 0 + @inbounds y1[ind] = r + @inbounds y2[ind] = c + end + end + cnt += e2_cnts[cnt_start2+i] + end + return cnt +end + +@inline function hdrpass( + e, + e_starts, + e_cnts, + t1::T1, + t2::T2, + comp, + y1::V, + y2::V, + o2, + cnt, + adj, +) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeVar,I<:Integer,V<:AbstractVector{I}} + (cnt_start, e_start) = e_starts[t1.i] + len = e_cnts[cnt_start] + j = t2.i + cnt += 1 + for i in 1:len + ind = o2 + comp(cnt) + idx = e[e_start+i-1] + if y1 === y2 + if idx != 0 || j != 0 + @inbounds if idx >= j + y1[ind] = pack_indices(idx, j) + else + y1[ind] = pack_indices(j, idx) + end + end + else + if idx != 0 || j != 0 + @inbounds if idx >= j + y1[ind] = idx + y2[ind] = j + else + y1[ind] = j + y2[ind] = idx + end + end + end + cnt += e_cnts[cnt_start+i] + end + return cnt +end + +@inline function hdrpass( + e, + e_starts, + e_cnts, + t1::T1, + t2::T2, + comp, + y1::V, + y2::V, + o2, + cnt, + adj, +) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}} + i = t1.i + (cnt_start, e_start) = e_starts[t2.i] + len = e_cnts[cnt_start] + cnt += 1 + for k in 1:len + ind = o2 + comp(cnt) + idx = e[e_start+k-1] + if y1 === y2 + if i != 0 || idx != 0 + @inbounds if i >= idx + y1[ind] = pack_indices(i, idx) + else + y1[ind] = pack_indices(idx, i) + end + end + else + if i != 0 || idx != 0 + @inbounds if i >= idx + y1[ind] = i + y2[ind] = idx + else + y1[ind] = idx + y2[ind] = i + end + end + end + cnt += e_cnts[cnt_start+k] + end + return cnt +end + +@inline function hdrpass( + e, + e_starts, + e_cnts, + t1::T1, + t2::T2, + comp, + y1::V, + y2::V, + o2, + cnt, + adj, +) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}} + (cnt_start1, e_start1) = e_starts[t1.i] + len1 = e_cnts[cnt_start1] + (cnt_start2, e_start2) = e_starts[t2.i] + len2 = e_cnts[cnt_start2] + + cnt += 1 + for i in 1:len1 + idx1 = e[e_start1+i-1] + for j in 1:len2 + idx2 = e[e_start2+j-1] + ind = o2 + comp(cnt) + if y1 === y2 + if idx1 != 0 || idx2 != 0 + @inbounds if idx1 >= idx2 + y1[ind] = pack_indices(idx1, idx2) + else + y1[ind] = pack_indices(idx2, idx1) + end + end + else + if idx1 != 0 || idx2 != 0 + @inbounds if idx1 >= idx2 + y1[ind] = idx1 + y2[ind] = idx2 + else + y1[ind] = idx2 + y2[ind] = idx1 + end + end + end + cnt += e_cnts[cnt_start2+j] + end + cnt += e_cnts[cnt_start1+i] + end + return cnt +end + """ - shessian!(y1, y2, f, x, adj1, adj2) + shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, adj1, adj2) -Performs sparse jacobian evalution +Performs sparse hessian evaluation # Arguments: - `y1`: result vector #1 - `y2`: result vector #2 (only used when evaluating sparsity) - `f`: the function to be differentiated in `SIMDFunction` format - `x`: variable vector +- `θ`: parameter vector +- `e1`: expression Jacobian values +- `e1_starts`: expression Jacobian start indices +- `e1_cnts`: expression Jacobian counts +- `e2`: expression Hessian values +- `e2_starts`: expression Hessian start indices +- `e2_cnts`: expression Hessian counts - `adj1`: initial first adjoint - `adj2`: initial second adjoint +- `isexp`: expression indicator vector """ +function shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, adj1, adj2, isexp) + @simd for k in eachindex(f.itr) + @inbounds shessian!( + y1, + y2, + f.f, + f.itr[k], + x, + θ, + e1, e1_starts, e1_cnts, + e2, e2_starts, e2_cnts, + f.f.comp2, + offset2(f, k), + adj1, + adj2, + isexp, + ) + end +end + +function shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, adj1s::V, adj2, isexp) where {V<:AbstractVector} + @simd for k in eachindex(f.itr) + @inbounds shessian!( + y1, + y2, + f.f, + f.itr[k], + x, + θ, + e1, e1_starts, e1_cnts, + e2, e2_starts, e2_cnts, + f.f.comp2, + offset2(f, k), + adj1s[offset0(f, k)], + adj2, + isexp, + ) + end +end + +function shessian!(y1, y2, f, p, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, comp, o2, adj1, adj2, isexp) + graph = f(p, SecondAdjointNodeSource(x, nothing), θ) + hrpass0(e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, graph, comp, y1, y2, o2, 0, adj1, adj2) +end + +# Simple (no-expression-cache) wrappers used by nlp.jl callbacks function shessian!(y1, y2, f, x, θ, adj1, adj2) @simd for k in eachindex(f.itr) @inbounds shessian!( @@ -687,13 +1392,17 @@ function shessian!(y1, y2, f, x, θ, adj1, adj2) f.itr[k], x, θ, + nothing, nothing, nothing, + nothing, nothing, nothing, f.f.comp2, offset2(f, k), adj1, adj2, + nothing, ) end end + function shessian!(y1, y2, f, x, θ, adj1s::V, adj2) where {V<:AbstractVector} @simd for k in eachindex(f.itr) @inbounds shessian!( @@ -703,15 +1412,18 @@ function shessian!(y1, y2, f, x, θ, adj1s::V, adj2) where {V<:AbstractVector} f.itr[k], x, θ, + nothing, nothing, nothing, + nothing, nothing, nothing, f.f.comp2, offset2(f, k), adj1s[offset0(f, k)], adj2, + nothing, ) end end function shessian!(y1, y2, f, p, x, θ, comp, o2, adj1, adj2) - graph = f(p, SecondAdjointNodeSource(x), θ) - hrpass0(graph, comp, y1, y2, o2, 0, adj1, adj2) + graph = f(p, SecondAdjointNodeSource(x, nothing), θ) + hrpass0(nothing, nothing, nothing, nothing, nothing, nothing, graph, comp, y1, y2, o2, 0, adj1, adj2) end diff --git a/src/jacobian.jl b/src/jacobian.jl index bc60abcb2..c94467592 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -1,5 +1,5 @@ """ - jrpass(d::D, comp, i, y1, y2, o1, cnt, adj) + jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNode1} Performs sparse jacobian evaluation via the reverse pass on the computation (sub)graph formed by forward pass @@ -15,6 +15,9 @@ Performs sparse jacobian evaluation via the reverse pass on the computation (sub """ @inline function jrpass( d::D, + e, + e_starts, + e_cnts, comp, i, y1, @@ -25,51 +28,76 @@ Performs sparse jacobian evaluation via the reverse pass on the computation (sub ) where {D<:Union{AdjointNull,Real}} return cnt end -@inline function jrpass(d::D, comp, i, y1, y2, o1, cnt, adj) where {D<:AdjointNode1} - cnt = jrpass(d.inner, comp, i, y1, y2, o1, cnt, adj * d.y) +@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNode1} + cnt = jrpass(d.inner, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y) return cnt end -@inline function jrpass(d::D, comp, i, y1, y2, o1, cnt, adj) where {D<:AdjointNode2} - cnt = jrpass(d.inner1, comp, i, y1, y2, o1, cnt, adj * d.y1) - cnt = jrpass(d.inner2, comp, i, y1, y2, o1, cnt, adj * d.y2) +@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNode2} + cnt = jrpass(d.inner1, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y1) + cnt = jrpass(d.inner2, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y2) return cnt end -@inline function jrpass(d::D, comp, i, y1, y2, o1, cnt, adj) where {D<:AdjointNodeVar} +# jac_coord +@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNodeVar} @inbounds y1[o1+comp(cnt+=1)] += adj return cnt end +@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNodeExpr} + (cnt_start, e_start) = e_starts[d.i] + len = e_cnts[cnt_start] + cnt += 1 + for i in 1:len + @inbounds y1[o1+comp(cnt)] += adj * e[e_start + i - 1] + cnt += e_cnts[cnt_start + i] + end + return cnt +end +# jprod_nln @inline function jrpass( d::D, + e, + e_starts, + e_cnts, comp, - i, + o0, y1::Tuple{V1,V2}, - y2, + y2::Nothing, o1, cnt, adj, ) where {D<:AdjointNodeVar,V1<:AbstractVector,V2<:AbstractVector} (y, v) = y1 - @inbounds y[i] += adj * v[d.i] - return (cnt += 1) + @inbounds y[o0] += adj * v[d.i] + return 0 end +# TODO; jprod expressions +# jtprod_nln @inline function jrpass( d::D, + e, + e_starts, + e_cnts, comp, - i, - y1, + o0, + y1::Nothing, y2::Tuple{V1,V2}, o1, cnt, adj, ) where {D<:AdjointNodeVar,V1<:AbstractVector,V2<:AbstractVector} y, v = y2 - @inbounds y[d.i] += adj * v[i] - return (cnt += 1) + @inbounds y[d.i] += adj * v[o0] + return 0 end +# TODO; jtprod expressions +# jac_structure @inline function jrpass( d::D, + e, + e_starts, + e_cnts, comp, - i, + o0, y1::V, y2::V, o1, @@ -77,14 +105,82 @@ end adj, ) where {D<:AdjointNodeVar,I<:Integer,V<:AbstractVector{I}} ind = o1 + comp(cnt += 1) - @inbounds y1[ind] = i + @inbounds y1[ind] = o0 @inbounds y2[ind] = d.i return cnt end @inline function jrpass( d::D, + e, + e_starts, + e_cnts, comp, - i, + o0, + y1::V, + y2::V, + o1, + cnt, + adj, +) where {D<:AdjointNodeExpr,I<:Integer,V<:AbstractVector{I}} + (cnt_start, e_start) = e_starts[d.i] + len = e_cnts[cnt_start] + cnt += 1 + for i in 1:len + ind = o1 + comp(cnt) + @inbounds y1[ind] = o0 + @inbounds y2[ind] = e[e_start + i - 1] + cnt += e_cnts[cnt_start + i] + end + return cnt +end +# no rows when precomputing expressions +@inline function jrpass( + d::D, + e, + e_starts, + e_cnts, + comp, + o0, + y1::Nothing, + y2::V, + o1, + cnt, + adj, +) where {D<:AdjointNodeVar,I<:Integer,V<:AbstractVector{I}} + ind = o1 + comp(cnt += 1) + @inbounds y2[ind] = d.i + return cnt +end +@inline function jrpass( + d::D, + e, + e_starts, + e_cnts, + comp, + o0, + y1::Nothing, + y2::V, + o1, + cnt, + adj, +) where {D<:AdjointNodeExpr,I<:Integer,V<:AbstractVector{I}} + (cnt_start, e_start) = e_starts[d.i] + len = e_cnts[cnt_start] + cnt += 1 + for i in 1:len + ind = o1 + comp(cnt) + @inbounds y2[ind] = e[e_start + i - 1] + cnt += e_cnts[cnt_start + i] + end + return cnt +end +@inline function jrpass( + d::D, + e, + e_starts, + e_cnts, + comp, + o0, y1::V, y2, o1, @@ -92,11 +188,10 @@ end adj, ) where {D<:AdjointNodeVar,I<:Tuple{Tuple{Int,Int},Int},V<:AbstractVector{I}} ind = o1 + comp(cnt += 1) - @inbounds y1[ind] = ((i, d.i), ind) + @inbounds y1[ind] = ((o0, d.i), ind) return cnt end - """ sjacobian!(y1, y2, f, x, adj) @@ -109,12 +204,16 @@ Performs sparse jacobian evalution - `x`: variable vector - `adj`: initial adjoint """ -function sjacobian!(y1, y2, f, x, θ, adj) +function sjacobian!(e, e_starts, e_cnts, isexp, y1, y2, f, x, θ, adj) @simd for i in eachindex(f.itr) @inbounds sjacobian!( + isexp, y1, y2, f.f, + e, + e_starts, + e_cnts, f.itr[i], x, θ, @@ -126,7 +225,31 @@ function sjacobian!(y1, y2, f, x, θ, adj) end end -function sjacobian!(y1, y2, f, p, x, θ, comp, o0, o1, adj) - graph = f(p, AdjointNodeSource(x), θ) - jrpass(graph, comp, o0, y1, y2, o1, 0, adj) +function sjacobian!(isexp, y1, y2, f, e, e_starts, e_cnts, p, x, θ, comp, o0, o1, adj) + s = AdjointNodeSource(x, nothing) + graph = f(p, s, θ) + jrpass(graph, e, e_starts, e_cnts, comp, o0, y1, y2, o1, 0, adj) end + +# Simple (no-expression-cache) wrappers used by _jac_coord!, _jprod_nln!, _jtprod_nln! in nlp.jl +function sjacobian!(y1, y2, f, x, θ, adj) + @simd for i in eachindex(f.itr) + @inbounds sjacobian!( + nothing, + y1, + y2, + f.f, + nothing, + nothing, + nothing, + f.itr[i], + x, + θ, + f.f.comp1, + offset0(f, i), + offset1(f, i), + adj, + ) + end +end + diff --git a/src/nlp.jl b/src/nlp.jl index 1eed2917c..6059081da 100644 --- a/src/nlp.jl +++ b/src/nlp.jl @@ -1,13 +1,13 @@ abstract type AbstractVariable end abstract type AbstractParameter end abstract type AbstractConstraint end +abstract type AbstractExpression end abstract type AbstractObjective end struct NaNSource{T} end Base.getindex(::NaNSource{T}, i) where {T} = T(NaN) Base.eltype(::NaNSource{T}) where {T} = T Base.eltype(::Type{NaNSource{T}}) where {T} = T - """ Variable @@ -152,6 +152,12 @@ Constraint ) end +struct ExpressionAug{R,F,I} <: AbstractConstraint + inner::R + f::F + itr::I + oa::Int +end """ ConstraintAugmentation @@ -405,7 +411,6 @@ end ) @inline default_T(backend) = Float64 - Base.show(io::IO, c::ExaCore{T,VT,B}) where {T,VT,B} = print( io, """ @@ -507,8 +512,7 @@ function ExaModel(c::C; prod = false, kwargs...) where {C<:ExaCore} lcon = (c.lcon), ucon = (c.ucon), minimize = c.minimize, - ) - , + ), NLPModels.Counters(), build_extension(c; prod), c.tag, @@ -588,7 +592,6 @@ function __bound_check(a::UnitRange{Int}, b::I) where {I<:Integer} @assert(b in a, "Variable index bound error") end - function append!(backend, a, b::Base.Generator, lb) if lb != 0 b = _adapt_gen(b) @@ -812,7 +815,6 @@ function set_value!(model::ExaModel, param::Parameter, values) copyto!(view(model.θ, param.offset+1:param.offset+param.length), values) return nothing end - @inline _var_range(v::Variable) = v.offset+1 : v.offset+v.length @inline _con_range(c::Constraint) = c.offset+1 : c.offset+total(c.size) @@ -1193,7 +1195,6 @@ add_con!(c, bus, gen.bus => -pg[gen.i] for gen in data.gen) # subtrac gen = _adapt_gen(gen) f = SIMDFunction(T, gen, offset0(c1, 0), c.nnzj, c.nnzh) - pars = gen.iter _add_con!(c, f, pars, _constraint_dims(c1), tag) end @@ -1217,7 +1218,6 @@ c, _ = add_con!(c, g[i] += x[i] + x[i+1] for i = 1:9) """ @inline function add_con!(c::ExaCore{T}, gen::Base.Generator; tag = nothing) where T gen = _adapt_gen(gen) - # Probe the generator: the result is a ConAugPair which carries the target # constraint alongside the index/expression pair. probe = gen.f(DataSource()) @@ -1305,6 +1305,64 @@ c, s = add_expr(c, x[i, k]^2 for (i, k) in itr) return (ExaCore(c; refs = add_refs(c.refs, name, ex)), ex) end +# ── Functional-style aliases (PR #215 API) ──────────────────────────────────── +# These provide a functional style API alternative to the @add_* macros. +# They delegate to add_var / add_obj / add_con / add_expr but discard the +# updated core (mutations are not visible since ExaCore is immutable). +# NOTE: These work correctly only when ExaCore is used with `concrete = Val(false)` +# (the LegacyExaCore mutable path) or when the caller captures the returned core. +# For the immutable ExaCore, prefer the @add_var / @add_obj / @add_con / @add_expr macros. + +""" + variable(core, dims...; kwargs...) + +Functional alias for [`add_var`](@ref). Returns the `Variable` object. +""" +@inline function variable(c::C, ns...; kwargs...) where {C<:ExaCore} + _, v = add_var(c, ns...; kwargs...) + return v +end + +""" + objective(core, gen; kwargs...) + +Functional alias for [`add_obj`](@ref). Returns the `Objective` object. +""" +@inline function objective(c::C, gen; kwargs...) where {C<:ExaCore} + _, o = add_obj(c, gen; kwargs...) + return o +end + +""" + constraint(core, gen; kwargs...) + +Functional alias for [`add_con`](@ref). Returns the `Constraint` object. +""" +@inline function constraint(c::C, gen; kwargs...) where {C<:ExaCore} + _, con = add_con(c, gen; kwargs...) + return con +end + +""" + constraint!(core, c1, gen; kwargs...) + +Functional alias for [`add_con!`](@ref). Returns the `ConstraintAugmentation` object. +""" +@inline function constraint!(c::C, c1, gen; kwargs...) where {C<:ExaCore} + _, ca = add_con!(c, c1, gen; kwargs...) + return ca +end + +""" + subexpr(core, gen; kwargs...) + +Functional alias for [`add_expr`](@ref). Returns the `Expression` object. +""" +@inline function subexpr(c::C, gen; kwargs...) where {C<:ExaCore} + _, ex = add_expr(c, gen; kwargs...) + return ex +end + function jac_structure!(m::AbstractExaModel{T}, rows::AbstractVector, cols::AbstractVector) where T _jac_structure!(T, m.cons, rows, cols) return rows, cols @@ -1417,7 +1475,7 @@ function hess_coord!( m::AbstractExaModel, x::AbstractVector, hess::AbstractVector; - obj_weight = one(eltype(x)), + obj_weight=one(eltype(x)), ) fill!(hess, zero(eltype(hess))) _obj_hess_coord!(m.objs, x, m.θ, hess, obj_weight) @@ -1429,11 +1487,11 @@ function hess_coord!( x::AbstractVector, y::AbstractVector, hess::AbstractVector; - obj_weight = one(eltype(x)), + obj_weight=one(eltype(x)), ) fill!(hess, zero(eltype(hess))) _obj_hess_coord!(m.objs, x, m.θ, hess, obj_weight) - _con_hess_coord!(m.cons, x, m.θ, y, hess, obj_weight) + _con_hess_coord!(m.cons, x, m.θ, y, hess) return hess end @@ -1443,9 +1501,9 @@ _obj_hess_coord!(objs::Tuple{}, x, θ, hess, obj_weight) = nothing shessian!(hess, nothing, first(objs), x, θ, obj_weight, zero(eltype(hess))) end -_con_hess_coord!(cons::Tuple{}, x, θ, y, hess, obj_weight) = nothing -@inline function _con_hess_coord!(cons::Tuple, x, θ, y, hess, obj_weight) - _con_hess_coord!(Base.tail(cons), x, θ, y, hess, obj_weight) +_con_hess_coord!(cons::Tuple{}, x, θ, y, hess) = nothing +@inline function _con_hess_coord!(cons::Tuple, x, θ, y, hess) + _con_hess_coord!(Base.tail(cons), x, θ, y, hess) shessian!(hess, nothing, first(cons), x, θ, y, zero(eltype(hess))) end @@ -1454,7 +1512,7 @@ function hprod!( x::AbstractVector, v::AbstractVector, Hv::AbstractVector; - obj_weight = one(eltype(x)), + obj_weight=one(eltype(x)), ) fill!(Hv, zero(eltype(Hv))) _obj_hprod!(m.objs, x, m.θ, v, Hv, obj_weight) @@ -1467,11 +1525,11 @@ function hprod!( y::AbstractVector, v::AbstractVector, Hv::AbstractVector; - obj_weight = one(eltype(x)), + obj_weight=one(eltype(x)), ) fill!(Hv, zero(eltype(Hv))) _obj_hprod!(m.objs, x, m.θ, v, Hv, obj_weight) - _con_hprod!(m.cons, x, m.θ, y, v, Hv, obj_weight) + _con_hprod!(m.cons, x, m.θ, y, v, Hv) return Hv end @@ -1481,9 +1539,9 @@ _obj_hprod!(objs::Tuple{}, x, θ, v, Hv, obj_weight) = nothing shessian!((Hv, v), nothing, first(objs), x, θ, obj_weight, zero(eltype(Hv))) end -_con_hprod!(cons::Tuple{}, x, θ, y, v, Hv, obj_weight) = nothing -@inline function _con_hprod!(cons::Tuple, x, θ, y, v, Hv, obj_weight) - _con_hprod!(Base.tail(cons), x, θ, y, v, Hv, obj_weight) +_con_hprod!(cons::Tuple{}, x, θ, y, v, Hv) = nothing +@inline function _con_hprod!(cons::Tuple, x, θ, y, v, Hv) + _con_hprod!(Base.tail(cons), x, θ, y, v, Hv) shessian!((Hv, v), nothing, first(cons), x, θ, y, zero(eltype(Hv))) end diff --git a/src/register.jl b/src/register.jl index 78780db58..4a2e05627 100644 --- a/src/register.jl +++ b/src/register.jl @@ -67,8 +67,7 @@ macro register_univariate(f, df, ddf) @inline $f(t::T) where {T<:ExaModels.AbstractSecondAdjointNode} = ExaModels.SecondAdjointNode1($f, $f(t.x), $df(t.x), $ddf(t.x), t) - @inline (n::ExaModels.Node1{typeof($f),I})(i, x, θ) where {I} = - $f(n.inner(i, x, θ)) + @inline (n::ExaModels.Node1{typeof($f),I})(i, x, θ) where {I} = $f(n.inner(i, x, θ)) end, ) end @@ -246,12 +245,9 @@ macro register_bivariate(f, df1, df2, ddf11, ddf12, ddf22) ) end - @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1,I2} = - $f(n.inner1(i, x, θ), n.inner2(i, x, θ)) - @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1<:Real,I2} = - $f(n.inner1, n.inner2(i, x, θ)) - @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1,I2<:Real} = - $f(n.inner1(i, x, θ), n.inner2) + @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1,I2} = $f(n.inner1(i, x, θ), n.inner2(i, x, θ)) + @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1<:Real,I2} = $f(n.inner1, n.inner2(i, x, θ)) + @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1,I2<:Real} = $f(n.inner1(i, x, θ), n.inner2) end, ) end diff --git a/src/simdfunction.jl b/src/simdfunction.jl index e26d11763..d5f5a8ee6 100644 --- a/src/simdfunction.jl +++ b/src/simdfunction.jl @@ -25,7 +25,7 @@ struct SIMDFunction{F,C1,C2} end @inline (sf::SIMDFunction{F,C1,C2})(i, x, θ) where {F,C1,C2} = sf.f(i, x, θ) -@inline (sf::SIMDFunction{F,C1,C2})(i, x, θ) where {F <: Real,C1,C2} = sf.f +@inline (sf::SIMDFunction{F,C1,C2})(i, x, θ) where {F<:Real,C1,C2} = sf.f """ SIMDFunction(gen::Base.Generator, o0 = 0, o1 = 0, o2 = 0) @@ -73,13 +73,13 @@ end @inline function _simdfunction(T, f, o0, o1, o2) f = replace_T(T, f) - d = f(Identity(), AdjointNodeSource(NaNSource{T}()), NaNSource{T}()) + d = f(Identity(), AdjointNodeSource(NaNSource{T}(), nothing), NaNSource{T}()) raw1 = Any[] ExaModels.grpass(d, nothing, nothing, nothing, raw1, T(NaN)) - t = f(Identity(), SecondAdjointNodeSource(NaNSource{T}()), NaNSource{T}()) + t = f(Identity(), SecondAdjointNodeSource(NaNSource{T}(), nothing), NaNSource{T}()) raw2 = Any[] - ExaModels.hrpass0(t, nothing, nothing, nothing, nothing, raw2, T(NaN), T(NaN)) + ExaModels.hrpass0(nothing, nothing, nothing, nothing, nothing, nothing, t, nothing, nothing, nothing, nothing, raw2, T(NaN), T(NaN)) unique1 = _ident_unique(raw1) o1step = length(unique1) @@ -91,7 +91,8 @@ end mapping2 = Int[findfirst(y -> y === x, unique2) for x in raw2] c2 = Compressor(ntuple(i -> mapping2[i], _hr0_val(typeof(t)))) - SIMDFunction(f, c1, c2, o0, o1, o2, o1step, o2step) + f = SIMDFunction(f, c1, c2, o0, o1, o2, o1step, o2step) + return f end # === Val-based compile-time NTuple size computation (juliac-compatible, no @generated) === diff --git a/test/ADTest/ADTest.jl b/test/ADTest/ADTest.jl index 66205fd20..410634ae8 100644 --- a/test/ADTest/ADTest.jl +++ b/test/ADTest/ADTest.jl @@ -3,6 +3,8 @@ module ADTest using ExaModels using Test, ForwardDiff, SpecialFunctions +include("expression.jl") + const FUNCTIONS = [ ("basic-functions-:+", x -> +(x[1])), ("basic-functions-:-", x -> -(x[1])), diff --git a/test/ADTest/expression.jl b/test/ADTest/expression.jl new file mode 100644 index 000000000..ca7bc479b --- /dev/null +++ b/test/ADTest/expression.jl @@ -0,0 +1,126 @@ + +function test_expression() + @testset "AD Expression Tests" begin + @testset "Basic tests" begin + m = ExaCore() + v = variable(m, 5) + e1 = subexpr(m, (4,), v[i] * v[i + 1] for i in 1:4) + e2 = subexpr(m, (4,), e1[i] + v[i] for i in 1:4) + c = constraint(m, e2[i] / i for i in 1:4; ucon = 10.0) + o = objective(m, e2[i] for i in 1:4) + mod = ExaModel(m) + + x = Float64[i for i in 1:mod.meta.nvar] + + # Test Jacobian structure + jac_rows = zeros(Int, mod.meta.nnzj) + jac_cols = zeros(Int, mod.meta.nnzj) + jac_structure!(mod, jac_rows, jac_cols) + @test mod.meta.nnzj > 0 + + # Test Jacobian values + jac_buffer = zeros(mod.meta.nnzj) + jac_coord!(mod, x, jac_buffer) + @test all(isfinite, jac_buffer) + + # Test Hessian structure + hess_rows = zeros(Int, mod.meta.nnzh) + hess_cols = zeros(Int, mod.meta.nnzh) + hess_structure!(mod, hess_rows, hess_cols) + @test mod.meta.nnzh > 0 + + # Test Hessian values (objective only) + hess_buffer = zeros(mod.meta.nnzh) + hess_coord!(mod, x, hess_buffer; obj_weight=1.0) + @test all(isfinite, hess_buffer) + + # Test Hessian values (with constraints) + y = ones(mod.meta.ncon) + hess_buffer2 = zeros(mod.meta.nnzh) + hess_coord!(mod, x, y, hess_buffer2; obj_weight=1.0) + @test all(isfinite, hess_buffer2) + end + + @testset "Simple quadratic" begin + # Test a simple quadratic: f(x) = x^2, constraint: x^2 - 1 = 0 + # Hessian of f = 2 + # Hessian of constraint = 2 + m = ExaCore() + v = variable(m, 1) + o = objective(m, v[1]^2) + c = constraint(m, v[1]^2; lcon=1.0, ucon=1.0) + mod = ExaModel(m) + + x = [3.0] # arbitrary point + + # Hessian structure + hess_rows = zeros(Int, mod.meta.nnzh) + hess_cols = zeros(Int, mod.meta.nnzh) + hess_structure!(mod, hess_rows, hess_cols) + + # Objective Hessian only + hess_obj = zeros(mod.meta.nnzh) + hess_coord!(mod, x, hess_obj; obj_weight=1.0) + @test any(h ≈ 2.0 for h in hess_obj) + + # Full Hessian (obj + constraints) + y = [1.0] # constraint multiplier + hess_full = zeros(mod.meta.nnzh) + hess_coord!(mod, x, y, hess_full; obj_weight=1.0) + @test sum(hess_full) ≈ 4.0 + end + + @testset "Expression cross-derivative" begin + # Test with expression: e = x*y, f(e) = e, c(e) = e - 1 = 0 + m = ExaCore() + v = variable(m, 2) + e1 = subexpr(m, (1,), v[1] * v[2] for _ in 1:1) # e = x*y + o = objective(m, e1[1] for _ in 1:1) # f = e = x*y + c = constraint(m, e1[1] for _ in 1:1; lcon=1.0, ucon=1.0) # c = x*y = 1 + mod = ExaModel(m) + + x = zeros(mod.meta.nvar) + x[1:2] .= [2.0, 3.0] + + # Hessian structure + hess_rows = zeros(Int, mod.meta.nnzh) + hess_cols = zeros(Int, mod.meta.nnzh) + hess_structure!(mod, hess_rows, hess_cols) + + # Objective Hessian only + hess_obj = zeros(mod.meta.nnzh) + hess_coord!(mod, x, hess_obj; obj_weight=1.0) + @test any(h ≈ 1.0 for h in hess_obj) + + # Full Hessian + y = [1.0] + hess_full = zeros(mod.meta.nnzh) + hess_coord!(mod, x, y, hess_full; obj_weight=1.0) + @test any(h ≈ 2.0 for h in hess_full) || (sum(hess_full) ≈ 2.0) + end + + @testset "Nested expressions" begin + # Test nested expressions: e1 = x^2, e2 = e1 + x, f = e2 + m = ExaCore() + v = variable(m, 1) + e1 = subexpr(m, (1,), v[1]^2 for _ in 1:1) # e1 = x^2 + e2 = subexpr(m, (1,), e1[1] + v[1] for _ in 1:1) # e2 = e1 + x = x^2 + x + o = objective(m, e2[1] for _ in 1:1) # f = e2 = x^2 + x + mod = ExaModel(m) + + x = zeros(mod.meta.nvar) + x[1] = 5.0 + + # Hessian structure + hess_rows = zeros(Int, mod.meta.nnzh) + hess_cols = zeros(Int, mod.meta.nnzh) + hess_structure!(mod, hess_rows, hess_cols) + + # Objective Hessian only + hess_obj = zeros(mod.meta.nnzh) + hess_coord!(mod, x, hess_obj; obj_weight=1.0) + # Sum of hessian entries should be 2 + @test sum(hess_obj) ≈ 2.0 + end + end +end diff --git a/test/NLPTest/NLPTest.jl b/test/NLPTest/NLPTest.jl index f0816d144..c1c3fd48b 100644 --- a/test/NLPTest/NLPTest.jl +++ b/test/NLPTest/NLPTest.jl @@ -9,10 +9,10 @@ import ..BACKENDS import ..ad_tolerance, ..sol_tolerance, ..solver_tolerance const NLP_TEST_ARGUMENTS = [ - ("luksan_struct", 3), - ("luksan_struct", 20), ("luksan_vlcek", 3), ("luksan_vlcek", 20), + ("luksan_struct", 3), + ("luksan_struct", 20), ("ac_power", "pglib_opf_case3_lmbd.m"), ("ac_power", "pglib_opf_case14_ieee.m"), ("trivialmax", 1), # Issue #518 in MadNLP @@ -47,7 +47,11 @@ include("conaug_test.jl") function test_nlp(m1, m2; full = false, tol = sol_tolerance(eltype(m1.meta.x0), eltype(m2.meta.x0))) @testset "NLP meta tests" begin - list = [:nvar, :ncon, :x0, :lvar, :uvar, :y0, :lcon, :ucon] + list = [:ncon, :y0, :lcon, :ucon] + @test length(varis1) == length(varis2) + @test m1.meta.lvar[varis1] == m2.meta.lvar[varis2] + @test m1.meta.uvar[varis1] == m2.meta.uvar[varis2] + @test m1.meta.x0[varis1] == m2.meta.x0[varis2] if full append!(list, [:nnzj, :nnzh]) @@ -67,10 +71,6 @@ function test_nlp(m1, m2; full = false, tol = sol_tolerance(eltype(m1.meta.x0), end @testset "NLP callback tests" begin - x0 = copy(m2.meta.x0) - y0 = randn(eltype(m2.meta.x0), m2.meta.ncon) - u = randn(eltype(m2.meta.x0), m2.meta.nvar) - v = randn(eltype(m2.meta.x0), m2.meta.ncon) @test NLPModels.obj(m1, x0) ≈ NLPModels.obj(m2, x0) atol = tol rtol = tol @test NLPModels.cons(m1, x0) ≈ NLPModels.cons(m2, x0) atol = tol rtol = tol @@ -86,7 +86,6 @@ function test_nlp(m1, m2; full = false, tol = sol_tolerance(eltype(m1.meta.x0), jac_I_buffer2 = zeros(Int, m2.meta.nnzj) jac_J_buffer1 = zeros(Int, m1.meta.nnzj) jac_J_buffer2 = zeros(Int, m2.meta.nnzj) - hess_buffer1 = zeros(m1.meta.nnzh) hess_buffer2 = zeros(m2.meta.nnzh) hess_I_buffer1 = zeros(Int, m1.meta.nnzh) @@ -94,10 +93,10 @@ function test_nlp(m1, m2; full = false, tol = sol_tolerance(eltype(m1.meta.x0), hess_J_buffer1 = zeros(Int, m1.meta.nnzh) hess_J_buffer2 = zeros(Int, m2.meta.nnzh) - NLPModels.jac_coord!(m1, x0, jac_buffer1) - NLPModels.jac_coord!(m2, x0, jac_buffer2) - NLPModels.hess_coord!(m1, x0, y0, hess_buffer1) - NLPModels.hess_coord!(m2, x0, y0, hess_buffer2) + NLPModels.jac_coord!(m1, x01, jac_buffer1) + NLPModels.jac_coord!(m2, x02, jac_buffer2) + NLPModels.hess_coord!(m1, x01, y0, hess_buffer1) + NLPModels.hess_coord!(m2, x02, y0, hess_buffer2) NLPModels.jac_structure!(m1, jac_I_buffer1, jac_J_buffer1) NLPModels.jac_structure!(m2, jac_I_buffer2, jac_J_buffer2) NLPModels.hess_structure!(m1, hess_I_buffer1, hess_J_buffer1) @@ -114,10 +113,9 @@ function test_nlp(m1, m2; full = false, tol = sol_tolerance(eltype(m1.meta.x0), end function test_nlp_solution(result1, result2; tol = sol_tolerance(eltype(result1.solution),eltype(result2.solution))) - @testset "solution test" begin @test result1.status == result2.status - for field in [:solution, :multipliers, :multipliers_L, :multipliers_U] + for field in [:multipliers, :multipliers_L, :multipliers_U] @testset "$field" begin @test getfield(result1, field) ≈ getfield(result2, field) atol = tol rtol = tol end @@ -149,6 +147,9 @@ function runtests() @testset "NLP test" begin for backend in BACKENDS @testset "$backend" begin + @testset "Subexpr Test" begin + test_subexpr(backend) + end for (name, args) in NLP_TEST_ARGUMENTS @testset "$name $args" begin @@ -156,16 +157,19 @@ function runtests() jump_model = getfield(@__MODULE__, Symbol("_jump_$(name)_model")) m, vars0, cons0 = exa_model(nothing, args) + varis0 = m.varis m0 = WrapperNLPModel(m) m, vars2, cons2 = jump_model(nothing, args) m2 = MathOptNLPModel(m) + varis2 = [x for x in 1:m2.meta.nvar] set_optimizer(m, MadNLP.Optimizer) set_optimizer_attribute(m, "print_level", MadNLP.ERROR) optimize!(m) m, vars1, cons1 = exa_model(backend, args) + varis1 = m.varis m1 = WrapperNLPModel(m) @testset "Backend test" begin @@ -183,7 +187,7 @@ function runtests() result2 = solver(m2) @testset "$sname" begin - test_nlp_solution(result1, result2) + test_nlp_solution((result1, varis1), (result2, varis2)) end end end diff --git a/test/NLPTest/subexpr_test.jl b/test/NLPTest/subexpr_test.jl index ecb1142fd..7ed301445 100644 --- a/test/NLPTest/subexpr_test.jl +++ b/test/NLPTest/subexpr_test.jl @@ -42,7 +42,7 @@ function test_subexpr_basic(backend) x_vals = solution(result2, x2) return @test subexpr_vals ≈ x_vals .^ 2 atol = sol_tolerance(eltype(c1.x0)) rtol = sol_tolerance(eltype(c1.x0)) end - + """ Test multi-dimensional subexpressions with automatic dimension inference. """ @@ -65,12 +65,10 @@ function test_subexpr_multidim(backend) # Add some constraints to make it non-trivial @add_con(c, x[0, i] - 0.0 for i in 0:N) # Initial condition @add_con(c, x[T, i] - 1.0 for i in 0:N) # Final condition - m = ExaModel(c) # Wrap in WrapperNLPModel for GPU compatibility with Ipopt result = NLPModelsIpopt.ipopt(WrapperNLPModel(m); print_level = 0, tol = solver_tolerance(eltype(c.x0))) - @test result.status == :first_order # Check subexpression values match the definition x_sol = solution(result, x) @@ -101,12 +99,10 @@ function test_subexpr_auto_dims(backend) # Use in objective @add_obj(c, s[t, i] for t in 1:T, i in 1:N) - m = ExaModel(c) # Wrap in WrapperNLPModel for GPU compatibility with Ipopt result = NLPModelsIpopt.ipopt(WrapperNLPModel(m); print_level = 0, tol = solver_tolerance(eltype(c.x0))) - return @test result.status == :first_order end """ @@ -124,7 +120,6 @@ function test_subexpr_in_obj_and_con(backend) # Use in constraint @add_con(c, s[i] + s[i + 1] for i in 1:4; lcon = 1.0, ucon = 3.0) - m = ExaModel(c) # Wrap in WrapperNLPModel for GPU compatibility with Ipopt @@ -164,7 +159,6 @@ end # # Subexpression uses both x and θ: s[i] = x[i] * θ[i] # # With x start = 2.0 and θ = [1,2,3], expect start = [2,4,6] # @add_expr(c2, s2, x2[i] * θ[i] for i in 1:3) - # start_vals2 = c2.x0[(s2.offset+1):(s2.offset+s2.length)] # @test Array(start_vals2) ≈ [2.0, 4.0, 6.0] @@ -208,7 +202,6 @@ function test_subexpr_reduced_basic(backend) @test result1.status == result2.status return @test solution(result1, x1) ≈ solution(result2, x2) atol = sol_tolerance(eltype(c1.x0)) rtol = sol_tolerance(eltype(c1.x0)) end - # """ # Test that reduced and lifted subexpressions produce equivalent solutions. # """ @@ -227,7 +220,6 @@ end # @add_expr(c2, s2, sqrt(x2[i]) for i in 1:5) # @add_obj(c2, (s2[i] - 1)^2 for i in 1:5) # @add_con(c2, s2[i] + s2[i + 1] for i in 1:4; lcon = 1.0, ucon = 3.0) -# m2 = ExaModel(c2) # # Lifted has more vars/cons # @test m1.meta.nvar > m2.meta.nvar @@ -236,7 +228,6 @@ end # # Solve both (wrap in WrapperNLPModel for GPU compatibility with Ipopt) # result1 = NLPModelsIpopt.ipopt(WrapperNLPModel(m1); print_level = 0, tol = solver_tolerance(eltype(c1.x0))) # result2 = NLPModelsIpopt.ipopt(WrapperNLPModel(m2); print_level = 0, tol = solver_tolerance(eltype(c2.x0))) - # # Both should converge to same solution # @test result1.status == :first_order # @test result2.status == :first_order @@ -303,7 +294,6 @@ function test_subexpr_reduced_nested(backend) # 2*x^2 = 2 => x = 1 return @test solution(result, x) ≈ ones(5) atol = sol_tolerance(eltype(c.x0)) rtol = sol_tolerance(eltype(c.x0)) end - # """ # Test mixed lifted and reduced subexpressions. # """ @@ -318,7 +308,6 @@ end # @add_expr(c, s_reduced, s_lifted[i] * 2 for i in 1:5) # @add_obj(c, (s_reduced[i] - 2)^2 for i in 1:5) - # m = ExaModel(c) # # Only lifted subexpr adds vars/cons @@ -398,7 +387,6 @@ function test_subexpr_reduced_0based_nested(backend) # (x+1)*2 = 4 => x = 1 return @test solution(result, x) ≈ ones(T + 1, N + 1) atol = sol_tolerance(eltype(c.x0)) rtol = sol_tolerance(eltype(c.x0)) end - """ Run all subexpression tests. """ @@ -418,7 +406,6 @@ function test_subexpr(backend) # @testset "Subexpr in obj and con (lifted)" begin # test_subexpr_in_obj_and_con(backend) # end - # @testset "Subexpr lifted start values" begin # test_subexpr_lifted_start_values(backend) # end @@ -426,7 +413,6 @@ function test_subexpr(backend) @testset "Subexpr reduced basic" begin test_subexpr_reduced_basic(backend) end - # @testset "Subexpr lifted vs reduced" begin # test_subexpr_lifted_vs_reduced(backend) # end @@ -438,7 +424,6 @@ function test_subexpr(backend) @testset "Subexpr reduced nested" begin test_subexpr_reduced_nested(backend) end - # @testset "Subexpr mixed lifted and reduced" begin # test_subexpr_mixed(backend) # end @@ -452,4 +437,3 @@ function test_subexpr(backend) end end - diff --git a/test/TwoStageTest/Project.toml b/test/TwoStageTest/Project.toml new file mode 100644 index 000000000..dc597fc6e --- /dev/null +++ b/test/TwoStageTest/Project.toml @@ -0,0 +1,2 @@ +[deps] +ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" diff --git a/test/runtests.jl b/test/runtests.jl index f115393a4..859c1f50f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,9 @@ include("PrettyPrintTest.jl") @info "Running NLP Test" NLPTest.runtests() + @info "Running AD Test" + ADTest.runtests() + @info "Running JuMP Test" JuMPTest.runtests()