diff --git a/lib/error.ml b/lib/error.ml index e5e2cf6..2c62403 100644 --- a/lib/error.ml +++ b/lib/error.ml @@ -21,38 +21,38 @@ module ParserError = struct Printf.printf "ParserError at line %d, column %d: %s\n" e.pos.line e.pos.col e.msg end -type interpreter_error = { pos : code_pos; msg : string } +type runtime_error = { pos : code_pos; msg : string } -module InterpreterError = struct +module RuntimeError = struct type t = parser_error - let make (pos : code_pos) (msg : string) : interpreter_error = { pos; msg } + let make (pos : code_pos) (msg : string) : runtime_error = { pos; msg } - let print (e : interpreter_error) = - Printf.printf "InterpreterError at line %d, column %d: %s\n" e.pos.line e.pos.col e.msg + let print (e : runtime_error) = + Printf.printf "RuntimeError at line %d, column %d: %s\n" e.pos.line e.pos.col e.msg end type lox_error = | LexerError of lexer_error list | ParserError of parser_error list - | InterpreterError of interpreter_error + | RuntimeError of runtime_error let print_error (e : lox_error) = match e with | LexerError es -> let num_errors = List.length es in - assert (num_errors != 0); + assert (num_errors <> 0); Printf.printf "found %d %s:\n" num_errors (if num_errors = 1 then "LexerError" else "LexerErrors"); List.iter LexerError.print es | ParserError es -> let num_errors = List.length es in - assert (num_errors != 0); + assert (num_errors <> 0); Printf.printf "found %d %s:\n" num_errors (if num_errors = 1 then "ParserError" else "ParserErrors"); List.iter ParserError.print es - | InterpreterError e -> InterpreterError.print e + | RuntimeError e -> RuntimeError.print e let of_lexer_error e = Result.map_error (fun e -> LexerError e) e let of_parser_error e = Result.map_error (fun e -> ParserError e) e -let of_interpreter_error e = Result.map_error (fun e -> InterpreterError e) e +let of_runtimer_error e = Result.map_error (fun e -> RuntimeError e) e diff --git a/lib/expr.ml b/lib/expr.ml index 6b4874d..5485060 100644 --- a/lib/expr.ml +++ b/lib/expr.ml @@ -1,12 +1,12 @@ type literal = String of string | Number of float | Bool of bool | Nil -(* [@@deriving show { with_path = false }] *) +[@@deriving show { with_path = false }] -let show_literal literal = - match literal with - | String s -> s - | Number x -> string_of_float x - | Bool b -> string_of_bool b - | Nil -> "nil" +(* let show_literal literal = + match literal with + | String s -> s + | Number x -> string_of_float x + | Bool b -> string_of_bool b + | Nil -> "nil" *) type binary_op = | Plus diff --git a/lib/interpreter.ml b/lib/interpreter.ml index 1984f93..07423f8 100644 --- a/lib/interpreter.ml +++ b/lib/interpreter.ml @@ -2,14 +2,15 @@ let ( let* ) = Result.bind open Expr open Error +open Stmt open Value let value_of_literal (literal : literal) : Value.lox_value = match literal with String s -> String s | Number x -> Number x | Bool b -> Bool b | Nil -> Nil -let rec interpret_expr (expr : expr_node) : (lox_value, interpreter_error) result = - let pos = expr.pos in - match expr.expr with +let rec interpret_expr (expr : expr_node) : (lox_value, runtime_error) result = + let { pos; expr } = expr in + match expr with | Literal literal -> Ok (value_of_literal literal) | BinaryExpr { op; left; right } -> ( let* left = interpret_expr left in @@ -19,21 +20,28 @@ let rec interpret_expr (expr : expr_node) : (lox_value, interpreter_error) resul | Number x, Plus, Number y -> Ok (Number (x +. y)) | Number x, Minus, Number y -> Ok (Number (x -. y)) | Number x, Mul, Number y -> Ok (Number (x *. y)) - | Number x, Div, Number y -> Ok (Number (x /. y)) - | Number x, Equal, Number y -> Ok (Bool (x = y)) + | Number x, Div, Number y -> + if y <> 0. then Ok (Number (x /. y)) + else + let msg = "Division by 0" in + Error { pos; msg } + | Bool b, And, Bool c -> Ok (Bool (b && c)) + | Bool b, Or, Bool c -> Ok (Bool (b || c)) + | _, Equal, _ -> Ok (Bool (left = right)) | Number x, Greater, Number y -> Ok (Bool (x > y)) | Number x, GreaterEqual, Number y -> Ok (Bool (x >= y)) | Number x, Less, Number y -> Ok (Bool (x < y)) | Number x, LessEqual, Number y -> Ok (Bool (x <= y)) - | Bool b, And, Bool c -> Ok (Bool (b && c)) - | Bool b, Or, Bool c -> Ok (Bool (b || c)) - | _, Equal, _ -> Ok (Bool (left = right)) + | String a, Greater, String b -> Ok (Bool (a > b)) + | String a, GreaterEqual, String b -> Ok (Bool (a >= b)) + | String a, Less, String b -> Ok (Bool (a < b)) + | String a, LessEqual, String b -> Ok (Bool (a <= b)) | _, _, _ -> let msg = Printf.sprintf "Invalid operands of type %s and %s to operator %s" (type_string_of_lox_value left) (type_string_of_lox_value right) (show_binary_op op) in - Error (InterpreterError.make pos msg)) + Error { pos; msg }) | UnaryExpr { op; expr } -> ( let* expr = interpret_expr expr in match (op, expr) with @@ -44,4 +52,16 @@ let rec interpret_expr (expr : expr_node) : (lox_value, interpreter_error) resul Printf.sprintf "Invalid operant of type %s to operator %s" (type_string_of_lox_value expr) (show_unary_op op) in - Error (InterpreterError.make pos msg)) + Error (RuntimeError.make pos msg)) + +let interpret_stmt (stmt : stmt_node) : (unit, runtime_error) result = + let { pos; stmt } = stmt in + ignore pos; + match stmt with + | Print expr -> + let* value = interpret_expr expr in + print_endline (Value.string_of_lox_value value); + Ok () + | Expr expr -> + let* _ = interpret_expr expr in + Ok () diff --git a/lib/lox.ml b/lib/lox.ml index c973fea..bdd04e7 100644 --- a/lib/lox.ml +++ b/lib/lox.ml @@ -5,6 +5,7 @@ module Expr = Expr module Interpreter = Interpreter module Lexer = Lexer module Parser = Parser +module Stmt = Stmt type lox_error = Error.lox_error @@ -16,9 +17,16 @@ let run (source : string) : (unit, lox_error) result = print_newline (); *) let* ast = Error.of_parser_error (Parser.parse tokens) in (* Printf.printf "%s\n" (Expr.show_expr expr); *) - let* value = Error.of_interpreter_error (Interpreter.interpret_expr ast) in - print_endline (Value.string_of_lox_value value); - Ok () + (* let* value = Error.of_interpreter_error (Interpreter.interpret_expr ast) in + print_endline (Value.string_of_lox_value value); *) + let rec interpret_stmts (stmts : Stmt.stmt_node list) = + match stmts with + | [] -> Ok () + | stmt :: tail -> + let* _ = Interpreter.interpret_stmt stmt in + interpret_stmts tail + in + interpret_stmts ast |> Error.of_runtimer_error let runRepl () : unit = try diff --git a/lib/parser.ml b/lib/parser.ml index 76aaff1..c419cb9 100644 --- a/lib/parser.ml +++ b/lib/parser.ml @@ -3,8 +3,10 @@ let ( let* ) = Result.bind open Error open Expr open Lexer +open Stmt -type parse_result = (expr_node, parser_error list) result +type parse_result = (stmt_node list, parser_error list) result +type stmt_result = (stmt_node, parser_error) result type expr_result = (expr_node, parser_error) result type state = { tokens : token list; errors_rev : parser_error list } @@ -16,6 +18,8 @@ let append_error msg pos state = let advance state = state := { !state with tokens = List.tl !state.tokens } let peek state = List.hd !state.tokens +let peek_tt (state : state ref) : token_type = (peek state).token_type +let cur_pos state = (peek state).pos let next state = assert (not ((List.hd !state.tokens).token_type == Eof)); @@ -24,16 +28,22 @@ let next state = token let advance_if state tt = - if (peek state).token_type == tt then ( + if peek_tt state == tt then ( advance state; true) else false -let matches state tts = - let f = ( == ) (peek state).token_type in - Array.fold_left (fun acc tt -> acc || f tt) false tts +let consume state tt = + if advance_if state tt then Ok () + else + Error + (ParserError.make (cur_pos state) + (Printf.sprintf "Expected %s, but got %s" (show_token_type tt) + (show_token_type (peek_tt state)))) -let cur_pos state = (peek state).pos +let matches state tts = + let f = ( == ) (peek_tt state) in + Array.fold_left (fun acc tt -> acc || f tt) false tts let collect_chain (state : state ref) (tts : token_type array) (higher_prec : state ref -> expr_result) : ((expr_node * token) array, parser_error) result = @@ -49,7 +59,7 @@ let collect_chain (state : state ref) (tts : token_type array) let primary (state : state ref) : expr_result = let pos = cur_pos state in - match (peek state).token_type with + match peek_tt state with | Number x -> advance state; Ok (make_number pos x) @@ -77,7 +87,7 @@ let rec grouping (state : state ref) : expr_result = if advance_if state RightParen then Ok expr (* expect a ) here *) else let pos = cur_pos state in - let tt = (peek state).token_type in + let tt = peek_tt state in let msg = Printf.sprintf "Expected RightParen, got %s instead" (show_token_type tt) in Error { pos; msg }) else primary state @@ -100,7 +110,7 @@ and neg_not (state : state ref) : expr_result = and mul_or_div (state : state ref) : expr_result = let* expr = neg_not state in let* exprs_tokens = collect_chain state [| Star; Slash |] neg_not in - let f acc (expr, token) = + let f acc (expr, (token : token)) = let pos = token.pos in let op : binary_op = match token.token_type with @@ -116,7 +126,7 @@ and mul_or_div (state : state ref) : expr_result = and sum_or_diff (state : state ref) : expr_result = let* expr = mul_or_div state in let* exprs_tokens = collect_chain state [| Plus; Minus |] mul_or_div in - let f acc (expr, token) = + let f acc (expr, (token : token)) = let pos = token.pos in let op : binary_op = match token.token_type with @@ -135,7 +145,7 @@ and inequality (state : state ref) : expr_result = let* exprs_tokens = collect_chain state [| Greater; GreaterEqual; Less; LessEqual |] sum_or_diff in - let f acc (expr, token) = + let f acc (expr, (token : token)) = let pos = token.pos in let (op : binary_op) = match token.token_type with @@ -153,7 +163,7 @@ and inequality (state : state ref) : expr_result = and equality (state : state ref) : expr_result = let* expr = inequality state in let* exprs_tokens = collect_chain state [| EqualEqual; BangEqual |] inequality in - let f acc (expr, token) = + let f acc (expr, (token : token)) = let pos = token.pos in match token.token_type with | EqualEqual -> make_binary pos Equal acc expr @@ -167,26 +177,40 @@ and equality (state : state ref) : expr_result = and expression (state : state ref) : expr_result = equality state +let statement (state : state ref) : stmt_result = + let pos = cur_pos state in + match peek_tt state with + | Print -> + advance state; + let* expr = expression state in + let* _ = consume state Semicolon in + let stmt = make_print pos expr in + Ok stmt + | tt -> + let msg = + Printf.sprintf "Statement stating with %s not yet implemented" (show_token_type tt) + in + Error (ParserError.make pos msg) + let rec synchronise (state : state ref) = - match (peek state).token_type with + match peek_tt state with | Semicolon -> advance state | Class | Fun | Var | For | If | While | Print | Return | Eof -> () | _ -> advance state; synchronise state -let parse (tokens : token list) : parse_result = +let rec parse (tokens : token list) : parse_result = let state = ref { tokens; errors_rev = [] } in - let result = expression state |> Result.map_error (fun e -> [ e ]) in - assert (Result.is_error result || (peek state).token_type = Eof); - result -(* let expr = State.expression state in - let state = - if not (State.is_at_end state) then - let tt = (State.peek state).token_type in - let msg = Printf.sprintf "Unexpected %s at end" (show_token_type tt) in - State.append_error msg (State.peek state).pos state - else state - in - (* if List.length state.errors_rev != 0 then Ok expr else Error (List.rev state.errors_rev) *) - match state.errors_rev with [] -> Ok expr | es -> Error (List.rev es) *) + let result = statement state in + match result with + | Ok stmt when peek_tt state == Eof -> Ok [ stmt ] + | Ok stmt -> + let* stmts = parse !state.tokens in + Ok (stmt :: stmts) + | Error e -> ( + synchronise state; + if peek_tt state == Eof then Error [ e ] + else + let tail_result = parse !state.tokens in + match tail_result with Ok _ -> Error [ e ] | Error es -> Error (e :: es)) diff --git a/lib/stmt.ml b/lib/stmt.ml new file mode 100644 index 0000000..9fbd7aa --- /dev/null +++ b/lib/stmt.ml @@ -0,0 +1,10 @@ +type stmt = Expr of Expr.expr_node | Print of Expr.expr_node +and stmt_node = { stmt : stmt; pos : Error.code_pos } + +let make_expr_stmt (pos : Error.code_pos) (expr : Expr.expr_node) : stmt_node = + let stmt = Expr expr in + { stmt; pos } + +let make_print (pos : Error.code_pos) (expr : Expr.expr_node) : stmt_node = + let stmt = Print expr in + { stmt; pos }