Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Here are some tasks to complete before merging this PR:

## Jupyter

- [ ] Add references to the bottom of the notebook.
- [ ] Make sure it runs on cpu and gpu
- [ ] Comment the shapes of any numpy arrays or torch tensors. Make assertions on the output.
- [ ] Functions that have more than one parameter should have a `*` before the first or second parameter to force the user to use named arguments, unless there is only one parameter.
Expand All @@ -30,6 +31,6 @@ if __name__ == '__main__':
## Logseq

- [ ] Make logseq notes and flash cards.
- [ ] Do not use logseq aliases so that the graph looks clean and its more navigable.
- [ ] Use singular nouns for tags.
- [ ] Use spaces in filenames instead of `-` or `_` just so that you don't have to use aliases (ugly I know).
- [ ] Use `-` or `_` in filenames instead of spaces.
- [ ] Use aliases for spaces and plurals.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,7 @@ repos:
- id: enforce-ascii
files: notes/pages/.*\.md
- id: mdlinker
files: notes/pages/.*\.md
args:
- "--fix"
- "--allow-dirty"
22 changes: 22 additions & 0 deletions continuing_education/lib/episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class SARSA:
next_action: Action | None = None
action_log_prob: LogProb = LogProb(tensor(0.0))

def to_sars(self) -> "SARS":
return SARS.from_sarsa(self)


