class statement

This commit is contained in:
Greg 2024-11-06 00:09:50 +02:00
parent c1549bbc6d
commit 0d98ae8ab1
7 changed files with 56 additions and 3 deletions

View file

@ -190,3 +190,7 @@ func (as *AstStringer) visitReturnStmt(r *ReturnStmt) {
r.value.accept(as) r.value.accept(as)
as.str.WriteString(")") as.str.WriteString(")")
} }
func (as *AstStringer) visitClassStmt(c *ClassStmt) {
fmt.Printf("(class %s)", c.name.lexeme)
}

9
class.go Normal file
View file

@ -0,0 +1,9 @@
package main
type Class struct {
name string
}
func (c *Class) String() string {
return c.name
}

View file

@ -205,6 +205,12 @@ func (i *Interpreter) visitFunStmt(f *FunStmt) {
i.env.define(f.name.lexeme, newFunction(f.name, f.args, f.body, i.env)) i.env.define(f.name.lexeme, newFunction(f.name, f.args, f.body, i.env))
} }
func (i *Interpreter) visitClassStmt(c *ClassStmt) {
i.env.define(c.name.lexeme, nil)
class := &Class{c.name.lexeme}
i.env.assign(c.name, class)
}
func (i *Interpreter) visitLambda(l *Lambda) any { func (i *Interpreter) visitLambda(l *Lambda) any {
return newFunction(l.name, l.args, l.body, i.env) return newFunction(l.name, l.args, l.body, i.env)
} }

View file

@ -48,7 +48,8 @@ func (p *Parser) parse() ([]Stmt, error) {
return stmts, errors.Join(p.errors...) return stmts, errors.Join(p.errors...)
} }
// declaration -> varDecl | funDecl | statement // declaration ->
// varDecl | funDecl | classDecl | statement
func (p *Parser) declaration() Stmt { func (p *Parser) declaration() Stmt {
defer p.synchronize() defer p.synchronize()
if p.match(VAR) { if p.match(VAR) {
@ -59,6 +60,10 @@ func (p *Parser) declaration() Stmt {
return p.function("function") return p.function("function")
} }
if p.match(CLASS) {
return p.classDecl()
}
return p.statement() return p.statement()
} }
@ -76,10 +81,23 @@ func (p *Parser) varDecl() Stmt {
return &VarStmt{name, initializer} return &VarStmt{name, initializer}
} }
// classDecl -> "class" IDENTIFIER "{" function* "}"
func (p *Parser) classDecl() Stmt {
name := p.consume(IDENTIFIER, "Expect identifier for variable")
p.consume(LEFT_BRACE, "Expect '{' after class identifier")
methods := []FunStmt{}
for !p.isAtEnd() && !p.check(RIGHT_BRACE) {
methods = append(methods, *p.function("method"))
}
p.consume(RIGHT_BRACE, "Expect '}' after class definition")
return &ClassStmt{name, methods}
}
// funDecl -> "fun" function // funDecl -> "fun" function
// function -> IDENTIFIER "(" parameters? ")" blockStmt // function -> IDENTIFIER "(" parameters? ")" blockStmt
// parameters -> IDENTIFIER ( "," IDENTIFIER )* // parameters -> IDENTIFIER ( "," IDENTIFIER )*
func (p *Parser) function(kind string) Stmt { func (p *Parser) function(kind string) *FunStmt {
name := p.consume(IDENTIFIER, fmt.Sprintf("Expect %s name.", kind)) name := p.consume(IDENTIFIER, fmt.Sprintf("Expect %s name.", kind))
p.consume(LEFT_PAREN, fmt.Sprintf("Expect '(' after %s name.", kind)) p.consume(LEFT_PAREN, fmt.Sprintf("Expect '(' after %s name.", kind))

View file

@ -186,3 +186,8 @@ func (r *Resolver) visitUnary(u *Unary) any {
r.resolveExprs(u.right) r.resolveExprs(u.right)
return nil return nil
} }
func (r *Resolver) visitClassStmt(c *ClassStmt) {
r.declare(c.name)
r.define(c.name)
}

11
stmt.go
View file

@ -11,6 +11,7 @@ type StmtVisitor interface {
visitWhileStmt(w *WhileStmt) visitWhileStmt(w *WhileStmt)
visitBreakStmt(b *BreakStmt) visitBreakStmt(b *BreakStmt)
visitReturnStmt(r *ReturnStmt) visitReturnStmt(r *ReturnStmt)
visitClassStmt(c *ClassStmt)
} }
type Stmt interface { type Stmt interface {
@ -49,6 +50,11 @@ type WhileStmt struct {
body Stmt body Stmt
} }
type ClassStmt struct {
name Token
methods []FunStmt
}
type BreakStmt struct{} type BreakStmt struct{}
type FunStmt struct { type FunStmt struct {
@ -71,6 +77,7 @@ func (b *BlockStmt) stmt() {}
func (w *WhileStmt) stmt() {} func (w *WhileStmt) stmt() {}
func (b *BreakStmt) stmt() {} func (b *BreakStmt) stmt() {}
func (r *ReturnStmt) stmt() {} func (r *ReturnStmt) stmt() {}
func (c *ClassStmt) stmt() {}
func (p *PrintStmt) accept(v StmtVisitor) { func (p *PrintStmt) accept(v StmtVisitor) {
v.visitPrintStmt(p) v.visitPrintStmt(p)
@ -111,3 +118,7 @@ func (f *FunStmt) accept(v StmtVisitor) {
func (r *ReturnStmt) accept(v StmtVisitor) { func (r *ReturnStmt) accept(v StmtVisitor) {
v.visitReturnStmt(r) v.visitReturnStmt(r)
} }
func (c *ClassStmt) accept(v StmtVisitor) {
v.visitClassStmt(c)
}