Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions tools/sddl2/compiler/codegen/CodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class CodeGeneratorImpl {
case ConvertedNodeType::BYTES:
case ConvertedNodeType::ARRAY:
case ConvertedNodeType::RECORD:
case ConvertedNodeType::RECORD_FIELD:
case ConvertedNodeType::CALL:
case ConvertedNodeType::WHEN:
default:
Expand Down Expand Up @@ -232,21 +233,24 @@ class CodeGeneratorImpl {
case ConvertedNodeType::RECORD: {
auto record = type->as_record();
for (const auto& field : record->fields()) {
if (auto assume = field->as_op()) {
const auto& field_type = assume->args()[1];
auto [field_asm, _] = generateType(field_type);
output += std::move(field_asm);
}
if (auto when = field->as_when()) {
(void)when;
throw CodegenError(
field->loc(), "Not yet implemented!");
}
auto [field_asm, _] = generateType(field);
output += std::move(field_asm);
}
output += "push.i64 " + std::to_string(record->fields().size());
output += "type.structure";
return { std::move(output), type };
}
case ConvertedNodeType::RECORD_FIELD: {
auto field = type->as_record_field();
auto [field_asm, field_type] = generateType(field->type());
output += std::move(field_asm);
return { std::move(output), field_type };
}
case ConvertedNodeType::BYTES: {
auto bytes = type->as_bytes();
output += "push.type.bytes";
Expand Down Expand Up @@ -339,9 +343,10 @@ class CodeGeneratorImpl {
output += bindParams(curr_record, call->args());
}
for (const auto& field : curr_record->fields()) {
auto assume = field->as_op();
auto& field_name = assume->args()[0]->as_var()->name();
auto [field_asm, field_type] = generateType(assume->args()[1]);
auto record_field = field->as_record_field();
auto& field_name = record_field->name()->as_var()->name();
auto [field_asm, field_type] =
generateType(record_field->type());
if (name == field_name) {
type = field_type;
break;
Expand Down
1 change: 1 addition & 0 deletions tools/sddl2/compiler/optimizer/ConstFoldPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ConstFoldImpl {
switch (node->converted_node_type()) {
case ConvertedNodeType::NUM:
case ConvertedNodeType::BUILTIN_FIELD:
case ConvertedNodeType::RECORD_FIELD:
return node;
case ConvertedNodeType::VAR:
return optimizeVar(*node->as_var());
Expand Down
5 changes: 5 additions & 0 deletions tools/sddl2/compiler/optimizer/DeadVarPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class DeadVarImpl {
recordLastRefs(node->as_record()->fields());
return;
}
case ConvertedNodeType::RECORD_FIELD: {
recordLastRefs(node->as_record_field()->type());
return;
}
case ConvertedNodeType::OP: {
const auto& op = *node->as_op();
if (op.op() == Op::ASSIGN || op.op() == Op::ASSUME) {
Expand Down Expand Up @@ -101,6 +105,7 @@ class DeadVarImpl {
case ConvertedNodeType::NUM:
case ConvertedNodeType::BUILTIN_FIELD:
case ConvertedNodeType::RECORD:
case ConvertedNodeType::RECORD_FIELD:
return node;
case ConvertedNodeType::VAR:
return optimizeVar(node->as_var());
Expand Down
81 changes: 68 additions & 13 deletions tools/sddl2/compiler/parser/AST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ const ASTOp* ASTNode::as_op() const
return nullptr;
}

const ASTRecordField* ASTNode::as_record_field() const
{
return nullptr;
}

bool ASTNode::operator==(const Symbol& sym) const
{
const auto* tok = as_sym();
Expand Down Expand Up @@ -348,20 +353,39 @@ const ASTVec& ASTRecord::fields() const
return fields_;
}

ASTVec ASTRecord::extract_fields(
const SourceLocation& loc,
const ASTPtr& paren_ptr)
{
const auto* list = paren_ptr->as_list();
if (list == nullptr) {
throw InvariantViolation(
loc, "Record declaration must be given a list of fields.");
}
if (list->list_type() != ListType::CURLY) {
throw InvariantViolation(
loc, "Record declaration fields list must be curly-braced.");
/**
* Converts a list of ASTNodes into a list of ASTRecordFields. This is used
* when parsing the fields of a record.
*/
static ASTVec toRecordFields(const ASTVec& nodes)
{
ASTVec fields;
for (const auto& node : nodes) {
if (node->as_record_field()) {
fields.push_back(node);
} else if (auto when = node->as_when()) {
fields.push_back(
std::make_shared<ASTWhen>(
when->condition(), toRecordFields(when->body())));
} else if (auto op = node->as_op()) {
if (op->op() != Op::ASSUME) {
throw ParseError(
node->loc(), "Record field must be an assume op.");
};
fields.push_back(
std::make_shared<ASTRecordField>(
op->args()[0], op->args()[1]));
} else {
throw ParseError(
node->loc(), "Record field must be an op or when.");
}
}
return unwrap_parens(list->nodes());
return fields;
}

ASTVec ASTRecord::extract_fields(const SourceLocation&, const ASTPtr& paren_ptr)
{
return toRecordFields(unwrap_curly(paren_ptr));
}

ASTVec ASTRecord::extract_params(
Expand Down Expand Up @@ -511,4 +535,35 @@ const ASTVec& ASTOp::args() const
return args_;
}

ASTRecordField::ASTRecordField(ASTPtr name, ASTPtr type)
: ASTConverted(some(name).loc() + some(type).loc()),
name_(std::move(name)),
type_(std::move(type))
{
}

const ASTRecordField* ASTRecordField::as_record_field() const
{
return this;
}

void ASTRecordField::print(std::ostream& os, size_t indent) const
{
os << std::string(indent, ' ') << "RecordField:" << std::endl;
os << std::string(indent + 2, ' ') << "Name:" << std::endl;
name_->print(os, indent + 4);
os << std::string(indent + 2, ' ') << "Type:" << std::endl;
type_->print(os, indent + 4);
}

const ASTPtr& ASTRecordField::name() const
{
return name_;
}

const ASTPtr& ASTRecordField::type() const
{
return type_;
}

} // namespace openzl::sddl2
24 changes: 24 additions & 0 deletions tools/sddl2/compiler/parser/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ASTBuiltinField;
class ASTBytes;
class ASTArray;
class ASTRecord;
class ASTRecordField;
class ASTCall;
class ASTWhen;
class ASTOp;
Expand All @@ -38,6 +39,7 @@ enum class ConvertedNodeType {
BYTES,
ARRAY,
RECORD,
RECORD_FIELD,
CALL,
WHEN,
OP
Expand All @@ -64,6 +66,7 @@ class ASTNode {
virtual const ASTCall* as_call() const;
virtual const ASTWhen* as_when() const;
virtual const ASTOp* as_op() const;
virtual const ASTRecordField* as_record_field() const;

bool operator==(const Symbol& symbol) const;
bool operator!=(const Symbol& symbol) const;
Expand Down Expand Up @@ -355,6 +358,27 @@ class ASTOp : public ASTConverted {
const ASTVec args_;
};

class ASTRecordField : public ASTConverted {
public:
explicit ASTRecordField(ASTPtr name, ASTPtr type);

const ASTRecordField* as_record_field() const override;

void print(std::ostream& os, size_t indent) const override;

ConvertedNodeType converted_node_type() const override final
{
return ConvertedNodeType::RECORD_FIELD;
}

const ASTPtr& name() const;
const ASTPtr& type() const;

private:
const ASTPtr name_;
const ASTPtr type_;
};

/**
* Helper to build a synthetic AST tree rather than translating tokens 1:1.
*/
Expand Down
32 changes: 15 additions & 17 deletions tools/sddl2/compiler/semantic_analyzer/SemanticAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class SemanticAnalyzerImpl {
return analyze(*node->as_array());
case ConvertedNodeType::RECORD:
return analyze(*node->as_record());
case ConvertedNodeType::RECORD_FIELD:
return analyze(*node->as_record_field());
case ConvertedNodeType::CALL:
return analyze(*node->as_call());
case ConvertedNodeType::WHEN:
Expand Down Expand Up @@ -152,21 +154,14 @@ class SemanticAnalyzerImpl {
return Type{ TypeKind::ARRAY, &array };
}

void analyzeRecordFields(const ASTVec& fields)
Type analyze(const ASTRecordField& field)
{
for (const auto& field : fields) {
if (auto op = field->as_op()) {
if (op->op() != Op::ASSUME) {
throw SemanticError(op->loc(), "Invalid record field!");
}
analyzeAssume(*op);
} else if (auto when = field->as_when()) {
expectNumeric(analyzeNode(when->condition()));
analyzeRecordFields(when->body());
} else {
throw SemanticError(field->loc(), "Invalid record field!");
}
const auto* var = field.name()->as_var();
if (!var) {
throw SemanticError(field.loc(), "Field name must be a variable.");
}
expectFieldType(analyzeNode(field.type()));
return Type{ TypeKind::NONE };
}

Type analyze(const ASTRecord& record)
Expand All @@ -184,7 +179,10 @@ class SemanticAnalyzerImpl {
}

// Validate all fields
analyzeRecordFields(record.fields());
for (const auto& field : record.fields()) {
analyzeNode(field);
}

var_types_ = std::move(saved_vars);

return Type{ TypeKind::RECORD, &record };
Expand Down Expand Up @@ -323,10 +321,10 @@ class SemanticAnalyzerImpl {
// Find the field
std::optional<Type> found_type;
for (const auto& field : record->fields()) {
auto assume = field->as_op();
auto& name = assume->args()[0]->as_var()->name();
auto record_field = field->as_record_field();
auto& name = record_field->name()->as_var()->name();
if (name == field_name) {
found_type = assumedType(analyzeNode(assume->args()[1]));
found_type = assumedType(analyzeNode(record_field->type()));
}
}
if (!found_type) {
Expand Down
3 changes: 1 addition & 2 deletions tools/sddl2/compiler/tests/SemanticAnalyzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ TEST_F(SemanticAnalyzerTest, WhenBlockInRecordWithFieldReference)
: Data
)";

// TODO: this will pass once codegen is implemented
expect_error(prog, "code generation error");
expect_error(prog, "Undefined variable");
}

} // namespace openzl::sddl2::tests
Loading