From 3590a78154a605c53a0234c9ae0dc19bacfab51a Mon Sep 17 00:00:00 2001 From: Moritz Gmeiner Date: Mon, 12 Aug 2024 16:31:28 +0200 Subject: [PATCH] more work on the parser --- lib/error.ml | 24 +++++++-- lib/expr.ml | 36 +++++++++++++ lib/lexer.ml | 11 ++-- lib/lexer.mli | 85 +++++++++++++++++++++++++++++++ lib/lox.ml | 6 ++- lib/parser.ml | 137 ++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 287 insertions(+), 12 deletions(-) create mode 100644 lib/expr.ml create mode 100644 lib/lexer.mli create mode 100644 lib/parser.ml diff --git a/lib/error.ml b/lib/error.ml index d7a4471..8e7a4f0 100644 --- a/lib/error.ml +++ b/lib/error.ml @@ -4,15 +4,24 @@ type lexer_error = { pos : code_pos; msg : string } module LexerError = struct type t = lexer_error - let make (pos : code_pos) (msg : string) : lexer_error = - (* let pos = { line; col } in *) - { pos; msg } + let make (pos : code_pos) (msg : string) : lexer_error = { pos; msg } let print (e : lexer_error) = Printf.printf "LexerError at line %d, column %d: %s\n" e.pos.line e.pos.col e.msg end -type lox_error = LexerError of lexer_error list +type parser_error = { pos : code_pos; msg : string } + +module ParserError = struct + type t = parser_error + + let make (pos : code_pos) (msg : string) : parser_error = { pos; msg } + + let print (e : parser_error) = + Printf.printf "ParserError at line %d, column %d: %s\n" e.pos.line e.pos.col e.msg +end + +type lox_error = LexerError of lexer_error list | ParserError of parser_error list let print_error (e : lox_error) = match e with @@ -22,5 +31,12 @@ let print_error (e : lox_error) = Printf.printf "found %d %s:\n" num_errors (if num_errors = 1 then "LexerError" else "LexerErrors"); List.iter LexerError.print es + | ParserError es -> + let num_errors = List.length es in + assert (num_errors != 0); + Printf.printf "found %d %s:\n" num_errors + (if num_errors = 1 then "ParserError" else "ParserErrors"); + List.iter ParserError.print es let of_lexer_error e = Result.map_error (fun e -> LexerError e) e +let of_parser_error e = Result.map_error (fun e -> ParserError e) e diff --git a/lib/expr.ml b/lib/expr.ml new file mode 100644 index 0000000..f2b135c --- /dev/null +++ b/lib/expr.ml @@ -0,0 +1,36 @@ +type literal = String of string | Number of float | Bool of bool | Nil +(* [@@deriving show { with_path = false }] *) + +let show_literal literal = + match literal with + | String s -> s + | Number x -> string_of_float x + | Bool b -> string_of_bool b + | Nil -> "nil" + +type binary_op = Plus | Minus | Mul | Div | Equal | Less | Greater | LessEqual | GreaterEqual +[@@deriving show { with_path = false }] + +type unary_op = Neg | Not [@@deriving show { with_path = false }] + +type expr = + | Literal of literal + | BinaryExpr of { op : binary_op; left : expr; right : expr } + | UnaryExpr of { op : unary_op; expr : expr } +(* [@@deriving show { with_path = false }] *) + +let rec show_expr ?(indent = 0) expr = + let show_indented = show_expr ~indent:(indent + 2) in + let ident_s = String.make indent ' ' in + match expr with + | Literal literal -> ident_s ^ show_literal literal + | BinaryExpr { op; left; right } -> + ident_s ^ show_binary_op op ^ "\n" ^ show_indented left ^ "\n" ^ show_indented right + | UnaryExpr { op; expr } -> ident_s ^ show_unary_op op ^ "\n" ^ show_indented expr + +let make_string (s : string) = Literal (String s) +let make_number (x : float) = Literal (Number x) +let make_bool (b : bool) = Literal (Bool b) +let make_nil () = Literal Nil +let make_binary (op : binary_op) (left : expr) (right : expr) = BinaryExpr { op; left; right } +let make_unary (op : unary_op) (expr : expr) = UnaryExpr { op; expr } diff --git a/lib/lexer.ml b/lib/lexer.ml index b2a7afc..fd7a3c8 100644 --- a/lib/lexer.ml +++ b/lib/lexer.ml @@ -56,7 +56,7 @@ type state = { } module State = struct - type t = state + (* type t = state *) let is_digit c = match c with '0' .. '9' -> true | _ -> false let is_alpha c = match c with 'a' .. 'z' | 'A' .. 'Z' -> true | _ -> false @@ -95,10 +95,6 @@ module State = struct | Some c when f c -> advance_while f (snd (advance state)) | _ -> state (* EOF or no match *) - let last_char (state : state) : char = - assert (state.cur_pos > 0); - state.source.[state.cur_pos - 1] - let append_token (pos : code_pos) (token_type : token_type) (state : state) : state = (* let pos = { line = state.line; col = state.col } in *) { state with tokens_rev = { token_type; pos } :: state.tokens_rev } @@ -208,5 +204,6 @@ let tokenize (source : string) : lexer_result = { source; start_pos = 0; cur_pos = 0; tokens_rev = []; errors_rev = []; line = 1; col = 0 } in (* reverse the reversed tokens/errors *) - if List.length state.errors_rev = 0 then Ok (List.rev state.tokens_rev) - else Error (List.rev state.errors_rev) + (* if List.length state.errors_rev = 0 then Ok (List.rev state.tokens_rev) + else Error (List.rev state.errors_rev) *) + match state.errors_rev with [] -> Ok (List.rev state.tokens_rev) | es -> Error (List.rev es) diff --git a/lib/lexer.mli b/lib/lexer.mli new file mode 100644 index 0000000..0cdbd71 --- /dev/null +++ b/lib/lexer.mli @@ -0,0 +1,85 @@ +type token_type = + | LeftParen + | RightParen + | LeftBrace + | RightBrace + | Plus + | Minus + | Star + | Slash + | Bang + | Dot + | Comma + | Semicolon + | Equal + | EqualEqual + | BangEqual + | Greater + | Less + | GreaterEqual + | LessEqual + | Identifier of string + | String of string + | Number of float + | And + | Class + | Else + | False + | Fun + | For + | If + | Nil + | Or + | Print + | Return + | Super + | This + | True + | Var + | While + | Comment of string + | Eof + +val pp_token_type : Format.formatter -> token_type -> unit +val show_token_type : token_type -> string +val keywords : (string, token_type) Hashtbl.t + +type token = { token_type : token_type; pos : Error.code_pos } + +val show_token : token -> string + +type lexer_result = (token list, Error.lexer_error list) result + +(* type state = { + source : string; + start_pos : int; + cur_pos : int; + tokens_rev : token list; + errors_rev : Error.lexer_error list; + line : int; + col : int; + } + module State : sig + type t = state + + val is_digit : char -> bool + val is_alpha : char -> bool + val is_alphanum : char -> bool + val is_identifier : char -> bool + val is_at_end : state -> bool + val get_lexeme : state -> int -> int -> string + val advance : state -> char * state + val peek : state -> char option + val advance_if : char -> state -> bool * state + val advance_until : char -> state -> bool * state + val advance_while : (char -> bool) -> state -> state + val last_char : state -> char + val append_token : Error.code_pos -> token_type -> state -> state + val append_error : Error.code_pos -> string -> state -> state + val parse_number : state -> state + val parse_keyword_or_identifier : state -> state + val parse_block_commend : state -> state + val tokenize_rec : state -> state + end *) + +val tokenize : string -> lexer_result diff --git a/lib/lox.ml b/lib/lox.ml index 5195f48..58d60fb 100644 --- a/lib/lox.ml +++ b/lib/lox.ml @@ -1,7 +1,9 @@ let ( let* ) = Result.bind -module Lexer = Lexer module Error = Error +module Expr = Expr +module Lexer = Lexer +module Parser = Parser type token = Lexer.token type lox_error = Error.lox_error @@ -13,6 +15,8 @@ let run (source : string) : (unit, lox_error) result = Printf.printf "Got %d tokens\n" (List.length tokens); List.iter f tokens; print_endline ""; + let* expr = Error.of_parser_error (Parser.parse tokens) in + Printf.printf "%s\n" (Expr.show_expr expr); Ok () let runRepl () : unit = diff --git a/lib/parser.ml b/lib/parser.ml new file mode 100644 index 0000000..59da51a --- /dev/null +++ b/lib/parser.ml @@ -0,0 +1,137 @@ +open Error +open Expr +open Lexer + +type parse_result = (expr, parser_error list) result +type state = { tokens : token list; errors_rev : parser_error list } + +module State = struct + let is_at_end state = (List.hd state.tokens).token_type == Eof + + let append_error msg pos state = + let e = { pos; msg } in + { state with errors_rev = e :: state.errors_rev } + + let advance state = { state with tokens = List.tl state.tokens } + let peek state = List.hd state.tokens + + let next state = + assert (not ((List.hd state.tokens).token_type == Eof)); + (List.hd state.tokens, advance state) + + let advance_if state tt = + if (peek state).token_type == tt then (true, advance state) else (false, state) + + let matches state tts = + let f = ( == ) (peek state).token_type in + List.fold_left (fun acc tt -> acc || f tt) false tts + + let collect_chain (state : state) (tts : token_type list) higher_prec : + (expr * token) array * state = + (* ([], state) *) + let state_ref = ref state in + let out_list_rev = ref [] in + while (not (is_at_end !state_ref)) && matches !state_ref tts do + let token, state = next !state_ref in + let expr, state = higher_prec state in + state_ref := state; + out_list_rev := (expr, token) :: !out_list_rev + done; + (Array.of_list (List.rev !out_list_rev), !state_ref) + + let mul_or_div (state : state) : expr * state = + let token, state = next state in + (make_string @@ show_token_type token.token_type, state) + + let sum_or_diff (state : state) : expr * state = + let expr, state = mul_or_div state in + (* if (not (is_at_end state)) && matches state [ Plus; Minus ] then + let token, state = next state in + let expr2, state = sum_or_diff state in + let (op : binary_op) = + match token.token_type with + | Plus -> Plus + | Minus -> Minus + | _ -> assert false (* should only be here if tt is + - *) + in + let expr = make_binary op expr expr2 in + (expr, state) + else (expr, state) *) + (* turn expr and state to refs for the loop *) + (* Printf.printf "expr: %s\n\n" (show_expr expr); + let expr_ref, state_ref = (ref expr, ref state) in + while (not (is_at_end !state_ref)) && matches !state_ref [ Plus; Minus ] do + let token, state = next !state_ref in + let expr2, state = mul_or_div state in + let (op : binary_op) = + match token.token_type with + | Plus -> Plus + | Minus -> Minus + | _ -> assert false (* should only be here if tt is + - *) + in + let expr = make_binary op !expr_ref expr2 in + (* (expr_ref, state_ref) := (expr, state) *) + Printf.printf "expr: %s\n\n" (show_expr expr); + expr_ref := expr; + state_ref := state + done; + (!expr_ref, !state_ref) *) + let exprs_tokens, state = collect_chain state [ Plus; Minus ] mul_or_div in + let f acc (expr, token) = + let op : binary_op = + match token.token_type with Plus -> Plus | Minus -> Minus | _ -> assert false + in + make_binary op acc expr + in + let expr = Array.fold_left f expr exprs_tokens in + (expr, state) + + let rec inequality (state : state) : expr * state = + let expr, state = sum_or_diff state in + if (not (is_at_end state)) && matches state [ Greater; GreaterEqual; Less; LessEqual ] then + let token, state = next state in + let expr2, state = inequality state in + (* TODO: maybe rework to only have Less and Greater as ops; performance? *) + let (op : binary_op) = + match token.token_type with + | Greater -> Greater + | Less -> Less + | GreaterEqual -> GreaterEqual + | LessEqual -> LessEqual + | _ -> assert false (* should only be here if tt is > < >= <= *) + in + let expr = make_binary op expr expr2 in + (expr, state) + else (expr, state) + + let rec equality (state : state) : expr * state = + let expr, state = inequality state in + if matches state [ EqualEqual; BangEqual ] then + let token, state = next state in + let expr2, state = equality state in + let expr = + match token.token_type with + | EqualEqual -> make_binary Equal expr expr2 + | BangEqual -> + let expr = make_binary Equal expr expr2 in + make_unary Not expr + | _ -> assert false (* should only be here if tt is == != *) + in + (expr, state) + else (expr, state) + + let expression (state : state) : expr * state = equality state +end + +let parse (tokens : token list) : parse_result = + let state = { tokens; errors_rev = [] } in + let expr, state = State.expression state in + let state = + if not (State.is_at_end state) then + let tt = (State.peek state).token_type in + let msg = Printf.sprintf "Unexpected %s at end" (show_token_type tt) in + State.append_error msg (State.peek state).pos state + else state + in + (* if List.length state.errors_rev != 0 then Ok expr else Error (List.rev state.errors_rev) *) + match state.errors_rev with [] -> Ok expr | es -> Error (List.rev es)