add: implemented AST eval

This commit is contained in:
Eatswap 2022-12-13 13:06:39 +08:00
parent 29f7667a20
commit 4e9cae13ff
Signed by: Eatswap
GPG Key ID: BE661106A1F3FA0B
4 changed files with 532 additions and 18 deletions

441
AST.cpp
View File

@ -1,7 +1,448 @@
#include <utility>
#include <iostream>
#include <algorithm>
#include <numeric>
#include "AST.h"
namespace DragonLisp {
std::shared_ptr<Value> ArrayRefAST::eval(Context* parent) {
// Eval this->index
auto idx = std::dynamic_pointer_cast<SingleValue>(this->index->eval(parent));
if (!idx || !idx->isInt())
throw std::runtime_error("Cannot eval index as integer");
auto var = parent->getVariable(this->name);
if (!var)
throw std::runtime_error("Variable not found: " + this->name);
auto varC = std::dynamic_pointer_cast<ArrayValue>(var);
if (!varC)
throw std::runtime_error("Cannot reference from non-array variable: " + this->name);
if (varC->getSize() <= idx->getInt())
throw std::runtime_error("Index out of range: " + std::to_string(idx->getInt()) + " >= " + std::to_string(varC->getSize()));
return std::make_shared<SingleValue>((*varC)[idx->getInt()]);
}
std::shared_ptr<Value> ArrayRefAST::set(Context* parent, std::shared_ptr<Value> value) {
// Eval this->index
auto idx = std::dynamic_pointer_cast<SingleValue>(this->index->eval(parent));
if (!idx || !idx->isInt())
throw std::runtime_error("Cannot eval index as integer");
auto var = parent->getVariable(this->name);
if (!var)
throw std::runtime_error("Variable not found: " + this->name);
auto varC = std::dynamic_pointer_cast<ArrayValue>(var);
if (!varC)
throw std::runtime_error("Cannot reference from non-array variable: " + this->name);
if (varC->getSize() <= idx->getInt())
throw std::runtime_error("Index out of range: " + std::to_string(idx->getInt()) + " >= " + std::to_string(varC->getSize()));
auto val = std::dynamic_pointer_cast<SingleValue>(value);
if (!val)
throw std::runtime_error("Cannot set array element to another array");
varC->set(idx->getInt(), *val);
return value;
}
std::shared_ptr<Value> IdentifierAST::eval(Context* parent) {
auto var = parent->getVariable(this->name);
if (!var)
throw std::runtime_error("Variable not found: " + this->name);
return var->copy();
}
std::shared_ptr<Value> IdentifierAST::set(Context* parent, std::shared_ptr<Value> value) {
parent->setVariable(this->name, value);
return value;
}
std::shared_ptr<Value> FuncDefAST::eval(Context* parent, std::vector<std::shared_ptr<Value>> arg) {
// Create a new context
auto ctx = std::make_shared<Context>(parent);
// Set arguments
for (size_t i = 0; i < this->args.size(); i++) {
if (i >= arg.size())
throw std::runtime_error("Too few arguments");
ctx->setVariable(this->args[i], arg[i]);
}
// Eval body
std::shared_ptr<Value> ret = std::make_shared<SingleValue>(); // which is nil
for (auto& stmt : this->body) {
auto* ptr = stmt.get();
if (ptr->getType() == T_IfAST)
ptr = dynamic_cast<IfAST*>(ptr)->getResult(ctx.get()).get();
if (!ptr)
continue;
if (ptr->getType() == T_ReturnAST)
return ptr->eval(ctx.get());
ret = ptr->eval(ctx.get());
}
return ret;
}
std::shared_ptr<Value> FuncCallAST::eval(Context* parent) {
// Get the function
auto func = parent->getFunc(this->name);
if (!func)
throw std::runtime_error("Function not defined: " + this->name);
// Eval arguments
std::vector<std::shared_ptr<Value>> arg;
for (const auto& a : this->args) {
arg.push_back(a->eval(parent));
}
// Get the global context
auto globalCtx = parent;
while (globalCtx->getParent())
globalCtx = globalCtx->getParent();
// Eval under global context
return func->eval(globalCtx, arg);
}
std::shared_ptr<ExprAST> IfAST::getResult(Context* parent) {
// Eval condition
auto c = this->cond->eval(parent);
bool ok = c->isArray();
if (!ok) {
auto cc = std::dynamic_pointer_cast<SingleValue>(c);
if (!cc)
throw std::runtime_error("Unexpected error");
if (!cc->isNil())
ok = true;
}
return ok ? this->then : this->els;
}
std::shared_ptr<Value> LoopAST::eval(Context* parent) {
// Eval start & end
auto s = this->start->eval(parent);
auto e = this->end->eval(parent);
if (s->isArray() || e->isArray())
throw std::runtime_error("Cannot eval loop start or end as integer");
auto sc = std::dynamic_pointer_cast<SingleValue>(s);
auto ec = std::dynamic_pointer_cast<SingleValue>(e);
if (!sc || !ec)
throw std::runtime_error("Unexpected error");
if (sc->isT()) {
while (true) {
// No need to create a new context
for (auto& stmt : this->body) {
auto* ptr = stmt.get();
if (ptr->getType() == T_IfAST)
ptr = dynamic_cast<IfAST*>(ptr)->getResult(parent).get();
if (!ptr)
continue;
if (ptr->getType() == T_ReturnAST)
return ptr->eval(parent);
}
}
} else {
if (!sc->isInt() || !ec->isInt())
throw std::runtime_error("Cannot eval loop start or end as integer");
auto ss = sc->getInt();
auto ee = ec->getInt();
// Create a new context
auto ctx = std::make_shared<Context>(parent);
for (auto i = ss; i <= ee; ++i) {
ctx->setVariable(this->loopVar, std::make_shared<SingleValue>(i));
for (auto& stmt : this->body) {
auto* ptr = stmt.get();
if (ptr->getType() == T_IfAST)
ptr = dynamic_cast<IfAST*>(ptr)->getResult(ctx.get()).get();
if (!ptr)
continue;
if (ptr->getType() == T_ReturnAST)
return ptr->eval(ctx.get());
}
}
}
// Return nil
return std::make_shared<SingleValue>();
}
std::shared_ptr<Value> UnaryAST::eval(Context* parent) {
auto val = this->expr->eval(parent);
auto valS = std::dynamic_pointer_cast<SingleValue>(val);
switch (this->op) {
case NOT:
return std::make_shared<SingleValue>(val->isArray() || (valS && !valS->isNil()));
case MAKE_ARRAY:
if (val->isArray() || (valS && !valS->isInt()))
throw std::runtime_error("Array size must be an integer");
return std::make_shared<ArrayValue>(valS->getInt());
case PRINT:
std::cout << val->toString() << std::endl;
return val;
default:
throw std::runtime_error("Unexpected error");
}
}
std::shared_ptr<Value> BinaryAST::eval(Context* parent) {
// All binary operators requires
// both operands to be int / float.
auto lv = std::dynamic_pointer_cast<SingleValue>(this->lhs->eval(parent));
auto rv = std::dynamic_pointer_cast<SingleValue>(this->rhs->eval(parent));
if (!lv || !rv || !(lv->isInt() || lv->isFloat()) || !(rv->isInt() || rv->isFloat()))
throw std::runtime_error("Both operands must be int or float");
if (lv->isFloat() || rv->isFloat()) {
double l = lv->isInt() ? lv->getInt() : lv->getFloat();
double r = rv->isInt() ? rv->getInt() : rv->getFloat();
switch (this->op) {
case LESS:
return std::make_shared<SingleValue>(l < r);
case LESS_EQUAL:
return std::make_shared<SingleValue>(l <= r);
case GREATER:
return std::make_shared<SingleValue>(l > r);
case GREATER_EQUAL:
return std::make_shared<SingleValue>(l >= r);
case MOD:
case REM:
return std::make_shared<SingleValue>(std::fmod(l, r));
default:
throw std::runtime_error("This operator cannot be applied to float");
}
} else {
std::int64_t l = lv->getInt();
std::int64_t r = rv->getInt();
switch (this->op) {
case LESS:
return std::make_shared<SingleValue>(l < r);
case LESS_EQUAL:
return std::make_shared<SingleValue>(l <= r);
case GREATER:
return std::make_shared<SingleValue>(l > r);
case GREATER_EQUAL:
return std::make_shared<SingleValue>(l >= r);
case MOD:
case REM:
return std::make_shared<SingleValue>(l % r);
case LOGNOR:
return std::make_shared<SingleValue>(~(l | r));
default:
throw std::runtime_error("Unexpected error");
}
}
}
std::shared_ptr<Value> ListAST::eval(Context* parent) {
auto ret = this->exprs[0]->eval(parent);
auto retS = std::dynamic_pointer_cast<SingleValue>(ret);
if (this->exprs.size() == 1) {
switch (this->op) {
case LOGAND:
case LOGNOR:
case LOGXOR:
case LOGEQV:
if (ret->isArray() || (retS && !retS->isInt()))
throw std::runtime_error("Cannot apply selected operator to non-integer");
case MAX:
case MIN:
case PLUS:
case MULTIPLY:
if (ret->isArray() || (retS && !retS->isInt() && !retS->isFloat()))
throw std::runtime_error("Cannot apply selected operator to non-integer or non-float");
case AND:
case OR:
return ret;
case EQUAL:
case NOT_EQUAL:
if (ret->isArray() || (retS && !retS->isInt() && !retS->isFloat()))
throw std::runtime_error("Cannot apply selected operator to non-integer or non-float");
return std::make_shared<SingleValue>(true);
default:;
}
throw std::runtime_error("Invalid argument count for selected operator");
}
// Transform #1: Eval all values.
std::vector<std::shared_ptr<Value>> vals;
std::transform(this->exprs.begin(), this->exprs.end(), std::back_inserter(vals), [&](std::shared_ptr<ExprAST>& ptr) {
return ptr->eval(parent);
});
// For And, return NIL if any value is NIL. Otherwise, return the last value.
if (this->op == AND) {
if (std::any_of(vals.begin(), vals.end(), [](std::shared_ptr<Value>& ptr) {
auto val = std::dynamic_pointer_cast<SingleValue>(ptr);
return ptr->isArray() && val && val->isNil();
}))
return std::make_shared<SingleValue>(false);
return vals.back();
}
// For Or, return the first non-NIL value or NIL if all values are NIL.
if (this->op == OR) {
auto it = std::find_if(vals.begin(), vals.end(), [](std::shared_ptr<Value>& ptr) {
auto val = std::dynamic_pointer_cast<SingleValue>(ptr);
return !ptr->isArray() || (val && !val->isNil());
});
if (it == vals.end())
return std::make_shared<SingleValue>(false);
return *it;
}
// Now, all operators require all values to be int / float.
if (std::any_of(vals.begin(), vals.end(), [](std::shared_ptr<Value>& ptr) {
auto val = std::dynamic_pointer_cast<SingleValue>(ptr);
return ptr->isArray() || (val && !val->isInt() && !val->isFloat());
}))
throw std::runtime_error("All values must be int or float");
// Transform #2: Convert all values
std::vector<std::shared_ptr<SingleValue>> vals2;
std::transform(vals.begin(), vals.end(), std::back_inserter(vals2), [](std::shared_ptr<Value>& ptr) {
return std::dynamic_pointer_cast<SingleValue>(ptr);
});
// Assure that no value is nullptr.
if (std::any_of(vals2.begin(), vals2.end(), [](std::shared_ptr<SingleValue>& ptr) {
return ptr == nullptr;
}))
throw std::runtime_error("Unexpected error");
bool hasFloat = std::any_of(vals2.begin(), vals2.end(), [](std::shared_ptr<SingleValue>& ptr) {
return ptr->isFloat();
});
std::vector<std::int64_t> intVal;
std::vector<double> floatVal;
switch (this->op) {
default:
throw std::runtime_error("Unexpected error");
case LOGAND:
case LOGIOR:
case LOGXOR:
case LOGEQV:
if (hasFloat)
throw std::runtime_error("Cannot apply selected operator to non-integer");
std::transform(vals2.begin(), vals2.end(), std::back_inserter(intVal), [&](std::shared_ptr<SingleValue>& ptr) {
return ptr->getInt();
});
return std::make_shared<SingleValue>(std::reduce(intVal.begin() + 1, intVal.end(), intVal[0], [this](std::int64_t x, std::int64_t y) {
switch (this->op) {
case LOGAND:
return x & y;
case LOGIOR:
return x | y;
case LOGXOR:
return x ^ y;
case LOGEQV:
return ~(x ^ y);
default:
throw std::runtime_error("Unexpected error");
}
}));
case MAX:
return std::max_element(vals2.begin(), vals2.end(), [](std::shared_ptr<SingleValue>& x, std::shared_ptr<SingleValue>& y) {
return (x->isFloat() ? x->getFloat() : x->getInt()) < (y->isFloat() ? y->getFloat() : y->getInt());
})->operator->()->copy();
case MIN:
return std::min_element(vals2.begin(), vals2.end(), [](std::shared_ptr<SingleValue>& x, std::shared_ptr<SingleValue>& y) {
return (x->isFloat() ? x->getFloat() : x->getInt()) < (y->isFloat() ? y->getFloat() : y->getInt());
})->operator->()->copy();
case EQUAL:
return std::all_of(vals2.begin() + 1, vals2.end(), [&](std::shared_ptr<SingleValue>& ptr) {
return ptr->operator==(*vals2[0]);
}) ? std::make_shared<SingleValue>(true) : std::make_shared<SingleValue>(false);
case NOT_EQUAL:
return !std::all_of(vals2.begin() + 1, vals2.end(), [&](std::shared_ptr<SingleValue>& ptr) {
return ptr->operator==(*vals2[0]);
}) ? std::make_shared<SingleValue>(true) : std::make_shared<SingleValue>(false);
case PLUS:
case MINUS:
case MULTIPLY:
case DIVIDE:
auto opFunc = [this](auto x, auto y) {
switch (this->op) {
case PLUS:
return x + y;
case MINUS:
return x - y;
case MULTIPLY:
return x * y;
case DIVIDE:
return x / y;
default:
throw std::runtime_error("Unexpected error");
}
};
if (hasFloat) {
std::transform(vals2.begin(), vals2.end(), std::back_inserter(floatVal), [&](std::shared_ptr<SingleValue>& ptr) {
return (ptr->isFloat() ? ptr->getFloat() : ptr->getInt());
});
return std::make_shared<SingleValue>(std::reduce(floatVal.begin() + 1, floatVal.end(), floatVal[0], opFunc));
} else {
std::transform(vals2.begin(), vals2.end(), std::back_inserter(intVal), [&](std::shared_ptr<SingleValue>& ptr) {
return ptr->getInt();
});
return std::make_shared<SingleValue>(std::reduce(intVal.begin() + 1, intVal.end(), intVal[0], opFunc));
}
}
}
std::shared_ptr<Value> VarOpAST::eval(Context* parent) {
// DEFVAR || SETQ
if (this->op == SETQ && !parent->hasVariable(this->name)) {
throw std::runtime_error("Variable not defined: " + this->name);
}
// Eval value
auto val = this->expr->eval(parent);
parent->setVariable(this->name, val->copy());
return val;
}
std::shared_ptr<Value> LValOpAST::eval(Context* parent) {
// Eval value
auto val = this->expr->eval(parent);
auto lv = std::dynamic_pointer_cast<LValueAST>(this->lval);
if (!lv)
throw std::runtime_error("Unexpected error");
if (SETF == this->op)
return lv->set(parent, val);
auto original = lv->eval(parent);
auto originalVal = std::dynamic_pointer_cast<SingleValue>(original);
if (original->isArray() || !originalVal || !originalVal->isInt())
throw std::runtime_error("Cannot apply INC or DEC to non-integer value");
auto base = originalVal->getInt();
auto valVal = std::dynamic_pointer_cast<SingleValue>(val);
if (val->isArray() || !valVal || !valVal->isInt())
throw std::runtime_error("Cannot INC or DEC by non-integer value");
auto delta = valVal->getInt();
switch (this->op) {
case DECF:
delta = -delta;
case INCF:
base += delta;
return lv->set(parent, std::make_shared<SingleValue>(base));
default:;
}
throw std::runtime_error("Unexpected error");
}
std::shared_ptr<Value> ReturnAST::eval(Context* parent) {
return this->expr->eval(parent);
}
} // end of namespace DragonLisp

