@@ -69,9 +69,11 @@ def get(self, a: StateAugmentation, node) -> ndarray:
6969 node_prop = a .observation ['discovered_nodes_properties' ]
7070
7171 # list of all properties set/unset on the node
72- # Remap to get rid of unknown value 0: 1 -> 1, and -1 -> 0 (and 0-> 0)
7372 assert node < len (node_prop ), f'invalid node index { node } (not discovered yet)'
74- remapped = np .array ((1 + node_prop [node ]) / 2 , dtype = np .int_ )
73+
74+ # Remap to get rid of the unknown value (2):
75+ # 1->1, 0->0, 2->0
76+ remapped = np .array (node_prop [node ] % 2 , dtype = np .int_ )
7577 return remapped
7678
7779
@@ -85,7 +87,7 @@ def __init__(self, p: EnvironmentBounds):
8587 def get (self , a : StateAugmentation , node ) -> ndarray :
8688 assert node is not None , 'feature only valid in the context of a node'
8789
88- discovered_node_count = len ( a .observation ['discovered_nodes_properties' ])
90+ discovered_node_count = a .observation ['discovered_node_count' ]
8991
9092 assert node < discovered_node_count , f'invalid node index { node } (not discovered yet)'
9193
@@ -110,13 +112,14 @@ def __init__(self, p: EnvironmentBounds):
110112 super ().__init__ (p , [2 ] * p .property_count )
111113
112114 def get (self , a : StateAugmentation , node ) -> ndarray :
113- node_prop = np .array (a .observation ['discovered_nodes_properties' ])
115+ n = a .observation ['discovered_node_count' ]
116+ node_prop = np .array (a .observation ['discovered_nodes_properties' ])[:n ]
114117
115118 # keep last window of entries
116119 node_prop_window = node_prop [- self .window_size :, :]
117120
118- # Remap to get rid of unknown value 0: 1 -> 1, and -1 -> 0 (and 0-> 0 )
119- node_prop_window_remapped = np .int32 (( 1 + node_prop_window ) / 2 )
121+ # Remap to get rid of the unknown value (2 )
122+ node_prop_window_remapped = np .int32 (node_prop_window % 2 )
120123
121124 countby = np .sum (node_prop_window_remapped , axis = 0 )
122125
@@ -131,9 +134,11 @@ def __init__(self, p: EnvironmentBounds):
131134 super ().__init__ (p , [2 ] * p .port_count )
132135
133136 def get (self , a : StateAugmentation , node ):
134- ccm = a .observation ['credential_cache_matrix ' ]
137+ n = a .observation ['credential_cache_length ' ]
135138 known_credports = np .zeros (self .env_properties .port_count , dtype = np .int32 )
136- known_credports [np .int32 (ccm [:, 1 ])] = 1
139+ if n > 0 :
140+ ccm = np .array (a .observation ['credential_cache_matrix' ])[:n ]
141+ known_credports [np .int32 (ccm [:, 1 ])] = 1
137142 return known_credports
138143
139144
@@ -145,9 +150,11 @@ def __init__(self, p: EnvironmentBounds):
145150 super ().__init__ (p , [2 ] * p .port_count )
146151
147152 def get (self , a : StateAugmentation , node ):
148- ccm = a .observation ['credential_cache_matrix' ]
149153 known_credports = np .zeros (self .env_properties .port_count , dtype = np .int32 )
150- known_credports [np .int32 (ccm [- self .window_size :, 1 ])] = 1
154+ n = a .observation ['credential_cache_length' ]
155+ if n > 0 :
156+ ccm = np .array (a .observation ['credential_cache_matrix' ])[:n ]
157+ known_credports [np .int32 (ccm [- self .window_size :, 1 ])] = 1
151158 return known_credports
152159
153160
@@ -158,8 +165,13 @@ def __init__(self, p: EnvironmentBounds):
158165 super ().__init__ (p , [p .maximum_total_credentials + 1 ] * p .port_count )
159166
160167 def get (self , a : StateAugmentation , node ):
161- ccm = a .observation ['credential_cache_matrix' ]
162- return np .bincount (np .int32 (ccm [:, 1 ]), minlength = self .env_properties .port_count )
168+ n = a .observation ['credential_cache_length' ]
169+ if n > 0 :
170+ ccm = np .array (a .observation ['credential_cache_matrix' ])[:n ]
171+ ports = np .int32 (ccm [:, 1 ])
172+ else :
173+ ports = np .zeros (0 )
174+ return np .bincount (ports , minlength = self .env_properties .port_count )
163175
164176
165177class Feature_discovered_credential_count (Feature ):
@@ -169,7 +181,8 @@ def __init__(self, p: EnvironmentBounds):
169181 super ().__init__ (p , [p .maximum_total_credentials + 1 ])
170182
171183 def get (self , a : StateAugmentation , node ):
172- return [len (a .observation ['credential_cache_matrix' ])]
184+ n = a .observation ['credential_cache_length' ]
185+ return [n ]
173186
174187
175188class Feature_discovered_node_count (Feature ):
@@ -179,7 +192,7 @@ def __init__(self, p: EnvironmentBounds):
179192 super ().__init__ (p , [p .maximum_node_count + 1 ])
180193
181194 def get (self , a : StateAugmentation , node ):
182- return [len ( a .observation ['discovered_nodes_properties' ]) ]
195+ return [a .observation ['discovered_node_count' ] ]
183196
184197
185198class Feature_discovered_notowned_node_count (Feature ):
@@ -190,10 +203,10 @@ def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
190203 super ().__init__ (p , [self .clip + 1 ])
191204
192205 def get (self , a : StateAugmentation , node ):
193- node_props = a .observation ['discovered_nodes_properties ' ]
194- discovered = len ( node_props )
206+ discovered = a .observation ['discovered_node_count ' ]
207+ node_props = np . array ( a . observation [ 'discovered_nodes_properties' ][: discovered ] )
195208 # here we assume that a node is owned just if all its properties are known
196- owned = np .count_nonzero (np .all (node_props != 0 , axis = 1 ))
209+ owned = np .count_nonzero (np .all (node_props != 2 , axis = 1 ))
197210 diff = discovered - owned
198211 return [min (diff , self .clip )]
199212
@@ -355,7 +368,7 @@ def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_a
355368
356369 abstract_action_index_int = int (abstract_action_index )
357370
358- node_prop = np . array ( observation ['discovered_nodes_properties' ])
371+ discovered_nodes_count = observation ['discovered_node_count' ]
359372
360373 if abstract_action_index_int < self .n_local_actions :
361374 vuln = abstract_action_index_int
@@ -365,8 +378,6 @@ def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_a
365378 if abstract_action_index_int < self .n_remote_actions :
366379 vuln = abstract_action_index_int
367380
368- discovered_nodes_count = len (node_prop )
369-
370381 if discovered_nodes_count <= 1 :
371382 return None
372383
@@ -382,11 +393,11 @@ def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_a
382393 abstract_action_index_int -= self .n_remote_actions
383394 port = np .int32 (abstract_action_index_int )
384395
385- discovered_credentials = np .array (observation ['credential_cache_matrix' ])
386- n_discovered_creds = len (discovered_credentials )
396+ n_discovered_creds = observation ['credential_cache_length' ]
387397 if n_discovered_creds <= 0 :
388398 # no credential available in the cache: cannot poduce a valid connect action
389399 return None
400+ discovered_credentials = np .array (observation ['credential_cache_matrix' ])[:n_discovered_creds ]
390401
391402 nodes_not_owned = discovered_nodes_notowned (observation )
392403
0 commit comments