-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwriteup_firstdraft.html
More file actions
80 lines (79 loc) · 8.37 KB
/
writeup_firstdraft.html
File metadata and controls
80 lines (79 loc) · 8.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
<html>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'], ["\\(","\\)"] ],
processEscapes: true
}
});
</script>
<script src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-MML-AM_CHTML' async></script>
<head>
<link rel="stylesheet" href="writeup.css">
</head>
<body>
<h1 id="interpretability-in-reinforcement-learning-agents">Interpretability in Reinforcement Learning Agents</h1>
<p><em>Dmitry Nikulin, Sebastian Kosch, Fabian Steuer, Hoagy Cunningham</em></p>
<p><a href="https://github.com/Pastafarianist/rl-attention">Github Repository</a></p>
<h3 id="attention-regularization">Attention Regularization</h3>
<h4 id="background">Background</h4>
<p>We began our project as part of AI Safety Camp 3, trying to better understand the decision making processes used by RL agents. These methods can be split into <em>ante hoc</em> and <em>post hoc</em> - whether they are applied during the training of the agent or applied to an agent after training. In the former approach, one may add an attention layer to the architecture to force the model to focus spatially. The latter includes techniques such as gradient based saliency.</p>
<p>We began by aiming to combine these methods in a way that suited reinforcement learners, which we hoped would be especially fertile because the majority of the work in this space has focused on image classifiers.</p>
<p>Our initial goal was to train a reinforcement learning agent with an attention layer, and then apply regularization to the activation of the attention layer to promote more meaningful attention maps. We started with the work of Zhao et al<a href="#fn1" class="footnoteRef" id="fnref1"><sup>1</sup></a> who used an implementation based on Rainbow with a standard neural network architecture equipped with two attention layers used to filter the output coming from the last convolutional layer.</p>
<p>We recreated the bones of this architecture in TensorFlow using the stable-baselines package. We chose it because it is a fork of an established library, OpenAI Baselines, that intends to improve on the upstream's code style and documentation completeness. However, due to the lack of Rainbow in Baselines, we used PPO as a state-of-the-art algorithm that available from the library.</p>
<p>Our initial experiments found that adding the attention layers to the baseline architecture made no noticeable change in its performance.</p>
<h4 id="penalizing attention">Penalizing Diffuse Attention</h4>
<p>In this architecture, the agent is not forced to choose one or even a limited number of parts of the image to focus its attention. In fact, it can pass a very diffuse attention tensor and therefore pass all information from the convolutional network to the decision-making layers.</p>
<p>To incentivise informative results from the attention layer in our architecture we added an extra term in the loss function.</p>
<p>We considered other forms for this loss, including the divergence from a sum-of-Gaussians, but settled on using the entropy of the final activations of the attention tensor as the basis for our loss term.</p>
<p align="center">
<b>Diagram of our architecture, showing the attenion layer. Adapted from Zhao et al.</b>
</p>
<p><img src="images/architecture.png" class="archi" alt="Diagram of the architecture used in our models." width="600"/></p>
<p>Here we can see the way in which the attention filters the results, allowing us to see the focus of the model more clearly.</p>
<p>The formula below shows the calculation of the attention loss $L$ with coefficient $c$ from the final attention layer $A$ of shape $7\times 7\times 2$.</p>
<p>$$A^1,A^2\in R^{7\times7}$$</p>
<p>$$\sum A_{i,j} = 1$$</p>
<p>$$L = c \cdot \sum_A -A_{i,j}\log A_{i,j}$$</p>
<p>We ran a number of experiments with different coefficients for the attention loss, settling on a value between 1e-2 and 1e-3 as the region of interest.</p>
<p>Our results show that at 1e-3 the performance does not seem to be at all reduced but reduced entropy is only weakly incentivized. At 5e-3 entropy of the attention activation is strongly reducd but also reduces performance, with the exception of one run which performed as well as the unregularized models.</p>
<div class="figure">
<img src="images/scatter.png" alt="Scatter plot of score at 10M steps, attention loss" class="center"/>
<p class="caption">Scatter plot of score at 10M steps, attention loss</p>
</div>
<p align="center">
<b>Activation of the final attention layer plotted over the input frames for models trained with attention coefficient of 0 (Left) and 0.005 (Right)</b>
</p>
<figure class="video_container">
<video controls="true" allowfullscreen="true", width="350"> <source src="images/a2_no_attn_loss.mp4" type="video/webm"> </video> <video controls="true" allowfullscreen="true", width="350"> <source src="images/a2_attn_loss.mp4" type="video/webm"> </video>
</figure>
<h3 id="beyond-heatmaps">Beyond Heatmaps</h3>
<p>While we were doing this we also started to think about ways in which we could better visualise the structure of the RL agent.</p>
<p>The argument for the importance of our work requires the following:</p>
<ol style="list-style-type: decimal">
<li><p>We can create datastructures representing trained models that have the capacity to be highly informative about the decision making process.</p></li>
<li><p>The datastructures are not as informative as they could be because the weights that have been learned are not conducive to being understood in this way.</p></li>
<li><p>We can find parameterisations of 'informativeness' in this context which allow us to apply regularisation and thus incentivise the models to act in a comprehensable way.</p></li>
</ol>
We started to use Simonyan gradient<a href="#fn2" class="footnoteRef" id="fnref2"><sup>2</sup></a> to visualise the attention gradient and got results that were noisy and had few peaks. Using SmoothGrad and VarGrad gave us less noisy, more specific areas, but we found that the more samples we used, the more washed out the images became, until they hardly gave any information at all, suggesting that much of the apparent structure was due to noise.
<p align="center">
<b> Simonyan Gradient (Left), Smoothgrad (Right) and Vargrad (Bottom) of the attention layer without entropy penalty. </b>
</p>
<figure class="video_container" , align="center">
<video controls="true" allowfullscreen="true", width="350", align="left"> <source src="images/simonyan_no_attn.mp4" type="video/webm"> </video> <video controls="true" allowfullscreen="true", width="350", align="right"> <source src="images/coef_0.0_sum_smoothgrad_50.mp4" type="video/webm"> </video> <video controls="true" allowfullscreen="true", width="350", align="center"> <source src="images/coef_0.0_sum_vargrad_50.mp4" type="video/webm"> </video>
</figure>
<p>We wanted to find something that could capture more.</p>
<p>We therefore started to experiment with creating more detailed structures in order to explain actions. Instead of simply taking the activation of the attention layer, we started to take the attention layer, take the most important activations, and then recursively take the most important activations of the layer below, to create a tree structure in which the components required for an action were collectively visualised.</p>
<p>[Diagram of tree structure]</p>
<p>To select the most important nodes we started using k-means clustering on each layer with the intention of using a gap-statistic to decide on the number of clusters that we should select. Unfortunately this turned out to be too slow for a goal of making an interactive visualisation tool and so we switched to using a peak detection algorithm to select which neurons to dig down into. This is then plotted on top of the stack of input frames.</p>
<p>[Squares on a breakout frame]</p>
<p>This is now being transformed into a Javascript program to allow users to explore the way a trained model plays a game in interactive 3D, selecting the filters that they want to explore and seeing the most relevant nodes drawn out.</p>
<div class="footnotes">
<hr />
<ol>
<li id="fn1"><p>Zhao Yang et al. <a href="https://arxiv.org/pdf/1812.11276.pdf">Learn to Interpret Atari Agents</a><a href="#fnref1">↩</a></p></li>
<li id="fn2"><p>Simonyan et al. <a href="https://arxiv.org/abs/1312.6034.pdf">Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps</a><a href="#fnref2">↩</a></p></li>
</ol>
</div>
</body>
</html>