-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AI Completions #5910
AI Completions #5910
Changes from 8 commits
f9e859f
e34c3d5
b78afd0
e1cd873
00fdce4
a922cda
bbdcca2
78f72e5
a381513
bf58dbb
3a4aa47
eb5763a
6cd55d8
47b59e4
92404d5
4fdccf6
6549b8f
bce1281
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,10 +37,12 @@ pub mod action; | |
pub mod breadcrumbs; | ||
pub mod component; | ||
|
||
use crate::controller::graph::executed::Handle; | ||
use crate::model::execution_context::QualifiedMethodPointer; | ||
use crate::model::execution_context::Visualization; | ||
pub use action::Action; | ||
|
||
|
||
|
||
// ================= | ||
// === Constants === | ||
// ================= | ||
|
@@ -90,6 +92,16 @@ pub struct NotSupported { | |
#[fail(display = "An action cannot be executed when searcher is in \"edit node\" mode.")] | ||
pub struct CannotExecuteWhenEditingNode; | ||
|
||
#[allow(missing_docs)] | ||
#[derive(Copy, Clone, Debug, Fail)] | ||
#[fail(display = "An action cannot be executed when searcher is run without `this` argument.")] | ||
pub struct CannotRunWithoutThisArgument; | ||
|
||
#[allow(missing_docs)] | ||
#[derive(Copy, Clone, Debug, Fail)] | ||
#[fail(display = "No visualization data received for an AI suggestion.")] | ||
pub struct NoAIVisualizationDataReceived; | ||
|
||
#[allow(missing_docs)] | ||
#[derive(Copy, Clone, Debug, Fail)] | ||
#[fail(display = "Cannot commit expression in current mode ({:?}).", mode)] | ||
|
@@ -103,14 +115,15 @@ pub struct CannotCommitExpression { | |
// ===================== | ||
|
||
/// The notification emitted by Searcher Controller | ||
#[derive(Copy, Clone, Debug, Eq, PartialEq)] | ||
#[derive(Clone, Debug, Eq, PartialEq)] | ||
pub enum Notification { | ||
/// A new Suggestion list is available. | ||
NewActionList, | ||
/// Code should be inserted by means of using an AI autocompletion. | ||
AISuggestionUpdated(String), | ||
} | ||
|
||
|
||
|
||
// =================== | ||
// === Suggestions === | ||
// =================== | ||
|
@@ -154,7 +167,6 @@ impl Default for Actions { | |
} | ||
|
||
|
||
|
||
// =================== | ||
// === Input Parts === | ||
// =================== | ||
|
@@ -288,7 +300,7 @@ impl ParsedInput { | |
chain.wrapped.args.push(argument); | ||
} | ||
Some(chain) | ||
// If there isn't any expression part, the pattern is the whole input. | ||
// If there isn't any expression part, the pattern is the whole input. | ||
} else if let Some(sast) = pattern_sast { | ||
let chain = ast::prefix::Chain::from_ast_non_strict(&sast.wrapped); | ||
Some(ast::Shifted::new(self.pattern_offset, chain)) | ||
|
@@ -314,7 +326,6 @@ impl Display for ParsedInput { | |
} | ||
|
||
|
||
|
||
// ================ | ||
// === ThisNode === | ||
// ================ | ||
|
@@ -370,7 +381,6 @@ impl ThisNode { | |
} | ||
|
||
|
||
|
||
// =========================== | ||
// === Searcher Controller === | ||
// =========================== | ||
|
@@ -720,12 +730,64 @@ impl Searcher { | |
self.notifier.notify(Notification::NewActionList); | ||
} | ||
|
||
const AI_QUERY_PREFIX: &'static str = "AI:"; | ||
const AI_QUERY_ACCEPT_TOKEN: &'static str = "#"; | ||
const AI_STOP_SEQUENCE: &'static str = "`"; | ||
const AI_GOAL_PLACEHOLDER: &'static str = "__$$GOAL$$__"; | ||
|
||
async fn accept_ai_query( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is PoC, but docs are needed here for future refactoring/development. Please describe shortly what this function does and why. |
||
query: String, | ||
this: ThisNode, | ||
graph: Handle, | ||
notifier: notification::Publisher<Notification>, | ||
) -> FallibleResult { | ||
let vis_ptr = QualifiedMethodPointer::from_qualified_text( | ||
"Standard.Visualization.AI", | ||
"Standard.Visualization.AI", | ||
"build_ai_prompt", | ||
)?; | ||
let vis = Visualization::new(this.id, vis_ptr, vec![]); | ||
let mut result = graph.attach_visualization(vis.clone()).await?; | ||
let next = result.next().await.ok_or(NoAIVisualizationDataReceived)?; | ||
let prompt = std::str::from_utf8(&next)?; | ||
let prompt_with_goal = prompt.replace(Self::AI_GOAL_PLACEHOLDER, &query); | ||
graph.detach_visualization(vis.id).await?; | ||
let completion = graph.get_ai_completion(&prompt_with_goal, Self::AI_STOP_SEQUENCE).await?; | ||
notifier.publish(Notification::AISuggestionUpdated(completion)).await; | ||
Ok(()) | ||
} | ||
|
||
fn handle_ai_query(&self, query: String) -> FallibleResult { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is PoC, but docs are needed here for future refactoring/development. Please describe shortly what this function does and why. |
||
let query = query.trim_start_matches(Self::AI_QUERY_PREFIX); | ||
if !query.ends_with(Self::AI_QUERY_ACCEPT_TOKEN) { | ||
return Ok(()); | ||
} | ||
let query = query.trim_end_matches(Self::AI_QUERY_ACCEPT_TOKEN).trim().to_string(); | ||
let this = self.this_arg.clone(); | ||
if this.is_none() { | ||
return Err(CannotRunWithoutThisArgument.into()); | ||
} | ||
let this = this.as_ref().as_ref().unwrap().clone(); | ||
let graph = self.graph.clone_ref(); | ||
let notifier = self.notifier.clone_ref(); | ||
executor::global::spawn(async move { | ||
if let Err(e) = Self::accept_ai_query(query, this, graph, notifier).await { | ||
error!("error when handling AI query: {e}"); | ||
} | ||
}); | ||
|
||
Ok(()) | ||
} | ||
|
||
/// Set the Searcher Input. | ||
/// | ||
/// This function should be called each time user modifies Searcher input in view. It may result | ||
/// in a new action list (the appropriate notification will be emitted). | ||
#[profile(Debug)] | ||
pub fn set_input(&self, new_input: String) -> FallibleResult { | ||
if new_input.starts_with(Self::AI_QUERY_PREFIX) { | ||
return self.handle_ai_query(new_input); | ||
} | ||
debug!("Manually setting input to {new_input}."); | ||
let parsed_input = ParsedInput::new(new_input, self.ide.parser())?; | ||
let old_expr = self.data.borrow().input.expression.repr(); | ||
|
@@ -1617,7 +1679,6 @@ impl Drop for EditGuard { | |
} | ||
|
||
|
||
|
||
// === SimpleFunctionCall === | ||
|
||
/// A simple function call is an AST where function is a single identifier with optional | ||
|
@@ -1669,7 +1730,6 @@ fn apply_this_argument(this_var: &str, ast: &Ast) -> Ast { | |
} | ||
|
||
|
||
|
||
// ============= | ||
// === Tests === | ||
// ============= | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from Standard.Base import all | ||
import Standard.Table.Data.Table.Table | ||
|
||
goal_placeholder = "__$$GOAL$$__" | ||
|
||
Table.build_ai_prompt self = | ||
ops = ["aggregate","filter_by_expression","order_by","row_count","set","select_columns","transpose","join"] | ||
aggs = ["Count","Average","Sum","Median","First","Last","Maximum","Minimum"] | ||
joins = ["Inner","Left_Outer","Right_Outer","Full","Left_Exclusive","Right_Exclusive"] | ||
examples = """ | ||
Table["id","category","Unit Price","Stock"];goal=get product count by category==>>`aggregate [Aggregate_Column.Group_By "category", Aggregate_Column.Count Nothing]` | ||
Table["ID","Unit Price","Stock"];goal=order by how many items are available==>>`order_by ["Stock"]` | ||
Table["Name","Enrolled Year"];goal=select people who enrolled between 2015 and 2018==>>`filter_by_expression "[Enrolled Year] >= 2015 && [Enrolled Year] <= 2018` | ||
Table["Number of items","client name","city","unit price"];goal=compute the total value of each order==>>`set "[Number of items] * [unit price]" "total value"` | ||
Table["Number of items","client name","CITY","unit price","total value"];goal=compute the average order value by city==>>`aggregate [Aggregate_Column.Group_By "CITY", Aggregate_Column.Average "total value"]` | ||
Table["Area Code", "number"];goal=get full phone numbers==>>`set "'+1 (' + [Area Code] + ') ' + [number]" "full phone number"` | ||
Table["Name","Grade","Subject"];goal=rank students by their average grade==>>`aggregate [Aggregate_Column.Group_By "Name", Aggregate_Column.Average "Grade" "Average Grade"] . order_by [Sort_Column.Name "Average Grade" Sort_Direction.Descending]` | ||
Table["Country","Prime minister name","2018","2019","2020","2021"];goal=pivot yearly GDP values to rows==>>`transpose ["Country", "Prime minister name"] "Year" "GDP"` | ||
Table["Size","Weight","Width","stuff","thing"];goal=only select size and thing of each record==>>`select_columns ["Size", "thing"]` | ||
Table["ID","Name","Count"];goal=join it with var_17==>>`join var_17 Join_Kind.Inner` | ||
ops_prompt = "Operations available on Table are: " + (ops . join ",") | ||
aggs_prompt = "Available ways to aggregate a column are: " + (aggs . join ",") | ||
joins_prompt = "Available join kinds are: " + (joins . join ",") | ||
base_prompt = ops_prompt + '\n' + aggs_prompt + '\n' + joins_prompt + '\n' + examples | ||
columns = self.column_names . map .to_text . join "," "Table[" "];" | ||
goal_line = "goal=" + goal_placeholder + "==>>`" | ||
base_prompt + '\n' + columns + goal_line | ||
|
||
Any.build_ai_prompt self = "````" | ||
|
||
build_ai_prompt subject = subject.build_ai_prompt |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package org.enso.languageserver.ai | ||
|
||
import org.enso.jsonrpc.{HasParams, HasResult, Method} | ||
|
||
case object AICompletion extends Method("ai/completion") { | ||
case class Params(prompt: String, stopSequence: String) | ||
case class Result(code: String) | ||
|
||
implicit val hasParams = new HasParams[this.type] { | ||
type Params = AICompletion.Params | ||
} | ||
implicit val hasResult = new HasResult[this.type] { | ||
type Result = AICompletion.Result | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is an accidental change :)