Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions tests/compress/graphs/sddl2/test_sddl2_code_execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,119 @@ TEST_F(SDDL2CodeExecutionTest, AssumeRecordWithGlobalVariableReference)
expect_success(prog, input, expected_sizes);
}

// ============================================================================
// Parameterized Records
// ============================================================================

TEST_F(SDDL2CodeExecutionTest, ConsumeParameterizedRecord)
{
const std::vector<size_t> expected_sizes = { 40 };
const auto input = gen<uint8_t>(sum(expected_sizes));

const auto prog = R"(
Record Entry(N) = {
items: Int32LE[N]
}
: Entry(10)
)";

expect_success(prog, input, expected_sizes);
}

TEST_F(SDDL2CodeExecutionTest, ConsumeParameterizedRecordMultipleParams)
{
const std::vector<size_t> expected_sizes = { 22 };
const auto input = gen<uint8_t>(sum(expected_sizes));

const auto prog = R"(
Record Foo(A, B) = {
x: Int32LE[A],
y: Int16LE[B]
}
ParamFoo = Foo(3, 5)
: ParamFoo
)";

expect_success(prog, input, expected_sizes);
}

TEST_F(SDDL2CodeExecutionTest, ConsumeParameterizedRecordMultipleCalls)
{
const std::vector<size_t> expected_sizes = { 8, 12 };
const auto input = gen<uint8_t>(sum(expected_sizes));

const auto prog = R"(
Record Entry(N) = {
items: Int32LE[N]
}
: Entry(2)
: Entry(3)
)";

expect_success(prog, input, expected_sizes);
}

TEST_F(SDDL2CodeExecutionTest, ConsumeArrayOfParameterizedRecord)
{
const std::vector<size_t> expected_sizes = { 60 };
const auto input = gen<uint8_t>(sum(expected_sizes));

const auto prog = R"(
Record Entry(N) = {
items: Int32LE[N]
}
: Entry(5)[3]
)";

expect_success(prog, input, expected_sizes);
}

TEST_F(SDDL2CodeExecutionTest, AssumeParameterizedRecord)
{
const std::vector<size_t> expected_sizes = { 1, 21 };
std::vector<uint8_t> input(22, 0);
input[0] = 5; // n = 5
input[21] = 3; // tag = 3

const auto prog = R"(
Record Entry(N) = {
items: Int32LE[N],
tag: Byte
}
n: Byte
entry: Entry(n)
expect entry.tag == 3
)";

expect_success(prog, input, expected_sizes);
}

TEST_F(SDDL2CodeExecutionTest, AssumeNestedParameterizedRecord)
{
const std::vector<size_t> expected_sizes = { 1, 10 };
std::vector<uint8_t> input(11, 0);
input[0] = 2; // n = 2
input[9] = 3; // foo.bar.tag = 3
input[10] = 5; // foo.tag = 5

const auto prog = R"(
Record Bar(N) = {
items: Int32LE[N],
tag: Byte
}
Record Foo(M) = {
bar: Bar(M),
tag: Byte
}
N: Byte
foo: Foo(N)
expect foo.bar.tag == 3
expect foo.tag == 5
expect N == 2
)";

expect_success(prog, input, expected_sizes);
}
} // namespace testing
} // namespace sddl2
} // namespace openzl
3 changes: 3 additions & 0 deletions tools/sddl2/compiler/Syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ static const std::map<Symbol, poly::string_view> syms_to_debug_strs{

{ Symbol::BYTES, "BYTES" },
{ Symbol::RECORD, "RECORD" },

{ Symbol::WHEN, "WHEN" },
};

poly::string_view sym_to_debug_str(Symbol sym)
Expand Down Expand Up @@ -185,6 +187,7 @@ const std::vector<std::pair<poly::string_view, Symbol>> strs_to_syms{
{ "BFloat16BE", Symbol::BF16BE },
{ "Bytes", Symbol::BYTES },
{ "Record", Symbol::RECORD },
{ "when", Symbol::WHEN },
};

/* These symbols can't actually be accessed via these names. */
Expand Down
5 changes: 4 additions & 1 deletion tools/sddl2/compiler/Syntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ enum class Symbol {

// Other Fields
BYTES,
RECORD
RECORD,

// Control Flow
WHEN
};

/// @returns a name string for a symbol.
Expand Down
86 changes: 74 additions & 12 deletions tools/sddl2/compiler/codegen/CodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@ class CodeGeneratorImpl {
(void)log_;
AssemblyOutput output;
for (const auto& node : ast) {
output += generateOp(*node->as_op());
if (auto op = node->as_op()) {
output += generateOp(*op);
} else if (auto when = node->as_when()) {
(void)when;
throw CodegenError(node->loc(), "Not yet implemented!");
} else {
throw InvariantViolation(
node->loc(), "Expected an operation or when.");
}
}
return output.str();
}
Expand Down Expand Up @@ -177,12 +185,27 @@ class CodeGeneratorImpl {
case ConvertedNodeType::BYTES:
case ConvertedNodeType::ARRAY:
case ConvertedNodeType::RECORD:
case ConvertedNodeType::CALL:
case ConvertedNodeType::WHEN:
default:
throw InvariantViolation(
node->loc(), "Expected a value, got a type.");
}
}

