Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions extensions/pyo3/private/pyo3.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,43 @@ def _py_pyo3_library_impl(ctx):
is_windows = extension.basename.endswith(".dll")

# https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds
ext = ctx.actions.declare_file("{}{}".format(
ctx.label.name,
".pyd" if is_windows else ".so",
))
#
# Determine the on-disk and logical Python module layout.
#
# `module` is a full dotted module path (e.g. "foo.bar"). We split on the
# last "." such that:
# - module_prefix == "foo"
# - module_name == "bar"
#
# `module_name` must match the `#[pymodule] fn <name>(...)` in the Rust code
# and is also what we pass to the stub generator.
module_path = ctx.attr.module_name if ctx.attr.module_name else ctx.label.name.replace("/", ".")

if module_path.startswith(".") or module_path.endswith(".") or ".." in module_path:
fail("Invalid `module` value '{}': expected a dotted module path like 'foo.bar'.".format(module_path))

last_dot = module_path.rfind(".")
if last_dot == -1:
module_prefix = None
module_name = module_path
else:
module_prefix = module_path[:last_dot]
module_name = module_path[last_dot + 1:]

if not module_name:
fail("Invalid `module` value '{}': module name may not be empty.".format(module_path))

# Convert module_prefix (e.g. "foo.bar") into a path ("foo/bar") and place
# the extension and stubs in the corresponding directory.
if module_prefix:
module_prefix_path = module_prefix.replace(".", "/")
module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so")
stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name)
else:
module_relpath = "{}.{}".format(module_name, "pyd" if is_windows else "so")
stub_relpath = "{}.pyi".format(module_name)

ext = ctx.actions.declare_file(module_relpath)
ctx.actions.symlink(
output = ext,
target_file = extension,
Expand All @@ -99,10 +132,10 @@ def _py_pyo3_library_impl(ctx):

stub = None
if _stubs_enabled(ctx.attr.stubs, toolchain):
stub = ctx.actions.declare_file("{}.pyi".format(ctx.label.name))
stub = ctx.actions.declare_file(stub_relpath)

args = ctx.actions.args()
args.add(ctx.label.name, format = "--module_name=%s")
args.add(module_name, format = "--module_name=%s")
args.add(ext, format = "--module_path=%s")
args.add(stub, format = "--output=%s")
ctx.actions.run(
Expand Down Expand Up @@ -180,6 +213,9 @@ py_pyo3_library = rule(
"imports": attr.string_list(
doc = "List of import directories to be added to the `PYTHONPATH`.",
),
"module_name": attr.string(
doc = "A full dotted Python module path implemented by this extension (e.g. `foo.bar`).",
),
"stubs": attr.int(
doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.",
default = -1,
Expand Down Expand Up @@ -218,6 +254,7 @@ def pyo3_extension(
stubs = None,
version = None,
compilation_mode = "opt",
module_name = None,
**kwargs):
"""Define a PyO3 python extension module.

Expand Down Expand Up @@ -259,6 +296,7 @@ def pyo3_extension(
For more details see [rust_shared_library][rsl].
compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode)
value to build the extension for. If set to `"current"`, the current configuration will be used.
module_name (str, optional): A full dotted Python module path implemented by this extension (e.g. `foo.bar`).
**kwargs (dict): Additional keyword arguments.
"""
tags = kwargs.pop("tags", [])
Expand Down Expand Up @@ -318,6 +356,7 @@ def pyo3_extension(
compilation_mode = compilation_mode,
stubs = stubs_int,
imports = imports,
module_name = module_name,
tags = tags,
visibility = visibility,
**kwargs
Expand Down
15 changes: 15 additions & 0 deletions extensions/pyo3/test/module_prefix/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
load("@rules_python//python:defs.bzl", "py_test")
load("//:defs.bzl", "pyo3_extension")

pyo3_extension(
name = "module_prefix",
srcs = ["bar.rs"],
edition = "2021",
module_name = "foo.bar",
)

py_test(
name = "module_prefix_import_test",
srcs = ["module_prefix_import_test.py"],
deps = [":module_prefix"],
)
12 changes: 12 additions & 0 deletions extensions/pyo3/test/module_prefix/bar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use pyo3::prelude::*;

#[pyfunction]
fn thing() -> PyResult<&'static str> {
Ok("hello from rust")
}

#[pymodule]
fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(thing, m)?)?;
Ok(())
}
18 changes: 18 additions & 0 deletions extensions/pyo3/test/module_prefix/module_prefix_import_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Tests that a pyo3 extension can be imported via a module prefix."""

import unittest
from test.module_prefix.foo import bar


class ModulePrefixImportTest(unittest.TestCase):
"""Test Class."""

def test_import_and_call(self) -> None:
"""Test that a pyo3 extension can be imported via a module prefix."""

result = bar.thing()
self.assertEqual("hello from rust", result)


if __name__ == "__main__":
unittest.main()