diff --git a/AST.cpp b/AST.cpp index 502cf2e..7ce3e5b 100644 --- a/AST.cpp +++ b/AST.cpp @@ -1,7 +1,448 @@ #include +#include +#include +#include #include "AST.h" namespace DragonLisp { +std::shared_ptr ArrayRefAST::eval(Context* parent) { + // Eval this->index + auto idx = std::dynamic_pointer_cast(this->index->eval(parent)); + if (!idx || !idx->isInt()) + throw std::runtime_error("Cannot eval index as integer"); + + auto var = parent->getVariable(this->name); + if (!var) + throw std::runtime_error("Variable not found: " + this->name); + auto varC = std::dynamic_pointer_cast(var); + if (!varC) + throw std::runtime_error("Cannot reference from non-array variable: " + this->name); + if (varC->getSize() <= idx->getInt()) + throw std::runtime_error("Index out of range: " + std::to_string(idx->getInt()) + " >= " + std::to_string(varC->getSize())); + return std::make_shared((*varC)[idx->getInt()]); +} + +std::shared_ptr ArrayRefAST::set(Context* parent, std::shared_ptr value) { + // Eval this->index + auto idx = std::dynamic_pointer_cast(this->index->eval(parent)); + if (!idx || !idx->isInt()) + throw std::runtime_error("Cannot eval index as integer"); + + auto var = parent->getVariable(this->name); + if (!var) + throw std::runtime_error("Variable not found: " + this->name); + auto varC = std::dynamic_pointer_cast(var); + if (!varC) + throw std::runtime_error("Cannot reference from non-array variable: " + this->name); + if (varC->getSize() <= idx->getInt()) + throw std::runtime_error("Index out of range: " + std::to_string(idx->getInt()) + " >= " + std::to_string(varC->getSize())); + + auto val = std::dynamic_pointer_cast(value); + if (!val) + throw std::runtime_error("Cannot set array element to another array"); + varC->set(idx->getInt(), *val); + return value; +} + +std::shared_ptr IdentifierAST::eval(Context* parent) { + auto var = parent->getVariable(this->name); + if (!var) + throw std::runtime_error("Variable not found: " + this->name); + return var->copy(); +} + +std::shared_ptr IdentifierAST::set(Context* parent, std::shared_ptr value) { + parent->setVariable(this->name, value); + return value; +} + +std::shared_ptr FuncDefAST::eval(Context* parent, std::vector> arg) { + // Create a new context + auto ctx = std::make_shared(parent); + + // Set arguments + for (size_t i = 0; i < this->args.size(); i++) { + if (i >= arg.size()) + throw std::runtime_error("Too few arguments"); + ctx->setVariable(this->args[i], arg[i]); + } + + // Eval body + std::shared_ptr ret = std::make_shared(); // which is nil + for (auto& stmt : this->body) { + auto* ptr = stmt.get(); + if (ptr->getType() == T_IfAST) + ptr = dynamic_cast(ptr)->getResult(ctx.get()).get(); + if (!ptr) + continue; + if (ptr->getType() == T_ReturnAST) + return ptr->eval(ctx.get()); + ret = ptr->eval(ctx.get()); + } + return ret; +} + +std::shared_ptr FuncCallAST::eval(Context* parent) { + // Get the function + auto func = parent->getFunc(this->name); + if (!func) + throw std::runtime_error("Function not defined: " + this->name); + + // Eval arguments + std::vector> arg; + for (const auto& a : this->args) { + arg.push_back(a->eval(parent)); + } + + // Get the global context + auto globalCtx = parent; + while (globalCtx->getParent()) + globalCtx = globalCtx->getParent(); + + // Eval under global context + return func->eval(globalCtx, arg); +} + + +std::shared_ptr IfAST::getResult(Context* parent) { + // Eval condition + auto c = this->cond->eval(parent); + bool ok = c->isArray(); + if (!ok) { + auto cc = std::dynamic_pointer_cast(c); + if (!cc) + throw std::runtime_error("Unexpected error"); + if (!cc->isNil()) + ok = true; + } + return ok ? this->then : this->els; +} + + +std::shared_ptr LoopAST::eval(Context* parent) { + // Eval start & end + auto s = this->start->eval(parent); + auto e = this->end->eval(parent); + if (s->isArray() || e->isArray()) + throw std::runtime_error("Cannot eval loop start or end as integer"); + auto sc = std::dynamic_pointer_cast(s); + auto ec = std::dynamic_pointer_cast(e); + if (!sc || !ec) + throw std::runtime_error("Unexpected error"); + + if (sc->isT()) { + while (true) { + // No need to create a new context + for (auto& stmt : this->body) { + auto* ptr = stmt.get(); + if (ptr->getType() == T_IfAST) + ptr = dynamic_cast(ptr)->getResult(parent).get(); + if (!ptr) + continue; + if (ptr->getType() == T_ReturnAST) + return ptr->eval(parent); + } + } + } else { + if (!sc->isInt() || !ec->isInt()) + throw std::runtime_error("Cannot eval loop start or end as integer"); + + auto ss = sc->getInt(); + auto ee = ec->getInt(); + + // Create a new context + auto ctx = std::make_shared(parent); + + for (auto i = ss; i <= ee; ++i) { + ctx->setVariable(this->loopVar, std::make_shared(i)); + for (auto& stmt : this->body) { + auto* ptr = stmt.get(); + if (ptr->getType() == T_IfAST) + ptr = dynamic_cast(ptr)->getResult(ctx.get()).get(); + if (!ptr) + continue; + if (ptr->getType() == T_ReturnAST) + return ptr->eval(ctx.get()); + } + } + } + + // Return nil + return std::make_shared(); +} + + +std::shared_ptr UnaryAST::eval(Context* parent) { + auto val = this->expr->eval(parent); + auto valS = std::dynamic_pointer_cast(val); + switch (this->op) { + case NOT: + return std::make_shared(val->isArray() || (valS && !valS->isNil())); + case MAKE_ARRAY: + if (val->isArray() || (valS && !valS->isInt())) + throw std::runtime_error("Array size must be an integer"); + return std::make_shared(valS->getInt()); + case PRINT: + std::cout << val->toString() << std::endl; + return val; + default: + throw std::runtime_error("Unexpected error"); + } +} + +std::shared_ptr BinaryAST::eval(Context* parent) { + // All binary operators requires + // both operands to be int / float. + auto lv = std::dynamic_pointer_cast(this->lhs->eval(parent)); + auto rv = std::dynamic_pointer_cast(this->rhs->eval(parent)); + if (!lv || !rv || !(lv->isInt() || lv->isFloat()) || !(rv->isInt() || rv->isFloat())) + throw std::runtime_error("Both operands must be int or float"); + if (lv->isFloat() || rv->isFloat()) { + double l = lv->isInt() ? lv->getInt() : lv->getFloat(); + double r = rv->isInt() ? rv->getInt() : rv->getFloat(); + switch (this->op) { + case LESS: + return std::make_shared(l < r); + case LESS_EQUAL: + return std::make_shared(l <= r); + case GREATER: + return std::make_shared(l > r); + case GREATER_EQUAL: + return std::make_shared(l >= r); + case MOD: + case REM: + return std::make_shared(std::fmod(l, r)); + default: + throw std::runtime_error("This operator cannot be applied to float"); + } + } else { + std::int64_t l = lv->getInt(); + std::int64_t r = rv->getInt(); + switch (this->op) { + case LESS: + return std::make_shared(l < r); + case LESS_EQUAL: + return std::make_shared(l <= r); + case GREATER: + return std::make_shared(l > r); + case GREATER_EQUAL: + return std::make_shared(l >= r); + case MOD: + case REM: + return std::make_shared(l % r); + case LOGNOR: + return std::make_shared(~(l | r)); + default: + throw std::runtime_error("Unexpected error"); + } + } +} + +std::shared_ptr ListAST::eval(Context* parent) { + auto ret = this->exprs[0]->eval(parent); + auto retS = std::dynamic_pointer_cast(ret); + if (this->exprs.size() == 1) { + switch (this->op) { + case LOGAND: + case LOGNOR: + case LOGXOR: + case LOGEQV: + if (ret->isArray() || (retS && !retS->isInt())) + throw std::runtime_error("Cannot apply selected operator to non-integer"); + case MAX: + case MIN: + case PLUS: + case MULTIPLY: + if (ret->isArray() || (retS && !retS->isInt() && !retS->isFloat())) + throw std::runtime_error("Cannot apply selected operator to non-integer or non-float"); + case AND: + case OR: + return ret; + case EQUAL: + case NOT_EQUAL: + if (ret->isArray() || (retS && !retS->isInt() && !retS->isFloat())) + throw std::runtime_error("Cannot apply selected operator to non-integer or non-float"); + return std::make_shared(true); + default:; + } + throw std::runtime_error("Invalid argument count for selected operator"); + } + + // Transform #1: Eval all values. + std::vector> vals; + std::transform(this->exprs.begin(), this->exprs.end(), std::back_inserter(vals), [&](std::shared_ptr& ptr) { + return ptr->eval(parent); + }); + + // For And, return NIL if any value is NIL. Otherwise, return the last value. + if (this->op == AND) { + if (std::any_of(vals.begin(), vals.end(), [](std::shared_ptr& ptr) { + auto val = std::dynamic_pointer_cast(ptr); + return ptr->isArray() && val && val->isNil(); + })) + return std::make_shared(false); + return vals.back(); + } + + // For Or, return the first non-NIL value or NIL if all values are NIL. + if (this->op == OR) { + auto it = std::find_if(vals.begin(), vals.end(), [](std::shared_ptr& ptr) { + auto val = std::dynamic_pointer_cast(ptr); + return !ptr->isArray() || (val && !val->isNil()); + }); + if (it == vals.end()) + return std::make_shared(false); + return *it; + } + + // Now, all operators require all values to be int / float. + if (std::any_of(vals.begin(), vals.end(), [](std::shared_ptr& ptr) { + auto val = std::dynamic_pointer_cast(ptr); + return ptr->isArray() || (val && !val->isInt() && !val->isFloat()); + })) + throw std::runtime_error("All values must be int or float"); + + // Transform #2: Convert all values + std::vector> vals2; + std::transform(vals.begin(), vals.end(), std::back_inserter(vals2), [](std::shared_ptr& ptr) { + return std::dynamic_pointer_cast(ptr); + }); + + // Assure that no value is nullptr. + if (std::any_of(vals2.begin(), vals2.end(), [](std::shared_ptr& ptr) { + return ptr == nullptr; + })) + throw std::runtime_error("Unexpected error"); + + bool hasFloat = std::any_of(vals2.begin(), vals2.end(), [](std::shared_ptr& ptr) { + return ptr->isFloat(); + }); + + std::vector intVal; + std::vector floatVal; + switch (this->op) { + default: + throw std::runtime_error("Unexpected error"); + case LOGAND: + case LOGIOR: + case LOGXOR: + case LOGEQV: + if (hasFloat) + throw std::runtime_error("Cannot apply selected operator to non-integer"); + std::transform(vals2.begin(), vals2.end(), std::back_inserter(intVal), [&](std::shared_ptr& ptr) { + return ptr->getInt(); + }); + return std::make_shared(std::reduce(intVal.begin() + 1, intVal.end(), intVal[0], [this](std::int64_t x, std::int64_t y) { + switch (this->op) { + case LOGAND: + return x & y; + case LOGIOR: + return x | y; + case LOGXOR: + return x ^ y; + case LOGEQV: + return ~(x ^ y); + default: + throw std::runtime_error("Unexpected error"); + } + })); + case MAX: + return std::max_element(vals2.begin(), vals2.end(), [](std::shared_ptr& x, std::shared_ptr& y) { + return (x->isFloat() ? x->getFloat() : x->getInt()) < (y->isFloat() ? y->getFloat() : y->getInt()); + })->operator->()->copy(); + case MIN: + return std::min_element(vals2.begin(), vals2.end(), [](std::shared_ptr& x, std::shared_ptr& y) { + return (x->isFloat() ? x->getFloat() : x->getInt()) < (y->isFloat() ? y->getFloat() : y->getInt()); + })->operator->()->copy(); + case EQUAL: + return std::all_of(vals2.begin() + 1, vals2.end(), [&](std::shared_ptr& ptr) { + return ptr->operator==(*vals2[0]); + }) ? std::make_shared(true) : std::make_shared(false); + case NOT_EQUAL: + return !std::all_of(vals2.begin() + 1, vals2.end(), [&](std::shared_ptr& ptr) { + return ptr->operator==(*vals2[0]); + }) ? std::make_shared(true) : std::make_shared(false); + case PLUS: + case MINUS: + case MULTIPLY: + case DIVIDE: + auto opFunc = [this](auto x, auto y) { + switch (this->op) { + case PLUS: + return x + y; + case MINUS: + return x - y; + case MULTIPLY: + return x * y; + case DIVIDE: + return x / y; + default: + throw std::runtime_error("Unexpected error"); + } + }; + if (hasFloat) { + std::transform(vals2.begin(), vals2.end(), std::back_inserter(floatVal), [&](std::shared_ptr& ptr) { + return (ptr->isFloat() ? ptr->getFloat() : ptr->getInt()); + }); + return std::make_shared(std::reduce(floatVal.begin() + 1, floatVal.end(), floatVal[0], opFunc)); + } else { + std::transform(vals2.begin(), vals2.end(), std::back_inserter(intVal), [&](std::shared_ptr& ptr) { + return ptr->getInt(); + }); + return std::make_shared(std::reduce(intVal.begin() + 1, intVal.end(), intVal[0], opFunc)); + } + } + +} + +std::shared_ptr VarOpAST::eval(Context* parent) { + // DEFVAR || SETQ + if (this->op == SETQ && !parent->hasVariable(this->name)) { + throw std::runtime_error("Variable not defined: " + this->name); + } + + // Eval value + auto val = this->expr->eval(parent); + parent->setVariable(this->name, val->copy()); + return val; +} + +std::shared_ptr LValOpAST::eval(Context* parent) { + // Eval value + auto val = this->expr->eval(parent); + auto lv = std::dynamic_pointer_cast(this->lval); + if (!lv) + throw std::runtime_error("Unexpected error"); + + if (SETF == this->op) + return lv->set(parent, val); + + auto original = lv->eval(parent); + auto originalVal = std::dynamic_pointer_cast(original); + if (original->isArray() || !originalVal || !originalVal->isInt()) + throw std::runtime_error("Cannot apply INC or DEC to non-integer value"); + auto base = originalVal->getInt(); + + auto valVal = std::dynamic_pointer_cast(val); + if (val->isArray() || !valVal || !valVal->isInt()) + throw std::runtime_error("Cannot INC or DEC by non-integer value"); + auto delta = valVal->getInt(); + + switch (this->op) { + case DECF: + delta = -delta; + case INCF: + base += delta; + return lv->set(parent, std::make_shared(base)); + default:; + } + throw std::runtime_error("Unexpected error"); +} + +std::shared_ptr ReturnAST::eval(Context* parent) { + return this->expr->eval(parent); +} + } // end of namespace DragonLisp diff --git a/AST.h b/AST.h index c08f8a6..7ff0354 100644 --- a/AST.h +++ b/AST.h @@ -22,6 +22,7 @@ enum ASTType { T_ListAST, T_VarOpAST, T_LValOpAST, + T_ReturnAST, }; /// BaseAST - Base class for all AST nodes. @@ -97,7 +98,7 @@ private: std::vector> args; public: - FuncCallAST(std::string name, std::vector> args); + FuncCallAST(std::string name, std::vector> args) : name(std::move(name)), args(std::move(args)) {} std::shared_ptr eval(Context* parent) override final; @@ -113,7 +114,7 @@ private: std::shared_ptr els; public: - IfAST(std::unique_ptr cond, std::unique_ptr then, std::unique_ptr els); + IfAST(std::shared_ptr cond, std::shared_ptr then, std::shared_ptr els) : cond(std::move(cond)), then(std::move(then)), els(std::move(els)) {} std::shared_ptr eval(Context* parent) override final { throw std::runtime_error("You should use IfAST::getResult() instead of IfAST::eval()"); @@ -131,10 +132,10 @@ private: std::string loopVar; std::shared_ptr start; std::shared_ptr end; - std::shared_ptr body; + std::vector> body; public: - LoopAST(std::string loopVar, std::unique_ptr start, std::unique_ptr end, std::unique_ptr body); + LoopAST(std::string loopVar, std::shared_ptr start, std::shared_ptr end, std::vector> body) : loopVar(std::move(loopVar)), start(std::move(start)), end(std::move(end)), body(std::move(body)) {} std::shared_ptr eval(Context* parent) override final; @@ -149,7 +150,7 @@ private: Token op; public: - UnaryAST(std::unique_ptr expr, Token op); + UnaryAST(std::shared_ptr expr, Token op) : expr(std::move(expr)), op(std::move(op)) {} std::shared_ptr eval(Context* parent) override final; @@ -165,7 +166,7 @@ private: Token op; public: - BinaryAST(std::unique_ptr lhs, std::unique_ptr rhs, Token op); + BinaryAST(std::shared_ptr lhs, std::shared_ptr rhs, Token op) : lhs(std::move(lhs)), rhs(std::move(rhs)), op(std::move(op)) {} std::shared_ptr eval(Context* parent) override final; @@ -180,7 +181,7 @@ private: Token op; public: - ListAST(std::vector> exprs, Token op); + ListAST(std::vector> exprs, Token op) : exprs(std::move(exprs)), op(std::move(op)) {} std::shared_ptr eval(Context* parent) override final; @@ -196,7 +197,7 @@ private: Token op; public: - VarOpAST(std::string name, std::unique_ptr expr, Token op); + VarOpAST(std::string name, std::shared_ptr expr, Token op) : name(std::move(name)), expr(std::move(expr)), op(std::move(op)) {} std::shared_ptr eval(Context* parent) override final; @@ -212,7 +213,7 @@ private: Token op; public: - LValOpAST(std::unique_ptr lval, std::unique_ptr expr, Token op); + LValOpAST(std::shared_ptr lval, std::shared_ptr expr, Token op) : lval(std::move(lval)), expr(std::move(expr)), op(std::move(op)) {} std::shared_ptr eval(Context* parent) override final; @@ -221,6 +222,20 @@ public: } }; +class ReturnAST : public ExprAST { +private: + std::shared_ptr expr; + +public: + explicit ReturnAST(std::shared_ptr expr) : expr(std::move(expr)) {} + + std::shared_ptr eval(Context* parent) override final; + + inline ASTType getType() const override final { + return T_ReturnAST; + } +}; + } #endif // __DRAGON_LISP_AST_H__ diff --git a/context.h b/context.h index 37479c5..d26fa98 100644 --- a/context.h +++ b/context.h @@ -2,6 +2,7 @@ #define __DRAGON_LISP_CONTEXT_H__ #include +#include #include #include "value.h" @@ -30,7 +31,7 @@ public: std::shared_ptr getVariable(const std::string& name) const { if (this->variables.contains(name)) - return this->variables[name]; + return this->variables.at(name); if (this->parent) return this->parent->getVariable(name); return nullptr; @@ -39,6 +40,28 @@ public: void setVariable(const std::string& name, std::shared_ptr value) { this->variables[name] = std::move(value); } + + bool hasVariable(const std::string& name) const { + if (this->variables.contains(name)) + return true; + if (this->parent) + return this->parent->hasVariable(name); + return false; + } + + std::shared_ptr getFunc(const std::string& name) const { + if (this->funcs->contains(name)) + return (*this->funcs)[name]; + return nullptr; + } + + void setFunc(const std::string& name, std::shared_ptr value) { + (*this->funcs)[name] = std::move(value); + } + + Context* getParent() const { + return this->parent; + } }; } diff --git a/value.h b/value.h index 1357d44..d429b17 100644 --- a/value.h +++ b/value.h @@ -14,6 +14,10 @@ using ValueVariant = std::variant copy() const = 0; + + virtual std::string toString() const = 0; }; class SingleValue : public Value { @@ -31,20 +35,14 @@ public: explicit SingleValue(std::string v) : value(std::move(v)) {} + explicit SingleValue(bool v) : type(v ? TYPE_T : TYPE_NIL), value() {} + SingleValue() : SingleValue(ValueType::TYPE_NIL) {} bool isArray() const override final { return false; } - static SingleValue makeT() { - return SingleValue(ValueType::TYPE_T); - } - - static SingleValue makeNil() { - return SingleValue(ValueType::TYPE_NIL); - } - ValueType getType() const { return this->type; } @@ -88,6 +86,26 @@ public: void setValue(ValueVariant v) { this->value = std::move(v); } + + std::shared_ptr copy() const override final { + return std::make_shared(*this); + } + + std::string toString() const override final { + if (this->isInt()) + return std::to_string(this->getInt()); + if (this->isFloat()) + return std::to_string(this->getFloat()); + if (this->isString()) + return this->getString(); + if (this->isT()) + return "T"; + return "NIL"; + } + + bool operator==(const SingleValue& rhs) const { + return type == rhs.type && value == rhs.value; + } }; class ArrayValue : public Value { @@ -115,6 +133,10 @@ public: return this->values[i]; } + void set(std::size_t i, SingleValue v) { + this->values[i] = std::move(v); + } + std::vector& getValues() { return this->values; } @@ -122,6 +144,19 @@ public: const std::vector& getValues() const { return this->values; } + + std::shared_ptr copy() const override final { + return std::make_shared(*this); + } + + std::string toString() const override final { + std::string result = "["; + for (const auto& i : this->values) + result.append(i.toString()).append(", "); + result.pop_back(); + result.back() = ']'; + return result; + } }; class _Unused_Variable {