Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1024,11 +1024,14 @@ end

"""
add_con(core, generator; start = 0, lcon = 0, ucon = 0, name = nothing, tag = nothing)
add_con(core, expr, pars; start = 0, lcon = 0, ucon = 0, name = nothing, tag = nothing)
add_con(core, dims...; start = 0, lcon = 0, ucon = 0, name = nothing, tag = nothing)

Adds constraints to `core` and returns `(core, Constraint)`.

**Generator form**: pass a `generator` that yields one expression per constraint row.
**Generator form**: pass a `generator` that yields one expression per constraint row. Can
instead use the low-level `add_con(core, expr, pars; ...)` if you have already built an
expression tree.

**Dims form**: pass integer or `UnitRange` dimensions to create empty constraints,
then use [`add_con!`](@ref) / [`@add_con!`](@ref) to accumulate terms afterwards.
Expand Down Expand Up @@ -1099,25 +1102,37 @@ Constraint
gen = _get_generator(ns)
dims = _get_con_dims(ns)
gen = _adapt_gen(gen)
f = _simdfunction(T, gen.f(DataSource()), c.ncon, c.nnzj, c.nnzh)
pars = gen.iter
f, pars = _get_function_and_iter(T, gen, c)

_add_con(c, f, pars, dims, start, lcon, ucon, name, tag)
end

@inline _get_generator(ns) = (Null(nothing) for _ in _empty_con_itr(ns))
@inline _get_generator(gen::Tuple{G}) where G <: Base.Generator = gen[1]
@inline _get_generator(n::Tuple{N}) where N <: AbstractNode = (n[1] for _ in 1:1)
@inline _get_generator(n::Tuple{E, I}) where {E<:AbstractNode, I} = n

# Infer constraint dims from the original arguments, preserving range start info.
@inline _get_con_dims(ns) = ns
@inline _get_con_dims(gen::Tuple{G}) where G <: Base.Generator = _infer_subexpr_dims(gen[1].iter)
@inline _get_con_dims(n::Tuple{N}) where N <: AbstractNode = (1,)
@inline _get_con_dims(n::Tuple{E, I}) where {E<:AbstractNode, I} = _infer_subexpr_dims(n[2])

# Build an iterator for empty constraints: 1:n for 1D, collected ProductIterator for multi-dim.
_empty_con_itr(ns::Tuple{Any}) = 1:_length(ns[1])
_empty_con_itr(ns::Tuple) = collect(Iterators.product(map(n -> 1:_length(n), ns)...))

# Extract simd function and iterator from a generator
@inline function _get_function_and_iter(T, gen::Base.Generator, c)
f = _simdfunction(T, gen.f(DataSource()), c.ncon, c.nnzj, c.nnzh)
pars = gen.iter
return f, pars
end
@inline function _get_function_and_iter(T, gen::Tuple{E, I}, c) where {E<:AbstractNode, I}
f = _simdfunction(T, gen[1], c.ncon, c.nnzj, c.nnzh)
pars = gen[2]
return f, pars
end

function _add_con(c, f, pars, dims, start, lcon, ucon, name, tag)
nitr = length(pars)
Expand Down Expand Up @@ -1687,6 +1702,8 @@ end

_adapt_gen(gen) = Base.Generator(gen.f, collect(gen.iter))
_adapt_gen(gen::Base.Generator{P}) where {P<:Union{AbstractArray,AbstractRange}} = gen
_adapt_gen(gen::Tuple{E, I}) where {E<:AbstractNode, I} = (gen[1], collect(gen[2]))
_adapt_gen(gen::Tuple{E, I}) where {E<:AbstractNode, I<:Union{AbstractArray,AbstractRange}} = gen

function Base.getproperty(core::E, name::Symbol) where {E <: Union{ExaCore, ExaModel}}
if hasfield(E, name)
Expand Down
14 changes: 14 additions & 0 deletions test/NLPTest/feature_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ function test_nonunit_expr(backend)
end
end

function test_generator_free_constr(backend)
@testset "add_con(core, expr, itr)" begin
c = ExaCore(; backend, concrete = Val(true))
@add_var(c, x, 2)
gen = (sin(x[i]) for i in 1:2)
expr = gen.f(ExaModels.DataSource())
@add_con(c, g, expr, 1:2)
@test c.g isa ExaModels.Constraint
end
end

function test_features(backend)
@testset "Const" begin
test_const(backend)
Expand All @@ -184,4 +195,7 @@ function test_features(backend)
@testset "Non-unit expression indexing" begin
test_nonunit_expr(backend)
end
@testset "Generator-free constraint" begin
test_generator_free_constr(backend)
end
end
Loading