diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py index 8c4cded196..fd9ebcb8bf 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_constants.py @@ -1,6 +1,16 @@ # Copyright (c) Microsoft. All rights reserved. -"""Constants for Azure Functions Agent Framework integration.""" +"""Constants for Azure Functions Agent Framework integration. + +This module contains: +- Runtime configuration constants (polling, MIME types, headers) +- JSON field name mappings for camelCase (JSON) ↔ snake_case (Python) serialization + +For serialization constants, use the DurableStateFields, ContentTypes, and EntryTypes classes +to ensure consistent field naming between to_dict() and from_dict() methods. +""" + +from typing import Final # Supported request/response formats and MIME types REQUEST_RESPONSE_FORMAT_JSON: str = "json" @@ -17,3 +27,104 @@ # Polling configuration DEFAULT_MAX_POLL_RETRIES: int = 30 DEFAULT_POLL_INTERVAL_SECONDS: float = 1.0 + + +# ============================================================================= +# JSON Field Name Constants for Durable Agent State Serialization +# ============================================================================= +# These constants ensure consistent camelCase field names in JSON serialization. +# Use these in both to_dict() and from_dict() methods to prevent mismatches. + +# NOTE: Changing these constants is a breaking change and might require a schema version bump. + + +class DurableStateFields: + """JSON field name constants for durable agent state serialization. + + All field names are in camelCase to match the JSON schema. + Use these constants in both to_dict() and from_dict() methods. + """ + + # Schema-level fields + SCHEMA_VERSION: Final[str] = "schemaVersion" + DATA: Final[str] = "data" + + # Entry discriminator + TYPE_DISCRIMINATOR: Final[str] = "$type" + + # Internal field names + JSON_TYPE: Final[str] = "json_type" + TYPE_INTERNAL: Final[str] = "type" + + # Common entry fields + CORRELATION_ID: Final[str] = "correlationId" + CREATED_AT: Final[str] = "createdAt" + MESSAGES: Final[str] = "messages" + EXTENSION_DATA: Final[str] = "extensionData" + + # Request-specific fields + RESPONSE_TYPE: Final[str] = "responseType" + RESPONSE_SCHEMA: Final[str] = "responseSchema" + ORCHESTRATION_ID: Final[str] = "orchestrationId" + + # Response-specific fields + USAGE: Final[str] = "usage" + + # Message fields + ROLE: Final[str] = "role" + CONTENTS: Final[str] = "contents" + AUTHOR_NAME: Final[str] = "authorName" + + # Content fields + TEXT: Final[str] = "text" + URI: Final[str] = "uri" + MEDIA_TYPE: Final[str] = "mediaType" + MESSAGE: Final[str] = "message" + ERROR_CODE: Final[str] = "errorCode" + DETAILS: Final[str] = "details" + CALL_ID: Final[str] = "callId" + NAME: Final[str] = "name" + ARGUMENTS: Final[str] = "arguments" + RESULT: Final[str] = "result" + FILE_ID: Final[str] = "fileId" + VECTOR_STORE_ID: Final[str] = "vectorStoreId" + CONTENT: Final[str] = "content" + + # Usage fields (noqa: S105 - these are JSON field names, not passwords) + INPUT_TOKEN_COUNT: Final[str] = "inputTokenCount" # noqa: S105 + OUTPUT_TOKEN_COUNT: Final[str] = "outputTokenCount" # noqa: S105 + TOTAL_TOKEN_COUNT: Final[str] = "totalTokenCount" # noqa: S105 + + # History field + CONVERSATION_HISTORY: Final[str] = "conversationHistory" + + +class ContentTypes: + """Content type discriminator values for the $type field. + + These values are used in the JSON $type field to identify content types. + """ + + TEXT: Final[str] = "text" + DATA: Final[str] = "data" + ERROR: Final[str] = "error" + FUNCTION_CALL: Final[str] = "functionCall" + FUNCTION_RESULT: Final[str] = "functionResult" + HOSTED_FILE: Final[str] = "hostedFile" + HOSTED_VECTOR_STORE: Final[str] = "hostedVectorStore" + REASONING: Final[str] = "reasoning" + URI: Final[str] = "uri" + USAGE: Final[str] = "usage" + UNKNOWN: Final[str] = "unknown" + + +class ApiResponseFields: + """Field names for HTTP API responses (not part of persisted schema). + + These are used in try_get_agent_response() for backward compatibility + with the HTTP API response format. + """ + + CONTENT: Final[str] = "content" + MESSAGE_COUNT: Final[str] = "message_count" + CORRELATION_ID: Final[str] = "correlationId" diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py index ffb71d2367..8982e7b4f2 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -53,11 +53,22 @@ ) from dateutil import parser as date_parser +from ._constants import ApiResponseFields, ContentTypes, DurableStateFields from ._models import RunRequest, serialize_response_format logger = get_logger("agent_framework.azurefunctions.durable_agent_state") +class DurableAgentStateEntryJsonType(str, Enum): + """Enum for conversation history entry types. + + Discriminator values for the $type field in DurableAgentStateEntry objects. + """ + + REQUEST = "request" + RESPONSE = "response" + + def _parse_created_at(value: Any) -> datetime: """Normalize created_at values coming from persisted durable state.""" if isinstance(value, datetime): @@ -71,6 +82,7 @@ def _parse_created_at(value: Any) -> datetime: except (ValueError, TypeError): pass + logger.warning("Invalid or missing created_at value in durable agent state; defaulting to current UTC time.") return datetime.now(tz=timezone.utc) @@ -84,7 +96,7 @@ def _parse_messages(data: dict[str, Any]) -> list[DurableAgentStateMessage]: List of DurableAgentStateMessage objects """ messages: list[DurableAgentStateMessage] = [] - raw_messages: list[Any] = data.get("messages", []) + raw_messages: list[Any] = data.get(DurableStateFields.MESSAGES, []) for raw_msg in raw_messages: if isinstance(raw_msg, dict): messages.append(DurableAgentStateMessage.from_dict(cast(dict[str, Any], raw_msg))) @@ -102,12 +114,14 @@ def _parse_history_entries(data_dict: dict[str, Any]) -> list[DurableAgentStateE Returns: List of DurableAgentStateEntry objects (requests and responses) """ - history_data: list[Any] = data_dict.get("conversationHistory", []) + history_data: list[Any] = data_dict.get(DurableStateFields.CONVERSATION_HISTORY, []) deserialized_history: list[DurableAgentStateEntry] = [] for raw_entry in history_data: if isinstance(raw_entry, dict): entry_dict = cast(dict[str, Any], raw_entry) - entry_type = entry_dict.get("$type") or entry_dict.get("json_type") + entry_type = entry_dict.get(DurableStateFields.TYPE_DISCRIMINATOR) or entry_dict.get( + DurableStateFields.JSON_TYPE + ) if entry_type == DurableAgentStateEntryJsonType.RESPONSE: deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict)) elif entry_type == DurableAgentStateEntryJsonType.REQUEST: @@ -129,72 +143,95 @@ def _parse_contents(data: dict[str, Any]) -> list[DurableAgentStateContent]: List of DurableAgentStateContent objects """ contents: list[DurableAgentStateContent] = [] - raw_contents: list[Any] = data.get("contents", []) + raw_contents: list[Any] = data.get(DurableStateFields.CONTENTS, []) for raw_content in raw_contents: - if isinstance(raw_content, dict): + if isinstance(raw_content, DurableAgentStateContent): + contents.append(raw_content) + + elif isinstance(raw_content, dict): content_dict = cast(dict[str, Any], raw_content) - content_type: str | None = content_dict.get("$type") - if content_type == DurableAgentStateTextContent.type: - contents.append(DurableAgentStateTextContent(text=content_dict.get("text"))) - elif content_type == DurableAgentStateDataContent.type: - contents.append( - DurableAgentStateDataContent( - uri=str(content_dict.get("uri", "")), - media_type=content_dict.get("mediaType"), + content_type: str | None = content_dict.get(DurableStateFields.TYPE_DISCRIMINATOR) + + match content_type: + case ContentTypes.TEXT: + contents.append(DurableAgentStateTextContent(text=content_dict.get(DurableStateFields.TEXT))) + + case ContentTypes.DATA: + contents.append( + DurableAgentStateDataContent( + uri=str(content_dict.get(DurableStateFields.URI, "")), + media_type=content_dict.get(DurableStateFields.MEDIA_TYPE), + ) + ) + + case ContentTypes.ERROR: + contents.append( + DurableAgentStateErrorContent( + message=content_dict.get(DurableStateFields.MESSAGE), + error_code=content_dict.get(DurableStateFields.ERROR_CODE), + details=content_dict.get(DurableStateFields.DETAILS), + ) ) - ) - elif content_type == DurableAgentStateErrorContent.type: - contents.append( - DurableAgentStateErrorContent( - message=content_dict.get("message"), - error_code=content_dict.get("errorCode"), - details=content_dict.get("details"), + + case ContentTypes.FUNCTION_CALL: + contents.append( + DurableAgentStateFunctionCallContent( + call_id=str(content_dict.get(DurableStateFields.CALL_ID, "")), + name=str(content_dict.get(DurableStateFields.NAME, "")), + arguments=content_dict.get(DurableStateFields.ARGUMENTS, {}), + ) ) - ) - elif content_type == DurableAgentStateFunctionCallContent.type: - contents.append( - DurableAgentStateFunctionCallContent( - call_id=str(content_dict.get("callId", "")), - name=str(content_dict.get("name", "")), - arguments=content_dict.get("arguments", {}), + + case ContentTypes.FUNCTION_RESULT: + contents.append( + DurableAgentStateFunctionResultContent( + call_id=str(content_dict.get(DurableStateFields.CALL_ID, "")), + result=content_dict.get(DurableStateFields.RESULT), + ) ) - ) - elif content_type == DurableAgentStateFunctionResultContent.type: - contents.append( - DurableAgentStateFunctionResultContent( - call_id=str(content_dict.get("callId", "")), - result=content_dict.get("result"), + + case ContentTypes.HOSTED_FILE: + contents.append( + DurableAgentStateHostedFileContent( + file_id=str(content_dict.get(DurableStateFields.FILE_ID, "")) + ) ) - ) - elif content_type == DurableAgentStateHostedFileContent.type: - contents.append(DurableAgentStateHostedFileContent(file_id=str(content_dict.get("fileId", "")))) - elif content_type == DurableAgentStateHostedVectorStoreContent.type: - contents.append( - DurableAgentStateHostedVectorStoreContent( - vector_store_id=str(content_dict.get("vectorStoreId", "")) + + case ContentTypes.HOSTED_VECTOR_STORE: + contents.append( + DurableAgentStateHostedVectorStoreContent( + vector_store_id=str(content_dict.get(DurableStateFields.VECTOR_STORE_ID, "")) + ) ) - ) - elif content_type == DurableAgentStateTextReasoningContent.type: - contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text"))) - elif content_type == DurableAgentStateUriContent.type: - contents.append( - DurableAgentStateUriContent( - uri=str(content_dict.get("uri", "")), - media_type=str(content_dict.get("mediaType", "")), + + case ContentTypes.REASONING: + contents.append( + DurableAgentStateTextReasoningContent(text=content_dict.get(DurableStateFields.TEXT)) ) - ) - elif content_type == DurableAgentStateUsageContent.type: - usage_data = content_dict.get("usage") - if usage_data and isinstance(usage_data, dict): + + case ContentTypes.URI: contents.append( - DurableAgentStateUsageContent( - usage=DurableAgentStateUsage.from_dict(cast(dict[str, Any], usage_data)) + DurableAgentStateUriContent( + uri=str(content_dict.get(DurableStateFields.URI, "")), + media_type=str(content_dict.get(DurableStateFields.MEDIA_TYPE, "")), ) ) - elif content_type == DurableAgentStateUnknownContent.type: - contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content", {}))) - elif isinstance(raw_content, DurableAgentStateContent): - contents.append(raw_content) + + case ContentTypes.USAGE: + usage_data = content_dict.get(DurableStateFields.USAGE) + if usage_data and isinstance(usage_data, dict): + contents.append( + DurableAgentStateUsageContent( + usage=DurableAgentStateUsage.from_dict(cast(dict[str, Any], usage_data)) + ) + ) + + case ContentTypes.UNKNOWN | _: + # Handle UNKNOWN type or any unexpected content types (including None) + contents.append( + DurableAgentStateUnknownContent(content=content_dict.get(DurableStateFields.CONTENT, {})) + ) + return contents @@ -313,17 +350,17 @@ def __init__( def to_dict(self) -> dict[str, Any]: result: dict[str, Any] = { - "conversationHistory": [entry.to_dict() for entry in self.conversation_history], + DurableStateFields.CONVERSATION_HISTORY: [entry.to_dict() for entry in self.conversation_history], } if self.extension_data is not None: - result["extensionData"] = self.extension_data + result[DurableStateFields.EXTENSION_DATA] = self.extension_data return result @classmethod def from_dict(cls, data_dict: dict[str, Any]) -> DurableAgentStateData: return cls( conversation_history=_parse_history_entries(data_dict), - extension_data=data_dict.get("extensionData"), + extension_data=data_dict.get(DurableStateFields.EXTENSION_DATA), ) @@ -374,8 +411,8 @@ def __init__(self, schema_version: str = SCHEMA_VERSION): def to_dict(self) -> dict[str, Any]: return { - "schemaVersion": self.schema_version, - "data": self.data.to_dict(), + DurableStateFields.SCHEMA_VERSION: self.schema_version, + DurableStateFields.DATA: self.data.to_dict(), } def to_json(self) -> str: @@ -388,13 +425,13 @@ def from_dict(cls, state: dict[str, Any]) -> DurableAgentState: Args: state: Dictionary containing schemaVersion and data (full state structure) """ - schema_version = state.get("schemaVersion") + schema_version = state.get(DurableStateFields.SCHEMA_VERSION) if schema_version is None: logger.warning("Resetting state as it is incompatible with the current schema, all history will be lost") return cls() - instance = cls(schema_version=state.get("schemaVersion", DurableAgentState.SCHEMA_VERSION)) - instance.data = DurableAgentStateData.from_dict(state.get("data", {})) + instance = cls(schema_version=state.get(DurableStateFields.SCHEMA_VERSION, DurableAgentState.SCHEMA_VERSION)) + instance.data = DurableAgentStateData.from_dict(state.get(DurableStateFields.DATA, {})) return instance @@ -437,20 +474,14 @@ def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: # Get the text content from assistant messages only content = "\n".join(message.text for message in entry.messages if message.text) - return {"content": content, "message_count": self.message_count, "correlationId": correlation_id} + return { + ApiResponseFields.CONTENT: content, + ApiResponseFields.MESSAGE_COUNT: self.message_count, + ApiResponseFields.CORRELATION_ID: correlation_id, + } return None -class DurableAgentStateEntryJsonType(str, Enum): - """Enum for conversation history entry types. - - Discriminator values for the $type field in DurableAgentStateEntry objects. - """ - - REQUEST = "request" - RESPONSE = "response" - - class DurableAgentStateEntry: """Base class for conversation history entries (requests and responses). @@ -499,23 +530,23 @@ def __init__( def to_dict(self) -> dict[str, Any]: return { - "$type": self.json_type, - "correlationId": self.correlation_id, - "createdAt": self.created_at.isoformat(), - "messages": [m.to_dict() for m in self.messages], + DurableStateFields.TYPE_DISCRIMINATOR: self.json_type, + DurableStateFields.CORRELATION_ID: self.correlation_id, + DurableStateFields.CREATED_AT: self.created_at.isoformat(), + DurableStateFields.MESSAGES: [m.to_dict() for m in self.messages], } @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateEntry: - created_at = _parse_created_at(data.get("created_at")) + created_at = _parse_created_at(data.get(DurableStateFields.CREATED_AT)) messages = _parse_messages(data) return cls( - json_type=DurableAgentStateEntryJsonType(data.get("$type", "entry")), - correlation_id=data.get("correlationId", ""), + json_type=DurableAgentStateEntryJsonType(data.get(DurableStateFields.TYPE_DISCRIMINATOR)), + correlation_id=data.get(DurableStateFields.CORRELATION_ID), created_at=created_at, messages=messages, - extension_data=data.get("extensionData"), + extension_data=data.get(DurableStateFields.EXTENSION_DATA), ) @@ -564,26 +595,26 @@ def __init__( def to_dict(self) -> dict[str, Any]: data = super().to_dict() if self.orchestration_id is not None: - data["orchestrationId"] = self.orchestration_id + data[DurableStateFields.ORCHESTRATION_ID] = self.orchestration_id if self.response_type is not None: - data["responseType"] = self.response_type + data[DurableStateFields.RESPONSE_TYPE] = self.response_type if self.response_schema is not None: - data["responseSchema"] = self.response_schema + data[DurableStateFields.RESPONSE_SCHEMA] = self.response_schema return data @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest: - created_at = _parse_created_at(data.get("created_at")) + created_at = _parse_created_at(data.get(DurableStateFields.CREATED_AT)) messages = _parse_messages(data) return cls( - correlation_id=data.get("correlationId", ""), + correlation_id=data.get(DurableStateFields.CORRELATION_ID), created_at=created_at, messages=messages, - extension_data=data.get("extensionData"), - response_type=data.get("responseType"), - response_schema=data.get("responseSchema"), - orchestration_id=data.get("orchestrationId"), + extension_data=data.get(DurableStateFields.EXTENSION_DATA), + response_type=data.get(DurableStateFields.RESPONSE_TYPE), + response_schema=data.get(DurableStateFields.RESPONSE_SCHEMA), + orchestration_id=data.get(DurableStateFields.ORCHESTRATION_ID), ) @staticmethod @@ -592,7 +623,7 @@ def from_run_request(request: RunRequest) -> DurableAgentStateRequest: return DurableAgentStateRequest( correlation_id=request.correlation_id, messages=[DurableAgentStateMessage.from_run_request(request)], - created_at=datetime.now(tz=timezone.utc), + created_at=_parse_created_at(request.created_at), response_type=request.request_response_format, response_schema=serialize_response_format(request.response_format), orchestration_id=request.orchestration_id, @@ -620,7 +651,7 @@ class DurableAgentStateResponse(DurableAgentStateEntry): def __init__( self, - correlation_id: str, + correlation_id: str | None, created_at: datetime, messages: list[DurableAgentStateMessage], extension_data: dict[str, Any] | None = None, @@ -640,24 +671,24 @@ def __init__( def to_dict(self) -> dict[str, Any]: data = super().to_dict() if self.usage is not None: - data["usage"] = self.usage.to_dict() + data[DurableStateFields.USAGE] = self.usage.to_dict() return data @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateResponse: - created_at = _parse_created_at(data.get("created_at")) + created_at = _parse_created_at(data.get(DurableStateFields.CREATED_AT)) messages = _parse_messages(data) - usage_dict = data.get("usage") + usage_dict = data.get(DurableStateFields.USAGE) usage: DurableAgentStateUsage | None = None if usage_dict and isinstance(usage_dict, dict): usage = DurableAgentStateUsage.from_dict(cast(dict[str, Any], usage_dict)) return cls( - correlation_id=data.get("correlationId", ""), + correlation_id=data.get(DurableStateFields.CORRELATION_ID), created_at=created_at, messages=messages, - extension_data=data.get("extensionData"), + extension_data=data.get(DurableStateFields.EXTENSION_DATA), usage=usage, ) @@ -671,14 +702,6 @@ def from_run_response(correlation_id: str, response: AgentRunResponse) -> Durabl usage=DurableAgentStateUsage.from_usage(response.usage_details), ) - def to_run_response(self) -> Any: - """Converts this DurableAgentStateResponse back to an AgentRunResponse.""" - return AgentRunResponse( - created_at=self.created_at.isoformat() if self.created_at else None, - messages=[m.to_chat_message() for m in self.messages], - usage=self.usage.to_usage_details() if self.usage else None, - ) - class DurableAgentStateMessage: """Represents a message within a conversation history entry. @@ -717,27 +740,35 @@ def __init__( def to_dict(self) -> dict[str, Any]: result: dict[str, Any] = { - "role": self.role, - "contents": [ - {"$type": c.to_dict().get("type", "text"), **{k: v for k, v in c.to_dict().items() if k != "type"}} + DurableStateFields.ROLE: self.role, + DurableStateFields.CONTENTS: [ + { + DurableStateFields.TYPE_DISCRIMINATOR: c.to_dict().get( + DurableStateFields.TYPE_INTERNAL, ContentTypes.TEXT + ), + **{k: v for k, v in c.to_dict().items() if k != DurableStateFields.TYPE_INTERNAL}, + } for c in self.contents ], } # Only include optional fields if they have values if self.created_at is not None: - result["createdAt"] = self.created_at.isoformat() + result[DurableStateFields.CREATED_AT] = self.created_at.isoformat() if self.author_name is not None: - result["authorName"] = self.author_name + result[DurableStateFields.AUTHOR_NAME] = self.author_name return result @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateMessage: + data_created_at = data.get(DurableStateFields.CREATED_AT) + created_at = _parse_created_at(data_created_at) if data_created_at else None + return cls( - role=data.get("role", ""), + role=data.get(DurableStateFields.ROLE, ""), contents=_parse_contents(data), - author_name=data.get("authorName"), - created_at=_parse_created_at(data.get("createdAt")), - extension_data=data.get("extensionData"), + author_name=data.get(DurableStateFields.AUTHOR_NAME), + created_at=created_at, + extension_data=data.get(DurableStateFields.EXTENSION_DATA), ) @property @@ -761,7 +792,7 @@ def from_run_request(request: RunRequest) -> DurableAgentStateMessage: return DurableAgentStateMessage( role=request.role.value, contents=[DurableAgentStateTextContent(text=request.message)], - created_at=_parse_created_at(request.created_at), + created_at=_parse_created_at(request.created_at) if request.created_at else None, ) @staticmethod @@ -823,14 +854,18 @@ class DurableAgentStateDataContent(DurableAgentStateContent): uri: str = "" media_type: str | None = None - type: str = "data" + type: str = ContentTypes.DATA def __init__(self, uri: str, media_type: str | None = None) -> None: self.uri = uri self.media_type = media_type def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "uri": self.uri, "mediaType": self.media_type} + return { + DurableStateFields.TYPE_DISCRIMINATOR: self.type, + DurableStateFields.URI: self.uri, + DurableStateFields.MEDIA_TYPE: self.media_type, + } @staticmethod def from_data_content(content: DataContent) -> DurableAgentStateDataContent: @@ -856,7 +891,7 @@ class DurableAgentStateErrorContent(DurableAgentStateContent): error_code: str | None = None details: str | None = None - type: str = "error" + type: str = ContentTypes.ERROR def __init__(self, message: str | None = None, error_code: str | None = None, details: str | None = None) -> None: self.message = message @@ -864,7 +899,12 @@ def __init__(self, message: str | None = None, error_code: str | None = None, de self.details = details def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "message": self.message, "errorCode": self.error_code, "details": self.details} + return { + DurableStateFields.TYPE_DISCRIMINATOR: self.type, + DurableStateFields.MESSAGE: self.message, + DurableStateFields.ERROR_CODE: self.error_code, + DurableStateFields.DETAILS: self.details, + } @staticmethod def from_error_content(content: ErrorContent) -> DurableAgentStateErrorContent: @@ -893,7 +933,7 @@ class DurableAgentStateFunctionCallContent(DurableAgentStateContent): name: str arguments: dict[str, Any] - type: str = "functionCall" + type: str = ContentTypes.FUNCTION_CALL def __init__(self, call_id: str, name: str, arguments: dict[str, Any]) -> None: self.call_id = call_id @@ -901,7 +941,12 @@ def __init__(self, call_id: str, name: str, arguments: dict[str, Any]) -> None: self.arguments = arguments def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "callId": self.call_id, "name": self.name, "arguments": self.arguments} + return { + DurableStateFields.TYPE_DISCRIMINATOR: self.type, + DurableStateFields.CALL_ID: self.call_id, + DurableStateFields.NAME: self.name, + DurableStateFields.ARGUMENTS: self.arguments, + } @staticmethod def from_function_call_content(content: FunctionCallContent) -> DurableAgentStateFunctionCallContent: @@ -938,14 +983,18 @@ class DurableAgentStateFunctionResultContent(DurableAgentStateContent): call_id: str result: object | None = None - type: str = "functionResult" + type: str = ContentTypes.FUNCTION_RESULT def __init__(self, call_id: str, result: Any | None = None) -> None: self.call_id = call_id self.result = result def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "callId": self.call_id, "result": self.result} + return { + DurableStateFields.TYPE_DISCRIMINATOR: self.type, + DurableStateFields.CALL_ID: self.call_id, + DurableStateFields.RESULT: self.result, + } @staticmethod def from_function_result_content(content: FunctionResultContent) -> DurableAgentStateFunctionResultContent: @@ -967,13 +1016,13 @@ class DurableAgentStateHostedFileContent(DurableAgentStateContent): file_id: str - type: str = "hostedFile" + type: str = ContentTypes.HOSTED_FILE def __init__(self, file_id: str) -> None: self.file_id = file_id def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "fileId": self.file_id} + return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.FILE_ID: self.file_id} @staticmethod def from_hosted_file_content(content: HostedFileContent) -> DurableAgentStateHostedFileContent: @@ -996,13 +1045,16 @@ class DurableAgentStateHostedVectorStoreContent(DurableAgentStateContent): vector_store_id: str - type: str = "hostedVectorStore" + type: str = ContentTypes.HOSTED_VECTOR_STORE def __init__(self, vector_store_id: str) -> None: self.vector_store_id = vector_store_id def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "vectorStoreId": self.vector_store_id} + return { + DurableStateFields.TYPE_DISCRIMINATOR: self.type, + DurableStateFields.VECTOR_STORE_ID: self.vector_store_id, + } @staticmethod def from_hosted_vector_store_content( @@ -1024,13 +1076,13 @@ class DurableAgentStateTextContent(DurableAgentStateContent): text: The text content of the message """ - type: str = "text" + type: str = ContentTypes.TEXT def __init__(self, text: str | None) -> None: self.text = text def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "text": self.text} + return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.TEXT: self.text} @staticmethod def from_text_content(content: TextContent) -> DurableAgentStateTextContent: @@ -1050,13 +1102,13 @@ class DurableAgentStateTextReasoningContent(DurableAgentStateContent): text: The reasoning or thought process text """ - type: str = "reasoning" + type: str = ContentTypes.REASONING def __init__(self, text: str | None) -> None: self.text = text def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "text": self.text} + return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.TEXT: self.text} @staticmethod def from_text_reasoning_content(content: TextReasoningContent) -> DurableAgentStateTextReasoningContent: @@ -1080,14 +1132,18 @@ class DurableAgentStateUriContent(DurableAgentStateContent): uri: str media_type: str - type: str = "uri" + type: str = ContentTypes.URI def __init__(self, uri: str, media_type: str) -> None: self.uri = uri self.media_type = media_type def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "uri": self.uri, "mediaType": self.media_type} + return { + DurableStateFields.TYPE_DISCRIMINATOR: self.type, + DurableStateFields.URI: self.uri, + DurableStateFields.MEDIA_TYPE: self.media_type, + } @staticmethod def from_uri_content(content: UriContent) -> DurableAgentStateUriContent: @@ -1130,21 +1186,21 @@ def __init__( def to_dict(self) -> dict[str, Any]: result: dict[str, Any] = { - "inputTokenCount": self.input_token_count, - "outputTokenCount": self.output_token_count, - "totalTokenCount": self.total_token_count, + DurableStateFields.INPUT_TOKEN_COUNT: self.input_token_count, + DurableStateFields.OUTPUT_TOKEN_COUNT: self.output_token_count, + DurableStateFields.TOTAL_TOKEN_COUNT: self.total_token_count, } if self.extensionData is not None: - result["extensionData"] = self.extensionData + result[DurableStateFields.EXTENSION_DATA] = self.extensionData return result @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateUsage: return cls( - input_token_count=data.get("inputTokenCount"), - output_token_count=data.get("outputTokenCount"), - total_token_count=data.get("totalTokenCount"), - extensionData=data.get("extensionData"), + input_token_count=data.get(DurableStateFields.INPUT_TOKEN_COUNT), + output_token_count=data.get(DurableStateFields.OUTPUT_TOKEN_COUNT), + total_token_count=data.get(DurableStateFields.TOTAL_TOKEN_COUNT), + extensionData=data.get(DurableStateFields.EXTENSION_DATA), ) @staticmethod @@ -1179,17 +1235,20 @@ class DurableAgentStateUsageContent(DurableAgentStateContent): usage: DurableAgentStateUsage = DurableAgentStateUsage() - type: str = "usage" + type: str = ContentTypes.USAGE - def __init__(self, usage: DurableAgentStateUsage) -> None: - self.usage = usage + def __init__(self, usage: DurableAgentStateUsage | None) -> None: + self.usage = usage if usage is not None else DurableAgentStateUsage() def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "usage": self.usage.to_dict() if hasattr(self.usage, "to_dict") else self.usage} + return { + DurableStateFields.TYPE_DISCRIMINATOR: self.type, + DurableStateFields.USAGE: self.usage.to_dict(), + } @staticmethod def from_usage_content(content: UsageContent) -> DurableAgentStateUsageContent: - return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details)) # type: ignore + return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details)) def to_ai_content(self) -> UsageContent: return UsageContent(details=self.usage.to_usage_details()) @@ -1208,13 +1267,13 @@ class DurableAgentStateUnknownContent(DurableAgentStateContent): content: Any - type: str = "unknown" + type: str = ContentTypes.UNKNOWN def __init__(self, content: Any) -> None: self.content = content def to_dict(self) -> dict[str, Any]: - return {"$type": self.type, "content": self.content} + return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.CONTENT: self.content} @staticmethod def from_unknown_content(content: Any) -> DurableAgentStateUnknownContent: diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 1c3f5168a5..66f39861a1 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -1002,5 +1002,40 @@ def test_request_from_run_request_without_orchestration_id(self) -> None: assert durable_request.orchestration_id is None +class TestDurableAgentStateMessageCreatedAt: + """Test suite for DurableAgentStateMessage created_at field handling.""" + + def test_message_from_run_request_without_created_at_preserves_none(self) -> None: + """Test from_run_request preserves None created_at instead of defaulting to current time. + + When a RunRequest has no created_at value, the resulting DurableAgentStateMessage + should also have None for created_at, not default to current UTC time. + """ + run_request = RunRequest( + message="test message", + correlation_id="corr-run", + created_at=None, # Explicitly None + ) + + durable_message = DurableAgentStateMessage.from_run_request(run_request) + + assert durable_message.created_at is None + + def test_message_from_run_request_with_created_at_parses_correctly(self) -> None: + """Test from_run_request correctly parses a valid created_at timestamp.""" + run_request = RunRequest( + message="test message", + correlation_id="corr-run", + created_at="2024-01-15T10:30:00Z", + ) + + durable_message = DurableAgentStateMessage.from_run_request(run_request) + + assert durable_message.created_at is not None + assert durable_message.created_at.year == 2024 + assert durable_message.created_at.month == 1 + assert durable_message.created_at.day == 15 + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/samples/getting_started/azure_functions/03_callbacks/README.md b/python/samples/getting_started/azure_functions/03_callbacks/README.md index c021300561..09e50bcfd1 100644 --- a/python/samples/getting_started/azure_functions/03_callbacks/README.md +++ b/python/samples/getting_started/azure_functions/03_callbacks/README.md @@ -19,8 +19,9 @@ an HTTP API that can be polled by a web client or dashboard. Complete the shared environment setup steps in `../README.md`, including creating a virtual environment, installing dependencies, and configuring Azure OpenAI credentials and storage settings. -> **Note:** The sample stores callback events in memory for simplicity. For production scenarios you -> should persist events to Application Insights, Azure Storage, Cosmos DB, or another durable store. +> **Note:** This is a streaming example that currently uses a local in-memory store for simplicity. +> For distributed environments, consider using Redis, Service Bus, or another pub/sub mechanism for +> callback coordination. ## Running the Sample diff --git a/python/samples/getting_started/azure_functions/03_callbacks/function_app.py b/python/samples/getting_started/azure_functions/03_callbacks/function_app.py index 1ac3588724..e6702f6586 100644 --- a/python/samples/getting_started/azure_functions/03_callbacks/function_app.py +++ b/python/samples/getting_started/azure_functions/03_callbacks/function_app.py @@ -27,7 +27,9 @@ logger = logging.getLogger(__name__) -# 1. Maintain an in-memory store for callback events keyed by thread ID (replace with durable storage in production). +# 1. Maintain an in-memory store for callback events keyed by thread ID. +# NOTE: This is a streaming example using a local console logger. For distributed environments, +# consider using Redis or Service Bus for callback coordination across multiple instances. CallbackStore = DefaultDict[str, list[dict[str, Any]]] callback_events: CallbackStore = defaultdict(list)