let ( let* ) = Result.bind open Error open Expr open Lexer open Stmt type parse_result = (stmt_node list, parser_error list) result type stmt_result = (stmt_node, parser_error) result type expr_result = (expr_node, parser_error) result type state = { tokens : token list } let is_at_end state = (List.hd !state.tokens).token_type == Eof let advance state = state := { tokens = List.tl !state.tokens } let next state = assert (not ((List.hd !state.tokens).token_type == Eof)); let token = List.hd !state.tokens in advance state; token let peek state = List.hd !state.tokens let peek_tt (state : state ref) : token_type = (peek state).token_type let cur_pos state = (peek state).pos let advance_if state tt = if peek_tt state == tt then ( advance state; true) else false let consume state tt = if advance_if state tt then Ok () else Error (ParserError.make (cur_pos state) (Printf.sprintf "Expected %s, but got %s" (show_token_type tt) (show_token_type (peek_tt state)))) let matches state tts = let f = ( == ) (peek_tt state) in Array.fold_left (fun acc tt -> acc || f tt) false tts let collect_chain (state : state ref) (tts : token_type array) (higher_prec : state ref -> expr_result) : ((expr_node * token) array, parser_error) result = let rec collect_chain_rec (acc : (expr_node * token) list) = if (not (is_at_end state)) && matches state tts then let token = next state in let* expr = higher_prec state in let acc = (expr, token) :: acc in collect_chain_rec acc else Ok acc in collect_chain_rec [] |> Result.map (fun l -> Array.of_list (List.rev l)) let primary (state : state ref) : expr_result = let pos = cur_pos state in match peek_tt state with | Number x -> advance state; Ok (make_number pos x) | String s -> advance state; Ok (make_string pos s) | True -> advance state; Ok (make_bool pos true) | False -> advance state; Ok (make_bool pos false) | Nil -> advance state; Ok (make_nil pos) | tt -> let msg = Printf.sprintf "Unexpected %s, expected valid expression" (show_token_type tt) in let pos = (peek state).pos in Error { msg; pos } let rec grouping (state : state ref) : expr_result = if matches state [| LeftParen |] then ( advance state; let* expr = expression state in if advance_if state RightParen then Ok expr (* expect a ) here *) else let pos = cur_pos state in let tt = peek_tt state in let msg = Printf.sprintf "Expected RightParen, got %s instead" (show_token_type tt) in Error { pos; msg }) else primary state and neg_not (state : state ref) : expr_result = if matches state [| Bang; Minus |] then let token = next state in let pos = token.pos in let* expr = neg_not state in let op = match token.token_type with | Bang -> Not | Minus -> Neg | _ -> assert false (* should only be here if tt is - ! *) in let expr = make_unary pos op expr in Ok expr else grouping state and mul_or_div (state : state ref) : expr_result = let* expr = neg_not state in let* exprs_tokens = collect_chain state [| Star; Slash |] neg_not in let f acc (expr, (token : token)) = let pos = token.pos in let op : binary_op = match token.token_type with | Star -> Mul | Slash -> Div | _ -> assert false (* should only be here if tt is * / *) in make_binary pos op acc expr in let expr = Array.fold_left f expr exprs_tokens in Ok expr and sum_or_diff (state : state ref) : expr_result = let* expr = mul_or_div state in let* exprs_tokens = collect_chain state [| Plus; Minus |] mul_or_div in let f acc (expr, (token : token)) = let pos = token.pos in let op : binary_op = match token.token_type with | Plus -> Plus | Minus -> Minus | _ -> assert false (* should only be here if tt is + - *) in make_binary pos op acc expr in let expr = Array.fold_left f expr exprs_tokens in Ok expr and inequality (state : state ref) : expr_result = (* TODO: maybe rework to only have Less and Greater as ops; performance? *) let* expr = sum_or_diff state in let* exprs_tokens = collect_chain state [| Greater; GreaterEqual; Less; LessEqual |] sum_or_diff in let f acc (expr, (token : token)) = let pos = token.pos in let (op : binary_op) = match token.token_type with | Greater -> Greater | Less -> Less | GreaterEqual -> GreaterEqual | LessEqual -> LessEqual | _ -> assert false (* should only be here if tt is > < >= <= *) in make_binary pos op acc expr in let expr = Array.fold_left f expr exprs_tokens in Ok expr and equality (state : state ref) : expr_result = let* expr = inequality state in let* exprs_tokens = collect_chain state [| EqualEqual; BangEqual |] inequality in let f acc (expr, (token : token)) = let pos = token.pos in match token.token_type with | EqualEqual -> make_binary pos Equal acc expr | BangEqual -> let expr = make_binary pos Equal acc expr in make_unary pos Not expr | _ -> assert false (* should only be here if tt is == != *) in let expr = Array.fold_left f expr exprs_tokens in Ok expr and expression (state : state ref) : expr_result = equality state let statement (state : state ref) : stmt_result = let pos = cur_pos state in match peek_tt state with | Print -> advance state; let* expr = expression state in let* _ = consume state Semicolon in let stmt = make_print pos expr in Ok stmt | tt -> advance state; let msg = Printf.sprintf "Statement stating with %s not yet implemented" (show_token_type tt) in Error (ParserError.make pos msg) let rec synchronise (state : state ref) = match peek_tt state with | Semicolon -> advance state | Class | Fun | Var | For | If | While | Print | Return | Eof -> () | _ -> advance state; synchronise state let rec parse_impl (state : state ref) : parse_result = let result = statement state in match result with | Ok stmt when peek_tt state == Eof -> Ok [ stmt ] | Ok stmt -> print_endline (show_stmt stmt.stmt); let* stmts = parse_impl state in Ok (stmt :: stmts) | Error e -> ( synchronise state; if peek_tt state == Eof then Error [ e ] else let tail_result = parse_impl state in match tail_result with Ok _ -> Error [ e ] | Error es -> Error (e :: es)) let parse (tokens : token list) : parse_result = (* filter out all the comment tokens *) let tokens = List.filter (fun tok -> match tok.token_type with Comment _ -> false | _ -> true) tokens in let state = ref { tokens } in parse_impl state