From 85ff876a75147490e60c70c2f36e964513f1086a Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 10 Feb 2025 20:06:34 +0000 Subject: [PATCH] fix: correctly decode multi-byte characters over multiple chunks (#1316) --- src/internal/decoders/line.ts | 107 ++++++++++++++++++++++------------ src/streaming.ts | 6 +- tests/streaming.test.ts | 53 ++++++++++++++++- 3 files changed, 126 insertions(+), 40 deletions(-) diff --git a/src/internal/decoders/line.ts b/src/internal/decoders/line.ts index 34e41d1dc..66f62c057 100644 --- a/src/internal/decoders/line.ts +++ b/src/internal/decoders/line.ts @@ -13,52 +13,58 @@ export class LineDecoder { static NEWLINE_CHARS = new Set(['\n', '\r']); static NEWLINE_REGEXP = /\r\n|[\n\r]/g; - buffer: string[]; - trailingCR: boolean; + buffer: Uint8Array; + #carriageReturnIndex: number | null; textDecoder: any; // TextDecoder found in browsers; not typed to avoid pulling in either "dom" or "node" types. constructor() { - this.buffer = []; - this.trailingCR = false; + this.buffer = new Uint8Array(); + this.#carriageReturnIndex = null; } decode(chunk: Bytes): string[] { - let text = this.decodeText(chunk); - - if (this.trailingCR) { - text = '\r' + text; - this.trailingCR = false; - } - if (text.endsWith('\r')) { - this.trailingCR = true; - text = text.slice(0, -1); - } - - if (!text) { + if (chunk == null) { return []; } - const trailingNewline = LineDecoder.NEWLINE_CHARS.has(text[text.length - 1] || ''); - let lines = text.split(LineDecoder.NEWLINE_REGEXP); + const binaryChunk = + chunk instanceof ArrayBuffer ? new Uint8Array(chunk) + : typeof chunk === 'string' ? new TextEncoder().encode(chunk) + : chunk; + + let newData = new Uint8Array(this.buffer.length + binaryChunk.length); + newData.set(this.buffer); + newData.set(binaryChunk, this.buffer.length); + this.buffer = newData; + + const lines: string[] = []; + let patternIndex; + while ((patternIndex = findNewlineIndex(this.buffer, this.#carriageReturnIndex)) != null) { + if (patternIndex.carriage && this.#carriageReturnIndex == null) { + // skip until we either get a corresponding `\n`, a new `\r` or nothing + this.#carriageReturnIndex = patternIndex.index; + continue; + } - // if there is a trailing new line then the last entry will be an empty - // string which we don't care about - if (trailingNewline) { - lines.pop(); - } + // we got double \r or \rtext\n + if ( + this.#carriageReturnIndex != null && + (patternIndex.index !== this.#carriageReturnIndex + 1 || patternIndex.carriage) + ) { + lines.push(this.decodeText(this.buffer.slice(0, this.#carriageReturnIndex - 1))); + this.buffer = this.buffer.slice(this.#carriageReturnIndex); + this.#carriageReturnIndex = null; + continue; + } - if (lines.length === 1 && !trailingNewline) { - this.buffer.push(lines[0]!); - return []; - } + const endIndex = + this.#carriageReturnIndex !== null ? patternIndex.preceding - 1 : patternIndex.preceding; - if (this.buffer.length > 0) { - lines = [this.buffer.join('') + lines[0], ...lines.slice(1)]; - this.buffer = []; - } + const line = this.decodeText(this.buffer.slice(0, endIndex)); + lines.push(line); - if (!trailingNewline) { - this.buffer = [lines.pop() || '']; + this.buffer = this.buffer.slice(patternIndex.index); + this.#carriageReturnIndex = null; } return lines; @@ -102,13 +108,38 @@ export class LineDecoder { } flush(): string[] { - if (!this.buffer.length && !this.trailingCR) { + if (!this.buffer.length) { return []; } + return this.decode('\n'); + } +} - const lines = [this.buffer.join('')]; - this.buffer = []; - this.trailingCR = false; - return lines; +/** + * This function searches the buffer for the end patterns, (\r or \n) + * and returns an object with the index preceding the matched newline and the + * index after the newline char. `null` is returned if no new line is found. + * + * ```ts + * findNewLineIndex('abc\ndef') -> { preceding: 2, index: 3 } + * ``` + */ +function findNewlineIndex( + buffer: Uint8Array, + startIndex: number | null, +): { preceding: number; index: number; carriage: boolean } | null { + const newline = 0x0a; // \n + const carriage = 0x0d; // \r + + for (let i = startIndex ?? 0; i < buffer.length; i++) { + if (buffer[i] === newline) { + return { preceding: i, index: i + 1, carriage: false }; + } + + if (buffer[i] === carriage) { + return { preceding: i, index: i + 1, carriage: true }; + } } + + return null; } diff --git a/src/streaming.ts b/src/streaming.ts index 6a57a50a0..1d1ae344b 100644 --- a/src/streaming.ts +++ b/src/streaming.ts @@ -346,13 +346,17 @@ class SSEDecoder { } /** This is an internal helper function that's just used for testing */ -export function _decodeChunks(chunks: string[]): string[] { +export function _decodeChunks(chunks: string[], { flush }: { flush: boolean } = { flush: false }): string[] { const decoder = new LineDecoder(); const lines: string[] = []; for (const chunk of chunks) { lines.push(...decoder.decode(chunk)); } + if (flush) { + lines.push(...decoder.flush()); + } + return lines; } diff --git a/tests/streaming.test.ts b/tests/streaming.test.ts index 6fe9a5781..8e5d0ca31 100644 --- a/tests/streaming.test.ts +++ b/tests/streaming.test.ts @@ -2,6 +2,7 @@ import { Response } from 'node-fetch'; import { PassThrough } from 'stream'; import assert from 'assert'; import { _iterSSEMessages, _decodeChunks as decodeChunks } from 'openai/streaming'; +import { LineDecoder } from 'openai/internal/decoders/line'; describe('line decoder', () => { test('basic', () => { @@ -10,8 +11,8 @@ describe('line decoder', () => { }); test('basic with \\r', () => { - // baz is not included because the line hasn't ended yet expect(decodeChunks(['foo', ' bar\r\nbaz'])).toEqual(['foo bar']); + expect(decodeChunks(['foo', ' bar\r\nbaz'], { flush: true })).toEqual(['foo bar', 'baz']); }); test('trailing new lines', () => { @@ -29,6 +30,56 @@ describe('line decoder', () => { test('escaped new lines with \\r', () => { expect(decodeChunks(['foo', ' bar\\r\\nbaz\n'])).toEqual(['foo bar\\r\\nbaz']); }); + + test('\\r & \\n split across multiple chunks', () => { + expect(decodeChunks(['foo\r', '\n', 'bar'], { flush: true })).toEqual(['foo', 'bar']); + }); + + test('single \\r', () => { + expect(decodeChunks(['foo\r', 'bar'], { flush: true })).toEqual(['foo', 'bar']); + }); + + test('double \\r', () => { + expect(decodeChunks(['foo\r', 'bar\r'], { flush: true })).toEqual(['foo', 'bar']); + expect(decodeChunks(['foo\r', '\r', 'bar'], { flush: true })).toEqual(['foo', '', 'bar']); + // implementation detail that we don't yield the single \r line until a new \r or \n is encountered + expect(decodeChunks(['foo\r', '\r', 'bar'], { flush: false })).toEqual(['foo']); + }); + + test('double \\r then \\r\\n', () => { + expect(decodeChunks(['foo\r', '\r', '\r', '\n', 'bar', '\n'])).toEqual(['foo', '', '', 'bar']); + expect(decodeChunks(['foo\n', '\n', '\n', 'bar', '\n'])).toEqual(['foo', '', '', 'bar']); + }); + + test('double newline', () => { + expect(decodeChunks(['foo\n\nbar'], { flush: true })).toEqual(['foo', '', 'bar']); + expect(decodeChunks(['foo', '\n', '\nbar'], { flush: true })).toEqual(['foo', '', 'bar']); + expect(decodeChunks(['foo\n', '\n', 'bar'], { flush: true })).toEqual(['foo', '', 'bar']); + expect(decodeChunks(['foo', '\n', '\n', 'bar'], { flush: true })).toEqual(['foo', '', 'bar']); + }); + + test('multi-byte characters across chunks', () => { + const decoder = new LineDecoder(); + + // bytes taken from the string 'известни' and arbitrarily split + // so that some multi-byte characters span multiple chunks + expect(decoder.decode(new Uint8Array([0xd0]))).toHaveLength(0); + expect(decoder.decode(new Uint8Array([0xb8, 0xd0, 0xb7, 0xd0]))).toHaveLength(0); + expect( + decoder.decode(new Uint8Array([0xb2, 0xd0, 0xb5, 0xd1, 0x81, 0xd1, 0x82, 0xd0, 0xbd, 0xd0, 0xb8])), + ).toHaveLength(0); + + const decoded = decoder.decode(new Uint8Array([0xa])); + expect(decoded).toEqual(['известни']); + }); + + test('flushing trailing newlines', () => { + expect(decodeChunks(['foo\n', '\nbar'], { flush: true })).toEqual(['foo', '', 'bar']); + }); + + test('flushing empty buffer', () => { + expect(decodeChunks([], { flush: true })).toEqual([]); + }); }); describe('streaming decoding', () => {