AssemblyOutput bindParams(const ASTRecord* record, const ASTVec& args)
{
AssemblyOutput output;
for (size_t i = 0; i < record->params().size(); ++i) {
const auto& param_name = record->params()[i]->as_var()->name();
auto reg = registers_.assign(param_name);
output += generateValue(args[i]);
output += "push.i64 " + std::to_string(reg);
output += "var.store";
}
return output;
}

TypeResult generateType(const ASTPtr& type)
{
AssemblyOutput output;
Expand All @@ -209,10 +232,16 @@ class CodeGeneratorImpl {
case ConvertedNodeType::RECORD: {
auto record = type->as_record();
for (const auto& field : record->fields()) {
auto assume = field->as_op();
const auto& field_type = assume->args()[1];
auto [field_asm, _] = generateType(field_type);
output += std::move(field_asm);
if (auto assume = field->as_op()) {
const auto& field_type = assume->args()[1];
auto [field_asm, _] = generateType(field_type);
output += std::move(field_asm);
}
if (auto when = field->as_when()) {
(void)when;
throw CodegenError(
field->loc(), "Not yet implemented!");
}
}
output += "push.i64 " + std::to_string(record->fields().size());
output += "type.structure";
Expand Down Expand Up @@ -242,8 +271,27 @@ class CodeGeneratorImpl {
output += "type.fixed_array";
return { std::move(output), type };
}
case ConvertedNodeType::CALL: {
auto call = type->as_call();
const auto& target_name = call->target()->as_var()->name();
auto target = type_aliases_.at(target_name);

auto saved_regs = registers_;
// Bind params to registers
output += bindParams(target->as_record(), call->args());

// Generate the record body
auto [record_asm, _] = generateType(target);
output += std::move(record_asm);

// Restore the registers
registers_ = std::move(saved_regs);

return { std::move(output), type };
}
case ConvertedNodeType::NUM:
case ConvertedNodeType::OP:
case ConvertedNodeType::WHEN:
default:
throw InvariantViolation(
type->loc(), "Expected a type, got a value.");
Expand Down Expand Up @@ -275,30 +323,35 @@ class CodeGeneratorImpl {
// Flatten the member chain
const auto& [root, path] = flattenMember(op);

// Get the type info
auto type = assumed_types_[root.name()];

// Load the base offset
output += "push.i64 " + std::to_string(registers_.get(root.name()));
output += "var.load";

auto type = assumed_types_[root.name()];
auto saved_regs = registers_;

// Walk through the path, accumulating offsets
for (const auto& name : path) {
auto curr_record = type->as_record();
if (const auto call = type->as_call()) {
auto const target_name = call->target()->as_var()->name();
curr_record = type_aliases_.at(target_name)->as_record();
output += bindParams(curr_record, call->args());
}
for (const auto& field : curr_record->fields()) {
auto assume = field->as_op();
auto& field_name = assume->args()[0]->as_var()->name();
auto [field_asm, field_type] = generateType(assume->args()[1]);
if (name == field_name) {
type = field_type;
log_(0) << " Found " << name << std::endl;
break;
}
output += std::move(field_asm);
output += "type.sizeof";
output += "math.add";
}
}
registers_ = std::move(saved_regs);

// Load the final value if it's a builtin type
if (auto builtin = type->as_builtin_field()) {
Expand All @@ -312,9 +365,18 @@ class CodeGeneratorImpl {

AssemblyOutput generateAssign(const ASTOp& op)
{
auto& var = *op.args()[0]->as_var();
auto [type_asm, type] = generateType(op.args()[1]);
auto& var = *op.args()[0]->as_var();
auto& rhs = op.args()[1];

// If the RHS is a parameterized record, do not generate the code
if (auto record = rhs->as_record()) {
if (!record->params().empty()) {
type_aliases_[var.name()] = rhs;
return AssemblyOutput{};
}
}

auto [type_asm, type] = generateType(rhs);
type_aliases_[var.name()] = type;

AssemblyOutput output;
Expand Down Expand Up @@ -343,7 +405,7 @@ class CodeGeneratorImpl {
output += builtin_field_to_load.at(builtin->kw());
output +=
"push.i64 " + std::to_string(registers_.assign(var.name()));
} else if (type->as_record()) {
} else if (type->as_record() || type->as_call()) {
// Otherwise, save the current position and type information
output +=
"push.i64 " + std::to_string(registers_.assign(var.name()));
Expand Down
Loading
Loading