This commit is contained in:
Greg 2024-10-03 22:12:40 +03:00
parent c12d2047b1
commit 680df31650
5 changed files with 329 additions and 75 deletions

131
ast.go
View file

@ -5,50 +5,6 @@ import (
"strings" "strings"
) )
type Expr interface {
expr()
accept(v Visitor)
}
type Unary struct {
op Token
right Expr
}
func (u *Unary) expr() {}
func (u *Unary) accept(v Visitor) {
v.visitUnary(u)
}
type Grouping struct {
expression Expr
}
func (g *Grouping) expr() {}
func (g *Grouping) accept(v Visitor) {
v.visitGrouping(g)
}
type Literal struct {
value any
}
func (l *Literal) expr() {}
func (l *Literal) accept(v Visitor) {
v.visitLiteral(l)
}
type Binary struct {
left Expr
op Token
right Expr
}
func (b *Binary) expr() {}
func (b *Binary) accept(v Visitor) {
v.visitBinary(b)
}
type Visitor interface { type Visitor interface {
visitBinary(b *Binary) visitBinary(b *Binary)
visitLiteral(l *Literal) visitLiteral(l *Literal)
@ -56,11 +12,61 @@ type Visitor interface {
visitUnary(u *Unary) visitUnary(u *Unary)
} }
type Expr interface {
expr()
accept(v Visitor)
}
type Binary struct {
left Expr
op Token
right Expr
}
type Unary struct {
op Token
right Expr
}
type Grouping struct {
expression Expr
}
type Literal struct {
value any
}
func (u *Unary) expr() {}
func (g *Grouping) expr() {}
func (l *Literal) expr() {}
func (b *Binary) expr() {}
func (u *Unary) accept(v Visitor) {
v.visitUnary(u)
}
func (g *Grouping) accept(v Visitor) {
v.visitGrouping(g)
}
func (l *Literal) accept(v Visitor) {
v.visitLiteral(l)
}
func (b *Binary) accept(v Visitor) {
v.visitBinary(b)
}
type AstStringer struct { type AstStringer struct {
str strings.Builder str strings.Builder
} }
func (as AstStringer) String(expr Expr) string { func (as AstStringer) String(expr Expr) string {
if expr == nil {
return ""
}
expr.accept(&as) expr.accept(&as)
return as.str.String() return as.str.String()
} }
@ -91,3 +97,40 @@ func (as *AstStringer) visitUnary(u *Unary) {
u.right.accept(as) u.right.accept(as)
as.str.WriteString(")") as.str.WriteString(")")
} }
type AstToRPN struct {
str strings.Builder
}
func (as AstToRPN) String(expr Expr) string {
if expr == nil {
return ""
}
expr.accept(&as)
return as.str.String()
}
func (as *AstToRPN) visitBinary(b *Binary) {
b.left.accept(as)
as.str.WriteString(" ")
b.right.accept(as)
as.str.WriteString(" ")
as.str.WriteString(b.op.lexeme)
}
func (as *AstToRPN) visitLiteral(l *Literal) {
as.str.WriteString(fmt.Sprintf("%v", l.value))
}
func (as *AstToRPN) visitGrouping(g *Grouping) {
g.expression.accept(as)
as.str.WriteString(" group")
}
func (as *AstToRPN) visitUnary(u *Unary) {
u.right.accept(as)
as.str.WriteString(fmt.Sprintf(" %s", u.op.lexeme))
}

21
error.go Normal file
View file

@ -0,0 +1,21 @@
package main
import (
"fmt"
"log"
)
var hadError = false
func printError(token Token, message string) {
if token.typ == EOF {
report(token.line, " at and", message)
} else {
report(token.line, fmt.Sprintf(" at '%s'", token.lexeme), message)
}
}
func report(line int, where string, message string) {
log.Printf("[%d] Error %s: %s", line, where, message)
hadError = true
}

26
glox.go
View file

