diff --git a/README.md b/README.md index 3b046dc..67e9c16 100644 --- a/README.md +++ b/README.md @@ -108,10 +108,11 @@ USAGE [--opts ] [-o ] [-r ] ARGUMENTS - MODEL_RESOURCE The resource ID or alias of the model to call. + MODEL_RESOURCE The resource ID or alias of the model to call. The --app flag must be included if an alias is used. FLAGS - -a, --app= app to run command against + -a, --app= The name or ID of the app. If an alias for the MODEL_RESOURCE argument is used, this flag is + required. -j, --json Output response as JSON -o, --output= The file path where the command writes the model response. -p, --prompt= (required) The input prompt for the model. diff --git a/src/commands/ai/models/call.ts b/src/commands/ai/models/call.ts index e9b20f7..df6f66d 100644 --- a/src/commands/ai/models/call.ts +++ b/src/commands/ai/models/call.ts @@ -2,14 +2,21 @@ import color from '@heroku-cli/color' import {flags} from '@heroku-cli/command' import {Args, ux} from '@oclif/core' import fs from 'node:fs' -import {ChatCompletionResponse, EmbeddingResponse, ImageResponse, ModelList} from '../../../lib/ai/types' +import { + type ChatCompletionRequest, + ChatCompletionResponse, type CreateEmbeddingRequest, + EmbeddingResponse, + type ImageRequest, + ImageResponse, + ModelList, +} from '../../../lib/ai/types' import Command from '../../../lib/base' import {openUrl} from '../../../lib/open-url' export default class Call extends Command { static args = { model_resource: Args.string({ - description: 'The resource ID or alias of the model to call.', + description: 'The resource ID or alias of the model to call. The --app flag must be included if an alias is used.', required: true, }), } @@ -21,7 +28,10 @@ export default class Call extends Command { ] static flags = { - app: flags.app({required: false}), + app: flags.app({ + required: false, + description: 'The name or ID of the app. If an alias for the MODEL_RESOURCE argument is used, this flag is required.', + }), // interactive: flags.boolean({ // char: 'i', // description: 'Use interactive mode for conversation beyond the initial prompt (not available on all models)', @@ -69,19 +79,19 @@ export default class Call extends Command { // Note: modelType will always be lower case. MarcusBlankenship 11/13/24. switch (modelType) { case 'text-to-embedding': { - const embedding = await this.createEmbedding(prompt, options) + const embedding = await this.createEmbedding(prompt, options as CreateEmbeddingRequest) await this.displayEmbedding(embedding, output, json) break } case 'text-to-image': { - const image = await this.generateImage(prompt, options) + const image = await this.generateImage(prompt, options as ImageRequest) await this.displayImageResult(image, output, browser, json) break } case 'text-to-text': { - const completion = await this.createChatCompletion(prompt, options) + const completion = await this.createChatCompletion(prompt, options as ChatCompletionRequest) await this.displayChatCompletion(completion, output, json) break } @@ -98,7 +108,7 @@ export default class Call extends Command { * @param opts JSON string containing options. * @returns The parsed options as an object. */ - private parseOptions(optfile?: string, opts?: string) { + private parseOptions(optfile?: string, opts?: string): unknown { const options = {} if (optfile) { @@ -138,7 +148,11 @@ export default class Call extends Command { return options } - private async createChatCompletion(prompt: string, options = {}) { + private async createChatCompletion(prompt: string, options: ChatCompletionRequest = {} as ChatCompletionRequest) { + if (!this.isChatCompletionRequest(options)) { + return ux.error('Unexpected chat completion options', {exit: 1}) + } + const {body: chatCompletionResponse} = await this.herokuAI.post('/v1/chat/completions', { body: { ...options, @@ -164,7 +178,11 @@ export default class Call extends Command { } } - private async generateImage(prompt: string, options = {}) { + private async generateImage(prompt: string, options: ImageRequest = {} as ImageRequest) { + if (!this.isImageRequest(options)) { + return ux.error('Unexpected image options', {exit: 1}) + } + const {body: imageResponse} = await this.herokuAI.post('/v1/images/generations', { body: { ...options, @@ -201,7 +219,11 @@ export default class Call extends Command { ux.error('Unexpected response format', {exit: 1}) } - private async createEmbedding(input: string, options = {}) { + private async createEmbedding(input: string, options: CreateEmbeddingRequest = {} as CreateEmbeddingRequest) { + if (!this.isEmbeddingsRequest(options)) { + return ux.error('Unexpected embedding options', {exit: 1}) + } + const {body: EmbeddingResponse} = await this.herokuAI.post('/v1/embeddings', { body: { ...options, @@ -223,4 +245,58 @@ export default class Call extends Command { json ? ux.styledJSON(embedding) : ux.log(content) } } + + private isEmbeddingsRequest(obj: unknown): obj is CreateEmbeddingRequest { + const embeddingRequestKeys = new Set([ + 'model', + 'user', + 'dimensions', + 'encoding_format', + 'input', + ]) + const keys = Object.keys(obj ?? {}) + return keys.every(key => embeddingRequestKeys.has(key as keyof CreateEmbeddingRequest)) + } + + private isImageRequest(obj: unknown): obj is ImageRequest { + const imageRequestKeys = new Set([ + 'prompt', + 'model', + 'n', + 'quality', + 'response_format', + 'size', + 'style', + 'user', + 'sampler', + 'seed', + 'steps', + 'cfg_scale', + 'clip_guidance_preset', + 'style_preset', + ]) + + const keys = Object.keys(obj ?? {}) + return keys.every(key => imageRequestKeys.has(key as keyof ImageRequest)) + } + + private isChatCompletionRequest(obj: unknown): obj is ChatCompletionRequest { + const chatCompletionRequestKeys = new Set([ + 'messages', + 'model', + 'temperature', + 'top_p', + 'n', + 'stream', + 'stop', + 'max_tokens', + 'presence_penalty', + 'frequency_penalty', + 'tools', + 'tool_choice', + 'user', + ]) + const keys = Object.keys(obj ?? {}) + return keys.every(key => chatCompletionRequestKeys.has(key as keyof ChatCompletionRequest)) + } } diff --git a/src/lib/ai/types.ts b/src/lib/ai/types.ts index 422e8bd..4e0f96b 100644 --- a/src/lib/ai/types.ts +++ b/src/lib/ai/types.ts @@ -117,6 +117,49 @@ export type ChatCompletionChoice = { } | null } +export interface ChatCompletionRequest { + messages: ChatMessage[]; + model: string; + temperature?: number; + top_p?: number; + n?: number; + stream?: boolean; + stop?: string | string[]; + max_tokens?: number; + presence_penalty?: number; + frequency_penalty?: number; + tools?: Tool[]; + tool_choice?: 'none' | 'auto' | ToolChoice; + user?: string; +} + +export interface ChatMessage { + role: 'system' | 'user' | 'assistant' | 'tool'; + content: string; + name?: string; + tool_calls?: ToolCall[]; +} + +export interface Tool { + type: 'function'; + function: { + name: string; + description?: string; + parameters: { + type: 'object'; + properties: Record; + required?: string[]; + }; + }; +} + +export interface ToolChoice { + type: 'function'; + function: { + name: string; + }; +} + /** * Chat completion response schema. */ @@ -158,6 +201,70 @@ export type Image = { readonly url?: string | null } +export type ImageRequest = { + prompt: string, + model: string, + n: number, + quality: string, + response_format: ResponseFormat, + size: string, + style: string, + user: string, + sampler: SamplerType, + seed: number, + steps: number, + cfg_scale: number, + clip_guidance_preset: ClipGuidancePreset, + style_preset: StylePreset +} + +export enum ResponseFormat { + Url = 'url', + Base64 = 'base64', +} + +export enum SamplerType { + DDIM = 'DDIM', + DDPM = 'DDPM', + KDPMPP2M = 'K_DPMPP_2M', + KDPMPP2SANCESTRAL = 'K_DPMPP_2S_ANCESTRAL', + KDPM2 = 'K_DPM_2', + KDPM2ANCESTRAL = 'K_DPM_2_ANCESTRAL', + KEULER = 'K_EULER', + KEULERANCESTRAL = 'K_EULER_ANCESTRAL', + KHEUN = 'K_HEUN', + KLMS = 'K_LMS', +} + +export enum ClipGuidancePreset { + None = 'NONE', + FastBlue = 'FAST_BLUE', + FastGreen = 'FAST_GREEN', + SimpleSlow = 'SIMPLE SLOW', + Slower = 'SLOWER', + Slowest = 'SLOWEST', +} + +export enum StylePreset { + '3DModel' = '3DModel', + AnalogFilm = 'analog-film', + Anime = 'anime', + Cinematic = 'cinematic', + ComicBook = 'comic-book', + DigitalArt = 'digital-art', + Enhance = 'enhance', + FantasyArt = 'fantasy-art', + Isometric = 'isometric', + LineArt = 'line-art', + LowPoly = 'low-poly', + ModelingCompound = 'modeling-compound', + NeonPunk = 'neon-punk', + Origami = 'origami', + Photographic = 'photographic', + PixelArt = 'pixel-art', + TileTexture = 'tile-texture', +} + /** * Image response schema. */ @@ -180,6 +287,14 @@ export type Embedding = { readonly object: 'embeddings' } +export interface CreateEmbeddingRequest { + model: string; + input: string | string[] | number[]; + user?: string; + encoding_format?: 'float' | 'base64'; + dimensions?: number; +} + /** * Embedding response schema. */ diff --git a/test/commands/ai/models/call.test.ts b/test/commands/ai/models/call.test.ts index 5d46a1f..4f4b1d1 100644 --- a/test/commands/ai/models/call.test.ts +++ b/test/commands/ai/models/call.test.ts @@ -1,23 +1,31 @@ -import fs from 'node:fs' -import {stdout, stderr} from 'stdout-stderr' +import * as client from '@heroku-cli/command' import {expect} from 'chai' import nock from 'nock' +import fs from 'node:fs' import sinon from 'sinon' +import {stderr, stdout} from 'stdout-stderr' +import heredoc from 'tsheredoc' import Cmd from '../../../../src/commands/ai/models/call' import * as openUrl from '../../../../src/lib/open-url' -import stripAnsi from '../../../helpers/strip-ansi' -import {runCommand} from '../../../run-command' import { - addon3, addon3Attachment1, - addon5, addon5Attachment1, - addon6, addon6Attachment1, + addon3, + addon3Attachment1, + addon5, + addon5Attachment1, + addon6, + addon6Attachment1, availableModels, chatCompletionResponse, embeddingsResponse, - imageContentBase64, imageContent, imageResponseBase64, imageResponseUrl, imageUrl, + imageContent, + imageContentBase64, + imageResponseBase64, + imageResponseUrl, + imageUrl, stringifiedEmbeddingsVector, } from '../../../helpers/fixtures' -import heredoc from 'tsheredoc' +import stripAnsi from '../../../helpers/strip-ansi' +import {runCommand} from '../../../run-command' describe('ai:models:call', function () { const {env} = process @@ -33,6 +41,8 @@ describe('ai:models:call', function () { defaultInferenceApi = nock('https://us.inference.heroku.com') .get('/available-models') .reply(200, availableModels) + + sandbox.replaceGetter(client.APIClient.prototype, 'auth', () => '1234') }) afterEach(function () { @@ -42,6 +52,7 @@ describe('ai:models:call', function () { inferenceApi.done() nock.cleanAll() sandbox.restore() + sinon.restore() }) context('when targeting a LLM (Text-to-Text) model resource', function () { @@ -256,7 +267,7 @@ describe('ai:models:call', function () { expect(writeFileSyncMock.calledWith( 'model-output.txt', - "Hello! I'm an AI assistant created by a company called Anthropic. It's nice to meet you.", + 'Hello! I\'m an AI assistant created by a company called Anthropic. It\'s nice to meet you.', )).to.be.true expect(stdout.output).to.eq('') expect(stripAnsi(stderr.output)).to.eq('')