Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions lib/bindings/python/rust/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ impl HttpAsyncEngine {
}
}

#[derive(FromPyObject)]
struct HttpError {
code: u16,
message: String,
}

#[async_trait]
impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> for HttpAsyncEngine
where
Expand All @@ -153,18 +159,10 @@ where
Err(e) => {
if let Some(py_err) = e.downcast_ref::<PyErr>() {
Python::with_gil(|py| {
let err_val = py_err.clone_ref(py).into_value(py);
let bound_err = err_val.bind(py);

// check: Py03 exceptions cannot be cross-compiled, so we duck-type by name
// and fields.
if let Ok(type_name) = bound_err.get_type().name()
&& type_name.to_string().contains("HttpError")
&& let (Ok(code), Ok(message)) =
(bound_err.getattr("code"), bound_err.getattr("message"))
&& let (Ok(code), Ok(message)) =
(code.extract::<u16>(), message.extract::<String>())
{
// With the Stable ABI, we can't subclass Python's built-in exceptions in PyO3, so instead we
// implement the exception in Python and assume that it's an HttpError if the code and message
// are present.
if let Ok(HttpError { code, message }) = py_err.value(py).extract() {
// SSE panics if there are carriage returns or newlines
let message = message.replace(['\r', '\n'], "");
return Err(http_error::HttpError { code, message })?;
Expand Down
26 changes: 23 additions & 3 deletions lib/bindings/python/src/dynamo/llm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,33 @@

# flake8: noqa

import logging

logger = logging.getLogger(__name__)

_MAX_MESSAGE_LENGTH = 8192


class HttpError(Exception):
def __init__(self, code: int, message: str):
if not (isinstance(code, int) and 0 <= code < 600):
# These ValueErrors are easier to trace to here than the TypeErrors that
# would be raised otherwise.
if not isinstance(code, int) or isinstance(code, bool):
raise ValueError("HttpError status code must be an integer")

if not isinstance(message, str):
raise ValueError("HttpError message must be a string")

if not (0 <= code < 600):
raise ValueError("HTTP status code must be an integer between 0 and 599")
if not (isinstance(message, str) and 0 < len(message) <= 8192):
raise ValueError("HTTP error message must be a string of length <= 8192")

if len(message) > _MAX_MESSAGE_LENGTH:
logger.warning(
f"HttpError message length {len(message)} exceeds max length {_MAX_MESSAGE_LENGTH}, truncating..."
)
message = message[: (_MAX_MESSAGE_LENGTH - 3)] + "..."

self.code = code
self.message = message

super().__init__(f"HTTP {code}: {message}")
16 changes: 16 additions & 0 deletions lib/bindings/python/tests/test_http_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,19 @@ def test_raise_http_error():
def test_invalid_http_error_code():
with pytest.raises(ValueError):
HttpError(1700, "Invalid Code")


def test_invalid_http_error_message():
with pytest.raises(ValueError):
# The second argument must be a string, not bytes.
HttpError(400, b"Bad Request")


def test_long_http_error_message():
message = ("A" * 8192) + "B"
error = HttpError(400, message)
assert len(error.message) == 8192

# Ensure the exception string uses the truncated message too.
assert message[:8189] in str(error)
assert "B" not in str(error)
Loading