2020import re
2121import sys
2222
23+ import tree_sitter_cpp
24+ from tree_sitter import Language , Parser , Query , QueryCursor
25+
26+
27+ _TARGET_CLASS = "ReactNativeFeatureFlagsDefaults"
28+
29+
30+ def _method_query (names : set [str ]) -> str :
31+ alternation = "|" .join (re .escape (n ) for n in sorted (names ))
32+ return f"""
33+ (class_specifier
34+ name: (type_identifier) @class_name
35+ body: (field_declaration_list
36+ (function_definition
37+ declarator: (function_declarator
38+ declarator: (field_identifier) @method_name)
39+ body: (compound_statement
40+ (return_statement (_) @return_value)))
41+ )
42+ (#eq? @class_name "{ _TARGET_CLASS } ")
43+ (#match? @method_name "^({ alternation } )$")
44+ )
45+ """
46+
2347
24- def cxx_literal (value : object ) -> str :
48+ def cxx_literal (value : bool | int | float ) -> str :
2549 if isinstance (value , bool ):
2650 return "true" if value else "false"
2751 if isinstance (value , (int , float )):
@@ -33,33 +57,41 @@ def cxx_literal(value: object) -> str:
3357
3458
3559def rewrite (source : bytes , overrides : dict [str , object ]) -> bytes :
36- text = source .decode ("utf-8" )
37- for name , value in overrides .items ():
38- cxx_type = "bool" if isinstance (value , bool ) else "double"
39- pattern = rf"""
40- ( # group 1: everything up to the value
41- { cxx_type } \s+ # return type
42- { re .escape (name )} # method name
43- \s* \( \s* \) # parameter list
44- \s+ override # override specifier
45- \s* \{{ # opening brace
46- [^}}]*? # body before the return (non-greedy, no nested braces)
47- return \s+ # return keyword
60+ lang = Language (tree_sitter_cpp .language ())
61+ tree = Parser (lang ).parse (source )
62+ matches = QueryCursor (Query (lang , _method_query (overrides .keys ()))).matches (
63+ tree .root_node
64+ )
65+
66+ matched : set [str ] = set ()
67+ replacements : list [tuple [int , int , bytes ]] = []
68+
69+ for _ , match in matches :
70+ method_node = match ["method_name" ][0 ]
71+ name = source [method_node .start_byte : method_node .end_byte ].decode ("utf-8" )
72+ rv_node = match ["return_value" ][0 ]
73+ replacements .append (
74+ (
75+ rv_node .start_byte ,
76+ rv_node .end_byte ,
77+ cxx_literal (overrides [name ]).encode ("utf-8" ),
4878 )
49- [^;]+ # the value to replace
50- ( \s* ; ) # group 2: semicolon
51- """
52- text , n = re .subn (
53- pattern ,
54- rf"\g<1>{ cxx_literal (value )} \2" ,
55- text ,
56- count = 1 ,
57- flags = re .DOTALL | re .VERBOSE ,
5879 )
59- if n != 1 :
60- raise ValueError (f"{ name } not matched" )
80+ matched .add (name )
81+
82+ unmatched = set (overrides .keys ()) - matched
83+ if unmatched :
84+ raise ValueError (f"Unmatched flags: { ', ' .join (sorted (unmatched ))} " )
85+
86+ result = bytearray ()
87+ pos = 0
88+ for start , end , replacement in replacements :
89+ result .extend (source [pos :start ])
90+ result .extend (replacement )
91+ pos = end
92+ result .extend (source [pos :])
6193
62- return text . encode ( "utf-8" )
94+ return bytes ( result )
6395
6496
6597def main () -> None :
0 commit comments