diff --git a/CHANGELOG.md b/CHANGELOG.md index 30576e07..e7b91e50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Lowercase model name before invoking for hypermode hosted models [#221](https://github.com/gohypermode/runtime/pull/221) - Improve HTTP error messages [#222](https://github.com/gohypermode/runtime/pull/222) +- Add host function for direct logging [#224](https://github.com/gohypermode/runtime/pull/224) ## 2024-06-03 - Version 0.8.2 diff --git a/graphql/datasource/source.go b/graphql/datasource/source.go index aa0db4ba..027dbce9 100644 --- a/graphql/datasource/source.go +++ b/graphql/datasource/source.go @@ -68,7 +68,11 @@ func (s Source) callFunction(ctx context.Context, callInfo callInfo) (any, []res ctx = context.WithValue(ctx, utils.ExecutionIdContextKey, executionId) ctx = context.WithValue(ctx, utils.PluginContextKey, info.Plugin) - // Create output buffers for the function to write logs to + // Also prepare a slice to capture log messages sent through the "log" host function. + messages := []utils.LogMessage{} + ctx = context.WithValue(ctx, utils.FunctionMessagesContextKey, &messages) + + // Create output buffers for the function to write stdout/stderr to buffers := utils.OutputBuffers{} // Get a module instance for this request. @@ -92,8 +96,9 @@ func (s Source) callFunction(ctx context.Context, callInfo callInfo) (any, []res Buffers: buffers, } - // Transform error lines in the output buffers to GraphQL errors - gqlErrors := transformErrors(buffers, callInfo) + // Transform messages (and error lines in the output buffers) to GraphQL errors + messages = append(messages, utils.TransformConsoleOutput(buffers)...) + gqlErrors := transformErrors(messages, callInfo) return result, gqlErrors, err } @@ -238,8 +243,7 @@ func transformValue(data []byte, tf *fieldInfo) (result []byte, err error) { return buf.Bytes(), nil } -func transformErrors(buffers utils.OutputBuffers, ci callInfo) []resolve.GraphQLError { - messages := utils.TransformConsoleOutput(buffers) +func transformErrors(messages []utils.LogMessage, ci callInfo) []resolve.GraphQLError { errors := make([]resolve.GraphQLError, 0, len(messages)) for _, msg := range messages { // Only include errors. Other messages will be captured later and diff --git a/hostfunctions/hostfns.go b/hostfunctions/hostfns.go index 97e9c68b..87c385a9 100644 --- a/hostfunctions/hostfns.go +++ b/hostfunctions/hostfns.go @@ -22,6 +22,7 @@ func Instantiate(ctx context.Context, runtime *wazero.Runtime) error { b := (*runtime).NewHostModuleBuilder(hostModuleName) // Each host function should get a line here: + b.NewFunctionBuilder().WithFunc(hostLog).Export("log") b.NewFunctionBuilder().WithFunc(hostExecuteGQL).Export("executeGQL") b.NewFunctionBuilder().WithFunc(hostInvokeClassifier).Export("invokeClassifier") b.NewFunctionBuilder().WithFunc(hostComputeEmbedding).Export("computeEmbedding") diff --git a/hostfunctions/log.go b/hostfunctions/log.go new file mode 100644 index 00000000..b3b3e5c4 --- /dev/null +++ b/hostfunctions/log.go @@ -0,0 +1,36 @@ +/* + * Copyright 2024 Hypermode, Inc. + */ + +package hostfunctions + +import ( + "context" + "hmruntime/logger" + "hmruntime/utils" + + wasm "github.com/tetratelabs/wazero/api" +) + +func hostLog(ctx context.Context, mod wasm.Module, pLevel uint32, pMessage uint32) { + + var level, message string + err := readParams2(ctx, mod, pLevel, pMessage, &level, &message) + if err != nil { + logger.Err(ctx, err).Msg("Error reading input parameters.") + } + + // write to the logger + logger.Get(ctx). + WithLevel(logger.ParseLevel(level)). + Str("text", message). + Bool("user_visible", true). + Msg("Message logged from function.") + + // also store messages in the context, so we can return them to the caller + messages := ctx.Value(utils.FunctionMessagesContextKey).(*[]utils.LogMessage) + *messages = append(*messages, utils.LogMessage{ + Level: level, + Message: message, + }) +} diff --git a/logger/logwriter.go b/logger/logwriter.go index 60d14f09..7abb9ad7 100644 --- a/logger/logwriter.go +++ b/logger/logwriter.go @@ -41,7 +41,7 @@ func (w logWriter) Write(p []byte) (n int, err error) { func (w logWriter) logMessage(line string) { l, message := utils.SplitConsoleOutputLine(line) - level := parseLevel(l) + level := ParseLevel(l) if level == zerolog.NoLevel { level = w.level } @@ -52,7 +52,7 @@ func (w logWriter) logMessage(line string) { Msg("Message logged from function.") } -func parseLevel(level string) zerolog.Level { +func ParseLevel(level string) zerolog.Level { switch level { case "debug": return zerolog.DebugLevel diff --git a/utils/context.go b/utils/context.go index a1caef4c..0e84e05b 100644 --- a/utils/context.go +++ b/utils/context.go @@ -9,3 +9,4 @@ type contextKey string const ExecutionIdContextKey contextKey = "execution_id" const PluginContextKey contextKey = "plugin" const FunctionOutputContextKey contextKey = "function_output" +const FunctionMessagesContextKey contextKey = "function_messages"