Skip to content

Commit

Permalink
Merge branch 'feat/ark/moderation_hit_type' into 'integration_2025-01…
Browse files Browse the repository at this point in the history
…-09_669760003330'

feat: [development task] ark-runtime-manual-Golang (965061)

See merge request iaasng/volcengine-go-sdk!447
  • Loading branch information
BitsAdmin committed Jan 9, 2025
2 parents e2ab935 + c481548 commit 1e488b2
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 5 deletions.
4 changes: 4 additions & 0 deletions service/arkruntime/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,7 @@ func (c *Client) getRetryAfter(v model.Response) int64 {
}
return retryAfterInterval
}

func (c *Client) isAPIKeyAuthentication() bool {
return c.config.apiKey != ""
}
96 changes: 96 additions & 0 deletions service/arkruntime/content_generation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package arkruntime

import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"

"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)

const contentGenerationTaskPath = "/contents/generations/tasks"

func (c *Client) CreateContentGenerationTask(
ctx context.Context,
request model.CreateContentGenerationTaskRequest,
setters ...requestOption,
) (response model.CreateContentGenerationTaskResponse, err error) {
if !c.isAPIKeyAuthentication() {
return response, model.ErrAKSKNotSupported
}

requestOptions := append(setters, withBody(request))
err = c.Do(ctx, http.MethodPost, c.fullURL(contentGenerationTaskPath), resourceTypeEndpoint, request.Model, &response, requestOptions...)
return
}

func (c *Client) GetContentGenerationTask(
ctx context.Context,
request model.GetContentGenerationTaskRequest,
setters ...requestOption,
) (response model.GetContentGenerationTaskResponse, err error) {
if !c.isAPIKeyAuthentication() {
return response, model.ErrAKSKNotSupported
}

url := fmt.Sprintf("%s/%s", c.fullURL(contentGenerationTaskPath), request.ID)

err = c.Do(ctx, http.MethodGet, url, resourceTypeEndpoint, "", &response, setters...)
return
}

func (c *Client) DeleteContentGenerationTask(
ctx context.Context,
request model.DeleteContentGenerationTaskRequest,
setters ...requestOption,
) (err error) {
if !c.isAPIKeyAuthentication() {
return model.ErrAKSKNotSupported
}

url := fmt.Sprintf("%s/%s", c.fullURL(contentGenerationTaskPath), request.ID)

err = c.Do(ctx, http.MethodDelete, url, resourceTypeEndpoint, "", nil, setters...)
return err
}

func (c *Client) ListContentGenerationTasks(
ctx context.Context,
request model.ListContentGenerationTasksRequest,
setters ...requestOption,
) (response model.ListContentGenerationTasksResponse, err error) {
if !c.isAPIKeyAuthentication() {
return response, model.ErrAKSKNotSupported
}

values := url.Values{}
if pageNum := request.PageNum; pageNum != nil && *pageNum > 0 {
values.Add("page_num", strconv.Itoa(*pageNum))
}
if pageSize := request.PageSize; pageSize != nil && *pageSize > 0 {
values.Add("page_size", strconv.Itoa(*pageSize))
}

if filter := request.Filter; filter != nil {
if status := filter.Status; status != nil && *status != "" {
values.Add("filter.status", *status)
}
if model := filter.Model; model != nil && *model != "" {
values.Add("filter.model", *model)
}
for _, taskID := range filter.TaskIDs {
values.Add("filter.task_ids", *taskID)
}
}

endpoint := fmt.Sprintf("%s?%s", c.fullURL(contentGenerationTaskPath), values.Encode())

err = c.Do(ctx, http.MethodGet, endpoint, resourceTypeEndpoint, "", &response, setters...)
if err != nil {
return response, err
}

return response, nil
}
99 changes: 99 additions & 0 deletions service/arkruntime/example/content_generation/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package main

