Skip to content

Commit

Permalink
feat(W-17153162): ai:models:call update help information
Browse files Browse the repository at this point in the history
  • Loading branch information
justinwilaby committed Nov 21, 2024
1 parent 377435c commit f1c1031
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 22 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ USAGE
[--opts <value>] [-o <value>] [-r <value>]
ARGUMENTS
MODEL_RESOURCE The resource ID or alias of the model to call.
MODEL_RESOURCE The resource ID or alias of the model to call. The --app flag must be included if an alias is used.
FLAGS
-a, --app=<value> app to run command against
-a, --app=<value> The name or ID of the app. If an alias for the MODEL_RESOURCE argument is used, this flag is
required.
-j, --json Output response as JSON
-o, --output=<value> The file path where the command writes the model response.
-p, --prompt=<value> (required) The input prompt for the model.
Expand Down
96 changes: 86 additions & 10 deletions src/commands/ai/models/call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@ import color from '@heroku-cli/color'
import {flags} from '@heroku-cli/command'
import {Args, ux} from '@oclif/core'
import fs from 'node:fs'
import {ChatCompletionResponse, EmbeddingResponse, ImageResponse, ModelList} from '../../../lib/ai/types'
import {
type ChatCompletionRequest,
ChatCompletionResponse, type CreateEmbeddingRequest,
EmbeddingResponse,
type ImageRequest,
ImageResponse,
ModelList,
} from '../../../lib/ai/types'
import Command from '../../../lib/base'
import {openUrl} from '../../../lib/open-url'

