1010
1111#![ allow( clippy:: new_without_default) ]
1212
13- use wasm_bindgen:: prelude:: * ;
14- use ruvector_cnn:: contrastive:: { InfoNCELoss as RustInfoNCE , TripletLoss as RustTriplet , TripletDistance } ;
13+ use ruvector_cnn:: contrastive:: {
14+ InfoNCELoss as RustInfoNCE , TripletDistance , TripletLoss as RustTriplet ,
15+ } ;
1516use ruvector_cnn:: simd;
17+ use wasm_bindgen:: prelude:: * ;
1618
1719/// Initialize panic hook for better error messages
1820#[ wasm_bindgen( start) ]
@@ -94,9 +96,8 @@ impl WasmCnnEmbedder {
9496 let mean: f32 = channel_data. iter ( ) . sum :: < f32 > ( ) / pixels_per_channel as f32 ;
9597
9698 // Variance
97- let variance: f32 = channel_data. iter ( )
98- . map ( |x| ( x - mean) . powi ( 2 ) )
99- . sum :: < f32 > ( ) / pixels_per_channel as f32 ;
99+ let variance: f32 = channel_data. iter ( ) . map ( |x| ( x - mean) . powi ( 2 ) ) . sum :: < f32 > ( )
100+ / pixels_per_channel as f32 ;
100101
101102 // Store in embedding
102103 if c * 2 < self . embedding_dim {
@@ -195,7 +196,12 @@ impl WasmInfoNCELoss {
195196 /// Compute loss for a batch of embedding pairs
196197 /// embeddings: [2N, D] flattened where (i, i+N) are positive pairs
197198 #[ wasm_bindgen]
198- pub fn forward ( & self , embeddings : & [ f32 ] , batch_size : usize , dim : usize ) -> Result < f32 , JsValue > {
199+ pub fn forward (
200+ & self ,
201+ embeddings : & [ f32 ] ,
202+ batch_size : usize ,
203+ dim : usize ,
204+ ) -> Result < f32 , JsValue > {
199205 if embeddings. len ( ) != 2 * batch_size * dim {
200206 return Err ( JsValue :: from_str ( & format ! (
201207 "Expected {} elements, got {}" ,
@@ -269,17 +275,29 @@ impl WasmTripletLoss {
269275 negatives : & [ f32 ] ,
270276 dim : usize ,
271277 ) -> Result < f32 , JsValue > {
272- if anchors. len ( ) % dim != 0 || positives. len ( ) != anchors. len ( ) || negatives. len ( ) != anchors. len ( ) {
278+ if anchors. len ( ) % dim != 0
279+ || positives. len ( ) != anchors. len ( )
280+ || negatives. len ( ) != anchors. len ( )
281+ {
273282 return Err ( JsValue :: from_str ( "Invalid triplet dimensions" ) ) ;
274283 }
275284
276285 let batch_size = anchors. len ( ) / dim;
277286 let mut total_loss = 0.0f64 ;
278287
279288 for i in 0 ..batch_size {
280- let a: Vec < f64 > = anchors[ i * dim..( i + 1 ) * dim] . iter ( ) . map ( |& x| x as f64 ) . collect ( ) ;
281- let p: Vec < f64 > = positives[ i * dim..( i + 1 ) * dim] . iter ( ) . map ( |& x| x as f64 ) . collect ( ) ;
282- let n: Vec < f64 > = negatives[ i * dim..( i + 1 ) * dim] . iter ( ) . map ( |& x| x as f64 ) . collect ( ) ;
289+ let a: Vec < f64 > = anchors[ i * dim..( i + 1 ) * dim]
290+ . iter ( )
291+ . map ( |& x| x as f64 )
292+ . collect ( ) ;
293+ let p: Vec < f64 > = positives[ i * dim..( i + 1 ) * dim]
294+ . iter ( )
295+ . map ( |& x| x as f64 )
296+ . collect ( ) ;
297+ let n: Vec < f64 > = negatives[ i * dim..( i + 1 ) * dim]
298+ . iter ( )
299+ . map ( |& x| x as f64 )
300+ . collect ( ) ;
283301 total_loss += self . inner . forward ( & a, & p, & n) ;
284302 }
285303
@@ -351,14 +369,28 @@ impl LayerOps {
351369 ) -> Vec < f32 > {
352370 let channels = gamma. len ( ) ;
353371 let mut output = vec ! [ 0.0f32 ; input. len( ) ] ;
354- simd:: batch_norm_simd ( input, & mut output, gamma, beta, mean, var, epsilon, channels) ;
372+ simd:: batch_norm_simd (
373+ input,
374+ & mut output,
375+ gamma,
376+ beta,
377+ mean,
378+ var,
379+ epsilon,
380+ channels,
381+ ) ;
355382 output
356383 }
357384
358385 /// Apply global average pooling
359386 /// Returns one value per channel
360387 #[ wasm_bindgen]
361- pub fn global_avg_pool ( input : & [ f32 ] , height : usize , width : usize , channels : usize ) -> Vec < f32 > {
388+ pub fn global_avg_pool (
389+ input : & [ f32 ] ,
390+ height : usize ,
391+ width : usize ,
392+ channels : usize ,
393+ ) -> Vec < f32 > {
362394 let mut output = vec ! [ 0.0f32 ; channels] ;
363395 simd:: global_avg_pool_simd ( input, & mut output, height, width, channels) ;
364396 output
@@ -382,7 +414,8 @@ mod tests {
382414 input_size : 8 ,
383415 embedding_dim : 64 ,
384416 normalize : true ,
385- } ) ) . unwrap ( ) ;
417+ } ) )
418+ . unwrap ( ) ;
386419
387420 let image_data = vec ! [ 128u8 ; 8 * 8 * 3 ] ;
388421 let embedding = embedder. extract ( & image_data, 8 , 8 ) . unwrap ( ) ;
0 commit comments