@ -7,17 +7,6 @@ import (
) )
func main() { func main() {
expr := &Binary{
&Unary{Token{MINUS, "-", nil, 1}, &Literal{123}},
Token{STAR, "*", nil, 1},
&Grouping{&Grouping{&Binary{
&Unary{Token{MINUS, "-", nil, 1}, &Literal{123}},
Token{STAR, "*", nil, 1},
&Grouping{&Grouping{&Literal{45.67}}}}}},
}
println(AstStringer{}.String(expr))
switch len(os.Args) { switch len(os.Args) {
case 1: case 1:
runPrompt() runPrompt()
@ -48,24 +37,21 @@ func runPrompt() {
func runFile(path string) { func runFile(path string) {
file, err := os.ReadFile(path) file, err := os.ReadFile(path)
panic(err) try(err)
run(file) run(file)
if hadError {
os.Exit(1)
}
} }
func run(source []byte) { func run(source []byte) {
tokens := newScanner(source).scan() tokens := newScanner(source).scan()
for _, token := range tokens { ast := newParser(tokens).parse()
println(token.String())
} println(AstStringer{}.String(ast))
println(AstToRPN{}.String(ast))
} }
func panic(err error) { func try(err error) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

206
parser.go Normal file
View file

@ -0,0 +1,206 @@
package main
import (
"fmt"
)
type Parser struct {
tokens []Token
current int
}
type ParseError struct {
token Token
message string
}
func (pe ParseError) Error() string {
return fmt.Sprintf("%s: %s", pe.token.lexeme, pe.message)
}
func newParser(tokens []Token) *Parser {
return &Parser{
tokens: tokens,
current: 0,
}
}
func (p *Parser) parse() Expr {
defer p.recover()
return p.expression()
}
func (p *Parser) recover() {
if err := recover(); err != nil {
pe := err.(ParseError)
printError(pe.token, pe.message)
}
}
// expression -> equality
func (p *Parser) expression() Expr {
return p.equality()
}
// equality -> comparison ( ( "==" | "!=" ) comparison )*
func (p *Parser) equality() Expr {
expr := p.comparison()
for p.match(EQUAL_EQUAL, BANG_EQUAL) {
op := p.previous()
right := p.comparison()
expr = &Binary{expr, op, right}
}
return expr
}
// comparison -> term ( ( ">" | ">=" | "<" | "<=" ) term )*
func (p *Parser) comparison() Expr {
expr := p.term()
for p.match(GREATER, GREATER_EQUAL, LESS, LESS_EQUAL) {
op := p.previous()
right := p.term()
expr = &Binary{expr, op, right}
}
return expr
}
// term -> factor ( ( "-" | "+" ) factor )*
func (p *Parser) term() Expr {
expr := p.factor()
for p.match(MINUS, PLUS) {
op := p.previous()
right := p.factor()
expr = &Binary{expr, op, right}
}
return expr
}
// factor -> unary ( ( "/" | "*" ) unary )*
func (p *Parser) factor() Expr {
exp := p.unary()
for p.match(SLASH, STAR) {
op := p.previous()
right := p.unary()
exp = &Unary{op, right}
}
return exp
}
// unary -> ( "!" | "-" ) unary | primary
func (p *Parser) unary() Expr {
if p.match(BANG, MINUS) {
op := p.previous()
right := p.unary()
return &Unary{op, right}
}
return p.primary()
}
// primary -> NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")"
func (p *Parser) primary() Expr {
switch {
case p.match(FALSE):
return &Literal{false}
case p.match(TRUE):
return &Literal{true}
case p.match(NIL):
return &Literal{nil}
}
if p.match(NUMBER, STRING) {
return &Literal{p.previous().literal}
}
if p.match(LEFT_PAREN) {
expr := p.expression()
p.consume(RIGHT_PAREN, "Expect ')' after expression")
return &Grouping{expr}
}
panic(ParseError{p.peek(), "Expect expression"})
return nil
}
func (p *Parser) previous() Token {
return p.tokens[p.current-1]
}
func (p *Parser) peek() Token {
return p.tokens[p.current]
}
func (p *Parser) isAtEnd() bool {
return p.peek().typ == EOF
}
func (p *Parser) advance() Token {
if !p.isAtEnd() {
p.current++
}
return p.previous()
}
func (p *Parser) check(typ TokenType) bool {
if p.isAtEnd() {
return false
}
return p.peek().typ == typ
}
func (p *Parser) match(types ...TokenType) bool {
for _, typ := range types {
if p.check(typ) {
p.advance()
return true
}
}
return false
}
func (p *Parser) consume(typ TokenType, mes string) Token {
if p.check(typ) {
return p.advance()
}
panic(ParseError{p.peek(), mes})
return Token{}
}
func (p *Parser) synchronize() {
p.advance()
for !p.isAtEnd() {
if p.previous().typ == SEMICOLON {
return
}
switch p.peek().typ {
case CLASS:
case FOR:
case FUN:
case IF:
case PRINT:
case RETURN:
case VAR:
case WHILE:
return
}
p.advance()
}
}

View file

@ -8,8 +8,6 @@ import (
"unicode/utf8" "unicode/utf8"
) )
var hadError = false
//go:generate go run golang.org/x/tools/cmd/stringer -type=TokenType //go:generate go run golang.org/x/tools/cmd/stringer -type=TokenType
type TokenType int type TokenType int
@ -193,7 +191,7 @@ func (s *Scanner) scanToken() {
break break
} }
s.printError(fmt.Sprintf("Unexpected character %s", string(c))) report(s.line, "", fmt.Sprintf("Unexpected character %s", string(c)))
} }
} }
@ -226,7 +224,7 @@ func (s *Scanner) string() {
} }
if s.isAtEnd() { if s.isAtEnd() {
s.printError("Unterminated string") report(s.line, "", "Unterminated string")
return return
} }
@ -251,7 +249,7 @@ func (s *Scanner) number() {
num, err := strconv.ParseFloat(string(s.source[s.start:s.current]), 64) num, err := strconv.ParseFloat(string(s.source[s.start:s.current]), 64)
if err != nil { if err != nil {
s.printError(err.Error()) report(s.line, "", err.Error())
} }
s.addToken(NUMBER, num) s.addToken(NUMBER, num)
@ -288,15 +286,15 @@ func (s *Scanner) match(ch rune) bool {
} }
decoded, size := utf8.DecodeRune(s.source[s.current:]) decoded, size := utf8.DecodeRune(s.source[s.current:])
if decoded != ch {
return false
}
s.current += size s.current += size
return ch == decoded return true
} }
func (s *Scanner) isAtEnd() bool { func (s *Scanner) isAtEnd() bool {
return s.current >= len(s.source) return s.current >= len(s.source)
} }
func (s *Scanner) printError(message string) {
fmt.Printf("[line %d] Error: %s\n", s.line, message)
hadError = true
}