Skip to content

Commit 4fd228b

Browse files
blumuWilliam Blum
andauthored
Support for stable-baseline3 RL algorithms (#86)
This PR implements changes to the CyberBattleEnv's observation and action space to be consumable by the gym RL algorithms from https://stable-baselines3.readthedocs.io/en/master/ Changes include: - Observation fields can now be optionally padded to the shape expected by their corresponding gym space. (Requires more memory but is needed to train with stable-baseline agents) - Add gym wrappers to flatten the Action and Observation spaces from CyberBattleSim - Add option to CyberBattleEnv to allow invalid moves and return negative reward instead - Flatten multi-dimensioned `MultiBinary` spaces * works with spaces.MultiBinary([list]) and spaces.MultiBinary(number) * working with `nodes_privilegelevel` * works with `leaked_credentials` * works with `credential_cache_matrix` * works with `discovered_nodes_properties` - Add a `stable-baseline` test notebook - Fix some python 3.8 warnings Co-authored-by: William Blum <william.blum@microsoft.com>
1 parent 9b4d294 commit 4fd228b

12 files changed

Lines changed: 452 additions & 142 deletions

createstubs.sh

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ createstub() {
1818
echo stub $name already created
1919
fi
2020
}
21+
param=$1
22+
if [[ $param == "--recreate" ]]; then
23+
echo 'Deleting typing directory'
24+
rm -Rf typings/
25+
fi
26+
27+
echo 'Creating stubs'
2128

2229
mkdir -p typings/
2330

@@ -30,23 +37,21 @@ createstub ordered_set
3037
createstub asciichartpy
3138
createstub networkx
3239
createstub boolean
40+
createstub IPython
3341

3442

3543
if [ ! -d "typings/gym" ]; then
3644
pyright --createstub gym
3745
# Patch gym stubs
3846
echo ' spaces = ...' >> typings/gym/spaces/dict.pyi
3947
echo ' nvec = ...' >> typings/gym/spaces/space.pyi
48+
echo ' spaces = ...' >> typings/gym/spaces/space.pyi
49+
echo ' spaces = ...' >> typings/gym/spaces/tuple.pyi
50+
echo ' n = ...' >> typings/gym/spaces/multi_binary.pyi
4051
else
4152
echo stub gym already created
4253
fi
4354

44-
if [ ! -d "typings/IPython" ]; then
45-
pyright --createstub IPython.core.display
46-
else
47-
echo stub 'IPython' already created
48-
fi
49-
5055

5156
echo 'Typing stub generation completed'
5257

cyberbattle/_env/cyberbattle_env.py

Lines changed: 112 additions & 82 deletions
Large diffs are not rendered by default.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Space flattening wrappers fro the CyberBattleEnv gym environment.
2+
"""
3+
from collections import OrderedDict
4+
from sqlite3 import NotSupportedError
5+
from gym import spaces
6+
import numpy as np
7+
from cyberbattle._env.cyberbattle_env import DummySpace, CyberBattleEnv, Action
8+
from gym.core import ObservationWrapper, ActionWrapper
9+
10+
11+
class FlattenObservationWrapper(ObservationWrapper):
12+
"""
13+
Flatten all nested dictionaries and tuples from the
14+
observation space of a CyberBattleSim environment`CyberBattleEnv`.
15+
The resulting observation space is a dictionary containing only
16+
subspaces of types: `Discrete`, `MultiBinary`, and `MultiDiscrete`.
17+
"""
18+
19+
def flatten_multibinary_space(self, space: spaces.Space):
20+
if isinstance(space, spaces.MultiBinary):
21+
if type(space.n) in [tuple, list, np.ndarray]:
22+
flatten_dim = np.multiply.reduce(space.n)
23+
print(f'// MultiBinary flattened from {space.n} -> {flatten_dim}')
24+
return spaces.MultiBinary(flatten_dim)
25+
else:
26+
print(f'// MultiBinary already flat: {space.n}')
27+
return space
28+
else:
29+
return space
30+
31+
def __init__(self, env: CyberBattleEnv, ignore_fields=['action_mask']):
32+
ObservationWrapper.__init__(self, env)
33+
self.env = env
34+
self.ignore_fields = ignore_fields
35+
36+
space_dict = OrderedDict({})
37+
for key, space in env.observation_space.spaces.items():
38+
if key in ignore_fields:
39+
print('Filtering out field', key)
40+
elif isinstance(space, spaces.Dict):
41+
for k2, subspace in space.items():
42+
space_dict[f"{key}_{k2}"] = self.flatten_multibinary_space(subspace)
43+
elif isinstance(space, spaces.Tuple):
44+
for i, subspace in enumerate(space.spaces):
45+
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
46+
elif isinstance(space, spaces.MultiBinary):
47+
space_dict[key] = self.flatten_multibinary_space(space)
48+
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
49+
space_dict[key] = space
50+
elif isinstance(space, DummySpace):
51+
print(f'warning: unsupported observation space: {space} : {type(space)}')
52+
else:
53+
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
54+
55+
self.observation_space = spaces.Dict(space_dict)
56+
57+
def flatten_multibinary_observation(self, space, o):
58+
if isinstance(space, spaces.MultiBinary) and \
59+
type(space.n) in [tuple, list, np.ndarray] and \
60+
len(space.n) > 1:
61+
flatten_dim = np.multiply.reduce(space.n)
62+
return tuple(o.reshape(flatten_dim))
63+
else:
64+
return o
65+
66+
def observation(self, observation: dict):
67+
o = OrderedDict({})
68+
for key, space in self.env.observation_space.spaces.items():
69+
value = observation[key]
70+
if key in self.ignore_fields:
71+
continue
72+
elif isinstance(space, spaces.Dict):
73+
for subkey, subspace in space.items():
74+
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
75+
elif isinstance(space, spaces.Tuple):
76+
for i, subspace in enumerate(space.spaces):
77+
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
78+
elif isinstance(space, spaces.MultiBinary):
79+
o[key] = self.flatten_multibinary_observation(space, value)
80+
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
81+
o[key] = value
82+
elif isinstance(space, DummySpace):
83+
continue
84+
else:
85+
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
86+
87+
return o
88+
89+
90+
class FlattenActionWrapper(ActionWrapper):
91+
"""
92+
Flatten all nested dictionaries and tuples from the
93+
action space of a CyberBattleSim environment`CyberBattleEnv`.
94+
The resulting action space is a dictionary containing only
95+
subspaces of types: `Discrete`, `MultiBinary`, and `MultiDiscrete`.
96+
"""
97+
98+
def __init__(self, env: CyberBattleEnv):
99+
ActionWrapper.__init__(self, env)
100+
self.env = env
101+
102+
self.action_space = spaces.MultiDiscrete([
103+
# connect, local vulnerabilities, remote vulnerabilities
104+
1 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count,
105+
106+
# source node
107+
env.bounds.maximum_node_count,
108+
109+
# target node
110+
env.bounds.maximum_node_count,
111+
112+
# target port (for connect action only)
113+
env.bounds.port_count,
114+
115+
# target port (credentials used, for connect action only)
116+
env.bounds.maximum_total_credentials
117+
]
118+
)
119+
120+
def action(self, action: np.ndarray) -> Action:
121+
action_type = action[0]
122+
if action_type == 0:
123+
return {'connect': action[1:5]}
124+
125+
action_type -= 1
126+
if action_type < self.env.bounds.local_attacks_count:
127+
return {'local_vulnerability': np.array([action[1], action_type])}
128+
129+
action_type -= self.env.bounds.local_attacks_count
130+
if action_type < self.env.bounds.remote_attacks_count:
131+
return {'remote_vulnerability': np.array([action[1], action[2], action_type])}
132+
133+
raise NotSupportedError(f'Unsupported action: {action}')
134+
135+
def reverse_action(self, action):
136+
raise NotImplementedError

cyberbattle/_env/graph_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,15 @@ def __add_node(self, observation):
182182
creds = onp.full(self._bounds.maximum_total_credentials, -1, dtype=onp.int8)
183183
self.__graph.add_node(
184184
node_index,
185-
name=observation['discovered_nodes'][node_index],
185+
name=observation['_discovered_nodes'][node_index],
186186
privilege_level=None, flags=None, # these are set by __update_nodes()
187187
credentials=creds,
188188
has_leaked_creds=False,
189189
)
190190

191191
def __update_edges(self, observation):
192-
g_orig = observation['explored_network']
193-
node_ids = {n: i for i, n in enumerate(observation['discovered_nodes'])}
192+
g_orig = observation['_explored_network']
193+
node_ids = {n: i for i, n in enumerate(observation['_discovered_nodes'])}
194194
for (from_name, to_name), edge_properties in g_orig.edges.items():
195195
self.__graph.add_edge(node_ids[from_name], node_ids[to_name], **edge_properties)
196196

cyberbattle/agents/baseline/agent_wrapper.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ def get(self, a: StateAugmentation, node) -> ndarray:
6969
node_prop = a.observation['discovered_nodes_properties']
7070

7171
# list of all properties set/unset on the node
72-
# Remap to get rid of unknown value 0: 1 -> 1, and -1 -> 0 (and 0-> 0)
7372
assert node < len(node_prop), f'invalid node index {node} (not discovered yet)'
74-
remapped = np.array((1 + node_prop[node]) / 2, dtype=np.int_)
73+
74+
# Remap to get rid of the unknown value (2):
75+
# 1->1, 0->0, 2->0
76+
remapped = np.array(node_prop[node] % 2, dtype=np.int_)
7577
return remapped
7678

7779

@@ -85,7 +87,7 @@ def __init__(self, p: EnvironmentBounds):
8587
def get(self, a: StateAugmentation, node) -> ndarray:
8688
assert node is not None, 'feature only valid in the context of a node'
8789

88-
discovered_node_count = len(a.observation['discovered_nodes_properties'])
90+
discovered_node_count = a.observation['discovered_node_count']
8991

9092
assert node < discovered_node_count, f'invalid node index {node} (not discovered yet)'
9193

@@ -110,13 +112,14 @@ def __init__(self, p: EnvironmentBounds):
110112
super().__init__(p, [2] * p.property_count)
111113

112114
def get(self, a: StateAugmentation, node) -> ndarray:
113-
node_prop = np.array(a.observation['discovered_nodes_properties'])
115+
n = a.observation['discovered_node_count']
116+
node_prop = np.array(a.observation['discovered_nodes_properties'])[:n]
114117

115118
# keep last window of entries
116119
node_prop_window = node_prop[-self.window_size:, :]
117120

118-
# Remap to get rid of unknown value 0: 1 -> 1, and -1 -> 0 (and 0-> 0)
119-
node_prop_window_remapped = np.int32((1 + node_prop_window) / 2)
121+
# Remap to get rid of the unknown value (2)
122+
node_prop_window_remapped = np.int32(node_prop_window % 2)
120123

121124
countby = np.sum(node_prop_window_remapped, axis=0)
122125

@@ -131,9 +134,11 @@ def __init__(self, p: EnvironmentBounds):
131134
super().__init__(p, [2] * p.port_count)
132135

133136
def get(self, a: StateAugmentation, node):
134-
ccm = a.observation['credential_cache_matrix']
137+
n = a.observation['credential_cache_length']
135138
known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)
136-
known_credports[np.int32(ccm[:, 1])] = 1
139+
if n > 0:
140+
ccm = np.array(a.observation['credential_cache_matrix'])[:n]
141+
known_credports[np.int32(ccm[:, 1])] = 1
137142
return known_credports
138143

139144

@@ -145,9 +150,11 @@ def __init__(self, p: EnvironmentBounds):
145150
super().__init__(p, [2] * p.port_count)
146151

147152
def get(self, a: StateAugmentation, node):
148-
ccm = a.observation['credential_cache_matrix']
149153
known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)
150-
known_credports[np.int32(ccm[-self.window_size:, 1])] = 1
154+
n = a.observation['credential_cache_length']
155+
if n > 0:
156+
ccm = np.array(a.observation['credential_cache_matrix'])[:n]
157+
known_credports[np.int32(ccm[-self.window_size:, 1])] = 1
151158
return known_credports
152159

153160

@@ -158,8 +165,13 @@ def __init__(self, p: EnvironmentBounds):
158165
super().__init__(p, [p.maximum_total_credentials + 1] * p.port_count)
159166

160167
def get(self, a: StateAugmentation, node):
161-
ccm = a.observation['credential_cache_matrix']
162-
return np.bincount(np.int32(ccm[:, 1]), minlength=self.env_properties.port_count)
168+
n = a.observation['credential_cache_length']
169+
if n > 0:
170+
ccm = np.array(a.observation['credential_cache_matrix'])[:n]
171+
ports = np.int32(ccm[:, 1])
172+
else:
173+
ports = np.zeros(0)
174+
return np.bincount(ports, minlength=self.env_properties.port_count)
163175

164176

165177
class Feature_discovered_credential_count(Feature):
@@ -169,7 +181,8 @@ def __init__(self, p: EnvironmentBounds):
169181
super().__init__(p, [p.maximum_total_credentials + 1])
170182

171183
def get(self, a: StateAugmentation, node):
172-
return [len(a.observation['credential_cache_matrix'])]
184+
n = a.observation['credential_cache_length']
185+
return [n]
173186

174187

175188
class Feature_discovered_node_count(Feature):
@@ -179,7 +192,7 @@ def __init__(self, p: EnvironmentBounds):
179192
super().__init__(p, [p.maximum_node_count + 1])
180193

181194
def get(self, a: StateAugmentation, node):
182-
return [len(a.observation['discovered_nodes_properties'])]
195+
return [a.observation['discovered_node_count']]
183196

184197

185198
class Feature_discovered_notowned_node_count(Feature):
@@ -190,10 +203,10 @@ def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
190203
super().__init__(p, [self.clip + 1])
191204

192205
def get(self, a: StateAugmentation, node):
193-
node_props = a.observation['discovered_nodes_properties']
194-
discovered = len(node_props)
206+
discovered = a.observation['discovered_node_count']
207+
node_props = np.array(a.observation['discovered_nodes_properties'][:discovered])
195208
# here we assume that a node is owned just if all its properties are known
196-
owned = np.count_nonzero(np.all(node_props != 0, axis=1))
209+
owned = np.count_nonzero(np.all(node_props != 2, axis=1))
197210
diff = discovered - owned
198211
return [min(diff, self.clip)]
199212

@@ -355,7 +368,7 @@ def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_a
355368

356369
abstract_action_index_int = int(abstract_action_index)
357370

358-
node_prop = np.array(observation['discovered_nodes_properties'])
371+
discovered_nodes_count = observation['discovered_node_count']
359372

360373
if abstract_action_index_int < self.n_local_actions:
361374
vuln = abstract_action_index_int
@@ -365,8 +378,6 @@ def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_a
365378
if abstract_action_index_int < self.n_remote_actions:
366379
vuln = abstract_action_index_int
367380

368-
discovered_nodes_count = len(node_prop)
369-
370381
if discovered_nodes_count <= 1:
371382
return None
372383

@@ -382,11 +393,11 @@ def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_a
382393
abstract_action_index_int -= self.n_remote_actions
383394
port = np.int32(abstract_action_index_int)
384395

385-
discovered_credentials = np.array(observation['credential_cache_matrix'])
386-
n_discovered_creds = len(discovered_credentials)
396+
n_discovered_creds = observation['credential_cache_length']
387397
if n_discovered_creds <= 0:
388398
# no credential available in the cache: cannot poduce a valid connect action
389399
return None
400+
discovered_credentials = np.array(observation['credential_cache_matrix'])[:n_discovered_creds]
390401

391402
nodes_not_owned = discovered_nodes_notowned(observation)
392403

cyberbattle/agents/baseline/baseline_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import cyberbattle.agents.baseline.agent_dql as dqla
1313
import cyberbattle.agents.baseline.agent_wrapper as w
1414
import cyberbattle.agents.baseline.learner as learner
15+
import cyberbattle.agents.baseline.agent_tabularqlearning as tqa
1516

1617
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
1718

@@ -69,3 +70,24 @@ def test_agent_training() -> None:
6970
)
7071

7172
assert random_run
73+
74+
75+
def test_tabularq_agent_training() -> None:
76+
tabularq_run = learner.epsilon_greedy_search(
77+
cyberbattlechain,
78+
ep,
79+
learner=tqa.QTabularLearner(
80+
ep,
81+
gamma=0.015, learning_rate=0.01, exploit_percentile=100),
82+
episode_count=training_episode_count,
83+
iteration_count=iteration_count,
84+
epsilon=0.90,
85+
epsilon_exponential_decay=5000,
86+
epsilon_minimum=0.01,
87+
verbosity=Verbosity.Quiet,
88+
render=False,
89+
plot_episodes_length=False,
90+
title="Tabular Q-learning"
91+
)
92+
93+
assert tabularq_run

0 commit comments

Comments
 (0)