Skip to content

Commit d731cfc

Browse files
committed
Fix remaining tests
1 parent a5af452 commit d731cfc

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

cpmpy/solvers/rc2.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,6 @@ def solve(self, time_limit=None, stratified=True, adapt=True, exhaust=True, minz
195195

196196
# the user vars are only the Booleans (e.g. to ensure solveAll behaves consistently)
197197
self.user_vars = get_user_vars(self.user_vars, self.ivarmap)
198-
199-
# TODO I believe assumptions can be added in the WCNF as `soft`
200198

201199
# TODO: set time limit
202200
if time_limit is not None:
@@ -284,19 +282,17 @@ def transform_objective(self, expr):
284282
else:
285283
raise NotImplementedError(f"CPM_rc2: Non supported objective {flat_obj} (yet?)")
286284

287-
try:
288-
terms, cons, k = _encode_lin_expr(self.ivarmap, xs, weights, self.encoding)
289-
except TypeError:
290-
raise NotImplementedError(f"CPM_rc2: Unsupported objective: {flat_obj}")
285+
terms, cons, k = _encode_lin_expr(self.ivarmap, xs, weights, self.encoding)
291286

292287
self += cons
293288
const += k
294289

290+
terms = [(w, x) for w,x in terms if w != 0] # positive coefficients only
295291
ws, xs = zip(*terms) # unzip
296-
new_weights, new_xs, k = only_positive_coefficients_(ws, xs)
292+
new_weights, new_xs, k = only_positive_coefficients_(ws, xs) # this is actually only_non_negative_coefficients
297293
const += k
298294

299-
return new_weights, new_xs, const
295+
return list(new_weights), list(new_xs), const
300296

301297

302298
def objective(self, expr, minimize):

cpmpy/transformations/int2bool.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..expressions.core import BoolVal, Comparison, Expression, Operator
1111
from ..expressions.globalconstraints import DirectConstraint
1212
from ..expressions.variables import _BoolVarImpl, _IntVarImpl, boolvar
13+
from ..expressions.utils import is_int
1314

1415
UNKNOWN_COMPARATOR_ERROR = ValueError("Comparator is not known or should have been simplified by linearize.")
1516
EMPTY_DOMAIN_ERROR = ValueError("Attempted to encode variable with empty domain (which is unsat)")
@@ -101,7 +102,9 @@ def _encode_lin_expr(ivarmap, xs, weights, encoding, cmp=None):
101102
k = 0
102103
for w, x in zip(weights, xs):
103104
# the linear may contain Boolean as well as integer variables
104-
if isinstance(x, _BoolVarImpl):
105+
if is_int(x):
106+
k += w * x
107+
elif isinstance(x, _BoolVarImpl):
105108
terms += [(w, x)]
106109
elif isinstance(x, _IntVarImpl):
107110
x_enc, x_cons = _encode_int_var(ivarmap, x, _decide_encoding(x, cmp, encoding))
@@ -111,7 +114,7 @@ def _encode_lin_expr(ivarmap, xs, weights, encoding, cmp=None):
111114
terms += new_terms
112115
k += k_
113116
else:
114-
raise TypeError
117+
raise TypeError(f"Term {w} * {x} for {type(x)}")
115118
return terms, domain_constraints, k
116119

117120

cpmpy/transformations/linearize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def canonical_comparison(lst_of_expr):
631631
return newlist
632632

633633
def only_positive_coefficients_(ws, xs):
634+
""" Helper function which weight """
634635
indices = {i for i, (w, x) in enumerate(zip(ws, xs)) if w < 0 and isinstance(x, _BoolVarImpl)}
635636
nw, na = zip(*[(-w, ~x) if i in indices else (w, x) for i, (w, x) in enumerate(zip(ws, xs))])
636637
cons = sum(ws[i] for i in indices)
@@ -645,6 +646,7 @@ def only_positive_coefficients(lst_of_expr):
645646
646647
Resulting expression is linear.
647648
"""
649+
# TODO this should be renamed to only_non_negative_coefficients, because it does not remove terms with coefficient 0. I think it does not, because it risks removing user variables.
648650
newlist = []
649651
for cpm_expr in lst_of_expr:
650652
if isinstance(cpm_expr, Comparison):

tests/test_rc2_obj.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def test_rc2_solve_with_integer_variables(self):
182182
model = cp.Model()
183183
x = cp.boolvar(2)
184184
y = cp.intvar(0, 3, shape=2)
185-
model.maximize(cp.sum(x) + cp.sum(y))
185+
z = cp.intvar(0, 3)
186+
model.maximize(cp.sum(x) + cp.sum(y) + 0 * z)
186187
# Add constraints
187188
model += (x[0] != x[1]) # both must be different
188189
model += (y[0] < y[1]) # y[0] must be less than y[1]

0 commit comments

Comments
 (0)