Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AI Completions #5910

Merged
merged 18 commits into from
Jun 18, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
- [Named arguments syntax is now recognized in IDE][5774]. Connections to
function arguments will now use named argument syntax instead of inserting
wildcards on all preceding arguments.
- [AI-powered code completions][5910]. It is now possible to get AI-powered
completions when using node searcher with Tables.

#### EnsoGL (rendering engine)

Expand Down Expand Up @@ -520,6 +522,7 @@
[5802]: https://github.com/enso-org/enso/pull/5802
[5850]: https://github.com/enso-org/enso/pull/5850
[5863]: https://github.com/enso-org/enso/pull/5863
[5910]: https://github.com/enso-org/enso/pull/5910

#### Enso Compiler

Expand Down
4 changes: 4 additions & 0 deletions app/gui/controller/engine-protocol/src/language_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ trait API {
/// VCS snapshot if no `commit_id` is provided.
#[MethodInput=VcsRestoreInput, rpc_name="vcs/restore"]
fn restore_vcs(&self, root: Path, commit_id: Option<String>) -> response::RestoreVcs;

/// An OpenAI-powered completion to the given prompt, with the given stop sequence.
#[MethodInput=AiCompletionInput, rpc_name="ai/completion"]
fn ai_completion(&self, prompt: String, stop_sequence: String) -> response::AiCompletion;
}}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ pub struct Completion {
pub current_version: SuggestionsDatabaseVersion,
}

/// Response of `ai/completion` method.
#[derive(Hash, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(missing_docs)]
pub struct AiCompletion {
pub code: String,
}

/// Response of `get_component_groups` method.
#[derive(Hash, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down
5 changes: 5 additions & 0 deletions app/gui/src/controller/graph/executed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ impl Handle {
self.execution_ctx.attach_visualization(visualization).await
}

/// See [`model::ExecutionContext::get_ai_completion`].
pub async fn get_ai_completion(&self, code: &str, stop: &str) -> FallibleResult<String> {
self.execution_ctx.get_ai_completion(code, stop).await
}

