diff --git a/src/ast.rs b/src/ast.rs index 11450b8..bc5da94 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,10 +1,10 @@ -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum TopAST { Function(FunctionAST), Prototype(PrototypeAST), } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum ExprAST { NumberExpr(NumberExprAST), VariableExpr(VariableExprAST), @@ -12,36 +12,36 @@ pub enum ExprAST { CallExpr(CallExprAST), } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct NumberExprAST { pub val: f64, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct VariableExprAST { pub name: String, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct BinaryExprAST { pub op: char, pub lhs: Box, pub rhs: Box, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct CallExprAST { pub callee: String, pub args: Vec, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct PrototypeAST { pub name: String, pub args: Vec, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct FunctionAST { pub proto: PrototypeAST, pub body: ExprAST, diff --git a/src/parser.rs b/src/parser.rs index b56105c..b70a7f9 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -205,3 +205,96 @@ pub fn generate_ast(input: &str) -> Result> { let lexer = Lexer::new(input.chars()); Parser::parse(lexer) } + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn scan_input_1() { + let input = r#" + extern sin(a); + "#; + let ast = generate_ast(input).unwrap(); + let result = vec![TopAST::Prototype(PrototypeAST { + name: String::from("sin"), + args: vec![String::from("a")], + })]; + assert_eq!(ast, result); + } + + #[test] + fn scan_input_2() { + let input = r#" + def foo(x y) x+foo(y, 4.0); + "#; + let ast = generate_ast(input).unwrap(); + let result = vec![TopAST::Function(FunctionAST { + proto: PrototypeAST { + name: "foo".to_string(), + args: vec!["x".to_string(), "y".to_string()], + }, + body: ExprAST::BinaryExpr(BinaryExprAST { + op: '+', + lhs: Box::new(ExprAST::VariableExpr(VariableExprAST { + name: "x".to_string(), + })), + rhs: Box::new(ExprAST::CallExpr(CallExprAST { + callee: "foo".to_string(), + args: vec![ + ExprAST::VariableExpr(VariableExprAST { + name: "y".to_string(), + }), + ExprAST::NumberExpr(NumberExprAST { val: 4.0 }), + ], + })), + }), + })]; + assert_eq!(ast, result); + } + + #[test] + fn scan_input_3() { + let input = r#" + def foo(x y) x+y y; + "#; + let ast = generate_ast(input).unwrap(); + let result = vec![ + TopAST::Function(FunctionAST { + proto: PrototypeAST { + name: "foo".to_string(), + args: vec!["x".to_string(), "y".to_string()], + }, + body: ExprAST::BinaryExpr(BinaryExprAST { + op: '+', + lhs: Box::new(ExprAST::VariableExpr(VariableExprAST { + name: "x".to_string(), + })), + rhs: Box::new(ExprAST::VariableExpr(VariableExprAST { + name: "y".to_string(), + })), + }), + }), + TopAST::Function(FunctionAST { + proto: PrototypeAST { + name: "".to_string(), + args: vec![], + }, + body: ExprAST::VariableExpr(VariableExprAST { + name: "y".to_string(), + }), + }), + ]; + assert_eq!(ast, result); + } + + #[test] + fn scan_bad_input_1() { + let input = r#" + def foo(x y) x+y ); + "#; + let ast = generate_ast(input); + assert!(ast.is_err()); + } +}