@@ -35,7 +35,13 @@ def __init__(self, trx_path, coverage_path=None):
3535 self .file_coverage = []
3636
3737 def parse_trx (self ):
38- tree = ET .parse (self .trx_path )
38+ # Create a secure XML parser that disables external entity processing
39+ parser = ET .XMLParser ()
40+ parser .parser .DefaultHandler = lambda data : None
41+ parser .parser .ExternalEntityRefHandler = lambda context , base , uri , notationName : False
42+ parser .parser .EntityDeclHandler = lambda entityName , is_parameter_entity , value , base , systemId , notationName , publicId : False
43+
44+ tree = ET .parse (self .trx_path , parser )
3945 root = tree .getroot ()
4046 ns = {'t' : 'http://microsoft.com/schemas/VisualStudio/TeamTest/2010' }
4147
@@ -72,7 +78,9 @@ def parse_trx(self):
7278 duration_str = result .get ('duration' , '0' )
7379 duration = self ._parse_duration (duration_str )
7480
75- test_def = root .find (f".//t:UnitTest[@id='{ test_id } ']/t:TestMethod" , ns )
81+ # Sanitize test_id to prevent XPath injection
82+ sanitized_test_id = self ._sanitize_xml_attribute_value (test_id )
83+ test_def = root .find (f".//t:UnitTest[@id='{ sanitized_test_id } ']/t:TestMethod" , ns )
7684 class_name = test_def .get ('className' , '' ) if test_def is not None else ''
7785
7886 parts = class_name .split (',' )[0 ].rsplit ('.' , 1 )
@@ -113,7 +121,13 @@ def parse_coverage(self):
113121 if not self .coverage_path or not os .path .exists (self .coverage_path ):
114122 return
115123 try :
116- tree = ET .parse (self .coverage_path )
124+ # Create a secure XML parser that disables external entity processing
125+ parser = ET .XMLParser ()
126+ parser .parser .DefaultHandler = lambda data : None
127+ parser .parser .ExternalEntityRefHandler = lambda context , base , uri , notationName : False
128+ parser .parser .EntityDeclHandler = lambda entityName , is_parameter_entity , value , base , systemId , notationName , publicId : False
129+
130+ tree = ET .parse (self .coverage_path , parser )
117131 root = tree .getroot ()
118132 self .coverage ['lines_pct' ] = float (root .get ('line-rate' , 0 )) * 100
119133 self .coverage ['branches_pct' ] = float (root .get ('branch-rate' , 0 )) * 100
@@ -207,6 +221,44 @@ def _parse_condition_coverage(cond_str):
207221 return int (m .group (2 )), int (m .group (3 ))
208222 return 0 , 0
209223
224+ @staticmethod
225+ def _sanitize_xml_attribute_value (value ):
226+ """Sanitize XML attribute value to prevent XPath injection."""
227+ if not value :
228+ return ""
229+ # Remove potentially dangerous characters that could be used in XPath injection
230+ # Keep only alphanumeric, dash, underscore, and dot characters
231+ sanitized = re .sub (r'[^a-zA-Z0-9\-_\.]' , '' , str (value ))
232+ return sanitized
233+
234+ @staticmethod
235+ def _validate_output_path (output_path ):
236+ """Validate output path to prevent directory traversal attacks."""
237+ if not output_path :
238+ raise ValueError ("Output path cannot be empty" )
239+
240+ # Normalize the path to resolve any .. or . components
241+ normalized_path = os .path .normpath (output_path )
242+
243+ # Get absolute paths for comparison
244+ abs_output_path = os .path .abspath (normalized_path )
245+ current_dir = os .path .abspath (os .getcwd ())
246+
247+ # Ensure the output file will be created in or under the current directory
248+ if not abs_output_path .startswith (current_dir + os .sep ) and abs_output_path != current_dir :
249+ # Allow files in current directory or subdirectories only
250+ raise ValueError (f"Invalid output path: { output_path } . Path traversal detected." )
251+
252+ # Additional check: ensure no directory traversal patterns
253+ if '..' in normalized_path or normalized_path .startswith ('/' ):
254+ raise ValueError (f"Invalid output path: { output_path } . Path traversal patterns detected." )
255+
256+ # Ensure it's an HTML file
257+ if not normalized_path .lower ().endswith ('.html' ):
258+ raise ValueError (f"Output path must be an HTML file: { output_path } " )
259+
260+ return normalized_path
261+
210262 @staticmethod
211263 def _esc (text ):
212264 if text is None :
@@ -232,6 +284,9 @@ def _format_duration_display(self, seconds):
232284 return f"{ h } h { m } m"
233285
234286 def generate_html (self , output_path ):
287+ # Validate output path to prevent directory traversal attacks
288+ validated_output_path = self ._validate_output_path (output_path )
289+
235290 pass_rate = (self .results ['passed' ] / self .results ['total' ] * 100 ) if self .results ['total' ] > 0 else 0
236291
237292 by_file = {}
@@ -249,9 +304,9 @@ def generate_html(self, output_path):
249304 html += self ._html_scripts ()
250305 html += "</div></body></html>"
251306
252- with open (output_path , 'w' , encoding = 'utf-8' ) as f :
307+ with open (validated_output_path , 'w' , encoding = 'utf-8' ) as f :
253308 f .write (html )
254- return output_path
309+ return validated_output_path
255310
256311 def _html_head (self ):
257312 return """<!DOCTYPE html>
0 commit comments