fix: re-write loop to make it work

This commit is contained in:
Eatswap 2022-12-13 17:03:36 +08:00
parent 8367d5dd54
commit 6787722945
Signed by: Eatswap
GPG Key ID: BE661106A1F3FA0B
4 changed files with 174 additions and 56 deletions

81
AST.cpp
View File

@ -121,20 +121,9 @@ std::shared_ptr<ExprAST> IfAST::getResult(Context* parent) {
} }
std::shared_ptr<Value> LoopAST::eval(Context* parent) { std::shared_ptr<Value> LoopForeverAST::eval(Context* parent) {
// Eval start & end // No context is needed
auto s = this->start->eval(parent);
auto e = this->end->eval(parent);
if (s->isArray() || e->isArray())
throw std::runtime_error("Cannot eval loop start or end as integer");
auto sc = std::dynamic_pointer_cast<SingleValue>(s);
auto ec = std::dynamic_pointer_cast<SingleValue>(e);
if (!sc || !ec)
throw std::runtime_error("Unexpected error");
if (sc->isT()) {
while (true) { while (true) {
// No need to create a new context
for (auto& stmt : this->body) { for (auto& stmt : this->body) {
auto* ptr = stmt.get(); auto* ptr = stmt.get();
if (ptr->getType() == T_IfAST) if (ptr->getType() == T_IfAST)
@ -143,21 +132,30 @@ std::shared_ptr<Value> LoopAST::eval(Context* parent) {
continue; continue;
if (ptr->getType() == T_ReturnAST) if (ptr->getType() == T_ReturnAST)
return ptr->eval(parent); return ptr->eval(parent);
stmt->eval(parent); ptr->eval(parent);
} }
} }
} else { throw std::runtime_error("Unexpected error");
if (!sc->isInt() || !ec->isInt()) }
throw std::runtime_error("Cannot eval loop start or end as integer");
auto ss = sc->getInt();
auto ee = ec->getInt();
std::shared_ptr <Value> LoopForAST::eval(Context* parent) {
// Create a new context // Create a new context
auto ctx = std::make_shared<Context>(parent); auto ctx = std::make_shared<Context>(parent);
for (auto i = ss; i <= ee; ++i) { // Eval condition
ctx->setVariable(this->loopVar, std::make_shared<SingleValue>(i)); auto s = std::dynamic_pointer_cast<SingleValue>(this->start->eval(parent)->copy());
auto e = std::dynamic_pointer_cast<SingleValue>(this->end->eval(parent));
// Assert that s and e are numeric
if (!s || !e || (!s->isInt() && !s->isFloat()) || (!e->isInt() && !e->isFloat()))
throw std::runtime_error("LoopForAST: start and end must be numeric");
// Main loop
while (*s <= *e) {
// Set the variable
ctx->setVariable(this->name, s);
// Eval body
for (auto& stmt : this->body) { for (auto& stmt : this->body) {
auto* ptr = stmt.get(); auto* ptr = stmt.get();
if (ptr->getType() == T_IfAST) if (ptr->getType() == T_IfAST)
@ -166,16 +164,49 @@ std::shared_ptr<Value> LoopAST::eval(Context* parent) {
continue; continue;
if (ptr->getType() == T_ReturnAST) if (ptr->getType() == T_ReturnAST)
return ptr->eval(ctx.get()); return ptr->eval(ctx.get());
stmt->eval(ctx.get()); ptr->eval(ctx.get());
} }
// Increment
s->operator++();
}
// Return nil
return std::make_shared<SingleValue>(false);
}
std::shared_ptr <Value> LoopDoTimesAST::eval(Context* parent) {
// Create a new context
auto ctx = std::make_shared<Context>(parent);
// Eval condition
auto terminate = std::dynamic_pointer_cast<SingleValue>(this->times->eval(parent)->copy());
if (!terminate || !terminate->isInt())
throw std::runtime_error("DOTIMES: times must be an integer");
auto n = terminate->getInt();
// Main Loop
for (std::int64_t i = 0; i < n; ++i) {
// Set Variable
ctx->setVariable(this->name, std::make_shared<SingleValue>(i));
// Eval Body
for (auto& stmt : this->body) {
auto* ptr = stmt.get();
if (ptr->getType() == T_IfAST)
ptr = dynamic_cast<IfAST*>(ptr)->getResult(ctx.get()).get();
if (!ptr)
continue;
if (ptr->getType() == T_ReturnAST)
return ptr->eval(ctx.get());
ptr->eval(ctx.get());
} }
} }
// Return nil // Return nil
return std::make_shared<SingleValue>(); return std::make_shared<SingleValue>(false);
} }
std::shared_ptr<Value> UnaryAST::eval(Context* parent) { std::shared_ptr<Value> UnaryAST::eval(Context* parent) {
auto val = this->expr->eval(parent); auto val = this->expr->eval(parent);
auto valS = std::dynamic_pointer_cast<SingleValue>(val); auto valS = std::dynamic_pointer_cast<SingleValue>(val);

43
AST.h
View File

@ -17,6 +17,9 @@ enum ASTType {
T_FuncCallAST, T_FuncCallAST,
T_IfAST, T_IfAST,
T_LoopAST, T_LoopAST,
T_LoopForeverAST,
T_LoopForAST,
T_LoopDoTimesAST,
T_UnaryAST, T_UnaryAST,
T_BinaryAST, T_BinaryAST,
T_ListAST, T_ListAST,
@ -132,20 +135,52 @@ public:
} }
}; };
class LoopAST : public ExprAST { class LoopAST : public ExprAST {};
class LoopForeverAST : public LoopAST {
private: private:
std::string loopVar; std::vector<std::shared_ptr<ExprAST>> body;
public:
explicit LoopForeverAST(std::vector<std::shared_ptr<ExprAST>> body) : body(std::move(body)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
inline ASTType getType() const override final {
return T_LoopForeverAST;
}
};
class LoopForAST : public LoopAST {
private:
std::string name;
std::shared_ptr<ExprAST> start; std::shared_ptr<ExprAST> start;
std::shared_ptr<ExprAST> end; std::shared_ptr<ExprAST> end;
std::vector<std::shared_ptr<ExprAST>> body; std::vector<std::shared_ptr<ExprAST>> body;
public: public:
LoopAST(std::string loopVar, std::shared_ptr<ExprAST> start, std::shared_ptr<ExprAST> end, std::vector<std::shared_ptr<ExprAST>> body) : loopVar(std::move(loopVar)), start(std::move(start)), end(std::move(end)), body(std::move(body)) {} LoopForAST(std::string name, std::shared_ptr<ExprAST> start, std::shared_ptr<ExprAST> end, std::vector<std::shared_ptr<ExprAST>> body) : name(std::move(name)), start(std::move(start)), end(std::move(end)), body(std::move(body)) {}
std::shared_ptr<Value> eval(Context* parent) override final; std::shared_ptr<Value> eval(Context* parent) override final;
inline ASTType getType() const override final { inline ASTType getType() const override final {
return T_LoopAST; return T_LoopForAST;
}
};
class LoopDoTimesAST : public LoopAST {
private:
std::string name;
std::shared_ptr<ExprAST> times;
std::vector<std::shared_ptr<ExprAST>> body;
public:
LoopDoTimesAST(std::string name, std::shared_ptr<ExprAST> times, std::vector<std::shared_ptr<ExprAST>> body) : name(std::move(name)), times(std::move(times)), body(std::move(body)) {}
std::shared_ptr<Value> eval(Context* parent) override final;
inline ASTType getType() const override final {
return T_LoopDoTimesAST;
} }
}; };

View File

@ -99,25 +99,21 @@ std::shared_ptr<ReturnAST> DLDriver::constructReturnAST(std::shared_ptr<ExprAST>
} }
std::shared_ptr<LoopAST> DLDriver::constructLoopAST(std::vector<std::shared_ptr<ExprAST>> body) { std::shared_ptr<LoopAST> DLDriver::constructLoopAST(std::vector<std::shared_ptr<ExprAST>> body) {
return std::make_shared<LoopAST>( return std::make_shared<LoopForeverAST>(
"",
std::make_shared<LiteralAST>(true),
std::make_shared<LiteralAST>(true),
std::move(body) std::move(body)
); );
} }
std::shared_ptr<LoopAST> DLDriver::constructLoopAST(std::string id, std::shared_ptr<ExprAST> to, std::vector<std::shared_ptr<ExprAST>> body) { std::shared_ptr<LoopAST> DLDriver::constructLoopAST(std::string id, std::shared_ptr<ExprAST> to, std::vector<std::shared_ptr<ExprAST>> body) {
return std::make_shared<LoopAST>( return std::make_shared<LoopDoTimesAST>(
std::move(id), std::move(id),
std::make_shared<LiteralAST>(std::int64_t(0)),
std::move(to), std::move(to),
std::move(body) std::move(body)
); );
} }
std::shared_ptr<LoopAST> DLDriver::constructLoopAST(std::string id, std::shared_ptr<ExprAST> from, std::shared_ptr<ExprAST> to, std::vector<std::shared_ptr<ExprAST>> body) { std::shared_ptr<LoopAST> DLDriver::constructLoopAST(std::string id, std::shared_ptr<ExprAST> from, std::shared_ptr<ExprAST> to, std::vector<std::shared_ptr<ExprAST>> body) {
return std::make_shared<LoopAST>( return std::make_shared<LoopForAST>(
std::move(id), std::move(id),
std::move(from), std::move(from),
std::move(to), std::move(to),

56
value.h
View File

@ -106,6 +106,62 @@ public:
bool operator==(const SingleValue& rhs) const { bool operator==(const SingleValue& rhs) const {
return type == rhs.type && value == rhs.value; return type == rhs.type && value == rhs.value;
} }
bool operator<(const SingleValue& rhs) const {
if ((!this->isInt() && !this->isFloat()) || (!rhs.isInt() && !rhs.isFloat()))
throw std::runtime_error("Cannot compare non-numeric values");
if (this->isInt() && rhs.isInt())
return this->getInt() < rhs.getInt();
if (this->isInt())
return this->getInt() < rhs.getFloat();
if (rhs.isInt())
return this->getFloat() < rhs.getInt();
return this->getFloat() < rhs.getFloat();
}
bool operator<=(const SingleValue& rhs) const {
if ((!this->isInt() && !this->isFloat()) || (!rhs.isInt() && !rhs.isFloat()))
throw std::runtime_error("Cannot compare non-numeric values");
if (this->isInt() && rhs.isInt())
return this->getInt() <= rhs.getInt();
if (this->isInt())
return this->getInt() <= rhs.getFloat();
if (rhs.isInt())
return this->getFloat() <= rhs.getInt();
return this->getFloat() <= rhs.getFloat();
}
SingleValue& operator++() {
if (this->isInt())
this->value = this->getInt() + 1;
else if (this->isFloat())
this->value = this->getFloat() + 1;
else
throw std::runtime_error("Cannot increment non-numeric value");
return *this;
}
SingleValue& operator--() {
if (this->isInt())
this->value = this->getInt() - 1;
else if (this->isFloat())
this->value = this->getFloat() - 1;
else
throw std::runtime_error("Cannot decrement non-numeric value");
return *this;
}
SingleValue operator++(int) {
SingleValue tmp(*this);
operator++();
return tmp;
}
SingleValue operator--(int) {
SingleValue tmp(*this);
operator--();
return tmp;
}
}; };
class ArrayValue : public Value { class ArrayValue : public Value {