33
AST.h
View File

@ -22,6 +22,7 @@ enum ASTType {
T_ListAST,
T_VarOpAST,
T_LValOpAST,
T_ReturnAST,
};
/// BaseAST - Base class for all AST nodes.
@ -97,7 +98,7 @@ private:
std::vector<std::shared_ptr<ExprAST>> args;
public:
FuncCallAST(std::string name, std::vector<std::unique_ptr<ExprAST>> args);
FuncCallAST(std::string name, std::vector<std::shared_ptr<ExprAST>> args) : name(std::move(name)), args(std::move(args)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
@ -113,7 +114,7 @@ private:
std::shared_ptr<ExprAST> els;
public:
IfAST(std::unique_ptr<ExprAST> cond, std::unique_ptr<ExprAST> then, std::unique_ptr<ExprAST> els);
IfAST(std::shared_ptr<ExprAST> cond, std::shared_ptr<ExprAST> then, std::shared_ptr<ExprAST> els) : cond(std::move(cond)), then(std::move(then)), els(std::move(els)) {}
std::shared_ptr<Value> eval(Context* parent) override final {
throw std::runtime_error("You should use IfAST::getResult() instead of IfAST::eval()");
@ -131,10 +132,10 @@ private:
std::string loopVar;
std::shared_ptr<ExprAST> start;
std::shared_ptr<ExprAST> end;
std::shared_ptr<FuncDefAST> body;
std::vector<std::shared_ptr<ExprAST>> body;
public:
LoopAST(std::string loopVar, std::unique_ptr<ExprAST> start, std::unique_ptr<ExprAST> end, std::unique_ptr<FuncDefAST> body);
LoopAST(std::string loopVar, std::shared_ptr<ExprAST> start, std::shared_ptr<ExprAST> end, std::vector<std::shared_ptr<ExprAST>> body) : loopVar(std::move(loopVar)), start(std::move(start)), end(std::move(end)), body(std::move(body)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
@ -149,7 +150,7 @@ private:
Token op;
public:
UnaryAST(std::unique_ptr<ExprAST> expr, Token op);
UnaryAST(std::shared_ptr<ExprAST> expr, Token op) : expr(std::move(expr)), op(std::move(op)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
@ -165,7 +166,7 @@ private:
Token op;
public:
BinaryAST(std::unique_ptr<ExprAST> lhs, std::unique_ptr<ExprAST> rhs, Token op);
BinaryAST(std::shared_ptr<ExprAST> lhs, std::shared_ptr<ExprAST> rhs, Token op) : lhs(std::move(lhs)), rhs(std::move(rhs)), op(std::move(op)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
@ -180,7 +181,7 @@ private:
Token op;
public:
ListAST(std::vector<std::unique_ptr<ExprAST>> exprs, Token op);
ListAST(std::vector<std::shared_ptr<ExprAST>> exprs, Token op) : exprs(std::move(exprs)), op(std::move(op)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
@ -196,7 +197,7 @@ private:
Token op;
public:
VarOpAST(std::string name, std::unique_ptr<ExprAST> expr, Token op);
VarOpAST(std::string name, std::shared_ptr<ExprAST> expr, Token op) : name(std::move(name)), expr(std::move(expr)), op(std::move(op)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
@ -212,7 +213,7 @@ private:
Token op;
public:
LValOpAST(std::unique_ptr<ExprAST> lval, std::unique_ptr<ExprAST> expr, Token op);
LValOpAST(std::shared_ptr<ExprAST> lval, std::shared_ptr<ExprAST> expr, Token op) : lval(std::move(lval)), expr(std::move(expr)), op(std::move(op)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
@ -221,6 +222,20 @@ public:
}
};
class ReturnAST : public ExprAST {
private:
std::shared_ptr<ExprAST> expr;
public:
explicit ReturnAST(std::shared_ptr<ExprAST> expr) : expr(std::move(expr)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
inline ASTType getType() const override final {
return T_ReturnAST;
}
};
}
#endif // __DRAGON_LISP_AST_H__

View File

@ -2,6 +2,7 @@
#define __DRAGON_LISP_CONTEXT_H__
#include <memory>
#include <utility>
#include <variant>
#include "value.h"
@ -30,7 +31,7 @@ public:
std::shared_ptr<Value> getVariable(const std::string& name) const {
if (this->variables.contains(name))
return this->variables[name];
return this->variables.at(name);
if (this->parent)
return this->parent->getVariable(name);
return nullptr;
@ -39,6 +40,28 @@ public:
void setVariable(const std::string& name, std::shared_ptr<Value> value) {
this->variables[name] = std::move(value);
}
bool hasVariable(const std::string& name) const {
if (this->variables.contains(name))
return true;
if (this->parent)
return this->parent->hasVariable(name);
return false;
}
std::shared_ptr<FuncDefAST> getFunc(const std::string& name) const {
if (this->funcs->contains(name))
return (*this->funcs)[name];
return nullptr;
}
void setFunc(const std::string& name, std::shared_ptr<FuncDefAST> value) {
(*this->funcs)[name] = std::move(value);
}
Context* getParent() const {
return this->parent;
}
};
}

51
value.h
View File

@ -14,6 +14,10 @@ using ValueVariant = std::variant<std::monostate, std::int64_t, double, std::str
class Value {
public:
virtual bool isArray() const = 0;
virtual std::shared_ptr<Value> copy() const = 0;
virtual std::string toString() const = 0;
};
class SingleValue : public Value {
@ -31,20 +35,14 @@ public:
explicit SingleValue(std::string v) : value(std::move(v)) {}
explicit SingleValue(bool v) : type(v ? TYPE_T : TYPE_NIL), value() {}
SingleValue() : SingleValue(ValueType::TYPE_NIL) {}
bool isArray() const override final {
return false;
}
static SingleValue makeT() {
return SingleValue(ValueType::TYPE_T);
}
static SingleValue makeNil() {
return SingleValue(ValueType::TYPE_NIL);
}
ValueType getType() const {
return this->type;
}
@ -88,6 +86,26 @@ public:
void setValue(ValueVariant v) {
this->value = std::move(v);
}
std::shared_ptr<Value> copy() const override final {
return std::make_shared<SingleValue>(*this);
}
std::string toString() const override final {
if (this->isInt())
return std::to_string(this->getInt());
if (this->isFloat())
return std::to_string(this->getFloat());
if (this->isString())
return this->getString();
if (this->isT())
return "T";
return "NIL";
}
bool operator==(const SingleValue& rhs) const {
return type == rhs.type && value == rhs.value;
}
};
class ArrayValue : public Value {
@ -115,6 +133,10 @@ public:
return this->values[i];
}
void set(std::size_t i, SingleValue v) {
this->values[i] = std::move(v);
}
std::vector<SingleValue>& getValues() {
return this->values;
}
@ -122,6 +144,19 @@ public:
const std::vector<SingleValue>& getValues() const {
return this->values;
}
std::shared_ptr<Value> copy() const override final {
return std::make_shared<ArrayValue>(*this);
}
std::string toString() const override final {
std::string result = "[";
for (const auto& i : this->values)
result.append(i.toString()).append(", ");
result.pop_back();
result.back() = ']';
return result;
}
};
class _Unused_Variable {