user defined funcs

This commit is contained in:
Greg 2024-10-11 17:01:12 +03:00
parent ce0492c489
commit 1fdd522c8f
10 changed files with 137 additions and 129 deletions

View file

@ -74,8 +74,12 @@ func (as *AstStringer) visitLogical(l *Logical) any {
func (as *AstStringer) visitCall(c *Call) any { func (as *AstStringer) visitCall(c *Call) any {
as.str.WriteString("(call ") as.str.WriteString("(call ")
c.callee.accept(as) c.callee.accept(as)
for _, arg := range c.arguments { as.str.WriteString(" ")
for i, arg := range c.args {
arg.accept(as) arg.accept(as)
if i < len(c.args)-1 {
as.str.WriteString(" ")
}
} }
as.str.WriteString(")") as.str.WriteString(")")
@ -145,11 +149,21 @@ func (as *AstStringer) visitFunStmt(f *FunStmt) {
as.str.WriteString(fmt.Sprintf("(fun %s", f.name.lexeme)) as.str.WriteString(fmt.Sprintf("(fun %s", f.name.lexeme))
if len(f.args) != 0 { if len(f.args) != 0 {
as.str.WriteString("(") as.str.WriteString("(")
for _, arg := range f.args { for i, arg := range f.args {
as.str.WriteString(arg.lexeme) as.str.WriteString(arg.lexeme)
if i < len(f.args)-1 {
as.str.WriteString(" ")
}
} }
as.str.WriteString(")") as.str.WriteString(")")
} }
f.body.accept(as)
as.str.WriteString(")") as.str.WriteString(")")
} }
func (as *AstStringer) visitReturnStmt(r *ReturnStmt) {
as.str.WriteString("(return ")
r.value.accept(as)
as.str.WriteString(")")
}

View file

@ -8,11 +8,24 @@ type Callable struct {
func newCallable(f *FunStmt) *Callable { func newCallable(f *FunStmt) *Callable {
return &Callable{ return &Callable{
arity: len(f.args), arity: len(f.args),
call: func(i *Interpreter, args ...any) any { call: func(i *Interpreter, args ...any) (ret any) {
defer func() {
if err := recover(); err != nil {
re, ok := err.(Return)
if !ok {
panic(err)
}
ret = re.val
}
}()
env := newEnvironment(i.globals) env := newEnvironment(i.globals)
for idx, arg := range f.args { for idx, arg := range f.args {
env.set(arg.lexeme, args[idx]) env.define(arg.lexeme, args[idx])
} }
i.executeBlock(f.body, env) i.executeBlock(f.body, env)

38
env.go
View file

@ -1,12 +1,14 @@
package main package main
import "fmt"
type Environment struct { type Environment struct {
values map[string]any values map[string]any
parent *Environment enclosing *Environment
} }
func newEnvironment(parent *Environment) *Environment { func newEnvironment(enclosing *Environment) *Environment {
return &Environment{values: map[string]any{}, parent: parent} return &Environment{map[string]any{}, enclosing}
} }
func (env *Environment) get(key string) any { func (env *Environment) get(key string) any {
@ -14,8 +16,8 @@ func (env *Environment) get(key string) any {
return found return found
} }
if env.parent != nil { if env.enclosing != nil {
return env.parent.get(key) return env.enclosing.get(key)
} }
return nil return nil
@ -23,20 +25,22 @@ func (env *Environment) get(key string) any {
func (env *Environment) exists(key string) bool { func (env *Environment) exists(key string) bool {
_, ok := env.values[key] _, ok := env.values[key]
if !ok && env.parent != nil {
return env.parent.exists(key)
}
return ok return ok
} }
func (env *Environment) set(key string, val any) { func (env *Environment) define(key string, val any) {
if env.parent != nil && env.parent.exists(key) {
env.parent.set(key, val)
}
env.values[key] = val env.values[key] = val
} }
func (env *Environment) assign(key Token, val any) *RuntimeError {
if env.exists(key.lexeme) {
env.values[key.lexeme] = val
return nil
}
if env.enclosing == nil {
return &RuntimeError{key, fmt.Sprintf("Can't assign: undefined variable '%s'.", key.lexeme)}
}
return env.enclosing.assign(key, val)
}

View file

@ -53,7 +53,7 @@ type Logical struct {
type Call struct { type Call struct {
callee Expr callee Expr
paren Token paren Token
arguments []Expr args []Expr
} }
func (c *Call) expr() {} func (c *Call) expr() {}

View file

@ -1,72 +0,0 @@
package main
import (
"fmt"
"strings"
)
type ExprToRPN struct {
str strings.Builder
}
func (as ExprToRPN) String(expr Expr) string {
if expr == nil {
return ""
}
expr.accept(&as)
return as.str.String()
}
func (as *ExprToRPN) visitBinary(b *Binary) any {
b.left.accept(as)
as.str.WriteString(" ")
b.right.accept(as)
as.str.WriteString(" ")
as.str.WriteString(b.op.lexeme)
return nil
}
func (as *ExprToRPN) visitLiteral(l *Literal) any {
as.str.WriteString(fmt.Sprintf("%v", l.value))
return nil
}
func (as *ExprToRPN) visitGrouping(g *Grouping) any {
g.expression.accept(as)
as.str.WriteString(" group")
return nil
}
func (as *ExprToRPN) visitUnary(u *Unary) any {
u.right.accept(as)
as.str.WriteString(fmt.Sprintf(" %s", u.op.lexeme))
return nil
}
func (as *ExprToRPN) visitVariable(va *Variable) any {
as.str.WriteString(va.name.lexeme)
return nil
}
func (as *ExprToRPN) visitAssignment(a *Assign) any {
as.str.WriteString(fmt.Sprintf("%v %s =", a.value, a.variable.lexeme))
return nil
}
func (as *ExprToRPN) visitLogical(lo *Logical) any {
lo.left.accept(as)
lo.right.accept(as)
as.str.WriteString(" or")
return nil
}
func (as *ExprToRPN) visitCall(c *Call) any {
for _, arg := range c.arguments {
arg.accept(as)
}
c.callee.accept(as)
as.str.WriteString(" call")
return nil
}

View file

@ -4,7 +4,7 @@ import "time"
func defineGlobals(env *Environment) { func defineGlobals(env *Environment) {
env.set("clock", &Callable{ env.define("clock", &Callable{
arity: 0, arity: 0,
call: func(i *Interpreter, arg ...any) any { call: func(i *Interpreter, arg ...any) any {
return time.Now().Unix() return time.Now().Unix()

View file

@ -19,6 +19,10 @@ type RuntimeError struct {
msg string msg string
} }
type Return struct {
val any
}
func (re *RuntimeError) Error() string { func (re *RuntimeError) Error() string {
return fmt.Sprintf("RuntimeError [%d][%s] Error: %s", re.token.line, re.token.typ, re.msg) return fmt.Sprintf("RuntimeError [%d][%s] Error: %s", re.token.line, re.token.typ, re.msg)
} }
@ -131,25 +135,15 @@ func (i *Interpreter) visitUnary(u *Unary) any {
} }
func (i *Interpreter) visitVariable(v *Variable) any { func (i *Interpreter) visitVariable(v *Variable) any {
return i.env.get(v.name.lexeme)
if !i.env.exists(v.name.lexeme) {
i.panic(&RuntimeError{v.name, fmt.Sprintf("Can't assign: undefined variable '%s'.", v.name.lexeme)})
}
val := i.env.get(v.name.lexeme)
return val
} }
func (i *Interpreter) visitAssignment(a *Assign) any { func (i *Interpreter) visitAssignment(a *Assign) any {
if !i.env.exists(a.variable.lexeme) {
i.panic(&RuntimeError{a.variable, fmt.Sprintf("Can't assign: undefined variable '%s'.", a.variable.lexeme)})
}
val := i.evaluate(a.value) val := i.evaluate(a.value)
err := i.env.assign(a.variable, val)
i.env.set(a.variable.lexeme, val) if err != nil {
i.panic(err)
}
return val return val
} }
@ -173,7 +167,7 @@ func (i *Interpreter) visitCall(c *Call) any {
args := []any{} args := []any{}
for _, arg := range c.arguments { for _, arg := range c.args {
args = append(args, i.evaluate(arg)) args = append(args, i.evaluate(arg))
} }
@ -198,7 +192,17 @@ func (i *Interpreter) visitCall(c *Call) any {
} }
func (i *Interpreter) visitFunStmt(f *FunStmt) { func (i *Interpreter) visitFunStmt(f *FunStmt) {
i.env.set(f.name.lexeme, newCallable(f)) i.env.define(f.name.lexeme, newCallable(f))
}
func (i *Interpreter) visitReturnStmt(r *ReturnStmt) {
var value any
if r.value != nil {
value = i.evaluate(r.value)
}
panic(Return{value})
} }
func (i *Interpreter) visitPrintStmt(p *PrintStmt) { func (i *Interpreter) visitPrintStmt(p *PrintStmt) {
@ -217,7 +221,7 @@ func (i *Interpreter) visitVarStmt(v *VarStmt) {
val = i.evaluate(v.initializer) val = i.evaluate(v.initializer)
} }
i.env.set(v.name.lexeme, val) i.env.define(v.name.lexeme, val)
} }
func (i *Interpreter) visitBlockStmt(b *BlockStmt) { func (i *Interpreter) visitBlockStmt(b *BlockStmt) {
@ -229,6 +233,12 @@ func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) {
parentEnv := i.env parentEnv := i.env
i.env = current i.env = current
// need to restore environment after
// panic(Return) in visitReturnStmt
defer func() {
i.env = parentEnv
}()
for _, stmt := range b.stmts { for _, stmt := range b.stmts {
if i.brk { if i.brk {
@ -238,7 +248,6 @@ func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) {
stmt.accept(i) stmt.accept(i)
} }
i.env = parentEnv
} }
func (i *Interpreter) visitBreakStmt(b *BreakStmt) { func (i *Interpreter) visitBreakStmt(b *BreakStmt) {
@ -262,7 +271,7 @@ func (i *Interpreter) visitEnvStmt(e *EnvStmt) {
for walker != nil { for walker != nil {
flatten = slices.Insert(flatten, 0, walker) flatten = slices.Insert(flatten, 0, walker)
walker = walker.parent walker = walker.enclosing
} }
for ident, e := range flatten { for ident, e := range flatten {

View file

@ -110,6 +110,7 @@ func (p *Parser) function(kind string) Stmt {
// | blockStmt // | blockStmt
// | breakStmt // | breakStmt
// | ifStmt // | ifStmt
// | returnStmt
// | env // | env
func (p *Parser) statement() Stmt { func (p *Parser) statement() Stmt {
if p.match(PRINT) { if p.match(PRINT) {
@ -140,6 +141,10 @@ func (p *Parser) statement() Stmt {
return p.breakStmt() return p.breakStmt()
} }
if p.match(RETURN) {
return p.returnStmt()
}
return p.exprStmt() return p.exprStmt()
} }
@ -272,6 +277,13 @@ func (p *Parser) envStmt() Stmt {
return &EnvStmt{} return &EnvStmt{}
} }
// return -> "return" expression ";"
func (p *Parser) returnStmt() Stmt {
ret := p.expression()
p.consume(SEMICOLON, "Expect ';' after return;")
return &ReturnStmt{ret}
}
// expression -> assignment // expression -> assignment
func (p *Parser) expression() Expr { func (p *Parser) expression() Expr {
return p.assignment() return p.assignment()

10
stmt.go
View file

@ -10,6 +10,7 @@ type StmtVisitor interface {
visitBlockStmt(b *BlockStmt) visitBlockStmt(b *BlockStmt)
visitWhileStmt(w *WhileStmt) visitWhileStmt(w *WhileStmt)
visitBreakStmt(b *BreakStmt) visitBreakStmt(b *BreakStmt)
visitReturnStmt(r *ReturnStmt)
} }
type Stmt interface { type Stmt interface {
@ -56,6 +57,10 @@ type FunStmt struct {
body *BlockStmt body *BlockStmt
} }
type ReturnStmt struct {
value Expr
}
func (i *IfStmt) stmt() {} func (i *IfStmt) stmt() {}
func (f *FunStmt) stmt() {} func (f *FunStmt) stmt() {}
func (e *EnvStmt) stmt() {} func (e *EnvStmt) stmt() {}
@ -65,6 +70,7 @@ func (p *PrintStmt) stmt() {}
func (b *BlockStmt) stmt() {} func (b *BlockStmt) stmt() {}
func (w *WhileStmt) stmt() {} func (w *WhileStmt) stmt() {}
func (b *BreakStmt) stmt() {} func (b *BreakStmt) stmt() {}
func (r *ReturnStmt) stmt() {}
func (p *PrintStmt) accept(v StmtVisitor) { func (p *PrintStmt) accept(v StmtVisitor) {
v.visitPrintStmt(p) v.visitPrintStmt(p)
@ -101,3 +107,7 @@ func (b *BreakStmt) accept(v StmtVisitor) {
func (f *FunStmt) accept(v StmtVisitor) { func (f *FunStmt) accept(v StmtVisitor) {
v.visitFunStmt(f) v.visitFunStmt(f)
} }
func (r *ReturnStmt) accept(v StmtVisitor) {
v.visitReturnStmt(r)
}

View file

@ -18,9 +18,27 @@ hi("John", "Doe");
fun re(turn) { fun re(turn) {
print "before retrun"; print "before return";
return turn; return turn;
print "should not be printed"; print "should not be printed";
} }
print re("print"); print re("turn");
fun sum(start, end) {
if (start == end) return start;
return start + sum(start + 1, end);
}
print sum(1, 3);
fun fib(n) {
if (n <= 1) return n;
return fib(n - 2) + fib(n - 1);
}
for (var i = 1; i <= 10; i = i + 1) {
print fib(i);
}