diff --git a/lib/error.ml b/lib/error.ml index adb6703..2ac35d5 100644 --- a/lib/error.ml +++ b/lib/error.ml @@ -4,9 +4,9 @@ type lexer_error = { pos : code_pos; msg : string } module LexerError = struct type t = lexer_error - let make (pos : code_pos) (msg : string) : lexer_error = { pos; msg } + let make (pos : code_pos) (msg : string) : t = { pos; msg } - let show (e : lexer_error) = + let show (e : t) = Printf.sprintf "LexerError at line %d, column %d: %s" e.pos.line e.pos.col e.msg end @@ -15,27 +15,22 @@ type parser_error = { pos : code_pos; msg : string } module ParserError = struct type t = parser_error - let make (pos : code_pos) (msg : string) : parser_error = { pos; msg } + let make (pos : code_pos) (msg : string) : t = { pos; msg } - let show (e : parser_error) = + let show (e : t) = Printf.sprintf "ParserError at line %d, column %d: %s" e.pos.line e.pos.col e.msg end (* type runtime_error = { pos : code_pos; msg : string; type_ : runtime_error_type } *) -type runtime_error = Error of { pos : code_pos; msg : string } | Break | Continue +type runtime_error = { pos : code_pos; msg : string } module RuntimeError = struct - type t = parser_error + type t = runtime_error - let make (pos : code_pos) (msg : string) : runtime_error = Error { pos; msg } - let break () : runtime_error = Break - let continue () : runtime_error = Continue + let make (pos : code_pos) (msg : string) : t = { pos; msg } - let show (e : runtime_error) = - match e with - | Error { pos; msg } -> - Printf.sprintf "RuntimeError at line %d, column %d: %s" pos.line pos.col msg - | Break | Continue -> assert false + let show ({ pos; msg } : t) = + Printf.sprintf "RuntimeError at line %d, column %d: %s" pos.line pos.col msg end type lox_error = diff --git a/lib/interpreter.ml b/lib/interpreter.ml index 54b7912..581086c 100644 --- a/lib/interpreter.ml +++ b/lib/interpreter.ml @@ -1,5 +1,3 @@ -let ( let* ) = Result.bind - open Environment open Error open Expr @@ -9,22 +7,69 @@ open Value let value_of_literal (literal : literal) : Value.lox_value = match literal with String s -> String s | Number x -> Number x | Bool b -> Bool b | Nil -> Nil -let rec interpret_expr (env : environment) (expr_node : expr_node) : - (lox_value, runtime_error) result = +type 'a interpreter_result = + | Ok of 'a + | Break + | Continue + | Return of lox_value + | Error of runtime_error + +module InterpreterResult = struct + type 'a t = 'a interpreter_result + + let ok value = Ok value + let break () = Break + let continue () = Continue + let return (value : lox_value) = Return value + let error e = Error e + + let map f result = + match result with + | Ok x -> f x |> ok + | Break -> Break + | Continue -> Continue + | Return _ as r -> r + | Error _ as e -> e + + let map_error f result = + match result with + | Error e -> f e |> error + | Ok _ as x -> x + | Break -> Break + | Continue -> Continue + | Return _ as r -> r + + let bind result f = + match result with + | Ok value -> f value + | Break -> Break + | Continue -> Continue + | Return _ as r -> r + | Error _ as e -> e +end + +open InterpreterResult + +let ( let* ) = InterpreterResult.bind + +type expr_result = lox_value interpreter_result +type stmt_result = unit interpreter_result + +let rec interpret_expr (env : environment) (expr_node : expr_node) : lox_value interpreter_result = let { pos; expr } = expr_node in match expr with - | Literal literal -> value_of_literal literal |> Result.ok + | Literal literal -> value_of_literal literal |> ok | Variable name -> ( let value_opt = Env.get env name in match value_opt with | Some x -> Ok x | None -> let msg = Printf.sprintf "name \"%s\" is not defined" name in - RuntimeError.make pos msg |> Result.error) + RuntimeError.make pos msg |> error) | Assignment { name; expr } -> if not (Env.is_defined env name) then let msg = Printf.sprintf "tried to assign to undefined variable %s" name in - RuntimeError.make pos msg |> Result.error + RuntimeError.make pos msg |> error else let* value = interpret_expr env expr in Env.update env name value; @@ -32,44 +77,44 @@ let rec interpret_expr (env : environment) (expr_node : expr_node) : | Unary { op; expr } -> ( let* expr = interpret_expr env expr in match (op, expr) with - | Neg, Number x -> Number (-.x) |> Result.ok - | Not, value -> Bool (lox_value_to_bool value |> not) |> Result.ok + | Neg, Number x -> Number (-.x) |> ok + | Not, value -> Bool (lox_value_to_bool value |> not) |> ok | _, _ -> let msg = Printf.sprintf "Invalid operant of type %s to operator %s" (type_string_of_lox_value expr) (show_unary_op op) in - RuntimeError.make pos msg |> Result.error) + RuntimeError.make pos msg |> error) | Binary { op; left; right } -> ( let* left = interpret_expr env left in let* right = interpret_expr env right in match (left, op, right) with - | String a, Plus, String b -> String (a ^ b) |> Result.ok - | Number x, Plus, Number y -> Number (x +. y) |> Result.ok - | Number x, Minus, Number y -> Number (x -. y) |> Result.ok - | Number x, Mul, Number y -> Number (x *. y) |> Result.ok + | String a, Plus, String b -> String (a ^ b) |> ok + | Number x, Plus, Number y -> Number (x +. y) |> ok + | Number x, Minus, Number y -> Number (x -. y) |> ok + | Number x, Mul, Number y -> Number (x *. y) |> ok | Number x, Div, Number y -> - if y <> 0. then Number (x /. y) |> Result.ok + if y <> 0. then Number (x /. y) |> ok else let msg = "Division by 0" in - RuntimeError.make pos msg |> Result.error - | Bool b, And, Bool c -> Bool (b && c) |> Result.ok - | Bool b, Or, Bool c -> Bool (b || c) |> Result.ok + RuntimeError.make pos msg |> error + | Bool b, And, Bool c -> Bool (b && c) |> ok + | Bool b, Or, Bool c -> Bool (b || c) |> ok | _, Equal, _ -> Ok (Bool (left = right)) - | Number x, Greater, Number y -> Bool (x > y) |> Result.ok - | Number x, GreaterEqual, Number y -> Bool (x >= y) |> Result.ok - | Number x, Less, Number y -> Bool (x < y) |> Result.ok - | Number x, LessEqual, Number y -> Bool (x <= y) |> Result.ok - | String a, Greater, String b -> Bool (a > b) |> Result.ok - | String a, GreaterEqual, String b -> Bool (a >= b) |> Result.ok - | String a, Less, String b -> Bool (a < b) |> Result.ok - | String a, LessEqual, String b -> Bool (a <= b) |> Result.ok + | Number x, Greater, Number y -> Bool (x > y) |> ok + | Number x, GreaterEqual, Number y -> Bool (x >= y) |> ok + | Number x, Less, Number y -> Bool (x < y) |> ok + | Number x, LessEqual, Number y -> Bool (x <= y) |> ok + | String a, Greater, String b -> Bool (a > b) |> ok + | String a, GreaterEqual, String b -> Bool (a >= b) |> ok + | String a, Less, String b -> Bool (a < b) |> ok + | String a, LessEqual, String b -> Bool (a <= b) |> ok | _, _, _ -> let msg = Printf.sprintf "Invalid operands of type %s and %s to operator %s" (type_string_of_lox_value left) (type_string_of_lox_value right) (show_binary_op op) in - RuntimeError.make pos msg |> Result.error) + RuntimeError.make pos msg |> error) | Logical { op; left; right } -> ( let* left = interpret_expr env left in match (op, lox_value_to_bool left) with @@ -77,56 +122,60 @@ let rec interpret_expr (env : environment) (expr_node : expr_node) : | _ -> interpret_expr env right) | Call { callee; args } -> ( let* callee = interpret_expr env callee in - let f (acc : (lox_value list, runtime_error) result) (arg : expr_node) : - (lox_value list, runtime_error) result = - match acc with - | Ok acc -> - let* arg = interpret_expr env arg in - Ok (arg :: acc) - | Error e -> Error e + let f (acc : lox_value list interpreter_result) arg = + let* acc = acc in + let* arg = interpret_expr env arg in + Ok (arg :: acc) in let* args = List.fold_left f (Ok []) args in let args = List.rev args in match callee with - | Function (NativeFunction { name; arity; fn }) -> + | NativeFunction { name; arity; fn } -> ( let args_len = List.length args in if args_len <> arity then let msg = Printf.sprintf "Native Function %s has arity %d, but was called with %d args" name args_len arity in - RuntimeError.make pos msg |> Result.error - else fn args |> Result.map_error (RuntimeError.make pos) - | Function (LoxFunction { name; arity; arg_names; body }) -> + RuntimeError.make pos msg |> error + else + match fn args with Ok value -> Ok value | Error s -> RuntimeError.make pos s |> error) + | Function { name; arity; arg_names; body } -> ( let args_len = List.length args in if args_len <> arity then let msg = Printf.sprintf "Function %s has arity %d, but was called with %d args" name args_len arity in - RuntimeError.make pos msg |> Result.error + RuntimeError.make pos msg |> error else let env = Env.push_frame env in let () = List.iter2 (fun name value -> assert (Env.define env name value)) arg_names args in - let* () = interpret_stmt env body in - Ok Nil + let result = interpret_stmt env body in + match result with + | Ok () -> Ok Nil + | Return value -> Ok value + | Error _ as e -> e + | _ -> assert false) | _ -> ignore args; let msg = Printf.sprintf "%s object is not callable" (type_string_of_lox_value callee) in - RuntimeError.make pos msg |> Result.error) + RuntimeError.make pos msg |> error) -and interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runtime_error) result = - let { pos; stmt } = stmt_node in - ignore pos; +and interpret_stmt (env : environment) (stmt_node : stmt_node) : unit interpreter_result = + let { stmt; pos } = stmt_node in match stmt with | Expr expr -> let* value = interpret_expr env expr in ignore value; Ok () - | Break -> RuntimeError.break () |> Result.error - | Continue -> RuntimeError.continue () |> Result.error + | Break -> break () + | Continue -> continue () + | Return expr -> + let* value = interpret_expr env expr in + Return value | Print expr -> let* value = interpret_expr env expr in print_endline (Value.string_of_lox_value value); @@ -139,7 +188,7 @@ and interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runtime_ let msg = Printf.sprintf "Tried to define %s, but a variable of that name was already defined" name in - RuntimeError.make pos msg |> Result.error + RuntimeError.make pos msg |> error | FunDecl { name; arg_names; body } -> let fn = make_lox_function name arg_names body in let success = Env.define env name fn in @@ -149,7 +198,7 @@ and interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runtime_ Printf.sprintf "Tried to define function %s, but a variable of that name was already defined" name in - RuntimeError.make pos msg |> Result.error + RuntimeError.make pos msg |> error | Block stmts -> let env = Env.enter env in let rec _interpret stmts = @@ -171,9 +220,9 @@ and interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runtime_ if cond then let result = interpret_stmt env body in match result with - | Ok () | Error Continue -> interpret_stmt env stmt_node - | Error Break -> Ok () - | Error e -> Error e + | Ok () | Continue -> interpret_stmt env stmt_node + | Break -> Ok () + | other -> other else Ok () | For { init; cond; update; body } -> let env = Env.enter env in @@ -182,7 +231,7 @@ and interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runtime_ cond |> Option.map (interpret_expr env) |> Option.value ~default:(Ok (Value.Bool true)) - |> Result.map Value.lox_value_to_bool + |> map Value.lox_value_to_bool in let do_update () = update |> Option.map (interpret_expr env) |> Option.value ~default:(Ok Value.Nil) @@ -192,12 +241,12 @@ and interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runtime_ if cond then let result = interpret_stmt env body in match result with - | Ok () | Error Continue -> + | Ok () | Continue -> let* value = do_update () in ignore value; loop () - | Error Break -> Ok () - | Error e -> Error e + | Break -> Ok () + | other -> other else Ok () in loop () diff --git a/lib/lox.ml b/lib/lox.ml index 5d88c74..93a5c7b 100644 --- a/lib/lox.ml +++ b/lib/lox.ml @@ -49,9 +49,11 @@ let run ?(env : Environment.environment option) ?(debug = false) (source : strin let rec interpret_stmts (stmts : Stmt.stmt_node list) = match stmts with | [] -> Ok () - | stmt :: tail -> - let* () = Interpreter.interpret_stmt env stmt in - interpret_stmts tail + | stmt :: tail -> ( + match Interpreter.interpret_stmt env stmt with + | Ok () -> interpret_stmts tail + | Error e -> Error e + | _ -> assert false) in interpret_stmts stmts |> Error.of_runtimer_error diff --git a/lib/loxstd.ml b/lib/loxstd.ml index b35f4bf..d981164 100644 --- a/lib/loxstd.ml +++ b/lib/loxstd.ml @@ -23,6 +23,6 @@ let exit : native_function = { name = "exit"; arity = 1; fn } let init_std (env : environment) = - let register_fn fn = Env.define_global env fn.name (Function (NativeFunction fn)) in + let register_fn fn = Env.define_global env fn.name (NativeFunction fn) in let _ = List.map register_fn [ clock; exit ] in () diff --git a/lib/stmt.ml b/lib/stmt.ml index 14b3fcd..d10fbc3 100644 --- a/lib/stmt.ml +++ b/lib/stmt.ml @@ -5,6 +5,7 @@ type stmt = | Expr of expr_node | Break | Continue + | Return of expr_node | Print of expr_node | VarDecl of { name : string; init : expr_node option } | FunDecl of { name : string; arg_names : string list; body : stmt_node } @@ -28,6 +29,9 @@ let rec show_stmt ?(indent = 0) stmt = | Expr expr -> indent_s ^ "Expr\n" ^ show_expr_ind expr.expr | Break -> indent_s ^ "Break" | Continue -> indent_s ^ "Continue" + | Return expr -> + let expr_s = show_expr_ind expr.expr in + "Return\n" ^ expr_s | Print expr -> indent_s ^ "Print\n" ^ show_expr_ind expr.expr | VarDecl { name; init } -> let init_s = match init with Some init -> " = \n" ^ show_expr_ind init.expr | None -> "" in @@ -90,6 +94,7 @@ let make_stmt_node (pos : code_pos) (stmt : stmt) : stmt_node = { stmt; pos } let make_expr_stmt (pos : code_pos) (expr : expr_node) : stmt_node = Expr expr |> make_stmt_node pos let make_break (pos : code_pos) : stmt_node = Break |> make_stmt_node pos let make_continue (pos : code_pos) : stmt_node = Continue |> make_stmt_node pos +let make_return (pos : code_pos) (expr : expr_node) = Return expr |> make_stmt_node pos let make_print (pos : code_pos) (expr : expr_node) : stmt_node = Print expr |> make_stmt_node pos let make_var_decl (pos : code_pos) (name : string) (init : expr_node option) = diff --git a/lib/value.ml b/lib/value.ml index e55f588..b4cf40d 100644 --- a/lib/value.ml +++ b/lib/value.ml @@ -1,6 +1,7 @@ type lox_function = { name : string; arity : int; + (* env : Environment.environment; *) arg_names : string list; body : Stmt.stmt_node; [@printer fun fmt _ -> fprintf fmt "
"] } @@ -13,16 +14,19 @@ type native_function = { } [@@deriving show { with_path = false }] -and function_ = NativeFunction of native_function | LoxFunction of lox_function -[@@deriving show { with_path = false }] - -and lox_value = Function of function_ | String of string | Number of float | Bool of bool | Nil +and lox_value = + | Function of lox_function + | NativeFunction of native_function + | String of string + | Number of float + | Bool of bool + | Nil [@@deriving show { with_path = false }] let string_of_lox_value lox_value = match lox_value with - | Function (NativeFunction { name; arity; _ }) -> Printf.sprintf "