def collect_episode(
*, env: Env, policy: DiscreteActionPolicyInterface, max_t: int, **policy_kwargs
Expand Down Expand Up @@ -51,3 +54,22 @@ def collect_episode(

assert next_action is not None
state, action, action_logprob = next_state, next_action, next_action_logprob


@dataclass
class SARS:
state: State
action: Action
reward: float
next_state: State
done: bool

@staticmethod
def from_sarsa(sarsa: SARSA) -> "SARS":
return SARS(
state=sarsa.state,
action=sarsa.action,
reward=sarsa.reward,
next_state=sarsa.next_state,
done=sarsa.done,
)
25 changes: 15 additions & 10 deletions continuing_education/lib/experiments/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def __init__(
self.file = file
self.primary_metric = primary_metric

# Check for unstaged changes
# This is being moved to __init__ because we usually generate some notebook changes when running experiments, like plots and the like
self.repo = Repo(
path=self.file.absolute().parent, search_parent_directories=True
)
if self.repo.is_dirty(untracked_files=True):
warnings.warn(
"There are unstaged changes in the repository. Please commit or stage them before running the experiment manager."
)

@property
def is_jupytext(self) -> bool:
return (
Expand All @@ -34,29 +44,24 @@ def run_jupytext_sync(self):
subprocess.run(cmd, check=True)

def commit(self, metrics: dict[str, Any] | None = None):
repo = Repo(path=self.file.absolute().parent, search_parent_directories=True)
if metrics is None:
metrics = {}
self.run_jupytext_sync()

# Staging files
files = [self.file.relative_to(repo.working_dir)]
files = [self.file.relative_to(self.repo.working_dir)]
if self.is_jupytext:
files.append(self.file.relative_to(repo.working_dir).with_suffix(".py"))
repo.index.add(files)

# Check for unstaged changes
if repo.is_dirty(untracked_files=True):
warnings.warn(
"There are unstaged changes in the repository. Please commit or stage them before running the experiment manager."
files.append(
self.file.relative_to(self.repo.working_dir).with_suffix(".py")
)
self.repo.index.add(files)

# Committing changes
if self.primary_metric in metrics:
commit_message = f"Experiment: {self.name}, {self.primary_metric}: {metrics[self.primary_metric]}"
else:
commit_message = f"Experiment: {self.name}"
detailed_message = f"{self.description}\n\nResults:\n{pformat(metrics)}".strip()
repo.index.commit(
self.repo.index.commit(
message=commit_message + "\n\n" + detailed_message, skip_hooks=True
)
161 changes: 161 additions & 0 deletions continuing_education/policy_gradient_methods/actor_critic/a2c.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"if __name__ == \"__main__\":\n",
" __this_file = (\n",
" Path().resolve() / \"actor_critic.ipynb\"\n",
" ) # jupyter does not have __file__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda:0\n"
]
}
],
"source": [
"import torch\n",
"\n",
"DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"if __name__ == \"__main__\":\n",
" print(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"from continuing_education.policy_gradient_methods.reinforce import SamplePolicy"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 2
},
"source": [
"# Actor Critic\n",
"\n",
"So in the REINFORCE algorithm, we experimented with batch updates and saw that while it slowed down the code it led to serious improvements. The size of batching needed to introduce stability in learning is called the sample efficiency of an algorithm, and REINFORCE is particularly bad at it. In Actor-Critic we introduce asynchronicity into the environments episode generation and the policy update.\n",
"\n",
"The Actor-Critic method uses asynchronicity by having neural network \"servers\" which spawn copies of themselves to each interact with copies of the environment. Trajectories are generated by these servers and the gradient updates are sent back to the main server, which averages them and updates the policy. This is a form of parallelism, and it is a very powerful tool in reinforcement learning.\n",
"\n",
"Another concept that Actor Critic introduces is that it hybridizes policy based methods (the actor) and value based methods (the critic). This gives you many of the advantages of both methods.\n",
"\n",
"The huggingface tutorial on A2C wants us to use stable-baselines3, but I call that cheating. We will implement A2C from scratch using PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Actor(SamplePolicy):\n",
" \"\"\"This is exactly the same as our SamplePolicy from REINFORCE, but we need to add a few methods.\"\"\"\n",
"\n",
" def copy(self) -> \"Actor\":\n",
" new_actor = Actor(\n",
" state_size=self.state_size,\n",
" action_size=self.action_size,\n",
" hidden_sizes=self.hidden_sizes,\n",
" )\n",
" new_actor.network.load_state_dict(self.network.state_dict())\n",
" return new_actor\n",
"\n",
" def update(self, gradients: list[nn.Module]):\n",
" raise NotImplementedError(\"Unsure how this works\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"from continuing_education.value_based_methods.dqn.dqn import QLearningModel\n",
"\n",
"\n",
"class Critic(QLearningModel):\n",
" def copy(self) -> \"Critic\":\n",
" new_critic = Critic(\n",
" state_size=self.state_size,\n",
" action_size=self.action_size,\n",
" hidden_sizes=self.hidden_sizes,\n",
" )\n",
" new_critic.network.load_state_dict(self.network.state_dict())\n",
" return new_critic\n",
"\n",
" def update(self, gradients: list[nn.Module]):\n",
" raise NotImplementedError(\"Unsure how this works\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# References\n",
"\n",
"1. Mnih, V., Badia, A. P., Mirza, M., Graves, A., Lillicrap, T. P., Harley, T., … Kavukcuoglu, K. (2016). Asynchronous Methods for Deep Reinforcement Learning. arXiv [Cs.LG]. Retrieved from http://arxiv.org/abs/1602.01783\n",
"2. UNIT 6. ACTOR CRITIC METHODS WITH ROBOTICS ENVIRONMENTS. Hugging Face. (n.d.). https://huggingface.co/learn/deep-rl-course/unit6/introduction"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,py:percent"
},
"kernelspec": {
"display_name": "continuing_education",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
93 changes: 93 additions & 0 deletions continuing_education/policy_gradient_methods/actor_critic/a2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.7
# kernelspec:
# display_name: continuing_education
# language: python
# name: python3
# ---

# %%
# %load_ext autoreload
# %autoreload 2

# %%
from pathlib import Path

if __name__ == "__main__":
__this_file = (
Path().resolve() / "actor_critic.ipynb"
) # jupyter does not have __file__

# %%
import torch

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
print(DEVICE)

# %%
from torch import nn

from continuing_education.policy_gradient_methods.reinforce import SamplePolicy


# %% [markdown]
# # Actor Critic
#
# So in the REINFORCE algorithm, we experimented with batch updates and saw that while it slowed down the code it led to serious improvements. The size of batching needed to introduce stability in learning is called the sample efficiency of an algorithm, and REINFORCE is particularly bad at it. In Actor-Critic we introduce asynchronicity into the environments episode generation and the policy update.
#
# The Actor-Critic method uses asynchronicity by having neural network "servers" which spawn copies of themselves to each interact with copies of the environment. Trajectories are generated by these servers and the gradient updates are sent back to the main server, which averages them and updates the policy. This is a form of parallelism, and it is a very powerful tool in reinforcement learning.
#
# Another concept that Actor Critic introduces is that it hybridizes policy based methods (the actor) and value based methods (the critic). This gives you many of the advantages of both methods.
#
# The huggingface tutorial on A2C wants us to use stable-baselines3, but I call that cheating. We will implement A2C from scratch using PyTorch.


# %%
class Actor(SamplePolicy):
"""This is exactly the same as our SamplePolicy from REINFORCE, but we need to add a few methods."""

def copy(self) -> "Actor":
new_actor = Actor(
state_size=self.state_size,
action_size=self.action_size,
hidden_sizes=self.hidden_sizes,
)
new_actor.network.load_state_dict(self.network.state_dict())
return new_actor

def update(self, gradients: list[nn.Module]):
raise NotImplementedError("Unsure how this works")


# %%
from continuing_education.value_based_methods.dqn.dqn import QLearningModel


class Critic(QLearningModel):
def copy(self) -> "Critic":
new_critic = Critic(
state_size=self.state_size,
action_size=self.action_size,
hidden_sizes=self.hidden_sizes,
)
new_critic.network.load_state_dict(self.network.state_dict())
return new_critic

def update(self, gradients: list[nn.Module]):
raise NotImplementedError("Unsure how this works")


# %% [markdown]
# # References
#
# 1. Mnih, V., Badia, A. P., Mirza, M., Graves, A., Lillicrap, T. P., Harley, T., … Kavukcuoglu, K. (2016). Asynchronous Methods for Deep Reinforcement Learning. arXiv [Cs.LG]. Retrieved from http://arxiv.org/abs/1602.01783
# 2. UNIT 6. ACTOR CRITIC METHODS WITH ROBOTICS ENVIRONMENTS. Hugging Face. (n.d.). https://huggingface.co/learn/deep-rl-course/unit6/introduction
Loading