Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ pub enum CallName {
Panic,
/// [`dbg!`].
Debug,
/// [`padding!`]
Padding(NonZeroUsize),
/// Cast from the given source type.
TypeCast(ResolvedType),
/// A custom function that was defined previously.
Expand Down Expand Up @@ -1184,6 +1186,14 @@ impl AbstractSyntaxTree for Call {
scope.track_call(from, TrackedCallName::Debug(arg_ty));
args
}
CallName::Padding(_size) => {
let args_tys = [];
check_argument_types(from.args(), &args_tys).with_span(from)?;
let args = analyze_arguments(from.args(), &args_tys, scope)?;
scope.track_call(from, TrackedCallName::Padding);

args
}
CallName::TypeCast(source) => {
if StructuralType::from(&source) != StructuralType::from(ty) {
return Err(Error::InvalidCast(source, ty.clone())).with_span(from);
Expand Down Expand Up @@ -1311,6 +1321,7 @@ impl AbstractSyntaxTree for CallName {
parse::CallName::Assert => Ok(Self::Assert),
parse::CallName::Panic => Ok(Self::Panic),
parse::CallName::Debug => Ok(Self::Debug),
parse::CallName::Padding(size) => Ok(Self::Padding(*size)),
parse::CallName::TypeCast(target) => {
scope.resolve(target).map(Self::TypeCast).with_span(from)
}
Expand Down Expand Up @@ -1570,3 +1581,73 @@ impl AsRef<Span> for ModuleAssignment {
&self.span
}
}

#[cfg(test)]
mod test {
use crate::{error, parse::ParseFromStr};

use super::*;

fn parse_padding(input: &str) -> parse::Call {
// Parse the if expression
let parsed_expr = parse::Expression::parse_from_str(input).expect("Failed to parse");

// Extract the parsed If from the expression
let parsed_if = match parsed_expr.inner() {
parse::ExpressionInner::Single(single) => match single.inner() {
parse::SingleExpressionInner::Call(call) => match call.name() {
parse::CallName::Padding(_) => call.clone(),
_ => panic!("Expected padding call"),
},
_ => panic!("Expected If expression"),
},
_ => panic!("Expected Single expression"),
};
parsed_if
}

#[test]
fn test_ast_padding() {
let input = "padding::<10>()";

let parsed_call = &parse_padding(input);

// Analyze the if expression with u8 as the expected type
let expected_type = ResolvedType::unit();
let mut scope = Scope::default();
let ast_padding = Call::analyze(parsed_call, &expected_type, &mut scope)
.expect("Failed to analyze Padding expression");

// Verify the structure
assert_eq!(
ast_padding.args().len(),
0,
"Args did not analyse correctly"
);
assert_eq!(
ast_padding.name(),
&CallName::Padding(NonZeroUsize::new(10).unwrap()),
"Call name was not padding"
);
}

#[test]
fn test_ast_padding_should_fail_with_args() {
let input = "padding::<10>(1)";

let parsed_call = &parse_padding(input);

// Analyze the if expression with u8 as the expected type
let expected_type = ResolvedType::unit();
let mut scope = Scope::default();
let res = Call::analyze(parsed_call, &expected_type, &mut scope);

assert!(
matches!(
res.unwrap_err().error(),
error::Error::InvalidNumberOfArguments(0, 1)
),
"padding parsed correctly but should have failed"
);
}
}
48 changes: 48 additions & 0 deletions src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,24 @@ impl Call {
let iden = ProgNode::iden(scope.ctx());
scope.with_debug_symbol(args, &iden, self)
}
CallName::Padding(size) => {
fn recurse_padding<'x>(
index: usize,
scope: &mut Scope<'x>,
me: &Call,
) -> Result<PairBuilder<ProgNode<'x>>, RichError> {
if index == 0 {
Ok(PairBuilder::unit(scope.ctx()))
} else {
let left = { PairBuilder::unit(scope.ctx()) };
let right = recurse_padding(index - 1, scope, me)?;
let pair = left.pair(right);
let drop_iden = ProgNode::drop_(&ProgNode::iden(scope.ctx()));
pair.comp(&drop_iden).with_span(me)
}
}
recurse_padding(size.get(), scope, self)
}
CallName::TypeCast(..) => {
// A cast converts between two structurally equal types.
// Structural equality of SimplicityHL types A and B means
Expand Down Expand Up @@ -680,3 +698,33 @@ impl Match {
input.comp(&output).with_span(self)
}
}

#[cfg(test)]
mod test {

use crate::{
ast,
parse::{self, ParseFromStr},
};

use super::*;

fn compile_program(
input: &str,
) -> Result<Arc<named::CommitNode<Elements>>, crate::error::RichError> {
let parse_program = parse::Program::parse_from_str(input).expect("Failed to parse");
let ast_program = ast::Program::analyze(&parse_program).expect("Failed to analyze");
ast_program.compile(Arguments::default(), false)
}

#[test]
fn test_padding_compiles() {
let input_program = r#"
fn main() {
padding::<20>();
}"#;

let padding_node =
compile_program(input_program).expect("padding expression should compile");
}
}
3 changes: 3 additions & 0 deletions src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub enum TrackedCallName {
UnwrapRight(ResolvedType),
Unwrap,
Debug(ResolvedType),
Padding,
}

/// Fallible call expression with runtime input value.
Expand All @@ -60,6 +61,7 @@ pub enum FallibleCallName {
UnwrapLeft(Value),
UnwrapRight(Value),
Unwrap,
Padding,
}

/// Debug expression with runtime input value.
Expand Down Expand Up @@ -188,6 +190,7 @@ impl TrackedCall {
})
.map(Either::Right)
}
TrackedCallName::Padding => FallibleCallName::Padding,
};
Some(Either::Left(FallibleCall {
text: Arc::clone(&self.text),
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ pub enum Error {
ModuleRedefined(ModuleName),
ArgumentMissing(WitnessName),
ArgumentTypeMismatch(WitnessName, ResolvedType, ResolvedType),
PaddingSizeZero,
}

#[rustfmt::skip]
Expand Down Expand Up @@ -582,6 +583,7 @@ impl fmt::Display for Error {
f,
"Parameter `{name}` was declared with type `{declared}` but its assigned argument is of type `{assigned}`"
),
Error::PaddingSizeZero => write!(f, "Padding size cannot be zero")
}
}
}
Expand Down
94 changes: 92 additions & 2 deletions src/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,14 @@ pub fn lexer<'src>(
.ignore_then(digits_with_underscore(2))
.map(|s: &str| Token::BinLiteral(Binary::from_str_unchecked(s.replace('_', "").as_str())));

