Skip to content

Commit db8f469

Browse files
committed
implement bytestream-handling & serialization methods
simplify bytestream handling and add test Signed-off-by: Oliver Guggenbühl <[email protected]>
1 parent b5eb395 commit db8f469

File tree

2 files changed

+176
-27
lines changed

2 files changed

+176
-27
lines changed

docling_haystack/converter.py

Lines changed: 98 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
"""Docling Haystack converter module."""
77

8+
import tempfile
89
from abc import ABC, abstractmethod
910
from enum import Enum
1011
from pathlib import Path
@@ -13,7 +14,10 @@
1314
from docling.chunking import BaseChunk, BaseChunker, HybridChunker
1415
from docling.datamodel.document import DoclingDocument
1516
from docling.document_converter import DocumentConverter
16-
from haystack import Document, component
17+
from haystack import Document, component, default_from_dict, default_to_dict, logging
18+
from haystack.dataclasses.byte_stream import ByteStream
19+
20+
logger = logging.getLogger(__name__)
1721

1822

1923
class ExportType(str, Enum):
@@ -100,42 +104,109 @@ def __init__(
100104
)
101105
self._meta_extractor = meta_extractor or MetaExtractor()
102106

107+
def _handle_bytestream(self, bytestream: ByteStream) -> tuple[str, bool]:
108+
"""Save ByteStream to a temporary file if needed."""
109+
suffix = (
110+
f".{bytestream.meta.get('file_extension', '')}"
111+
if bytestream.meta.get("file_extension")
112+
else None
113+
)
114+
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
115+
temp_file.write(bytestream.data)
116+
temp_file.close()
117+
return temp_file.name, True
118+
103119
@component.output_types(documents=list[Document])
104120
def run(
105121
self,
106-
paths: Iterable[Union[Path, str]],
122+
paths: Iterable[Union[Path, str, ByteStream]],
107123
):
108124
"""Run the DoclingConverter.
109125
110126
Args:
111-
paths: The input document locations, either as local paths or URLs.
127+
paths: The input document locations, either as local paths, URLs, or ByteStream objects.
112128
113129
Returns:
114130
list[Document]: The output Haystack Documents.
115131
"""
116132
documents: list[Document] = []
117-
for filepath in paths:
118-
dl_doc = self._converter.convert(
119-
source=filepath,
120-
**self._convert_kwargs,
121-
).document
122-
123-
if self._export_type == ExportType.DOC_CHUNKS:
124-
chunk_iter = self._chunker.chunk(dl_doc=dl_doc)
125-
hs_docs = [
126-
Document(
127-
content=self._chunker.serialize(chunk=chunk),
128-
meta=self._meta_extractor.extract_chunk_meta(chunk=chunk),
133+
temp_files = [] # Track temporary files to clean up later
134+
135+
try:
136+
for source in paths:
137+
try:
138+
if isinstance(source, ByteStream):
139+
filepath, is_temp = self._handle_bytestream(source)
140+
if is_temp:
141+
temp_files.append(filepath)
142+
else:
143+
filepath = str(source)
144+
145+
dl_doc = self._converter.convert(
146+
source=filepath,
147+
**self._convert_kwargs,
148+
).document
149+
150+
if self._export_type == ExportType.DOC_CHUNKS:
151+
chunk_iter = self._chunker.chunk(dl_doc=dl_doc)
152+
hs_docs = [
153+
Document(
154+
content=self._chunker.serialize(chunk=chunk),
155+
meta=self._meta_extractor.extract_chunk_meta(
156+
chunk=chunk
157+
),
158+
)
159+
for chunk in chunk_iter
160+
]
161+
documents.extend(hs_docs)
162+
elif self._export_type == ExportType.MARKDOWN:
163+
hs_doc = Document(
164+
content=dl_doc.export_to_markdown(**self._md_export_kwargs),
165+
meta=self._meta_extractor.extract_dl_doc_meta(
166+
dl_doc=dl_doc
167+
),
168+
)
169+
documents.append(hs_doc)
170+
else:
171+
raise RuntimeError(
172+
f"Unexpected export type: {self._export_type}"
173+
)
174+
except Exception as e:
175+
logger.warning(
176+
"Could not process {source}. Skipping it. Error: {error}",
177+
source=source,
178+
error=e,
129179
)
130-
for chunk in chunk_iter
131-
]
132-
documents.extend(hs_docs)
133-
elif self._export_type == ExportType.MARKDOWN:
134-
hs_doc = Document(
135-
content=dl_doc.export_to_markdown(**self._md_export_kwargs),
136-
meta=self._meta_extractor.extract_dl_doc_meta(dl_doc=dl_doc),
137-
)
138-
documents.append(hs_doc)
139-
else:
140-
raise RuntimeError(f"Unexpected export type: {self._export_type}")
141-
return {"documents": documents}
180+
return {"documents": documents}
181+
finally: # cleanup
182+
for temp_file in temp_files:
183+
try:
184+
Path(temp_file).unlink()
185+
except Exception as e:
186+
logger.debug(f"Failed to delete temporary file {temp_file}: {e}")
187+
188+
def to_dict(self) -> dict[str, Any]:
189+
"""
190+
Serialize the component to a dictionary for pipeline persistence.
191+
192+
Returns:
193+
dict[str, Any]: A dictionary representation of the component
194+
"""
195+
return default_to_dict(
196+
self,
197+
convert_kwargs=self._convert_kwargs,
198+
md_export_kwargs=self._md_export_kwargs,
199+
)
200+
201+
@classmethod
202+
def from_dict(cls, data: dict[str, Any]) -> "DoclingConverter":
203+
"""
204+
Deserialize the component from a dictionary.
205+
206+
Args:
207+
data: Dictionary representation of the component
208+
209+
Returns:
210+
DoclingConverter: A new instance of the component
211+
"""
212+
return default_from_dict(cls, data)

test/test_converter.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from docling.chunking import HybridChunker
55
from docling.datamodel.document import DoclingDocument
6+
from haystack.dataclasses.byte_stream import ByteStream
67

78
from docling_haystack.converter import DoclingConverter, ExportType
89

@@ -80,3 +81,80 @@ def test_convert_markdown(monkeypatch):
8081
with open(EXPECTED_OUT_FILE) as f:
8182
exp_data = json.load(fp=f)
8283
assert exp_data == act_data
84+
85+
86+
def test_serialization_deserialization():
87+
"""Test component serialization and deserialization."""
88+
converter = DoclingConverter(
89+
convert_kwargs={"optimize_ocr": True},
90+
md_export_kwargs={"image_placeholder": "[IMAGE]"},
91+
)
92+
93+
# serialize the component to dict
94+
serialized = converter.to_dict()
95+
96+
assert "init_parameters" in serialized
97+
assert serialized["init_parameters"].get("convert_kwargs") == {"optimize_ocr": True}
98+
99+
md_export_kwargs = serialized["init_parameters"].get("md_export_kwargs", {})
100+
assert md_export_kwargs.get("image_placeholder") == "[IMAGE]"
101+
102+
# deserialize back to component
103+
deserialized = DoclingConverter.from_dict(serialized)
104+
assert deserialized._convert_kwargs == {"optimize_ocr": True}
105+
106+
assert deserialized._md_export_kwargs.get("image_placeholder") == "[IMAGE]"
107+
108+
109+
def test_bytestream_handling(monkeypatch):
110+
"""Test conversion from ByteStream."""
111+
with open("test/data/2408.09869v5.md", "rb") as f:
112+
data = f.read()
113+
114+
bytestream = ByteStream(
115+
data=data,
116+
meta={"file_extension": "md", "filename": "test_file.md"},
117+
)
118+
convert_mock = MagicMock()
119+
120+
with open("test/data/2408.09869v5.json") as f:
121+
data_json = f.read()
122+
mock_dl_doc = DoclingDocument.model_validate_json(data_json)
123+
124+
mock_response = MagicMock()
125+
mock_response.document = mock_dl_doc
126+
convert_mock.return_value = mock_response
127+
128+
monkeypatch.setattr(
129+
"docling.document_converter.DocumentConverter.__init__",
130+
lambda *args, **kwargs: None,
131+
)
132+
monkeypatch.setattr(
133+
"docling.document_converter.DocumentConverter.convert",
134+
convert_mock, # use our mock that captures the filepath
135+
)
136+
137+
def mock_extract_meta(self, dl_doc):
138+
return {"custom_field": "test_value"}
139+
140+
monkeypatch.setattr(
141+
"docling_haystack.converter.MetaExtractor.extract_dl_doc_meta",
142+
mock_extract_meta,
143+
)
144+
145+
converter = DoclingConverter(
146+
export_type=ExportType.MARKDOWN,
147+
)
148+
149+
# ByteStream directly in the paths parameter
150+
result = converter.run(paths=[bytestream])
151+
documents = result["documents"]
152+
153+
assert convert_mock.called
154+
filepath_arg = convert_mock.call_args[1]["source"]
155+
assert isinstance(filepath_arg, str)
156+
assert filepath_arg.endswith(".md")
157+
158+
assert len(documents) > 0
159+
assert documents[0].meta.get("custom_field") == "test_value"
160+
assert len(documents[0].content) > 0

0 commit comments

Comments
 (0)