fun declaration

This commit is contained in:
Greg 2024-10-09 23:36:18 +03:00
parent a6e0673c3b
commit ce0492c489
6 changed files with 115 additions and 7 deletions

View file

@ -140,3 +140,16 @@ func (as *AstStringer) visitWhileStmt(w *WhileStmt) {
func (as *AstStringer) visitBreakStmt(b *BreakStmt) {
as.str.WriteString("(break)")
}
func (as *AstStringer) visitFunStmt(f *FunStmt) {
as.str.WriteString(fmt.Sprintf("(fun %s", f.name.lexeme))
if len(f.args) != 0 {
as.str.WriteString("(")
for _, arg := range f.args {
as.str.WriteString(arg.lexeme)
}
as.str.WriteString(")")
}
as.str.WriteString(")")
}

View file

@ -4,3 +4,20 @@ type Callable struct {
arity int
call func(*Interpreter, ...any) any
}
func newCallable(f *FunStmt) *Callable {
return &Callable{
arity: len(f.args),
call: func(i *Interpreter, args ...any) any {
env := newEnvironment(i.globals)
for idx, arg := range f.args {
env.set(arg.lexeme, args[idx])
}
i.executeBlock(f.body, env)
return nil
},
}
}

View file

@ -197,6 +197,10 @@ func (i *Interpreter) visitCall(c *Call) any {
return callable.call(i, args...)
}
func (i *Interpreter) visitFunStmt(f *FunStmt) {
i.env.set(f.name.lexeme, newCallable(f))
}
func (i *Interpreter) visitPrintStmt(p *PrintStmt) {
fmt.Printf("%v\n", i.evaluate(p.val))
}
@ -217,9 +221,13 @@ func (i *Interpreter) visitVarStmt(v *VarStmt) {
}
func (i *Interpreter) visitBlockStmt(b *BlockStmt) {
i.executeBlock(b, newEnvironment(i.env))
}
func (i *Interpreter) executeBlock(b *BlockStmt, current *Environment) {
parentEnv := i.env
i.env = newEnvironment(parentEnv)
i.env = current
for _, stmt := range b.stmts {

View file

@ -43,18 +43,23 @@ func (p *Parser) parse() ([]Stmt, []error) {
return stmts, p.errors
}
// declaration -> varDecl | statement
// declaration -> varDecl | funDecl | statement
func (p *Parser) declaration() Stmt {
defer p.synchronize()
if p.match(VAR) {
return p.varDecl()
}
if p.match(FUN) {
return p.function("function")
}
return p.statement()
}
// varDecl -> "var" IDENTIFIER ("=" expression)? ";"
func (p *Parser) varDecl() Stmt {
name := p.consume(IDENTIFIER, "expect identifier for variable")
name := p.consume(IDENTIFIER, "Expect identifier for variable")
var initializer Expr = nil
if p.match(EQUAL) {
@ -66,6 +71,37 @@ func (p *Parser) varDecl() Stmt {
return &VarStmt{name, initializer}
}
// funDecl -> "fun" function
// function -> IDENTIFIER "(" parameters? ")" blockStmt
// parameters -> IDENTIFIER ( "," IDENTIFIER )*
func (p *Parser) function(kind string) Stmt {
name := p.consume(IDENTIFIER, fmt.Sprintf("Expect %s name.", kind))
p.consume(LEFT_PAREN, fmt.Sprintf("Expect '(' after %s name.", kind))
args := []Token{}
for !p.check(RIGHT_PAREN) {
args = append(
args,
p.consume(
IDENTIFIER,
fmt.Sprintf("Expect %s argument.", kind),
),
)
if p.check(COMMA) {
p.advance()
}
}
p.consume(RIGHT_PAREN, fmt.Sprintf("Expect ')' after %s name.", kind))
p.consume(LEFT_BRACE, fmt.Sprintf("Expect '{' after %s arguments.", kind))
body := p.blockStmt()
return &FunStmt{name, args, body}
}
// statement -> exprStmt
//
// | whileStmt
@ -132,7 +168,7 @@ func (p *Parser) printStmt() Stmt {
}
// blockStmt -> "{" statement* "}"
func (p *Parser) blockStmt() Stmt {
func (p *Parser) blockStmt() *BlockStmt {
stmts := []Stmt{}
for !p.check(RIGHT_BRACE) && !p.isAtEnd() {

14
stmt.go
View file

@ -3,10 +3,11 @@ package main
type StmtVisitor interface {
visitIfStmt(i *IfStmt)
visitVarStmt(v *VarStmt)
visitEnvStmt(e *EnvStmt)
visitFunStmt(f *FunStmt)
visitExprStmt(es *ExprStmt)
visitPrintStmt(p *PrintStmt)
visitBlockStmt(b *BlockStmt)
visitEnvStmt(e *EnvStmt)
visitWhileStmt(w *WhileStmt)
visitBreakStmt(b *BreakStmt)
}
@ -49,7 +50,14 @@ type WhileStmt struct {
type BreakStmt struct{}
type FunStmt struct {
name Token
args []Token
body *BlockStmt
}
func (i *IfStmt) stmt() {}
func (f *FunStmt) stmt() {}
func (e *EnvStmt) stmt() {}
func (vs *VarStmt) stmt() {}
func (es *ExprStmt) stmt() {}
@ -89,3 +97,7 @@ func (w *WhileStmt) accept(v StmtVisitor) {
func (b *BreakStmt) accept(v StmtVisitor) {
v.visitBreakStmt(b)
}
func (f *FunStmt) accept(v StmtVisitor) {
v.visitFunStmt(f)
}

View file

@ -1,4 +1,26 @@
print "functions test";
print "native function";
print clock();
print clock();
fun count(n) {
if (n > 1) count(n - 1);
print n;
}
count(10);
fun hi(name, surname) {
print "hello, " + name + " " + surname + "!";
}
hi("John", "Doe");
fun re(turn) {
print "before retrun";
return turn;
print "should not be printed";
}
print re("print");