import (
"context"
"fmt"
"github.com/volcengine/volcengine-go-sdk/volcengine"
"os"

"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)

/**
* Authentication
* 1.If you authorize your endpoint using an API key, you can set your api key to environment variable "ARK_API_KEY"
* client := arkruntime.NewClientWithApiKey(os.Getenv("ARK_API_KEY"))
* Note: If you use an API key, this API key will not be refreshed.
* To prevent the API from expiring and failing after some time, choose an API key with no expiration date.
*/
func main() {
client := arkruntime.NewClientWithApiKey(os.Getenv("ARK_API_KEY"))
ctx := context.Background()
modelEp := "YOUR_ENDPOINT_ID"

fmt.Println("----- create content generation task -----")
createReq := model.CreateContentGenerationTaskRequest{
Model: modelEp, // Replace with your endpoint ID
Content: []*model.CreateContentGenerationContentItem{
{
Type: model.ContentGenerationContentItemTypeText,
Text: volcengine.String("龙与地下城女骑士背景是起伏的平原,目光从镜头转向平原 --ratio 1:1"),
},
{
Type: model.ContentGenerationContentItemTypeImage,
ImageURL: &model.ImageURL{
URL: "${YOUR URL HERE}", // Replace with URL
},
},
},
}

createResponse, err := client.CreateContentGenerationTask(ctx, createReq)
if err != nil {
fmt.Printf("create content generation error: %v\n", err)
return
}
fmt.Printf("Task Created with ID: %s\n", createResponse.ID)

fmt.Println("----- get content generation task -----")
taskID := createResponse.ID

getRequest := model.GetContentGenerationTaskRequest{ID: taskID}

getResponse, err := client.GetContentGenerationTask(ctx, getRequest)
if err != nil {
fmt.Printf("get content generation task error: %v\n", err)
return
}

fmt.Printf("Task ID: %s\n", getResponse.ID)
fmt.Printf("Model: %s\n", getResponse.Model)
fmt.Printf("Status: %s\n", getResponse.Status)
fmt.Printf("Failure Reason: %v\n", getResponse.FailureReason)
fmt.Printf("Video URL: %s\n", getResponse.Content.VideoURL)
fmt.Printf("Completion Tokens: %d\n", getResponse.Usage.CompletionTokens)
fmt.Printf("Created At: %d\n", getResponse.CreatedAt)
fmt.Printf("Updated At: %d\n", getResponse.UpdatedAt)

fmt.Println("----- list content generation task -----")

listRequest := model.ListContentGenerationTasksRequest{
PageNum: volcengine.Int(1),
PageSize: volcengine.Int(10),
Filter: &model.ListContentGenerationTasksFilter{
Status: volcengine.String(model.StatusSucceeded),
//TaskIDs: volcengine.StringSlice([]string{"cgt-example-1", "cgt-example-2"}),
//Model: volcengine.String(modelEp),
},
}

listResponse, err := client.ListContentGenerationTasks(ctx, listRequest)
if err != nil {
fmt.Printf("failed to list content generation tasks: %v\n", err)
}

fmt.Printf("ListContentGenerationTasks returned %v results\n", listResponse.Total)

fmt.Println("----- delete content generation task -----")

deleteRequest := model.DeleteContentGenerationTaskRequest{ID: taskID}

err = client.DeleteContentGenerationTask(ctx, deleteRequest)
if err != nil {
fmt.Printf("delete content generation task error: %v\n", err)
} else {
fmt.Println("successfully deleted task id: ", taskID)
}

}
23 changes: 18 additions & 5 deletions service/arkruntime/model/chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ func (r FinishReason) MarshalJSON() ([]byte, error) {
return []byte(`"` + string(r) + `"`), nil // best effort to not break future API changes
}

type ChatCompletionResponseChoicesElemModerationHitType string

const (
ChatCompletionResponseChoicesElemModerationHitTypeViolence ChatCompletionResponseChoicesElemModerationHitType = "violence"
ChatCompletionResponseChoicesElemModerationHitTypeSevereViolation ChatCompletionResponseChoicesElemModerationHitType = "severe_violation"
)

type ChatCompletionChoice struct {
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
Expand All @@ -309,7 +316,12 @@ type ChatCompletionChoice struct {
// content_filter: Omitted content due to a flag from our content filters
// null: API response still in progress or incomplete
FinishReason FinishReason `json:"finish_reason"`
LogProbs *LogProbs `json:"logprobs,omitempty"`
// ModerationHitType
// The type of content moderation strategy hit.
// Only after selecting a moderation strategy for the endpoint that supports returning moderation hit types,
// API will return the corresponding values.
ModerationHitType *ChatCompletionResponseChoicesElemModerationHitType `json:"moderation_hit_type,omitempty" yaml:"moderation_hit_type,omitempty" mapstructure:"moderation_hit_type,omitempty"`
LogProbs *LogProbs `json:"logprobs,omitempty"`
}

// ChatCompletionResponse represents a response structure for chat completion API.
Expand All @@ -332,10 +344,11 @@ type ChatCompletionStreamChoiceDelta struct {
}

type ChatCompletionStreamChoice struct {
Index int `json:"index"`
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
LogProbs *LogProbs `json:"logprobs,omitempty"`
FinishReason FinishReason `json:"finish_reason"`
Index int `json:"index"`
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
LogProbs *LogProbs `json:"logprobs,omitempty"`
FinishReason FinishReason `json:"finish_reason"`
ModerationHitType *ChatCompletionResponseChoicesElemModerationHitType `json:"moderation_hit_type,omitempty" yaml:"moderation_hit_type,omitempty" mapstructure:"moderation_hit_type,omitempty"`
}

type ChatCompletionStreamResponse struct {
Expand Down
94 changes: 94 additions & 0 deletions service/arkruntime/model/content_generation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package model

type ContentGenerationContentItemType string

const (
ContentGenerationContentItemTypeText ContentGenerationContentItemType = "text"
ContentGenerationContentItemTypeImage ContentGenerationContentItemType = "image_url"
)

const (
StatusSucceeded = "succeeded"
StatusCancelled = "cancelled"
StatusFailed = "failed"
StatusRunning = "running"
StatusQueued = "queued"
)

type CreateContentGenerationTaskRequest struct {
Model string `json:"model"`
Content []*CreateContentGenerationContentItem `json:"content"`
}

type CreateContentGenerationTaskResponse struct {
ID string `json:"id"`

HttpHeader
}

type GetContentGenerationTaskRequest struct {
ID string `json:"id"`
}

type GetContentGenerationTaskResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Status string `json:"status"`
FailureReason *string `json:"failure_reason,omitempty"`
Content Content `json:"content"`
Usage Usage `json:"usage"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`

HttpHeader
}

type ListContentGenerationTasksRequest struct {
PageNum *int `json:"page_num,omitempty"`
PageSize *int `json:"page_size,omitempty"`
Filter *ListContentGenerationTasksFilter `json:"filter,omitempty"`
}

type DeleteContentGenerationTaskRequest struct {
ID string `json:"id"`
}

type ListContentGenerationTasksFilter struct {
Status *string `json:"status,omitempty"`
TaskIDs []*string `json:"task_ids,omitempty"`
Model *string `json:"model,omitempty"`
}

type CreateContentGenerationContentItem struct {
Type ContentGenerationContentItemType `json:"type"`
Text *string `json:"text,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}

type ImageURL struct {
URL string `json:"url"`
}
type Content struct {
VideoURL string `json:"video_url"`
}

type ContentGenerationUsage struct {
CompletionTokens int `json:"completion_tokens"`
}

type ListContentGenerationTasksResponse struct {
Total int64 `json:"total"`
Items []ListContentGenerationTaskItem `json:"items"`
HttpHeader
}

type ListContentGenerationTaskItem struct {
ID string `json:"id"`
Model string `json:"model"`
Status string `json:"status"`
FailureReason *string `json:"failure_reason,omitempty"`
Content Content `json:"content"`
Usage Usage `json:"usage"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
1 change: 1 addition & 0 deletions service/arkruntime/model/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ var (
ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously")
ErrBodyWithoutEndpoint = errors.New("can't fetch endpoint sts token without endpoint")
ErrBodyWithoutBot = errors.New("can't fetch bot sts token without bot id")
ErrAKSKNotSupported = errors.New("ak&sk authentication is currently not supported for this method, please use api key instead")
)

0 comments on commit 1e488b2

Please sign in to comment.