diff --git a/src/rsrs/args.rs b/src/rsrs/args.rs index ee3a732..edac206 100644 --- a/src/rsrs/args.rs +++ b/src/rsrs/args.rs @@ -1,7 +1,7 @@ use crate::{ rsrs::{ rsrs_factors::{ - null_and_extract::{ExtractOptions, IdOptions, PivotMethod}, + null_and_extract::{ExtractOptions, IdOptions, NonSymmetricIdCombination, PivotMethod}, rsrs_operator::FactType, }, sketch::Shift, @@ -141,6 +141,8 @@ pub struct RsrsArgs { num_threads: usize, flush_factors: bool, store_far: bool, + #[serde(default)] + nonsymmetric_id_combination: NonSymmetricIdCombination, } impl RsrsArgs @@ -202,6 +204,7 @@ where num_threads, flush_factors, store_far, + nonsymmetric_id_combination: NonSymmetricIdCombination::default(), } } } @@ -269,6 +272,7 @@ impl RsrsOptions { tol_null: args.tol_null, tol_id: args.tol_id, store_far: args.store_far, + nonsymmetric_id_combination: args.nonsymmetric_id_combination, }, lu_options: ExtractOptions { block_extraction_method: args.near_block_extraction_method, @@ -330,6 +334,18 @@ impl RsrsOptions { ) .unwrap(); + if matches!( + self.id_options.nonsymmetric_id_combination, + NonSymmetricIdCombination::Concat + ) { + write!( + &mut id, + "_nsid_{:?}", + self.id_options.nonsymmetric_id_combination + ) + .unwrap(); + } + match self.id_options.qr_method{ RankRevealingQrType::RRQR => write!( &mut id, @@ -376,3 +392,54 @@ where self } } + +#[cfg(test)] +mod tests { + use super::*; + + fn make_args() -> RsrsArgs { + RsrsArgs::new( + 8, + 16, + 0, + 0, + Shift::False, + NullMethod::Projection, + RankRevealingQrType::RRQR, + BlockExtractionMethod::LuLstSq, + BlockExtractionMethod::LuLstSq, + PivotMethod::LuHybrid(0.0), + PivotMethod::LuHybrid(0.0), + 1e-16, + 40.0, + 1e-16, + 1e-16, + 1, + 1, + Symmetry::NoSymm, + RankPicking::Min, + FactType::Joint, + true, + 1, + false, + false, + ) + .with_fixed_rank_sampling_mode(FixedRankSamplingMode::Constant) + } + + #[test] + fn default_nonsymmetric_id_mode_keeps_legacy_identifier() { + let args = make_args(); + let identifier = RsrsOptions::new(Some(args)).to_identifier(); + assert!(identifier.contains("_fsamp_Constant")); + assert!(!identifier.contains("_nsid_")); + } + + #[test] + fn concat_nonsymmetric_id_mode_is_encoded_in_identifier() { + let mut args = make_args(); + args.nonsymmetric_id_combination = NonSymmetricIdCombination::Concat; + let identifier = RsrsOptions::new(Some(args)).to_identifier(); + assert!(identifier.contains("_fsamp_Constant_nsid_Concat")); + } +} diff --git a/src/rsrs/rsrs_factors/commutative_factors.rs b/src/rsrs/rsrs_factors/commutative_factors.rs index 7ff8e1a..595460e 100644 --- a/src/rsrs/rsrs_factors/commutative_factors.rs +++ b/src/rsrs/rsrs_factors/commutative_factors.rs @@ -6,7 +6,7 @@ use crate::rsrs::rsrs_factors::base_factors::{ }; use crate::rsrs::rsrs_factors::null_and_extract::{ extract_lu_factor_from_blocks, near_box_extraction, null_near_field_into, ExtractOptions, - ExtractionScratch, IdOptions, PivotMethod, + ExtractionScratch, IdOptions, NonSymmetricIdCombination, PivotMethod, }; use crate::rsrs::rsrs_factors::rsrs_operator::FactType; use crate::rsrs::sketch::SketchData; @@ -544,7 +544,17 @@ where let start: Instant = Instant::now(); let test_shape = [subs_sample_dim, near_field_inds.len()]; let sketch_shape = [subs_sample_dim, target_inds.len()]; - let null_shape = [test_shape[0] - test_shape[1], sketch_shape[1]]; + let base_null_rows = test_shape[0].saturating_sub(test_shape[1]); + let null_rows = if !symmetry.symm_val() + && matches!( + id_options.nonsymmetric_id_combination, + NonSymmetricIdCombination::Concat + ) { + 2 * base_null_rows + } else { + base_null_rows + }; + let null_shape = [null_rows, sketch_shape[1]]; // nullification of the near field null_near_field_into( diff --git a/src/rsrs/rsrs_factors/null_and_extract.rs b/src/rsrs/rsrs_factors/null_and_extract.rs index 6b70e41..1c90abc 100644 --- a/src/rsrs/rsrs_factors/null_and_extract.rs +++ b/src/rsrs/rsrs_factors/null_and_extract.rs @@ -38,6 +38,13 @@ pub enum PivotMethod { LuHybrid(f64), //TODO: Change to Item } +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +pub enum NonSymmetricIdCombination { + #[default] + Sum, + Concat, +} + #[derive(Debug, Clone)] pub struct ExtractOptions { pub block_extraction_method: BlockExtractionMethod, @@ -52,6 +59,7 @@ pub struct IdOptions { pub tol_null: Real, pub tol_id: Real, pub store_far: bool, + pub nonsymmetric_id_combination: NonSymmetricIdCombination, } pub struct ExtractionScratch { @@ -78,6 +86,47 @@ impl Default for ExtractionScratch { } } +fn usable_nullspace_rows(subs_sample_dim: usize, near_field_len: usize) -> usize { + subs_sample_dim.saturating_sub(near_field_len) +} + +fn concat_nonsymmetric_id_sketches( + primary: &mut DynamicArray, + secondary: &DynamicArray, + rows_per_sketch: usize, +) { + let primary_shape = primary.shape(); + let secondary_shape = secondary.shape(); + assert_eq!( + primary_shape[1], secondary_shape[1], + "Cannot concatenate ID sketches with different column counts" + ); + + let kept_primary_rows = primary_shape[0].min(rows_per_sketch); + let kept_secondary_rows = secondary_shape[0].min(rows_per_sketch); + let cols = primary_shape[1]; + let mut combined = rlst_dynamic_array2!(Item, [kept_primary_rows + kept_secondary_rows, cols]); + + if kept_primary_rows > 0 { + combined + .r_mut() + .into_subview([0, 0], [kept_primary_rows, cols]) + .fill_from(primary.r().into_subview([0, 0], [kept_primary_rows, cols])); + } + if kept_secondary_rows > 0 { + combined + .r_mut() + .into_subview([kept_primary_rows, 0], [kept_secondary_rows, cols]) + .fill_from( + secondary + .r() + .into_subview([0, 0], [kept_secondary_rows, cols]), + ); + } + + *primary = combined; +} + #[allow(clippy::too_many_arguments)] fn null_sketch_near_field_into< Item: RlstScalar @@ -217,7 +266,18 @@ pub fn null_near_field_into< test_scratch, normal_scratch, ); - far_field_sketch.sum_into(aux_sketch.r()); + match id_options.nonsymmetric_id_combination { + NonSymmetricIdCombination::Sum => { + far_field_sketch.sum_into(aux_sketch.r()); + } + NonSymmetricIdCombination::Concat => { + concat_nonsymmetric_id_sketches( + far_field_sketch, + aux_sketch, + usable_nullspace_rows(subs_sample_dim, near_field_inds.len()), + ); + } + } } } @@ -666,3 +726,39 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn concat_mode_stacks_sketches_by_rows() { + let mut primary = rlst_dynamic_array2!(f64, [3, 2]); + primary.r_mut()[[0, 0]] = 1.0; + primary.r_mut()[[0, 1]] = 2.0; + primary.r_mut()[[1, 0]] = 3.0; + primary.r_mut()[[1, 1]] = 4.0; + primary.r_mut()[[2, 0]] = 99.0; + primary.r_mut()[[2, 1]] = 99.0; + + let mut secondary = rlst_dynamic_array2!(f64, [3, 2]); + secondary.r_mut()[[0, 0]] = 5.0; + secondary.r_mut()[[0, 1]] = 6.0; + secondary.r_mut()[[1, 0]] = 7.0; + secondary.r_mut()[[1, 1]] = 8.0; + secondary.r_mut()[[2, 0]] = 88.0; + secondary.r_mut()[[2, 1]] = 88.0; + + concat_nonsymmetric_id_sketches(&mut primary, &secondary, 2); + + assert_eq!(primary.shape(), [4, 2]); + assert_eq!(primary.r()[[0, 0]], 1.0); + assert_eq!(primary.r()[[0, 1]], 2.0); + assert_eq!(primary.r()[[1, 0]], 3.0); + assert_eq!(primary.r()[[1, 1]], 4.0); + assert_eq!(primary.r()[[2, 0]], 5.0); + assert_eq!(primary.r()[[2, 1]], 6.0); + assert_eq!(primary.r()[[3, 0]], 7.0); + assert_eq!(primary.r()[[3, 1]], 8.0); + } +}