Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/supported-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ purpose, but with a `typing.Literal` the decoded values are literal `int` or
A literal can be composed of any of the following objects:

- `None`
- `bool` values (`True` and `False`)
- `int` values
- `str` values
- Nested `typing.Literal` types
Expand Down Expand Up @@ -1170,6 +1171,12 @@ values, or doesn't match any of their component types.
File "<stdin>", line 1, in <module>
msgspec.ValidationError: Expected `int`, got `str`

>>> msgspec.json.decode(b'true', type=Literal[True])
True

>>> msgspec.json.decode(b'false', type=Literal[True, False])
False

``NewType``
-----------

Expand Down
68 changes: 64 additions & 4 deletions src/msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -2820,6 +2820,8 @@ AssocList_Sort(AssocList* list) {
#define MS_TYPE_TYPEDDICT (1ull << 33)
#define MS_TYPE_DATACLASS (1ull << 34)
#define MS_TYPE_NAMEDTUPLE (1ull << 35)
#define MS_TYPE_BOOLLITERAL_TRUE (1ull << 36)
#define MS_TYPE_BOOLLITERAL_FALSE (1ull << 37)
/* Constraints */
#define MS_CONSTR_INT_MIN (1ull << 42)
#define MS_CONSTR_INT_MAX (1ull << 43)
Expand Down Expand Up @@ -2943,6 +2945,8 @@ typedef struct {
PyObject *int_lookup;
PyObject *str_lookup;
bool literal_none;
bool literal_bool_true;
bool literal_bool_false;
} LiteralInfo;

typedef struct {
Expand Down Expand Up @@ -3449,7 +3453,7 @@ typenode_simple_repr(TypeNode *self) {
if (self->types & (MS_TYPE_ANY | MS_TYPE_CUSTOM | MS_TYPE_CUSTOM_GENERIC) || self->types == 0) {
return PyUnicode_FromString("any");
}
if (self->types & MS_TYPE_BOOL) {
if (self->types & (MS_TYPE_BOOL | MS_TYPE_BOOLLITERAL_TRUE | MS_TYPE_BOOLLITERAL_FALSE)) {
if (!strbuilder_extend_literal(&builder, "bool")) return NULL;
}
if (self->types & (MS_TYPE_INT | MS_TYPE_INTENUM | MS_TYPE_INTLITERAL)) {
Expand Down Expand Up @@ -3543,6 +3547,8 @@ typedef struct {
PyObject *literal_str_values;
PyObject *literal_str_lookup;
bool literal_none;
bool literal_bool_true;
bool literal_bool_false;
/* Constraints */
int64_t c_int_min;
int64_t c_int_max;
Expand Down Expand Up @@ -4433,6 +4439,14 @@ typenode_collect_literal(TypeNodeCollectState *state, PyObject *literal) {
if (obj == Py_None || obj == NONE_TYPE) {
state->literal_none = true;
}
else if (type == &PyBool_Type) {
if (obj == Py_True) {
state->literal_bool_true = true;
}
else {
state->literal_bool_false = true;
}
}
else if (type == &PyLong_Type) {
if (state->literal_int_values == NULL) {
state->literal_int_values = PySet_New(NULL);
Expand Down Expand Up @@ -4469,7 +4483,7 @@ typenode_collect_literal(TypeNodeCollectState *state, PyObject *literal) {
invalid:
PyErr_Format(
PyExc_TypeError,
"Literal may only contain None/integers/strings - %R is not supported",
"Literal may only contain None/booleans/integers/strings - %R is not supported",
literal
);

Expand Down Expand Up @@ -4507,6 +4521,12 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) {
if (info->literal_none) {
state->types |= MS_TYPE_NONE;
}
if (info->literal_bool_true) {
state->types |= MS_TYPE_BOOLLITERAL_TRUE;
}
if (info->literal_bool_false) {
state->types |= MS_TYPE_BOOLLITERAL_FALSE;
}
Py_DECREF(cached);
return 0;
}
Expand Down Expand Up @@ -4535,6 +4555,12 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) {
if (state->literal_none) {
state->types |= MS_TYPE_NONE;
}
if (state->literal_bool_true) {
state->types |= MS_TYPE_BOOLLITERAL_TRUE;
}
if (state->literal_bool_false) {
state->types |= MS_TYPE_BOOLLITERAL_FALSE;
}

if (n == 1) {
/* A single `Literal` object, cache the lookups on it */
Expand All @@ -4545,6 +4571,8 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) {
Py_XINCREF(state->literal_str_lookup);
info->str_lookup = state->literal_str_lookup;
info->literal_none = state->literal_none;
info->literal_bool_true = state->literal_bool_true;
info->literal_bool_false = state->literal_bool_false;
PyObject_GC_Track(info);
PyObject *literal = PyList_GET_ITEM(state->literals, 0);
int status = PyObject_SetAttr(
Expand Down Expand Up @@ -15341,6 +15369,18 @@ mpack_decode_bool(DecoderState *self, PyObject *val, TypeNode *type, PathNode *p
Py_INCREF(val);
return val;
}
if (val == Py_True && (type->types & MS_TYPE_BOOLLITERAL_TRUE)) {
Py_INCREF(Py_True);
return Py_True;
}
if (val == Py_False && (type->types & MS_TYPE_BOOLLITERAL_FALSE)) {
Py_INCREF(Py_False);
return Py_False;
}
if (type->types & (MS_TYPE_BOOLLITERAL_TRUE | MS_TYPE_BOOLLITERAL_FALSE)) {
ms_raise_validation_error(path, "Invalid enum value %R%U", val);
return NULL;
}
return ms_validation_error("bool", type, path);
}

Expand Down Expand Up @@ -16966,10 +17006,14 @@ json_decode_true(JSONDecoderState *self, TypeNode *type, PathNode *path) {
if (MS_UNLIKELY(c1 != 'r' || c2 != 'u' || c3 != 'e')) {
return json_err_invalid(self, "invalid character");
}
if (type->types & (MS_TYPE_ANY | MS_TYPE_BOOL)) {
if (type->types & (MS_TYPE_ANY | MS_TYPE_BOOL | MS_TYPE_BOOLLITERAL_TRUE)) {
Py_INCREF(Py_True);
return Py_True;
}
if (type->types & MS_TYPE_BOOLLITERAL_FALSE) {
ms_raise_validation_error(path, "Invalid enum value %R%U", Py_True);
return NULL;
}
return ms_validation_error("bool", type, path);
}

Expand All @@ -16987,10 +17031,14 @@ json_decode_false(JSONDecoderState *self, TypeNode *type, PathNode *path) {
if (MS_UNLIKELY(c1 != 'a' || c2 != 'l' || c3 != 's' || c4 != 'e')) {
return json_err_invalid(self, "invalid character");
}
if (type->types & (MS_TYPE_ANY | MS_TYPE_BOOL)) {
if (type->types & (MS_TYPE_ANY | MS_TYPE_BOOL | MS_TYPE_BOOLLITERAL_FALSE)) {
Py_INCREF(Py_False);
return Py_False;
}
if (type->types & MS_TYPE_BOOLLITERAL_TRUE) {
ms_raise_validation_error(path, "Invalid enum value %R%U", Py_False);
return NULL;
}
return ms_validation_error("bool", type, path);
}

Expand Down Expand Up @@ -20639,6 +20687,18 @@ convert_bool(
Py_INCREF(obj);
return obj;
}
if (obj == Py_True && (type->types & MS_TYPE_BOOLLITERAL_TRUE)) {
Py_INCREF(Py_True);
return Py_True;
}
if (obj == Py_False && (type->types & MS_TYPE_BOOLLITERAL_FALSE)) {
Py_INCREF(Py_False);
return Py_False;
}
if (type->types & (MS_TYPE_BOOLLITERAL_TRUE | MS_TYPE_BOOLLITERAL_FALSE)) {
ms_raise_validation_error(path, "Invalid enum value %R%U", obj);
return NULL;
}
return ms_validation_error("bool", type, path);
}

Expand Down
6 changes: 3 additions & 3 deletions src/msgspec/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,11 @@ class LiteralType(Type):
Parameters
----------
values: tuple
A tuple of possible values for this literal instance. Only `str` or
`int` literals are supported.
A tuple of possible values for this literal instance. Only `bool`,
`str`, or `int` literals are supported.
"""

values: Union[Tuple[str, ...], Tuple[int, ...]]
values: Union[Tuple[bool, ...], Tuple[str, ...], Tuple[int, ...]]


class CustomType(Type):
Expand Down
33 changes: 28 additions & 5 deletions tests/unit/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,13 +869,8 @@ def test_int_literal_values_out_of_range(self, values):
@pytest.mark.parametrize(
"typ",
[
Literal[1, False],
Literal["ok", b"bad"],
Literal[1, object()],
Union[Literal[1, 2], Literal[3, False]],
Union[Literal["one", "two"], Literal[3, False]],
Literal[Literal[1, 2], Literal[3, False]],
Literal[Literal["one", "two"], Literal[3, False]],
Literal[1, 2, List[int]],
Literal[1, 2, List],
],
Expand Down Expand Up @@ -952,6 +947,34 @@ def test_nested_literals(self):
with pytest.raises(ValidationError, match="Invalid enum value 'carrot'"):
dec.decode(msgspec.msgpack.encode("carrot"))

@pytest.mark.parametrize(
"typ, good, bad",
[
(Literal[True], [True], [False]),
(Literal[False], [False], [True]),
(Literal[True, False], [True, False], []),
(Literal[1, False], [1, False], [True]),
(Literal[True, "yes", None], [True, "yes", None], [False]),
],
)
def test_literal_bool(self, typ, good, bad):
dec = msgspec.msgpack.Decoder(typ)
for val in good:
assert dec.decode(msgspec.msgpack.encode(val)) == val
for val in bad:
with pytest.raises(ValidationError):
dec.decode(msgspec.msgpack.encode(val))

def test_literal_bool_error_message(self):
dec = msgspec.msgpack.Decoder(Literal[True])
with pytest.raises(ValidationError, match="Invalid enum value False"):
dec.decode(msgspec.msgpack.encode(False))

def test_mix_bool_and_bool_literal(self):
dec = msgspec.msgpack.Decoder(Union[Literal[True], bool])
assert dec.decode(msgspec.msgpack.encode(True)) is True
assert dec.decode(msgspec.msgpack.encode(False)) is False

def test_mix_int_and_int_literal(self):
dec = msgspec.msgpack.Decoder(Union[Literal[-1, 1], int])
for x in [-1, 1, 10]:
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,18 @@ def test_int_literal(self):
with pytest.raises(ValidationError, match="Expected `int`, got `str`"):
convert("A", typ)

def test_bool_literal(self):
assert convert(True, Literal[True]) is True
assert convert(False, Literal[False]) is False
assert convert(True, Literal[True, False]) is True
assert convert(False, Literal[True, False]) is False
with pytest.raises(ValidationError, match="Invalid enum value False"):
convert(False, Literal[True])
with pytest.raises(ValidationError, match="Invalid enum value True"):
convert(True, Literal[False])
with pytest.raises(ValidationError, match="Expected `bool`, got `str`"):
convert("yes", Literal[True])


class TestSequences:
def test_any_sequence(self):
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,10 @@ class TestLiteral:
(1, 2, "three", "four"),
(1, None),
("one", None),
(True,),
(False,),
(True, False),
(True, 1, "yes"),
],
)
def test_literal(self, values):
Expand Down Expand Up @@ -1229,6 +1233,29 @@ def test_str_literal_errors(self):
with pytest.raises(msgspec.ValidationError, match="Invalid enum value 'bad'"):
dec.decode(b'"bad"')

def test_bool_literal_true_only(self):
dec = msgspec.json.Decoder(Literal[True])
assert dec.decode(b"true") is True
with pytest.raises(msgspec.ValidationError, match="Invalid enum value False"):
dec.decode(b"false")

def test_bool_literal_false_only(self):
dec = msgspec.json.Decoder(Literal[False])
assert dec.decode(b"false") is False
with pytest.raises(msgspec.ValidationError, match="Invalid enum value True"):
dec.decode(b"true")

def test_bool_literal_errors(self):
dec = msgspec.json.Decoder(Literal[True])
with pytest.raises(msgspec.ValidationError, match="Expected `bool`, got `int`"):
dec.decode(b"42")
with pytest.raises(msgspec.ValidationError, match="Expected `bool`, got `str`"):
dec.decode(b'"hello"')
with pytest.raises(
msgspec.ValidationError, match="Expected `bool`, got `null`"
):
dec.decode(b"null")


class TestFloat:
@pytest.mark.parametrize(
Expand Down
Loading