export default class Call extends Command {
static args = {
model_resource: Args.string({
description: 'The resource ID or alias of the model to call.',
description: 'The resource ID or alias of the model to call. The --app flag must be included if an alias is used.',
required: true,
}),
}
Expand All @@ -21,7 +28,10 @@ export default class Call extends Command {
]

static flags = {
app: flags.app({required: false}),
app: flags.app({
required: false,
description: 'The name or ID of the app. If an alias for the MODEL_RESOURCE argument is used, this flag is required.',
}),
// interactive: flags.boolean({
// char: 'i',
// description: 'Use interactive mode for conversation beyond the initial prompt (not available on all models)',
Expand Down Expand Up @@ -69,19 +79,19 @@ export default class Call extends Command {
// Note: modelType will always be lower case. MarcusBlankenship 11/13/24.
switch (modelType) {
case 'text-to-embedding': {
const embedding = await this.createEmbedding(prompt, options)
const embedding = await this.createEmbedding(prompt, options as CreateEmbeddingRequest)
await this.displayEmbedding(embedding, output, json)
break
}

case 'text-to-image': {
const image = await this.generateImage(prompt, options)
const image = await this.generateImage(prompt, options as ImageRequest)
await this.displayImageResult(image, output, browser, json)
break
}

case 'text-to-text': {
const completion = await this.createChatCompletion(prompt, options)
const completion = await this.createChatCompletion(prompt, options as ChatCompletionRequest)
await this.displayChatCompletion(completion, output, json)
break
}
Expand All @@ -98,7 +108,7 @@ export default class Call extends Command {
* @param opts JSON string containing options.
* @returns The parsed options as an object.
*/
private parseOptions(optfile?: string, opts?: string) {
private parseOptions(optfile?: string, opts?: string): unknown {
const options = {}

if (optfile) {
Expand Down Expand Up @@ -138,7 +148,11 @@ export default class Call extends Command {
return options
}

private async createChatCompletion(prompt: string, options = {}) {
private async createChatCompletion(prompt: string, options: ChatCompletionRequest = {} as ChatCompletionRequest) {
if (!this.isChatCompletionRequest(options)) {
return ux.error('Unexpected chat completion options', {exit: 1})
}

const {body: chatCompletionResponse} = await this.herokuAI.post<ChatCompletionResponse>('/v1/chat/completions', {
body: {
...options,
Expand All @@ -164,7 +178,11 @@ export default class Call extends Command {
}
}

private async generateImage(prompt: string, options = {}) {
private async generateImage(prompt: string, options: ImageRequest = {} as ImageRequest) {
if (!this.isImageRequest(options)) {
return ux.error('Unexpected image options', {exit: 1})
}

const {body: imageResponse} = await this.herokuAI.post<ImageResponse>('/v1/images/generations', {
body: {
...options,
Expand Down Expand Up @@ -201,7 +219,11 @@ export default class Call extends Command {
ux.error('Unexpected response format', {exit: 1})
}

private async createEmbedding(input: string, options = {}) {
private async createEmbedding(input: string, options: CreateEmbeddingRequest = {} as CreateEmbeddingRequest) {
if (!this.isEmbeddingsRequest(options)) {
return ux.error('Unexpected embedding options', {exit: 1})
}

const {body: EmbeddingResponse} = await this.herokuAI.post<EmbeddingResponse>('/v1/embeddings', {
body: {
...options,
Expand All @@ -223,4 +245,58 @@ export default class Call extends Command {
json ? ux.styledJSON(embedding) : ux.log(content)
}
}

private isEmbeddingsRequest(obj: unknown): obj is CreateEmbeddingRequest {
const embeddingRequestKeys = new Set<keyof CreateEmbeddingRequest>([
'model',
'user',
'dimensions',
'encoding_format',
'input',
])
const keys = Object.keys(obj ?? {})
return keys.every(key => embeddingRequestKeys.has(key as keyof CreateEmbeddingRequest))
}

private isImageRequest(obj: unknown): obj is ImageRequest {
const imageRequestKeys = new Set<keyof ImageRequest>([
'prompt',
'model',
'n',
'quality',
'response_format',
'size',
'style',
'user',
'sampler',
'seed',
'steps',
'cfg_scale',
'clip_guidance_preset',
'style_preset',
])

const keys = Object.keys(obj ?? {})
return keys.every(key => imageRequestKeys.has(key as keyof ImageRequest))
}

private isChatCompletionRequest(obj: unknown): obj is ChatCompletionRequest {
const chatCompletionRequestKeys = new Set<keyof ChatCompletionRequest>([
'messages',
'model',
'temperature',
'top_p',
'n',
'stream',
'stop',
'max_tokens',
'presence_penalty',
'frequency_penalty',
'tools',
'tool_choice',
'user',
])
const keys = Object.keys(obj ?? {})
return keys.every(key => chatCompletionRequestKeys.has(key as keyof ChatCompletionRequest))
}
}
115 changes: 115 additions & 0 deletions src/lib/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,49 @@ export type ChatCompletionChoice = {
} | null
}

export interface ChatCompletionRequest {
messages: ChatMessage[];
model: string;
temperature?: number;
top_p?: number;
n?: number;
stream?: boolean;
stop?: string | string[];
max_tokens?: number;
presence_penalty?: number;
frequency_penalty?: number;
tools?: Tool[];
tool_choice?: 'none' | 'auto' | ToolChoice;
user?: string;
}

export interface ChatMessage {
role: 'system' | 'user' | 'assistant' | 'tool';
content: string;
name?: string;
tool_calls?: ToolCall[];
}

export interface Tool {
type: 'function';
function: {
name: string;
description?: string;
parameters: {
type: 'object';
properties: Record<string, any>;
required?: string[];
};
};
}

export interface ToolChoice {
type: 'function';
function: {
name: string;
};
}

/**
* Chat completion response schema.
*/
Expand Down Expand Up @@ -158,6 +201,70 @@ export type Image = {
readonly url?: string | null
}

export type ImageRequest = {
prompt: string,
model: string,
n: number,
quality: string,
response_format: ResponseFormat,
size: string,
style: string,
user: string,
sampler: SamplerType,
seed: number,
steps: number,
cfg_scale: number,
clip_guidance_preset: ClipGuidancePreset,
style_preset: StylePreset
}

export enum ResponseFormat {
Url = 'url',
Base64 = 'base64',
}

export enum SamplerType {
DDIM = 'DDIM',
DDPM = 'DDPM',
KDPMPP2M = 'K_DPMPP_2M',
KDPMPP2SANCESTRAL = 'K_DPMPP_2S_ANCESTRAL',
KDPM2 = 'K_DPM_2',
KDPM2ANCESTRAL = 'K_DPM_2_ANCESTRAL',
KEULER = 'K_EULER',
KEULERANCESTRAL = 'K_EULER_ANCESTRAL',
KHEUN = 'K_HEUN',
KLMS = 'K_LMS',
}

export enum ClipGuidancePreset {
None = 'NONE',
FastBlue = 'FAST_BLUE',
FastGreen = 'FAST_GREEN',
SimpleSlow = 'SIMPLE SLOW',
Slower = 'SLOWER',
Slowest = 'SLOWEST',
}

export enum StylePreset {
'3DModel' = '3DModel',
AnalogFilm = 'analog-film',
Anime = 'anime',
Cinematic = 'cinematic',
ComicBook = 'comic-book',
DigitalArt = 'digital-art',
Enhance = 'enhance',
FantasyArt = 'fantasy-art',
Isometric = 'isometric',
LineArt = 'line-art',
LowPoly = 'low-poly',
ModelingCompound = 'modeling-compound',
NeonPunk = 'neon-punk',
Origami = 'origami',
Photographic = 'photographic',
PixelArt = 'pixel-art',
TileTexture = 'tile-texture',
}

/**
* Image response schema.
*/
Expand All @@ -180,6 +287,14 @@ export type Embedding = {
readonly object: 'embeddings'
}

export interface CreateEmbeddingRequest {
model: string;
input: string | string[] | number[];
user?: string;
encoding_format?: 'float' | 'base64';
dimensions?: number;
}

/**
* Embedding response schema.
*/
Expand Down
31 changes: 21 additions & 10 deletions test/commands/ai/models/call.test.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import fs from 'node:fs'
import {stdout, stderr} from 'stdout-stderr'
import * as client from '@heroku-cli/command'
import {expect} from 'chai'
import nock from 'nock'
import fs from 'node:fs'
import sinon from 'sinon'
import {stderr, stdout} from 'stdout-stderr'
import heredoc from 'tsheredoc'
import Cmd from '../../../../src/commands/ai/models/call'
import * as openUrl from '../../../../src/lib/open-url'
import stripAnsi from '../../../helpers/strip-ansi'
import {runCommand} from '../../../run-command'
import {
addon3, addon3Attachment1,
addon5, addon5Attachment1,
addon6, addon6Attachment1,
addon3,
addon3Attachment1,
addon5,
addon5Attachment1,
addon6,
addon6Attachment1,
availableModels,
chatCompletionResponse,
embeddingsResponse,
imageContentBase64, imageContent, imageResponseBase64, imageResponseUrl, imageUrl,
imageContent,
imageContentBase64,
imageResponseBase64,
imageResponseUrl,
imageUrl,
stringifiedEmbeddingsVector,
} from '../../../helpers/fixtures'
import heredoc from 'tsheredoc'
import stripAnsi from '../../../helpers/strip-ansi'
import {runCommand} from '../../../run-command'

describe('ai:models:call', function () {
const {env} = process
Expand All @@ -33,6 +41,8 @@ describe('ai:models:call', function () {
defaultInferenceApi = nock('https://us.inference.heroku.com')
.get('/available-models')
.reply(200, availableModels)

sandbox.replaceGetter(client.APIClient.prototype, 'auth', () => '1234')
})

afterEach(function () {
Expand All @@ -42,6 +52,7 @@ describe('ai:models:call', function () {
inferenceApi.done()
nock.cleanAll()
sandbox.restore()
sinon.restore()
})

context('when targeting a LLM (Text-to-Text) model resource', function () {
Expand Down Expand Up @@ -256,7 +267,7 @@ describe('ai:models:call', function () {

expect(writeFileSyncMock.calledWith(
'model-output.txt',
"Hello! I'm an AI assistant created by a company called Anthropic. It's nice to meet you.",
'Hello! I\'m an AI assistant created by a company called Anthropic. It\'s nice to meet you.',
)).to.be.true
expect(stdout.output).to.eq('')
expect(stripAnsi(stderr.output)).to.eq('')
Expand Down

0 comments on commit f1c1031

Please sign in to comment.