diff --git a/ast_string.go b/ast_string.go index 0b88a14..19173ab 100644 --- a/ast_string.go +++ b/ast_string.go @@ -74,8 +74,12 @@ func (as *AstStringer) visitLogical(l *Logical) any { func (as *AstStringer) visitCall(c *Call) any { as.str.WriteString("(call ") c.callee.accept(as) - for _, arg := range c.arguments { + as.str.WriteString(" ") + for i, arg := range c.args { arg.accept(as) + if i < len(c.args)-1 { + as.str.WriteString(" ") + } } as.str.WriteString(")") @@ -145,11 +149,21 @@ func (as *AstStringer) visitFunStmt(f *FunStmt) { as.str.WriteString(fmt.Sprintf("(fun %s", f.name.lexeme)) if len(f.args) != 0 { as.str.WriteString("(") - for _, arg := range f.args { + for i, arg := range f.args { as.str.WriteString(arg.lexeme) + if i < len(f.args)-1 { + as.str.WriteString(" ") + } } as.str.WriteString(")") } + f.body.accept(as) as.str.WriteString(")") } + +func (as *AstStringer) visitReturnStmt(r *ReturnStmt) { + as.str.WriteString("(return ") + r.value.accept(as) + as.str.WriteString(")") +} diff --git a/callable.go b/callable.go index d87dec4..eddffc9 100644 --- a/callable.go +++ b/callable.go @@ -8,11 +8,24 @@ type Callable struct { func newCallable(f *FunStmt) *Callable { return &Callable{ arity: len(f.args), - call: func(i *Interpreter, args ...any) any { + call: func(i *Interpreter, args ...any) (ret any) { + + defer func() { + if err := recover(); err != nil { + re, ok := err.(Return) + + if !ok { + panic(err) + } + + ret = re.val + } + }() + env := newEnvironment(i.globals) for idx, arg := range f.args { - env.set(arg.lexeme, args[idx]) + env.define(arg.lexeme, args[idx]) } i.executeBlock(f.body, env) diff --git a/env.go b/env.go index 4324359..1b86e8d 100644 --- a/env.go +++ b/env.go @@ -1,12 +1,14 @@ package main +import "fmt" + type Environment struct { - values map[string]any - parent *Environment + values map[string]any + enclosing *Environment } -func newEnvironment(parent *Environment) *Environment { - return &Environment{values: map[string]any{}, parent: parent} +func newEnvironment(enclosing *Environment) *Environment { + return &Environment{map[string]any{}, enclosing} } func (env *Environment) get(key string) any { @@ -14,8 +16,8 @@ func (env *Environment) get(key string) any { return found } - if env.parent != nil { - return env.parent.get(key) + if env.enclosing != nil { + return env.enclosing.get(key) } return nil @@ -23,20 +25,22 @@ func (env *Environment) get(key string) any { func (env *Environment) exists(key string) bool { _, ok := env.values[key] - - if !ok && env.parent != nil { - return env.parent.exists(key) - } - return ok - } -func (env *Environment) set(key string, val any) { - - if env.parent != nil && env.parent.exists(key) { - env.parent.set(key, val) - } - +func (env *Environment) define(key string, val any) { env.values[key] = val } + +func (env *Environment) assign(key Token, val any) *RuntimeError { + if env.exists(key.lexeme) { + env.values[key.lexeme] = val + return nil + } + + if env.enclosing == nil { + return &RuntimeError{key, fmt.Sprintf("Can't assign: undefined variable '%s'.", key.lexeme)} + } + + return env.enclosing.assign(key, val) +} diff --git a/expr.go b/expr.go index b8d0cff..d617dbd 100644 --- a/expr.go +++ b/expr.go @@ -51,9 +51,9 @@ type Logical struct { } type Call struct { - callee Expr - paren Token - arguments []Expr + callee Expr + paren Token + args []Expr } func (c *Call) expr() {} diff --git a/expr_rpn.go b/expr_rpn.go deleted file mode 100644 index e988011..0000000 --- a/expr_rpn.go +++ /dev/null @@ -1,72 +0,0 @@ -package main - -import ( - "fmt" - "strings" -) - -type ExprToRPN struct { - str strings.Builder -} - -func (as ExprToRPN) String(expr Expr) string { - - if expr == nil { - return "" - } - - expr.accept(&as) - return as.str.String() -} - -func (as *ExprToRPN) visitBinary(b *Binary) any { - b.left.accept(as) - as.str.WriteString(" ") - b.right.accept(as) - as.str.WriteString(" ") - as.str.WriteString(b.op.lexeme) - return nil -} - -func (as *ExprToRPN) visitLiteral(l *Literal) any { - as.str.WriteString(fmt.Sprintf("%v", l.value)) - return nil -} - -func (as *ExprToRPN) visitGrouping(g *Grouping) any { - g.expression.accept(as) - as.str.WriteString(" group") - return nil -} - -func (as *ExprToRPN) visitUnary(u *Unary) any { - u.right.accept(as) - as.str.WriteString(fmt.Sprintf(" %s", u.op.lexeme)) - return nil -} - -func (as *ExprToRPN) visitVariable(va *Variable) any { - as.str.WriteString(va.name.lexeme) - return nil -} - -func (as *ExprToRPN) visitAssignment(a *Assign) any { - as.str.WriteString(fmt.Sprintf("%v %s =", a.value, a.variable.lexeme)) - return nil -} - -func (as *ExprToRPN) visitLogical(lo *Logical) any { - lo.left.accept(as) - lo.right.accept(as) - as.str.WriteString(" or") - return nil -} - -func (as *ExprToRPN) visitCall(c *Call) any { - for _, arg := range c.arguments { - arg.accept(as) - } - c.callee.accept(as) - as.str.WriteString(" call") - return nil -} diff --git a/globals.go b/globals.go index f9412af..437ce88 100644 --- a/globals.go +++ b/globals.go @@ -4,7 +4,7 @@ import "time" func defineGlobals(env *Environment) { - env.set("clock", &Callable{ + env.define("clock", &Callable{ arity: 0, call: func(i *Interpreter, arg ...any) any { return time.Now().Unix() diff --git a/interpreter.go b/interpreter.go index fd82d80..62d9945 100644 --- a/interpreter.go +++ b/interpreter.go @@ -19,6 +19,10 @@ type RuntimeError struct { msg string } +type Return struct { + val any +} + func (re *RuntimeError) Error() string { return fmt.Sprintf("RuntimeError [%d][%s] Error: %s", re.token.line, re.token.typ, re.msg) } @@ -131,25 +135,15 @@ func (i *Interpreter) visitUnary(u *Unary) any { } func (i *Interpreter) visitVariable(v *Variable) any { - - if !i.env.exists(v.name.lexeme) { - i.panic(&RuntimeError{v.name, fmt.Sprintf("Can't assign: undefined variable '%s'.", v.name.lexeme)}) - } - - val := i.env.get(v.name.lexeme) - return val + return i.env.get(v.name.lexeme) } func (i *Interpreter) visitAssignment(a *Assign) any { - - if !i.env.exists(a.variable.lexeme) { - i.panic(&RuntimeError{a.variable, fmt.Sprintf("Can't assign: undefined variable '%s'.", a.variable.lexeme)}) - } - val := i.evaluate(a.value) - - i.env.set(a.variable.lexeme, val) - + err := i.env.assign(a.variable, val) + if err != nil { + i.panic(err) + } return val } @@ -173,7 +167,7 @@ func (i *Interpreter) visitCall(c *Call) any { args := []any{} - for _, arg := range c.arguments { + for _, arg := range c.args { args = append(args, i.evaluate(arg)) } @@ -198,7 +192,17 @@ func (i *Interpreter) visitCall(c *Call) any { } func (i *Interpreter) visitFunStmt(f *FunStmt) { - i.env.set(f.name.lexeme, newCallable(f)) + i.env.define(f.name.lexeme, newCallable(f)) +} + +func (i *Interpreter) visitReturnStmt(r *ReturnStmt) { + var value any + + if r.value != nil { + value = i.evaluate(r.value) + } + + panic(Return{value}) } func (i *Interpreter) visitPrintStmt(p *PrintStmt) { @@ -217,7 +221,7 @@ func (i *Interpreter) visitVarStmt(v *VarStmt) { val = i.evaluate(v.initializer) } - i.env.set(v.name.lexeme, val) + i.env.define(v.name.lexeme, val) } func (i *Interpreter) visitBlockStmt(b *BlockStmt) { @@ -229,6 +233,12 @@ func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) { parentEnv := i.env i.env = current + // need to restore environment after + // panic(Return) in visitReturnStmt + defer func() { + i.env = parentEnv + }() + for _, stmt := range b.stmts { if i.brk { @@ -238,7 +248,6 @@ func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) { stmt.accept(i) } - i.env = parentEnv } func (i *Interpreter) visitBreakStmt(b *BreakStmt) { @@ -262,7 +271,7 @@ func (i *Interpreter) visitEnvStmt(e *EnvStmt) { for walker != nil { flatten = slices.Insert(flatten, 0, walker) - walker = walker.parent + walker = walker.enclosing } for ident, e := range flatten { diff --git a/parser.go b/parser.go index 44826e0..3b6ad76 100644 --- a/parser.go +++ b/parser.go @@ -110,6 +110,7 @@ func (p *Parser) function(kind string) Stmt { // | blockStmt // | breakStmt // | ifStmt +// | returnStmt // | env func (p *Parser) statement() Stmt { if p.match(PRINT) { @@ -140,6 +141,10 @@ func (p *Parser) statement() Stmt { return p.breakStmt() } + if p.match(RETURN) { + return p.returnStmt() + } + return p.exprStmt() } @@ -272,6 +277,13 @@ func (p *Parser) envStmt() Stmt { return &EnvStmt{} } +// return -> "return" expression ";" +func (p *Parser) returnStmt() Stmt { + ret := p.expression() + p.consume(SEMICOLON, "Expect ';' after return;") + return &ReturnStmt{ret} +} + // expression -> assignment func (p *Parser) expression() Expr { return p.assignment() diff --git a/stmt.go b/stmt.go index 4d8c064..b4c6258 100644 --- a/stmt.go +++ b/stmt.go @@ -10,6 +10,7 @@ type StmtVisitor interface { visitBlockStmt(b *BlockStmt) visitWhileStmt(w *WhileStmt) visitBreakStmt(b *BreakStmt) + visitReturnStmt(r *ReturnStmt) } type Stmt interface { @@ -56,15 +57,20 @@ type FunStmt struct { body *BlockStmt } -func (i *IfStmt) stmt() {} -func (f *FunStmt) stmt() {} -func (e *EnvStmt) stmt() {} -func (vs *VarStmt) stmt() {} -func (es *ExprStmt) stmt() {} -func (p *PrintStmt) stmt() {} -func (b *BlockStmt) stmt() {} -func (w *WhileStmt) stmt() {} -func (b *BreakStmt) stmt() {} +type ReturnStmt struct { + value Expr +} + +func (i *IfStmt) stmt() {} +func (f *FunStmt) stmt() {} +func (e *EnvStmt) stmt() {} +func (vs *VarStmt) stmt() {} +func (es *ExprStmt) stmt() {} +func (p *PrintStmt) stmt() {} +func (b *BlockStmt) stmt() {} +func (w *WhileStmt) stmt() {} +func (b *BreakStmt) stmt() {} +func (r *ReturnStmt) stmt() {} func (p *PrintStmt) accept(v StmtVisitor) { v.visitPrintStmt(p) @@ -101,3 +107,7 @@ func (b *BreakStmt) accept(v StmtVisitor) { func (f *FunStmt) accept(v StmtVisitor) { v.visitFunStmt(f) } + +func (r *ReturnStmt) accept(v StmtVisitor) { + v.visitReturnStmt(r) +} diff --git a/tests/functions.lox b/tests/functions.lox index 8b2a1c6..f8ccf07 100644 --- a/tests/functions.lox +++ b/tests/functions.lox @@ -18,9 +18,27 @@ hi("John", "Doe"); fun re(turn) { - print "before retrun"; + print "before return"; return turn; print "should not be printed"; } -print re("print"); \ No newline at end of file +print re("turn"); + + +fun sum(start, end) { + if (start == end) return start; + return start + sum(start + 1, end); +} + +print sum(1, 3); + +fun fib(n) { + if (n <= 1) return n; + + return fib(n - 2) + fib(n - 1); +} + +for (var i = 1; i <= 10; i = i + 1) { + print fib(i); +} \ No newline at end of file