From 67ceaa273c1e4e39fe4b3b6f79ce9c0928828d55 Mon Sep 17 00:00:00 2001 From: stringhandler Date: Mon, 2 Mar 2026 17:24:58 +0200 Subject: [PATCH] add padding keyword to increase program size --- src/ast.rs | 81 +++++++++++++++++++++++++++++++++++++++ src/compile/mod.rs | 48 +++++++++++++++++++++++ src/debug.rs | 3 ++ src/error.rs | 2 + src/lexer.rs | 94 +++++++++++++++++++++++++++++++++++++++++++++- src/parse.rs | 71 ++++++++++++++++++++++++++++++++++ 6 files changed, 297 insertions(+), 2 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 6cda4851..808f99fe 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -287,6 +287,8 @@ pub enum CallName { Panic, /// [`dbg!`]. Debug, + /// [`padding!`] + Padding(NonZeroUsize), /// Cast from the given source type. TypeCast(ResolvedType), /// A custom function that was defined previously. @@ -1184,6 +1186,14 @@ impl AbstractSyntaxTree for Call { scope.track_call(from, TrackedCallName::Debug(arg_ty)); args } + CallName::Padding(_size) => { + let args_tys = []; + check_argument_types(from.args(), &args_tys).with_span(from)?; + let args = analyze_arguments(from.args(), &args_tys, scope)?; + scope.track_call(from, TrackedCallName::Padding); + + args + } CallName::TypeCast(source) => { if StructuralType::from(&source) != StructuralType::from(ty) { return Err(Error::InvalidCast(source, ty.clone())).with_span(from); @@ -1311,6 +1321,7 @@ impl AbstractSyntaxTree for CallName { parse::CallName::Assert => Ok(Self::Assert), parse::CallName::Panic => Ok(Self::Panic), parse::CallName::Debug => Ok(Self::Debug), + parse::CallName::Padding(size) => Ok(Self::Padding(*size)), parse::CallName::TypeCast(target) => { scope.resolve(target).map(Self::TypeCast).with_span(from) } @@ -1570,3 +1581,73 @@ impl AsRef for ModuleAssignment { &self.span } } + +#[cfg(test)] +mod test { + use crate::{error, parse::ParseFromStr}; + + use super::*; + + fn parse_padding(input: &str) -> parse::Call { + // Parse the if expression + let parsed_expr = parse::Expression::parse_from_str(input).expect("Failed to parse"); + + // Extract the parsed If from the expression + let parsed_if = match parsed_expr.inner() { + parse::ExpressionInner::Single(single) => match single.inner() { + parse::SingleExpressionInner::Call(call) => match call.name() { + parse::CallName::Padding(_) => call.clone(), + _ => panic!("Expected padding call"), + }, + _ => panic!("Expected If expression"), + }, + _ => panic!("Expected Single expression"), + }; + parsed_if + } + + #[test] + fn test_ast_padding() { + let input = "padding::<10>()"; + + let parsed_call = &parse_padding(input); + + // Analyze the if expression with u8 as the expected type + let expected_type = ResolvedType::unit(); + let mut scope = Scope::default(); + let ast_padding = Call::analyze(parsed_call, &expected_type, &mut scope) + .expect("Failed to analyze Padding expression"); + + // Verify the structure + assert_eq!( + ast_padding.args().len(), + 0, + "Args did not analyse correctly" + ); + assert_eq!( + ast_padding.name(), + &CallName::Padding(NonZeroUsize::new(10).unwrap()), + "Call name was not padding" + ); + } + + #[test] + fn test_ast_padding_should_fail_with_args() { + let input = "padding::<10>(1)"; + + let parsed_call = &parse_padding(input); + + // Analyze the if expression with u8 as the expected type + let expected_type = ResolvedType::unit(); + let mut scope = Scope::default(); + let res = Call::analyze(parsed_call, &expected_type, &mut scope); + + assert!( + matches!( + res.unwrap_err().error(), + error::Error::InvalidNumberOfArguments(0, 1) + ), + "padding parsed correctly but should have failed" + ); + } +} diff --git a/src/compile/mod.rs b/src/compile/mod.rs index 2af17e6f..30ff2cac 100644 --- a/src/compile/mod.rs +++ b/src/compile/mod.rs @@ -423,6 +423,24 @@ impl Call { let iden = ProgNode::iden(scope.ctx()); scope.with_debug_symbol(args, &iden, self) } + CallName::Padding(size) => { + fn recurse_padding<'x>( + index: usize, + scope: &mut Scope<'x>, + me: &Call, + ) -> Result>, RichError> { + if index == 0 { + Ok(PairBuilder::unit(scope.ctx())) + } else { + let left = { PairBuilder::unit(scope.ctx()) }; + let right = recurse_padding(index - 1, scope, me)?; + let pair = left.pair(right); + let drop_iden = ProgNode::drop_(&ProgNode::iden(scope.ctx())); + pair.comp(&drop_iden).with_span(me) + } + } + recurse_padding(size.get(), scope, self) + } CallName::TypeCast(..) => { // A cast converts between two structurally equal types. // Structural equality of SimplicityHL types A and B means @@ -680,3 +698,33 @@ impl Match { input.comp(&output).with_span(self) } } + +#[cfg(test)] +mod test { + + use crate::{ + ast, + parse::{self, ParseFromStr}, + }; + + use super::*; + + fn compile_program( + input: &str, + ) -> Result>, crate::error::RichError> { + let parse_program = parse::Program::parse_from_str(input).expect("Failed to parse"); + let ast_program = ast::Program::analyze(&parse_program).expect("Failed to analyze"); + ast_program.compile(Arguments::default(), false) + } + + #[test] + fn test_padding_compiles() { + let input_program = r#" + fn main() { + padding::<20>(); + }"#; + + let padding_node = + compile_program(input_program).expect("padding expression should compile"); + } +} diff --git a/src/debug.rs b/src/debug.rs index 610e0b15..61347abb 100644 --- a/src/debug.rs +++ b/src/debug.rs @@ -42,6 +42,7 @@ pub enum TrackedCallName { UnwrapRight(ResolvedType), Unwrap, Debug(ResolvedType), + Padding, } /// Fallible call expression with runtime input value. @@ -60,6 +61,7 @@ pub enum FallibleCallName { UnwrapLeft(Value), UnwrapRight(Value), Unwrap, + Padding, } /// Debug expression with runtime input value. @@ -188,6 +190,7 @@ impl TrackedCall { }) .map(Either::Right) } + TrackedCallName::Padding => FallibleCallName::Padding, }; Some(Either::Left(FallibleCall { text: Arc::clone(&self.text), diff --git a/src/error.rs b/src/error.rs index c06cc90b..5b2be74b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -437,6 +437,7 @@ pub enum Error { ModuleRedefined(ModuleName), ArgumentMissing(WitnessName), ArgumentTypeMismatch(WitnessName, ResolvedType, ResolvedType), + PaddingSizeZero, } #[rustfmt::skip] @@ -582,6 +583,7 @@ impl fmt::Display for Error { f, "Parameter `{name}` was declared with type `{declared}` but its assigned argument is of type `{assigned}`" ), + Error::PaddingSizeZero => write!(f, "Padding size cannot be zero") } } } diff --git a/src/lexer.rs b/src/lexer.rs index 71c004b6..97ad5cc6 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -130,8 +130,14 @@ pub fn lexer<'src>( .ignore_then(digits_with_underscore(2)) .map(|s: &str| Token::BinLiteral(Binary::from_str_unchecked(s.replace('_', "").as_str()))); - let macros = - choice((just("assert!"), just("panic!"), just("dbg!"), just("list!"))).map(Token::Macro); + let macros = choice(( + just("assert!"), + just("panic!"), + just("dbg!"), + just("list!"), + // just("padding!"), + )) + .map(Token::Macro); let keyword = text::ident().map(|s| match s { "fn" => Token::Fn, @@ -243,6 +249,75 @@ mod tests { use super::*; + /// Helper function to get the variant name of a token + fn variant_name(token: &Token) -> &'static str { + match token { + Token::Fn => "Fn", + Token::Let => "Let", + Token::Type => "Type", + Token::Mod => "Mod", + Token::Const => "Const", + Token::Match => "Match", + Token::Arrow => "Arrow", + Token::Colon => "Colon", + Token::Semi => "Semi", + Token::Comma => "Comma", + Token::Eq => "Eq", + Token::FatArrow => "FatArrow", + Token::LParen => "LParen", + Token::RParen => "RParen", + Token::LBracket => "LBracket", + Token::RBracket => "RBracket", + Token::LBrace => "LBrace", + Token::RBrace => "RBrace", + Token::LAngle => "LAngle", + Token::RAngle => "RAngle", + Token::DecLiteral(_) => "DecLiteral", + Token::HexLiteral(_) => "HexLiteral", + Token::BinLiteral(_) => "BinLiteral", + Token::Bool(_) => "Bool", + Token::Ident(_) => "Ident", + Token::Jet(_) => "Jet", + Token::Witness(_) => "Witness", + Token::Param(_) => "Param", + Token::Macro(_) => "Macro", + Token::Comment => "Comment", + Token::BlockComment => "BlockComment", + } + } + + /// Macro to assert that a sequence of tokens matches the expected variant types + macro_rules! assert_tokens_match { + ($tokens:expr, $($expected:ident),* $(,)?) => { + { + let tokens = $tokens.as_ref().expect("Expected Some tokens"); + let expected_variants = vec![$( stringify!($expected) ),*]; + + assert_eq!( + tokens.len(), + expected_variants.len(), + "Expected {} tokens, got {}.\nTokens: {:?}", + expected_variants.len(), + tokens.len(), + tokens + ); + + for (idx, ((token, _span), expected_variant)) in tokens.iter().zip(expected_variants.iter()).enumerate() { + let actual_variant = variant_name(token); + assert_eq!( + actual_variant, + *expected_variant, + "Token at index {} does not match: expected {}, got {} (token: {:?})", + idx, + expected_variant, + actual_variant, + token + ); + } + } + }; + } + fn lex<'src>( input: &'src str, ) -> (Option>>, Vec>) { @@ -344,4 +419,19 @@ mod tests { assert!(lex_errs.is_empty()); } + + #[test] + fn test_lexer_padding_detection() { + let expr = "padding::<10>()"; + + let (tokens, lex_errs) = lexer().parse(expr).into_output_errors(); + + // let _ = tokens.unwrap(); + + assert!(lex_errs.is_empty()); + + assert_tokens_match!( + tokens, Ident, Colon, Colon, LAngle, DecLiteral, RAngle, LParen, RParen + ); + } } diff --git a/src/parse.rs b/src/parse.rs index f47dda5e..059d77b9 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -216,6 +216,8 @@ pub enum CallName { ArrayFold(FunctionName, NonZeroUsize), /// Loop over the given function a bounded number of times until it returns success. ForWhile(FunctionName), + /// Padding to add to the transaction to increase the weight of the transaction to fit CPU heavy programs + Padding(NonZeroUsize), } /// A type alias. @@ -819,6 +821,7 @@ impl fmt::Display for CallName { CallName::Fold(name, bound) => write!(f, "fold::<{name}, {bound}>"), CallName::ArrayFold(name, size) => write!(f, "array_fold::<{name}, {size}>"), CallName::ForWhile(name) => write!(f, "for_while::<{name}>"), + CallName::Padding(size) => write!(f, "padding::<{size}>"), } } } @@ -1437,6 +1440,29 @@ impl ChumskyParse for CallName { Token::Macro("dbg!") => CallName::Debug, }; + let padding = just(Token::Ident("padding")) + .ignore_then(turbofish_start.clone()) + .then(select! { Token::DecLiteral(s) => s }.labelled("size")) + .then_ignore(generics_close.clone()) + .validate(|(_, size_str), e, emit| { + let size = match size_str.as_inner().parse::() { + Ok(0) => { + emit.emit(Error::PaddingSizeZero.with_span(e.span())); + NonZeroUsize::new(1).unwrap() + } + Ok(n) => NonZeroUsize::new(n).unwrap(), + Err(_) => { + emit.emit( + Error::CannotParse(format!("Invalid number: {}", size_str)) + .with_span(e.span()), + ); + NonZeroUsize::new(1).unwrap() + } + }; + + CallName::Padding(size) + }); + let jet = select! { Token::Jet(s) => JetName::from_str_unchecked(s) }.map(CallName::Jet); let custom_func = FunctionName::parser().map(CallName::Custom); @@ -1451,6 +1477,8 @@ impl ChumskyParse for CallName { for_while, simple_builtins, jet, + padding, + // Note: Add built-in functions before this, otherwise they will not be matched. custom_func, )) } @@ -2163,3 +2191,46 @@ impl crate::ArbitraryRec for Match { }) } } + +#[cfg(test)] +mod test { + + use super::*; + + fn parse_padding(input: &str) -> Result { + // Parse the if expression + let parsed_expr = Expression::parse_from_str(input).map_err(|_| "Failed to parse")?; + + // Extract the parsed If from the expression + let parsed_if = match parsed_expr.inner() { + ExpressionInner::Single(single) => match single.inner() { + SingleExpressionInner::Call(call) => match call.name() { + CallName::Padding(_) => Ok(call.clone()), + _ => Err("Expected padding call"), + }, + _ => Err("Expected Call expression"), + }, + _ => Err("Expected Single expression"), + }; + parsed_if + } + + #[test] + fn test_parse_padding() { + let input = "padding::<10>()"; + + parse_padding(input).unwrap(); + } + + #[test] + fn test_parse_padding_should_fail_with_multiple_generics() { + let input = "padding::<10, 22>()"; + + let parsed_call = parse_padding(input); + + assert!( + parsed_call.is_err(), + "padding parsed correctly but should have failed" + ); + } +}