Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ExaModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ export ExaModel,
@add_con,
@add_con!,
set_parameter!,
objective,
constraint,
constraint!,
subexpr,
solution,
multipliers,
multipliers_L,
Expand Down
40 changes: 28 additions & 12 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,26 @@ Performs dense gradient evaluation via the reverse pass on the computation (sub)
- `y`: result vector
- `adj`: adjoint propagated up to the current node
"""
@inline function drpass(d::D, y, adj) where {D<:Union{Real,AdjointNull}}
@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:Union{Real,AdjointNull}}
nothing
end
@inline function drpass(d::D, y, adj) where {D<:AdjointNode1}
offset = drpass(d.inner, y, adj * d.y)
@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNode1}
drpass(e, e_starts, e_cnts, d.inner, y, adj * d.y)
nothing
end
@inline function drpass(d::D, y, adj) where {D<:AdjointNode2}
offset = drpass(d.inner1, y, adj * d.y1)
offset = drpass(d.inner2, y, adj * d.y2)
@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNode2}
drpass(e, e_starts, e_cnts, d.inner1, y, adj * d.y1)
drpass(e, e_starts, e_cnts, d.inner2, y, adj * d.y2)
nothing
end
@inline function drpass(d::D, y, adj) where {D<:AdjointNodeVar}
@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNodeVar}
@inbounds y[d.i] += adj
nothing
end
@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNodeExpr}
y[d.i] += e[e_starts[d.i][2]]
nothing
end

"""
gradient!(y, f, x, adj)
Expand All @@ -36,15 +40,27 @@ Performs dense gradient evalution
- `x`: variable vector
- `adj`: initial adjoint
"""
function gradient!(isexp, e, e_starts, e_cnts, y, f, x, θ, adj)
@simd for k in eachindex(f.itr)
@inbounds gradient!(isexp, e, e_starts, e_cnts, y, f.f, x, θ, f.itr[k], adj)
end
return y
end
function gradient!(isexp, e, e_starts, e_cnts, y, f, x, θ, p, adj)
graph = f(p, AdjointNodeSource(x, nothing), θ)
drpass(e, e_starts, e_cnts, graph, y, adj)
return y
end
# Simple (no-expression-cache) wrappers used by _grad! in nlp.jl
function gradient!(y, f, x, θ, adj)
@simd for k in eachindex(f.itr)
@inbounds gradient!(y, f.f, x, θ, f.itr[k], adj)
end
return y
end
function gradient!(y, f, x, θ, p, adj)
graph = f(p, AdjointNodeSource(x), θ)
drpass(graph, y, adj)
graph = f(p, AdjointNodeSource(x, nothing), θ)
drpass(nothing, nothing, nothing, graph, y, adj)
return y
end

Expand Down Expand Up @@ -167,15 +183,15 @@ Performs sparse gradient evalution
- `x`: variable vector
- `adj`: initial adjoint
"""
function sgradient!(y, f, x, θ, adj)
function sgradient!(y, f, x, θ, adj, isexp)
@simd for k in eachindex(f.itr)
@inbounds sgradient!(y, f.f, f.itr[k], x, θ, f.itr.comp1, offset1(f, k), adj)
end
return y
end

function sgradient!(y, f, p, x, θ, comp, o1, adj)
graph = f(p, AdjointNodeSource(x), θ)
function sgradient!(y, f, p, x, θ, comp, o1, adj, isexp)
graph = f(p, AdjointNodeSource(x, nothing), θ)
grpass(graph, comp, y, o1, 0, adj)
return y
end
27 changes: 21 additions & 6 deletions src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ the index is itself a node).
struct Var{I} <: AbstractNode
i::I
end
struct Exp{I} <: AbstractNode
i::I
end

struct ParameterSource <: AbstractNode end
struct ParameterNode{I} <: AbstractNode
Expand Down Expand Up @@ -216,21 +219,22 @@ end
@inline Base.getindex(v::DataIndexed{I, n}, i) where {I, n} = DataIndexed(v, i)
@inline Base.indexed_iterate(v::DataIndexed{I, n}, idx, start = 1) where {I, n} = (DataIndexed(v, idx), idx + 1)


@inline Base.getindex(n::VarSource, i) = Var(i)
@inline Base.getindex(::ParameterSource, i) = ParameterNode(i)

@inline Node1(f::F, inner::I) where {F,I} = Node1{F,I}(inner)
@inline Node2(f::F, inner1::I1, inner2::I2) where {F,I1,I2} = Node2{F,I1,I2}(inner1, inner2)


struct Identity end

@inline (v::Var{I})(i, x, θ) where {I<:AbstractNode} = @inbounds x[v.i(i, x, θ)]
@inline (v::Var{I})(i, x, θ) where {I} = @inbounds x[v.i]
@inline (v::Var{I})(i::Identity, x, θ) where {I <: AbstractNode} = x[v]
@inline (v::Var{I})(i::Identity, x, θ) where {I <: Real} = x[v]

@inline (e::Exp{I})(i, x, θ) where {I<:AbstractNode} = @inbounds x[e.i(i, x, θ)]
@inline (e::Exp{I})(i, x, θ) where {I} = @inbounds x[e.i]

@inline (v::ParameterNode{I})(i, x, θ) where {I<:AbstractNode} = @inbounds θ[v.i(i, x, θ)]
@inline (v::ParameterNode{I})(::Any, x, θ) where {I} = @inbounds θ[v.i]
@inline (v::ParameterNode{I})(::Identity, x, θ) where {I<:AbstractNode} = @inbounds θ[v.i]
Expand Down Expand Up @@ -294,6 +298,11 @@ struct AdjointNodeVar{I,T} <: AbstractAdjointNode
x::T
end

struct AdjointNodeExpr{I,T} <: AbstractAdjointNode
i::I
x::T
end

"""
AdjointNodeSource{VT}

Expand All @@ -304,8 +313,9 @@ primal value.
# Fields
- `inner::VT`: primal variable vector (or `nothing` for a zero-valued seed)
"""
struct AdjointNodeSource{VT}
struct AdjointNodeSource{VT,OE}
inner::VT
offset_exps::OE
end

@inline AdjointNode1(f::F, x::T, y, inner::I) where {F,T,I} =
Expand All @@ -318,7 +328,6 @@ end
@inline Base.getindex(x::I, i) where {I<:AdjointNodeSource} =
@inbounds AdjointNodeVar(i, x.inner[i])


"""
SecondAdjointNode1{F, T, I} <: AbstractSecondAdjointNode

Expand Down Expand Up @@ -379,17 +388,23 @@ struct SecondAdjointNodeVar{I,T} <: AbstractSecondAdjointNode
x::T
end

struct SecondAdjointNodeExpr{I,T} <: AbstractSecondAdjointNode
i::I
x::T
end

"""
SecondAdjointNodeSource{VT}
SecondAdjointNodeSource{VT,VTI}

Factory for [`SecondAdjointNodeVar`](@ref) leaves. Indexing with `i` returns
`SecondAdjointNodeVar(i, inner[i])`, seeding the Hessian-pass tree.

# Fields
- `inner::VT`: primal variable vector (or `nothing` for a zero-valued seed)
"""
struct SecondAdjointNodeSource{VT}
struct SecondAdjointNodeSource{VT,OE}
inner::VT
offset_exps::OE
end

@inline SecondAdjointNode1(f::F, x::T, y, h, inner::I) where {F,T,I} =
Expand Down
Loading
Loading