diff --git a/lib/rexml/xpath_parser.rb b/lib/rexml/xpath_parser.rb
index 9c856a65..edd11f21 100644
--- a/lib/rexml/xpath_parser.rb
+++ b/lib/rexml/xpath_parser.rb
@@ -385,31 +385,63 @@ def calls_position_dependent_function?(expr)
expr.any? {|part| calls_position_dependent_function?(part) }
end
- # Detects simple position-based predicates that can be optimized in axis scanning, such as [1], [position()=1], [position() < 2], [position() > 3]
- # Returns operators and values such as [:==, 1], [:<, 2], [:>, 3]
+ # Detects simple position-based predicates that can be optimized in axis scanning, such as [1], [position()=1], [position() < 2], [last()-3], etc.
+ # Returns operators and values such as [:index_eq, 0], [:index_lt, 1], [:index_gt, 2], [:reverse_index_eq, 0], [:reverse_index_lt, 1], [:reverse_index_gt, 2]
# Returns nil if the predicate is not a simple position-based predicate
def position_operation(predicate_expr)
- return [:==, predicate_expr[1]] if predicate_expr[0] == :literal && predicate_expr[1].is_a?(Integer)
+ return [:index_eq, predicate_expr[1] - 1] if predicate_expr[0] == :literal && predicate_expr[1].is_a?(Integer)
+
+ reverse_index = last_minus_integer(predicate_expr)
+ return [:reverse_index_eq, reverse_index] if reverse_index
op, left, right = predicate_expr
return unless op == :eq || op == :lt || op == :lteq || op == :gt || op == :gteq
return unless [left, right].include?([:function, 'position', []])
- literal = [left, right].find {|part| part[0] == :literal && part[1].is_a?(Integer) }
- return unless literal
+ if right == [:function, 'position', []]
+ op = { eq: :eq, lt: :gt, lteq: :gteq, gt: :lt, gteq: :lteq }[op]
+ left, right = right, left
+ end
- value = literal[1]
- case op
- when :eq
- [:==, value]
- when :lt
- literal == right ? [:<, value] : [:>, value]
- when :lteq
- literal == right ? [:<, value + 1] : [:>, value - 1]
- when :gt
- literal == right ? [:>, value]: [:<, value]
- when :gteq
- literal == right ? [:>, value - 1] : [:<, value + 1]
+ index = right[1] - 1 if right[0] == :literal && right[1].is_a?(Integer)
+ reverse_index = last_minus_integer(right)
+
+ if index
+ case op
+ when :eq
+ [:index_eq, index]
+ when :lt
+ [:index_lt, index]
+ when :lteq
+ [:index_lt, index + 1]
+ when :gt
+ [:index_gt, index]
+ when :gteq
+ [:index_gt, index - 1]
+ end
+ elsif reverse_index
+ case op
+ when :eq
+ [:reverse_index_eq, reverse_index]
+ when :lt
+ [:reverse_index_gt, reverse_index]
+ when :lteq
+ [:reverse_index_gt, reverse_index - 1]
+ when :gt
+ [:reverse_index_lt, reverse_index]
+ when :gteq
+ [:reverse_index_lt, reverse_index + 1]
+ end
+ end
+ end
+
+ # If the expression is `last()-INTEGER` or `last()` (equivalent to `last()-0`), returns the integer part.
+ # Otherwise, returns nil.
+ def last_minus_integer(expr)
+ if expr == [:function, 'last', []]
+ 0
+ elsif expr[0] == :minus && expr[1] == [:function, 'last', []] && expr[2][0] == :literal && expr[2][1].is_a?(Integer)
+ expr[2][1]
end
end
@@ -435,57 +467,18 @@ def following_sibling(nodeset, tester, selector)
def preceding_following_sibling(nodeset, tester, selector, reverse:)
nodeset = nodeset.select {|node| node.respond_to?(:parent) && node.parent }
- case selector
- when :uniq
- nodeset.group_by(&:parent).flat_map do |parent, sibling_nodes|
- sets = Set.new.compare_by_identity
- sibling_nodes.each {|sibling| sets << sibling }
- children = parent.children
- children = children.reverse if reverse
- children.drop_while {|child| !sets.include?(child) }.drop(1)
- end.select(&tester)
- when :nodesets
- nodesets = nodeset.map do |node|
- parent = node.parent
- index = parent.children.index(node)
- reverse ? parent.children[0...index].reverse : parent.children[index + 1..-1]
- end
- non_optimized_nodesets_select(nodesets, tester, selector)
- else
- operator, value = selector
- nodeset.group_by(&:parent).flat_map do |parent, sibling_nodes|
- anchors = Set.new.compare_by_identity
- sibling_nodes.each {|sibling| anchors << sibling }
- children = parent.children
- children = children.reverse if reverse
- followings = children.drop_while {|child| !anchors.include?(child) }.drop(1)
- anchor_indexes = Set[0]
- last_anchor = 0
- index = 0
- matched = []
- followings.each do |node|
- if tester.call(node)
- case operator
- when :==
- # anchor_indexes only contain values smaller or equal to `index`,
- # so value <= 0 case doesn't accidentally match any node.
- matched << node if anchor_indexes.include?(index - value + 1)
- when :<
- # Position from the last anchor will be the minimum possible position for the node
- matched << node if index - last_anchor + 1 < value
- when :>
- # Position from the first anchor(==0) will be the maximum possible position for the node
- matched << node if index + 1 > value
- end
- index += 1
- end
- if anchors.include?(node)
- anchor_indexes << index
- last_anchor = index
- end
- end
- matched
+ nodeset.group_by(&:parent).flat_map do |parent, sibling_nodes|
+ anchors = Set.new.compare_by_identity.replace(sibling_nodes)
+ children = parent.children
+ children = children.reverse if reverse
+ followings = children.drop_while {|child| !anchors.include?(child) }.drop(1)
+ events = [:push]
+ followings.each do |node|
+ events << node if tester.call(node)
+ events << :push if anchors.include?(node)
end
+ anchors.size.times { events << :pop }
+ sequence_positional_scan(events, selector)
end
end
@@ -506,7 +499,7 @@ def ancestor(nodeset, tester, selector, include_self: false)
end
ancestors.select(&tester)
else
- # Slow pass
+ # Slow path
nodesets = nodeset.map do |node|
ancestors = []
ancestors << node if include_self
@@ -537,12 +530,18 @@ def non_optimized_nodesets_select(nodesets, tester, selector)
operator, value = selector
nodes =
case operator
- when :==
- nodesets.map {|nodeset| nodeset[value - 1] if value >= 1 }.compact
- when :<
- nodesets.flat_map {|nodeset| nodeset[0...value - 1] if value >= 1 }.compact
- when :>
- nodesets.flat_map {|nodeset| value <= 0 ? nodeset : nodeset.drop(value) }
+ when :index_eq
+ nodesets.map {|nodeset| nodeset[value] if value >= 0 }.compact
+ when :index_lt
+ nodesets.flat_map {|nodeset| nodeset[0...value] if value >= 0 }.compact
+ when :index_gt
+ nodesets.flat_map {|nodeset| value < 0 ? nodeset : nodeset.drop(value + 1) }
+ when :reverse_index_eq
+ nodesets.map {|nodeset| nodeset[-(value + 1)] if value >= 0 }.compact
+ when :reverse_index_lt
+ nodesets.flat_map {|nodeset| nodeset.last(value) if value > 0 }.compact
+ when :reverse_index_gt
+ nodesets.flat_map {|nodeset| value < 0 ? nodeset : nodeset[0...-(value + 1)] }
end
seen = Set.new.compare_by_identity
nodes.each {|node| seen << node }
@@ -791,48 +790,142 @@ def descendant_or_self(nodeset, tester, selector)
# Scanner for descendant axis
def descendant(nodeset, tester, selector, include_self: false)
nodeset = nodeset.select {|node| node.respond_to?(:children) }
- case selector
- when :uniq
- seen = Set.new.compare_by_identity
- recursive = ->(node) do
- node_type = node.node_type
- return if seen.include?(node)
- seen << node if node_type != :xmldecl
- return unless node_type == :element || node_type == :document
- node.children.each do |child|
- recursive.call(child)
+ targets = Set.new.compare_by_identity.replace(nodeset)
+ descendant_anchor_roots(nodeset).flat_map do |root|
+ descendant_positional_scan(root, targets, tester, selector, include_self)
+ end
+ end
+
+ def descendant_anchor_roots(nodes)
+ seen = Set.new.compare_by_identity
+ nodes.each do |node|
+ descendant_traverse(node, include_self: false) do |n|
+ if !seen.include?(n)
+ seen << n
+ true
end
end
- nodeset.each do |node|
+ end
+ nodes.reject {|node| seen.include?(node) }
+ end
+
+ def descendant_positional_scan(root, targets, tester, selector, include_self)
+ events = []
+ descendant_traverse_event(root) do |type, node|
+ if type == :enter
if include_self
- recursive.call(node)
+ events << :push if targets.include?(node)
+ events << node if tester.call(node)
else
- node.children.each(&recursive)
+ events << node if !node.equal?(root) && tester.call(node)
+ events << :push if targets.include?(node)
end
+ elsif type == :leave
+ events << :pop if targets.include?(node)
end
- seen.select(&tester)
- else
- nodesets = nodeset.map do |node|
- new_nodeset = []
- new_nodes = {}
- descendant_recursive(node, new_nodeset, new_nodes, include_self)
- new_nodeset
+ end
+ sequence_positional_scan(events, selector)
+ end
+
+ # Select nodes matching the positional predicate from a sequence of events.
+ # Events are: :push, :pop, and node
+ # push/pop is an event that pushes/pops an anchor of position-based predicate
+ def sequence_positional_scan(events, selector)
+ case selector
+ when :uniq
+ return events.grep_v(Symbol)
+ when :nodesets
+ nodes = []
+ indexes = []
+ nodesets = []
+ last_range = nil
+ events.each do |e|
+ case e
+ when :push
+ indexes << nodes.size
+ when :pop
+ start_idx = indexes.pop
+ range = start_idx...nodes.size
+ nodesets << nodes[range] if last_range != range
+ last_range = range
+ else
+ nodes << e
+ end
end
- non_optimized_nodesets_select(nodesets, tester, selector)
+ return nodesets
end
+
+ operator, value = selector
+ reverse = operator == :reverse_index_eq || operator == :reverse_index_lt || operator == :reverse_index_gt
+ events = events.reverse_each if reverse
+ start_event = reverse ? :pop : :push
+ end_event = reverse ? :push : :pop
+ anchor_indexes = []
+ anchor_set = Set.new
+ node_index = 0
+ result = []
+ events.each do |event|
+ if event == start_event
+ anchor_indexes << node_index
+ anchor_set << node_index
+ elsif event == end_event
+ idx = anchor_indexes.pop
+ anchor_set.delete(idx) if anchor_indexes.last != idx
+ else # event is a node that passed the tester
+ case operator
+ when :index_eq, :reverse_index_eq
+ result << event if anchor_set.include?(node_index - value)
+ when :index_lt, :reverse_index_lt
+ result << event if node_index - anchor_indexes.last < value
+ when :index_gt, :reverse_index_gt
+ result << event if node_index > value
+ end
+ node_index += 1
+ end
+ end
+ result
end
- def descendant_recursive(node, new_nodeset, new_nodes, include_self)
- if include_self
- return if new_nodes.key?(node)
- new_nodeset << node
- new_nodes[node] = true
+ # Scans the node and its descendants in document order
+ # and dispatches two types of events: :enter and :leave for each node.
+ def descendant_traverse_event(node)
+ stack = [node]
+ until stack.empty?
+ if stack.last
+ node = stack.last
+ # Push nil as a mark that we are entering this node.
+ # If we see this mark again, we know to pop the node and yield :leave
+ stack << nil
+ yield :enter, node
+ node_type = node.node_type
+ if node_type == :element or node_type == :document
+ node.children.reverse_each do |child|
+ stack << child if child.node_type != :xmldecl
+ end
+ end
+ else
+ stack.pop
+ node = stack.pop
+ yield :leave, node
+ end
end
+ end
+
+ # Scans the descendants of a node in document order and yields each node to the block.
+ # If a block returns falsy value, the node's descendants are not traversed.
+ def descendant_traverse(anchor_node, include_self:)
+ stack = [anchor_node]
+ until stack.empty?
+ node = stack.pop
+ if include_self || !node.equal?(anchor_node)
+ next unless yield node
+ end
- node_type = node.node_type
- if node_type == :element or node_type == :document
- node.children.each do |child|
- descendant_recursive(child, new_nodeset, new_nodes, true)
+ node_type = node.node_type
+ if node_type == :element or node_type == :document
+ node.children.reverse_each do |child|
+ stack << child if child.node_type != :xmldecl
+ end
end
end
end
@@ -885,36 +978,14 @@ def preceding_node_of( node )
# Scanner for following axis
def following(nodeset, tester, selector)
- nodesets = nodeset.select {|node| node.respond_to?(:parent) }.map do |node|
- following_nodes(node)
+ anchors = Set.new.compare_by_identity.replace(nodeset)
+ events = []
+ descendant_traverse_event(nodeset.first.document || nodeset.first.root) do |type, node|
+ events << :push if type == :leave && anchors.include?(node)
+ events << node if !events.empty? && type == :enter && tester.call(node)
end
- non_optimized_nodesets_select(nodesets, tester, selector)
- end
-
- def following_nodes(node)
- followings = []
- following_node = next_sibling_node(node)
- while following_node
- followings << following_node
- following_node = following_node_of(following_node)
- end
- followings
- end
-
- def following_node_of( node )
- return node.children[0] if node.kind_of?(Element) and node.children.size > 0
-
- next_sibling_node(node)
- end
-
- def next_sibling_node(node)
- psn = node.next_sibling_node
- while psn.nil?
- return nil if node.parent.nil? or node.parent.class == Document
- node = node.parent
- psn = node.next_sibling_node
- end
- psn
+ anchors.size.times { events << :pop }
+ sequence_positional_scan(events, selector)
end
def child(nodeset)
diff --git a/test/xpath/test_axis_preceding_sibling.rb b/test/xpath/test_axis_preceding_sibling.rb
index 0e1ddff2..92af633f 100644
--- a/test/xpath/test_axis_preceding_sibling.rb
+++ b/test/xpath/test_axis_preceding_sibling.rb
@@ -98,8 +98,12 @@ def test_preceding_following_sibling_multiple_anchors
XML
+ assert_equal(%w[1 3 4 5 7 9 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[@id][position() mod 4 = 1]").map {|n| n.attributes["id"] })
+ assert_equal(%w[5 9 10 12], XPath.match(doc, "/a/anchor/following-sibling::b[@id][position() mod 4 = 1]").map {|n| n.attributes["id"] })
+
assert_equal(%w[2 7 9], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() = 3]").map {|n| n.attributes["id"] })
assert_equal(%w[2 3 4 7 8 9 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() <= 3]").map {|n| n.attributes["id"] })
+ assert_equal(%w[2 3 4 7 8 9 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[4 > position()]").map {|n| n.attributes["id"] })
assert_equal(%w[1 2 3 4 5 6 7 8], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() >= 4]").map {|n| n.attributes["id"] })
assert_equal(%w[2 7 a2], XPath.match(doc, "/a/anchor/preceding-sibling::*[@id][position() = 3]").map {|n| n.attributes["id"] })
assert_equal(%w[2 3 4 7 8 9 a2 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::*[@id][position() <= 3]").map {|n| n.attributes["id"] })
@@ -108,9 +112,22 @@ def test_preceding_following_sibling_multiple_anchors
assert_equal(%w[7 12], XPath.match(doc, "/a/anchor/following-sibling::b[position() = 3]").map {|n| n.attributes["id"] })
assert_equal(%w[5 6 7 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[position() <= 3]").map {|n| n.attributes["id"] })
assert_equal(%w[8 9 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[position() >= 4]").map {|n| n.attributes["id"] })
+ assert_equal(%w[8 9 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[3 < position()]").map {|n| n.attributes["id"] })
assert_equal(%w[7 a3], XPath.match(doc, "/a/anchor/following-sibling::*[@id][position() = 3]").map {|n| n.attributes["id"] })
assert_equal(%w[5 6 7 10 11 a3 12], XPath.match(doc, "/a/anchor/following-sibling::*[@id][position() <= 3]").map {|n| n.attributes["id"] })
assert_equal(%w[8 9 a2 10 11 a3 12], XPath.match(doc, "/a/anchor/following-sibling::*[@id][position() >= 4]").map {|n| n.attributes["id"] })
+
+ assert_equal(%w[1], XPath.match(doc, "/a/anchor/preceding-sibling::b[last()]").map {|n| n.attributes["id"] })
+ assert_equal(%w[4], XPath.match(doc, "/a/anchor/preceding-sibling::b[last() - 3]").map {|n| n.attributes["id"] })
+ assert_equal(%w[4 5 6 7 8 9 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() <= last() - 3]").map {|n| n.attributes["id"] })
+ assert_equal(%w[4 5 6 7 8 9 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[last() - 2 > position()]").map {|n| n.attributes["id"] })
+ assert_equal(%w[1 2 3 4 5], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() >= last() - 4]").map {|n| n.attributes["id"] })
+
+ assert_equal(%w[12], XPath.match(doc, "/a/anchor/following-sibling::b[last()]").map {|n| n.attributes["id"] })
+ assert_equal(%w[9], XPath.match(doc, "/a/anchor/following-sibling::b[last() - 3]").map {|n| n.attributes["id"] })
+ assert_equal(%w[5 6 7 8 9], XPath.match(doc, "/a/anchor/following-sibling::b[position() <= last() - 3]").map {|n| n.attributes["id"] })
+ assert_equal(%w[8 9 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[position() >= last() - 4]").map {|n| n.attributes["id"] })
+ assert_equal(%w[8 9 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[last() - 5 < position()]").map {|n| n.attributes["id"] })
end
end
end
diff --git a/test/xpath/test_base.rb b/test/xpath/test_base.rb
index 911fda12..1d9dc67e 100644
--- a/test/xpath/test_base.rb
+++ b/test/xpath/test_base.rb
@@ -93,6 +93,56 @@ def test_descendant
assert_equal( 1, XPath.match( doc, "/descendant::z[1]" ).size )
end
+ def test_descendant_positions
+ positions = ['[1]', '[position() <= 2]', '[last()]', '[position() mod 3 = 1]']
+ anchors = {
+ '//a[@id=1]' => [%w[2], %w[2 3], %w[18], %w[2 5 8 10 13 16]],
+ '//b[@id=2]' => [%w[3], %w[3 4], %w[4], %w[3]],
+ '//e[@id=9]' => [%w[10], %w[10 11], %w[14], %w[10 13]],
+ '//f[@id=12]' => [%w[13], %w[13], %w[13], %w[13]],
+ }
+ anchors.each do |anchor_xpath, expecteds|
+ positions.zip(expecteds).each do |position, expected|
+ xpath = "#{anchor_xpath}/descendant::*#{position}"
+ assert_equal(expected, XPath.match(@@doc, xpath).map {|n| n['id'] }, xpath)
+ end
+ end
+ expecteds = anchors.values.transpose.map {|ids| ids.flatten.uniq.sort_by(&:to_i) }
+ positions.zip(expecteds).each do |position, expected|
+ xpath = "(#{anchors.keys.join('|')})/descendant::*#{position}"
+ assert_equal(expected, XPath.match(@@doc, xpath).map {|n| n['id'] }, xpath)
+ end
+ end
+
+ def test_descendant_or_self_positions
+ positions = ['[1]', '[position() <= 2]', '[last()]', '[position() mod 3 = 1]']
+ anchors = {
+ '//a[@id=1]' => [%w[1], %w[1 2], %w[18], %w[1 4 7 9 12 15 18]],
+ '//b[@id=2]' => [%w[2], %w[2 3], %w[4], %w[2]],
+ '//e[@id=9]' => [%w[9], %w[9 10], %w[14], %w[9 12]],
+ '//f[@id=12]' => [%w[12], %w[12 13], %w[13], %w[12]],
+ }
+ anchors.each do |anchor_xpath, expecteds|
+ positions.zip(expecteds).each do |position, expected|
+ xpath = "#{anchor_xpath}/descendant-or-self::*#{position}"
+ assert_equal(expected, XPath.match(@@doc, xpath).map {|n| n['id'] }, xpath)
+ end
+ end
+ expecteds = anchors.values.transpose.map {|ids| ids.flatten.uniq.sort_by(&:to_i) }
+ positions.zip(expecteds).each do |position, expected|
+ xpath = "(#{anchors.keys.join('|')})/descendant-or-self::*#{position}"
+ assert_equal(expected, XPath.match(@@doc, xpath).map {|n| n['id'] }, xpath)
+ end
+ end
+
+ def test_following_positions
+ anchor_xpath = '(//b[@id=2]|//e[@id=9]|//f[@id=12])'
+ assert_equal(%w[5 14 15], XPath.match(@@doc, "#{anchor_xpath}/following::*[1]").map {|n| n['id'] })
+ assert_equal(%w[5 6 14 15 16], XPath.match(@@doc, "#{anchor_xpath}/following::*[position()<=2]").map {|n| n['id'] })
+ assert_equal(%w[18], XPath.match(@@doc, "#{anchor_xpath}/following::*[last()]").map {|n| n['id'] })
+ assert_equal(%w[5 8 10 13 14 15 16 17 18], XPath.match(@@doc, "#{anchor_xpath}/following::*[position() mod 3 = 1]").map {|n| n['id'] })
+ end
+
def test_root
source = ""
doc = Document.new( source )
diff --git a/test/xpath/test_predicate.rb b/test/xpath/test_predicate.rb
index 164e3e96..7cafc49c 100644
--- a/test/xpath/test_predicate.rb
+++ b/test/xpath/test_predicate.rb
@@ -116,6 +116,9 @@ def test_predicate_out_of_range_position
assert_equal(%w[a b c d], parser.parse("#{base}[position()>0]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base}[position()>-1]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base}[position()<10]", doc).map(&:name))
+ assert_equal(%w[], parser.parse("#{base}[position()last()-10]", doc).map(&:name))
+ assert_equal(%w[], parser.parse("#{base}[last()-10]", doc).map(&:name))
# non-optimizable case
base_no_opt = '/r/*[position()!=name()]'
@@ -128,6 +131,9 @@ def test_predicate_out_of_range_position
assert_equal(%w[a b c d], parser.parse("#{base_no_opt}[position()>0]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base_no_opt}[position()>-1]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base_no_opt}[position()<10]", doc).map(&:name))
+ assert_equal(%w[], parser.parse("#{base_no_opt}[position()last()-10]", doc).map(&:name))
+ assert_equal(%w[], parser.parse("#{base_no_opt}[last()-10]", doc).map(&:name))
end
def test_predicate_parenthesized_position