diff --git a/crates/medmodels-core/src/medrecord/group_mapping.rs b/crates/medmodels-core/src/medrecord/group_mapping.rs index 6c194914..c554b779 100644 --- a/crates/medmodels-core/src/medrecord/group_mapping.rs +++ b/crates/medmodels-core/src/medrecord/group_mapping.rs @@ -98,12 +98,8 @@ impl GroupMapping { group: Group, node_index: NodeIndex, ) -> Result<(), MedRecordError> { - let nodes_in_group = - self.nodes_in_group - .get_mut(&group) - .ok_or(MedRecordError::IndexError(format!( - "Cannot find group {group}" - )))?; + // TODO: This was changed. Add a test for adding to a non-existing group + let nodes_in_group = self.nodes_in_group.entry(group.clone()).or_default(); if !nodes_in_group.insert(node_index.clone()) { return Err(MedRecordError::AssertionError(format!( @@ -124,12 +120,8 @@ impl GroupMapping { group: Group, edge_index: EdgeIndex, ) -> Result<(), MedRecordError> { - let edges_in_group = - self.edges_in_group - .get_mut(&group) - .ok_or(MedRecordError::IndexError(format!( - "Cannot find group {group}" - )))?; + // TODO: This was changed. Add a test for adding to a non-existing group + let edges_in_group = self.edges_in_group.entry(group.clone()).or_default(); if !edges_in_group.insert(edge_index) { return Err(MedRecordError::AssertionError(format!( @@ -376,11 +368,6 @@ mod test { .add_group("0".into(), Some(vec!["0".into()]), None) .unwrap(); - // Adding to a non-existing group should fail - assert!(group_mapping - .add_node_to_group("50".into(), "1".into()) - .is_err_and(|e| matches!(e, MedRecordError::IndexError(_)))); - // Adding a node to a group that already is in the group should fail assert!(group_mapping .add_node_to_group("0".into(), "0".into()) @@ -414,11 +401,6 @@ mod test { .add_group("0".into(), None, Some(vec![0])) .unwrap(); - // Adding to a non-existing group should fail - assert!(group_mapping - .add_edge_to_group("50".into(), 1) - .is_err_and(|e| matches!(e, MedRecordError::IndexError(_)))); - // Adding an edge to a group that already is in the group should fail assert!(group_mapping .add_edge_to_group("0".into(), 0) diff --git a/crates/medmodels-core/src/medrecord/mod.rs b/crates/medmodels-core/src/medrecord/mod.rs index 4dc4bab3..d62d2f5c 100644 --- a/crates/medmodels-core/src/medrecord/mod.rs +++ b/crates/medmodels-core/src/medrecord/mod.rs @@ -497,6 +497,44 @@ impl MedRecord { .map_err(MedRecordError::from) } + // TODO: Add tests + pub fn add_node_with_group( + &mut self, + node_index: NodeIndex, + attributes: Attributes, + group: Group, + ) -> Result<(), MedRecordError> { + match self.schema.schema_type() { + SchemaType::Inferred => { + let nodes_in_group = self + .group_mapping + .nodes_in_group + .get(&group) + .map(|nodes| nodes.len()) + .unwrap_or(0); + + self.schema + .update_node(&attributes, Some(&group), nodes_in_group == 0); + } + SchemaType::Provided => { + self.schema + .validate_node(&node_index, &attributes, Some(&group))?; + } + } + + self.graph + .add_node(node_index.clone(), attributes) + .map_err(MedRecordError::from)?; + + self.group_mapping + .add_node_to_group(group, node_index.clone()) + .inspect_err(|_| { + self.graph + .remove_node(&node_index, &mut self.group_mapping) + .expect("Node must exist"); + }) + } + pub fn remove_node(&mut self, node_index: &NodeIndex) -> Result { self.group_mapping.remove_node(node_index); @@ -513,6 +551,23 @@ impl MedRecord { Ok(()) } + // TODO: Add tests + pub fn add_nodes_with_group( + &mut self, + nodes: Vec<(NodeIndex, Attributes)>, + group: Group, + ) -> Result<(), MedRecordError> { + if !self.contains_group(&group) { + self.add_group(group.clone(), None, None)?; + } + + for (node_index, attributes) in nodes.into_iter() { + self.add_node_with_group(node_index, attributes, group.clone())?; + } + + Ok(()) + } + pub fn add_nodes_dataframes( &mut self, nodes_dataframes: impl IntoIterator>, @@ -532,6 +587,27 @@ impl MedRecord { self.add_nodes(nodes) } + // TODO: Add tests + pub fn add_nodes_dataframes_with_group( + &mut self, + nodes_dataframes: impl IntoIterator>, + group: Group, + ) -> Result<(), MedRecordError> { + let nodes = nodes_dataframes + .into_iter() + .map(|dataframe_input| { + let dataframe_input = dataframe_input.into(); + + dataframe_to_nodes(dataframe_input.dataframe, &dataframe_input.index_column) + }) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + + self.add_nodes_with_group(nodes, group) + } + pub fn add_edge( &mut self, source_node_index: NodeIndex, @@ -569,6 +645,53 @@ impl MedRecord { } } + // TODO: Add tests + pub fn add_edge_with_group( + &mut self, + source_node_index: NodeIndex, + target_node_index: NodeIndex, + attributes: Attributes, + group: Group, + ) -> Result { + let edge_index = self + .graph + .add_edge(source_node_index, target_node_index, attributes.to_owned()) + .map_err(MedRecordError::from)?; + + match self.schema.schema_type() { + SchemaType::Inferred => { + let edges_in_group = self + .group_mapping + .edges_in_group + .get(&group) + .map(|edges| edges.len()) + .unwrap_or(0); + + self.schema + .update_edge(&attributes, Some(&group), edges_in_group == 0); + } + SchemaType::Provided => { + self.schema + .validate_edge(&edge_index, &attributes, Some(&group)) + .inspect_err(|_| { + self.graph + .remove_edge(&edge_index) + .expect("Edge must exist"); + })?; + } + } + + self.group_mapping + .add_edge_to_group(group, edge_index) + .inspect_err(|_| { + self.graph + .remove_edge(&edge_index) + .expect("Edge must exist"); + })?; + + Ok(edge_index) + } + pub fn remove_edge(&mut self, edge_index: &EdgeIndex) -> Result { self.group_mapping.remove_edge(edge_index); @@ -589,6 +712,29 @@ impl MedRecord { .collect() } + // TODO: Add tests + pub fn add_edges_with_group( + &mut self, + edges: Vec<(NodeIndex, NodeIndex, Attributes)>, + group: Group, + ) -> Result, MedRecordError> { + if !self.contains_group(&group) { + self.add_group(group.clone(), None, None)?; + } + + edges + .into_iter() + .map(|(source_edge_index, target_node_index, attributes)| { + self.add_edge_with_group( + source_edge_index, + target_node_index, + attributes, + group.clone(), + ) + }) + .collect() + } + pub fn add_edges_dataframes( &mut self, edges_dataframes: impl IntoIterator>, @@ -612,6 +758,31 @@ impl MedRecord { self.add_edges(edges) } + // TODO: Add tests + pub fn add_edges_dataframes_with_group( + &mut self, + edges_dataframes: impl IntoIterator>, + group: Group, + ) -> Result, MedRecordError> { + let edges = edges_dataframes + .into_iter() + .map(|dataframe_input| { + let dataframe_input = dataframe_input.into(); + + dataframe_to_edges( + dataframe_input.dataframe, + &dataframe_input.source_index_column, + &dataframe_input.target_index_column, + ) + }) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + + self.add_edges_with_group(edges, group) + } + pub fn add_group( &mut self, group: Group, @@ -1861,11 +2032,6 @@ mod test { .add_group("0".into(), Some(vec!["0".into()]), None) .unwrap(); - // Adding to a non-existing group should fail - assert!(medrecord - .add_node_to_group("1".into(), "0".into()) - .is_err_and(|e| matches!(e, MedRecordError::IndexError(_)))); - // Adding a non-existing node to a group should fail assert!(medrecord .add_node_to_group("0".into(), "50".into()) @@ -1931,11 +2097,6 @@ mod test { .add_group("0".into(), None, Some(vec![0])) .unwrap(); - // Adding to a non-existing group should fail - assert!(medrecord - .add_edge_to_group("1".into(), 0) - .is_err_and(|e| matches!(e, MedRecordError::IndexError(_)))); - // Adding a non-existing edge to a group should fail assert!(medrecord .add_edge_to_group("0".into(), 50) diff --git a/crates/medmodels-python/src/medrecord/mod.rs b/crates/medmodels-python/src/medrecord/mod.rs index 8ccb7759..3c83f13b 100644 --- a/crates/medmodels-python/src/medrecord/mod.rs +++ b/crates/medmodels-python/src/medrecord/mod.rs @@ -444,6 +444,17 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } + pub fn add_nodes_with_group( + &mut self, + nodes: Vec<(PyNodeIndex, PyAttributes)>, + group: PyGroup, + ) -> PyResult<()> { + Ok(self + .0 + .add_nodes_with_group(nodes.deep_into(), group.into()) + .map_err(PyMedRecordError::from)?) + } + pub fn add_nodes_dataframes( &mut self, nodes_dataframes: Vec<(PyDataFrame, String)>, @@ -454,6 +465,17 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } + pub fn add_nodes_dataframes_with_group( + &mut self, + nodes_dataframes: Vec<(PyDataFrame, String)>, + group: PyGroup, + ) -> PyResult<()> { + Ok(self + .0 + .add_nodes_dataframes_with_group(nodes_dataframes, group.into()) + .map_err(PyMedRecordError::from)?) + } + pub fn remove_edges( &mut self, edge_indices: Vec, @@ -546,6 +568,17 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } + pub fn add_edges_with_group( + &mut self, + relations: Vec<(PyNodeIndex, PyNodeIndex, PyAttributes)>, + group: PyGroup, + ) -> PyResult> { + Ok(self + .0 + .add_edges_with_group(relations.deep_into(), group.into()) + .map_err(PyMedRecordError::from)?) + } + pub fn add_edges_dataframes( &mut self, edges_dataframes: Vec<(PyDataFrame, String, String)>, @@ -556,6 +589,17 @@ impl PyMedRecord { .map_err(PyMedRecordError::from)?) } + pub fn add_edges_dataframes_with_group( + &mut self, + edges_dataframes: Vec<(PyDataFrame, String, String)>, + group: PyGroup, + ) -> PyResult> { + Ok(self + .0 + .add_edges_dataframes_with_group(edges_dataframes, group.into()) + .map_err(PyMedRecordError::from)?) + } + #[pyo3(signature = (group, node_indices_to_add=None, edge_indices_to_add=None))] pub fn add_group( &mut self, diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index f8c3f32b..37013a46 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -238,9 +238,15 @@ class PyMedRecord: self, node_index: NodeIndexInputList, attribute: MedRecordAttribute ) -> None: ... def add_nodes(self, nodes: Sequence[NodeTuple]) -> None: ... + def add_nodes_with_group( + self, nodes: Sequence[NodeTuple], group: Group + ) -> None: ... def add_nodes_dataframes( self, nodes_dataframe: List[PolarsNodeDataFrameInput] ) -> None: ... + def add_nodes_dataframes_with_group( + self, nodes_dataframe: List[PolarsNodeDataFrameInput], group: Group + ) -> None: ... def remove_edges( self, edge_index: EdgeIndexInputList ) -> Dict[EdgeIndex, Attributes]: ... @@ -257,9 +263,15 @@ class PyMedRecord: self, edge_index: EdgeIndexInputList, attribute: MedRecordAttribute ) -> None: ... def add_edges(self, edges: Sequence[EdgeTuple]) -> List[EdgeIndex]: ... + def add_edges_with_group( + self, edges: Sequence[EdgeTuple], group: Group + ) -> List[EdgeIndex]: ... def add_edges_dataframes( self, edges_dataframe: List[PolarsEdgeDataFrameInput] ) -> List[EdgeIndex]: ... + def add_edges_dataframes_with_group( + self, edges_dataframe: List[PolarsEdgeDataFrameInput], group: Group + ) -> List[EdgeIndex]: ... def add_group( self, group: Group, diff --git a/medmodels/medrecord/medrecord.py b/medmodels/medrecord/medrecord.py index dadfc505..1be84341 100644 --- a/medmodels/medrecord/medrecord.py +++ b/medmodels/medrecord/medrecord.py @@ -925,15 +925,10 @@ def add_nodes( if is_node_tuple(nodes): nodes = [nodes] - self._medrecord.add_nodes(nodes) - if group is None: - return - - if not self.contains_group(group): - self.add_group(group) - - self.add_nodes_to_group(group, [node[0] for node in nodes]) + self._medrecord.add_nodes(nodes) + else: + self._medrecord.add_nodes_with_group(nodes, group) def add_nodes_pandas( self, @@ -976,24 +971,14 @@ def add_nodes_polars( group (Optional[Group]): The name of the group to add the nodes to. If not specified, the nodes are added to the MedRecord without a group. """ - self._medrecord.add_nodes_dataframes( - nodes if isinstance(nodes, list) else [nodes] - ) - if group is None: - return - - if not self.contains_group(group): - self.add_group(group) - - if isinstance(nodes, list): - node_indices = [ - nodes for node in nodes for nodes in node[0][node[1]].to_list() - ] + self._medrecord.add_nodes_dataframes( + nodes if isinstance(nodes, list) else [nodes] + ) else: - node_indices = nodes[0][nodes[1]].to_list() - - self.add_nodes_to_group(group, node_indices) + self._medrecord.add_nodes_dataframes_with_group( + nodes if isinstance(nodes, list) else [nodes], group + ) @overload def remove_edges(self, edges: Union[EdgeIndex, EdgeIndexQuery]) -> Attributes: ... @@ -1073,17 +1058,10 @@ def add_edges( if is_edge_tuple(edges): edges = [edges] - edge_indices = self._medrecord.add_edges(edges) - if group is None: - return edge_indices - - if not self.contains_group(group): - self.add_group(group) + return self._medrecord.add_edges(edges) - self.add_edges_to_group(group, edge_indices) - - return edge_indices + return self._medrecord.add_edges_with_group(edges, group) def add_edges_pandas( self, @@ -1134,19 +1112,14 @@ def add_edges_polars( Returns: List[EdgeIndex]: A list of the edge indices added. """ - edge_indices = self._medrecord.add_edges_dataframes( - edges if isinstance(edges, list) else [edges] - ) - if group is None: - return edge_indices - - if not self.contains_group(group): - self.add_group(group) - - self.add_edges_to_group(group, edge_indices) + return self._medrecord.add_edges_dataframes( + edges if isinstance(edges, list) else [edges] + ) - return edge_indices + return self._medrecord.add_edges_dataframes_with_group( + edges if isinstance(edges, list) else [edges], group + ) def add_group( self, diff --git a/tests/medrecord/test_medrecord.py b/tests/medrecord/test_medrecord.py index 4be6db6a..58952f0c 100644 --- a/tests/medrecord/test_medrecord.py +++ b/tests/medrecord/test_medrecord.py @@ -1694,14 +1694,6 @@ def test_invalid_add_nodes_to_group(self) -> None: medrecord.add_group("0", ["0"]) - # Adding to a non-existing group should fail - with pytest.raises(IndexError): - medrecord.add_nodes_to_group("50", "1") - - # Adding to a non-existing group should fail - with pytest.raises(IndexError): - medrecord.add_nodes_to_group("50", ["1", "2"]) - # Adding a non-existing node to a group should fail with pytest.raises(IndexError): medrecord.add_nodes_to_group("0", "50") @@ -1799,14 +1791,6 @@ def test_invalid_add_edges_to_group(self) -> None: medrecord.add_group("0", edges=[0]) - # Adding to a non-existing group should fail - with pytest.raises(IndexError): - medrecord.add_edges_to_group("50", 1) - - # Adding to a non-existing group should fail - with pytest.raises(IndexError): - medrecord.add_edges_to_group("50", [1, 2]) - # Adding a non-existing edge to a group should fail with pytest.raises(IndexError): medrecord.add_edges_to_group("0", 50)