-
Notifications
You must be signed in to change notification settings - Fork 757
fix(apple): Handle partial UTF-8 sequences in streaming LLM output #16219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,141 @@ | |
| #import "ExecuTorchLLMError.h" | ||
|
|
||
| #import <executorch/extension/llm/runner/text_llm_runner.h> | ||
| #import <memory> | ||
|
|
||
| using namespace executorch::extension; | ||
| using namespace executorch::runtime; | ||
|
|
||
| namespace { | ||
|
|
||
| /// A streaming UTF-8 buffer that accumulates bytes until complete UTF-8 | ||
| /// sequences are formed. This handles the case where BPE tokenizers output | ||
| /// partial multi-byte UTF-8 sequences across token boundaries. | ||
| /// | ||
| /// For example, the Chinese character "清" (UTF-8: E6 B8 85) might be split | ||
| /// across two tokens: "æ¸" (E6 B8) and "ħ" (85). This buffer accumulates | ||
| /// bytes and only emits complete, valid UTF-8 strings. | ||
| class UTF8StreamingBuffer { | ||
| public: | ||
| UTF8StreamingBuffer() = default; | ||
|
|
||
| /// Process incoming token bytes and return any complete UTF-8 string. | ||
| /// Returns empty string if more bytes are needed to complete a sequence. | ||
| /// Invalid bytes are silently skipped to maintain robustness. | ||
| std::string process(const std::string& token) { | ||
| buffer_.append(token); | ||
|
|
||
| std::string result; | ||
| size_t i = 0; | ||
|
|
||
| while (i < buffer_.size()) { | ||
| unsigned char byte = static_cast<unsigned char>(buffer_[i]); | ||
| size_t seqLen = utf8SequenceLength(byte); | ||
|
|
||
| if (seqLen == 0) { | ||
| // Invalid start byte (lone continuation or illegal byte) - skip it | ||
| i++; | ||
| continue; | ||
| } | ||
|
|
||
| if (i + seqLen > buffer_.size()) { | ||
| // Incomplete sequence at the end - keep in buffer for next call | ||
| break; | ||
| } | ||
|
|
||
| // Verify all continuation bytes are valid | ||
| bool valid = true; | ||
| for (size_t j = 1; j < seqLen; j++) { | ||
| if (!isUTF8Continuation(static_cast<unsigned char>(buffer_[i + j]))) { | ||
| valid = false; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (valid) { | ||
| // Append complete valid sequence to result | ||
| result.append(buffer_, i, seqLen); | ||
| i += seqLen; | ||
| } else { | ||
| // Invalid sequence - skip only the start byte and resync | ||
| i++; | ||
| } | ||
| } | ||
|
|
||
| // Keep only the incomplete sequence (if any) for next call | ||
| if (i < buffer_.size()) { | ||
| buffer_ = buffer_.substr(i); | ||
| } else { | ||
| buffer_.clear(); | ||
| } | ||
|
|
||
| return result; | ||
| } | ||
|
|
||
| /// Flush any remaining bytes in the buffer. | ||
| /// Called at the end of generation to emit any leftover content. | ||
| /// Skips any invalid bytes that couldn't form valid UTF-8. | ||
| std::string flush() { | ||
| std::string result; | ||
|
|
||
| for (size_t i = 0; i < buffer_.size(); i++) { | ||
| unsigned char byte = static_cast<unsigned char>(buffer_[i]); | ||
| size_t seqLen = utf8SequenceLength(byte); | ||
|
|
||
| // Skip invalid start bytes | ||
| if (seqLen == 0) { | ||
| continue; | ||
| } | ||
|
|
||
| // Check if we have enough bytes for this sequence | ||
| if (i + seqLen > buffer_.size()) { | ||
| // Incomplete sequence - skip remaining bytes | ||
| break; | ||
| } | ||
|
|
||
| // Verify continuation bytes | ||
| bool valid = true; | ||
| for (size_t j = 1; j < seqLen; j++) { | ||
| if (!isUTF8Continuation(static_cast<unsigned char>(buffer_[i + j]))) { | ||
| valid = false; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (valid) { | ||
| result.append(buffer_, i, seqLen); | ||
| i += seqLen - 1; // -1 because loop will i++ | ||
| } | ||
| } | ||
|
|
||
| buffer_.clear(); | ||
| return result; | ||
| } | ||
|
|
||
| private: | ||
| std::string buffer_; | ||
|
|
||
| /// Returns the number of bytes expected for a UTF-8 sequence starting with | ||
| /// the given byte. Returns 0 for invalid start bytes, including overlong | ||
| /// encodings (0xC0, 0xC1) and out-of-range bytes (0xF5-0xFF). | ||
| static size_t utf8SequenceLength(unsigned char byte) { | ||
| if ((byte & 0x80) == 0x00) return 1; // 0xxxxxxx - ASCII | ||
| if (byte == 0xC0 || byte == 0xC1) return 0; // Overlong encoding - invalid | ||
| if ((byte & 0xE0) == 0xC0) return 2; // 110xxxxx | ||
| if ((byte & 0xF0) == 0xE0) return 3; // 1110xxxx | ||
| if (byte >= 0xF5) return 0; // Out of Unicode range - invalid | ||
| if ((byte & 0xF8) == 0xF0) return 4; // 11110xxx | ||
| return 0; // Continuation byte (10xxxxxx) or other invalid | ||
| } | ||
|
|
||
| /// Returns true if the byte is a valid UTF-8 continuation byte (10xxxxxx). | ||
| static bool isUTF8Continuation(unsigned char byte) { | ||
| return (byte & 0xC0) == 0x80; | ||
| } | ||
| }; | ||
|
|
||
| } // anonymous namespace | ||
|
|
||
| @interface ExecuTorchLLMConfig () | ||
|
|
||
| - (const llm::GenerationConfig &)nativeConfig; | ||
|
|
@@ -88,15 +219,47 @@ - (BOOL)generateWithPrompt:(NSString*)prompt | |
| if (![self loadWithError:error]) { | ||
| return NO; | ||
| } | ||
|
|
||
| // Create a UTF-8 streaming buffer to handle partial multi-byte sequences. | ||
| // BPE tokenizers (especially ByteLevel like GPT-2/SmolLM) can output tokens | ||
| // that split UTF-8 characters at byte boundaries. This buffer accumulates | ||
| // bytes until complete UTF-8 sequences are formed before calling the callback. | ||
| auto utf8Buffer = std::make_shared<UTF8StreamingBuffer>(); | ||
|
|
||
| auto status = _runner->generate( | ||
| prompt.UTF8String, | ||
| config.nativeConfig, | ||
| [callback](const std::string& token) { | ||
| [callback, utf8Buffer](const std::string& token) { | ||
| if (callback) { | ||
| callback(@(token.c_str())); | ||
| // Process token through UTF-8 buffer | ||
| std::string validUTF8 = utf8Buffer->process(token); | ||
|
|
||
| // Only call callback when we have complete UTF-8 sequences | ||
| if (!validUTF8.empty()) { | ||
| NSString *tokenString = [[NSString alloc] initWithBytes:validUTF8.data() | ||
| length:validUTF8.size() | ||
| encoding:NSUTF8StringEncoding]; | ||
| if (tokenString) { | ||
| callback(tokenString); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // Flush any remaining bytes in the buffer | ||
| if (callback) { | ||
| std::string remaining = utf8Buffer->flush(); | ||
| if (!remaining.empty()) { | ||
| NSString *remainingString = [[NSString alloc] initWithBytes:remaining.data() | ||
| length:remaining.size() | ||
| encoding:NSUTF8StringEncoding]; | ||
| if (remainingString) { | ||
| callback(remainingString); | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
+249
to
+261
|
||
|
|
||
| if (status != Error::Ok) { | ||
| if (error) { | ||
| *error = [NSError errorWithDomain:ExecuTorchLLMErrorDomain | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop increments
ibyseqLen - 1to account for the loop'si++, but the loop counter is incremented after the adjustment. This results in advancing byseqLentotal, which is correct. However, when the sequence is invalid (skipped viacontinue), the loop only increments by 1. This creates an inconsistency where a valid multi-byte sequence advances correctly, but an invalid start byte only advances by 1, potentially causing the loop to re-examine bytes that are part of a previously skipped sequence. Consider restructuring the loop to avoid manual adjustment and use explicit indexing instead.