From b85e94e8939b3b61b2ab7ea42f1d2b8b5a546b9b Mon Sep 17 00:00:00 2001 From: AdsonFS Date: Sun, 22 Sep 2024 22:09:00 -0300 Subject: [PATCH 1/2] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Type=20AST?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor: change type ast --- examples/main.ll | 54 ++---------- examples/refactor.ll | 7 +- grammar_rules | 7 +- src/ast/ast.h | 16 ++-- src/ast/interpreter_visitor.cpp | 45 +++++----- src/ast/printer_visitor.cpp | 8 +- src/ast/semantic_visitor.cpp | 86 ++++++++++---------- src/lexi/lexi_scanner.cpp | 2 + src/parser/lang_parser.h | 1 + src/parser/validators/declaration_parser.cpp | 49 ++++++----- src/tokens/token.cpp | 3 +- src/tokens/token.h | 1 + 12 files changed, 121 insertions(+), 158 deletions(-) diff --git a/examples/main.ll b/examples/main.ll index 7c28a74..def12d9 100644 --- a/examples/main.ll +++ b/examples/main.ll @@ -1,45 +1,12 @@ ->> "start: main.ll"; - -func counter() -> func -> void { - var i -> number := 0; - func count() -> void { - i := i + 1; - >> i; - } - - return count; -} - -var cc -> func -> void; -var dd -> func -> void; -counter(); -cc := counter(); -dd := cc; -cc(); -cc(); -cc(); -dd(); -dd(); -dd(); - -var x -> number := 5; - class address { var street -> string := "123 Main St"; var city -> string := "Springfield"; var state -> string := "IL"; var zip -> string := "62704"; - - func print() -> void { - >> street; - >> city; - >> state; - >> zip; - } } class person { - var age -> number := 12 + x; + var age -> number := 12 + 4; var name -> string := "John"; var addr -> address := address(); @@ -51,26 +18,15 @@ class person { class student : person { var code -> number := 35; + <> override do metodo print func print() -> void { >> "Hi, I am a student, my name is" >> name; } } -func newperson() -> person { - var p -> person := person(); - >> p.age; - >> p.name; - p.addr.print(); - return p; -} -var person1 -> person := newperson(); - var p -> person := person(); -p.addr.print(); - -var lout -> func -> void := p.print; -lout(); - - var s -> student := student(); + +p.print(); s.print(); + diff --git a/examples/refactor.ll b/examples/refactor.ll index b24a78d..5bd1ef6 100644 --- a/examples/refactor.ll +++ b/examples/refactor.ll @@ -24,6 +24,11 @@ class person { var p -> person := person(); p.print(); - +if (x > -12) { + var y -> number := 10; + dd := counter(); + cc(); + dd(); +} var lout -> func -> void := p.print; lout(); diff --git a/grammar_rules b/grammar_rules index 1eda1c3..5786480 100644 --- a/grammar_rules +++ b/grammar_rules @@ -29,7 +29,10 @@ class_declaration : class identifier : identifier '{' ( variable_declaration variable_declaration : 'var' identifier '->' type := expression | 'var' identifier '->' type; -func_declaration : 'func' identifier '(' ')' -> type? '{' statement_list '}' +func_declaration : 'func' identifier '(' parameters? ')' -> type? '{' statement_list '}' + +parameters : identifier -> type ( COMMA identifier -> type )* +arguments : expression ( COMMA expression )* type : 'string' | 'number' | void | identifier | '(' ')' -> type? output_stream : >> output output_stream* @@ -57,5 +60,5 @@ factor : (PLUS|MINUS)factor property_chain : (identifier_or_call) ( . property_chain )* identifier_or_call : identifier | call -call : identifier '(' ')' +call : identifier '(' arguments? ')' identifier : IDENTIFIER diff --git a/src/ast/ast.h b/src/ast/ast.h index 7ea07c9..9a5fa75 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -97,12 +97,12 @@ class ClassDeclarationAST : public AST { class FunctionDeclarationAST : public AST { public: - FunctionDeclarationAST(Token identifier, std::stack types, + FunctionDeclarationAST(Token identifier, TypeAST *type, StatementListAST *statements) - : identifier(identifier), types(types), statements(statements) {} + : identifier(identifier), type(type), statements(statements) {} ASTValue *accept(ASTVisitor &visitor) override; - std::stack types; + TypeAST *type; Token identifier; StatementListAST *statements; }; @@ -126,11 +126,11 @@ class InputStreamAST : public AST { class VariableDeclarationAST : public AST { public: - VariableDeclarationAST(std::stack types, Token identifier, AST *value) - : types(types), identifier(identifier), value(value) {} + VariableDeclarationAST(TypeAST* type, Token identifier, AST *value) + : type(type), identifier(identifier), value(value) {} ASTValue *accept(ASTVisitor &visitor) override; - std::stack types; + TypeAST* type; Token identifier; AST *value; }; @@ -185,10 +185,10 @@ class PropertyChainAST : public AST { class TypeAST : public AST { public: - TypeAST(Token token) : token(token) {} + TypeAST(std::stack types) : types(types) {} ASTValue *accept(ASTVisitor &visitor) override; - Token token; + std::stack types; }; class IdentifierAST : public AST { diff --git a/src/ast/interpreter_visitor.cpp b/src/ast/interpreter_visitor.cpp index a349b97..8cbee99 100644 --- a/src/ast/interpreter_visitor.cpp +++ b/src/ast/interpreter_visitor.cpp @@ -95,18 +95,7 @@ ASTValue *InterpreterVisitor::visitClassDeclaration(ClassDeclarationAST *expr) { ASTValue * InterpreterVisitor::visitFunctionDeclaration(FunctionDeclarationAST *expr) { - - ASTValue *type = nullptr; - std::stack types = expr->types; - while (!types.empty()) { - ASTValue *t = types.top()->accept(*this); - if (type == nullptr) - type = t; - else - type = new ASTValue(new LangFunction(nullptr, type->value, this->scope)); - types.pop(); - } - + ASTValue *type = expr->type->accept(*this); FuncSymbol *func = new FuncSymbol(expr->identifier.getValue(), new ASTValue(new LangFunction(expr->statements, @@ -136,17 +125,8 @@ ASTValue *InterpreterVisitor::visitInputStream(InputStreamAST *expr) { ASTValue * InterpreterVisitor::visitVariableDeclaration(VariableDeclarationAST *expr) { ASTValue *value = expr->value->accept(*this); + ASTValue *type = expr->type->accept(*this); - ASTValue *type = nullptr; - std::stack types = expr->types; - while (!types.empty()) { - ASTValue *t = types.top()->accept(*this); - if (type == nullptr) - type = t; - else - type = new ASTValue(new LangFunction(nullptr, type->value, this->scope)); - types.pop(); - } if (dynamic_cast(value->value) != nullptr) value->value = new LangNil(type->value); this->scope->set(new VarSymbol(expr->identifier.getValue(), value)); @@ -218,7 +198,8 @@ ASTValue *InterpreterVisitor::visitUnaryOperatorExpr(UnaryOperatorAST *expr) { throw RuntimeError("invalid operator: " + expr->op.getValue(), expr->op); } -ASTValue *InterpreterVisitor::visitCall(LangObject *callee, std::string name, Token &token) { +ASTValue *InterpreterVisitor::visitCall(LangObject *callee, std::string name, + Token &token) { if (typeid(*callee) == typeid(LangFunction)) { LangFunction *func = dynamic_cast(callee); ScopedSymbolTable *currentScope = this->scope; @@ -242,7 +223,8 @@ ASTValue *InterpreterVisitor::visitCall(LangObject *callee, std::string name, To ASTValue *InterpreterVisitor::visitCall(CallAST *expr) { ASTValue *value = this->scope->getValue(expr->identifier.getValue(), this->jumpTable[expr]); - return this->visitCall(value->value, expr->identifier.getValue(), expr->identifier); + return this->visitCall(value->value, expr->identifier.getValue(), + expr->identifier); } ASTValue *InterpreterVisitor::visitPropertyChain(PropertyChainAST *expr) { @@ -258,14 +240,25 @@ ASTValue *InterpreterVisitor::visitPropertyChain(PropertyChainAST *expr) { } else if (typeid(*node) == typeid(CallAST)) { CallAST *call = dynamic_cast(node); value = instance->getScope()->getValue(call->identifier.getValue(), 0); - value = this->visitCall(value->value, call->identifier.getValue(), call->identifier); + value = this->visitCall(value->value, call->identifier.getValue(), + call->identifier); } } return value; } ASTValue *InterpreterVisitor::visitType(TypeAST *expr) { - return this->scope->getValue(expr->token.getValue(), this->jumpTable[expr]); + ASTValue *type = nullptr; + std::stack types = expr->types; + while (!types.empty()) { + ASTValue *t = types.top()->accept(*this); + if (type == nullptr) + type = t; + else + type = new ASTValue(new LangFunction(nullptr, type->value, this->scope)); + types.pop(); + } + return type; } ASTValue *InterpreterVisitor::visitIdentifier(IdentifierAST *expr) { diff --git a/src/ast/printer_visitor.cpp b/src/ast/printer_visitor.cpp index d6d2503..2951005 100644 --- a/src/ast/printer_visitor.cpp +++ b/src/ast/printer_visitor.cpp @@ -178,7 +178,13 @@ ASTValue* PrinterVisitor::visitPropertyChain(PropertyChainAST *expr) { ASTValue* PrinterVisitor::visitType(TypeAST* expr) { this->printIndent(this->indent); - std::cout << "token.getValue() << ">\n"; + std::cout << "\n"; + std::stack types = expr->types; + while(!types.empty()) { + this->printIndent(this->indent); + std::cout << "token.getValue() << ">\n"; + types.pop(); + } return new ASTValue(new LangNil()); } diff --git a/src/ast/semantic_visitor.cpp b/src/ast/semantic_visitor.cpp index 98f2116..9a935f4 100644 --- a/src/ast/semantic_visitor.cpp +++ b/src/ast/semantic_visitor.cpp @@ -70,7 +70,8 @@ ASTValue *SemanticVisitor::visitClassDeclaration(ClassDeclarationAST *expr) { expr->superclass == nullptr ? nullptr : expr->superclass->accept(*this); if (superclass != nullptr && typeid(*superclass->value) != typeid(LangClass)) throw SemanticError("Invalid superclass for class " + - expr->identifier.getValue(), expr->superclass->token); + expr->identifier.getValue(), + dynamic_cast(expr->superclass->types.top())->token); if (superclass != nullptr) { LangClass *super = dynamic_cast(superclass->value); classScope->setSymbols(super->getScope()->getSymbols()); @@ -95,33 +96,23 @@ ASTValue *SemanticVisitor::visitClassDeclaration(ClassDeclarationAST *expr) { new ASTValue(new LangClass(expr->identifier.getValue(), classScope))); if (!this->scope->set(classSymbol)) throw SemanticError("Class " + expr->identifier.getValue() + - " already declared", expr->identifier); + " already declared", + expr->identifier); return new ASTValue(new LangNil()); } ASTValue * SemanticVisitor::visitFunctionDeclaration(FunctionDeclarationAST *expr) { - ASTValue *type = nullptr; - std::stack types = expr->types; - while (!types.empty()) { - ASTValue *t = types.top()->accept(*this); - if (type == nullptr) - type = t; - else if (typeid(*t->value) != typeid(LangFunction)) - throw SemanticError("type mismatch: " + expr->identifier.getValue(), types.top()->token); - else - type = new ASTValue(new LangFunction(nullptr, type->value, this->scope)); - types.pop(); - } - + ASTValue *type = expr->type->accept(*this); FuncSymbol *func = new FuncSymbol(expr->identifier.getValue(), new ASTValue(new LangFunction(expr->statements, type->value, this->scope))); if (!this->scope->set(func)) throw SemanticError("Function " + expr->identifier.getValue() + - " already declared", expr->identifier); + " already declared", + expr->identifier); ScopedSymbolTable *currentScope = this->scope; @@ -133,7 +124,8 @@ SemanticVisitor::visitFunctionDeclaration(FunctionDeclarationAST *expr) { if (!ScopedSymbolTable::isSameType(type->value, this->currentReturnType)) throw SemanticError("Invalid return type to function: " + - expr->identifier.getValue(), expr->identifier); + expr->identifier.getValue(), + expr->identifier); this->scope = currentScope; return new ASTValue(new LangNil()); @@ -155,27 +147,18 @@ ASTValue *SemanticVisitor::visitInputStream(InputStreamAST *expr) { ASTValue * SemanticVisitor::visitVariableDeclaration(VariableDeclarationAST *expr) { ASTValue *value = expr->value->accept(*this); + ASTValue *type = expr->type->accept(*this); - ASTValue *type = nullptr; - std::stack types = expr->types; - while (!types.empty()) { - ASTValue *t = types.top()->accept(*this); - if (type == nullptr) - type = t; - else if (typeid(*t->value) != typeid(LangFunction)) - throw SemanticError("type mismatch: " + expr->identifier.getValue(), types.top()->token); - else - type = new ASTValue(new LangFunction(nullptr, type->value, this->scope)); - types.pop(); - } if (dynamic_cast(value->value) != nullptr) value->value = new LangNil(type->value); else if (!ScopedSymbolTable::isSameType(type->value, value->value)) throw SemanticError("Invalid type for variable " + - expr->identifier.getValue(), expr->identifier); + expr->identifier.getValue(), + expr->identifier); if (!this->scope->set(new VarSymbol(expr->identifier.getValue(), value))) throw SemanticError("Variable " + expr->identifier.getValue() + - " already declared", expr->identifier); + " already declared", + expr->identifier); return new ASTValue(new LangNil()); } @@ -185,7 +168,8 @@ SemanticVisitor::visitAssignmentVariable(AssignmentVariableAST *expr) { ASTValue *value = expr->value->accept(*this); if (!ScopedSymbolTable::isSameType(leftReference->value, value->value)) - throw SemanticError("Invalid type for assignment", expr->assignmentOperator); + throw SemanticError("Invalid type for assignment", + expr->assignmentOperator); leftReference->value = value->value; return leftReference; } @@ -212,7 +196,8 @@ ASTValue *SemanticVisitor::visitCall(CallAST *expr) { ScopedSymbolTable::jumpTo(expr->identifier.getValue(), this->scope); if (jumpTable[expr] == -1) throw SemanticError("Function " + expr->identifier.getValue() + - " not declared", expr->identifier); + " not declared", + expr->identifier); LangObject *callee = this->scope->getSymbol(expr->identifier.getValue(), jumpTable[expr]) ->value->value; @@ -227,7 +212,8 @@ ASTValue *SemanticVisitor::visitCall(CallAST *expr) { return new ASTValue(lang_class); } - throw SemanticError("Invalid call to " + expr->identifier.getValue(), expr->identifier); + throw SemanticError("Invalid call to " + expr->identifier.getValue(), + expr->identifier); } ASTValue *SemanticVisitor::visitPropertyChain(PropertyChainAST *expr) { @@ -236,10 +222,11 @@ ASTValue *SemanticVisitor::visitPropertyChain(PropertyChainAST *expr) { for (int i = 1; i < expr->accesses.size(); i++) { LangClass *instance = dynamic_cast(value->value); if (instance == nullptr) { - IdentifierAST *identifier = dynamic_cast(expr->accesses[i-1]); - CallAST *call = dynamic_cast(expr->accesses[i-1]); - if (identifier != nullptr) - throw SemanticError("Invalid property chain", identifier->token); + IdentifierAST *identifier = + dynamic_cast(expr->accesses[i - 1]); + CallAST *call = dynamic_cast(expr->accesses[i - 1]); + if (identifier != nullptr) + throw SemanticError("Invalid property chain", identifier->token); else throw SemanticError("Invalid property chain", call->identifier); } @@ -256,18 +243,29 @@ ASTValue *SemanticVisitor::visitPropertyChain(PropertyChainAST *expr) { } ASTValue *SemanticVisitor::visitType(TypeAST *expr) { - jumpTable[expr] = - ScopedSymbolTable::jumpTo(expr->token.getValue(), this->scope); - if (jumpTable[expr] == -1) - throw SemanticError("Type " + expr->token.getValue() + " not declared", expr->token); - return this->scope->getSymbol(expr->token.getValue(), jumpTable[expr])->value; + ASTValue *type = nullptr; + std::stack types = expr->types; + while (!types.empty()) { + IdentifierAST *identifier = types.top(); + ASTValue *t = types.top()->accept(*this); + if (type == nullptr) + type = t; + else if (typeid(*t->value) != typeid(LangFunction)) + throw SemanticError("type mismatch: " + identifier->token.getValue(), + types.top()->token); + else + type = new ASTValue(new LangFunction(nullptr, type->value, this->scope)); + types.pop(); + } + return type; } ASTValue *SemanticVisitor::visitIdentifier(IdentifierAST *expr) { jumpTable[expr] = ScopedSymbolTable::jumpTo(expr->token.getValue(), this->scope); if (jumpTable[expr] == -1) - throw SemanticError("Variable " + expr->token.getValue() + " not declared", expr->token); + throw SemanticError("Variable " + expr->token.getValue() + " not declared", + expr->token); return this->scope->getSymbol(expr->token.getValue(), jumpTable[expr])->value; } diff --git a/src/lexi/lexi_scanner.cpp b/src/lexi/lexi_scanner.cpp index 32ab7bf..dcee7ea 100644 --- a/src/lexi/lexi_scanner.cpp +++ b/src/lexi/lexi_scanner.cpp @@ -77,6 +77,8 @@ Token LexiScanner::getNextToken() { return Token(TokenType::TK_COLON, ":", Position(cl, cc, cp)); else if (currentChar == '.') return Token(TokenType::TK_DOT, ".", Position(cl, cc, cp)); + else if (currentChar == ',') + return Token(TokenType::TK_COMMA, ",", Position(cl, cc, cp)); else if (this->isSemicolon(currentChar)) return Token(TokenType::TK_SEMICOLON, ";", Position(cl, cc, cp)); else if (this->isParentheses(currentChar)) diff --git a/src/parser/lang_parser.h b/src/parser/lang_parser.h index d7f9786..4a0a05c 100644 --- a/src/parser/lang_parser.h +++ b/src/parser/lang_parser.h @@ -31,6 +31,7 @@ class LangParser { AST *logicalExpression(); AST *propertyChain(); AST *identifier_or_call(); + AST *type(); private: LexiScanner &scanner; diff --git a/src/parser/validators/declaration_parser.cpp b/src/parser/validators/declaration_parser.cpp index c35b8e0..9c812b2 100644 --- a/src/parser/validators/declaration_parser.cpp +++ b/src/parser/validators/declaration_parser.cpp @@ -5,24 +5,14 @@ AST *LangParser::variableDeclaration() { this->consume(Token(TokenType::TK_RESERVED_WORD, "var")); Token identifier = this->consume(TokenType::TK_IDENTIFIER); - - std::stack types; - while (this->match(TokenType::TK_ARROW)) { - this->consume(Token(TokenType::TK_ARROW, "->")); - if (this->match(TokenType::TK_RESERVED_WORD)) - types.push(new TypeAST(this->consume(TokenType::TK_RESERVED_WORD))); - else - types.push(new TypeAST(this->consume(TokenType::TK_IDENTIFIER))); - } - if (types.empty()) - this->consume(Token(TokenType::TK_ARROW, "->")); + TypeAST *type = dynamic_cast(this->type()); AST *node; if (this->token.getType() != TK_SEMICOLON) { this->consume(Token(TokenType::TK_ASSIGNMENT, ":=")); - node = new VariableDeclarationAST(types, identifier, this->expression()); + node = new VariableDeclarationAST(type, identifier, this->expression()); } else - node = new VariableDeclarationAST(types, identifier, new NilAST()); + node = new VariableDeclarationAST(type, identifier, new NilAST()); return node; } @@ -31,20 +21,11 @@ AST *LangParser::funcDeclaration() { Token identifier = this->consume(TokenType::TK_IDENTIFIER); this->consume(Token(TokenType::TK_PARENTHESES, "(")); this->consume(Token(TokenType::TK_PARENTHESES, ")")); - /*this->consume(Token(TokenType::TK_ARROW, "->"));*/ - - std::stack types; - while (this->match(TokenType::TK_ARROW)) { - this->consume(Token(TokenType::TK_ARROW, "->")); - if (this->match(TokenType::TK_RESERVED_WORD)) - types.push(new TypeAST(this->consume(TokenType::TK_RESERVED_WORD))); - else - types.push(new TypeAST(this->consume(TokenType::TK_IDENTIFIER))); - } + TypeAST *type = dynamic_cast(this->type()); this->consume(Token(TokenType::TK_CURLY_BRACES, "{")); AST *node = new FunctionDeclarationAST( - identifier, types, + identifier, type, dynamic_cast(this->statementList())); this->consume(Token(TokenType::TK_CURLY_BRACES, "}")); return node; @@ -56,11 +37,13 @@ AST *LangParser::classDeclaration() { Token identifier = this->consume(TokenType::TK_IDENTIFIER); if (this->match(TK_COLON)) { + std::stack types; this->consume(TK_COLON); if (this->match(TK_RESERVED_WORD)) - superclass = new TypeAST(this->consume(TK_RESERVED_WORD)); + types.push(new IdentifierAST(this->consume(TK_RESERVED_WORD))); else - superclass = new TypeAST(this->consume(TK_IDENTIFIER)); + types.push(new IdentifierAST(this->consume(TK_IDENTIFIER))); + superclass = new TypeAST(types); } this->consume(Token(TokenType::TK_CURLY_BRACES, "{")); @@ -80,3 +63,17 @@ AST *LangParser::classDeclaration() { this->consume(Token(TokenType::TK_CURLY_BRACES, "}")); return new ClassDeclarationAST(identifier, superclass, variables, methods); } + +AST *LangParser::type() { + std::stack types; + while (this->match(TokenType::TK_ARROW)) { + this->consume(Token(TokenType::TK_ARROW, "->")); + if (this->match(TokenType::TK_RESERVED_WORD)) + types.push(new IdentifierAST(this->consume(TokenType::TK_RESERVED_WORD))); + else + types.push(new IdentifierAST(this->consume(TokenType::TK_IDENTIFIER))); + } + if (types.empty()) + this->consume(Token(TokenType::TK_ARROW, "->")); + return new TypeAST(types); +} diff --git a/src/tokens/token.cpp b/src/tokens/token.cpp index 86a42a6..1baad62 100644 --- a/src/tokens/token.cpp +++ b/src/tokens/token.cpp @@ -47,7 +47,8 @@ std::string Token::toString() { return "ARROW"; case TK_COLON: return "COLON"; - break; + case TK_COMMA: + return "COMMA"; } return ""; } diff --git a/src/tokens/token.h b/src/tokens/token.h index ca259f9..ad759fd 100644 --- a/src/tokens/token.h +++ b/src/tokens/token.h @@ -6,6 +6,7 @@ enum TokenType { TK_IDENTIFIER, TK_DOT, + TK_COMMA, TK_COLON, TK_STRING, TK_NUMBER, From d2c051cd7f6ff83836b31ebf23c652af90738c3f Mon Sep 17 00:00:00 2001 From: AdsonFS Date: Mon, 23 Sep 2024 13:56:54 -0300 Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=A8=20Function=20Parameters=20And=20A?= =?UTF-8?q?rguments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat: add parameters and arguments to functions --- examples/main.ll | 35 +++--------- src/ast/ast.h | 30 +++++++---- src/ast/interpreter_visitor.cpp | 38 ++++++++++--- src/ast/interpreter_visitor.h | 56 ++++++++++---------- src/ast/semantic_visitor.cpp | 43 +++++++++++---- src/core/core.h | 1 + src/core/lang_object.cpp | 1 + src/core/lang_object.h | 23 +++++++- src/parser/validators/declaration_parser.cpp | 15 +++++- src/parser/validators/expression_parser.cpp | 5 ++ src/symbols/symbol.cpp | 4 ++ src/symbols/symbol.h | 2 +- 12 files changed, 164 insertions(+), 89 deletions(-) diff --git a/examples/main.ll b/examples/main.ll index def12d9..ecebdff 100644 --- a/examples/main.ll +++ b/examples/main.ll @@ -1,32 +1,9 @@ -class address { - var street -> string := "123 Main St"; - var city -> string := "Springfield"; - var state -> string := "IL"; - var zip -> string := "62704"; -} - -class person { - var age -> number := 12 + 4; - var name -> string := "John"; - var addr -> address := address(); - - func print() -> void { - >> "Hello, I am" >> name; +<> Fibonacci +func fib(n -> number) -> number { + if (n < 2) { + return n; } + return fib(n - 1) + fib(n - 2); } - -class student : person { - var code -> number := 35; - - <> override do metodo print - func print() -> void { - >> "Hi, I am a student, my name is" >> name; - } -} - -var p -> person := person(); -var s -> student := student(); - -p.print(); -s.print(); +>> fib(28); diff --git a/src/ast/ast.h b/src/ast/ast.h index 9a5fa75..fad49ec 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -95,17 +95,6 @@ class ClassDeclarationAST : public AST { std::vector methods; }; -class FunctionDeclarationAST : public AST { -public: - FunctionDeclarationAST(Token identifier, TypeAST *type, - StatementListAST *statements) - : identifier(identifier), type(type), statements(statements) {} - ASTValue *accept(ASTVisitor &visitor) override; - - TypeAST *type; - Token identifier; - StatementListAST *statements; -}; class OutputStreamAST : public AST { public: @@ -134,6 +123,25 @@ class VariableDeclarationAST : public AST { Token identifier; AST *value; }; +class FunctionDeclarationAST : public AST { +public: + FunctionDeclarationAST(Token identifier, TypeAST *type, + std::vector parameters, + StatementListAST *statements) + : identifier(identifier), type(type), parameters(parameters), statements(statements) {} + ASTValue *accept(ASTVisitor &visitor) override; + std::vector getParameterNames() { + std::vector names; + for (auto ¶meter : parameters) + names.push_back(parameter->identifier.getValue()); + return names; + } + + TypeAST *type; + std::vector parameters; + Token identifier; + StatementListAST *statements; +}; class AssignmentVariableAST : public AST { public: diff --git a/src/ast/interpreter_visitor.cpp b/src/ast/interpreter_visitor.cpp index 8cbee99..d15a463 100644 --- a/src/ast/interpreter_visitor.cpp +++ b/src/ast/interpreter_visitor.cpp @@ -96,10 +96,17 @@ ASTValue *InterpreterVisitor::visitClassDeclaration(ClassDeclarationAST *expr) { ASTValue * InterpreterVisitor::visitFunctionDeclaration(FunctionDeclarationAST *expr) { ASTValue *type = expr->type->accept(*this); - FuncSymbol *func = - new FuncSymbol(expr->identifier.getValue(), - new ASTValue(new LangFunction(expr->statements, - type->value, this->scope))); + std::vector parameters; + for (auto ¶meter : expr->parameters) + parameters.push_back(parameter->type->accept(*this)->value); + + /*for (auto ¶meter : expr->parameters)*/ + /* parameter->accept(*this);*/ + + FuncSymbol *func = new FuncSymbol( + expr->identifier.getValue(), + new ASTValue(new LangFunction(expr->statements, parameters, expr->parameters, + type->value, this->scope))); this->scope->set(func); return new ASTValue(new LangNil()); } @@ -198,12 +205,18 @@ ASTValue *InterpreterVisitor::visitUnaryOperatorExpr(UnaryOperatorAST *expr) { throw RuntimeError("invalid operator: " + expr->op.getValue(), expr->op); } -ASTValue *InterpreterVisitor::visitCall(LangObject *callee, std::string name, +ASTValue *InterpreterVisitor::visitCall(LangObject *callee, std::string name, std::vector arguments, Token &token) { if (typeid(*callee) == typeid(LangFunction)) { LangFunction *func = dynamic_cast(callee); ScopedSymbolTable *currentScope = this->scope; this->scope = func->getScope()->newScope("functioncall"); + for (int i = 0; i < func->getParameters().size(); i++) { + std::string name = func->getParameters()[i]->identifier.getValue(); + ASTValue *value = new ASTValue(arguments[i]); + this->scope->set(new VarSymbol(name, value)); + } + LangObject *returnValue = new LangVoid(); try { func->getValue()->accept(*this); @@ -223,7 +236,11 @@ ASTValue *InterpreterVisitor::visitCall(LangObject *callee, std::string name, ASTValue *InterpreterVisitor::visitCall(CallAST *expr) { ASTValue *value = this->scope->getValue(expr->identifier.getValue(), this->jumpTable[expr]); - return this->visitCall(value->value, expr->identifier.getValue(), + std::vector arguments; + for (auto &argument : expr->arguments) + arguments.push_back(argument->accept(*this)->value); + + return this->visitCall(value->value, expr->identifier.getValue(), arguments, expr->identifier); } @@ -240,7 +257,11 @@ ASTValue *InterpreterVisitor::visitPropertyChain(PropertyChainAST *expr) { } else if (typeid(*node) == typeid(CallAST)) { CallAST *call = dynamic_cast(node); value = instance->getScope()->getValue(call->identifier.getValue(), 0); - value = this->visitCall(value->value, call->identifier.getValue(), + std::vector arguments; + for (auto &argument : call->arguments) + arguments.push_back(argument->accept(*this)->value); + + value = this->visitCall(value->value, call->identifier.getValue(), arguments, call->identifier); } } @@ -255,7 +276,8 @@ ASTValue *InterpreterVisitor::visitType(TypeAST *expr) { if (type == nullptr) type = t; else - type = new ASTValue(new LangFunction(nullptr, type->value, this->scope)); + type = + new ASTValue(new LangFunction(nullptr, {}, {}, type->value, this->scope)); types.pop(); } return type; diff --git a/src/ast/interpreter_visitor.h b/src/ast/interpreter_visitor.h index 76c7c66..6673218 100644 --- a/src/ast/interpreter_visitor.h +++ b/src/ast/interpreter_visitor.h @@ -5,61 +5,63 @@ #include "visitor.h" #include -class InterpreterVisitor: public ASTVisitor { +class InterpreterVisitor : public ASTVisitor { public: - static void setJumpTable(std::unordered_map jumpTable) { + static void setJumpTable(std::unordered_map jumpTable) { InterpreterVisitor::jumpTable = jumpTable; } + protected: private: static ScopedSymbolTable *scope; - static std::unordered_map jumpTable; + static std::unordered_map jumpTable; + + ASTValue *visitCall(LangObject *callee, std::string name, + std::vector arguments, Token &token); - ASTValue *visitCall(LangObject *callee, std::string name, Token &token); + ASTValue *visitStatementList(StatementListAST *expr) override; - ASTValue* visitStatementList(StatementListAST *expr) override; + ASTValue *visitBLock(BlockAST *expr) override; - ASTValue* visitBLock(BlockAST *expr) override; + ASTValue *visitReturn(ReturnAST *expr) override; - ASTValue* visitReturn(ReturnAST *expr) override; + ASTValue *visitWhileStatement(WhileStatementAST *expr) override; - ASTValue* visitWhileStatement(WhileStatementAST *expr) override; + ASTValue *visitForStatement(ForStatementAST *expr) override; - ASTValue* visitForStatement(ForStatementAST *expr) override; + ASTValue *visitIfStatement(IfStatementAST *expr) override; - ASTValue* visitIfStatement(IfStatementAST *expr) override; + ASTValue *visitClassDeclaration(ClassDeclarationAST *expr) override; - ASTValue* visitClassDeclaration(ClassDeclarationAST *expr) override; + ASTValue *visitFunctionDeclaration(FunctionDeclarationAST *expr) override; - ASTValue* visitFunctionDeclaration(FunctionDeclarationAST *expr) override; + ASTValue *visitOutputStream(OutputStreamAST *expr) override; - ASTValue* visitOutputStream(OutputStreamAST *expr) override; + ASTValue *visitInputStream(InputStreamAST *expr) override; - ASTValue* visitInputStream(InputStreamAST *expr) override; - - ASTValue* visitVariableDeclaration(VariableDeclarationAST *expr) override; + ASTValue *visitVariableDeclaration(VariableDeclarationAST *expr) override; - ASTValue* visitAssignmentVariable(AssignmentVariableAST *expr) override; + ASTValue *visitAssignmentVariable(AssignmentVariableAST *expr) override; - ASTValue* visitBinaryOperatorExpr(BinaryOperatorAST *expr) override; + ASTValue *visitBinaryOperatorExpr(BinaryOperatorAST *expr) override; - ASTValue* visitUnaryOperatorExpr(UnaryOperatorAST *expr) override; + ASTValue *visitUnaryOperatorExpr(UnaryOperatorAST *expr) override; - ASTValue* visitCall(CallAST *expr) override; + ASTValue *visitCall(CallAST *expr) override; - ASTValue* visitPropertyChain(PropertyChainAST *expr) override; + ASTValue *visitPropertyChain(PropertyChainAST *expr) override; - ASTValue* visitType(TypeAST* expr) override; + ASTValue *visitType(TypeAST *expr) override; - ASTValue* visitIdentifier(IdentifierAST *expr) override; + ASTValue *visitIdentifier(IdentifierAST *expr) override; - ASTValue* visitNumberExpr(NumberAST *expr) override; + ASTValue *visitNumberExpr(NumberAST *expr) override; - ASTValue* visitStringExpr(StringAST *expr) override; + ASTValue *visitStringExpr(StringAST *expr) override; - ASTValue* visitVoid(VoidAST *expr) override; + ASTValue *visitVoid(VoidAST *expr) override; - ASTValue* visitNil(NilAST *expr) override; + ASTValue *visitNil(NilAST *expr) override; }; #endif // INTERPRETER_VISITOR_H diff --git a/src/ast/semantic_visitor.cpp b/src/ast/semantic_visitor.cpp index 9a935f4..05c2a4c 100644 --- a/src/ast/semantic_visitor.cpp +++ b/src/ast/semantic_visitor.cpp @@ -1,6 +1,7 @@ #include "semantic_visitor.h" #include "../core/lang_object.h" #include "../error/error.h" +#include ScopedSymbolTable *SemanticVisitor::scope; std::unordered_map SemanticVisitor::jumpTable; @@ -56,10 +57,11 @@ ASTValue *SemanticVisitor::visitIfStatement(IfStatementAST *expr) { this->scope = this->scope->newScope("if"); expr->ifStatements->accept(*this); this->scope = this->scope->previousScope; - // } else if (expr->elseStatements != nullptr) { - this->scope = this->scope->newScope("else"); - expr->elseStatements->accept(*this); - this->scope = this->scope->previousScope; + if (expr->elseStatements != nullptr) { + this->scope = this->scope->newScope("else"); + expr->elseStatements->accept(*this); + this->scope = this->scope->previousScope; + } return new ASTValue(new LangNil()); } @@ -69,9 +71,9 @@ ASTValue *SemanticVisitor::visitClassDeclaration(ClassDeclarationAST *expr) { ASTValue *superclass = expr->superclass == nullptr ? nullptr : expr->superclass->accept(*this); if (superclass != nullptr && typeid(*superclass->value) != typeid(LangClass)) - throw SemanticError("Invalid superclass for class " + - expr->identifier.getValue(), - dynamic_cast(expr->superclass->types.top())->token); + throw SemanticError( + "Invalid superclass for class " + expr->identifier.getValue(), + dynamic_cast(expr->superclass->types.top())->token); if (superclass != nullptr) { LangClass *super = dynamic_cast(superclass->value); classScope->setSymbols(super->getScope()->getSymbols()); @@ -105,9 +107,15 @@ ASTValue *SemanticVisitor::visitClassDeclaration(ClassDeclarationAST *expr) { ASTValue * SemanticVisitor::visitFunctionDeclaration(FunctionDeclarationAST *expr) { ASTValue *type = expr->type->accept(*this); + + std::vector parameters; + for (auto ¶meter : expr->parameters) + parameters.push_back(parameter->type->accept(*this)->value); + FuncSymbol *func = new FuncSymbol(expr->identifier.getValue(), - new ASTValue(new LangFunction(expr->statements, + new ASTValue(new LangFunction(expr->statements, parameters, + expr->parameters, type->value, this->scope))); if (!this->scope->set(func)) throw SemanticError("Function " + expr->identifier.getValue() + @@ -120,6 +128,9 @@ SemanticVisitor::visitFunctionDeclaration(FunctionDeclarationAST *expr) { this->currentReturnType = new LangVoid(); this->currentFunctionType = type->value; + for (auto ¶meter : expr->parameters) + parameter->accept(*this); + expr->statements->accept(*this); if (!ScopedSymbolTable::isSameType(type->value, this->currentReturnType)) @@ -206,6 +217,19 @@ ASTValue *SemanticVisitor::visitCall(CallAST *expr) { if (typeid(*callee) == typeid(LangFunction)) { LangFunction *function = dynamic_cast(callee); + if (expr->arguments.size() != function->getArguments().size()) + throw SemanticError("Invalid number of arguments to " + + expr->identifier.getValue(), + expr->identifier); + for (int i = 0; i < expr->arguments.size(); i++) { + ASTValue *argument = expr->arguments[i]->accept(*this); + if (!ScopedSymbolTable::isSameType(function->getArguments()[i], + argument->value)) + throw SemanticError("Invalid type for argument " + std::to_string(i) + + " to " + expr->identifier.getValue(), + expr->identifier); + } + return new ASTValue(function->getReturnType()); } else if (typeid(*callee) == typeid(LangClass)) { LangClass *lang_class = dynamic_cast(callee); @@ -254,7 +278,8 @@ ASTValue *SemanticVisitor::visitType(TypeAST *expr) { throw SemanticError("type mismatch: " + identifier->token.getValue(), types.top()->token); else - type = new ASTValue(new LangFunction(nullptr, type->value, this->scope)); + type = new ASTValue( + new LangFunction(nullptr, {}, {}, type->value, this->scope)); types.pop(); } return type; diff --git a/src/core/core.h b/src/core/core.h index a627ea0..f582ca9 100644 --- a/src/core/core.h +++ b/src/core/core.h @@ -3,6 +3,7 @@ #include #include +#include #include "lang_object.h" class ASTVisitor; diff --git a/src/core/lang_object.cpp b/src/core/lang_object.cpp index 5804643..7d2005d 100644 --- a/src/core/lang_object.cpp +++ b/src/core/lang_object.cpp @@ -25,6 +25,7 @@ void LangNumber::cin(std::istream &is) { is >> this->value; } LangObject *LangNumber::operator-() { return new LangNumber(-this->value); } LangObject *LangNumber::operator+() { return new LangNumber(+this->value); } + void LangBoolean::toString(std::ostream &os) const { os << (std::string) (this->value ? "true" : "false"); } diff --git a/src/core/lang_object.h b/src/core/lang_object.h index 5b9ffd7..ec6129f 100644 --- a/src/core/lang_object.h +++ b/src/core/lang_object.h @@ -5,8 +5,10 @@ #include "istream" #include #include +#include class AST; +class VariableDeclarationAST; class LangBoolean; class LangFunction; class ScopedSymbolTable; @@ -67,12 +69,17 @@ class LangObject { class LangFunction : public LangObject { public: - LangFunction(AST *value, LangObject *returnType, ScopedSymbolTable *scope) - : value(value), returnType(returnType), scope(scope) {} + LangFunction(AST *value, std::vector arguments, + std::vector parameters, LangObject *returnType, + ScopedSymbolTable *scope) + : value(value), arguments(arguments), parameters(parameters), + returnType(returnType), scope(scope) {} void setValue(AST *value) { this->value = value; } AST *getValue() { return value; } LangObject *getReturnType() { return returnType; } ScopedSymbolTable *getScope() { return scope; } + std::vector getArguments() { return arguments; } + std::vector getParameters() { return parameters; } private: void setValue(LangObject *value) override { @@ -84,6 +91,8 @@ class LangFunction : public LangObject { bool isTrue() const override { return true; } void toString(std::ostream &os) const override; AST *value; + std::vector parameters; + std::vector arguments; ScopedSymbolTable *scope; LangObject *returnType; }; @@ -126,12 +135,22 @@ class LangNumber : public LangObject { LangObject *operator-() override; LangObject *operator+() override; + LangBoolean *operator==(const LangObject &rhs) const override { + return new LangBoolean(value == ((LangNumber &)rhs).value); + } + LangBoolean *operator<(const LangObject &rhs) const override { return new LangBoolean(value < ((LangNumber &)rhs).value); } + LangBoolean *operator>(const LangObject &rhs) const override { + return new LangBoolean(value > ((LangNumber &)rhs).value); + } LangObject *operator+(const LangObject &rhs) const override { return new LangNumber(value + ((LangNumber &)rhs).value); } + LangObject *operator-(const LangObject &rhs) const override { + return new LangNumber(value - ((LangNumber &)rhs).value); + } bool isTrue() const override { return value; } void toString(std::ostream &os) const override; diff --git a/src/parser/validators/declaration_parser.cpp b/src/parser/validators/declaration_parser.cpp index 9c812b2..633751b 100644 --- a/src/parser/validators/declaration_parser.cpp +++ b/src/parser/validators/declaration_parser.cpp @@ -19,13 +19,24 @@ AST *LangParser::variableDeclaration() { AST *LangParser::funcDeclaration() { this->consume(Token(TokenType::TK_RESERVED_WORD, "func")); Token identifier = this->consume(TokenType::TK_IDENTIFIER); + this->consume(Token(TokenType::TK_PARENTHESES, "(")); + std::vector parameters; + while (this->match(TokenType::TK_IDENTIFIER)) { + Token identifier = this->consume(TokenType::TK_IDENTIFIER); + TypeAST *type = dynamic_cast(this->type()); + parameters.push_back( + new VariableDeclarationAST(type, identifier, new NilAST())); + if (this->match(TokenType::TK_COMMA)) + this->consume(TokenType::TK_COMMA); + } this->consume(Token(TokenType::TK_PARENTHESES, ")")); - TypeAST *type = dynamic_cast(this->type()); + + TypeAST *type = dynamic_cast(this->type()); this->consume(Token(TokenType::TK_CURLY_BRACES, "{")); AST *node = new FunctionDeclarationAST( - identifier, type, + identifier, type, parameters, dynamic_cast(this->statementList())); this->consume(Token(TokenType::TK_CURLY_BRACES, "}")); return node; diff --git a/src/parser/validators/expression_parser.cpp b/src/parser/validators/expression_parser.cpp index 9f5a375..ea50281 100644 --- a/src/parser/validators/expression_parser.cpp +++ b/src/parser/validators/expression_parser.cpp @@ -108,6 +108,11 @@ AST *LangParser::identifier_or_call() { if (this->match(Token(TK_PARENTHESES, "("))) { std::vector arguments; this->consume(Token(TK_PARENTHESES, "(")); + while (this->token.getType() != TK_PARENTHESES) { + arguments.push_back(this->expression()); + if (this->match(Token(TK_COMMA, ","))) + this->consume(Token(TK_COMMA, ",")); + } this->consume(Token(TK_PARENTHESES, ")")); return new CallAST(token, arguments); } diff --git a/src/symbols/symbol.cpp b/src/symbols/symbol.cpp index efae1be..a8aec14 100644 --- a/src/symbols/symbol.cpp +++ b/src/symbols/symbol.cpp @@ -20,6 +20,10 @@ bool ScopedSymbolTable::isSameType(LangObject *lhs, LangObject *rhs) { LangNil *lhsNil = dynamic_cast(lhs); return isSameType(lhsNil->getType(), rhs); } + if (typeid(*rhs) == typeid(LangNil)) { + LangNil *rhsNil = dynamic_cast(rhs); + return isSameType(lhs, rhsNil->getType()); + } if (typeid(*lhs) != typeid(*rhs)) return false; if (typeid(*lhs) == typeid(LangFunction)) { diff --git a/src/symbols/symbol.h b/src/symbols/symbol.h index 0e2abaa..874cdf8 100644 --- a/src/symbols/symbol.h +++ b/src/symbols/symbol.h @@ -52,7 +52,7 @@ class ScopedSymbolTable { this->set(new BuiltInTypeSymbol("void", new ASTValue (new LangVoid()))); this->set(new BuiltInTypeSymbol("func", - new ASTValue(new LangFunction(nullptr, new LangVoid(), this)))); + new ASTValue(new LangFunction(nullptr, {}, {}, new LangVoid(), this)))); } std::string getName(); bool set(Symbol *symbol);