diff --git a/ast_string.go b/ast_string.go index 252c85e..0b88a14 100644 --- a/ast_string.go +++ b/ast_string.go @@ -140,3 +140,16 @@ func (as *AstStringer) visitWhileStmt(w *WhileStmt) { func (as *AstStringer) visitBreakStmt(b *BreakStmt) { as.str.WriteString("(break)") } + +func (as *AstStringer) visitFunStmt(f *FunStmt) { + as.str.WriteString(fmt.Sprintf("(fun %s", f.name.lexeme)) + if len(f.args) != 0 { + as.str.WriteString("(") + for _, arg := range f.args { + as.str.WriteString(arg.lexeme) + } + as.str.WriteString(")") + } + as.str.WriteString(")") + +} diff --git a/callable.go b/callable.go index 14da667..d87dec4 100644 --- a/callable.go +++ b/callable.go @@ -4,3 +4,20 @@ type Callable struct { arity int call func(*Interpreter, ...any) any } + +func newCallable(f *FunStmt) *Callable { + return &Callable{ + arity: len(f.args), + call: func(i *Interpreter, args ...any) any { + env := newEnvironment(i.globals) + + for idx, arg := range f.args { + env.set(arg.lexeme, args[idx]) + } + + i.executeBlock(f.body, env) + + return nil + }, + } +} diff --git a/interpreter.go b/interpreter.go index 307f837..fd82d80 100644 --- a/interpreter.go +++ b/interpreter.go @@ -197,6 +197,10 @@ func (i *Interpreter) visitCall(c *Call) any { return callable.call(i, args...) } +func (i *Interpreter) visitFunStmt(f *FunStmt) { + i.env.set(f.name.lexeme, newCallable(f)) +} + func (i *Interpreter) visitPrintStmt(p *PrintStmt) { fmt.Printf("%v\n", i.evaluate(p.val)) } @@ -217,9 +221,13 @@ func (i *Interpreter) visitVarStmt(v *VarStmt) { } func (i *Interpreter) visitBlockStmt(b *BlockStmt) { + i.executeBlock(b, newEnvironment(i.env)) +} + +func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) { parentEnv := i.env - i.env = newEnvironment(parentEnv) + i.env = current for _, stmt := range b.stmts { diff --git a/parser.go b/parser.go index 8b2a638..44826e0 100644 --- a/parser.go +++ b/parser.go @@ -43,18 +43,23 @@ func (p *Parser) parse() ([]Stmt, []error) { return stmts, p.errors } -// declaration -> varDecl | statement +// declaration -> varDecl | funDecl | statement func (p *Parser) declaration() Stmt { defer p.synchronize() if p.match(VAR) { return p.varDecl() } + + if p.match(FUN) { + return p.function("function") + } + return p.statement() } // varDecl -> "var" IDENTIFIER ("=" expression)? ";" func (p *Parser) varDecl() Stmt { - name := p.consume(IDENTIFIER, "expect identifier for variable") + name := p.consume(IDENTIFIER, "Expect identifier for variable") var initializer Expr = nil if p.match(EQUAL) { @@ -66,6 +71,37 @@ func (p *Parser) varDecl() Stmt { return &VarStmt{name, initializer} } +// funDecl -> "fun" function +// function -> IDENTIFIER "(" parameters? ")" blockStmt +// parameters -> IDENTIFIER ( "," IDENTIFIER )* +func (p *Parser) function(kind string) Stmt { + name := p.consume(IDENTIFIER, fmt.Sprintf("Expect %s name.", kind)) + + p.consume(LEFT_PAREN, fmt.Sprintf("Expect '(' after %s name.", kind)) + + args := []Token{} + for !p.check(RIGHT_PAREN) { + args = append( + args, + p.consume( + IDENTIFIER, + fmt.Sprintf("Expect %s argument.", kind), + ), + ) + + if p.check(COMMA) { + p.advance() + } + } + + p.consume(RIGHT_PAREN, fmt.Sprintf("Expect ')' after %s name.", kind)) + p.consume(LEFT_BRACE, fmt.Sprintf("Expect '{' after %s arguments.", kind)) + + body := p.blockStmt() + + return &FunStmt{name, args, body} +} + // statement -> exprStmt // // | whileStmt @@ -132,7 +168,7 @@ func (p *Parser) printStmt() Stmt { } // blockStmt -> "{" statement* "}" -func (p *Parser) blockStmt() Stmt { +func (p *Parser) blockStmt() *BlockStmt { stmts := []Stmt{} for !p.check(RIGHT_BRACE) && !p.isAtEnd() { diff --git a/stmt.go b/stmt.go index c4cfe5c..4d8c064 100644 --- a/stmt.go +++ b/stmt.go @@ -3,10 +3,11 @@ package main type StmtVisitor interface { visitIfStmt(i *IfStmt) visitVarStmt(v *VarStmt) + visitEnvStmt(e *EnvStmt) + visitFunStmt(f *FunStmt) visitExprStmt(es *ExprStmt) visitPrintStmt(p *PrintStmt) visitBlockStmt(b *BlockStmt) - visitEnvStmt(e *EnvStmt) visitWhileStmt(w *WhileStmt) visitBreakStmt(b *BreakStmt) } @@ -49,7 +50,14 @@ type WhileStmt struct { type BreakStmt struct{} +type FunStmt struct { + name Token + args []Token + body *BlockStmt +} + func (i *IfStmt) stmt() {} +func (f *FunStmt) stmt() {} func (e *EnvStmt) stmt() {} func (vs *VarStmt) stmt() {} func (es *ExprStmt) stmt() {} @@ -89,3 +97,7 @@ func (w *WhileStmt) accept(v StmtVisitor) { func (b *BreakStmt) accept(v StmtVisitor) { v.visitBreakStmt(b) } + +func (f *FunStmt) accept(v StmtVisitor) { + v.visitFunStmt(f) +} diff --git a/tests/functions.lox b/tests/functions.lox index 16d66bd..8b2a1c6 100644 --- a/tests/functions.lox +++ b/tests/functions.lox @@ -1,4 +1,26 @@ -print "functions test"; +print "native function"; -print clock(); \ No newline at end of file +print clock(); + +fun count(n) { + if (n > 1) count(n - 1); + print n; +} + +count(10); + +fun hi(name, surname) { + print "hello, " + name + " " + surname + "!"; +} + +hi("John", "Doe"); + + +fun re(turn) { + print "before retrun"; + return turn; + print "should not be printed"; +} + +print re("print"); \ No newline at end of file