Skip to content

Commit

Permalink
Merge pull request #136 from athina-ai/feature/code-execution-v2
Browse files Browse the repository at this point in the history
ATH-2619: Added Code Execution V2 step
  • Loading branch information
vivek-athina authored Dec 5, 2024
2 parents 3d09f05 + f0c6b3f commit 74929c7
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 5 deletions.
2 changes: 1 addition & 1 deletion athina/llms/abstract_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def embeddings(self, text: str) -> list:
raise NotImplementedError

@abstractmethod
def chat_completion(self, messages, model, **kwargs) -> str:
def chat_completion(self, messages, model, **kwargs):
"""
Fetches a chat completion response. This method should be implemented by subclasses
to interact with the specific LLM provider's chat completion API.
Expand Down
4 changes: 4 additions & 0 deletions athina/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from athina.steps.open_ai_assistant import OpenAiAssistant
from athina.steps.transcribe_speech_to_text import TranscribeSpeechToText
from athina.steps.search import Search
from athina.steps.code_execution import CodeExecution
from athina.steps.code_execution_v2 import CodeExecutionV2
from athina.steps.spider_crawl import SpiderCrawl

__all__ = [
Expand All @@ -37,5 +39,7 @@
"OpenAiAssistant",
"TranscribeSpeechToText",
"Search",
"CodeExecution",
"CodeExecutionV2",
"SpiderCrawl",
]
344 changes: 344 additions & 0 deletions athina/steps/code_execution_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
from typing import Union, Dict, Any, Optional, Literal, ClassVar, TypedDict
from athina.steps import Step
import io
import sys
from contextlib import redirect_stdout, redirect_stderr
from dotenv import load_dotenv
from e2b_code_interpreter import Sandbox
import time
import json

# Load environment variables
load_dotenv()

# Constants
EXECUTION_LOCAL = "local"
EXECUTION_E2B = "e2b"
ExecutionEnvironment = Literal["local", "e2b"]

VARS_START_MARKER = "__VARS_START__"
VARS_END_MARKER = "__VARS_END__"
COMMAND_PREFIX = "!"


class StepResult(TypedDict):
status: Literal["success", "error"]
data: str
metadata: Dict[str, Any]


# Extract variable serialization logic
def _serialize_variable(name: str, value: Any) -> Optional[str]:
"""
Attempt to serialize a variable to a string representation.
Returns None if serialization fails.
"""
try:
serialized_value = repr(value)
return f"{name} = {serialized_value}"
except Exception as e:
print(f"Error serializing variable {name}: {str(e)}")
return None


# Extract variable capture code into a constant
VARIABLE_CAPTURE_CODE = f"""
import json
_exported_vars = {{}}
_locals = locals()
_globals = globals()
_builtin_names = dir(__builtins__)
# Create a list of items to iterate over to prevent dictionary modification during iteration
_global_items = list(_globals.items())
for var_name, var_value in _global_items:
if (not var_name.startswith('_') and
var_name not in _builtin_names and
var_name not in ['json']):
try:
json.dumps(var_value) # Test if value is JSON serializable
_exported_vars[var_name] = var_value
except:
print(f"Could not serialize {{var_name}}")
continue
print('{VARS_START_MARKER}')
print(json.dumps(_exported_vars))
print('{VARS_END_MARKER}')
"""


class CodeExecutionV2(Step):
"""
Step that executes code using either local environment or E2B sandbox.
Attributes:
code (str): The code to execute.
session_id (str): Unique identifier for the sandbox session.
name (Optional[str]): Name identifier for the execution.
execution_environment (ExecutionEnvironment): Execution context ('local' or 'e2b').
_sandbox (Optional[Sandbox]): E2B sandbox instance.
DEFAULT_TIMEOUT (ClassVar[int]): Default timeout for sandbox operations.
sandbox_timeout (Optional[int]): Custom timeout for sandbox operations.
"""

code: str
session_id: str
name: Optional[str] = None
_sandbox: Optional[Sandbox] = None
execution_environment: ExecutionEnvironment = EXECUTION_E2B
DEFAULT_TIMEOUT: ClassVar[int] = 60 # 1 minute default timeout for sandbox
MAX_TIMEOUT: ClassVar[int] = 300 # 5 minute limit for e2b sandbox execution
sandbox_timeout: Optional[int] = None

def __init__(
self,
execution_environment: ExecutionEnvironment = EXECUTION_E2B,
sandbox_timeout: Optional[int] = None,
**data,
):
super().__init__(**data)
self._sandbox = None
self.execution_environment = execution_environment
self.sandbox_timeout = sandbox_timeout

def _create_or_initialize_sandbox(self):
"""Checks if sandbox exists and connects to it or creates a new one if not"""
if not self.session_id:
raise ValueError("session_id is required for e2b execution")

try:

running_sandboxes = Sandbox.list()

for sandbox in running_sandboxes:
if sandbox.metadata.get("session_id") == self.session_id:
# Connect to the existing sandbox
self._sandbox = Sandbox.connect(sandbox.sandbox_id)
break

if self._sandbox is None:
self._sandbox = Sandbox(
timeout=min(
self.sandbox_timeout or self.DEFAULT_TIMEOUT, self.MAX_TIMEOUT
),
metadata={"session_id": self.session_id},
)
if self.code.startswith("!"):
# Run the code as a command
commands = map(lambda x: x[1:], self.code.split("\n"))
for command in commands:
self._sandbox.commands.run(command)
print(f"Created new sandbox with ID: {self._sandbox.sandbox_id}")

except Exception as e:
print(f"Error initializing sandbox: {str(e)}")
raise RuntimeError(f"Failed to initialize sandbox: {str(e)}") from e

def _create_step_result(
self,
status: Literal["success", "error"],
data: str,
start_time: float,
exported_vars: Optional[Dict] = None,
) -> StepResult:
"""
Create a standardized result object for step execution.
Args:
status: Execution status ("success" or "error")
data: Output data or error message
start_time: Time when execution started
exported_vars: Optional dictionary of exported variables
"""
execution_time_ms = round((time.time() - start_time) * 1000)
metadata: Dict[str, Any] = {"response_time": execution_time_ms}

if exported_vars is not None:
metadata["exported_vars"] = exported_vars

return {"status": status, "data": data, "metadata": metadata}

def _execute_local(self, input_data: dict, start_time: float) -> StepResult:
"""Execute code locally using exec"""
globals_dict = {"__builtins__": __builtins__}
globals_dict.update(input_data)

stdout_buffer = io.StringIO()
stderr_buffer = io.StringIO()

try:
with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
exec(self.code, globals_dict)

return self._create_step_result(
status="success", data=stdout_buffer.getvalue(), start_time=start_time
)
except Exception as e:
return self._create_step_result(
status="error",
data=f"Failed to execute the code.\nDetails:\n{str(e)}",
start_time=start_time,
)

def _prepare_input_variables(self, input_data: dict) -> list[str]:
"""
Prepare input variables for sandbox execution.
Returns a list of variable initialization statements.
"""
input_vars_code = []

for var_name, var_value in input_data.items():
if isinstance(var_value, dict) and "exported_vars" in var_value:
# Handle exported vars from previous steps
for exp_var_name, exp_var_value in var_value["exported_vars"].items():
if code := _serialize_variable(exp_var_name, exp_var_value):
input_vars_code.append(code)
else:
if code := _serialize_variable(var_name, var_value):
input_vars_code.append(code)

return input_vars_code

def _extract_exported_vars(self, stdout: str) -> dict:
"""
Extract exported variables from sandbox output.
Returns empty dict if extraction fails.
"""
try:
vars_start = stdout.find(f"{VARS_START_MARKER}\n") + len(
f"{VARS_START_MARKER}\n"
)
vars_end = stdout.find(f"\n{VARS_END_MARKER}")

if vars_start > -1 and vars_end > -1:
return json.loads(stdout[vars_start:vars_end])
except Exception as e:
print(f"Error extracting variables: {str(e)}")

return {}

def _execute_e2b(self, input_data: dict, start_time: float) -> StepResult:
"""
Execute code in E2B sandbox.
The execution follows these steps:
1. Initialize/connect to sandbox
2. Split input into commands and Python code
3. Initialize input variables in sandbox
4. Execute main code
5. Capture and extract output variables
"""
try:
self._create_or_initialize_sandbox()
if self._sandbox is None:
print("Sandbox is not initialized")
return self._create_step_result(
status="error",
data="Sandbox is not initialized",
start_time=start_time,
)

# Split code into commands and Python code
lines = self.code.split("\n")
commands = [
line.strip()[1:]
for line in lines
if line.strip().startswith(COMMAND_PREFIX)
]
python_code = [
line for line in lines if not line.strip().startswith(COMMAND_PREFIX)
]

if not python_code:
# Only commands were provided
print("Only commands were provided")
return self._create_step_result(
status="success",
data="Commands executed successfully",
start_time=start_time,
exported_vars={},
)

# Initialize input variables
input_vars_code = self._prepare_input_variables(input_data)
if input_vars_code:
setup_code = "\n".join(input_vars_code)
setup_execution = self._sandbox.run_code(setup_code)
if setup_execution.error:
print(f"Error setting up input variables: {setup_execution.error}")

# Execute main code
main_code = "\n".join(python_code)
execution = self._sandbox.run_code(main_code)
if execution.error:
return self._create_step_result(
status="error",
data=f"Failed to execute the code.\nDetails:\n{execution.error}",
start_time=start_time,
)

# Capture variables
var_execution = self._sandbox.run_code(VARIABLE_CAPTURE_CODE)
if var_execution.error:
print(f"Error capturing variables: {var_execution.error}")
return self._create_step_result(
status="success",
data="\n".join(execution.logs.stdout),
start_time=start_time,
exported_vars={},
)

# Extract and return results
exported_vars = self._extract_exported_vars(
"\n".join(var_execution.logs.stdout)
)
return self._create_step_result(
status="success",
data="\n".join(execution.logs.stdout),
start_time=start_time,
exported_vars=exported_vars,
)

except Exception as e:
print(f"\nUnexpected error: {str(e)}")
return self._create_step_result(
status="error",
data=f"Failed to execute the code.\nDetails:\n{str(e)}",
start_time=start_time,
)

def execute(self, input_data: Any) -> StepResult:
"""
Execute the code with the input data.
Args:
input_data: Dictionary containing input variables for code execution.
Returns:
Dict containing execution status, output data, and metadata.
Raises:
TypeError: If input_data is not a dictionary.
ValueError: If session_id is empty in e2b mode.
"""

if not self.code.strip():
raise ValueError("No code provided for execution")

if self.execution_environment == "e2b" and not self.session_id:
raise ValueError("session_id is required for e2b execution")

input_data = input_data or {}
if not isinstance(input_data, dict):
raise TypeError("Input data must be a dictionary")

# Start timing
start_time = time.time()

if self.execution_environment == "local":
return self._execute_local(input_data, start_time)
else:
return self._execute_e2b(input_data=input_data, start_time=start_time)
Loading

0 comments on commit 74929c7

Please sign in to comment.