Skip to content

Commit

Permalink
torii mcp crate
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Mar 4, 2025
1 parent a57aed9 commit 917e50d
Show file tree
Hide file tree
Showing 13 changed files with 510 additions and 505 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ torii-runner = { path = "crates/torii/runner" }
torii-server = { path = "crates/torii/server" }
torii-sqlite = { path = "crates/torii/sqlite" }
torii-typed-data = { path = "crates/torii/typed-data" }
torii-mcp = { path = "crates/torii/mcp" }

# sozo
sozo-ops = { path = "crates/sozo/ops" }
Expand Down
13 changes: 13 additions & 0 deletions crates/torii/mcp/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
edition.workspace = true
name = "torii-mcp"
version.workspace = true

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
serde.workspace = true
serde_json.workspace = true
sqlx.workspace = true
tokio.workspace = true
torii-sqlite.workspace = true
3 changes: 3 additions & 0 deletions crates/torii/mcp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod types;
pub mod tools;
pub mod resources;
8 changes: 8 additions & 0 deletions crates/torii/mcp/src/resources/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#[derive(Clone, Debug)]
pub struct Resource {
pub name: &'static str,
}

pub fn get_resources() -> Vec<Resource> {
vec![] // Add resources as needed
}
18 changes: 18 additions & 0 deletions crates/torii/mcp/src/tools/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use serde_json::Value;

pub mod query;
pub mod schema;

#[derive(Clone, Debug)]
pub struct Tool {
pub name: &'static str,
pub description: &'static str,
pub input_schema: Value,
}

pub fn get_tools() -> Vec<Tool> {
vec![
query::get_tool(),
schema::get_tool(),
]
}
76 changes: 76 additions & 0 deletions crates/torii/mcp/src/tools/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use std::sync::Arc;

use serde_json::{json, Value};
use sqlx::SqlitePool;

use crate::types::{JsonRpcRequest, JsonRpcResponse, JsonRpcError};
use torii_sqlite::utils::map_row_to_json;
use crate::types::JSONRPC_VERSION;

use super::Tool;

pub fn get_tool() -> Tool {
Tool {
name: "query",
description: "Execute a SQL query on the database",
input_schema: json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL query to execute"
}
},
"required": ["query"]
}),
}
}

pub async fn handle(pool: Arc<SqlitePool>, request: JsonRpcRequest) -> JsonRpcResponse {
let Some(params) = request.params else {
return JsonRpcResponse::invalid_params(request.id, "Missing params");
};

let args = params.get("arguments").and_then(Value::as_object);
if let Some(query) = args.and_then(|args| args.get("query").and_then(Value::as_str)) {
match sqlx::query(query).fetch_all(&*pool).await {
Ok(rows) => {
// Convert rows to JSON using shared mapping function
let result = rows.iter().map(map_row_to_json).collect::<Vec<_>>();

JsonRpcResponse {
jsonrpc: JSONRPC_VERSION.to_string(),
id: request.id,
result: Some(json!({
"content": [{
"type": "text",
"text": serde_json::to_string(&result).unwrap()
}]
})),
error: None,
}
}
Err(e) => JsonRpcResponse {
jsonrpc: JSONRPC_VERSION.to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32603,
message: "Database error".to_string(),
data: Some(json!({ "details": e.to_string() })),
}),
},
}
} else {
JsonRpcResponse {
jsonrpc: JSONRPC_VERSION.to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32602,
message: "Invalid params".to_string(),
data: Some(json!({ "details": "Missing query parameter" })),
}),
}
}
}
116 changes: 116 additions & 0 deletions crates/torii/mcp/src/tools/schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::sync::Arc;

use serde_json::{json, Value};
use sqlx::SqlitePool;
use sqlx::Row;
use crate::types::{JsonRpcRequest, JsonRpcResponse, JsonRpcError};
use crate::types::JSONRPC_VERSION;

use super::Tool;

pub fn get_tool() -> Tool {
Tool {
name: "schema",
description: "Retrieve the database schema including tables, columns, and their types",
input_schema: json!({
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Optional table name to get schema for. If omitted, returns schema for all tables."
}
}
}),
}
}

