This commit is contained in:
Greg 2024-10-12 14:48:27 +03:00
parent 1117f2c104
commit a24586e601
6 changed files with 96 additions and 11 deletions

View file

@ -86,6 +86,24 @@ func (as *AstStringer) visitCall(c *Call) any {
return nil return nil
} }
func (as *AstStringer) visitLambda(l *Lambda) any {
as.str.WriteString("(lambda ")
if len(l.args) != 0 {
as.str.WriteString("(")
for i, arg := range l.args {
as.str.WriteString(arg.lexeme)
if i < len(l.args)-1 {
as.str.WriteString(" ")
}
}
as.str.WriteString(")")
}
l.body.accept(as)
as.str.WriteString(")")
return nil
}
func (as *AstStringer) visitPrintStmt(p *PrintStmt) { func (as *AstStringer) visitPrintStmt(p *PrintStmt) {
as.str.WriteString("(print ") as.str.WriteString("(print ")
p.val.accept(as) p.val.accept(as)
@ -159,7 +177,6 @@ func (as *AstStringer) visitFunStmt(f *FunStmt) {
} }
f.body.accept(as) f.body.accept(as)
as.str.WriteString(")") as.str.WriteString(")")
} }
func (as *AstStringer) visitReturnStmt(r *ReturnStmt) { func (as *AstStringer) visitReturnStmt(r *ReturnStmt) {

View file

@ -6,8 +6,10 @@ type Callable interface {
} }
type Function struct { type Function struct {
definition *FunStmt name Token
closure *Environment args []Token
body *BlockStmt
closure *Environment
} }
func (f *Function) call(i *Interpreter, args ...any) (ret any) { func (f *Function) call(i *Interpreter, args ...any) (ret any) {
@ -26,19 +28,19 @@ func (f *Function) call(i *Interpreter, args ...any) (ret any) {
env := newEnvironment(f.closure) env := newEnvironment(f.closure)
for idx, arg := range f.definition.args { for idx, arg := range f.args {
env.define(arg.lexeme, args[idx]) env.define(arg.lexeme, args[idx])
} }
i.executeBlock(f.definition.body, env) i.executeBlock(f.body, env)
return nil return nil
} }
func (f *Function) arity() int { func (f *Function) arity() int {
return len(f.definition.args) return len(f.args)
} }
func newFunction(fun *FunStmt, env *Environment) Callable { func newFunction(name Token, args []Token, body *BlockStmt, env *Environment) Callable {
return &Function{fun, env} return &Function{name, args, body, env}
} }

12
expr.go
View file

@ -3,6 +3,7 @@ package main
type ExprVisitor interface { type ExprVisitor interface {
visitCall(c *Call) any visitCall(c *Call) any
visitUnary(u *Unary) any visitUnary(u *Unary) any
visitLambda(l *Lambda) any
visitBinary(b *Binary) any visitBinary(b *Binary) any
visitLiteral(l *Literal) any visitLiteral(l *Literal) any
visitGrouping(g *Grouping) any visitGrouping(g *Grouping) any
@ -56,10 +57,17 @@ type Call struct {
args []Expr args []Expr
} }
type Lambda struct {
name Token
args []Token
body *BlockStmt
}
func (c *Call) expr() {} func (c *Call) expr() {}
func (u *Unary) expr() {} func (u *Unary) expr() {}
func (a *Assign) expr() {} func (a *Assign) expr() {}
func (b *Binary) expr() {} func (b *Binary) expr() {}
func (l *Lambda) expr() {}
func (l *Literal) expr() {} func (l *Literal) expr() {}
func (g *Grouping) expr() {} func (g *Grouping) expr() {}
func (v *Variable) expr() {} func (v *Variable) expr() {}
@ -96,3 +104,7 @@ func (l *Logical) accept(v ExprVisitor) any {
func (c *Call) accept(v ExprVisitor) any { func (c *Call) accept(v ExprVisitor) any {
return v.visitCall(c) return v.visitCall(c)
} }
func (l *Lambda) accept(v ExprVisitor) any {
return v.visitLambda(l)
}

View file

@ -192,7 +192,11 @@ func (i *Interpreter) visitCall(c *Call) any {
} }
func (i *Interpreter) visitFunStmt(f *FunStmt) { func (i *Interpreter) visitFunStmt(f *FunStmt) {
i.env.define(f.name.lexeme, newFunction(f, i.env)) i.env.define(f.name.lexeme, newFunction(f.name, f.args, f.body, i.env))
}
func (i *Interpreter) visitLambda(l *Lambda) any {
return newFunction(l.name, l.args, l.body, i.env)
} }
func (i *Interpreter) visitReturnStmt(r *ReturnStmt) { func (i *Interpreter) visitReturnStmt(r *ReturnStmt) {

View file

@ -431,7 +431,15 @@ func (p *Parser) arguments(callee Expr) Expr {
return &Call{callee, paren, arguments} return &Call{callee, paren, arguments}
} }
// primary -> NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" | IDENTIFIER // primary -> IDENTIFIER
//
// | NUMBER
// | STRING
// | "true"
// | "false"
// | "nil"
// | "(" expression ")"
// | lambda
func (p *Parser) primary() Expr { func (p *Parser) primary() Expr {
switch { switch {
case p.match(FALSE): case p.match(FALSE):
@ -442,6 +450,10 @@ func (p *Parser) primary() Expr {
return &Literal{nil} return &Literal{nil}
} }
if p.match(FUN) {
return p.lambda()
}
if p.match(NUMBER, STRING) { if p.match(NUMBER, STRING) {
return &Literal{p.previous().literal} return &Literal{p.previous().literal}
} }
@ -461,6 +473,34 @@ func (p *Parser) primary() Expr {
return nil return nil
} }
func (p *Parser) lambda() Expr {
name := p.previous()
p.consume(LEFT_PAREN, "Expect '(' before lambda arguments.")
args := []Token{}
for !p.check(RIGHT_PAREN) {
args = append(
args,
p.consume(
IDENTIFIER,
"Expect lambda argument.",
),
)
if p.check(COMMA) {
p.advance()
}
}
p.consume(RIGHT_PAREN, "Expect ')' after lambda arguments.")
p.consume(LEFT_BRACE, "Expect '{' before lambda body.")
body := p.blockStmt()
return &Lambda{name, args, body}
}
func (p *Parser) previous() Token { func (p *Parser) previous() Token {
return p.tokens[p.current-1] return p.tokens[p.current-1]
} }

View file

@ -59,3 +59,13 @@ var counter = makeCounter();
counter(); counter();
counter(); counter();
fun thrice(fn) {
for (var i = 1; i <= 3; i = i + 1) {
fn(i);
}
}
thrice(fun (a) { print a; });
print fun () { return "hello, "; }() + "world";