#include #include #include #include #include "AST.h" namespace DragonLisp { std::shared_ptr ArrayRefAST::eval(Context* parent) { // Eval this->index auto idx = std::dynamic_pointer_cast(this->index->eval(parent)); if (!idx || !idx->isInt()) throw std::runtime_error("Cannot eval index as integer"); auto var = parent->getVariable(this->name); if (!var) throw std::runtime_error("Variable not found: " + this->name); auto varC = std::dynamic_pointer_cast(var); if (!varC) throw std::runtime_error("Cannot reference from non-array variable: " + this->name); if (varC->getSize() <= idx->getInt()) throw std::runtime_error("Index out of range: " + std::to_string(idx->getInt()) + " >= " + std::to_string(varC->getSize())); return std::make_shared((*varC)[idx->getInt()]); } std::shared_ptr ArrayRefAST::set(Context* parent, std::shared_ptr value) { // Eval this->index auto idx = std::dynamic_pointer_cast(this->index->eval(parent)); if (!idx || !idx->isInt()) throw std::runtime_error("Cannot eval index as integer"); auto var = parent->getVariable(this->name); if (!var) throw std::runtime_error("Variable not found: " + this->name); auto varC = std::dynamic_pointer_cast(var); if (!varC) throw std::runtime_error("Cannot reference from non-array variable: " + this->name); if (varC->getSize() <= idx->getInt()) throw std::runtime_error("Index out of range: " + std::to_string(idx->getInt()) + " >= " + std::to_string(varC->getSize())); auto val = std::dynamic_pointer_cast(value); if (!val) throw std::runtime_error("Cannot set array element to another array"); varC->set(idx->getInt(), *val); return value; } std::shared_ptr IdentifierAST::eval(Context* parent) { auto var = parent->getVariable(this->name); if (!var) throw std::runtime_error("Variable not found: " + this->name); return var->copy(); } std::shared_ptr IdentifierAST::set(Context* parent, std::shared_ptr value) { parent->setVariable(this->name, value); return value; } std::shared_ptr FuncDefAST::eval(Context* parent, std::vector> arg) { // Create a new context auto ctx = std::make_shared(parent); // Set arguments for (size_t i = 0; i < this->args.size(); i++) { if (i >= arg.size()) throw std::runtime_error("Too few arguments"); ctx->setVariable(this->args[i], arg[i]); } // Eval body std::shared_ptr ret = std::make_shared(); // which is nil for (auto& stmt : this->body) { auto* ptr = stmt.get(); if (ptr->getType() == T_IfAST) ptr = dynamic_cast(ptr)->getResult(ctx.get()).get(); if (!ptr) continue; if (ptr->getType() == T_ReturnAST) { auto retAST = dynamic_cast(ptr); if (this->name == retAST->getName()) return ptr->eval(ctx.get()); throw std::runtime_error("Return name mismatch. Closure is not implemented yet!"); } ret = ptr->eval(ctx.get()); } return ret; } std::shared_ptr FuncCallAST::eval(Context* parent) { // Get the function auto func = parent->getFunc(this->name); if (!func) throw std::runtime_error("Function not defined: " + this->name); // Eval arguments std::vector> arg; for (const auto& a : this->args) { arg.push_back(a->eval(parent)); } // Get the global context auto globalCtx = parent; while (globalCtx->getParent()) globalCtx = globalCtx->getParent(); // Eval under global context return func->eval(globalCtx, arg); } std::shared_ptr IfAST::getResult(Context* parent) { // Eval condition auto c = this->cond->eval(parent); bool ok = c->isArray(); if (!ok) { auto cc = std::dynamic_pointer_cast(c); if (!cc) throw std::runtime_error("Unexpected error"); if (!cc->isNil()) ok = true; } return ok ? this->then : this->els; } std::shared_ptr LoopForeverAST::eval(Context* parent) { // No context is needed while (true) { for (auto& stmt : this->body) { auto* ptr = stmt.get(); if (ptr->getType() == T_IfAST) ptr = dynamic_cast(ptr)->getResult(parent).get(); if (!ptr) continue; if (ptr->getType() == T_ReturnAST) return ptr->eval(parent); ptr->eval(parent); } } throw std::runtime_error("Unexpected error"); } std::shared_ptr LoopForAST::eval(Context* parent) { // Create a new context auto ctx = std::make_shared(parent); // Eval condition auto s = std::dynamic_pointer_cast(this->start->eval(parent)->copy()); auto e = std::dynamic_pointer_cast(this->end->eval(parent)); // Assert that s and e are numeric if (!s || !e || (!s->isInt() && !s->isFloat()) || (!e->isInt() && !e->isFloat())) throw std::runtime_error("LoopForAST: start and end must be numeric"); // Main loop while (*s <= *e) { // Set the variable ctx->setVariable(this->name, s); // Eval body for (auto& stmt : this->body) { auto* ptr = stmt.get(); if (ptr->getType() == T_IfAST) ptr = dynamic_cast(ptr)->getResult(ctx.get()).get(); if (!ptr) continue; if (ptr->getType() == T_ReturnAST) return ptr->eval(ctx.get()); ptr->eval(ctx.get()); } // Increment s->operator++(); } // Return nil return std::make_shared(false); } std::shared_ptr LoopDoTimesAST::eval(Context* parent) { // Create a new context auto ctx = std::make_shared(parent); // Eval condition auto terminate = std::dynamic_pointer_cast(this->times->eval(parent)->copy()); if (!terminate || !terminate->isInt()) throw std::runtime_error("DOTIMES: times must be an integer"); auto n = terminate->getInt(); // Main Loop for (std::int64_t i = 0; i < n; ++i) { // Set Variable ctx->setVariable(this->name, std::make_shared(i)); // Eval Body for (auto& stmt : this->body) { auto* ptr = stmt.get(); if (ptr->getType() == T_IfAST) ptr = dynamic_cast(ptr)->getResult(ctx.get()).get(); if (!ptr) continue; if (ptr->getType() == T_ReturnAST) return ptr->eval(ctx.get()); ptr->eval(ctx.get()); } } // Return nil return std::make_shared(false); } std::shared_ptr UnaryAST::eval(Context* parent) { auto val = this->expr->eval(parent); auto valS = std::dynamic_pointer_cast(val); switch (this->op) { case NOT: return std::make_shared(!val->isArray() && valS && valS->isNil()); case MAKE_ARRAY: if (val->isArray() || (valS && !valS->isInt())) throw std::runtime_error("Array size must be an integer"); return std::make_shared(valS->getInt()); case PRINT: std::cout << val->toString() << std::endl; return val; default: throw std::runtime_error("Unexpected error"); } } std::shared_ptr BinaryAST::eval(Context* parent) { // All binary operators requires // both operands to be int / float. auto lv = std::dynamic_pointer_cast(this->lhs->eval(parent)); auto rv = std::dynamic_pointer_cast(this->rhs->eval(parent)); if (!lv || !rv || !(lv->isInt() || lv->isFloat()) || !(rv->isInt() || rv->isFloat())) throw std::runtime_error("Both operands must be int or float"); if (lv->isFloat() || rv->isFloat()) { double l = lv->isInt() ? lv->getInt() : lv->getFloat(); double r = rv->isInt() ? rv->getInt() : rv->getFloat(); switch (this->op) { case LESS: return std::make_shared(l < r); case LESS_EQUAL: return std::make_shared(l <= r); case GREATER: return std::make_shared(l > r); case GREATER_EQUAL: return std::make_shared(l >= r); case MOD: case REM: return std::make_shared(std::fmod(l, r)); default: throw std::runtime_error("This operator cannot be applied to float"); } } else { std::int64_t l = lv->getInt(); std::int64_t r = rv->getInt(); switch (this->op) { case LESS: return std::make_shared(l < r); case LESS_EQUAL: return std::make_shared(l <= r); case GREATER: return std::make_shared(l > r); case GREATER_EQUAL: return std::make_shared(l >= r); case MOD: case REM: return std::make_shared(l % r); case LOGNOR: return std::make_shared(~(l | r)); default: throw std::runtime_error("Unexpected error"); } } } std::shared_ptr ListAST::eval(Context* parent) { auto ret = this->exprs[0]->eval(parent); auto retS = std::dynamic_pointer_cast(ret); if (this->exprs.size() == 1) { switch (this->op) { case LOGAND: case LOGNOR: case LOGXOR: case LOGEQV: if (ret->isArray() || (retS && !retS->isInt())) throw std::runtime_error("Cannot apply selected operator to non-integer"); case MAX: case MIN: case PLUS: case MULTIPLY: if (ret->isArray() || (retS && !retS->isInt() && !retS->isFloat())) throw std::runtime_error("Cannot apply selected operator to non-integer or non-float"); case AND: case OR: return ret; case EQUAL: case NOT_EQUAL: if (ret->isArray() || (retS && !retS->isInt() && !retS->isFloat())) throw std::runtime_error("Cannot apply selected operator to non-integer or non-float"); return std::make_shared(true); default:; } throw std::runtime_error("Invalid argument count for selected operator"); } // Transform #1: Eval all values. std::vector> vals; std::transform(this->exprs.begin(), this->exprs.end(), std::back_inserter(vals), [&](std::shared_ptr& ptr) { return ptr->eval(parent); }); // For And, return NIL if any value is NIL. Otherwise, return the last value. if (this->op == AND) { if (std::any_of(vals.begin(), vals.end(), [](std::shared_ptr& ptr) { auto val = std::dynamic_pointer_cast(ptr); return ptr->isArray() && val && val->isNil(); })) return std::make_shared(false); return vals.back(); } // For Or, return the first non-NIL value or NIL if all values are NIL. if (this->op == OR) { auto it = std::find_if(vals.begin(), vals.end(), [](std::shared_ptr& ptr) { auto val = std::dynamic_pointer_cast(ptr); return !ptr->isArray() || (val && !val->isNil()); }); if (it == vals.end()) return std::make_shared(false); return *it; } // Now, all operators require all values to be int / float. if (std::any_of(vals.begin(), vals.end(), [](std::shared_ptr& ptr) { auto val = std::dynamic_pointer_cast(ptr); return ptr->isArray() || (val && !val->isInt() && !val->isFloat()); })) throw std::runtime_error("All values must be int or float"); // Transform #2: Convert all values std::vector> vals2; std::transform(vals.begin(), vals.end(), std::back_inserter(vals2), [](std::shared_ptr& ptr) { return std::dynamic_pointer_cast(ptr); }); // Assure that no value is nullptr. if (std::any_of(vals2.begin(), vals2.end(), [](std::shared_ptr& ptr) { return ptr == nullptr; })) throw std::runtime_error("Unexpected error"); bool hasFloat = std::any_of(vals2.begin(), vals2.end(), [](std::shared_ptr& ptr) { return ptr->isFloat(); }); std::vector intVal; std::vector floatVal; switch (this->op) { default: throw std::runtime_error("Unexpected error"); case LOGAND: case LOGIOR: case LOGXOR: case LOGEQV: if (hasFloat) throw std::runtime_error("Cannot apply selected operator to non-integer"); std::transform(vals2.begin(), vals2.end(), std::back_inserter(intVal), [&](std::shared_ptr& ptr) { return ptr->getInt(); }); return std::make_shared(std::accumulate(intVal.begin() + 1, intVal.end(), intVal[0], [this](std::int64_t x, std::int64_t y) { switch (this->op) { case LOGAND: return x & y; case LOGIOR: return x | y; case LOGXOR: return x ^ y; case LOGEQV: return ~(x ^ y); default: throw std::runtime_error("Unexpected error"); } })); case MAX: return std::max_element(vals2.begin(), vals2.end(), [](std::shared_ptr& x, std::shared_ptr& y) { return (x->isFloat() ? x->getFloat() : x->getInt()) < (y->isFloat() ? y->getFloat() : y->getInt()); })->operator->()->copy(); case MIN: return std::min_element(vals2.begin(), vals2.end(), [](std::shared_ptr& x, std::shared_ptr& y) { return (x->isFloat() ? x->getFloat() : x->getInt()) < (y->isFloat() ? y->getFloat() : y->getInt()); })->operator->()->copy(); case EQUAL: return std::all_of(vals2.begin() + 1, vals2.end(), [&](std::shared_ptr& ptr) { return ptr->operator==(*vals2[0]); }) ? std::make_shared(true) : std::make_shared(false); case NOT_EQUAL: return !std::all_of(vals2.begin() + 1, vals2.end(), [&](std::shared_ptr& ptr) { return ptr->operator==(*vals2[0]); }) ? std::make_shared(true) : std::make_shared(false); case PLUS: case MINUS: case MULTIPLY: case DIVIDE: auto opFunc = [this](auto x, auto y) { switch (this->op) { case PLUS: return x + y; case MINUS: return x - y; case MULTIPLY: return x * y; case DIVIDE: return x / y; default: throw std::runtime_error("Unexpected error"); } }; if (hasFloat) { std::transform(vals2.begin(), vals2.end(), std::back_inserter(floatVal), [&](std::shared_ptr& ptr) { return (ptr->isFloat() ? ptr->getFloat() : ptr->getInt()); }); return std::make_shared(std::accumulate(floatVal.begin() + 1, floatVal.end(), floatVal[0], opFunc)); } else { std::transform(vals2.begin(), vals2.end(), std::back_inserter(intVal), [&](std::shared_ptr& ptr) { return ptr->getInt(); }); return std::make_shared(std::accumulate(intVal.begin() + 1, intVal.end(), intVal[0], opFunc)); } } } std::shared_ptr VarOpAST::eval(Context* parent) { // DEFVAR || SETQ if (this->op == SETQ && !parent->hasVariable(this->name)) { throw std::runtime_error("Variable not defined: " + this->name); } // Eval value auto val = this->expr->eval(parent); parent->setVariable(this->name, val->copy()); return val; } std::shared_ptr LValOpAST::eval(Context* parent) { // Eval value auto val = this->expr->eval(parent); auto lv = std::dynamic_pointer_cast(this->lval); if (!lv) throw std::runtime_error("Unexpected error"); if (SETF == this->op) return lv->set(parent, val); auto original = lv->eval(parent); auto originalVal = std::dynamic_pointer_cast(original); if (original->isArray() || !originalVal || !originalVal->isInt()) throw std::runtime_error("Cannot apply INC or DEC to non-integer value"); auto base = originalVal->getInt(); auto valVal = std::dynamic_pointer_cast(val); if (val->isArray() || !valVal || !valVal->isInt()) throw std::runtime_error("Cannot INC or DEC by non-integer value"); auto delta = valVal->getInt(); switch (this->op) { case DECF: delta = -delta; case INCF: base += delta; return lv->set(parent, std::make_shared(base)); default:; } throw std::runtime_error("Unexpected error"); } std::shared_ptr ReturnAST::eval(Context* parent) { return this->expr->eval(parent); } } // end of namespace DragonLisp