glox/interpreter.go
2024-10-14 22:53:26 +03:00

381 lines
6.8 KiB
Go

package main
import (
"errors"
"fmt"
"log"
"reflect"
"slices"
)
type Interpreter struct {
env *Environment
globals *Environment
locals map[Expr]int
errors []error
brk bool
}
type RuntimeError struct {
token Token
msg string
}
type Return struct {
val any
}
func (re *RuntimeError) Error() string {
return fmt.Sprintf("RuntimeError [%d][%s] Error: %s", re.token.line, re.token.typ, re.msg)
}
func newInterpreter() *Interpreter {
globals := newEnvironment(nil)
defineGlobals(globals)
return &Interpreter{
env: globals,
globals: globals,
locals: map[Expr]int{},
errors: []error{},
brk: false,
}
}
func (i *Interpreter) interpret(stmts []Stmt) error {
defer i.recover()
i.errors = []error{}
for _, stmt := range stmts {
stmt.accept(i)
}
return errors.Join(i.errors...)
}
func (i *Interpreter) recover() {
if err := recover(); err != nil {
_, ok := err.(*RuntimeError)
if !ok {
panic(err)
}
}
}
func (i *Interpreter) evaluate(e Expr) any {
return e.accept(i)
}
func (i *Interpreter) visitBinary(b *Binary) any {
left := i.evaluate(b.left)
right := i.evaluate(b.right)
switch b.op.typ {
case MINUS:
i.checkIfFloats(b.op, left, right)
return left.(float64) - right.(float64)
case SLASH:
i.checkIfFloats(b.op, left, right)
return left.(float64) / right.(float64)
case STAR:
i.checkIfFloats(b.op, left, right)
return left.(float64) * right.(float64)
case GREATER:
i.checkIfFloats(b.op, left, right)
return left.(float64) > right.(float64)
case LESS:
i.checkIfFloats(b.op, left, right)
return left.(float64) < right.(float64)
case GREATER_EQUAL:
i.checkIfFloats(b.op, left, right)
return left.(float64) >= right.(float64)
case LESS_EQUAL:
i.checkIfFloats(b.op, left, right)
return left.(float64) <= right.(float64)
case BANG_EQUAL:
return !reflect.DeepEqual(left, right)
case EQUAL_EQUAL:
return reflect.DeepEqual(left, right)
case PLUS:
if isFloats(left, right) {
return left.(float64) + right.(float64)
}
if isStrings(left, right) {
return left.(string) + right.(string)
}
}
i.panic(&RuntimeError{b.op, fmt.Sprintf("Operands must be numbers or strings: %v %s %v", left, b.op.lexeme, right)})
return nil
}
func (i *Interpreter) visitLiteral(l *Literal) any {
return l.value
}
func (i *Interpreter) visitGrouping(g *Grouping) any {
return i.evaluate(g.expression)
}
func (i *Interpreter) visitUnary(u *Unary) any {
val := i.evaluate(u.right)
switch u.op.typ {
case MINUS:
i.checkIfFloat(u.op, val)
return -val.(float64)
case BANG:
return !isTruthy(val)
}
return nil
}
func (i *Interpreter) visitVariable(v *Variable) any {
return i.lookUpVariable(v.name, v)
}
func (i *Interpreter) visitAssignment(a *Assign) any {
val := i.evaluate(a.value)
distance, isLocal := i.locals[a]
if isLocal {
i.env.assignAt(distance, a.variable, val)
return val
}
err := i.globals.assign(a.variable, val)
if err != nil {
i.panic(err)
}
return val
}
func (i *Interpreter) visitLogical(lo *Logical) any {
left := i.evaluate(lo.left)
shortOr := lo.operator.typ == OR && isTruthy(left)
shortAnd := lo.operator.typ == AND && !isTruthy(left)
if shortOr || shortAnd {
return left
}
return i.evaluate(lo.right)
}
func (i *Interpreter) visitCall(c *Call) any {
callee := i.evaluate(c.callee)
args := []any{}
for _, arg := range c.args {
args = append(args, i.evaluate(arg))
}
callable, ok := callee.(Callable)
if !ok {
i.panic(&RuntimeError{c.paren, "Can only call function and classes."})
}
if callable.arity() != len(args) {
i.panic(&RuntimeError{
c.paren,
fmt.Sprintf(
"Expected %d arguments but got %d",
callable.arity(),
len(args),
),
})
}
return callable.call(i, args...)
}
func (i *Interpreter) visitFunStmt(f *FunStmt) {
i.env.define(f.name.lexeme, newFunction(f.name, f.args, f.body, i.env))
}
func (i *Interpreter) visitLambda(l *Lambda) any {
return newFunction(l.name, l.args, l.body, i.env)
}
func (i *Interpreter) visitReturnStmt(r *ReturnStmt) {
var value any
if r.value != nil {
value = i.evaluate(r.value)
}
panic(Return{value})
}
func (i *Interpreter) visitPrintStmt(p *PrintStmt) {
fmt.Printf("%v\n", i.evaluate(p.val))
}
func (i *Interpreter) visitExprStmt(se *ExprStmt) {
i.evaluate(se.expr)
}
func (i *Interpreter) visitVarStmt(v *VarStmt) {
var val any = nil
if v.initializer != nil {
val = i.evaluate(v.initializer)
}
i.env.define(v.name.lexeme, val)
}
func (i *Interpreter) visitBlockStmt(b *BlockStmt) {
i.executeBlock(b.stmts, newEnvironment(i.env))
}
func (i *Interpreter) executeBlock(stmts []Stmt, current *Environment) {
parentEnv := i.env
i.env = current
// need to restore environment after
// panic(Return) in visitReturnStmt
defer func() {
i.env = parentEnv
}()
for _, stmt := range stmts {
if i.brk {
break
}
stmt.accept(i)
}
}
func (i *Interpreter) visitBreakStmt(b *BreakStmt) {
i.brk = true
}
func (i *Interpreter) visitIfStmt(iff *IfStmt) {
if isTruthy(i.evaluate(iff.cond)) {
iff.then.accept(i)
} else if iff.or != nil {
iff.or.accept(i)
}
}
func (i *Interpreter) visitEnvStmt(e *EnvStmt) {
walker := i.env
flatten := []*Environment{}
for walker != nil {
flatten = slices.Insert(flatten, 0, walker)
walker = walker.enclosing
}
fmt.Printf("globals: %+v\n", *i.globals)
for ident, e := range flatten {
fmt.Printf("%*s", ident, "")
fmt.Printf("%+v\n", *e)
}
}
func (i *Interpreter) visitWhileStmt(w *WhileStmt) {
for isTruthy(i.evaluate(w.cond)) {
if i.brk {
i.brk = false
break
}
w.body.accept(i)
}
}
func (i *Interpreter) resolve(expr Expr, depth int) {
i.locals[expr] = depth
}
func (i *Interpreter) lookUpVariable(name Token, expr Expr) any {
distance, isLocal := i.locals[expr]
if !isLocal {
return i.globals.get(name.lexeme)
}
return i.env.getAt(distance, name.lexeme)
}
func (i *Interpreter) panic(re *RuntimeError) {
i.errors = append(i.errors, re)
log.Println(re)
panic(re)
}
func (i *Interpreter) checkIfFloat(op Token, val any) {
if _, ok := val.(float64); ok {
return
}
i.panic(&RuntimeError{op, "value must be a number."})
}
func (i *Interpreter) checkIfFloats(op Token, a any, b any) {
if isFloats(a, b) {
return
}
i.panic(&RuntimeError{op, fmt.Sprintf("Operands must be numbers: %v %s %v", a, op.lexeme, b)})
}
func isFloats(a any, b any) bool {
if a == nil || b == nil {
return false
}
ltype := reflect.TypeOf(a)
rtype := reflect.TypeOf(b)
return ltype.Kind() == rtype.Kind() && ltype.Kind() == reflect.Float64
}
func isStrings(a any, b any) bool {
if a == nil || b == nil {
return false
}
ltype := reflect.TypeOf(a)
rtype := reflect.TypeOf(b)
return ltype.Kind() == rtype.Kind() && ltype.Kind() == reflect.String
}
func isTruthy(val any) bool {
if val == nil {
return false
}
if boolean, ok := val.(bool); ok {
return boolean
}
return true
}