/// See [`model::ExecutionContext::modify_visualization`].
pub fn modify_visualization(
&self,
Expand Down
78 changes: 69 additions & 9 deletions app/gui/src/controller/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ pub mod action;
pub mod breadcrumbs;
pub mod component;

use crate::controller::graph::executed::Handle;
use crate::model::execution_context::QualifiedMethodPointer;
use crate::model::execution_context::Visualization;
pub use action::Action;



// =================
// === Constants ===
// =================
Expand Down Expand Up @@ -90,6 +92,16 @@ pub struct NotSupported {
#[fail(display = "An action cannot be executed when searcher is in \"edit node\" mode.")]
pub struct CannotExecuteWhenEditingNode;

#[allow(missing_docs)]
#[derive(Copy, Clone, Debug, Fail)]
#[fail(display = "An action cannot be executed when searcher is run without `this` argument.")]
pub struct CannotRunWithoutThisArgument;

#[allow(missing_docs)]
#[derive(Copy, Clone, Debug, Fail)]
#[fail(display = "No visualization data received for an AI suggestion.")]
pub struct NoAIVisualizationDataReceived;

#[allow(missing_docs)]
#[derive(Copy, Clone, Debug, Fail)]
#[fail(display = "Cannot commit expression in current mode ({:?}).", mode)]
Expand All @@ -103,14 +115,15 @@ pub struct CannotCommitExpression {
// =====================

/// The notification emitted by Searcher Controller
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Notification {
/// A new Suggestion list is available.
NewActionList,
/// Code should be inserted by means of using an AI autocompletion.
AISuggestionUpdated(String),
}



// ===================
// === Suggestions ===
// ===================
Expand Down Expand Up @@ -154,7 +167,6 @@ impl Default for Actions {
}



// ===================
// === Input Parts ===
// ===================
Expand Down Expand Up @@ -288,7 +300,7 @@ impl ParsedInput {
chain.wrapped.args.push(argument);
}
Some(chain)
// If there isn't any expression part, the pattern is the whole input.
// If there isn't any expression part, the pattern is the whole input.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is an accidental change :)

} else if let Some(sast) = pattern_sast {
let chain = ast::prefix::Chain::from_ast_non_strict(&sast.wrapped);
Some(ast::Shifted::new(self.pattern_offset, chain))
Expand All @@ -314,7 +326,6 @@ impl Display for ParsedInput {
}



// ================
// === ThisNode ===
// ================
Expand Down Expand Up @@ -370,7 +381,6 @@ impl ThisNode {
}



// ===========================
// === Searcher Controller ===
// ===========================
Expand Down Expand Up @@ -720,12 +730,64 @@ impl Searcher {
self.notifier.notify(Notification::NewActionList);
}

const AI_QUERY_PREFIX: &'static str = "AI:";
const AI_QUERY_ACCEPT_TOKEN: &'static str = "#";
const AI_STOP_SEQUENCE: &'static str = "`";
const AI_GOAL_PLACEHOLDER: &'static str = "__$$GOAL$$__";

async fn accept_ai_query(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is PoC, but docs are needed here for future refactoring/development. Please describe shortly what this function does and why.

query: String,
this: ThisNode,
graph: Handle,
notifier: notification::Publisher<Notification>,
) -> FallibleResult {
let vis_ptr = QualifiedMethodPointer::from_qualified_text(
"Standard.Visualization.AI",
"Standard.Visualization.AI",
"build_ai_prompt",
)?;
let vis = Visualization::new(this.id, vis_ptr, vec![]);
let mut result = graph.attach_visualization(vis.clone()).await?;
let next = result.next().await.ok_or(NoAIVisualizationDataReceived)?;
let prompt = std::str::from_utf8(&next)?;
let prompt_with_goal = prompt.replace(Self::AI_GOAL_PLACEHOLDER, &query);
graph.detach_visualization(vis.id).await?;
let completion = graph.get_ai_completion(&prompt_with_goal, Self::AI_STOP_SEQUENCE).await?;
notifier.publish(Notification::AISuggestionUpdated(completion)).await;
Ok(())
}

fn handle_ai_query(&self, query: String) -> FallibleResult {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is PoC, but docs are needed here for future refactoring/development. Please describe shortly what this function does and why.

let query = query.trim_start_matches(Self::AI_QUERY_PREFIX);
if !query.ends_with(Self::AI_QUERY_ACCEPT_TOKEN) {
return Ok(());
}
let query = query.trim_end_matches(Self::AI_QUERY_ACCEPT_TOKEN).trim().to_string();
let this = self.this_arg.clone();
if this.is_none() {
return Err(CannotRunWithoutThisArgument.into());
}
let this = this.as_ref().as_ref().unwrap().clone();
let graph = self.graph.clone_ref();
let notifier = self.notifier.clone_ref();
executor::global::spawn(async move {
if let Err(e) = Self::accept_ai_query(query, this, graph, notifier).await {
error!("error when handling AI query: {e}");
}
});

Ok(())
}

/// Set the Searcher Input.
///
/// This function should be called each time user modifies Searcher input in view. It may result
/// in a new action list (the appropriate notification will be emitted).
#[profile(Debug)]
pub fn set_input(&self, new_input: String) -> FallibleResult {
if new_input.starts_with(Self::AI_QUERY_PREFIX) {
return self.handle_ai_query(new_input);
}
debug!("Manually setting input to {new_input}.");
let parsed_input = ParsedInput::new(new_input, self.ide.parser())?;
let old_expr = self.data.borrow().input.expression.repr();
Expand Down Expand Up @@ -1617,7 +1679,6 @@ impl Drop for EditGuard {
}



// === SimpleFunctionCall ===

/// A simple function call is an AST where function is a single identifier with optional
Expand Down Expand Up @@ -1669,7 +1730,6 @@ fn apply_this_argument(this_var: &str, ast: &Ast) -> Ast {
}



// =============
// === Tests ===
// =============
Expand Down
8 changes: 7 additions & 1 deletion app/gui/src/model/execution_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ pub trait API: Debug {
FallibleResult<futures::channel::mpsc::UnboundedReceiver<VisualizationUpdateData>>,
>;


/// Detach the visualization from this execution context.
#[allow(clippy::needless_lifetimes)] // Note: Needless lifetimes
fn detach_visualization<'a>(
Expand Down Expand Up @@ -493,6 +492,13 @@ pub trait API: Debug {
futures::future::join_all(detach_actions).boxed_local()
}

/// Get an AI completion for the given `prompt`, with specified `stop` sequence.
fn get_ai_completion<'a>(
&'a self,
prompt: &str,
stop: &str,
) -> BoxFuture<'a, FallibleResult<String>>;

/// Interrupt the program execution.
#[allow(clippy::needless_lifetimes)] // Note: Needless lifetimes
fn interrupt<'a>(&'a self) -> BoxFuture<'a, FallibleResult>;
Expand Down
8 changes: 8 additions & 0 deletions app/gui/src/model/execution_context/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ impl model::execution_context::API for ExecutionContext {
}
}

fn get_ai_completion<'a>(
&'a self,
_prompt: &str,
_stop: &str,
) -> LocalBoxFuture<'a, FallibleResult<String>> {
futures::future::ready(Ok("".to_string())).boxed_local()
}

