diff --git a/ast_string.go b/ast_string.go index de5f9f8..be86ec0 100644 --- a/ast_string.go +++ b/ast_string.go @@ -117,3 +117,11 @@ func (as *AstStringer) visitIfStmt(i *IfStmt) { func (as *AstStringer) visitEnvStmt(e *EnvStmt) { as.str.WriteString("(env)") } + +func (as *AstStringer) visitWhileStmt(w *WhileStmt) { + as.str.WriteString("(while ") + w.cond.accept(as) + as.str.WriteString(" ") + w.body.accept(as) + as.str.WriteString(")") +} diff --git a/env.go b/env.go index 38d9b98..4324359 100644 --- a/env.go +++ b/env.go @@ -33,5 +33,10 @@ func (env *Environment) exists(key string) bool { } func (env *Environment) set(key string, val any) { + + if env.parent != nil && env.parent.exists(key) { + env.parent.set(key, val) + } + env.values[key] = val } diff --git a/interpreter.go b/interpreter.go index afc6f9b..bf51d06 100644 --- a/interpreter.go +++ b/interpreter.go @@ -210,7 +210,12 @@ func (i *Interpreter) visitEnvStmt(e *EnvStmt) { fmt.Printf("%*s", ident, "") fmt.Printf("%+v\n", *e) } +} +func (i *Interpreter) visitWhileStmt(w *WhileStmt) { + for isTruthy(i.evaluate(w.cond)) { + w.body.accept(i) + } } func (i *Interpreter) panic(re *RuntimeError) { diff --git a/parser.go b/parser.go index d40e8d8..6049e17 100644 --- a/parser.go +++ b/parser.go @@ -66,14 +66,20 @@ func (p *Parser) varDecl() Stmt { return &VarStmt{name, initializer} } -// statement -> exprStmt | printStmt | block | ifStmt | env +// statement -> exprStmt +// +// | whileStmt +// | printStmt +// | blockStmt +// | ifStmt +// | env func (p *Parser) statement() Stmt { if p.match(PRINT) { return p.printStmt() } if p.match(LEFT_BRACE) { - return p.block() + return p.blockStmt() } if p.match(IF) { @@ -84,6 +90,10 @@ func (p *Parser) statement() Stmt { return p.envStmt() } + if p.match(WHILE) { + return p.whileStmt() + } + return p.exprStmt() } @@ -111,8 +121,8 @@ func (p *Parser) printStmt() Stmt { return &PrintStmt{expr} } -// block -> "{" statement* "}" -func (p *Parser) block() Stmt { +// blockStmt -> "{" statement* "}" +func (p *Parser) blockStmt() Stmt { stmts := []Stmt{} for !p.check(RIGHT_BRACE) && !p.isAtEnd() { @@ -140,6 +150,16 @@ func (p *Parser) ifStmt() Stmt { return &IfStmt{name, expr, then, or} } +// while -> "while" "(" expression ")" statement +func (p *Parser) whileStmt() Stmt { + p.consume(LEFT_PAREN, "Expect '(' after 'while'.") + cond := p.expression() + p.consume(RIGHT_PAREN, "Expect ')' after 'while' expression.") + body := p.statement() + + return &WhileStmt{cond, body} +} + // env -> "env" ";" func (p *Parser) envStmt() Stmt { p.consume(SEMICOLON, "Expect ';' after 'env'.") diff --git a/stmt.go b/stmt.go index cd6efb2..429c97b 100644 --- a/stmt.go +++ b/stmt.go @@ -7,6 +7,7 @@ type StmtVisitor interface { visitPrintStmt(p *PrintStmt) visitBlockStmt(b *BlockStmt) visitEnvStmt(e *EnvStmt) + visitWhileStmt(w *WhileStmt) } type Stmt interface { @@ -40,12 +41,18 @@ type IfStmt struct { or Stmt } +type WhileStmt struct { + cond Expr + body Stmt +} + func (i *IfStmt) stmt() {} +func (e *EnvStmt) stmt() {} func (vs *VarStmt) stmt() {} func (es *ExprStmt) stmt() {} func (p *PrintStmt) stmt() {} func (b *BlockStmt) stmt() {} -func (e *EnvStmt) stmt() {} +func (w *WhileStmt) stmt() {} func (p *PrintStmt) accept(v StmtVisitor) { v.visitPrintStmt(p) @@ -70,3 +77,7 @@ func (i *IfStmt) accept(v StmtVisitor) { func (e *EnvStmt) accept(v StmtVisitor) { v.visitEnvStmt(e) } + +func (w *WhileStmt) accept(v StmtVisitor) { + v.visitWhileStmt(w) +}