diff --git a/lib/expr.ml b/lib/expr.ml index 3cfad5b..52c355f 100644 --- a/lib/expr.ml +++ b/lib/expr.ml @@ -23,13 +23,15 @@ type binary_op = [@@deriving show { with_path = false }] type unary_op = Neg | Not [@@deriving show { with_path = false }] +type logical_op = And | Or [@@deriving show { with_path = false }] type expr = | Literal of literal | Variable of string | Assignment of { name : string; expr : expr_node } - | BinaryExpr of { op : binary_op; left : expr_node; right : expr_node } - | UnaryExpr of { op : unary_op; expr : expr_node } + | Unary of { op : unary_op; expr : expr_node } + | Binary of { op : binary_op; left : expr_node; right : expr_node } + | Logical of { op : logical_op; left : expr_node; right : expr_node } and expr_node = { expr : expr; pos : Error.code_pos } @@ -39,11 +41,14 @@ let rec show_expr ?(indent = 0) expr = match expr with | Literal literal -> indent_s ^ show_literal literal | Variable name -> indent_s ^ "Variable " ^ name - | Assignment { name; expr } -> indent_s ^ name ^ " = \n" ^ show_expr expr.expr ~indent:(indent + 2) - | BinaryExpr { op; left; right } -> + | Assignment { name; expr } -> indent_s ^ name ^ " = \n" ^ show_indented expr.expr + | Unary { op; expr } -> indent_s ^ show_unary_op op ^ "\n" ^ show_indented expr.expr + | Binary { op; left; right } -> indent_s ^ show_binary_op op ^ "\n" ^ show_indented left.expr ^ "\n" ^ show_indented right.expr - | UnaryExpr { op; expr } -> indent_s ^ show_unary_op op ^ "\n" ^ show_indented expr.expr + | Logical { op; left; right } -> + indent_s ^ show_logical_op op ^ "\n" ^ show_indented left.expr ^ "\n" + ^ show_indented right.expr let show_expr_node expr_node = show_expr expr_node.expr let make_expr_node (pos : Error.code_pos) (expr : expr) : expr_node = { expr; pos } @@ -63,8 +68,11 @@ let make_variable (pos : Error.code_pos) (name : string) : expr_node = let make_assignment (pos : Error.code_pos) (name : string) (expr : expr_node) : expr_node = Assignment { name; expr } |> make_expr_node pos -let make_binary (pos : Error.code_pos) (op : binary_op) (left : expr_node) (right : expr_node) = - make_expr_node pos (BinaryExpr { op; left; right }) - let make_unary (pos : Error.code_pos) (op : unary_op) (expr : expr_node) = - make_expr_node pos (UnaryExpr { op; expr }) + Unary { op; expr } |> make_expr_node pos + +let make_binary (pos : Error.code_pos) (op : binary_op) (left : expr_node) (right : expr_node) = + Binary { op; left; right } |> make_expr_node pos + +let make_logical (pos : Error.code_pos) (op : logical_op) (left : expr_node) (right : expr_node) = + Logical { op; left; right } |> make_expr_node pos diff --git a/lib/interpreter.ml b/lib/interpreter.ml index f6d6fa5..bf25e15 100644 --- a/lib/interpreter.ml +++ b/lib/interpreter.ml @@ -28,7 +28,18 @@ let rec interpret_expr (env : environment) (expr : expr_node) : (lox_value, runt let* value = interpret_expr env expr in Env.update env name value; Ok value - | BinaryExpr { op; left; right } -> ( + | Unary { op; expr } -> ( + let* expr = interpret_expr env expr in + match (op, expr) with + | Neg, Number x -> Ok (Number (-.x)) + | Not, Bool b -> Ok (Bool (not b)) + | _, _ -> + let msg = + Printf.sprintf "Invalid operant of type %s to operator %s" + (type_string_of_lox_value expr) (show_unary_op op) + in + Error (RuntimeError.make pos msg)) + | Binary { op; left; right } -> ( let* left = interpret_expr env left in let* right = interpret_expr env right in match (left, op, right) with @@ -58,17 +69,11 @@ let rec interpret_expr (env : environment) (expr : expr_node) : (lox_value, runt (type_string_of_lox_value left) (type_string_of_lox_value right) (show_binary_op op) in Error { pos; msg }) - | UnaryExpr { op; expr } -> ( - let* expr = interpret_expr env expr in - match (op, expr) with - | Neg, Number x -> Ok (Number (-.x)) - | Not, Bool b -> Ok (Bool (not b)) - | _, _ -> - let msg = - Printf.sprintf "Invalid operant of type %s to operator %s" - (type_string_of_lox_value expr) (show_unary_op op) - in - Error (RuntimeError.make pos msg)) + | Logical { op; left; right } -> ( + let* left = interpret_expr env left in + match (op, lox_value_to_bool left) with + | And, false | Or, true -> Ok left (* short circuit *) + | _ -> interpret_expr env right) let rec interpret_stmt (env : environment) (stmt : stmt_node) : (unit, runtime_error) result = let { pos; stmt } = stmt in @@ -98,3 +103,8 @@ let rec interpret_stmt (env : environment) (stmt : stmt_node) : (unit, runtime_e | [] -> Ok () in _interpret stmts + | If { cond; then_; else_ } -> + let* cond = interpret_expr env cond in + let cond = lox_value_to_bool cond in + if cond then interpret_stmt env then_ + else Option.map (interpret_stmt env) else_ |> Option.value ~default:(Ok ()) diff --git a/lib/parser.ml b/lib/parser.ml index cc8e10b..9e950d7 100644 --- a/lib/parser.ml +++ b/lib/parser.ml @@ -35,7 +35,7 @@ let consume state tt = let pos = cur_pos state in let tt' = peek_tt state in let msg = Printf.sprintf "Expected %s, but got %s" (show_token_type tt) (show_token_type tt') in - Error (ParserError.make pos msg) + ParserError.make pos msg |> Result.error let consume_identifier state = match peek_tt state with @@ -44,15 +44,15 @@ let consume_identifier state = Ok name | tt -> let pos = cur_pos state in - let msg = Printf.sprintf "Expected iedntifier, but got %s" (show_token_type tt) in - Error (ParserError.make pos msg) + let msg = Printf.sprintf "Expected identifier, but got %s" (show_token_type tt) in + ParserError.make pos msg |> Result.error let matches state tts = let f = ( = ) (peek_tt state) in Array.fold_left (fun acc tt -> acc || f tt) false tts let collect_chain (state : state ref) (tts : token_type array) - (higher_prec : state ref -> expr_result) : ((expr_node * token) array, parser_error) result = + (higher_prec : state ref -> expr_result) : ((expr_node * token) list, parser_error) result = let rec collect_chain_rec (acc : (expr_node * token) list) = if (not (is_at_end state)) && matches state tts then let token = next state in @@ -61,29 +61,29 @@ let collect_chain (state : state ref) (tts : token_type array) collect_chain_rec acc else Ok acc in - collect_chain_rec [] |> Result.map (fun l -> Array.of_list (List.rev l)) + collect_chain_rec [] |> Result.map (fun l -> List.rev l) let primary (state : state ref) : expr_result = let pos = cur_pos state in match peek_tt state with | Number x -> advance state; - Ok (make_number pos x) + make_number pos x |> Result.ok | String s -> advance state; - Ok (make_string pos s) + make_string pos s |> Result.ok | True -> advance state; - Ok (make_bool pos true) + make_bool pos true |> Result.ok | False -> advance state; - Ok (make_bool pos false) + make_bool pos false |> Result.ok | Nil -> advance state; - Ok (make_nil pos) + make_nil pos |> Result.ok | Identifier name -> advance state; - Ok (make_variable pos name) + make_variable pos name |> Result.ok | tt -> advance state; let msg = Printf.sprintf "Expected valid expression, got %s instead" (show_token_type tt) in @@ -129,7 +129,7 @@ and mul_or_div (state : state ref) : expr_result = in make_binary pos op acc expr in - let expr = Array.fold_left f expr exprs_tokens in + let expr = List.fold_left f expr exprs_tokens in Ok expr and sum_or_diff (state : state ref) : expr_result = @@ -145,7 +145,7 @@ and sum_or_diff (state : state ref) : expr_result = in make_binary pos op acc expr in - let expr = Array.fold_left f expr exprs_tokens in + let expr = List.fold_left f expr exprs_tokens in Ok expr and inequality (state : state ref) : expr_result = @@ -166,7 +166,7 @@ and inequality (state : state ref) : expr_result = in make_binary pos op acc expr in - let expr = Array.fold_left f expr exprs_tokens in + let expr = List.fold_left f expr exprs_tokens in Ok expr and equality (state : state ref) : expr_result = @@ -181,19 +181,41 @@ and equality (state : state ref) : expr_result = make_unary pos Not expr | _ -> assert false (* should only be here if tt is == != *) in - let expr = Array.fold_left f expr exprs_tokens in + let expr = List.fold_left f expr exprs_tokens in + Ok expr + +and logical_and (state : state ref) : expr_result = + let* expr = equality state in + let* exprs_tokens = collect_chain state [| And |] equality in + let f acc (expr, (token : token)) = + let pos = token.pos in + assert (token.token_type = And); + make_logical pos And acc expr + in + let expr = List.fold_left f expr exprs_tokens in + Ok expr + +and logical_or (state : state ref) : expr_result = + let* expr = logical_and state in + let* exprs_tokens = collect_chain state [| Or |] logical_and in + let f acc (expr, (token : token)) = + let pos = token.pos in + assert (token.token_type = Or); + make_logical pos Or acc expr + in + let expr = List.fold_left f expr exprs_tokens in Ok expr and assignment (state : state ref) : expr_result = - let* expr = equality state in + let* expr = logical_or state in if Equal = peek_tt state then let pos = (next state).pos in let* rhs = assignment state in match expr.expr with - | Variable name -> Ok (make_assignment pos name rhs) + | Variable name -> make_assignment pos name rhs |> Result.ok | _ -> let msg = "Invalid assignment target" in - Error (ParserError.make pos msg) + ParserError.make pos msg |> Result.error else Ok expr and expression (state : state ref) : expr_result = assignment state @@ -204,7 +226,7 @@ let rec block (state : state ref) : stmt_result = let rec collect_stmts state = if is_at_end state then let msg = "Unterminated block" in - Error (ParserError.make pos msg) + ParserError.make pos msg |> Result.error else match peek_tt state with | RightBrace -> Ok [] @@ -215,7 +237,19 @@ let rec block (state : state ref) : stmt_result = in let* stmts = collect_stmts state in let* _ = consume state RightBrace in - Ok (make_block pos stmts) + make_block pos stmts |> Result.ok + +and if_then_else (state : state ref) : stmt_result = + let pos = cur_pos state in + let* _ = consume state If in + let* _ = consume state LeftParen in + let* cond = expression state in + let* _ = consume state RightParen in + let* then_ = statement state in + let* (else_ : stmt_node option) = + if advance_if state Else then statement state |> Result.map Option.some else Ok None + in + make_if pos cond then_ else_ |> Result.ok and statement (state : state ref) : stmt_result = let pos = cur_pos state in @@ -227,6 +261,7 @@ and statement (state : state ref) : stmt_result = let stmt = make_print pos expr in Ok stmt | LeftBrace -> block state + | If -> if_then_else state | _ -> let* expr = expression state in let* _ = consume state Semicolon in @@ -248,7 +283,7 @@ and var_declaration (state : state ref) : stmt_result = Ok None in let* _ = consume state Semicolon in - Ok (make_var_decl pos name init) + make_var_decl pos name init |> Result.ok and declaration (state : state ref) : stmt_result = match peek_tt state with Var -> var_declaration state | _ -> statement state @@ -270,7 +305,6 @@ let rec parse_impl (state : state ref) : parse_result = let* stmts = parse_impl state in Ok (stmt :: stmts) | Error e -> ( - print_endline e.msg; synchronise state; if peek_tt state = Eof then Error [ e ] else diff --git a/lib/stmt.ml b/lib/stmt.ml index bcaca44..15c9598 100644 --- a/lib/stmt.ml +++ b/lib/stmt.ml @@ -5,40 +5,47 @@ type stmt = | Print of expr_node | VarDecl of { name : string; init : expr_node option } | Block of stmt_node list + | If of { cond : expr_node; then_ : stmt_node; else_ : stmt_node option } and stmt_node = { stmt : stmt; pos : Error.code_pos } let rec show_stmt ?(indent = 0) stmt = let indent_s = String.make indent ' ' in + let show_expr_ind ?(depth = 2) = show_expr ~indent:(indent + depth) in + let show_stmt_ind ?(depth = 2) = show_stmt ~indent:(indent + depth) in match stmt with - | Expr expr -> indent_s ^ "Expr\n" ^ show_expr ~indent:(indent + 2) expr.expr - | Print expr -> indent_s ^ "Print\n" ^ show_expr ~indent:(indent + 2) expr.expr + | Expr expr -> indent_s ^ "Expr\n" ^ show_expr_ind expr.expr + | Print expr -> indent_s ^ "Print\n" ^ show_expr_ind expr.expr | VarDecl { name; init } -> ( indent_s ^ "Var " ^ name - ^ - match init with Some init -> " = \n" ^ show_expr ~indent:(indent + 2) init.expr | None -> "") + ^ match init with Some init -> " = \n" ^ show_expr_ind init.expr | None -> "") | Block stmts -> let stmts_s = - List.fold_left - (fun acc stmt -> acc ^ show_stmt ~indent:(indent + 2) stmt.stmt ^ "\n") - "" stmts + List.fold_left (fun acc stmt -> acc ^ show_stmt_ind stmt.stmt ^ "\n") "" stmts in "Block" ^ stmts_s ^ "End" + | If { cond; then_; else_ } -> + let cond_s = show_expr_ind cond.expr in + let then_s = show_stmt_ind ~depth:4 then_.stmt in + let else_s = Option.map (fun stmt -> show_stmt_ind ~depth:4 stmt.stmt) else_ in + indent_s ^ "If\n" ^ cond_s ^ "\n" ^ indent_s ^ " Then\n" ^ then_s + ^ if Option.is_some else_s then "\n" ^ indent_s ^ " Else\n" ^ Option.get else_s else "" let show_stmt_node stmt_node = show_stmt stmt_node.stmt +let make_stmt_node (pos : Error.code_pos) (stmt : stmt) : stmt_node = { stmt; pos } let make_expr_stmt (pos : Error.code_pos) (expr : expr_node) : stmt_node = - let stmt = Expr expr in - { stmt; pos } + Expr expr |> make_stmt_node pos let make_print (pos : Error.code_pos) (expr : expr_node) : stmt_node = - let stmt = Print expr in - { stmt; pos } + Print expr |> make_stmt_node pos let make_var_decl (pos : Error.code_pos) (name : string) (init : expr_node option) = - let stmt = VarDecl { name; init } in - { stmt; pos } + VarDecl { name; init } |> make_stmt_node pos let make_block (pos : Error.code_pos) (stmts : stmt_node list) : stmt_node = - let stmt = Block stmts in - { stmt; pos } + Block stmts |> make_stmt_node pos + +let make_if (pos : Error.code_pos) (cond : expr_node) (then_ : stmt_node) (else_ : stmt_node option) + = + If { cond; then_; else_ } |> make_stmt_node pos diff --git a/lib/value.ml b/lib/value.ml index 75c7194..3fe9dde 100644 --- a/lib/value.ml +++ b/lib/value.ml @@ -14,3 +14,6 @@ let type_string_of_lox_value lox_value = | Number _ -> "Number" | Bool _ -> "Bool" | Nil -> "Nil" + +let lox_value_to_bool lox_value = + match lox_value with String _ -> true | Number _ -> true | Bool b -> b | Nil -> false