fn interrupt(&self) -> BoxFuture<FallibleResult> {
futures::future::ready(Ok(())).boxed_local()
}
Expand Down
13 changes: 13 additions & 0 deletions app/gui/src/model/execution_context/synchronized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,19 @@ impl model::execution_context::API for ExecutionContext {
self.model.dispatch_visualization_update(visualization_id, data)
}

fn get_ai_completion<'a>(
&'a self,
prompt: &str,
stop: &str,
) -> BoxFuture<'a, FallibleResult<String>> {
self.language_server
.client
.ai_completion(&prompt.to_string(), &stop.to_string())
.map(|result| result.map(|completion| completion.code).map_err(Into::into))
.boxed_local()
}


fn interrupt(&self) -> BoxFuture<FallibleResult> {
async move {
self.language_server.client.interrupt(&self.id).await?;
Expand Down
3 changes: 3 additions & 0 deletions app/gui/src/presenter/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,12 @@ impl Searcher {

let weak_model = Rc::downgrade(&model);
let notifications = model.controller.subscribe();
let graph = model.view.graph().clone();
spawn_stream_handler(weak_model, notifications, move |notification, _| {
match notification {
Notification::NewActionList => action_list_changed.emit(()),
Notification::AISuggestionUpdated(expr) =>
graph.set_node_expression((input_view, node_view::Expression::new_plain(expr))),
};
std::future::ready(())
});
Expand Down
31 changes: 31 additions & 0 deletions distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from Standard.Base import all
import Standard.Table.Data.Table.Table

goal_placeholder = "__$$GOAL$$__"

Table.build_ai_prompt self =
ops = ["aggregate","filter_by_expression","order_by","row_count","set","select_columns","transpose","join"]
aggs = ["Count","Average","Sum","Median","First","Last","Maximum","Minimum"]
joins = ["Inner","Left_Outer","Right_Outer","Full","Left_Exclusive","Right_Exclusive"]
examples = """
Table["id","category","Unit Price","Stock"];goal=get product count by category==>>`aggregate [Aggregate_Column.Group_By "category", Aggregate_Column.Count Nothing]`
Table["ID","Unit Price","Stock"];goal=order by how many items are available==>>`order_by ["Stock"]`
Table["Name","Enrolled Year"];goal=select people who enrolled between 2015 and 2018==>>`filter_by_expression "[Enrolled Year] >= 2015 && [Enrolled Year] <= 2018`
Table["Number of items","client name","city","unit price"];goal=compute the total value of each order==>>`set "[Number of items] * [unit price]" "total value"`
Table["Number of items","client name","CITY","unit price","total value"];goal=compute the average order value by city==>>`aggregate [Aggregate_Column.Group_By "CITY", Aggregate_Column.Average "total value"]`
Table["Area Code", "number"];goal=get full phone numbers==>>`set "'+1 (' + [Area Code] + ') ' + [number]" "full phone number"`
Table["Name","Grade","Subject"];goal=rank students by their average grade==>>`aggregate [Aggregate_Column.Group_By "Name", Aggregate_Column.Average "Grade" "Average Grade"] . order_by [Sort_Column.Name "Average Grade" Sort_Direction.Descending]`
Table["Country","Prime minister name","2018","2019","2020","2021"];goal=pivot yearly GDP values to rows==>>`transpose ["Country", "Prime minister name"] "Year" "GDP"`
Table["Size","Weight","Width","stuff","thing"];goal=only select size and thing of each record==>>`select_columns ["Size", "thing"]`
Table["ID","Name","Count"];goal=join it with var_17==>>`join var_17 Join_Kind.Inner`
ops_prompt = "Operations available on Table are: " + (ops . join ",")
aggs_prompt = "Available ways to aggregate a column are: " + (aggs . join ",")
joins_prompt = "Available join kinds are: " + (joins . join ",")
base_prompt = ops_prompt + '\n' + aggs_prompt + '\n' + joins_prompt + '\n' + examples
columns = self.column_names . map .to_text . join "," "Table[" "];"
goal_line = "goal=" + goal_placeholder + "==>>`"
base_prompt + '\n' + columns + goal_line

Any.build_ai_prompt self = "````"

build_ai_prompt subject = subject.build_ai_prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.enso.languageserver.ai

import org.enso.jsonrpc.{HasParams, HasResult, Method}

case object AICompletion extends Method("ai/completion") {
case class Params(prompt: String, stopSequence: String)
case class Result(code: String)

implicit val hasParams = new HasParams[this.type] {
type Params = AICompletion.Params
}
implicit val hasResult = new HasResult[this.type] {
type Result = AICompletion.Result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: LogLevel) {
ContentRoot.Project(serverConfig.contentRootUuid),
new File(serverConfig.contentRootPath)
)

private val openAiKey = sys.env.get("OPENAI_API_KEY")
private val openAiCfg = openAiKey.map(AICompletionConfig)

val languageServerConfig = Config(
contentRoot,
FileManagerConfig(timeout = 3.seconds),
Expand All @@ -91,7 +95,8 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: LogLevel) {
PathWatcherConfig(),
ExecutionContextConfig(),
directoriesConfig,
serverConfig.profilingConfig
serverConfig.profilingConfig,
openAiCfg
)
log.trace("Created Language Server config [{}].", languageServerConfig)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ case class VcsManagerConfig(
Path.of(ProjectDirectoriesConfig.DataDirectory)
}

case class AICompletionConfig(apiKey: String)

object VcsManagerConfig {
def apply(asyncInit: Boolean = true): VcsManagerConfig =
VcsManagerConfig(initTimeout = 5.seconds, 5.seconds, asyncInit)
Expand Down Expand Up @@ -151,7 +153,8 @@ case class Config(
pathWatcher: PathWatcherConfig,
executionContext: ExecutionContextConfig,
directories: ProjectDirectoriesConfig,
profiling: ProfilingConfig
profiling: ProfilingConfig,
aiCompletionConfig: Option[AICompletionConfig]
) extends ToLogString {

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.typesafe.scalalogging.LazyLogging
import org.enso.cli.task.ProgressUnit
import org.enso.cli.task.notifications.TaskNotificationApi
import org.enso.jsonrpc._
import org.enso.languageserver.ai.AICompletion
import org.enso.languageserver.boot.resource.InitializationComponent
import org.enso.languageserver.capability.CapabilityApi.{
AcquireCapability,
Expand Down Expand Up @@ -498,6 +499,9 @@ class JsonConnectionController(
.props(requestTimeout, suggestionsHandler),
InvalidateSuggestionsDatabase -> search.InvalidateSuggestionsDatabaseHandler
.props(requestTimeout, suggestionsHandler),
AICompletion -> ai.AICompletionHandler.props(
languageServerConfig.aiCompletionConfig
),
Completion -> search.CompletionHandler
.props(requestTimeout, suggestionsHandler),
ExecuteExpression -> ExecuteExpressionHandler
Expand Down
Loading