scope and binding

This commit is contained in:
Greg 2024-10-14 22:53:26 +03:00
parent a24586e601
commit 1af4030dc1
12 changed files with 359 additions and 32 deletions

View file

@ -74,7 +74,9 @@ func (as *AstStringer) visitLogical(l *Logical) any {
func (as *AstStringer) visitCall(c *Call) any {
as.str.WriteString("(call ")
c.callee.accept(as)
if len(c.args) != 0 {
as.str.WriteString(" ")
}
for i, arg := range c.args {
arg.accept(as)
if i < len(c.args)-1 {
@ -98,7 +100,9 @@ func (as *AstStringer) visitLambda(l *Lambda) any {
}
as.str.WriteString(")")
}
l.body.accept(as)
for _, stmt := range l.body {
stmt.accept(as)
}
as.str.WriteString(")")
return nil
@ -137,7 +141,7 @@ func (as *AstStringer) visitBlockStmt(b *BlockStmt) {
func (as *AstStringer) visitIfStmt(i *IfStmt) {
as.str.WriteString("(if ")
i.expr.accept(as)
i.cond.accept(as)
as.str.WriteString(" ")
i.then.accept(as)
if i.or != nil {
@ -164,7 +168,7 @@ func (as *AstStringer) visitBreakStmt(b *BreakStmt) {
}
func (as *AstStringer) visitFunStmt(f *FunStmt) {
as.str.WriteString(fmt.Sprintf("(fun %s", f.name.lexeme))
as.str.WriteString(fmt.Sprintf("(fun %s ", f.name.lexeme))
if len(f.args) != 0 {
as.str.WriteString("(")
for i, arg := range f.args {
@ -175,7 +179,9 @@ func (as *AstStringer) visitFunStmt(f *FunStmt) {
}
as.str.WriteString(")")
}
f.body.accept(as)
for _, stmt := range f.body {
stmt.accept(as)
}
as.str.WriteString(")")
}

View file

@ -8,7 +8,7 @@ type Callable interface {
type Function struct {
name Token
args []Token
body *BlockStmt
body []Stmt
closure *Environment
}
@ -41,6 +41,6 @@ func (f *Function) arity() int {
return len(f.args)
}
func newFunction(name Token, args []Token, body *BlockStmt, env *Environment) Callable {
func newFunction(name Token, args []Token, body []Stmt, env *Environment) Callable {
return &Function{name, args, body, env}
}

17
env.go
View file

@ -44,3 +44,20 @@ func (env *Environment) assign(key Token, val any) *RuntimeError {
return env.enclosing.assign(key, val)
}
func (env *Environment) getAt(distance int, key string) any {
return env.ancestor(distance).get(key)
}
func (env *Environment) assignAt(distance int, key Token, val any) {
env.ancestor(distance).values[key.lexeme] = val
}
func (env *Environment) ancestor(distance int) *Environment {
parent := env
for i := 0; i < distance; i++ {
parent = parent.enclosing
}
return parent
}

View file

@ -60,7 +60,7 @@ type Call struct {
type Lambda struct {
name Token
args []Token
body *BlockStmt
body []Stmt
}
func (c *Call) expr() {}

36
glox.go
View file

@ -28,12 +28,22 @@ func (gl *Glox) runPrompt() {
scanner := bufio.NewScanner(os.Stdin)
scanner.Split(bufio.ScanLines)
doRun := func(line []byte) {
defer func() {
if err := recover(); err != nil {
log.Println(err)
}
}()
gl.run(line)
}
for {
print("> ")
if !scanner.Scan() {
break
}
gl.run(scanner.Bytes())
doRun(scanner.Bytes())
}
}
@ -48,11 +58,29 @@ func (gl *Glox) runFile(path string) {
}
func (gl *Glox) run(source []byte) {
tokens, _ := newScanner(source).scan()
tokens, err := newScanner(source).scan()
stmts, _ := newParser(tokens).parse()
if err != nil {
panic(err)
}
stmts, parseErrs := newParser(tokens).parse()
if parseErrs != nil {
panic(parseErrs)
}
fmt.Println(AstStringer{stmts: stmts})
gl.Interpreter.interpret(stmts)
resolveErrs := newResolver(gl.Interpreter).resolveStmts(stmts...)
if resolveErrs != nil {
panic(resolveErrs)
}
interpreterErrs := gl.Interpreter.interpret(stmts)
if interpreterErrs != nil {
panic(interpreterErrs)
}
}

View file

@ -1,6 +1,7 @@
package main
import (
"errors"
"fmt"
"log"
"reflect"
@ -10,6 +11,7 @@ import (
type Interpreter struct {
env *Environment
globals *Environment
locals map[Expr]int
errors []error
brk bool
}
@ -36,12 +38,13 @@ func newInterpreter() *Interpreter {
return &Interpreter{
env: globals,
globals: globals,
locals: map[Expr]int{},
errors: []error{},
brk: false,
}
}
func (i *Interpreter) interpret(stmts []Stmt) []error {
func (i *Interpreter) interpret(stmts []Stmt) error {
defer i.recover()
i.errors = []error{}
@ -50,7 +53,7 @@ func (i *Interpreter) interpret(stmts []Stmt) []error {
stmt.accept(i)
}
return i.errors
return errors.Join(i.errors...)
}
func (i *Interpreter) recover() {
@ -135,12 +138,19 @@ func (i *Interpreter) visitUnary(u *Unary) any {
}
func (i *Interpreter) visitVariable(v *Variable) any {
return i.env.get(v.name.lexeme)
return i.lookUpVariable(v.name, v)
}
func (i *Interpreter) visitAssignment(a *Assign) any {
val := i.evaluate(a.value)
err := i.env.assign(a.variable, val)
distance, isLocal := i.locals[a]
if isLocal {
i.env.assignAt(distance, a.variable, val)
return val
}
err := i.globals.assign(a.variable, val)
if err != nil {
i.panic(err)
}
@ -229,10 +239,10 @@ func (i *Interpreter) visitVarStmt(v *VarStmt) {
}
func (i *Interpreter) visitBlockStmt(b *BlockStmt) {
i.executeBlock(b, newEnvironment(i.env))
i.executeBlock(b.stmts, newEnvironment(i.env))
}
func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) {
func (i *Interpreter) executeBlock(stmts []Stmt, current *Environment) {
parentEnv := i.env
i.env = current
@ -243,7 +253,7 @@ func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) {
i.env = parentEnv
}()
for _, stmt := range b.stmts {
for _, stmt := range stmts {
if i.brk {
break
@ -259,7 +269,7 @@ func (i *Interpreter) visitBreakStmt(b *BreakStmt) {
}
func (i *Interpreter) visitIfStmt(iff *IfStmt) {
if isTruthy(i.evaluate(iff.expr)) {
if isTruthy(i.evaluate(iff.cond)) {
iff.then.accept(i)
} else if iff.or != nil {
@ -278,6 +288,8 @@ func (i *Interpreter) visitEnvStmt(e *EnvStmt) {
walker = walker.enclosing
}
fmt.Printf("globals: %+v\n", *i.globals)
for ident, e := range flatten {
fmt.Printf("%*s", ident, "")
fmt.Printf("%+v\n", *e)
@ -296,6 +308,20 @@ func (i *Interpreter) visitWhileStmt(w *WhileStmt) {
}
}
func (i *Interpreter) resolve(expr Expr, depth int) {
i.locals[expr] = depth
}
func (i *Interpreter) lookUpVariable(name Token, expr Expr) any {
distance, isLocal := i.locals[expr]
if !isLocal {
return i.globals.get(name.lexeme)
}
return i.env.getAt(distance, name.lexeme)
}
func (i *Interpreter) panic(re *RuntimeError) {
i.errors = append(i.errors, re)
log.Println(re)

View file

@ -1,8 +1,8 @@
package main
import (
"errors"
"fmt"
"log"
)
type Parser struct {
@ -28,7 +28,7 @@ func newParser(tokens []Token) *Parser {
}
// program -> declaration* EOF
func (p *Parser) parse() ([]Stmt, []error) {
func (p *Parser) parse() ([]Stmt, error) {
defer p.recover()
stmts := []Stmt{}
@ -40,7 +40,7 @@ func (p *Parser) parse() ([]Stmt, []error) {
}
}
return stmts, p.errors
return stmts, errors.Join(p.errors...)
}
// declaration -> varDecl | funDecl | statement
@ -97,7 +97,7 @@ func (p *Parser) function(kind string) Stmt {
p.consume(RIGHT_PAREN, fmt.Sprintf("Expect ')' after %s name.", kind))
p.consume(LEFT_BRACE, fmt.Sprintf("Expect '{' after %s arguments.", kind))
body := p.blockStmt()
body := p.block()
return &FunStmt{name, args, body}
}
@ -172,8 +172,7 @@ func (p *Parser) printStmt() Stmt {
return &PrintStmt{expr}
}
// blockStmt -> "{" statement* "}"
func (p *Parser) blockStmt() *BlockStmt {
func (p *Parser) block() []Stmt {
stmts := []Stmt{}
for !p.check(RIGHT_BRACE) && !p.isAtEnd() {
@ -182,7 +181,12 @@ func (p *Parser) blockStmt() *BlockStmt {
p.consume(RIGHT_BRACE, "Unclosed block: Expected '}'.")
return &BlockStmt{stmts}
return stmts
}
// blockStmt -> "{" statement* "}"
func (p *Parser) blockStmt() *BlockStmt {
return &BlockStmt{p.block()}
}
// breakStmt -> break ";"
@ -496,7 +500,7 @@ func (p *Parser) lambda() Expr {
p.consume(RIGHT_PAREN, "Expect ')' after lambda arguments.")
p.consume(LEFT_BRACE, "Expect '{' before lambda body.")
body := p.blockStmt()
body := p.block()
return &Lambda{name, args, body}
}
@ -583,7 +587,6 @@ func (p *Parser) recover() {
func (p *Parser) panic(pe *ParseError) {
p.errors = append(p.errors, pe)
log.Println(pe)
panic(pe)
}

188
resolver.go Normal file
View file

@ -0,0 +1,188 @@
package main
type Scope map[string]bool
type Resolver struct {
interpreter *Interpreter
scopes Stack[Scope]
}
type ResolveError struct {
msg string
}
func (r *ResolveError) Error() string {
return r.msg
}
func newResolver(i *Interpreter) *Resolver {
return &Resolver{i, NewStack[Scope]()}
}
func (r *Resolver) resolveStmts(stmts ...Stmt) error {
for _, stmt := range stmts {
stmt.accept(r)
}
return nil
}
func (r *Resolver) resolveExprs(exprs ...Expr) error {
for _, expr := range exprs {
expr.accept(r)
}
return nil
}
func (r *Resolver) beginScope() {
r.scopes.Push(map[string]bool{})
}
func (r *Resolver) endScope() {
r.scopes.Pop()
}
func (r *Resolver) declare(token Token) {
if !r.scopes.Empty() {
r.scopes.Peek()[token.lexeme] = false
}
}
func (r *Resolver) define(token Token) {
if !r.scopes.Empty() {
r.scopes.Peek()[token.lexeme] = true
}
}
func (r *Resolver) visitBlockStmt(b *BlockStmt) {
r.beginScope()
r.resolveStmts(b.stmts...)
r.endScope()
}
func (r *Resolver) visitVarStmt(v *VarStmt) {
r.declare(v.name)
if v.initializer != nil {
r.resolveExprs(v.initializer)
}
r.define(v.name)
}
func (r *Resolver) visitVariable(v *Variable) any {
if !r.scopes.Empty() {
defined, declared := r.scopes.Peek()[v.name.lexeme]
if declared && !defined {
panic(&ResolveError{"Can't read local variable in its own initializer."})
}
}
r.resolveLocal(v, v.name)
return nil
}
func (r *Resolver) visitAssignment(a *Assign) any {
r.resolveExprs(a.value)
r.resolveLocal(a, a.variable)
return nil
}
func (r *Resolver) resolveLocal(expr Expr, name Token) {
for i := r.scopes.Size() - 1; i >= 0; i-- {
if _, exists := r.scopes.At(i)[name.lexeme]; exists {
r.interpreter.resolve(expr, r.scopes.Size()-1-i)
return
}
}
}
func (r *Resolver) visitFunStmt(fun *FunStmt) {
r.declare(fun.name)
r.define(fun.name)
r.resolveFun(fun)
}
func (r *Resolver) resolveFun(fun *FunStmt) {
r.beginScope()
for _, arg := range fun.args {
r.declare(arg)
r.define(arg)
}
r.resolveStmts(fun.body...)
r.endScope()
}
func (r *Resolver) visitExprStmt(es *ExprStmt) {
r.resolveExprs(es.expr)
}
func (r *Resolver) visitBreakStmt(b *BreakStmt) {}
func (r *Resolver) visitEnvStmt(b *EnvStmt) {}
func (r *Resolver) visitIfStmt(ifs *IfStmt) {
r.resolveExprs(ifs.cond)
r.resolveStmts(ifs.then)
if ifs.or != nil {
r.resolveStmts(ifs.or)
}
}
func (r *Resolver) visitPrintStmt(p *PrintStmt) {
r.resolveExprs(p.val)
}
func (r *Resolver) visitReturnStmt(ret *ReturnStmt) {
if ret.value != nil {
r.resolveExprs(ret.value)
}
}
func (r *Resolver) visitWhileStmt(w *WhileStmt) {
r.resolveExprs(w.cond)
r.resolveStmts(w.body)
}
func (r *Resolver) visitBinary(b *Binary) any {
r.resolveExprs(b.left)
r.resolveExprs(b.right)
return nil
}
func (r *Resolver) visitCall(c *Call) any {
r.resolveExprs(c.callee)
for _, arg := range c.args {
r.resolveExprs(arg)
}
return nil
}
func (r *Resolver) visitGrouping(g *Grouping) any {
r.resolveExprs(g.expression)
return nil
}
func (r *Resolver) visitLambda(l *Lambda) any {
r.beginScope()
for _, arg := range l.args {
r.declare(arg)
r.define(arg)
}
r.resolveStmts(l.body...)
r.endScope()
return nil
}
func (r *Resolver) visitLiteral(l *Literal) any {
return nil
}
func (r *Resolver) visitLogical(l *Logical) any {
r.resolveExprs(l.left)
r.resolveExprs(l.right)
return nil
}
func (r *Resolver) visitUnary(u *Unary) any {
r.resolveExprs(u.right)
return nil
}

47
stack.go Normal file
View file

@ -0,0 +1,47 @@
package main
type Stack[Item any] interface {
Push(Item)
Pop() Item
Peek() Item
At(int) Item
Size() int
Empty() bool
}
type node[Item any] struct {
item Item
next *node[Item]
}
type stack[OfType any] []OfType
func NewStack[OfType any]() Stack[OfType] {
return &stack[OfType]{}
}
func (s *stack[Item]) Push(item Item) {
*s = append(*s, item)
}
func (s *stack[Item]) Pop() Item {
last := s.Peek()
*s = (*s)[:len(*s)-1]
return last
}
func (s *stack[Item]) At(idx int) Item {
return (*s)[idx]
}
func (s *stack[Item]) Peek() Item {
return (*s)[len(*s)-1]
}
func (s *stack[_]) Size() int {
return len(*s)
}
func (s *stack[_]) Empty() bool {
return s.Size() == 0
}

View file

@ -39,7 +39,7 @@ type EnvStmt struct{}
type IfStmt struct {
name Token
expr Expr
cond Expr
then Stmt
or Stmt
}
@ -54,7 +54,7 @@ type BreakStmt struct{}
type FunStmt struct {
name Token
args []Token
body *BlockStmt
body []Stmt
}
type ReturnStmt struct {

View file

@ -4,6 +4,7 @@ print "native function";
print clock();
fun count(n) {
print n;
if (n > 1) count(n - 1);
print n;
}
@ -43,7 +44,6 @@ for (var i = 1; i <= 10; i = i + 1) {
print fib(i);
}
fun makeCounter() {
var i = 0;

View file

@ -0,0 +1,12 @@
var a = "global";
{
fun showA() {
print a;
}
showA();
var a = "inner v2";
showA();
}