Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion src/rsrs/args.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
rsrs::{
rsrs_factors::{
null_and_extract::{ExtractOptions, IdOptions, PivotMethod},
null_and_extract::{ExtractOptions, IdOptions, NonSymmetricIdCombination, PivotMethod},
rsrs_operator::FactType,
},
sketch::Shift,
Expand Down Expand Up @@ -141,6 +141,8 @@ pub struct RsrsArgs<Item: RlstScalar> {
num_threads: usize,
flush_factors: bool,
store_far: bool,
#[serde(default)]
nonsymmetric_id_combination: NonSymmetricIdCombination,
}

impl<Item> RsrsArgs<Item>
Expand Down Expand Up @@ -202,6 +204,7 @@ where
num_threads,
flush_factors,
store_far,
nonsymmetric_id_combination: NonSymmetricIdCombination::default(),
}
}
}
Expand Down Expand Up @@ -269,6 +272,7 @@ impl<Item: RlstScalar + std::fmt::Display> RsrsOptions<Item> {
tol_null: args.tol_null,
tol_id: args.tol_id,
store_far: args.store_far,
nonsymmetric_id_combination: args.nonsymmetric_id_combination,
},
lu_options: ExtractOptions {
block_extraction_method: args.near_block_extraction_method,
Expand Down Expand Up @@ -330,6 +334,18 @@ impl<Item: RlstScalar + std::fmt::Display> RsrsOptions<Item> {
)
.unwrap();

if matches!(
self.id_options.nonsymmetric_id_combination,
NonSymmetricIdCombination::Concat
) {
write!(
&mut id,
"_nsid_{:?}",
self.id_options.nonsymmetric_id_combination
)
.unwrap();
}

match self.id_options.qr_method{
RankRevealingQrType::RRQR => write!(
&mut id,
Expand Down Expand Up @@ -376,3 +392,54 @@ where
self
}
}

#[cfg(test)]
mod tests {
use super::*;

fn make_args() -> RsrsArgs<f64> {
RsrsArgs::new(
8,
16,
0,
0,
Shift::False,
NullMethod::Projection,
RankRevealingQrType::RRQR,
BlockExtractionMethod::LuLstSq,
BlockExtractionMethod::LuLstSq,
PivotMethod::LuHybrid(0.0),
PivotMethod::LuHybrid(0.0),
1e-16,
40.0,
1e-16,
1e-16,
1,
1,
Symmetry::NoSymm,
RankPicking::Min,
FactType::Joint,
true,
1,
false,
false,
)
.with_fixed_rank_sampling_mode(FixedRankSamplingMode::Constant)
}

#[test]
fn default_nonsymmetric_id_mode_keeps_legacy_identifier() {
let args = make_args();
let identifier = RsrsOptions::new(Some(args)).to_identifier();
assert!(identifier.contains("_fsamp_Constant"));
assert!(!identifier.contains("_nsid_"));
}

#[test]
fn concat_nonsymmetric_id_mode_is_encoded_in_identifier() {
let mut args = make_args();
args.nonsymmetric_id_combination = NonSymmetricIdCombination::Concat;
let identifier = RsrsOptions::new(Some(args)).to_identifier();
assert!(identifier.contains("_fsamp_Constant_nsid_Concat"));
}
}
14 changes: 12 additions & 2 deletions src/rsrs/rsrs_factors/commutative_factors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::rsrs::rsrs_factors::base_factors::{
};
use crate::rsrs::rsrs_factors::null_and_extract::{
extract_lu_factor_from_blocks, near_box_extraction, null_near_field_into, ExtractOptions,
ExtractionScratch, IdOptions, PivotMethod,
ExtractionScratch, IdOptions, NonSymmetricIdCombination, PivotMethod,
};
use crate::rsrs::rsrs_factors::rsrs_operator::FactType;
use crate::rsrs::sketch::SketchData;
Expand Down Expand Up @@ -544,7 +544,17 @@ where
let start: Instant = Instant::now();
let test_shape = [subs_sample_dim, near_field_inds.len()];
let sketch_shape = [subs_sample_dim, target_inds.len()];
let null_shape = [test_shape[0] - test_shape[1], sketch_shape[1]];
let base_null_rows = test_shape[0].saturating_sub(test_shape[1]);
let null_rows = if !symmetry.symm_val()
&& matches!(
id_options.nonsymmetric_id_combination,
NonSymmetricIdCombination::Concat
) {
2 * base_null_rows
} else {
base_null_rows
};
let null_shape = [null_rows, sketch_shape[1]];

// nullification of the near field
null_near_field_into(
Expand Down
98 changes: 97 additions & 1 deletion src/rsrs/rsrs_factors/null_and_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ pub enum PivotMethod {
LuHybrid(f64), //TODO: Change to Item
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
pub enum NonSymmetricIdCombination {
#[default]
Sum,
Concat,
}

#[derive(Debug, Clone)]
pub struct ExtractOptions<Item: RlstScalar> {
pub block_extraction_method: BlockExtractionMethod,
Expand All @@ -52,6 +59,7 @@ pub struct IdOptions<Item: RlstScalar> {
pub tol_null: Real<Item>,
pub tol_id: Real<Item>,
pub store_far: bool,
pub nonsymmetric_id_combination: NonSymmetricIdCombination,
}

pub struct ExtractionScratch<Item: RlstScalar> {
Expand All @@ -78,6 +86,47 @@ impl<Item: RlstScalar> Default for ExtractionScratch<Item> {
}
}

fn usable_nullspace_rows(subs_sample_dim: usize, near_field_len: usize) -> usize {
subs_sample_dim.saturating_sub(near_field_len)
}

fn concat_nonsymmetric_id_sketches<Item: RlstScalar>(
primary: &mut DynamicArray<Item, 2>,
secondary: &DynamicArray<Item, 2>,
rows_per_sketch: usize,
) {
let primary_shape = primary.shape();
let secondary_shape = secondary.shape();
assert_eq!(
primary_shape[1], secondary_shape[1],
"Cannot concatenate ID sketches with different column counts"
);

let kept_primary_rows = primary_shape[0].min(rows_per_sketch);
let kept_secondary_rows = secondary_shape[0].min(rows_per_sketch);
let cols = primary_shape[1];
let mut combined = rlst_dynamic_array2!(Item, [kept_primary_rows + kept_secondary_rows, cols]);

if kept_primary_rows > 0 {
combined
.r_mut()
.into_subview([0, 0], [kept_primary_rows, cols])
.fill_from(primary.r().into_subview([0, 0], [kept_primary_rows, cols]));
}
if kept_secondary_rows > 0 {
combined
.r_mut()
.into_subview([kept_primary_rows, 0], [kept_secondary_rows, cols])
.fill_from(
secondary
.r()
.into_subview([0, 0], [kept_secondary_rows, cols]),
);
}

*primary = combined;
}

#[allow(clippy::too_many_arguments)]
fn null_sketch_near_field_into<
Item: RlstScalar
Expand Down Expand Up @@ -217,7 +266,18 @@ pub fn null_near_field_into<
test_scratch,
normal_scratch,
);
far_field_sketch.sum_into(aux_sketch.r());
match id_options.nonsymmetric_id_combination {
NonSymmetricIdCombination::Sum => {
far_field_sketch.sum_into(aux_sketch.r());
}
NonSymmetricIdCombination::Concat => {
concat_nonsymmetric_id_sketches(
far_field_sketch,
aux_sketch,
usable_nullspace_rows(subs_sample_dim, near_field_inds.len()),
);
}
}
}
}

Expand Down Expand Up @@ -666,3 +726,39 @@ where
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn concat_mode_stacks_sketches_by_rows() {
let mut primary = rlst_dynamic_array2!(f64, [3, 2]);
primary.r_mut()[[0, 0]] = 1.0;
primary.r_mut()[[0, 1]] = 2.0;
primary.r_mut()[[1, 0]] = 3.0;
primary.r_mut()[[1, 1]] = 4.0;
primary.r_mut()[[2, 0]] = 99.0;
primary.r_mut()[[2, 1]] = 99.0;

let mut secondary = rlst_dynamic_array2!(f64, [3, 2]);
secondary.r_mut()[[0, 0]] = 5.0;
secondary.r_mut()[[0, 1]] = 6.0;
secondary.r_mut()[[1, 0]] = 7.0;
secondary.r_mut()[[1, 1]] = 8.0;
secondary.r_mut()[[2, 0]] = 88.0;
secondary.r_mut()[[2, 1]] = 88.0;

concat_nonsymmetric_id_sketches(&mut primary, &secondary, 2);

assert_eq!(primary.shape(), [4, 2]);
assert_eq!(primary.r()[[0, 0]], 1.0);
assert_eq!(primary.r()[[0, 1]], 2.0);
assert_eq!(primary.r()[[1, 0]], 3.0);
assert_eq!(primary.r()[[1, 1]], 4.0);
assert_eq!(primary.r()[[2, 0]], 5.0);
assert_eq!(primary.r()[[2, 1]], 6.0);
assert_eq!(primary.r()[[3, 0]], 7.0);
assert_eq!(primary.r()[[3, 1]], 8.0);
}
}
Loading