diff --git a/src/msgspec/_core.c b/src/msgspec/_core.c index 6a5e7a8b..1c1a5229 100644 --- a/src/msgspec/_core.c +++ b/src/msgspec/_core.c @@ -539,13 +539,18 @@ msgspec_get_state(PyObject *module) return (MsgspecState *)PyModule_GetState(module); } -/* Find the module instance imported in the currently running sub-interpreter - and get its state. */ +/* +with multi-phase init PyState_FindModule is not usable. +since we declare MULTIPLE_INTERPRETERS_NOT_SUPPORTED, we are guaranteed to have +at most one live module instance per process, so we can cache its state here. +state is populated populate in _core_exec and cleared by m_clear / m_free. +*/ +static MsgspecState *_core_state = NULL; + static MsgspecState * msgspec_get_global_state(void) { - PyObject *module = PyState_FindModule(&msgspecmodule); - return module == NULL ? NULL : msgspec_get_state(module); + return _core_state; } static int @@ -22331,6 +22336,7 @@ static void msgspec_free(PyObject *m) { msgspec_clear(m); + _core_state = NULL; } static int @@ -22399,133 +22405,114 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg) return 0; } -static struct PyModuleDef msgspecmodule = { - PyModuleDef_HEAD_INIT, - .m_name = "msgspec._core", - .m_size = sizeof(MsgspecState), - .m_methods = msgspec_methods, - .m_traverse = msgspec_traverse, - .m_clear = msgspec_clear, - .m_free =(freefunc)msgspec_free -}; -PyMODINIT_FUNC -PyInit__core(void) +int _core_exec(PyObject *m) { - PyObject *m, *temp_module, *temp_obj; MsgspecState *st; + PyObject *temp_module, *temp_obj; - PyDateTime_IMPORT; + /* populate the state pointer first, so any call during the remainder of init sees + a valid state */ + st = msgspec_get_state(m); + _core_state = st; - m = PyState_FindModule(&msgspecmodule); - if (m) { - Py_INCREF(m); - return m; - } + PyDateTime_IMPORT; StructMetaType.tp_base = &PyType_Type; if (PyType_Ready(&NoDefault_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&Unset_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&Factory_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&Field_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&IntLookup_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&StrLookup_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&LiteralInfo_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&TypedDictInfo_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&DataclassInfo_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&NamedTupleInfo_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&StructInfo_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&Meta_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&StructMetaType) < 0) - return NULL; + return -1; if (PyType_Ready(&StructMixinType) < 0) - return NULL; + return -1; if (PyType_Ready(&StructConfig_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&Encoder_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&Decoder_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&Ext_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&Raw_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&JSONEncoder_Type) < 0) - return NULL; + return -1; if (PyType_Ready(&JSONDecoder_Type) < 0) - return NULL; - - /* Create the module */ - m = PyModule_Create(&msgspecmodule); - if (m == NULL) - return NULL; + return -1; /* Add types */ Py_INCREF(&Factory_Type); if (PyModule_AddObject(m, "Factory", (PyObject *)&Factory_Type) < 0) - return NULL; + return -1; if (PyModule_AddObject(m, "Field", (PyObject *)&Field_Type) < 0) - return NULL; + return -1; Py_INCREF(&Meta_Type); if (PyModule_AddObject(m, "Meta", (PyObject *)&Meta_Type) < 0) - return NULL; + return -1; Py_INCREF(&StructConfig_Type); if (PyModule_AddObject(m, "StructConfig", (PyObject *)&StructConfig_Type) < 0) - return NULL; + return -1; Py_INCREF(&Ext_Type); if (PyModule_AddObject(m, "Ext", (PyObject *)&Ext_Type) < 0) - return NULL; + return -1; Py_INCREF(&Raw_Type); if (PyModule_AddObject(m, "Raw", (PyObject *)&Raw_Type) < 0) - return NULL; + return -1; Py_INCREF(&Encoder_Type); if (PyModule_AddObject(m, "MsgpackEncoder", (PyObject *)&Encoder_Type) < 0) - return NULL; + return -1; Py_INCREF(&Decoder_Type); if (PyModule_AddObject(m, "MsgpackDecoder", (PyObject *)&Decoder_Type) < 0) - return NULL; + return -1; Py_INCREF(&JSONEncoder_Type); if (PyModule_AddObject(m, "JSONEncoder", (PyObject *)&JSONEncoder_Type) < 0) - return NULL; + return -1; Py_INCREF(&JSONDecoder_Type); if (PyModule_AddObject(m, "JSONDecoder", (PyObject *)&JSONDecoder_Type) < 0) - return NULL; + return -1; Py_INCREF(&Unset_Type); if (PyModule_AddObject(m, "UnsetType", (PyObject *)&Unset_Type) < 0) - return NULL; + return -1; Py_INCREF((PyObject *)&StructMetaType); if (PyModule_AddObject(m, "StructMeta", (PyObject *)&StructMetaType) < 0) { Py_DECREF((PyObject *)&StructMetaType); - Py_DECREF(m); - return NULL; + return -1; } - st = msgspec_get_state(m); - /* Initialize GC counter */ st->gc_cycle = 0; /* Add NODEFAULT singleton */ Py_INCREF(NODEFAULT); if (PyModule_AddObject(m, "NODEFAULT", NODEFAULT) < 0) - return NULL; + return -1; /* Add UNSET singleton */ Py_INCREF(UNSET); if (PyModule_AddObject(m, "UNSET", UNSET) < 0) - return NULL; + return -1; /* Initialize the exceptions. */ st->MsgspecError = PyErr_NewExceptionWithDoc( @@ -22533,61 +22520,61 @@ PyInit__core(void) "Base class for all Msgspec exceptions", NULL, NULL ); - if (st->MsgspecError == NULL) return NULL; + if (st->MsgspecError == NULL) return -1; st->EncodeError = PyErr_NewExceptionWithDoc( "msgspec.EncodeError", "An error occurred while encoding an object", st->MsgspecError, NULL ); - if (st->EncodeError == NULL) return NULL; + if (st->EncodeError == NULL) return -1; temp_obj = PyTuple_Pack(2, st->MsgspecError, PyExc_ValueError); - if (temp_obj == NULL) return NULL; + if (temp_obj == NULL) return -1; st->DecodeError = PyErr_NewExceptionWithDoc( "msgspec.DecodeError", "An error occurred while decoding an object", temp_obj, NULL ); Py_XDECREF(temp_obj); - if (st->DecodeError == NULL) return NULL; + if (st->DecodeError == NULL) return -1; st->ValidationError = PyErr_NewExceptionWithDoc( "msgspec.ValidationError", "The message didn't match the expected schema", st->DecodeError, NULL ); - if (st->ValidationError == NULL) return NULL; + if (st->ValidationError == NULL) return -1; Py_INCREF(st->MsgspecError); if (PyModule_AddObject(m, "MsgspecError", st->MsgspecError) < 0) - return NULL; + return -1; Py_INCREF(st->EncodeError); if (PyModule_AddObject(m, "EncodeError", st->EncodeError) < 0) - return NULL; + return -1; Py_INCREF(st->DecodeError); if (PyModule_AddObject(m, "DecodeError", st->DecodeError) < 0) - return NULL; + return -1; Py_INCREF(st->ValidationError); if (PyModule_AddObject(m, "ValidationError", st->ValidationError) < 0) - return NULL; + return -1; /* Initialize the struct_lookup_cache */ st->struct_lookup_cache = PyDict_New(); - if (st->struct_lookup_cache == NULL) return NULL; + if (st->struct_lookup_cache == NULL) return -1; Py_INCREF(st->struct_lookup_cache); if (PyModule_AddObject(m, "_struct_lookup_cache", st->struct_lookup_cache) < 0) - return NULL; + return -1; #define SET_REF(attr, name) \ do { \ st->attr = PyObject_GetAttrString(temp_module, name); \ - if (st->attr == NULL) return NULL; \ + if (st->attr == NULL) return -1; \ } while (0) /* Get all imports from the typing module */ temp_module = PyImport_ImportModule("typing"); - if (temp_module == NULL) return NULL; + if (temp_module == NULL) return -1; SET_REF(typing_union, "Union"); SET_REF(typing_any, "Any"); SET_REF(typing_literal, "Literal"); @@ -22602,7 +22589,7 @@ PyInit__core(void) Py_DECREF(temp_module); temp_module = PyImport_ImportModule("msgspec._utils"); - if (temp_module == NULL) return NULL; + if (temp_module == NULL) return -1; SET_REF(concrete_types, "_CONCRETE_TYPES"); SET_REF(get_type_hints, "get_type_hints"); SET_REF(get_class_annotations, "get_class_annotations"); @@ -22613,86 +22600,86 @@ PyInit__core(void) Py_DECREF(temp_module); temp_module = PyImport_ImportModule("types"); - if (temp_module == NULL) return NULL; + if (temp_module == NULL) return -1; SET_REF(types_uniontype, "UnionType"); Py_DECREF(temp_module); /* Get the EnumMeta type */ temp_module = PyImport_ImportModule("enum"); if (temp_module == NULL) - return NULL; + return -1; temp_obj = PyObject_GetAttrString(temp_module, "EnumMeta"); Py_DECREF(temp_module); if (temp_obj == NULL) - return NULL; + return -1; if (!PyType_Check(temp_obj)) { Py_DECREF(temp_obj); PyErr_SetString(PyExc_TypeError, "enum.EnumMeta should be a type"); - return NULL; + return -1; } st->EnumMetaType = (PyTypeObject *)temp_obj; /* Get the abc.ABCMeta type and _abc_init helper */ temp_module = PyImport_ImportModule("abc"); if (temp_module == NULL) - return NULL; + return -1; temp_obj = PyObject_GetAttrString(temp_module, "ABCMeta"); if (temp_obj == NULL) { Py_DECREF(temp_module); - return NULL; + return -1; } if (!PyType_Check(temp_obj)) { Py_DECREF(temp_obj); Py_DECREF(temp_module); PyErr_SetString(PyExc_TypeError, "abc.ABCMeta should be a type"); - return NULL; + return -1; } st->ABCMetaType = (PyTypeObject *)temp_obj; temp_obj = PyObject_GetAttrString(temp_module, "_abc_init"); Py_DECREF(temp_module); if (temp_obj == NULL) - return NULL; + return -1; st->_abc_init = temp_obj; /* Get the datetime.datetime.astimezone method */ temp_module = PyImport_ImportModule("datetime"); - if (temp_module == NULL) return NULL; + if (temp_module == NULL) return -1; temp_obj = PyObject_GetAttrString(temp_module, "datetime"); Py_DECREF(temp_module); - if (temp_obj == NULL) return NULL; + if (temp_obj == NULL) return -1; st->astimezone = PyObject_GetAttrString(temp_obj, "astimezone"); Py_DECREF(temp_obj); - if (st->astimezone == NULL) return NULL; + if (st->astimezone == NULL) return -1; /* uuid module imports */ temp_module = PyImport_ImportModule("uuid"); - if (temp_module == NULL) return NULL; + if (temp_module == NULL) return -1; st->UUIDType = PyObject_GetAttrString(temp_module, "UUID"); - if (st->UUIDType == NULL) return NULL; + if (st->UUIDType == NULL) return -1; temp_obj = PyObject_GetAttrString(temp_module, "SafeUUID"); - if (temp_obj == NULL) return NULL; + if (temp_obj == NULL) return -1; st->uuid_safeuuid_unknown = PyObject_GetAttrString(temp_obj, "unknown"); Py_DECREF(temp_obj); - if (st->uuid_safeuuid_unknown == NULL) return NULL; + if (st->uuid_safeuuid_unknown == NULL) return -1; /* decimal module imports */ temp_module = PyImport_ImportModule("decimal"); - if (temp_module == NULL) return NULL; + if (temp_module == NULL) return -1; st->DecimalType = PyObject_GetAttrString(temp_module, "Decimal"); - if (st->DecimalType == NULL) return NULL; + if (st->DecimalType == NULL) return -1; /* Get the re.compile function */ temp_module = PyImport_ImportModule("re"); - if (temp_module == NULL) return NULL; + if (temp_module == NULL) return -1; st->re_compile = PyObject_GetAttrString(temp_module, "compile"); Py_DECREF(temp_module); - if (st->re_compile == NULL) return NULL; + if (st->re_compile == NULL) return -1; /* Initialize cached constant strings */ #define CACHED_STRING(attr, str) \ - if ((st->attr = PyUnicode_InternFromString(str)) == NULL) return NULL + if ((st->attr = PyUnicode_InternFromString(str)) == NULL) return -1 CACHED_STRING(str___weakref__, "__weakref__"); CACHED_STRING(str___dict__, "__dict__"); CACHED_STRING(str___msgspec_cached_hash__, "__msgspec_cached_hash__"); @@ -22727,16 +22714,41 @@ PyInit__core(void) CACHED_STRING(str_is_safe, "is_safe"); /* Initialize the Struct Type */ - PyState_AddModule(m, &msgspecmodule); st->StructType = PyObject_CallFunction( (PyObject *)&StructMetaType, "s(O){ssss}", "Struct", &StructMixinType, "__module__", "msgspec", "__doc__", Struct__doc__ ); - if (st->StructType == NULL) return NULL; + if (st->StructType == NULL) return -1; Py_INCREF(st->StructType); - if (PyModule_AddObject(m, "Struct", st->StructType) < 0) return NULL; + if (PyModule_AddObject(m, "Struct", st->StructType) < 0) return -1; #ifdef Py_GIL_DISABLED - PyUnstable_Module_SetGIL(m, Py_MOD_GIL_NOT_USED); + #if !PY313_PLUS + PyUnstable_Module_SetGIL(m, Py_MOD_GIL_NOT_USED); + #endif +#endif + return 0; +} + +static struct PyModuleDef_Slot module_slots[] = { +#if PY313_PLUS + {Py_mod_gil, Py_MOD_GIL_NOT_USED}, #endif - return m; + {Py_mod_exec, _core_exec}, + {0, NULL} +}; + +static struct PyModuleDef msgspecmodule = { + PyModuleDef_HEAD_INIT, + .m_name = "msgspec._core", + .m_size = sizeof(MsgspecState), + .m_methods = msgspec_methods, + .m_traverse = msgspec_traverse, + .m_clear = msgspec_clear, + .m_free =(freefunc)msgspec_free, + .m_slots = module_slots +}; + +PyMODINIT_FUNC +PyInit__core(void) { + return PyModuleDef_Init(&msgspecmodule); } diff --git a/tests/unit/test_module.py b/tests/unit/test_module.py new file mode 100644 index 00000000..96c287ed --- /dev/null +++ b/tests/unit/test_module.py @@ -0,0 +1,29 @@ +import sys + +import pytest + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="3.12+ only") +class TestSubinterpreterGuard: + # sub-interpreters currently not supported. make sure we correctly report that + @pytest.mark.skipif(sys.version_info >= (3, 13), reason="3.12 only") + def test_subinterpreter_import_rejected_312(self): + interpreters = pytest.importorskip("test.support.interpreters") + + interp = interpreters.create() + try: + with pytest.raises(match=".*does not support loading in subinterpreters"): + interp.run("import msgspec._core") + finally: + interp.close() + + @pytest.mark.skipif(sys.version_info < (3, 13), reason="3.13+ only") + def test_subinterpreter_import_rejected_313(self): + interpreters = pytest.importorskip("_interpreters") + + interp = interpreters.create() + try: + res = interpreters.exec(interp, "import msgspec._core") + assert "does not support loading in subinterpreters" in res.msg + finally: + interpreters.destroy(interp)