diff --git a/src/trustshell/rhel_releases.py b/src/trustshell/rhel_releases.py index 9c2b231..ab9b771 100644 --- a/src/trustshell/rhel_releases.py +++ b/src/trustshell/rhel_releases.py @@ -23,13 +23,22 @@ class RHELReleaseNode: """Represents a RHEL release node with its metadata and relationships.""" - def __init__(self, name: str, node_type: str, cpes: List[str]): + def __init__( + self, + name: str, + node_type: str, + cpes: List[str], + ps_update_stream: Optional[str] = None, + ): self.name = name self.node_type = node_type # main, eus, aus, e4s # Remove any single digit like cpe:/a:redhat:enterprise_linux:9::appstream self.cpes = [ cpe for cpe in cpes if not re.search(r":redhat:enterprise_linux:\d:", cpe) ] + self.ps_update_stream = ( + ps_update_stream # The ps_update_stream associated with this node + ) self.children: Set[str] = set() self.parents: Set[str] = set() @@ -412,8 +421,11 @@ def _parse_yaml_data(self, data: Dict[str, Any]) -> None: for node_name, node_data in data["nodes"].items(): node_type = node_data.get("type", "unknown") cpes = node_data.get("cpes", []) + ps_update_stream = node_data.get( + "ps_update_stream" + ) # Extract stream field from YAML - node = RHELReleaseNode(node_name, node_type, cpes) + node = RHELReleaseNode(node_name, node_type, cpes, ps_update_stream) self.nodes[node_name] = node # Build CPE to node mapping (use node.cpes which are already filtered) @@ -483,14 +495,14 @@ def find_active_streams_for_cpe( Find active ps_update_streams that should be associated with a given CPE. This implements the rules: - 1. If CPE matches directly to an active ps_update_stream, use it - 2. If CPE matches to a parent node, consider it part of each leaf node - whose CPEs are in active streams + 1. For non-RHEL streams: If CPE matches directly to an active ps_update_stream, use it + 2. For RHEL streams: Use the RHEL release graph to find nodes matching the CPE, + then match nodes to streams using their ps_update_stream attribute Args: - cpe: The CPE to match + cpe: The CPE to match (typically from an SBOM) active_streams: Set of active ps_update_stream names - stream_cpes: Mapping of stream names to their CPEs + stream_cpes: Mapping of stream names to their CPEs (from product-definitions) Returns: Set of active stream names that should be associated with this CPE @@ -527,18 +539,16 @@ def find_active_streams_for_cpe( # Include the node itself if it's a leaf all_candidate_nodes = descendants | {node.name} - # Check if any leaf descendants have CPEs in active streams + # Check if any leaf descendants match active streams by ps_update_stream for candidate_node_name in all_candidate_nodes: if candidate_node_name in self.nodes: candidate_node = self.nodes[candidate_node_name] - # Check if this is effectively a leaf (or the original node) - # and if its CPEs are represented in active streams - for candidate_cpe in candidate_node.cpes: - for stream_name in active_streams: - if stream_name in stream_cpes: - if candidate_cpe in stream_cpes[stream_name]: - result_streams.add(stream_name) + # Check if this node's ps_update_stream matches any active stream + # This avoids relying on CPE matching between product-definitions and rhel_releases + if candidate_node.ps_update_stream: + if candidate_node.ps_update_stream in active_streams: + result_streams.add(candidate_node.ps_update_stream) return result_streams @@ -547,7 +557,7 @@ def get_all_cpes_for_stream( ) -> Set[str]: """ Get all CPEs that should be associated with a given RHEL stream by traversing - the release graph to find related nodes using CPE-based matching. + the release graph to find related nodes using ps_update_stream attribute matching. Args: stream_name: The ps_update_stream name (e.g., "rhel-9.2.0.z") @@ -562,25 +572,22 @@ def get_all_cpes_for_stream( if stream_name in stream_cpes: all_cpes.update(stream_cpes[stream_name]) - # Use the stream's CPEs to find matching nodes in the RHEL release graph - if stream_name in stream_cpes: - matching_nodes = set() - - # For each CPE in the stream, find matching nodes in the release graph - for cpe in stream_cpes[stream_name]: - nodes_for_cpe = self.find_matching_nodes_for_cpe(cpe) - matching_nodes.update(nodes_for_cpe) - - # For each matching node, collect CPEs from the node and its ancestors - for node in matching_nodes: - # Add CPEs from this node - all_cpes.update(node.cpes) - - # Add CPEs from ancestor nodes (parent releases) - ancestors = self.get_ancestors(node.name) - for ancestor_name in ancestors: - if ancestor_name in self.nodes: - all_cpes.update(self.nodes[ancestor_name].cpes) + # Find matching nodes by ps_update_stream attribute instead of CPE matching + matching_nodes = set() + for node in self.nodes.values(): + if node.ps_update_stream == stream_name: + matching_nodes.add(node) + + # For each matching node, collect CPEs from the node and its ancestors + for node in matching_nodes: + # Add CPEs from this node + all_cpes.update(node.cpes) + + # Add CPEs from ancestor nodes (parent releases) + ancestors = self.get_ancestors(node.name) + for ancestor_name in ancestors: + if ancestor_name in self.nodes: + all_cpes.update(self.nodes[ancestor_name].cpes) return all_cpes diff --git a/tests/test_product_definitions.py b/tests/test_product_definitions.py index dcb8990..a2b5cae 100644 --- a/tests/test_product_definitions.py +++ b/tests/test_product_definitions.py @@ -274,10 +274,14 @@ def _create_test_rhel_releases_yaml(self): nodes: RHEL-9.0.0.GA: type: main + ps_update_stream: rhel-9.0.0.z cpes: - cpe:/a:redhat:enterprise_linux:9::appstream - cpe:/o:redhat:enterprise_linux:9::baseos - cpe:/a:redhat:enterprise_linux:9::crb + - cpe:/a:redhat:enterprise_linux:9.0::appstream + - cpe:/o:redhat:enterprise_linux:9.0::baseos + - cpe:/a:redhat:enterprise_linux:9.0::crb RHEL-9.0.0.Z.MAIN+EUS: type: main @@ -295,6 +299,7 @@ def _create_test_rhel_releases_yaml(self): RHEL-9.2.0.GA: type: main + ps_update_stream: rhel-9.2.0.z cpes: - cpe:/a:redhat:enterprise_linux:9::appstream - cpe:/o:redhat:enterprise_linux:9::baseos @@ -306,6 +311,20 @@ def _create_test_rhel_releases_yaml(self): - cpe:/a:redhat:rhel_eus:9.2::appstream - cpe:/o:redhat:rhel_eus:9.2::baseos + RHEL-9.2.0.Z.MAIN+EUS: + type: main + cpes: + - cpe:/a:redhat:enterprise_linux:9::appstream + - cpe:/a:redhat:enterprise_linux:9.2::appstream + + RHEL-9.3.0.GA: + type: main + ps_update_stream: rhel-9.3.0.z + cpes: + - cpe:/a:redhat:enterprise_linux:9::appstream + - cpe:/a:redhat:enterprise_linux:9.2::appstream + - cpe:/a:redhat:enterprise_linux:9.3::appstream + edges: RHEL-9.0.0.GA: - RHEL-9.0.0.Z.MAIN+EUS @@ -314,6 +333,10 @@ def _create_test_rhel_releases_yaml(self): - RHEL-9.2.0.GA RHEL-9.2.0.GA: - RHEL-9.2.0.Z.EUS + - RHEL-9.2.0.Z.MAIN+EUS + RHEL-9.2.0.Z.MAIN+EUS: + - RHEL-9.2.0.Z.EUS + - RHEL-9.3.0.GA """ temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) temp_file.write(test_data) @@ -349,6 +372,14 @@ def _create_enhanced_product_definitions(self): "cpe:/o:redhat:rhel_eus:9.2::baseos", ], }, + "rhel-9.3.0.z": { + "pp_label": "rhel-9.3.0.z", + "version": "rhel-9.3.0.z", + "cpe": [ + "cpe:/a:redhat:enterprise_linux:9::appstream", + "cpe:/a:redhat:enterprise_linux:9.3::appstream", + ], + }, }, } @@ -518,5 +549,12 @@ def test_get_all_cpes_for_rhel_stream_enhanced(self, mock_service): print(f"Enhanced CPEs for rhel-9.0.0.z: {sorted(all_cpes)}") + # Get all CPEs for rhel-9.3.0.z stream + all_93_cpes = prod_defs.get_all_cpes_for_rhel_stream("rhel-9.3.0.z") + print(f"Enhanced CPEs for rhel-9.3.0.z: {sorted(all_93_cpes)}") + + assert "cpe:/a:redhat:enterprise_linux:9.2::appstream" in all_93_cpes + assert "cpe:/a:redhat:enterprise_linux:9.3::appstream" in all_93_cpes + finally: os.unlink(rhel_yaml_path) diff --git a/tests/test_rhel_releases.py b/tests/test_rhel_releases.py index f355ec4..c0ddf57 100644 --- a/tests/test_rhel_releases.py +++ b/tests/test_rhel_releases.py @@ -14,6 +14,7 @@ def create_test_rhel_data(): nodes: RHEL-9.0.0.GA: type: main + ps_update_stream: rhel-9.0.0.z cpes: - cpe:/a:redhat:enterprise_linux:9.0::appstream - cpe:/o:redhat:enterprise_linux:9.0::baseos @@ -35,6 +36,7 @@ def create_test_rhel_data(): RHEL-9.2.0.GA: type: main + ps_update_stream: rhel-9.2.0.z cpes: - cpe:/a:redhat:enterprise_linux:9.2::appstream - cpe:/o:redhat:enterprise_linux:9.2::baseos @@ -66,7 +68,6 @@ def create_test_product_definitions(): "ps_update_streams": ["rhel-9.0.0.z", "rhel-9.2.0.z"], "active_ps_update_streams": ["rhel-9.0.0.z", "rhel-9.2.0.z"], "cpe": [ - "cpe:/o:redhat:enterprise_linux:9", "cpe:/a:redhat:enterprise_linux:9", ], } @@ -209,7 +210,7 @@ class TestEnhancedProdDefs: """Test enhanced product definitions with RHEL release data.""" def test_enhance_cpe_matching_direct_match(self): - """Test enhanced CPE matching for direct matches.""" + """Test enhanced CPE matching for direct matches using ps_update_stream.""" with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f: f.write(create_test_rhel_data()) f.flush() @@ -229,12 +230,14 @@ def test_enhance_cpe_matching_direct_match(self): ], } - # Test direct match + # Test CPE that matches RHEL-9.0.0.GA node (which has ps_update_stream: rhel-9.0.0.z) + # The CPE from the GA node should match via ps_update_stream result = enhanced.enhance_cpe_matching( - "cpe:/a:redhat:rhel_eus:9.0::appstream", active_streams, stream_cpes + "cpe:/a:redhat:enterprise_linux:9.0::appstream", + active_streams, + stream_cpes, ) assert "rhel-9.0.0.z" in result - assert len(result) == 1 finally: os.unlink(f.name) @@ -275,6 +278,64 @@ def test_enhance_cpe_matching_parent_match(self): finally: os.unlink(f.name) + def test_enhance_cpe_matching_with_ps_update_stream(self): + """Test enhanced CPE matching using ps_update_stream attribute.""" + test_data_with_streams = """ +nodes: + RHEL-9.2.0.GA: + type: main + ps_update_stream: rhel-9.2.0.z + cpes: + - cpe:/a:redhat:enterprise_linux:9.2::appstream + - cpe:/o:redhat:enterprise_linux:9.2::baseos + + RHEL-9.3.0.GA: + type: main + ps_update_stream: rhel-9.3.0.z + cpes: + - cpe:/a:redhat:enterprise_linux:9.3::appstream + +edges: + RHEL-9.2.0.GA: + - RHEL-9.3.0.GA +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f: + f.write(test_data_with_streams) + f.flush() + + try: + enhanced = EnhancedProdDefs(rhel_releases_path=f.name) + + active_streams = {"rhel-9.2.0.z", "rhel-9.3.0.z"} + # Note: stream_cpes may not have all CPEs if minor version CPEs are removed + stream_cpes = { + "rhel-9.2.0.z": [ + "cpe:/a:redhat:rhel_eus:9.2::appstream", + ], + "rhel-9.3.0.z": [ + "cpe:/a:redhat:enterprise_linux:9.3::appstream", + ], + } + + # Test CPE that matches both RHEL-9.2.0.GA and RHEL-9.3.0.GA nodes + # Even though stream_cpes doesn't have all CPEs, it should still match + # because nodes have ps_update_stream attributes + result = enhanced.enhance_cpe_matching( + "cpe:/a:redhat:enterprise_linux:9.2::appstream", + active_streams, + stream_cpes, + ) + + # Should match both streams because: + # - RHEL-9.2.0.GA has ps_update_stream: rhel-9.2.0.z and contains the 9.2 CPE + # - RHEL-9.3.0.GA has ps_update_stream: rhel-9.3.0.z and contains the 9.2 CPE + # - RHEL-9.2.0.GA is a parent of RHEL-9.3.0.GA, so descendants are checked + assert "rhel-9.2.0.z" in result + assert "rhel-9.3.0.z" in result + + finally: + os.unlink(f.name) + class TestProdDefsIntegration: """Test integration with existing ProdDefs class."""