fixed closures

This commit is contained in:
Moritz Gmeiner 2024-09-03 00:48:10 +02:00
commit 671f7d5306
6 changed files with 106 additions and 72 deletions

View file

@ -37,12 +37,8 @@ fn generate_tests() {
for lox_file in glob("tests/lox/**/*.lox").unwrap() { for lox_file in glob("tests/lox/**/*.lox").unwrap() {
let lox_file = lox_file.unwrap(); let lox_file = lox_file.unwrap();
println!("found lox file: {}", lox_file.to_str().unwrap());
output += &build_test_case(lox_file); output += &build_test_case(lox_file);
} }
println!("{output}");
std::fs::write("tests/all_tests.rs", output).unwrap(); std::fs::write("tests/all_tests.rs", output).unwrap();
} }

View file

@ -205,13 +205,13 @@ impl<'a> Environment<'a> {
name: &str, name: &str,
level: usize, level: usize,
) -> Result<HeapedValue, RuntimeError> { ) -> Result<HeapedValue, RuntimeError> {
if level < self.scopes().len() { assert!(level < self.scopes().len());
let scope = self.scopes_mut().iter_mut().rev().nth(level).unwrap(); let scope = self.scopes_mut().iter_mut().rev().nth(level).unwrap();
if let Some(scoped_value) = scope.get_mut(name) { if let Some(scoped_value) = scope.get_mut(name) {
return Ok(scoped_value.heapified()); return Ok(scoped_value.heapified());
} }
}
panic!("Name not defined not caught by resolver"); panic!("Name not defined not caught by resolver");
@ -252,20 +252,10 @@ impl<'a> Environment<'a> {
for (name, level) in closure_vars { for (name, level) in closure_vars {
// special injected variables // special injected variables
if &**name == "this" || &**name == "super" { if name == "this" || name == "super" {
continue; continue;
} }
// let heap_value = self
// .scopes()
// .iter()
// .rev()
// .nth(*level)
// .expect("Closure variable got resolved to level that doesn't exist")
// .get(name)
// .expect("Closure variable was resolved, but could not be found")
// .clone();
let heap_value = self.get_local_heaped(name, *level).unwrap(); let heap_value = self.get_local_heaped(name, *level).unwrap();
let name = name.to_owned(); let name = name.to_owned();

View file

@ -15,7 +15,7 @@ pub struct LoxFunction {
name: SmolStr, name: SmolStr,
closure: ClosureScope, closure: ClosureScope,
param_names: Vec<SmolStr>, param_names: Vec<SmolStr>,
body: Box<Stmt>, body: Stmt,
} }
impl LoxFunction { impl LoxFunction {
@ -26,7 +26,6 @@ impl LoxFunction {
body: Stmt, body: Stmt,
) -> Rc<Self> { ) -> Rc<Self> {
let name = name.into(); let name = name.into();
let body = Box::new(body);
let fun = LoxFunction { let fun = LoxFunction {
name, name,
closure, closure,

View file

@ -182,12 +182,19 @@ impl Eval for Expr {
param_names, param_names,
closure_vars, closure_vars,
body, body,
} => Ok(Value::function(LoxFunction::new( } => {
name.clone(), let name = name.clone();
env.collect_closure(closure_vars), let closure = env.collect_closure(closure_vars);
param_names.clone(), let param_names = param_names.clone();
body.as_ref().clone(), let body = body.as_ref().clone();
))),
Ok(Value::function(LoxFunction::new(
name,
closure,
param_names,
body,
)))
}
Expr::Class { Expr::Class {
superclass, superclass,
name, name,
@ -315,7 +322,7 @@ pub fn call_fun(fun: Rc<LoxFunction>, env: &mut Environment) -> EvalResult<Value
env.enter_scope()?; env.enter_scope()?;
env.define(fun.name(), Value::Function(fun.clone())); env.define(fun.name(), Value::Function(Rc::clone(&fun)));
env.insert_closure(fun.closure().clone()); env.insert_closure(fun.closure().clone());

View file

@ -54,7 +54,7 @@ fn print_globals() -> LoxExternFunction {
globals.sort_by_key(|&(name, _value)| name); globals.sort_by_key(|&(name, _value)| name);
for (name, value) in globals { for (name, value) in globals {
println!("{name}: {value}"); println!("{name} = {value}");
} }
Ok(Value::Nil) Ok(Value::Nil)

View file

@ -8,7 +8,7 @@ use crate::{LoxError, ResolverError};
/*====================================================================================================================*/ /*====================================================================================================================*/
type ResolverScope = FxHashMap<SmolStr, ResolveStatus>; type ResolverScope = FxHashMap<SmolStr, ResolveStatus>;
type ResolverFrame = Vec<ResolverScope>; // type ResolverFrame = Vec<ResolverScope>;
type ResolverResult<T> = Result<T, ResolverError>; type ResolverResult<T> = Result<T, ResolverError>;
@ -20,6 +20,22 @@ enum ResolveStatus {
Defined, Defined,
} }
#[derive(Debug, Default)]
struct ResolverFrame {
scopes: Vec<ResolverScope>,
closure_vars: FxHashMap<SmolStr, usize>,
}
impl ResolverFrame {
fn new() -> Self {
ResolverFrame {
scopes: vec![ResolverScope::default()],
closure_vars: Default::default(),
}
}
}
/*====================================================================================================================*/ /*====================================================================================================================*/
pub fn resolve(statement: &mut Stmt, runtime: &mut Runtime) -> Result<(), LoxError> { pub fn resolve(statement: &mut Stmt, runtime: &mut Runtime) -> Result<(), LoxError> {
@ -35,10 +51,10 @@ pub fn resolve(statement: &mut Stmt, runtime: &mut Runtime) -> Result<(), LoxErr
#[derive(Debug)] #[derive(Debug)]
struct Resolver { struct Resolver {
global_scope: ResolverScope, global_scope: ResolverScope,
// local_scopes: Vec<ResolverScope>, // local_scopes: Vec<ResolverScope>,
frames: Vec<ResolverFrame>, frames: Vec<ResolverFrame>,
// closure_vars: FxHashMap<SmolStr, usize>,
closure_vars: FxHashMap<SmolStr, usize>,
} }
impl Resolver { impl Resolver {
@ -51,19 +67,35 @@ impl Resolver {
Resolver { Resolver {
global_scope, global_scope,
frames: vec![ResolverFrame::new()], frames: vec![ResolverFrame::default()],
closure_vars: FxHashMap::default(), // closure_vars: FxHashMap::default(),
} }
} }
fn local_scopes(&self) -> &ResolverFrame { fn frame(&self) -> &ResolverFrame {
self.frames.last().unwrap() self.frames.last().unwrap()
} }
fn local_scopes_mut(&mut self) -> &mut ResolverFrame { fn frame_mut(&mut self) -> &mut ResolverFrame {
self.frames.last_mut().unwrap() self.frames.last_mut().unwrap()
} }
fn local_scopes(&self) -> &[ResolverScope] {
&self.frame().scopes
}
fn local_scopes_mut(&mut self) -> &mut Vec<ResolverScope> {
&mut self.frame_mut().scopes
}
/* fn closure_vars(&self) -> &FxHashMap<SmolStr, usize> {
&self.frame().closure_vars
} */
fn closure_vars_mut(&mut self) -> &mut FxHashMap<SmolStr, usize> {
&mut self.frame_mut().closure_vars
}
fn enter_scope(&mut self) { fn enter_scope(&mut self) {
self.local_scopes_mut().push(ResolverScope::default()); self.local_scopes_mut().push(ResolverScope::default());
} }
@ -75,7 +107,7 @@ impl Resolver {
} }
fn push_frame(&mut self) { fn push_frame(&mut self) {
self.frames.push(vec![ResolverScope::default()]); self.frames.push(ResolverFrame::new());
} }
fn pop_frame(&mut self) { fn pop_frame(&mut self) {
@ -119,39 +151,51 @@ impl Resolver {
} }
} }
fn resolve_var(&mut self, name: &str) -> ResolverResult<Expr> { // try to recursively resolve closure variable by descending through the frames until the
let mut level = 0; // variable is found and then inserting closures at every needed level
fn resolve_closure(&mut self, name: SmolStr, frame_lvl: usize) -> Option<usize> {
let frame = &mut self.frames[frame_lvl];
for (level, scope) in frame.scopes.iter().rev().enumerate() {
if scope.contains_key(&name) {
return Some(level);
}
}
if frame_lvl == 0 {
return None;
}
if let Some(level) = self.resolve_closure(name.clone(), frame_lvl - 1) {
let frame = &mut self.frames[frame_lvl];
frame.closure_vars.insert(name, level);
Some(frame.scopes.len() - 1)
} else {
None
}
}
fn resolve_var(&mut self, name: SmolStr) -> ResolverResult<Expr> {
// resolve normal local variable // resolve normal local variable
for scope in self.local_scopes().iter().rev() { for (level, scope) in self.local_scopes().iter().rev().enumerate() {
if scope.contains_key(name) { if scope.contains_key(&name) {
return Ok(Expr::local_variable(name, level)); return Ok(Expr::local_variable(name, level));
} }
level += 1;
} }
// resolve closure variable // if we have more than one frame: look for closure variable
for frame in self.frames.iter().rev().skip(1) { if self.frames.len() > 1 {
for scope in frame.iter().rev() { //try to resolve closure variable up from one frame down
if scope.contains_key(name) { if let Some(level) = self.resolve_closure(name.clone(), self.frames.len() - 2) {
if !self.closure_vars.contains_key(name) { self.closure_vars_mut().insert(name.clone(), level);
// the level at which the closed-over variable will be collected from
// from the perspective of the parameter/closure scope i.e. the outmost of the function
self.closure_vars
.insert(name.into(), level - self.local_scopes().len());
}
// distance from actual variable refernce to parameter/closure scope return Ok(Expr::local_variable(name, self.local_scopes().len() - 1));
let level = self.local_scopes().len() - 1;
return Ok(Expr::local_variable(name, level));
}
level += 1;
} }
} }
// resolve global variable // resolve global variable
if self.global_scope.contains_key(name) { if self.global_scope.contains_key(&name) {
return Ok(Expr::global_variable(name)); return Ok(Expr::global_variable(name));
} }
@ -161,8 +205,6 @@ impl Resolver {
return Err(ResolverError::SuperOutsideMethod); return Err(ResolverError::SuperOutsideMethod);
} }
let name = name.into();
Err(ResolverError::UnresolvableVariable { name }) Err(ResolverError::UnresolvableVariable { name })
} }
@ -243,7 +285,7 @@ impl Resolver {
return Err(ResolverError::VarInOwnInitializer { name }); return Err(ResolverError::VarInOwnInitializer { name });
} }
*expr = self.resolve_var(name)?; *expr = self.resolve_var(name.clone())?;
Ok(()) Ok(())
} }
@ -260,7 +302,7 @@ impl Resolver {
let target = target.as_mut(); let target = target.as_mut();
if let Expr::Variable { name } = target { if let Expr::Variable { name } = target {
*target = self.resolve_var(name)?; *target = self.resolve_var(name.clone())?;
} else { } else {
panic!("Invalid assignment target {target}"); panic!("Invalid assignment target {target}");
} }
@ -298,7 +340,7 @@ impl Resolver {
closure_vars, closure_vars,
body, body,
} => { } => {
let old_closure_names = std::mem::take(&mut self.closure_vars); // let old_closure_names = std::mem::take(&mut self.closure_vars);
self.push_frame(); self.push_frame();
@ -314,12 +356,12 @@ impl Resolver {
self.resolve_stmt(body)?; self.resolve_stmt(body)?;
let closure_names = std::mem::take(self.closure_vars_mut());
self.pop_frame(); self.pop_frame();
let closure_names = std::mem::replace(&mut self.closure_vars, old_closure_names); for (var_name, level) in closure_names {
closure_vars.push((var_name, level));
for closure_var in closure_names {
closure_vars.push(closure_var);
} }
Ok(()) Ok(())
@ -344,7 +386,7 @@ impl Resolver {
self.declare("this").unwrap(); self.declare("this").unwrap();
if superclass.is_some() { if superclass.is_some() {
// this should never fail either! same as `this` // this should never fail either! same reason as `this`
self.declare("super").unwrap(); self.declare("super").unwrap();
} }
@ -357,7 +399,7 @@ impl Resolver {
Ok(()) Ok(())
} }
Expr::This => { Expr::This => {
*expr = self.resolve_var("this")?; *expr = self.resolve_var("this".into())?;
Ok(()) Ok(())
} }
Expr::Super { Expr::Super {