let macros =
choice((just("assert!"), just("panic!"), just("dbg!"), just("list!"))).map(Token::Macro);
let macros = choice((
just("assert!"),
just("panic!"),
just("dbg!"),
just("list!"),
// just("padding!"),
))
.map(Token::Macro);

let keyword = text::ident().map(|s| match s {
"fn" => Token::Fn,
Expand Down Expand Up @@ -243,6 +249,75 @@ mod tests {

use super::*;

/// Helper function to get the variant name of a token
fn variant_name(token: &Token) -> &'static str {
match token {
Token::Fn => "Fn",
Token::Let => "Let",
Token::Type => "Type",
Token::Mod => "Mod",
Token::Const => "Const",
Token::Match => "Match",
Token::Arrow => "Arrow",
Token::Colon => "Colon",
Token::Semi => "Semi",
Token::Comma => "Comma",
Token::Eq => "Eq",
Token::FatArrow => "FatArrow",
Token::LParen => "LParen",
Token::RParen => "RParen",
Token::LBracket => "LBracket",
Token::RBracket => "RBracket",
Token::LBrace => "LBrace",
Token::RBrace => "RBrace",
Token::LAngle => "LAngle",
Token::RAngle => "RAngle",
Token::DecLiteral(_) => "DecLiteral",
Token::HexLiteral(_) => "HexLiteral",
Token::BinLiteral(_) => "BinLiteral",
Token::Bool(_) => "Bool",
Token::Ident(_) => "Ident",
Token::Jet(_) => "Jet",
Token::Witness(_) => "Witness",
Token::Param(_) => "Param",
Token::Macro(_) => "Macro",
Token::Comment => "Comment",
Token::BlockComment => "BlockComment",
}
}

/// Macro to assert that a sequence of tokens matches the expected variant types
macro_rules! assert_tokens_match {
($tokens:expr, $($expected:ident),* $(,)?) => {
{
let tokens = $tokens.as_ref().expect("Expected Some tokens");
let expected_variants = vec![$( stringify!($expected) ),*];

assert_eq!(
tokens.len(),
expected_variants.len(),
"Expected {} tokens, got {}.\nTokens: {:?}",
expected_variants.len(),
tokens.len(),
tokens
);

for (idx, ((token, _span), expected_variant)) in tokens.iter().zip(expected_variants.iter()).enumerate() {
let actual_variant = variant_name(token);
assert_eq!(
actual_variant,
*expected_variant,
"Token at index {} does not match: expected {}, got {} (token: {:?})",
idx,
expected_variant,
actual_variant,
token
);
}
}
};
}

fn lex<'src>(
input: &'src str,
) -> (Option<Vec<Token<'src>>>, Vec<Rich<'src, char, SimpleSpan>>) {
Expand Down Expand Up @@ -344,4 +419,19 @@ mod tests {

assert!(lex_errs.is_empty());
}

#[test]
fn test_lexer_padding_detection() {
let expr = "padding::<10>()";

let (tokens, lex_errs) = lexer().parse(expr).into_output_errors();

// let _ = tokens.unwrap();

assert!(lex_errs.is_empty());

assert_tokens_match!(
tokens, Ident, Colon, Colon, LAngle, DecLiteral, RAngle, LParen, RParen
);
}
}
Loading
Loading