Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 4 additions & 22 deletions crates/medmodels-core/src/medrecord/group_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,8 @@ impl GroupMapping {
group: Group,
node_index: NodeIndex,
) -> Result<(), MedRecordError> {
let nodes_in_group =
self.nodes_in_group
.get_mut(&group)
.ok_or(MedRecordError::IndexError(format!(
"Cannot find group {group}"
)))?;
// TODO: This was changed. Add a test for adding to a non-existing group
let nodes_in_group = self.nodes_in_group.entry(group.clone()).or_default();

if !nodes_in_group.insert(node_index.clone()) {
return Err(MedRecordError::AssertionError(format!(
Expand All @@ -124,12 +120,8 @@ impl GroupMapping {
group: Group,
edge_index: EdgeIndex,
) -> Result<(), MedRecordError> {
let edges_in_group =
self.edges_in_group
.get_mut(&group)
.ok_or(MedRecordError::IndexError(format!(
"Cannot find group {group}"
)))?;
// TODO: This was changed. Add a test for adding to a non-existing group
let edges_in_group = self.edges_in_group.entry(group.clone()).or_default();

if !edges_in_group.insert(edge_index) {
return Err(MedRecordError::AssertionError(format!(
Expand Down Expand Up @@ -376,11 +368,6 @@ mod test {
.add_group("0".into(), Some(vec!["0".into()]), None)
.unwrap();

// Adding to a non-existing group should fail
assert!(group_mapping
.add_node_to_group("50".into(), "1".into())
.is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));

// Adding a node to a group that already is in the group should fail
assert!(group_mapping
.add_node_to_group("0".into(), "0".into())
Expand Down Expand Up @@ -414,11 +401,6 @@ mod test {
.add_group("0".into(), None, Some(vec![0]))
.unwrap();

// Adding to a non-existing group should fail
assert!(group_mapping
.add_edge_to_group("50".into(), 1)
.is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));

// Adding an edge to a group that already is in the group should fail
assert!(group_mapping
.add_edge_to_group("0".into(), 0)
Expand Down
181 changes: 171 additions & 10 deletions crates/medmodels-core/src/medrecord/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,44 @@ impl MedRecord {
.map_err(MedRecordError::from)
}

// TODO: Add tests
pub fn add_node_with_group(
&mut self,
node_index: NodeIndex,
attributes: Attributes,
group: Group,
) -> Result<(), MedRecordError> {
match self.schema.schema_type() {
SchemaType::Inferred => {
let nodes_in_group = self
.group_mapping
.nodes_in_group
.get(&group)
.map(|nodes| nodes.len())
.unwrap_or(0);

self.schema
.update_node(&attributes, Some(&group), nodes_in_group == 0);
}
SchemaType::Provided => {
self.schema
.validate_node(&node_index, &attributes, Some(&group))?;
}
}

self.graph
.add_node(node_index.clone(), attributes)
.map_err(MedRecordError::from)?;

self.group_mapping
.add_node_to_group(group, node_index.clone())
.inspect_err(|_| {
self.graph
.remove_node(&node_index, &mut self.group_mapping)
.expect("Node must exist");
})
}

pub fn remove_node(&mut self, node_index: &NodeIndex) -> Result<Attributes, MedRecordError> {
self.group_mapping.remove_node(node_index);

Expand All @@ -513,6 +551,23 @@ impl MedRecord {
Ok(())
}

// TODO: Add tests
pub fn add_nodes_with_group(
&mut self,
nodes: Vec<(NodeIndex, Attributes)>,
group: Group,
) -> Result<(), MedRecordError> {
if !self.contains_group(&group) {
self.add_group(group.clone(), None, None)?;
}

for (node_index, attributes) in nodes.into_iter() {
self.add_node_with_group(node_index, attributes, group.clone())?;
}

Ok(())
}

pub fn add_nodes_dataframes(
&mut self,
nodes_dataframes: impl IntoIterator<Item = impl Into<NodeDataFrameInput>>,
Expand All @@ -532,6 +587,27 @@ impl MedRecord {
self.add_nodes(nodes)
}

// TODO: Add tests
pub fn add_nodes_dataframes_with_group(
&mut self,
nodes_dataframes: impl IntoIterator<Item = impl Into<NodeDataFrameInput>>,
group: Group,
) -> Result<(), MedRecordError> {
let nodes = nodes_dataframes
.into_iter()
.map(|dataframe_input| {
let dataframe_input = dataframe_input.into();

dataframe_to_nodes(dataframe_input.dataframe, &dataframe_input.index_column)
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();

self.add_nodes_with_group(nodes, group)
}

pub fn add_edge(
&mut self,
source_node_index: NodeIndex,
Expand Down Expand Up @@ -569,6 +645,53 @@ impl MedRecord {
}
}

// TODO: Add tests
pub fn add_edge_with_group(
&mut self,
source_node_index: NodeIndex,
target_node_index: NodeIndex,
attributes: Attributes,
group: Group,
) -> Result<EdgeIndex, MedRecordError> {
let edge_index = self
.graph
.add_edge(source_node_index, target_node_index, attributes.to_owned())
.map_err(MedRecordError::from)?;

match self.schema.schema_type() {
SchemaType::Inferred => {
let edges_in_group = self
.group_mapping
.edges_in_group
.get(&group)
.map(|edges| edges.len())
.unwrap_or(0);

self.schema
.update_edge(&attributes, Some(&group), edges_in_group == 0);
}
SchemaType::Provided => {
self.schema
.validate_edge(&edge_index, &attributes, Some(&group))
.inspect_err(|_| {
self.graph
.remove_edge(&edge_index)
.expect("Edge must exist");
})?;
}
}

self.group_mapping
.add_edge_to_group(group, edge_index)
.inspect_err(|_| {
self.graph
.remove_edge(&edge_index)
.expect("Edge must exist");
})?;

Ok(edge_index)
}

pub fn remove_edge(&mut self, edge_index: &EdgeIndex) -> Result<Attributes, MedRecordError> {
self.group_mapping.remove_edge(edge_index);

Expand All @@ -589,6 +712,29 @@ impl MedRecord {
.collect()
}

// TODO: Add tests
pub fn add_edges_with_group(
&mut self,
edges: Vec<(NodeIndex, NodeIndex, Attributes)>,
group: Group,
) -> Result<Vec<EdgeIndex>, MedRecordError> {
if !self.contains_group(&group) {
self.add_group(group.clone(), None, None)?;
}

edges
.into_iter()
.map(|(source_edge_index, target_node_index, attributes)| {
self.add_edge_with_group(
source_edge_index,
target_node_index,
attributes,
group.clone(),
)
})
.collect()
}

pub fn add_edges_dataframes(
&mut self,
edges_dataframes: impl IntoIterator<Item = impl Into<EdgeDataFrameInput>>,
Expand All @@ -612,6 +758,31 @@ impl MedRecord {
self.add_edges(edges)
}

// TODO: Add tests
pub fn add_edges_dataframes_with_group(
&mut self,
edges_dataframes: impl IntoIterator<Item = impl Into<EdgeDataFrameInput>>,
group: Group,
) -> Result<Vec<EdgeIndex>, MedRecordError> {
let edges = edges_dataframes
.into_iter()
.map(|dataframe_input| {
let dataframe_input = dataframe_input.into();

dataframe_to_edges(
dataframe_input.dataframe,
&dataframe_input.source_index_column,
&dataframe_input.target_index_column,
)
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();

self.add_edges_with_group(edges, group)
}

pub fn add_group(
&mut self,
group: Group,
Expand Down Expand Up @@ -1861,11 +2032,6 @@ mod test {
.add_group("0".into(), Some(vec!["0".into()]), None)
.unwrap();

// Adding to a non-existing group should fail
assert!(medrecord
.add_node_to_group("1".into(), "0".into())
.is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));

// Adding a non-existing node to a group should fail
assert!(medrecord
.add_node_to_group("0".into(), "50".into())
Expand Down Expand Up @@ -1931,11 +2097,6 @@ mod test {
.add_group("0".into(), None, Some(vec![0]))
.unwrap();

// Adding to a non-existing group should fail
assert!(medrecord
.add_edge_to_group("1".into(), 0)
.is_err_and(|e| matches!(e, MedRecordError::IndexError(_))));

// Adding a non-existing edge to a group should fail
assert!(medrecord
.add_edge_to_group("0".into(), 50)
Expand Down
44 changes: 44 additions & 0 deletions crates/medmodels-python/src/medrecord/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,17 @@ impl PyMedRecord {
.map_err(PyMedRecordError::from)?)
}

pub fn add_nodes_with_group(
&mut self,
nodes: Vec<(PyNodeIndex, PyAttributes)>,
group: PyGroup,
) -> PyResult<()> {
Ok(self
.0
.add_nodes_with_group(nodes.deep_into(), group.into())
.map_err(PyMedRecordError::from)?)
}

pub fn add_nodes_dataframes(
&mut self,
nodes_dataframes: Vec<(PyDataFrame, String)>,
Expand All @@ -454,6 +465,17 @@ impl PyMedRecord {
.map_err(PyMedRecordError::from)?)
}

pub fn add_nodes_dataframes_with_group(
&mut self,
nodes_dataframes: Vec<(PyDataFrame, String)>,
group: PyGroup,
) -> PyResult<()> {
Ok(self
.0
.add_nodes_dataframes_with_group(nodes_dataframes, group.into())
.map_err(PyMedRecordError::from)?)
}

pub fn remove_edges(
&mut self,
edge_indices: Vec<EdgeIndex>,
Expand Down Expand Up @@ -546,6 +568,17 @@ impl PyMedRecord {
.map_err(PyMedRecordError::from)?)
}

pub fn add_edges_with_group(
&mut self,
relations: Vec<(PyNodeIndex, PyNodeIndex, PyAttributes)>,
group: PyGroup,
) -> PyResult<Vec<EdgeIndex>> {
Ok(self
.0
.add_edges_with_group(relations.deep_into(), group.into())
.map_err(PyMedRecordError::from)?)
}

pub fn add_edges_dataframes(
&mut self,
edges_dataframes: Vec<(PyDataFrame, String, String)>,
Expand All @@ -556,6 +589,17 @@ impl PyMedRecord {
.map_err(PyMedRecordError::from)?)
}

pub fn add_edges_dataframes_with_group(
&mut self,
edges_dataframes: Vec<(PyDataFrame, String, String)>,
group: PyGroup,
) -> PyResult<Vec<EdgeIndex>> {
Ok(self
.0
.add_edges_dataframes_with_group(edges_dataframes, group.into())
.map_err(PyMedRecordError::from)?)
}

#[pyo3(signature = (group, node_indices_to_add=None, edge_indices_to_add=None))]
pub fn add_group(
&mut self,
Expand Down
12 changes: 12 additions & 0 deletions medmodels/_medmodels.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,15 @@ class PyMedRecord:
self, node_index: NodeIndexInputList, attribute: MedRecordAttribute
) -> None: ...
def add_nodes(self, nodes: Sequence[NodeTuple]) -> None: ...
def add_nodes_with_group(
self, nodes: Sequence[NodeTuple], group: Group
) -> None: ...
def add_nodes_dataframes(
self, nodes_dataframe: List[PolarsNodeDataFrameInput]
) -> None: ...
def add_nodes_dataframes_with_group(
self, nodes_dataframe: List[PolarsNodeDataFrameInput], group: Group
) -> None: ...
def remove_edges(
self, edge_index: EdgeIndexInputList
) -> Dict[EdgeIndex, Attributes]: ...
Expand All @@ -257,9 +263,15 @@ class PyMedRecord:
self, edge_index: EdgeIndexInputList, attribute: MedRecordAttribute
) -> None: ...
def add_edges(self, edges: Sequence[EdgeTuple]) -> List[EdgeIndex]: ...
def add_edges_with_group(
self, edges: Sequence[EdgeTuple], group: Group
) -> List[EdgeIndex]: ...
def add_edges_dataframes(
self, edges_dataframe: List[PolarsEdgeDataFrameInput]
) -> List[EdgeIndex]: ...
def add_edges_dataframes_with_group(
self, edges_dataframe: List[PolarsEdgeDataFrameInput], group: Group
) -> List[EdgeIndex]: ...
def add_group(
self,
group: Group,
Expand Down
Loading