Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions promptguard/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ def _request(self, method: str, path: str, **kwargs) -> dict[str, Any]:
if response.status_code >= 400:
raise _parse_error(response)

return response.json()
data: dict[str, Any] = response.json()
return data

if last_exc:
raise last_exc
Expand Down Expand Up @@ -658,7 +659,8 @@ async def _request(self, method: str, path: str, **kwargs) -> dict[str, Any]:
if response.status_code >= 400:
raise _parse_error(response)

return response.json()
data: dict[str, Any] = response.json()
return data

if last_exc:
raise last_exc
Expand Down
6 changes: 4 additions & 2 deletions promptguard/integrations/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,16 @@ def _extract_response_from_payload(self, payload: dict[str, Any]) -> str | None:
if isinstance(response, str):
return response
if hasattr(response, "text"):
return response.text
resp_text: str | None = response.text
return resp_text
if hasattr(response, "message") and hasattr(response.message, "content"):
return str(response.message.content)
return str(response)

raw = payload.get("raw")
if raw and hasattr(raw, "text"):
return raw.text
raw_text: str | None = raw.text
return raw_text

return None

Expand Down
3 changes: 2 additions & 1 deletion promptguard/patches/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def _handle_pre_scan_decision(

if decision.redacted and decision.redacted_messages:
if get_mode() == "enforce" and apply_redaction:
return apply_redaction(args, kwargs, decision.redacted_messages)
redacted: dict = apply_redaction(args, kwargs, decision.redacted_messages)
return redacted
logger.warning(
"[monitor] PromptGuard would redact: %s (event=%s)",
decision.threat_type,
Expand Down
14 changes: 7 additions & 7 deletions promptguard/patches/anthropic_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,17 @@ def _extract_messages(args, kwargs) -> tuple[list[dict[str, str]], str | None, d


def _apply_redaction(args, kwargs, redacted: list[dict[str, str]]) -> dict:
kwargs = dict(kwargs)
system = kwargs.get("system")
messages = kwargs.get("messages")
new_kwargs: dict = dict(kwargs)
system = new_kwargs.get("system")
messages = new_kwargs.get("messages")
has_system = system is not None
offset = 1 if has_system else 0

if has_system and redacted:
kwargs["system"] = redacted[0]["content"]
new_kwargs["system"] = redacted[0]["content"]

if messages and redacted:
result = []
result: list = []
for i, msg in enumerate(messages):
idx = i + offset
if idx < len(redacted):
Expand All @@ -101,9 +101,9 @@ def _apply_redaction(args, kwargs, redacted: list[dict[str, str]]) -> dict:
result.append(msg)
else:
result.append(msg)
kwargs["messages"] = result
new_kwargs["messages"] = result

return kwargs
return new_kwargs


def _extract_response_content(response: Any) -> str | None:
Expand Down
6 changes: 4 additions & 2 deletions promptguard/patches/cohere_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,17 @@ def _extract_messages(args, kwargs) -> tuple[list[dict[str, str]], str | None, d
def _extract_response_text(response: Any) -> str | None:
try:
if hasattr(response, "text"):
return response.text
text: str | None = response.text
return text
if hasattr(response, "message") and hasattr(response.message, "content"):
parts = response.message.content
if isinstance(parts, list):
texts = [p.text for p in parts if hasattr(p, "text")]
return "\n".join(texts) if texts else None
return str(parts) if parts else None
if isinstance(response, dict):
return response.get("text", response.get("message", {}).get("content"))
dict_text: str | None = response.get("text", response.get("message", {}).get("content"))
return dict_text
except Exception:
logger.debug("Failed to extract Cohere response text", exc_info=True)
return None
Expand Down
6 changes: 4 additions & 2 deletions promptguard/patches/google_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ def _extract_messages(
def _extract_response_text(response: Any) -> str | None:
try:
if hasattr(response, "text"):
return response.text
text: str | None = response.text
return text
if hasattr(response, "candidates") and response.candidates:
parts = response.candidates[0].content.parts
return _extract_text_from_parts(parts)
from_parts: str | None = _extract_text_from_parts(parts)
return from_parts
except Exception:
logger.debug("Failed to extract Google response text", exc_info=True)
return None
Expand Down
19 changes: 11 additions & 8 deletions promptguard/patches/openai_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

def _messages_to_guard_format(messages: Any) -> list[dict[str, str]]:
"""Convert OpenAI-style messages to the guard API format."""
result = []
result: list[dict[str, str]] = []
if not messages:
return result
for msg in messages:
Expand Down Expand Up @@ -68,8 +68,9 @@ def _apply_redaction(args, kwargs, redacted: list[dict[str, str]]) -> dict:
"""Apply redacted content back into kwargs."""
messages = kwargs.get("messages") or (args[1] if len(args) > 1 else None)
if not messages or not redacted:
return kwargs
result = []
new_kwargs: dict = dict(kwargs)
return new_kwargs
result: list = []
for i, msg in enumerate(messages):
if i < len(redacted):
if isinstance(msg, dict):
Expand All @@ -80,9 +81,9 @@ def _apply_redaction(args, kwargs, redacted: list[dict[str, str]]) -> dict:
result.append(msg)
else:
result.append(msg)
kwargs = dict(kwargs)
kwargs["messages"] = result
return kwargs
new_kwargs = dict(kwargs)
new_kwargs["messages"] = result
return new_kwargs


def _extract_response_content(response: Any) -> str | None:
Expand All @@ -91,11 +92,13 @@ def _extract_response_content(response: Any) -> str | None:
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and hasattr(choice.message, "content"):
return choice.message.content
content: str | None = choice.message.content
return content
if isinstance(response, dict):
choices = response.get("choices", [])
if choices:
return choices[0].get("message", {}).get("content")
msg_content: str | None = choices[0].get("message", {}).get("content")
return msg_content
except Exception:
logger.debug("Failed to extract OpenAI response text", exc_info=True)
return None
Expand Down
Loading