diff --git a/ast_string.go b/ast_string.go index 19173ab..6b6a470 100644 --- a/ast_string.go +++ b/ast_string.go @@ -86,6 +86,24 @@ func (as *AstStringer) visitCall(c *Call) any { return nil } +func (as *AstStringer) visitLambda(l *Lambda) any { + as.str.WriteString("(lambda ") + if len(l.args) != 0 { + as.str.WriteString("(") + for i, arg := range l.args { + as.str.WriteString(arg.lexeme) + if i < len(l.args)-1 { + as.str.WriteString(" ") + } + } + as.str.WriteString(")") + } + l.body.accept(as) + as.str.WriteString(")") + + return nil +} + func (as *AstStringer) visitPrintStmt(p *PrintStmt) { as.str.WriteString("(print ") p.val.accept(as) @@ -159,7 +177,6 @@ func (as *AstStringer) visitFunStmt(f *FunStmt) { } f.body.accept(as) as.str.WriteString(")") - } func (as *AstStringer) visitReturnStmt(r *ReturnStmt) { diff --git a/callable.go b/callable.go index be3cc74..eceb357 100644 --- a/callable.go +++ b/callable.go @@ -6,8 +6,10 @@ type Callable interface { } type Function struct { - definition *FunStmt - closure *Environment + name Token + args []Token + body *BlockStmt + closure *Environment } func (f *Function) call(i *Interpreter, args ...any) (ret any) { @@ -26,19 +28,19 @@ func (f *Function) call(i *Interpreter, args ...any) (ret any) { env := newEnvironment(f.closure) - for idx, arg := range f.definition.args { + for idx, arg := range f.args { env.define(arg.lexeme, args[idx]) } - i.executeBlock(f.definition.body, env) + i.executeBlock(f.body, env) return nil } func (f *Function) arity() int { - return len(f.definition.args) + return len(f.args) } -func newFunction(fun *FunStmt, env *Environment) Callable { - return &Function{fun, env} +func newFunction(name Token, args []Token, body *BlockStmt, env *Environment) Callable { + return &Function{name, args, body, env} } diff --git a/expr.go b/expr.go index d617dbd..6530d6c 100644 --- a/expr.go +++ b/expr.go @@ -3,6 +3,7 @@ package main type ExprVisitor interface { visitCall(c *Call) any visitUnary(u *Unary) any + visitLambda(l *Lambda) any visitBinary(b *Binary) any visitLiteral(l *Literal) any visitGrouping(g *Grouping) any @@ -56,10 +57,17 @@ type Call struct { args []Expr } +type Lambda struct { + name Token + args []Token + body *BlockStmt +} + func (c *Call) expr() {} func (u *Unary) expr() {} func (a *Assign) expr() {} func (b *Binary) expr() {} +func (l *Lambda) expr() {} func (l *Literal) expr() {} func (g *Grouping) expr() {} func (v *Variable) expr() {} @@ -96,3 +104,7 @@ func (l *Logical) accept(v ExprVisitor) any { func (c *Call) accept(v ExprVisitor) any { return v.visitCall(c) } + +func (l *Lambda) accept(v ExprVisitor) any { + return v.visitLambda(l) +} diff --git a/interpreter.go b/interpreter.go index debd954..7cd8c22 100644 --- a/interpreter.go +++ b/interpreter.go @@ -192,7 +192,11 @@ func (i *Interpreter) visitCall(c *Call) any { } func (i *Interpreter) visitFunStmt(f *FunStmt) { - i.env.define(f.name.lexeme, newFunction(f, i.env)) + i.env.define(f.name.lexeme, newFunction(f.name, f.args, f.body, i.env)) +} + +func (i *Interpreter) visitLambda(l *Lambda) any { + return newFunction(l.name, l.args, l.body, i.env) } func (i *Interpreter) visitReturnStmt(r *ReturnStmt) { diff --git a/parser.go b/parser.go index 3b6ad76..94f01ae 100644 --- a/parser.go +++ b/parser.go @@ -431,7 +431,15 @@ func (p *Parser) arguments(callee Expr) Expr { return &Call{callee, paren, arguments} } -// primary -> NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" | IDENTIFIER +// primary -> IDENTIFIER +// +// | NUMBER +// | STRING +// | "true" +// | "false" +// | "nil" +// | "(" expression ")" +// | lambda func (p *Parser) primary() Expr { switch { case p.match(FALSE): @@ -442,6 +450,10 @@ func (p *Parser) primary() Expr { return &Literal{nil} } + if p.match(FUN) { + return p.lambda() + } + if p.match(NUMBER, STRING) { return &Literal{p.previous().literal} } @@ -461,6 +473,34 @@ func (p *Parser) primary() Expr { return nil } +func (p *Parser) lambda() Expr { + name := p.previous() + + p.consume(LEFT_PAREN, "Expect '(' before lambda arguments.") + + args := []Token{} + for !p.check(RIGHT_PAREN) { + args = append( + args, + p.consume( + IDENTIFIER, + "Expect lambda argument.", + ), + ) + + if p.check(COMMA) { + p.advance() + } + } + + p.consume(RIGHT_PAREN, "Expect ')' after lambda arguments.") + p.consume(LEFT_BRACE, "Expect '{' before lambda body.") + + body := p.blockStmt() + + return &Lambda{name, args, body} +} + func (p *Parser) previous() Token { return p.tokens[p.current-1] } diff --git a/tests/functions.lox b/tests/functions.lox index 0865723..abc0c62 100644 --- a/tests/functions.lox +++ b/tests/functions.lox @@ -58,4 +58,14 @@ fun makeCounter() { var counter = makeCounter(); counter(); -counter(); \ No newline at end of file +counter(); + +fun thrice(fn) { + for (var i = 1; i <= 3; i = i + 1) { + fn(i); + } +} + +thrice(fun (a) { print a; }); + +print fun () { return "hello, "; }() + "world"; \ No newline at end of file