Skip to content

Commit 6434c7b

Browse files
authored
feat: light program 1 byte discriminator (#2302)
* feat: pinocchio account add custom discriminator, add 1 byte discriminator compress decompress test * feat: add 1 byte discriminator account to stress test * randomize tests and format * address feedback * test: discriminators with 2-7 bytes * feat: add discriminator compile time collision detection * fix doc comment
1 parent b378643 commit 6434c7b

24 files changed

Lines changed: 2218 additions & 290 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

js/compressed-token/src/v3/actions/create-mint-interface.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ export async function createMintInterface(
8181

8282
// Default: light-token mint creation
8383
if (!('secretKey' in mintAuthority)) {
84-
throw new Error(
85-
'mintAuthority must be a Signer for light-token mints',
86-
);
84+
throw new Error('mintAuthority must be a Signer for light-token mints');
8785
}
8886
if (
8987
addressTreeInfo &&

js/compressed-token/src/v3/get-mint-interface.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ export async function getMintInterface(
103103
);
104104

105105
if (!compressedAccount?.data?.data) {
106-
throw new Error(
107-
`Light mint not found for ${address.toString()}`,
108-
);
106+
throw new Error(`Light mint not found for ${address.toString()}`);
109107
}
110108

111109
const compressedData = Buffer.from(compressedAccount.data.data);

sdk-libs/macros/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,14 @@ pub fn light_account_derive(input: TokenStream) -> TokenStream {
395395
/// - The `compression_info` field must be first or last field in the struct
396396
/// - Struct should be `#[repr(C)]` for predictable memory layout
397397
/// - Use `[u8; 32]` instead of `Pubkey` for address fields
398-
#[proc_macro_derive(LightPinocchioAccount, attributes(compress_as, skip))]
398+
///
399+
/// ## Custom discriminator
400+
///
401+
/// Use `#[light_pinocchio(discriminator = [1u8])]` to override the default
402+
/// 8-byte SHA256 discriminator with a shorter custom discriminator (1-8 bytes).
403+
/// Variants with short discriminators should be declared last in `ProgramAccounts`
404+
/// enums to avoid prefix-matching conflicts during dispatch.
405+
#[proc_macro_derive(LightPinocchioAccount, attributes(compress_as, skip, light_pinocchio))]
399406
pub fn light_pinocchio_account_derive(input: TokenStream) -> TokenStream {
400407
let input = parse_macro_input!(input as DeriveInput);
401408
into_token_stream(light_pdas::account::derive::derive_light_pinocchio_account(

sdk-libs/macros/src/light_pdas/account/derive.rs

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,68 @@ pub fn derive_light_pinocchio_account(input: DeriveInput) -> Result<TokenStream>
117117
derive_light_account_internal(input, Framework::Pinocchio)
118118
}
119119

120+
/// Parses the `discriminator` bytes from `#[light_pinocchio(discriminator = [...])]` if present.
121+
/// Returns None if the attribute is absent (use hash-derived discriminator).
122+
fn parse_pinocchio_discriminator(attrs: &[syn::Attribute]) -> Result<Option<Vec<u8>>> {
123+
for attr in attrs {
124+
if !attr.path().is_ident("light_pinocchio") {
125+
continue;
126+
}
127+
let meta_list = attr.meta.require_list()?;
128+
let nested: Punctuated<syn::Meta, Token![,]> =
129+
meta_list.parse_args_with(Punctuated::parse_terminated)?;
130+
for meta in &nested {
131+
if let syn::Meta::NameValue(nv) = meta {
132+
if nv.path.is_ident("discriminator") {
133+
if let syn::Expr::Array(arr) = &nv.value {
134+
let bytes: Vec<u8> = arr
135+
.elems
136+
.iter()
137+
.map(|e| {
138+
if let syn::Expr::Lit(lit) = e {
139+
if let syn::Lit::Int(i) = &lit.lit {
140+
return i
141+
.base10_parse::<u8>()
142+
.map_err(|err| syn::Error::new_spanned(i, err));
143+
}
144+
}
145+
if let syn::Expr::Cast(cast) = e {
146+
if let syn::Expr::Lit(lit) = cast.expr.as_ref() {
147+
if let syn::Lit::Int(i) = &lit.lit {
148+
return i
149+
.base10_parse::<u8>()
150+
.map_err(|err| syn::Error::new_spanned(i, err));
151+
}
152+
}
153+
}
154+
Err(syn::Error::new_spanned(e, "expected integer literal"))
155+
})
156+
.collect::<Result<Vec<u8>>>()?;
157+
if bytes.is_empty() {
158+
return Err(syn::Error::new_spanned(
159+
arr,
160+
"discriminator must have at least one byte",
161+
));
162+
}
163+
if bytes.len() > 8 {
164+
return Err(syn::Error::new_spanned(
165+
arr,
166+
"discriminator must not exceed 8 bytes",
167+
));
168+
}
169+
return Ok(Some(bytes));
170+
}
171+
return Err(syn::Error::new_spanned(
172+
&nv.value,
173+
"discriminator must be an array like [1u8]",
174+
));
175+
}
176+
}
177+
}
178+
}
179+
Ok(None)
180+
}
181+
120182
/// Internal implementation of LightAccount derive, parameterized by framework.
121183
fn derive_light_account_internal(input: DeriveInput, framework: Framework) -> Result<TokenStream> {
122184
// Convert DeriveInput to ItemStruct for macros that need it
@@ -125,8 +187,35 @@ fn derive_light_account_internal(input: DeriveInput, framework: Framework) -> Re
125187
// Generate LightHasherSha implementation
126188
let hasher_impl = derive_light_hasher_sha(item_struct.clone())?;
127189

128-
// Generate LightDiscriminator implementation
129-
let discriminator_impl = discriminator::anchor_discriminator(item_struct)?;
190+
// Check for custom discriminator argument from #[light_pinocchio(discriminator = [...])]
191+
// Only valid for the Pinocchio framework; reject it on Anchor to avoid silent misuse.
192+
let discriminator_impl = if let Some(disc_bytes) = parse_pinocchio_discriminator(&input.attrs)?
193+
{
194+
if framework != Framework::Pinocchio {
195+
return Err(syn::Error::new_spanned(
196+
&input.ident,
197+
"#[light_pinocchio(discriminator = [...])] is only valid with \
198+
#[derive(LightPinocchioAccount)], not with #[derive(LightAccount)]",
199+
));
200+
}
201+
let mut padded = [0u8; 8];
202+
let copy_len = disc_bytes.len().min(8);
203+
padded[..copy_len].copy_from_slice(&disc_bytes[..copy_len]);
204+
let discriminator_tokens: proc_macro2::TokenStream = format!("{padded:?}").parse().unwrap();
205+
let slice_tokens: proc_macro2::TokenStream = format!("{disc_bytes:?}").parse().unwrap();
206+
let struct_name = &input.ident;
207+
let (impl_gen, type_gen, where_clause) = input.generics.split_for_impl();
208+
quote! {
209+
impl #impl_gen LightDiscriminator for #struct_name #type_gen #where_clause {
210+
const LIGHT_DISCRIMINATOR: [u8; 8] = #discriminator_tokens;
211+
const LIGHT_DISCRIMINATOR_SLICE: &'static [u8] = &#slice_tokens;
212+
fn discriminator() -> [u8; 8] { Self::LIGHT_DISCRIMINATOR }
213+
}
214+
}
215+
} else {
216+
// Generate LightDiscriminator implementation via SHA256
217+
discriminator::anchor_discriminator(item_struct)?
218+
};
130219

131220
// Generate unified LightAccount implementation (includes PackedXxx struct)
132221
let light_account_impl = generate_light_account_impl(&input, framework)?;
@@ -747,6 +836,106 @@ mod tests {
747836

748837
use super::*;
749838

839+
#[test]
840+
fn test_light_pinocchio_custom_discriminator() {
841+
let input: DeriveInput = parse_quote! {
842+
#[light_pinocchio(discriminator = [1u8])]
843+
pub struct OneByteRecord {
844+
pub compression_info: CompressionInfo,
845+
pub owner: [u8; 32],
846+
}
847+
};
848+
849+
let result = derive_light_pinocchio_account(input);
850+
assert!(
851+
result.is_ok(),
852+
"LightPinocchioAccount with custom discriminator should succeed: {:?}",
853+
result.err()
854+
);
855+
856+
let output = result.unwrap().to_string();
857+
858+
// Should contain custom discriminator (1, 0, 0, 0, 0, 0, 0, 0)
859+
assert!(
860+
output.contains("LIGHT_DISCRIMINATOR"),
861+
"Should have LIGHT_DISCRIMINATOR"
862+
);
863+
assert!(
864+
output.contains("1 , 0 , 0 , 0 , 0 , 0 , 0 , 0")
865+
|| output.contains("1, 0, 0, 0, 0, 0, 0, 0"),
866+
"LIGHT_DISCRIMINATOR should be [1,0,0,0,0,0,0,0]"
867+
);
868+
// LIGHT_DISCRIMINATOR_SLICE must be &[1] (1 byte), NOT the padded &[1, 0, 0, 0, 0, 0, 0, 0]
869+
assert!(
870+
output.contains("LIGHT_DISCRIMINATOR_SLICE"),
871+
"Should have LIGHT_DISCRIMINATOR_SLICE"
872+
);
873+
// Verify the slice contains exactly 1 element (not 8)
874+
// The generated token stream renders as `& [1u8]` or `& [1]`
875+
assert!(
876+
output.contains("& [1u8]") || output.contains("& [1]"),
877+
"LIGHT_DISCRIMINATOR_SLICE should be &[1] (1 byte), got: {output}"
878+
);
879+
}
880+
881+
#[test]
882+
fn test_light_pinocchio_custom_discriminator_empty_rejected() {
883+
let input: DeriveInput = parse_quote! {
884+
#[light_pinocchio(discriminator = [])]
885+
pub struct EmptyDisc {
886+
pub compression_info: CompressionInfo,
887+
pub owner: [u8; 32],
888+
}
889+
};
890+
let result = derive_light_pinocchio_account(input);
891+
assert!(
892+
result.is_err(),
893+
"Empty discriminator array should be rejected"
894+
);
895+
let err = result.unwrap_err().to_string();
896+
assert!(
897+
err.contains("at least one byte"),
898+
"Error should mention 'at least one byte', got: {err}"
899+
);
900+
}
901+
902+
#[test]
903+
fn test_light_pinocchio_custom_discriminator_too_long_rejected() {
904+
let input: DeriveInput = parse_quote! {
905+
#[light_pinocchio(discriminator = [1, 2, 3, 4, 5, 6, 7, 8, 9])]
906+
pub struct TooLongDisc {
907+
pub compression_info: CompressionInfo,
908+
pub owner: [u8; 32],
909+
}
910+
};
911+
let result = derive_light_pinocchio_account(input);
912+
assert!(
913+
result.is_err(),
914+
"Discriminator longer than 8 bytes should be rejected"
915+
);
916+
let err = result.unwrap_err().to_string();
917+
assert!(
918+
err.contains("exceed 8 bytes"),
919+
"Error should mention max length, got: {err}"
920+
);
921+
}
922+
923+
#[test]
924+
fn test_light_pinocchio_discriminator_rejected_on_anchor() {
925+
let input: DeriveInput = parse_quote! {
926+
#[light_pinocchio(discriminator = [1u8])]
927+
pub struct AnchorRecord {
928+
pub compression_info: CompressionInfo,
929+
pub owner: Pubkey,
930+
}
931+
};
932+
let result = derive_light_account(input);
933+
assert!(
934+
result.is_err(),
935+
"#[light_pinocchio(discriminator)] should be rejected with LightAccount (Anchor)"
936+
);
937+
}
938+
750939
#[test]
751940
fn test_light_account_basic() {
752941
let input: DeriveInput = parse_quote! {

0 commit comments

Comments
 (0)