pub async fn handle(pool: Arc<SqlitePool>, request: JsonRpcRequest) -> JsonRpcResponse {
let table_filter = request
.params
.as_ref()
.and_then(|p| p.get("arguments"))
.and_then(|args| args.get("table"))
.and_then(Value::as_str);

let schema_query = match table_filter {
Some(_table) => "SELECT
m.name as table_name,
p.*
FROM sqlite_master m
JOIN pragma_table_info(m.name) p
WHERE m.type = 'table'
AND m.name = ?
ORDER BY m.name, p.cid"
.to_string(),
_ => "SELECT
m.name as table_name,
p.*
FROM sqlite_master m
JOIN pragma_table_info(m.name) p
WHERE m.type = 'table'
ORDER BY m.name, p.cid"
.to_string(),
};

let rows = match table_filter {
Some(table) => sqlx::query(&schema_query).bind(table).fetch_all(&*pool).await,
_ => sqlx::query(&schema_query).fetch_all(&*pool).await,
};

match rows {
Ok(rows) => {
let mut schema = serde_json::Map::new();

for row in rows {
let table_name: String = row.try_get("table_name").unwrap();
let column_name: String = row.try_get("name").unwrap();
let column_type: String = row.try_get("type").unwrap();
let not_null: bool = row.try_get::<bool, _>("notnull").unwrap();
let pk: bool = row.try_get::<bool, _>("pk").unwrap();
let default_value: Option<String> = row.try_get("dflt_value").unwrap();

let table_entry = schema.entry(table_name).or_insert_with(|| {
json!({
"columns": serde_json::Map::new()
})
});

if let Some(columns) =
table_entry.get_mut("columns").and_then(|v| v.as_object_mut())
{
columns.insert(
column_name,
json!({
"type": column_type,
"nullable": !not_null,
"primary_key": pk,
"default": default_value
}),
);
}
}

JsonRpcResponse {
jsonrpc: JSONRPC_VERSION.to_string(),
id: request.id,
result: Some(json!({
"content": [{
"type": "text",
"text": serde_json::to_string_pretty(&schema).unwrap()
}]
})),
error: None,
}
}
Err(e) => JsonRpcResponse {
jsonrpc: JSONRPC_VERSION.to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32603,
message: "Database error".to_string(),
data: Some(json!({ "details": e.to_string() })),
}),
},
}
}
109 changes: 109 additions & 0 deletions crates/torii/mcp/src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tokio::sync::broadcast;

// Constants
pub const JSONRPC_VERSION: &str = "2.0";
pub const MCP_VERSION: &str = "2024-11-05";
pub const SSE_CHANNEL_CAPACITY: usize = 100;

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcMessage {
Request(JsonRpcRequest),
Notification(JsonRpcNotification),
}

#[derive(Debug, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: Value,
pub method: String,
pub params: Option<Value>,
}

#[derive(Debug, Deserialize)]
pub struct JsonRpcNotification {
pub _jsonrpc: String,
pub _method: String,
pub _params: Option<Value>,
}

#[derive(Debug, Serialize, Clone)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}

#[derive(Debug, Serialize, Clone)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}

#[derive(Debug, Serialize)]
pub struct Implementation {
pub name: String,
pub version: String,
}

#[derive(Debug, Serialize)]
pub struct ServerCapabilities {
pub tools: ToolCapabilities,
pub resources: ResourceCapabilities,
}

#[derive(Debug, Serialize)]
pub struct ToolCapabilities {
pub list_changed: bool,
}

#[derive(Debug, Serialize)]
pub struct ResourceCapabilities {
pub subscribe: bool,
pub list_changed: bool,
}

// Structure to hold SSE session information
#[derive(Clone, Debug)]
pub struct SseSession {
pub tx: broadcast::Sender<JsonRpcResponse>,
pub _session_id: String,
}

impl JsonRpcResponse {
pub fn ok(id: Value, result: Value) -> Self {
Self { jsonrpc: JSONRPC_VERSION.to_string(), id, result: Some(result), error: None }
}

pub fn error(id: Value, code: i32, message: &str, data: Option<Value>) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
id,
result: None,
error: Some(JsonRpcError { code, message: message.to_string(), data }),
}
}

pub fn invalid_request(id: Value) -> Self {
Self::error(id, -32600, "Invalid Request", None)
}

pub fn method_not_found(id: Value) -> Self {
Self::error(id, -32601, "Method not found", None)
}

pub fn invalid_params(id: Value, details: &str) -> Self {
Self::error(id, -32602, "Invalid params", Some(json!({ "details": details })))
}

pub fn parse_error(id: Value, details: &str) -> Self {
Self::error(id, -32700, "Parse error", Some(json!({ "details": details })))
}
}
Loading

0 comments on commit 917e50d

Please sign in to comment.