@@ -117,6 +117,68 @@ pub fn derive_light_pinocchio_account(input: DeriveInput) -> Result<TokenStream>
117117 derive_light_account_internal ( input, Framework :: Pinocchio )
118118}
119119
120+ /// Parses the `discriminator` bytes from `#[light_pinocchio(discriminator = [...])]` if present.
121+ /// Returns None if the attribute is absent (use hash-derived discriminator).
122+ fn parse_pinocchio_discriminator ( attrs : & [ syn:: Attribute ] ) -> Result < Option < Vec < u8 > > > {
123+ for attr in attrs {
124+ if !attr. path ( ) . is_ident ( "light_pinocchio" ) {
125+ continue ;
126+ }
127+ let meta_list = attr. meta . require_list ( ) ?;
128+ let nested: Punctuated < syn:: Meta , Token ! [ , ] > =
129+ meta_list. parse_args_with ( Punctuated :: parse_terminated) ?;
130+ for meta in & nested {
131+ if let syn:: Meta :: NameValue ( nv) = meta {
132+ if nv. path . is_ident ( "discriminator" ) {
133+ if let syn:: Expr :: Array ( arr) = & nv. value {
134+ let bytes: Vec < u8 > = arr
135+ . elems
136+ . iter ( )
137+ . map ( |e| {
138+ if let syn:: Expr :: Lit ( lit) = e {
139+ if let syn:: Lit :: Int ( i) = & lit. lit {
140+ return i
141+ . base10_parse :: < u8 > ( )
142+ . map_err ( |err| syn:: Error :: new_spanned ( i, err) ) ;
143+ }
144+ }
145+ if let syn:: Expr :: Cast ( cast) = e {
146+ if let syn:: Expr :: Lit ( lit) = cast. expr . as_ref ( ) {
147+ if let syn:: Lit :: Int ( i) = & lit. lit {
148+ return i
149+ . base10_parse :: < u8 > ( )
150+ . map_err ( |err| syn:: Error :: new_spanned ( i, err) ) ;
151+ }
152+ }
153+ }
154+ Err ( syn:: Error :: new_spanned ( e, "expected integer literal" ) )
155+ } )
156+ . collect :: < Result < Vec < u8 > > > ( ) ?;
157+ if bytes. is_empty ( ) {
158+ return Err ( syn:: Error :: new_spanned (
159+ arr,
160+ "discriminator must have at least one byte" ,
161+ ) ) ;
162+ }
163+ if bytes. len ( ) > 8 {
164+ return Err ( syn:: Error :: new_spanned (
165+ arr,
166+ "discriminator must not exceed 8 bytes" ,
167+ ) ) ;
168+ }
169+ return Ok ( Some ( bytes) ) ;
170+ }
171+ return Err ( syn:: Error :: new_spanned (
172+ & nv. value ,
173+ "discriminator must be an array like [1u8]" ,
174+ ) ) ;
175+ }
176+ }
177+ }
178+ }
179+ Ok ( None )
180+ }
181+
120182/// Internal implementation of LightAccount derive, parameterized by framework.
121183fn derive_light_account_internal ( input : DeriveInput , framework : Framework ) -> Result < TokenStream > {
122184 // Convert DeriveInput to ItemStruct for macros that need it
@@ -125,8 +187,35 @@ fn derive_light_account_internal(input: DeriveInput, framework: Framework) -> Re
125187 // Generate LightHasherSha implementation
126188 let hasher_impl = derive_light_hasher_sha ( item_struct. clone ( ) ) ?;
127189
128- // Generate LightDiscriminator implementation
129- let discriminator_impl = discriminator:: anchor_discriminator ( item_struct) ?;
190+ // Check for custom discriminator argument from #[light_pinocchio(discriminator = [...])]
191+ // Only valid for the Pinocchio framework; reject it on Anchor to avoid silent misuse.
192+ let discriminator_impl = if let Some ( disc_bytes) = parse_pinocchio_discriminator ( & input. attrs ) ?
193+ {
194+ if framework != Framework :: Pinocchio {
195+ return Err ( syn:: Error :: new_spanned (
196+ & input. ident ,
197+ "#[light_pinocchio(discriminator = [...])] is only valid with \
198+ #[derive(LightPinocchioAccount)], not with #[derive(LightAccount)]",
199+ ) ) ;
200+ }
201+ let mut padded = [ 0u8 ; 8 ] ;
202+ let copy_len = disc_bytes. len ( ) . min ( 8 ) ;
203+ padded[ ..copy_len] . copy_from_slice ( & disc_bytes[ ..copy_len] ) ;
204+ let discriminator_tokens: proc_macro2:: TokenStream = format ! ( "{padded:?}" ) . parse ( ) . unwrap ( ) ;
205+ let slice_tokens: proc_macro2:: TokenStream = format ! ( "{disc_bytes:?}" ) . parse ( ) . unwrap ( ) ;
206+ let struct_name = & input. ident ;
207+ let ( impl_gen, type_gen, where_clause) = input. generics . split_for_impl ( ) ;
208+ quote ! {
209+ impl #impl_gen LightDiscriminator for #struct_name #type_gen #where_clause {
210+ const LIGHT_DISCRIMINATOR : [ u8 ; 8 ] = #discriminator_tokens;
211+ const LIGHT_DISCRIMINATOR_SLICE : & ' static [ u8 ] = & #slice_tokens;
212+ fn discriminator( ) -> [ u8 ; 8 ] { Self :: LIGHT_DISCRIMINATOR }
213+ }
214+ }
215+ } else {
216+ // Generate LightDiscriminator implementation via SHA256
217+ discriminator:: anchor_discriminator ( item_struct) ?
218+ } ;
130219
131220 // Generate unified LightAccount implementation (includes PackedXxx struct)
132221 let light_account_impl = generate_light_account_impl ( & input, framework) ?;
@@ -747,6 +836,106 @@ mod tests {
747836
748837 use super :: * ;
749838
839+ #[ test]
840+ fn test_light_pinocchio_custom_discriminator ( ) {
841+ let input: DeriveInput = parse_quote ! {
842+ #[ light_pinocchio( discriminator = [ 1u8 ] ) ]
843+ pub struct OneByteRecord {
844+ pub compression_info: CompressionInfo ,
845+ pub owner: [ u8 ; 32 ] ,
846+ }
847+ } ;
848+
849+ let result = derive_light_pinocchio_account ( input) ;
850+ assert ! (
851+ result. is_ok( ) ,
852+ "LightPinocchioAccount with custom discriminator should succeed: {:?}" ,
853+ result. err( )
854+ ) ;
855+
856+ let output = result. unwrap ( ) . to_string ( ) ;
857+
858+ // Should contain custom discriminator (1, 0, 0, 0, 0, 0, 0, 0)
859+ assert ! (
860+ output. contains( "LIGHT_DISCRIMINATOR" ) ,
861+ "Should have LIGHT_DISCRIMINATOR"
862+ ) ;
863+ assert ! (
864+ output. contains( "1 , 0 , 0 , 0 , 0 , 0 , 0 , 0" )
865+ || output. contains( "1, 0, 0, 0, 0, 0, 0, 0" ) ,
866+ "LIGHT_DISCRIMINATOR should be [1,0,0,0,0,0,0,0]"
867+ ) ;
868+ // LIGHT_DISCRIMINATOR_SLICE must be &[1] (1 byte), NOT the padded &[1, 0, 0, 0, 0, 0, 0, 0]
869+ assert ! (
870+ output. contains( "LIGHT_DISCRIMINATOR_SLICE" ) ,
871+ "Should have LIGHT_DISCRIMINATOR_SLICE"
872+ ) ;
873+ // Verify the slice contains exactly 1 element (not 8)
874+ // The generated token stream renders as `& [1u8]` or `& [1]`
875+ assert ! (
876+ output. contains( "& [1u8]" ) || output. contains( "& [1]" ) ,
877+ "LIGHT_DISCRIMINATOR_SLICE should be &[1] (1 byte), got: {output}"
878+ ) ;
879+ }
880+
881+ #[ test]
882+ fn test_light_pinocchio_custom_discriminator_empty_rejected ( ) {
883+ let input: DeriveInput = parse_quote ! {
884+ #[ light_pinocchio( discriminator = [ ] ) ]
885+ pub struct EmptyDisc {
886+ pub compression_info: CompressionInfo ,
887+ pub owner: [ u8 ; 32 ] ,
888+ }
889+ } ;
890+ let result = derive_light_pinocchio_account ( input) ;
891+ assert ! (
892+ result. is_err( ) ,
893+ "Empty discriminator array should be rejected"
894+ ) ;
895+ let err = result. unwrap_err ( ) . to_string ( ) ;
896+ assert ! (
897+ err. contains( "at least one byte" ) ,
898+ "Error should mention 'at least one byte', got: {err}"
899+ ) ;
900+ }
901+
902+ #[ test]
903+ fn test_light_pinocchio_custom_discriminator_too_long_rejected ( ) {
904+ let input: DeriveInput = parse_quote ! {
905+ #[ light_pinocchio( discriminator = [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] ) ]
906+ pub struct TooLongDisc {
907+ pub compression_info: CompressionInfo ,
908+ pub owner: [ u8 ; 32 ] ,
909+ }
910+ } ;
911+ let result = derive_light_pinocchio_account ( input) ;
912+ assert ! (
913+ result. is_err( ) ,
914+ "Discriminator longer than 8 bytes should be rejected"
915+ ) ;
916+ let err = result. unwrap_err ( ) . to_string ( ) ;
917+ assert ! (
918+ err. contains( "exceed 8 bytes" ) ,
919+ "Error should mention max length, got: {err}"
920+ ) ;
921+ }
922+
923+ #[ test]
924+ fn test_light_pinocchio_discriminator_rejected_on_anchor ( ) {
925+ let input: DeriveInput = parse_quote ! {
926+ #[ light_pinocchio( discriminator = [ 1u8 ] ) ]
927+ pub struct AnchorRecord {
928+ pub compression_info: CompressionInfo ,
929+ pub owner: Pubkey ,
930+ }
931+ } ;
932+ let result = derive_light_account ( input) ;
933+ assert ! (
934+ result. is_err( ) ,
935+ "#[light_pinocchio(discriminator)] should be rejected with LightAccount (Anchor)"
936+ ) ;
937+ }
938+
750939 #[ test]
751940 fn test_light_account_basic ( ) {
752941 let input: DeriveInput = parse_quote ! {
0 commit comments