From 1af4030dc175fbfad30cb9fdfc8222406b47bb7d Mon Sep 17 00:00:00 2001 From: Greg Date: Mon, 14 Oct 2024 22:53:26 +0300 Subject: [PATCH] scope and binding --- ast_string.go | 16 ++- callable.go | 4 +- env.go | 17 ++++ expr.go | 2 +- glox.go | 36 ++++++- interpreter.go | 42 ++++++-- parser.go | 21 ++-- resolver.go | 188 ++++++++++++++++++++++++++++++++++++ stack.go | 47 +++++++++ stmt.go | 4 +- tests/functions.lox | 2 +- tests/scope_and_binding.lox | 12 +++ 12 files changed, 359 insertions(+), 32 deletions(-) create mode 100644 resolver.go create mode 100644 stack.go create mode 100644 tests/scope_and_binding.lox diff --git a/ast_string.go b/ast_string.go index 6b6a470..9d7c575 100644 --- a/ast_string.go +++ b/ast_string.go @@ -74,7 +74,9 @@ func (as *AstStringer) visitLogical(l *Logical) any { func (as *AstStringer) visitCall(c *Call) any { as.str.WriteString("(call ") c.callee.accept(as) - as.str.WriteString(" ") + if len(c.args) != 0 { + as.str.WriteString(" ") + } for i, arg := range c.args { arg.accept(as) if i < len(c.args)-1 { @@ -98,7 +100,9 @@ func (as *AstStringer) visitLambda(l *Lambda) any { } as.str.WriteString(")") } - l.body.accept(as) + for _, stmt := range l.body { + stmt.accept(as) + } as.str.WriteString(")") return nil @@ -137,7 +141,7 @@ func (as *AstStringer) visitBlockStmt(b *BlockStmt) { func (as *AstStringer) visitIfStmt(i *IfStmt) { as.str.WriteString("(if ") - i.expr.accept(as) + i.cond.accept(as) as.str.WriteString(" ") i.then.accept(as) if i.or != nil { @@ -164,7 +168,7 @@ func (as *AstStringer) visitBreakStmt(b *BreakStmt) { } func (as *AstStringer) visitFunStmt(f *FunStmt) { - as.str.WriteString(fmt.Sprintf("(fun %s", f.name.lexeme)) + as.str.WriteString(fmt.Sprintf("(fun %s ", f.name.lexeme)) if len(f.args) != 0 { as.str.WriteString("(") for i, arg := range f.args { @@ -175,7 +179,9 @@ func (as *AstStringer) visitFunStmt(f *FunStmt) { } as.str.WriteString(")") } - f.body.accept(as) + for _, stmt := range f.body { + stmt.accept(as) + } as.str.WriteString(")") } diff --git a/callable.go b/callable.go index eceb357..8db5da6 100644 --- a/callable.go +++ b/callable.go @@ -8,7 +8,7 @@ type Callable interface { type Function struct { name Token args []Token - body *BlockStmt + body []Stmt closure *Environment } @@ -41,6 +41,6 @@ func (f *Function) arity() int { return len(f.args) } -func newFunction(name Token, args []Token, body *BlockStmt, env *Environment) Callable { +func newFunction(name Token, args []Token, body []Stmt, env *Environment) Callable { return &Function{name, args, body, env} } diff --git a/env.go b/env.go index 1b86e8d..3ca7b57 100644 --- a/env.go +++ b/env.go @@ -44,3 +44,20 @@ func (env *Environment) assign(key Token, val any) *RuntimeError { return env.enclosing.assign(key, val) } + +func (env *Environment) getAt(distance int, key string) any { + return env.ancestor(distance).get(key) +} + +func (env *Environment) assignAt(distance int, key Token, val any) { + env.ancestor(distance).values[key.lexeme] = val +} + +func (env *Environment) ancestor(distance int) *Environment { + parent := env + for i := 0; i < distance; i++ { + parent = parent.enclosing + } + + return parent +} diff --git a/expr.go b/expr.go index 6530d6c..08b72ed 100644 --- a/expr.go +++ b/expr.go @@ -60,7 +60,7 @@ type Call struct { type Lambda struct { name Token args []Token - body *BlockStmt + body []Stmt } func (c *Call) expr() {} diff --git a/glox.go b/glox.go index e25252f..91a3096 100644 --- a/glox.go +++ b/glox.go @@ -28,12 +28,22 @@ func (gl *Glox) runPrompt() { scanner := bufio.NewScanner(os.Stdin) scanner.Split(bufio.ScanLines) + doRun := func(line []byte) { + defer func() { + if err := recover(); err != nil { + log.Println(err) + } + }() + + gl.run(line) + } + for { print("> ") if !scanner.Scan() { break } - gl.run(scanner.Bytes()) + doRun(scanner.Bytes()) } } @@ -48,11 +58,29 @@ func (gl *Glox) runFile(path string) { } func (gl *Glox) run(source []byte) { - tokens, _ := newScanner(source).scan() + tokens, err := newScanner(source).scan() - stmts, _ := newParser(tokens).parse() + if err != nil { + panic(err) + } + + stmts, parseErrs := newParser(tokens).parse() + + if parseErrs != nil { + panic(parseErrs) + } fmt.Println(AstStringer{stmts: stmts}) - gl.Interpreter.interpret(stmts) + resolveErrs := newResolver(gl.Interpreter).resolveStmts(stmts...) + + if resolveErrs != nil { + panic(resolveErrs) + } + + interpreterErrs := gl.Interpreter.interpret(stmts) + + if interpreterErrs != nil { + panic(interpreterErrs) + } } diff --git a/interpreter.go b/interpreter.go index 7cd8c22..85868d7 100644 --- a/interpreter.go +++ b/interpreter.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "log" "reflect" @@ -10,6 +11,7 @@ import ( type Interpreter struct { env *Environment globals *Environment + locals map[Expr]int errors []error brk bool } @@ -36,12 +38,13 @@ func newInterpreter() *Interpreter { return &Interpreter{ env: globals, globals: globals, + locals: map[Expr]int{}, errors: []error{}, brk: false, } } -func (i *Interpreter) interpret(stmts []Stmt) []error { +func (i *Interpreter) interpret(stmts []Stmt) error { defer i.recover() i.errors = []error{} @@ -50,7 +53,7 @@ func (i *Interpreter) interpret(stmts []Stmt) []error { stmt.accept(i) } - return i.errors + return errors.Join(i.errors...) } func (i *Interpreter) recover() { @@ -135,12 +138,19 @@ func (i *Interpreter) visitUnary(u *Unary) any { } func (i *Interpreter) visitVariable(v *Variable) any { - return i.env.get(v.name.lexeme) + return i.lookUpVariable(v.name, v) } func (i *Interpreter) visitAssignment(a *Assign) any { val := i.evaluate(a.value) - err := i.env.assign(a.variable, val) + distance, isLocal := i.locals[a] + + if isLocal { + i.env.assignAt(distance, a.variable, val) + return val + } + + err := i.globals.assign(a.variable, val) if err != nil { i.panic(err) } @@ -229,10 +239,10 @@ func (i *Interpreter) visitVarStmt(v *VarStmt) { } func (i *Interpreter) visitBlockStmt(b *BlockStmt) { - i.executeBlock(b, newEnvironment(i.env)) + i.executeBlock(b.stmts, newEnvironment(i.env)) } -func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) { +func (i *Interpreter) executeBlock(stmts []Stmt, current *Environment) { parentEnv := i.env i.env = current @@ -243,7 +253,7 @@ func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) { i.env = parentEnv }() - for _, stmt := range b.stmts { + for _, stmt := range stmts { if i.brk { break @@ -259,7 +269,7 @@ func (i *Interpreter) visitBreakStmt(b *BreakStmt) { } func (i *Interpreter) visitIfStmt(iff *IfStmt) { - if isTruthy(i.evaluate(iff.expr)) { + if isTruthy(i.evaluate(iff.cond)) { iff.then.accept(i) } else if iff.or != nil { @@ -278,6 +288,8 @@ func (i *Interpreter) visitEnvStmt(e *EnvStmt) { walker = walker.enclosing } + fmt.Printf("globals: %+v\n", *i.globals) + for ident, e := range flatten { fmt.Printf("%*s", ident, "") fmt.Printf("%+v\n", *e) @@ -296,6 +308,20 @@ func (i *Interpreter) visitWhileStmt(w *WhileStmt) { } } +func (i *Interpreter) resolve(expr Expr, depth int) { + i.locals[expr] = depth +} + +func (i *Interpreter) lookUpVariable(name Token, expr Expr) any { + distance, isLocal := i.locals[expr] + + if !isLocal { + return i.globals.get(name.lexeme) + } + + return i.env.getAt(distance, name.lexeme) +} + func (i *Interpreter) panic(re *RuntimeError) { i.errors = append(i.errors, re) log.Println(re) diff --git a/parser.go b/parser.go index 94f01ae..3ff5eb5 100644 --- a/parser.go +++ b/parser.go @@ -1,8 +1,8 @@ package main import ( + "errors" "fmt" - "log" ) type Parser struct { @@ -28,7 +28,7 @@ func newParser(tokens []Token) *Parser { } // program -> declaration* EOF -func (p *Parser) parse() ([]Stmt, []error) { +func (p *Parser) parse() ([]Stmt, error) { defer p.recover() stmts := []Stmt{} @@ -40,7 +40,7 @@ func (p *Parser) parse() ([]Stmt, []error) { } } - return stmts, p.errors + return stmts, errors.Join(p.errors...) } // declaration -> varDecl | funDecl | statement @@ -97,7 +97,7 @@ func (p *Parser) function(kind string) Stmt { p.consume(RIGHT_PAREN, fmt.Sprintf("Expect ')' after %s name.", kind)) p.consume(LEFT_BRACE, fmt.Sprintf("Expect '{' after %s arguments.", kind)) - body := p.blockStmt() + body := p.block() return &FunStmt{name, args, body} } @@ -172,8 +172,7 @@ func (p *Parser) printStmt() Stmt { return &PrintStmt{expr} } -// blockStmt -> "{" statement* "}" -func (p *Parser) blockStmt() *BlockStmt { +func (p *Parser) block() []Stmt { stmts := []Stmt{} for !p.check(RIGHT_BRACE) && !p.isAtEnd() { @@ -182,7 +181,12 @@ func (p *Parser) blockStmt() *BlockStmt { p.consume(RIGHT_BRACE, "Unclosed block: Expected '}'.") - return &BlockStmt{stmts} + return stmts +} + +// blockStmt -> "{" statement* "}" +func (p *Parser) blockStmt() *BlockStmt { + return &BlockStmt{p.block()} } // breakStmt -> break ";" @@ -496,7 +500,7 @@ func (p *Parser) lambda() Expr { p.consume(RIGHT_PAREN, "Expect ')' after lambda arguments.") p.consume(LEFT_BRACE, "Expect '{' before lambda body.") - body := p.blockStmt() + body := p.block() return &Lambda{name, args, body} } @@ -583,7 +587,6 @@ func (p *Parser) recover() { func (p *Parser) panic(pe *ParseError) { p.errors = append(p.errors, pe) - log.Println(pe) panic(pe) } diff --git a/resolver.go b/resolver.go new file mode 100644 index 0000000..3dc15c7 --- /dev/null +++ b/resolver.go @@ -0,0 +1,188 @@ +package main + +type Scope map[string]bool + +type Resolver struct { + interpreter *Interpreter + scopes Stack[Scope] +} + +type ResolveError struct { + msg string +} + +func (r *ResolveError) Error() string { + return r.msg +} + +func newResolver(i *Interpreter) *Resolver { + return &Resolver{i, NewStack[Scope]()} +} + +func (r *Resolver) resolveStmts(stmts ...Stmt) error { + for _, stmt := range stmts { + stmt.accept(r) + } + + return nil +} + +func (r *Resolver) resolveExprs(exprs ...Expr) error { + for _, expr := range exprs { + expr.accept(r) + } + + return nil +} + +func (r *Resolver) beginScope() { + r.scopes.Push(map[string]bool{}) +} + +func (r *Resolver) endScope() { + r.scopes.Pop() +} + +func (r *Resolver) declare(token Token) { + if !r.scopes.Empty() { + r.scopes.Peek()[token.lexeme] = false + } +} + +func (r *Resolver) define(token Token) { + if !r.scopes.Empty() { + r.scopes.Peek()[token.lexeme] = true + } +} + +func (r *Resolver) visitBlockStmt(b *BlockStmt) { + r.beginScope() + r.resolveStmts(b.stmts...) + r.endScope() +} + +func (r *Resolver) visitVarStmt(v *VarStmt) { + r.declare(v.name) + if v.initializer != nil { + r.resolveExprs(v.initializer) + } + r.define(v.name) +} + +func (r *Resolver) visitVariable(v *Variable) any { + if !r.scopes.Empty() { + defined, declared := r.scopes.Peek()[v.name.lexeme] + + if declared && !defined { + panic(&ResolveError{"Can't read local variable in its own initializer."}) + } + } + + r.resolveLocal(v, v.name) + return nil +} + +func (r *Resolver) visitAssignment(a *Assign) any { + r.resolveExprs(a.value) + r.resolveLocal(a, a.variable) + return nil +} + +func (r *Resolver) resolveLocal(expr Expr, name Token) { + for i := r.scopes.Size() - 1; i >= 0; i-- { + if _, exists := r.scopes.At(i)[name.lexeme]; exists { + r.interpreter.resolve(expr, r.scopes.Size()-1-i) + return + } + } +} + +func (r *Resolver) visitFunStmt(fun *FunStmt) { + r.declare(fun.name) + r.define(fun.name) + r.resolveFun(fun) +} + +func (r *Resolver) resolveFun(fun *FunStmt) { + r.beginScope() + for _, arg := range fun.args { + r.declare(arg) + r.define(arg) + } + r.resolveStmts(fun.body...) + r.endScope() +} + +func (r *Resolver) visitExprStmt(es *ExprStmt) { + r.resolveExprs(es.expr) +} + +func (r *Resolver) visitBreakStmt(b *BreakStmt) {} +func (r *Resolver) visitEnvStmt(b *EnvStmt) {} +func (r *Resolver) visitIfStmt(ifs *IfStmt) { + r.resolveExprs(ifs.cond) + r.resolveStmts(ifs.then) + if ifs.or != nil { + r.resolveStmts(ifs.or) + } +} + +func (r *Resolver) visitPrintStmt(p *PrintStmt) { + r.resolveExprs(p.val) +} + +func (r *Resolver) visitReturnStmt(ret *ReturnStmt) { + if ret.value != nil { + r.resolveExprs(ret.value) + } +} + +func (r *Resolver) visitWhileStmt(w *WhileStmt) { + r.resolveExprs(w.cond) + r.resolveStmts(w.body) +} + +func (r *Resolver) visitBinary(b *Binary) any { + r.resolveExprs(b.left) + r.resolveExprs(b.right) + return nil +} + +func (r *Resolver) visitCall(c *Call) any { + r.resolveExprs(c.callee) + for _, arg := range c.args { + r.resolveExprs(arg) + } + return nil +} + +func (r *Resolver) visitGrouping(g *Grouping) any { + r.resolveExprs(g.expression) + return nil +} + +func (r *Resolver) visitLambda(l *Lambda) any { + r.beginScope() + for _, arg := range l.args { + r.declare(arg) + r.define(arg) + } + r.resolveStmts(l.body...) + r.endScope() + return nil +} + +func (r *Resolver) visitLiteral(l *Literal) any { + return nil +} + +func (r *Resolver) visitLogical(l *Logical) any { + r.resolveExprs(l.left) + r.resolveExprs(l.right) + return nil +} + +func (r *Resolver) visitUnary(u *Unary) any { + r.resolveExprs(u.right) + return nil +} diff --git a/stack.go b/stack.go new file mode 100644 index 0000000..627ee36 --- /dev/null +++ b/stack.go @@ -0,0 +1,47 @@ +package main + +type Stack[Item any] interface { + Push(Item) + Pop() Item + Peek() Item + At(int) Item + Size() int + Empty() bool +} + +type node[Item any] struct { + item Item + next *node[Item] +} + +type stack[OfType any] []OfType + +func NewStack[OfType any]() Stack[OfType] { + return &stack[OfType]{} +} + +func (s *stack[Item]) Push(item Item) { + *s = append(*s, item) +} + +func (s *stack[Item]) Pop() Item { + last := s.Peek() + *s = (*s)[:len(*s)-1] + return last +} + +func (s *stack[Item]) At(idx int) Item { + return (*s)[idx] +} + +func (s *stack[Item]) Peek() Item { + return (*s)[len(*s)-1] +} + +func (s *stack[_]) Size() int { + return len(*s) +} + +func (s *stack[_]) Empty() bool { + return s.Size() == 0 +} diff --git a/stmt.go b/stmt.go index b4c6258..99bc9b9 100644 --- a/stmt.go +++ b/stmt.go @@ -39,7 +39,7 @@ type EnvStmt struct{} type IfStmt struct { name Token - expr Expr + cond Expr then Stmt or Stmt } @@ -54,7 +54,7 @@ type BreakStmt struct{} type FunStmt struct { name Token args []Token - body *BlockStmt + body []Stmt } type ReturnStmt struct { diff --git a/tests/functions.lox b/tests/functions.lox index abc0c62..1568247 100644 --- a/tests/functions.lox +++ b/tests/functions.lox @@ -4,6 +4,7 @@ print "native function"; print clock(); fun count(n) { + print n; if (n > 1) count(n - 1); print n; } @@ -43,7 +44,6 @@ for (var i = 1; i <= 10; i = i + 1) { print fib(i); } - fun makeCounter() { var i = 0; diff --git a/tests/scope_and_binding.lox b/tests/scope_and_binding.lox new file mode 100644 index 0000000..8262136 --- /dev/null +++ b/tests/scope_and_binding.lox @@ -0,0 +1,12 @@ +var a = "global"; + +{ + fun showA() { + print a; + } + + showA(); + var a = "inner v2"; + + showA(); +} \ No newline at end of file