diff --git a/lib/interpreter.ml b/lib/interpreter.ml index 8d7b267..54b7912 100644 --- a/lib/interpreter.ml +++ b/lib/interpreter.ml @@ -87,17 +87,37 @@ let rec interpret_expr (env : environment) (expr_node : expr_node) : in let* args = List.fold_left f (Ok []) args in let args = List.rev args in - (* let args_s = - List.fold_left (fun acc value -> acc ^ " " ^ string_of_lox_value value) "" args - in - Printf.eprintf "Called %s with args%s\n%!" (string_of_lox_value callee) args_s; *) match callee with + | Function (NativeFunction { name; arity; fn }) -> + let args_len = List.length args in + if args_len <> arity then + let msg = + Printf.sprintf "Native Function %s has arity %d, but was called with %d args" name + args_len arity + in + RuntimeError.make pos msg |> Result.error + else fn args |> Result.map_error (RuntimeError.make pos) + | Function (LoxFunction { name; arity; arg_names; body }) -> + let args_len = List.length args in + if args_len <> arity then + let msg = + Printf.sprintf "Function %s has arity %d, but was called with %d args" name args_len + arity + in + RuntimeError.make pos msg |> Result.error + else + let env = Env.push_frame env in + let () = + List.iter2 (fun name value -> assert (Env.define env name value)) arg_names args + in + let* () = interpret_stmt env body in + Ok Nil | _ -> ignore args; let msg = Printf.sprintf "%s object is not callable" (type_string_of_lox_value callee) in RuntimeError.make pos msg |> Result.error) -let rec interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runtime_error) result = +and interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runtime_error) result = let { pos; stmt } = stmt_node in ignore pos; match stmt with @@ -116,7 +136,19 @@ let rec interpret_stmt (env : environment) (stmt_node : stmt_node) : (unit, runt let success = Env.define env name init in if success then Ok () else - let msg = Printf.sprintf "Tried to define %s, but was already defined" name in + let msg = + Printf.sprintf "Tried to define %s, but a variable of that name was already defined" name + in + RuntimeError.make pos msg |> Result.error + | FunDecl { name; arg_names; body } -> + let fn = make_lox_function name arg_names body in + let success = Env.define env name fn in + if success then Ok () + else + let msg = + Printf.sprintf + "Tried to define function %s, but a variable of that name was already defined" name + in RuntimeError.make pos msg |> Result.error | Block stmts -> let env = Env.enter env in diff --git a/lib/parser.ml b/lib/parser.ml index 3dc4bd4..51f6b64 100644 --- a/lib/parser.ml +++ b/lib/parser.ml @@ -6,11 +6,11 @@ open Lexer open Stmt type parse_result = (stmt_node list, parser_error list) result -type state = { tokens : token list ref; is_in_loop : bool } +type state = { tokens : token list ref; is_in_loop : bool; is_in_fun : bool } type stmt_result = (stmt_node, parser_error) result type expr_result = (expr_node, parser_error) result -let make_state tokens = { tokens; is_in_loop = false } +let make_state tokens = { tokens; is_in_loop = false; is_in_fun = true } let with_is_in_loop (f : state -> 'a) (state : state) : 'a = let new_state = { state with is_in_loop = true } in @@ -18,6 +18,12 @@ let with_is_in_loop (f : state -> 'a) (state : state) : 'a = (* state.tokens <- new_state.tokens; *) result +let with_is_in_fun (f : state -> 'a) (state : state) : 'a = + let new_state = { state with is_in_fun = true; is_in_loop = false } in + let result = f new_state in + (* state.tokens <- new_state.tokens; *) + result + let is_at_end state = assert (not (List.is_empty !(state.tokens))); (List.hd !(state.tokens)).token_type = Eof @@ -61,12 +67,12 @@ let matches state tts = let f = ( = ) (peek_tt state) in Array.fold_left (fun acc tt -> acc || f tt) false tts -let collect_chain (state : state) (tts : token_type array) (higher_prec : state -> expr_result) : - ((expr_node * token) list, parser_error) result = - let rec collect_chain_rec (acc : (expr_node * token) list) = +let collect_chain (tts : token_type array) (collector : state -> ('a, parser_error) result) + (state : state) : (('a * token) list, parser_error) result = + let rec collect_chain_rec (acc : ('a * token) list) = if (not (is_at_end state)) && matches state tts then let token = next state in - let* expr = higher_prec state in + let* expr = collector state in let acc = (expr, token) :: acc in collect_chain_rec acc else Ok acc @@ -119,7 +125,7 @@ and call (state : state) : expr_result = if peek_tt state = RightParen then Ok [] else let* first_arg = expression state in - let* exprs_tokens = collect_chain state [| Comma |] expression in + let* exprs_tokens = collect_chain [| Comma |] expression state in let other_args = List.map fst exprs_tokens in let args = first_arg :: other_args in Ok args @@ -149,7 +155,7 @@ and neg_not (state : state) : expr_result = and mul_or_div (state : state) : expr_result = let* expr = neg_not state in - let* exprs_tokens = collect_chain state [| Star; Slash |] neg_not in + let* exprs_tokens = collect_chain [| Star; Slash |] neg_not state in let f acc (expr, (token : token)) = let pos = token.pos in let op : binary_op = @@ -165,7 +171,7 @@ and mul_or_div (state : state) : expr_result = and sum_or_diff (state : state) : expr_result = let* expr = mul_or_div state in - let* exprs_tokens = collect_chain state [| Plus; Minus |] mul_or_div in + let* exprs_tokens = collect_chain [| Plus; Minus |] mul_or_div state in let f acc (expr, (token : token)) = let pos = token.pos in let op : binary_op = @@ -183,7 +189,7 @@ and inequality (state : state) : expr_result = (* TODO: maybe rework to only have Less and Greater as ops; performance? *) let* expr = sum_or_diff state in let* exprs_tokens = - collect_chain state [| Greater; GreaterEqual; Less; LessEqual |] sum_or_diff + collect_chain [| Greater; GreaterEqual; Less; LessEqual |] sum_or_diff state in let f acc (expr, (token : token)) = let pos = token.pos in @@ -202,7 +208,7 @@ and inequality (state : state) : expr_result = and equality (state : state) : expr_result = let* expr = inequality state in - let* exprs_tokens = collect_chain state [| EqualEqual; BangEqual |] inequality in + let* exprs_tokens = collect_chain [| EqualEqual; BangEqual |] inequality state in let f acc (expr, (token : token)) = let pos = token.pos in match token.token_type with @@ -217,7 +223,7 @@ and equality (state : state) : expr_result = and logical_and (state : state) : expr_result = let* expr = equality state in - let* exprs_tokens = collect_chain state [| And |] equality in + let* exprs_tokens = collect_chain [| And |] equality state in let f acc (expr, (token : token)) = let pos = token.pos in assert (token.token_type = And); @@ -228,7 +234,7 @@ and logical_and (state : state) : expr_result = and logical_or (state : state) : expr_result = let* expr = logical_and state in - let* exprs_tokens = collect_chain state [| Or |] logical_and in + let* exprs_tokens = collect_chain [| Or |] logical_and state in let f acc (expr, (token : token)) = let pos = token.pos in assert (token.token_type = Or); @@ -360,7 +366,7 @@ and statement (state : state) : stmt_result = and var_declaration (state : state) : stmt_result = let pos = cur_pos state in (* consume var token *) - assert ((next state).token_type = Var); + let* () = consume state Var in let* name = consume_identifier state in let* init = if Equal = peek_tt state then @@ -374,8 +380,34 @@ and var_declaration (state : state) : stmt_result = let* () = consume state Semicolon in make_var_decl pos name init |> Result.ok +and fun_declaration (state : state) : stmt_result = + let pos = cur_pos state in + let* () = consume state Fun in + let* name = consume_identifier state in + let* () = consume state LeftParen in + let* arg_names = + if peek_tt state = RightParen then Ok [] + else + let* first_arg = consume_identifier state in + let* exprs_tokens = collect_chain [| Comma |] consume_identifier state in + let other_args = List.map fst exprs_tokens in + let args = first_arg :: other_args in + Ok args + in + if List.length arg_names >= 255 then + let msg = Printf.sprintf "Function %s can't have more than 255 arguments" name in + ParserError.make pos msg |> Result.error + else + let* () = consume state RightParen in + let* body = with_is_in_fun block state in + make_fun_decl pos name arg_names body |> Result.ok +(* make_nil pos |> make_expr_stmt pos |> Result.ok *) + and declaration (state : state) : stmt_result = - match peek_tt state with Var -> var_declaration state | _ -> statement state + match peek_tt state with + | Var -> var_declaration state + | Fun -> fun_declaration state + | _ -> statement state let rec synchronise (state : state) = match peek_tt state with diff --git a/lib/stmt.ml b/lib/stmt.ml index d1d600b..14b3fcd 100644 --- a/lib/stmt.ml +++ b/lib/stmt.ml @@ -7,6 +7,7 @@ type stmt = | Continue | Print of expr_node | VarDecl of { name : string; init : expr_node option } + | FunDecl of { name : string; arg_names : string list; body : stmt_node } | Block of stmt_node list | If of { cond : expr_node; then_ : stmt_node; else_ : stmt_node option } | While of { cond : expr_node; body : stmt_node } @@ -31,6 +32,11 @@ let rec show_stmt ?(indent = 0) stmt = | VarDecl { name; init } -> let init_s = match init with Some init -> " = \n" ^ show_expr_ind init.expr | None -> "" in indent_s ^ "Var " ^ name ^ init_s + | FunDecl { name; arg_names; body } -> + let args_s = List.fold_left (fun acc arg -> acc ^ "\n" ^ arg) "" arg_names in + let body_s = show_stmt_ind ~add:4 body.stmt in + indent_s ^ "Fun " ^ name ^ "\n" ^ indent_s ^ " Args" ^ args_s ^ indent_s ^ " Body\n" + ^ body_s | Block stmts -> let stmts_s = List.fold_left (fun acc stmt -> acc ^ show_stmt_ind stmt.stmt ^ "\n") "" stmts @@ -89,6 +95,9 @@ let make_print (pos : code_pos) (expr : expr_node) : stmt_node = Print expr |> m let make_var_decl (pos : code_pos) (name : string) (init : expr_node option) = VarDecl { name; init } |> make_stmt_node pos +let make_fun_decl (pos : code_pos) (name : string) (arg_names : string list) (body : stmt_node) = + FunDecl { name; arg_names; body } |> make_stmt_node pos + let make_block (pos : code_pos) (stmts : stmt_node list) : stmt_node = Block stmts |> make_stmt_node pos