diff --git a/README.md b/README.md index 012511412..a1f4bf760 100644 --- a/README.md +++ b/README.md @@ -499,7 +499,7 @@ const credential = new DefaultAzureCredential(); const scope = 'https://cognitiveservices.azure.com/.default'; const azureADTokenProvider = getBearerTokenProvider(credential, scope); -const openai = new AzureOpenAI({ azureADTokenProvider }); +const openai = new AzureOpenAI({ azureADTokenProvider, apiVersion: "" }); const result = await openai.chat.completions.create({ model: 'gpt-4o', @@ -509,6 +509,26 @@ const result = await openai.chat.completions.create({ console.log(result.choices[0]!.message?.content); ``` +### Realtime API +This SDK provides real-time streaming capabilities for Azure OpenAI through the `OpenAIRealtimeWS` and `OpenAIRealtimeWebSocket` clients described previously. + +To utilize the real-time features, begin by creating a fully configured `AzureOpenAI` client and passing it into either `OpenAIRealtimeWS.azure` or `OpenAIRealtimeWebSocket.azure`. For example: + +```ts +const cred = new DefaultAzureCredential(); +const scope = 'https://cognitiveservices.azure.com/.default'; +const deploymentName = 'gpt-4o-realtime-preview-1001'; +const azureADTokenProvider = getBearerTokenProvider(cred, scope); +const client = new AzureOpenAI({ + azureADTokenProvider, + apiVersion: '2024-10-01-preview', + deployment: deploymentName, +}); +const rt = await OpenAIRealtimeWS.azure(client); +``` + +Once the instance has been created, you can then begin sending requests and receiving streaming responses in real time. + ### Retries Certain errors will be automatically retried 2 times by default, with a short exponential backoff. diff --git a/examples/azure.ts b/examples/azure/chat.ts similarity index 91% rename from examples/azure.ts rename to examples/azure/chat.ts index 5fe1718fa..46df820f8 100755 --- a/examples/azure.ts +++ b/examples/azure/chat.ts @@ -2,6 +2,7 @@ import { AzureOpenAI } from 'openai'; import { getBearerTokenProvider, DefaultAzureCredential } from '@azure/identity'; +import 'dotenv/config'; // Corresponds to your Model deployment within your OpenAI resource, e.g. gpt-4-1106-preview // Navigate to the Azure OpenAI Studio to deploy a model. @@ -13,7 +14,7 @@ const azureADTokenProvider = getBearerTokenProvider(credential, scope); // Make sure to set AZURE_OPENAI_ENDPOINT with the endpoint of your Azure resource. // You can find it in the Azure Portal. -const openai = new AzureOpenAI({ azureADTokenProvider }); +const openai = new AzureOpenAI({ azureADTokenProvider, apiVersion: '2024-10-01-preview' }); async function main() { console.log('Non-streaming:'); diff --git a/examples/azure/realtime/websocket.ts b/examples/azure/realtime/websocket.ts new file mode 100644 index 000000000..bec74e654 --- /dev/null +++ b/examples/azure/realtime/websocket.ts @@ -0,0 +1,60 @@ +import { OpenAIRealtimeWebSocket } from 'openai/beta/realtime/websocket'; +import { AzureOpenAI } from 'openai'; +import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity'; +import 'dotenv/config'; + +async function main() { + const cred = new DefaultAzureCredential(); + const scope = 'https://cognitiveservices.azure.com/.default'; + const deploymentName = 'gpt-4o-realtime-preview-1001'; + const azureADTokenProvider = getBearerTokenProvider(cred, scope); + const client = new AzureOpenAI({ + azureADTokenProvider, + apiVersion: '2024-10-01-preview', + deployment: deploymentName, + }); + const rt = await OpenAIRealtimeWebSocket.azure(client); + + // access the underlying `ws.WebSocket` instance + rt.socket.addEventListener('open', () => { + console.log('Connection opened!'); + rt.send({ + type: 'session.update', + session: { + modalities: ['text'], + model: 'gpt-4o-realtime-preview', + }, + }); + + rt.send({ + type: 'conversation.item.create', + item: { + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'Say a couple paragraphs!' }], + }, + }); + + rt.send({ type: 'response.create' }); + }); + + rt.on('error', (err) => { + // in a real world scenario this should be logged somewhere as you + // likely want to continue procesing events regardless of any errors + throw err; + }); + + rt.on('session.created', (event) => { + console.log('session created!', event.session); + console.log(); + }); + + rt.on('response.text.delta', (event) => process.stdout.write(event.delta)); + rt.on('response.text.done', () => console.log()); + + rt.on('response.done', () => rt.close()); + + rt.socket.addEventListener('close', () => console.log('\nConnection closed!')); +} + +main(); diff --git a/examples/azure/realtime/ws.ts b/examples/azure/realtime/ws.ts new file mode 100644 index 000000000..ae20a1438 --- /dev/null +++ b/examples/azure/realtime/ws.ts @@ -0,0 +1,67 @@ +import { DefaultAzureCredential, getBearerTokenProvider } from '@azure/identity'; +import { OpenAIRealtimeWS } from 'openai/beta/realtime/ws'; +import { AzureOpenAI } from 'openai'; +import 'dotenv/config'; + +async function main() { + const cred = new DefaultAzureCredential(); + const scope = 'https://cognitiveservices.azure.com/.default'; + const deploymentName = 'gpt-4o-realtime-preview-1001'; + const azureADTokenProvider = getBearerTokenProvider(cred, scope); + const client = new AzureOpenAI({ + azureADTokenProvider, + apiVersion: '2024-10-01-preview', + deployment: deploymentName, + }); + const rt = await OpenAIRealtimeWS.azure(client); + + // access the underlying `ws.WebSocket` instance + rt.socket.on('open', () => { + console.log('Connection opened!'); + rt.send({ + type: 'session.update', + session: { + modalities: ['text'], + model: 'gpt-4o-realtime-preview', + }, + }); + rt.send({ + type: 'session.update', + session: { + modalities: ['text'], + model: 'gpt-4o-realtime-preview', + }, + }); + + rt.send({ + type: 'conversation.item.create', + item: { + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'Say a couple paragraphs!' }], + }, + }); + + rt.send({ type: 'response.create' }); + }); + + rt.on('error', (err) => { + // in a real world scenario this should be logged somewhere as you + // likely want to continue procesing events regardless of any errors + throw err; + }); + + rt.on('session.created', (event) => { + console.log('session created!', event.session); + console.log(); + }); + + rt.on('response.text.delta', (event) => process.stdout.write(event.delta)); + rt.on('response.text.done', () => console.log()); + + rt.on('response.done', () => rt.close()); + + rt.socket.on('close', () => console.log('\nConnection closed!')); +} + +main(); diff --git a/examples/package.json b/examples/package.json index b8c34ac45..70ec2c523 100644 --- a/examples/package.json +++ b/examples/package.json @@ -7,6 +7,7 @@ "private": true, "dependencies": { "@azure/identity": "^4.2.0", + "dotenv": "^16.4.7", "express": "^4.18.2", "next": "^14.1.1", "openai": "file:..", diff --git a/examples/realtime/ws.ts b/examples/realtime/ws.ts index 4bbe85e5d..bba140800 100644 --- a/examples/realtime/ws.ts +++ b/examples/realtime/ws.ts @@ -9,7 +9,7 @@ async function main() { rt.send({ type: 'session.update', session: { - modalities: ['foo'] as any, + modalities: ['text'], model: 'gpt-4o-realtime-preview', }, }); diff --git a/src/beta/realtime/internal-base.ts b/src/beta/realtime/internal-base.ts index 391d69911..d01e87fa4 100644 --- a/src/beta/realtime/internal-base.ts +++ b/src/beta/realtime/internal-base.ts @@ -1,6 +1,7 @@ import { RealtimeClientEvent, RealtimeServerEvent, ErrorEvent } from '../../resources/beta/realtime/realtime'; import { EventEmitter } from '../../lib/EventEmitter'; import { OpenAIError } from '../../error'; +import OpenAI, { AzureOpenAI } from 'openai'; export class OpenAIRealtimeError extends OpenAIError { /** @@ -73,11 +74,20 @@ export abstract class OpenAIRealtimeEmitter extends EventEmitter } } -export function buildRealtimeURL(props: { baseURL: string; model: string }): URL { - const path = '/realtime'; +export function isAzure(client: Pick): client is AzureOpenAI { + return client instanceof AzureOpenAI; +} - const url = new URL(props.baseURL + (props.baseURL.endsWith('/') ? path.slice(1) : path)); +export function buildRealtimeURL(client: Pick, model: string): URL { + const path = '/realtime'; + const baseURL = client.baseURL; + const url = new URL(baseURL + (baseURL.endsWith('/') ? path.slice(1) : path)); url.protocol = 'wss'; - url.searchParams.set('model', props.model); + if (isAzure(client)) { + url.searchParams.set('api-version', client.apiVersion); + url.searchParams.set('deployment', model); + } else { + url.searchParams.set('model', model); + } return url; } diff --git a/src/beta/realtime/websocket.ts b/src/beta/realtime/websocket.ts index e0853779d..ccb1b908f 100644 --- a/src/beta/realtime/websocket.ts +++ b/src/beta/realtime/websocket.ts @@ -1,8 +1,8 @@ -import { OpenAI } from '../../index'; +import { AzureOpenAI, OpenAI } from '../../index'; import { OpenAIError } from '../../error'; import * as Core from '../../core'; import type { RealtimeClientEvent, RealtimeServerEvent } from '../../resources/beta/realtime/realtime'; -import { OpenAIRealtimeEmitter, buildRealtimeURL } from './internal-base'; +import { OpenAIRealtimeEmitter, buildRealtimeURL, isAzure } from './internal-base'; interface MessageEvent { data: string; @@ -26,6 +26,7 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { props: { model: string; dangerouslyAllowBrowser?: boolean; + onUrl?: (url: URL) => void; }, client?: Pick, ) { @@ -44,11 +45,13 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { client ??= new OpenAI({ dangerouslyAllowBrowser }); - this.url = buildRealtimeURL({ baseURL: client.baseURL, model: props.model }); + this.url = buildRealtimeURL(client, props.model); + props.onUrl?.(this.url); + // @ts-ignore this.socket = new WebSocket(this.url, [ 'realtime', - `openai-insecure-api-key.${client.apiKey}`, + ...(isAzure(client) ? [] : [`openai-insecure-api-key.${client.apiKey}`]), 'openai-beta.realtime-v1', ]); @@ -77,6 +80,41 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter { this.socket.addEventListener('error', (event: any) => { this._onError(null, event.message, null); }); + + if (isAzure(client)) { + if (this.url.searchParams.get('Authorization') !== null) { + this.url.searchParams.set('Authorization', ''); + } else { + this.url.searchParams.set('api-key', ''); + } + } + } + + static async azure( + client: AzureOpenAI, + options: { deploymentName?: string; dangerouslyAllowBrowser?: boolean } = {}, + ): Promise { + const token = await client._getAzureADToken(); + function onUrl(url: URL) { + if (client.apiKey !== '') { + url.searchParams.set('api-key', client.apiKey); + } else { + if (token) { + url.searchParams.set('Authorization', `Bearer ${token}`); + } else { + throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.'); + } + } + } + const deploymentName = options.deploymentName ?? client.deploymentName; + if (!deploymentName) { + throw new Error('No deployment name provided'); + } + const { dangerouslyAllowBrowser } = options; + return new OpenAIRealtimeWebSocket( + { model: deploymentName, onUrl, ...(dangerouslyAllowBrowser ? { dangerouslyAllowBrowser } : {}) }, + client, + ); } send(event: RealtimeClientEvent) { diff --git a/src/beta/realtime/ws.ts b/src/beta/realtime/ws.ts index 631a36cd2..51339089c 100644 --- a/src/beta/realtime/ws.ts +++ b/src/beta/realtime/ws.ts @@ -1,7 +1,7 @@ import * as WS from 'ws'; -import { OpenAI } from '../../index'; +import { AzureOpenAI, OpenAI } from '../../index'; import type { RealtimeClientEvent, RealtimeServerEvent } from '../../resources/beta/realtime/realtime'; -import { OpenAIRealtimeEmitter, buildRealtimeURL } from './internal-base'; +import { OpenAIRealtimeEmitter, buildRealtimeURL, isAzure } from './internal-base'; export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { url: URL; @@ -14,12 +14,12 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { super(); client ??= new OpenAI(); - this.url = buildRealtimeURL({ baseURL: client.baseURL, model: props.model }); + this.url = buildRealtimeURL(client, props.model); this.socket = new WS.WebSocket(this.url, { ...props.options, headers: { ...props.options?.headers, - Authorization: `Bearer ${client.apiKey}`, + ...(isAzure(client) ? {} : { Authorization: `Bearer ${client.apiKey}` }), 'OpenAI-Beta': 'realtime=v1', }, }); @@ -51,6 +51,20 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { }); } + static async azure( + client: AzureOpenAI, + options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {}, + ): Promise { + const deploymentName = options.deploymentName ?? client.deploymentName; + if (!deploymentName) { + throw new Error('No deployment name provided'); + } + return new OpenAIRealtimeWS( + { model: deploymentName, options: { headers: await getAzureHeaders(client) } }, + client, + ); + } + send(event: RealtimeClientEvent) { try { this.socket.send(JSON.stringify(event)); @@ -67,3 +81,16 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter { } } } + +async function getAzureHeaders(client: AzureOpenAI) { + if (client.apiKey !== '') { + return { 'api-key': client.apiKey }; + } else { + const token = await client._getAzureADToken(); + if (token) { + return { Authorization: `Bearer ${token}` }; + } else { + throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.'); + } + } +} diff --git a/src/index.ts b/src/index.ts index 944def00f..3de224d90 100644 --- a/src/index.ts +++ b/src/index.ts @@ -491,7 +491,7 @@ export interface AzureClientOptions extends ClientOptions { /** API Client for interfacing with the Azure OpenAI API. */ export class AzureOpenAI extends OpenAI { private _azureADTokenProvider: (() => Promise) | undefined; - private _deployment: string | undefined; + deploymentName: string | undefined; apiVersion: string = ''; /** * API Client for interfacing with the Azure OpenAI API. @@ -574,7 +574,7 @@ export class AzureOpenAI extends OpenAI { this._azureADTokenProvider = azureADTokenProvider; this.apiVersion = apiVersion; - this._deployment = deployment; + this.deploymentName = deployment; } override buildRequest( @@ -589,7 +589,7 @@ export class AzureOpenAI extends OpenAI { if (!Core.isObj(options.body)) { throw new Error('Expected request body to be an object'); } - const model = this._deployment || options.body['model']; + const model = this.deploymentName || options.body['model']; if (model !== undefined && !this.baseURL.includes('/deployments')) { options.path = `/deployments/${model}${options.path}`; } @@ -597,7 +597,7 @@ export class AzureOpenAI extends OpenAI { return super.buildRequest(options, props); } - private async _getAzureADToken(): Promise { + async _getAzureADToken(): Promise { if (typeof this._azureADTokenProvider === 'function') { const token = await this._azureADTokenProvider(); if (!token || typeof token !== 'string') {