diff --git a/ast.go b/ast.go index 2e6d431..937f159 100644 --- a/ast.go +++ b/ast.go @@ -1,20 +1,20 @@ package main -import ( - "fmt" - "strings" -) - type Visitor interface { - visitBinary(b *Binary) - visitLiteral(l *Literal) - visitGrouping(g *Grouping) - visitUnary(u *Unary) + visitUnary(u *Unary) any + visitBinary(b *Binary) any + visitLiteral(l *Literal) any + visitGrouping(g *Grouping) any } type Expr interface { expr() - accept(v Visitor) + accept(v Visitor) any +} + +type Unary struct { + op Token + right Expr } type Binary struct { @@ -23,114 +23,31 @@ type Binary struct { right Expr } -type Unary struct { - op Token - right Expr +type Literal struct { + value any } type Grouping struct { expression Expr } -type Literal struct { - value any -} - func (u *Unary) expr() {} -func (g *Grouping) expr() {} -func (l *Literal) expr() {} func (b *Binary) expr() {} +func (l *Literal) expr() {} +func (g *Grouping) expr() {} -func (u *Unary) accept(v Visitor) { - v.visitUnary(u) +func (u *Unary) accept(v Visitor) any { + return v.visitUnary(u) } -func (g *Grouping) accept(v Visitor) { - v.visitGrouping(g) +func (b *Binary) accept(v Visitor) any { + return v.visitBinary(b) } -func (l *Literal) accept(v Visitor) { - v.visitLiteral(l) +func (l *Literal) accept(v Visitor) any { + return v.visitLiteral(l) } -func (b *Binary) accept(v Visitor) { - v.visitBinary(b) -} - -type AstStringer struct { - str strings.Builder -} - -func (as AstStringer) String(expr Expr) string { - - if expr == nil { - return "" - } - - expr.accept(&as) - return as.str.String() -} - -func (as *AstStringer) visitBinary(b *Binary) { - as.str.WriteString("(") - as.str.WriteString(b.op.lexeme) - as.str.WriteString(" ") - b.left.accept(as) - as.str.WriteString(" ") - b.right.accept(as) - as.str.WriteString(")") - -} - -func (as *AstStringer) visitLiteral(l *Literal) { - as.str.WriteString(fmt.Sprintf("%v", l.value)) -} - -func (as *AstStringer) visitGrouping(g *Grouping) { - as.str.WriteString("(group ") - g.expression.accept(as) - as.str.WriteString(")") -} - -func (as *AstStringer) visitUnary(u *Unary) { - as.str.WriteString(fmt.Sprintf("(%s ", u.op.lexeme)) - u.right.accept(as) - as.str.WriteString(")") -} - -type AstToRPN struct { - str strings.Builder -} - -func (as AstToRPN) String(expr Expr) string { - - if expr == nil { - return "" - } - - expr.accept(&as) - return as.str.String() -} - -func (as *AstToRPN) visitBinary(b *Binary) { - b.left.accept(as) - as.str.WriteString(" ") - b.right.accept(as) - as.str.WriteString(" ") - as.str.WriteString(b.op.lexeme) - -} - -func (as *AstToRPN) visitLiteral(l *Literal) { - as.str.WriteString(fmt.Sprintf("%v", l.value)) -} - -func (as *AstToRPN) visitGrouping(g *Grouping) { - g.expression.accept(as) - as.str.WriteString(" group") -} - -func (as *AstToRPN) visitUnary(u *Unary) { - u.right.accept(as) - as.str.WriteString(fmt.Sprintf(" %s", u.op.lexeme)) +func (g *Grouping) accept(v Visitor) any { + return v.visitGrouping(g) } diff --git a/ast_rpn.go b/ast_rpn.go new file mode 100644 index 0000000..96b2e2f --- /dev/null +++ b/ast_rpn.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "strings" +) + +type AstToRPN struct { + str strings.Builder +} + +func (as AstToRPN) String(expr Expr) string { + + if expr == nil { + return "" + } + + expr.accept(&as) + return as.str.String() +} + +func (as *AstToRPN) visitBinary(b *Binary) any { + b.left.accept(as) + as.str.WriteString(" ") + b.right.accept(as) + as.str.WriteString(" ") + as.str.WriteString(b.op.lexeme) + return nil +} + +func (as *AstToRPN) visitLiteral(l *Literal) any { + as.str.WriteString(fmt.Sprintf("%v", l.value)) + return nil +} + +func (as *AstToRPN) visitGrouping(g *Grouping) any { + g.expression.accept(as) + as.str.WriteString(" group") + return nil +} + +func (as *AstToRPN) visitUnary(u *Unary) any { + u.right.accept(as) + as.str.WriteString(fmt.Sprintf(" %s", u.op.lexeme)) + return nil +} diff --git a/ast_string.go b/ast_string.go new file mode 100644 index 0000000..f54f2f9 --- /dev/null +++ b/ast_string.go @@ -0,0 +1,51 @@ +package main + +import ( + "fmt" + "strings" +) + +type AstStringer struct { + str strings.Builder +} + +func (as AstStringer) String(expr Expr) string { + + if expr == nil { + return "" + } + + expr.accept(&as) + return as.str.String() +} + +func (as *AstStringer) visitBinary(b *Binary) any { + as.str.WriteString("(") + as.str.WriteString(b.op.lexeme) + as.str.WriteString(" ") + b.left.accept(as) + as.str.WriteString(" ") + b.right.accept(as) + as.str.WriteString(")") + return nil + +} + +func (as *AstStringer) visitLiteral(l *Literal) any { + as.str.WriteString(fmt.Sprintf("%v", l.value)) + return nil +} + +func (as *AstStringer) visitGrouping(g *Grouping) any { + as.str.WriteString("(group ") + g.expression.accept(as) + as.str.WriteString(")") + return nil +} + +func (as *AstStringer) visitUnary(u *Unary) any { + as.str.WriteString(fmt.Sprintf("(%s ", u.op.lexeme)) + u.right.accept(as) + as.str.WriteString(")") + return nil +} diff --git a/error.go b/error.go index 58940cb..8f1544f 100644 --- a/error.go +++ b/error.go @@ -6,6 +6,7 @@ import ( ) var hadError = false +var hadRuntimeError = false func printError(token Token, message string) { if token.typ == EOF { @@ -19,3 +20,8 @@ func report(line int, where string, message string) { log.Printf("[%d] Error %s: %s", line, where, message) hadError = true } + +func reportRuntimeError(token Token, message string) { + log.Printf("[%d] Error: %s", token.line, message) + hadRuntimeError = true +} diff --git a/glox.go b/glox.go index 7a5bb2b..4feeb79 100644 --- a/glox.go +++ b/glox.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "fmt" "log" "os" ) @@ -31,6 +32,7 @@ func runPrompt() { } run([]byte(scanner.Text())) hadError = false + hadRuntimeError = false } } @@ -40,15 +42,39 @@ func runFile(path string) { try(err) run(file) + + switch { + case hadError: + os.Exit(65) + case hadRuntimeError: + os.Exit(70) + default: + os.Exit(0) + } } func run(source []byte) { tokens := newScanner(source).scan() + if hadError { + return + } + ast := newParser(tokens).parse() + if hadError { + return + } + println(AstStringer{}.String(ast)) - println(AstToRPN{}.String(ast)) + + res := newInterpreter().evaluate(ast) + + if hadRuntimeError { + return + } + + fmt.Printf("%v\n", res) } func try(err error) { diff --git a/interpreter.go b/interpreter.go new file mode 100644 index 0000000..d621fc4 --- /dev/null +++ b/interpreter.go @@ -0,0 +1,148 @@ +package main + +import ( + "fmt" + "reflect" +) + +type Interpreter struct{} + +type RuntimeError struct { + token Token + msg string +} + +func (re RuntimeError) Error() string { + return re.msg +} + +func newInterpreter() *Interpreter { + return &Interpreter{} +} + +func (i *Interpreter) evaluate(e Expr) any { + defer i.recover() + return e.accept(i) +} + +func (i *Interpreter) recover() { + if err := recover(); err != nil { + pe, ok := err.(RuntimeError) + + if !ok { + panic(err) + } + + reportRuntimeError(pe.token, pe.msg) + hadRuntimeError = true + } +} + +func (i *Interpreter) visitBinary(b *Binary) any { + left := i.evaluate(b.left) + right := i.evaluate(b.right) + + switch b.op.typ { + case MINUS: + checkIfFloats(b.op, left, right) + return left.(float64) - right.(float64) + case SLASH: + checkIfFloats(b.op, left, right) + return left.(float64) / right.(float64) + case STAR: + checkIfFloats(b.op, left, right) + return left.(float64) * right.(float64) + case GREATER: + checkIfFloats(b.op, left, right) + return left.(float64) > right.(float64) + case LESS: + checkIfFloats(b.op, left, right) + return left.(float64) < right.(float64) + case GREATER_EQUAL: + checkIfFloats(b.op, left, right) + return left.(float64) >= right.(float64) + case LESS_EQUAL: + checkIfFloats(b.op, left, right) + return left.(float64) <= right.(float64) + case BANG_EQUAL: + return !reflect.DeepEqual(left, right) + case EQUAL_EQUAL: + return reflect.DeepEqual(left, right) + case PLUS: + if isFloats(left, right) { + return left.(float64) + right.(float64) + } + + if isStrings(left, right) { + return left.(string) + right.(string) + } + } + + panic(RuntimeError{b.op, fmt.Sprintf("Operands must be numbers or strings: %v %s %v", left, b.op.lexeme, right)}) + + return nil +} + +func (i *Interpreter) visitLiteral(l *Literal) any { + return l.value +} + +func (i *Interpreter) visitGrouping(g *Grouping) any { + return i.evaluate(g.expression) +} + +func (i *Interpreter) visitUnary(u *Unary) any { + val := i.evaluate(u.right) + + switch u.op.typ { + case MINUS: + checkIfFloat(u.op, val) + return -val.(float64) + case BANG: + return !isTruthy(val) + } + + return nil +} + +func checkIfFloat(op Token, val any) { + if _, ok := val.(float64); ok { + return + } + + panic(RuntimeError{op, "value must ne a number."}) +} + +func checkIfFloats(op Token, a any, b any) { + if isFloats(a, b) { + return + } + + panic(RuntimeError{op, fmt.Sprintf("Operands must be numbers: %v %s %v", a, op.lexeme, b)}) +} + +func isFloats(a any, b any) bool { + ltype := reflect.TypeOf(a) + rtype := reflect.TypeOf(b) + + return ltype.Kind() == rtype.Kind() && ltype.Kind() == reflect.Float64 +} + +func isStrings(a any, b any) bool { + ltype := reflect.TypeOf(a) + rtype := reflect.TypeOf(b) + + return ltype.Kind() == rtype.Kind() && ltype.Kind() == reflect.String +} + +func isTruthy(val any) bool { + if val == nil { + return false + } + + if boolean, ok := val.(bool); ok { + return boolean + } + + return true +} diff --git a/parser.go b/parser.go index 8dd1604..11cf9bb 100644 --- a/parser.go +++ b/parser.go @@ -88,7 +88,7 @@ func (p *Parser) factor() Expr { for p.match(SLASH, STAR) { op := p.previous() right := p.unary() - exp = &Unary{op, right} + exp = &Binary{exp, op, right} } return exp @@ -190,14 +190,7 @@ func (p *Parser) synchronize() { } switch p.peek().typ { - case CLASS: - case FOR: - case FUN: - case IF: - case PRINT: - case RETURN: - case VAR: - case WHILE: + case CLASS, FOR, FUN, IF, PRINT, RETURN, VAR, WHILE: return }