diff --git a/ast.go b/ast.go index bced2ce..2e6d431 100644 --- a/ast.go +++ b/ast.go @@ -5,50 +5,6 @@ import ( "strings" ) -type Expr interface { - expr() - accept(v Visitor) -} - -type Unary struct { - op Token - right Expr -} - -func (u *Unary) expr() {} -func (u *Unary) accept(v Visitor) { - v.visitUnary(u) -} - -type Grouping struct { - expression Expr -} - -func (g *Grouping) expr() {} -func (g *Grouping) accept(v Visitor) { - v.visitGrouping(g) -} - -type Literal struct { - value any -} - -func (l *Literal) expr() {} -func (l *Literal) accept(v Visitor) { - v.visitLiteral(l) -} - -type Binary struct { - left Expr - op Token - right Expr -} - -func (b *Binary) expr() {} -func (b *Binary) accept(v Visitor) { - v.visitBinary(b) -} - type Visitor interface { visitBinary(b *Binary) visitLiteral(l *Literal) @@ -56,11 +12,61 @@ type Visitor interface { visitUnary(u *Unary) } +type Expr interface { + expr() + accept(v Visitor) +} + +type Binary struct { + left Expr + op Token + right Expr +} + +type Unary struct { + op Token + right Expr +} + +type Grouping struct { + expression Expr +} + +type Literal struct { + value any +} + +func (u *Unary) expr() {} +func (g *Grouping) expr() {} +func (l *Literal) expr() {} +func (b *Binary) expr() {} + +func (u *Unary) accept(v Visitor) { + v.visitUnary(u) +} + +func (g *Grouping) accept(v Visitor) { + v.visitGrouping(g) +} + +func (l *Literal) accept(v Visitor) { + v.visitLiteral(l) +} + +func (b *Binary) accept(v Visitor) { + v.visitBinary(b) +} + type AstStringer struct { str strings.Builder } func (as AstStringer) String(expr Expr) string { + + if expr == nil { + return "" + } + expr.accept(&as) return as.str.String() } @@ -91,3 +97,40 @@ func (as *AstStringer) visitUnary(u *Unary) { u.right.accept(as) as.str.WriteString(")") } + +type AstToRPN struct { + str strings.Builder +} + +func (as AstToRPN) String(expr Expr) string { + + if expr == nil { + return "" + } + + expr.accept(&as) + return as.str.String() +} + +func (as *AstToRPN) visitBinary(b *Binary) { + b.left.accept(as) + as.str.WriteString(" ") + b.right.accept(as) + as.str.WriteString(" ") + as.str.WriteString(b.op.lexeme) + +} + +func (as *AstToRPN) visitLiteral(l *Literal) { + as.str.WriteString(fmt.Sprintf("%v", l.value)) +} + +func (as *AstToRPN) visitGrouping(g *Grouping) { + g.expression.accept(as) + as.str.WriteString(" group") +} + +func (as *AstToRPN) visitUnary(u *Unary) { + u.right.accept(as) + as.str.WriteString(fmt.Sprintf(" %s", u.op.lexeme)) +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..58940cb --- /dev/null +++ b/error.go @@ -0,0 +1,21 @@ +package main + +import ( + "fmt" + "log" +) + +var hadError = false + +func printError(token Token, message string) { + if token.typ == EOF { + report(token.line, " at and", message) + } else { + report(token.line, fmt.Sprintf(" at '%s'", token.lexeme), message) + } +} + +func report(line int, where string, message string) { + log.Printf("[%d] Error %s: %s", line, where, message) + hadError = true +} diff --git a/glox.go b/glox.go index c89a0eb..7a5bb2b 100644 --- a/glox.go +++ b/glox.go @@ -7,17 +7,6 @@ import ( ) func main() { - expr := &Binary{ - &Unary{Token{MINUS, "-", nil, 1}, &Literal{123}}, - Token{STAR, "*", nil, 1}, - &Grouping{&Grouping{&Binary{ - &Unary{Token{MINUS, "-", nil, 1}, &Literal{123}}, - Token{STAR, "*", nil, 1}, - &Grouping{&Grouping{&Literal{45.67}}}}}}, - } - - println(AstStringer{}.String(expr)) - switch len(os.Args) { case 1: runPrompt() @@ -48,24 +37,21 @@ func runPrompt() { func runFile(path string) { file, err := os.ReadFile(path) - panic(err) + try(err) run(file) - - if hadError { - os.Exit(1) - } } func run(source []byte) { tokens := newScanner(source).scan() - for _, token := range tokens { - println(token.String()) - } + ast := newParser(tokens).parse() + + println(AstStringer{}.String(ast)) + println(AstToRPN{}.String(ast)) } -func panic(err error) { +func try(err error) { if err != nil { log.Fatal(err) } diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..8dd1604 --- /dev/null +++ b/parser.go @@ -0,0 +1,206 @@ +package main + +import ( + "fmt" +) + +type Parser struct { + tokens []Token + current int +} + +type ParseError struct { + token Token + message string +} + +func (pe ParseError) Error() string { + return fmt.Sprintf("%s: %s", pe.token.lexeme, pe.message) +} + +func newParser(tokens []Token) *Parser { + return &Parser{ + tokens: tokens, + current: 0, + } +} + +func (p *Parser) parse() Expr { + defer p.recover() + return p.expression() +} + +func (p *Parser) recover() { + if err := recover(); err != nil { + pe := err.(ParseError) + printError(pe.token, pe.message) + } +} + +// expression -> equality +func (p *Parser) expression() Expr { + return p.equality() +} + +// equality -> comparison ( ( "==" | "!=" ) comparison )* +func (p *Parser) equality() Expr { + expr := p.comparison() + + for p.match(EQUAL_EQUAL, BANG_EQUAL) { + op := p.previous() + right := p.comparison() + expr = &Binary{expr, op, right} + } + + return expr +} + +// comparison -> term ( ( ">" | ">=" | "<" | "<=" ) term )* +func (p *Parser) comparison() Expr { + expr := p.term() + + for p.match(GREATER, GREATER_EQUAL, LESS, LESS_EQUAL) { + op := p.previous() + right := p.term() + expr = &Binary{expr, op, right} + } + + return expr +} + +// term -> factor ( ( "-" | "+" ) factor )* +func (p *Parser) term() Expr { + expr := p.factor() + + for p.match(MINUS, PLUS) { + op := p.previous() + right := p.factor() + expr = &Binary{expr, op, right} + } + + return expr +} + +// factor -> unary ( ( "/" | "*" ) unary )* +func (p *Parser) factor() Expr { + exp := p.unary() + + for p.match(SLASH, STAR) { + op := p.previous() + right := p.unary() + exp = &Unary{op, right} + } + + return exp +} + +// unary -> ( "!" | "-" ) unary | primary +func (p *Parser) unary() Expr { + if p.match(BANG, MINUS) { + op := p.previous() + right := p.unary() + return &Unary{op, right} + } + + return p.primary() +} + +// primary -> NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" +func (p *Parser) primary() Expr { + switch { + case p.match(FALSE): + return &Literal{false} + case p.match(TRUE): + return &Literal{true} + case p.match(NIL): + return &Literal{nil} + } + + if p.match(NUMBER, STRING) { + return &Literal{p.previous().literal} + } + + if p.match(LEFT_PAREN) { + expr := p.expression() + p.consume(RIGHT_PAREN, "Expect ')' after expression") + return &Grouping{expr} + } + + panic(ParseError{p.peek(), "Expect expression"}) + + return nil +} + +func (p *Parser) previous() Token { + return p.tokens[p.current-1] +} + +func (p *Parser) peek() Token { + return p.tokens[p.current] +} + +func (p *Parser) isAtEnd() bool { + return p.peek().typ == EOF +} + +func (p *Parser) advance() Token { + if !p.isAtEnd() { + p.current++ + } + + return p.previous() +} + +func (p *Parser) check(typ TokenType) bool { + if p.isAtEnd() { + return false + } + + return p.peek().typ == typ +} + +func (p *Parser) match(types ...TokenType) bool { + + for _, typ := range types { + if p.check(typ) { + p.advance() + return true + } + } + + return false +} + +func (p *Parser) consume(typ TokenType, mes string) Token { + if p.check(typ) { + return p.advance() + } + + panic(ParseError{p.peek(), mes}) + + return Token{} +} + +func (p *Parser) synchronize() { + p.advance() + + for !p.isAtEnd() { + if p.previous().typ == SEMICOLON { + return + } + + switch p.peek().typ { + case CLASS: + case FOR: + case FUN: + case IF: + case PRINT: + case RETURN: + case VAR: + case WHILE: + return + } + + p.advance() + } +} diff --git a/scanner.go b/scanner.go index afcbda0..0d4a12e 100644 --- a/scanner.go +++ b/scanner.go @@ -8,8 +8,6 @@ import ( "unicode/utf8" ) -var hadError = false - //go:generate go run golang.org/x/tools/cmd/stringer -type=TokenType type TokenType int @@ -193,7 +191,7 @@ func (s *Scanner) scanToken() { break } - s.printError(fmt.Sprintf("Unexpected character %s", string(c))) + report(s.line, "", fmt.Sprintf("Unexpected character %s", string(c))) } } @@ -226,7 +224,7 @@ func (s *Scanner) string() { } if s.isAtEnd() { - s.printError("Unterminated string") + report(s.line, "", "Unterminated string") return } @@ -251,7 +249,7 @@ func (s *Scanner) number() { num, err := strconv.ParseFloat(string(s.source[s.start:s.current]), 64) if err != nil { - s.printError(err.Error()) + report(s.line, "", err.Error()) } s.addToken(NUMBER, num) @@ -288,15 +286,15 @@ func (s *Scanner) match(ch rune) bool { } decoded, size := utf8.DecodeRune(s.source[s.current:]) + + if decoded != ch { + return false + } + s.current += size - return ch == decoded + return true } func (s *Scanner) isAtEnd() bool { return s.current >= len(s.source) } - -func (s *Scanner) printError(message string) { - fmt.Printf("[line %d] Error: %s\n", s.line, message) - hadError = true -}