diff --git a/README.md b/README.md
index 52234ba..7e17db2 100644
--- a/README.md
+++ b/README.md
@@ -38,4 +38,9 @@ https://github.com/user-attachments/assets/80c8b16b-df09-4607-bcc6-2b0e760f03c5
#### Robot FPS:
https://github.com/user-attachments/assets/d44efbd1-59c2-4828-ae88-d8b374fb27e2
+#### Platform2D environment:
+https://github.com/user-attachments/assets/468f3eb5-ea9f-4eb0-8ca1-6b8b67f37d02
+
+
+
diff --git a/examples/Platform2D/.gitattributes b/examples/Platform2D/.gitattributes
new file mode 100644
index 0000000..8ad74f7
--- /dev/null
+++ b/examples/Platform2D/.gitattributes
@@ -0,0 +1,2 @@
+# Normalize EOL for all files that Git considers text files.
+* text=auto eol=lf
diff --git a/examples/Platform2D/.gitignore b/examples/Platform2D/.gitignore
new file mode 100644
index 0000000..7de8ea5
--- /dev/null
+++ b/examples/Platform2D/.gitignore
@@ -0,0 +1,3 @@
+# Godot 4+ specific ignores
+.godot/
+android/
diff --git a/examples/Platform2D/Platform2D.csproj b/examples/Platform2D/Platform2D.csproj
new file mode 100644
index 0000000..6fa3be0
--- /dev/null
+++ b/examples/Platform2D/Platform2D.csproj
@@ -0,0 +1,11 @@
+
+
+ net6.0
+ net7.0
+ net8.0
+ true
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/Platform2D/Platform2D.sln b/examples/Platform2D/Platform2D.sln
new file mode 100644
index 0000000..5427097
--- /dev/null
+++ b/examples/Platform2D/Platform2D.sln
@@ -0,0 +1,19 @@
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio 2012
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Platform2D", "Platform2D.csproj", "{8552EC7B-EF81-42D4-828B-B6CD9D17C897}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ ExportDebug|Any CPU = ExportDebug|Any CPU
+ ExportRelease|Any CPU = ExportRelease|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {8552EC7B-EF81-42D4-828B-B6CD9D17C897}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {8552EC7B-EF81-42D4-828B-B6CD9D17C897}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {8552EC7B-EF81-42D4-828B-B6CD9D17C897}.ExportDebug|Any CPU.ActiveCfg = ExportDebug|Any CPU
+ {8552EC7B-EF81-42D4-828B-B6CD9D17C897}.ExportDebug|Any CPU.Build.0 = ExportDebug|Any CPU
+ {8552EC7B-EF81-42D4-828B-B6CD9D17C897}.ExportRelease|Any CPU.ActiveCfg = ExportRelease|Any CPU
+ {8552EC7B-EF81-42D4-828B-B6CD9D17C897}.ExportRelease|Any CPU.Build.0 = ExportRelease|Any CPU
+ EndGlobalSection
+EndGlobal
diff --git a/examples/Platform2D/addons/godot_rl_agents/controller/ai_controller_2d.gd b/examples/Platform2D/addons/godot_rl_agents/controller/ai_controller_2d.gd
new file mode 100644
index 0000000..06d928b
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/controller/ai_controller_2d.gd
@@ -0,0 +1,136 @@
+extends Node2D
+class_name AIController2D
+
+enum ControlModes {
+ INHERIT_FROM_SYNC, ## Inherit setting from sync node
+ HUMAN, ## Test the environment manually
+ TRAINING, ## Train a model
+ ONNX_INFERENCE, ## Load a pretrained model using an .onnx file
+ RECORD_EXPERT_DEMOS ## Record observations and actions for expert demonstrations
+}
+@export var control_mode: ControlModes = ControlModes.INHERIT_FROM_SYNC
+## The path to a trained .onnx model file to use for inference (overrides the path set in sync node).
+@export var onnx_model_path := ""
+## Once the number of steps has passed, the flag 'needs_reset' will be set to 'true' for this instance.
+@export var reset_after := 1000
+
+@export_group("Record expert demos mode options")
+## Path where the demos will be saved. The file can later be used for imitation learning.
+@export var expert_demo_save_path: String
+## The action that erases the last recorded episode from the currently recorded data.
+@export var remove_last_episode_key: InputEvent
+## Action will be repeated for n frames. Will introduce control lag if larger than 1.
+## Can be used to ensure that action_repeat on inference and training matches
+## the recorded demonstrations.
+@export var action_repeat: int = 1
+
+@export_group("Multi-policy mode options")
+## Allows you to set certain agents to use different policies.
+## Changing has no effect with default SB3 training. Works with Rllib example.
+## Tutorial: https://github.com/edbeeching/godot_rl_agents/blob/main/docs/TRAINING_MULTIPLE_POLICIES.md
+@export var policy_name: String = "shared_policy"
+
+var onnx_model: ONNXModel
+
+var heuristic := "human"
+var done := false
+var reward := 0.0
+var n_steps := 0
+var needs_reset := false
+
+var _player: Node2D
+
+
+func _ready():
+ add_to_group("AGENT")
+
+
+func init(player: Node2D):
+ _player = player
+
+
+#region Methods that need implementing using the "extend script" option in Godot
+func get_obs() -> Dictionary:
+ assert(false, "the get_obs method is not implemented when extending from ai_controller")
+ return {"obs": []}
+
+
+func get_reward() -> float:
+ assert(false, "the get_reward method is not implemented when extending from ai_controller")
+ return 0.0
+
+
+func get_action_space() -> Dictionary:
+ assert(
+ false, "the get_action_space method is not implemented when extending from ai_controller"
+ )
+ return {
+ "example_actions_continous": {"size": 2, "action_type": "continuous"},
+ "example_actions_discrete": {"size": 2, "action_type": "discrete"},
+ }
+
+
+func set_action(action) -> void:
+ assert(false, "the set_action method is not implemented when extending from ai_controller")
+
+
+#endregion
+
+
+#region Methods that sometimes need implementing using the "extend script" option in Godot
+# Only needed if you are recording expert demos with this AIController
+func get_action() -> Array:
+ assert(
+ false,
+ "the get_action method is not implemented in extended AIController but demo_recorder is used"
+ )
+ return []
+
+
+# For providing additional info (e.g. `is_success` for SB3 training)
+func get_info() -> Dictionary:
+ return {}
+
+
+#endregion
+
+
+func _physics_process(delta):
+ n_steps += 1
+ if n_steps > reset_after:
+ needs_reset = true
+
+
+func get_obs_space():
+ # may need overriding if the obs space is complex
+ var obs = get_obs()
+ return {
+ "obs": {"size": [len(obs["obs"])], "space": "box"},
+ }
+
+
+func reset():
+ n_steps = 0
+ needs_reset = false
+
+
+func reset_if_done():
+ if done:
+ reset()
+
+
+func set_heuristic(h):
+ # sets the heuristic from "human" or "model" nothing to change here
+ heuristic = h
+
+
+func get_done():
+ return done
+
+
+func set_done_false():
+ done = false
+
+
+func zero_reward():
+ reward = 0.0
diff --git a/examples/Platform2D/addons/godot_rl_agents/controller/ai_controller_3d.gd b/examples/Platform2D/addons/godot_rl_agents/controller/ai_controller_3d.gd
new file mode 100644
index 0000000..61a0529
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/controller/ai_controller_3d.gd
@@ -0,0 +1,136 @@
+extends Node3D
+class_name AIController3D
+
+enum ControlModes {
+ INHERIT_FROM_SYNC, ## Inherit setting from sync node
+ HUMAN, ## Test the environment manually
+ TRAINING, ## Train a model
+ ONNX_INFERENCE, ## Load a pretrained model using an .onnx file
+ RECORD_EXPERT_DEMOS ## Record observations and actions for expert demonstrations
+}
+@export var control_mode: ControlModes = ControlModes.INHERIT_FROM_SYNC
+## The path to a trained .onnx model file to use for inference (overrides the path set in sync node).
+@export var onnx_model_path := ""
+## Once the number of steps has passed, the flag 'needs_reset' will be set to 'true' for this instance.
+@export var reset_after := 1000
+
+@export_group("Record expert demos mode options")
+## Path where the demos will be saved. The file can later be used for imitation learning.
+@export var expert_demo_save_path: String
+## The action that erases the last recorded episode from the currently recorded data.
+@export var remove_last_episode_key: InputEvent
+## Action will be repeated for n frames. Will introduce control lag if larger than 1.
+## Can be used to ensure that action_repeat on inference and training matches
+## the recorded demonstrations.
+@export var action_repeat: int = 1
+
+@export_group("Multi-policy mode options")
+## Allows you to set certain agents to use different policies.
+## Changing has no effect with default SB3 training. Works with Rllib example.
+## Tutorial: https://github.com/edbeeching/godot_rl_agents/blob/main/docs/TRAINING_MULTIPLE_POLICIES.md
+@export var policy_name: String = "shared_policy"
+
+var onnx_model: ONNXModel
+
+var heuristic := "human"
+var done := false
+var reward := 0.0
+var n_steps := 0
+var needs_reset := false
+
+var _player: Node3D
+
+
+func _ready():
+ add_to_group("AGENT")
+
+
+func init(player: Node3D):
+ _player = player
+
+
+#region Methods that need implementing using the "extend script" option in Godot
+func get_obs() -> Dictionary:
+ assert(false, "the get_obs method is not implemented when extending from ai_controller")
+ return {"obs": []}
+
+
+func get_reward() -> float:
+ assert(false, "the get_reward method is not implemented when extending from ai_controller")
+ return 0.0
+
+
+func get_action_space() -> Dictionary:
+ assert(
+ false, "the get_action_space method is not implemented when extending from ai_controller"
+ )
+ return {
+ "example_actions_continous": {"size": 2, "action_type": "continuous"},
+ "example_actions_discrete": {"size": 2, "action_type": "discrete"},
+ }
+
+
+func set_action(action) -> void:
+ assert(false, "the set_action method is not implemented when extending from ai_controller")
+
+
+#endregion
+
+
+#region Methods that sometimes need implementing using the "extend script" option in Godot
+# Only needed if you are recording expert demos with this AIController
+func get_action() -> Array:
+ assert(
+ false,
+ "the get_action method is not implemented in extended AIController but demo_recorder is used"
+ )
+ return []
+
+
+# For providing additional info (e.g. `is_success` for SB3 training)
+func get_info() -> Dictionary:
+ return {}
+
+
+#endregion
+
+
+func _physics_process(delta):
+ n_steps += 1
+ if n_steps > reset_after:
+ needs_reset = true
+
+
+func get_obs_space():
+ # may need overriding if the obs space is complex
+ var obs = get_obs()
+ return {
+ "obs": {"size": [len(obs["obs"])], "space": "box"},
+ }
+
+
+func reset():
+ n_steps = 0
+ needs_reset = false
+
+
+func reset_if_done():
+ if done:
+ reset()
+
+
+func set_heuristic(h):
+ # sets the heuristic from "human" or "model" nothing to change here
+ heuristic = h
+
+
+func get_done():
+ return done
+
+
+func set_done_false():
+ done = false
+
+
+func zero_reward():
+ reward = 0.0
diff --git a/examples/Platform2D/addons/godot_rl_agents/godot_rl_agents.gd b/examples/Platform2D/addons/godot_rl_agents/godot_rl_agents.gd
new file mode 100644
index 0000000..e4fe136
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/godot_rl_agents.gd
@@ -0,0 +1,16 @@
+@tool
+extends EditorPlugin
+
+
+func _enter_tree():
+ # Initialization of the plugin goes here.
+ # Add the new type with a name, a parent type, a script and an icon.
+ add_custom_type("Sync", "Node", preload("sync.gd"), preload("icon.png"))
+ #add_custom_type("RaycastSensor2D2", "Node", preload("raycast_sensor_2d.gd"), preload("icon.png"))
+
+
+func _exit_tree():
+ # Clean-up of the plugin goes here.
+ # Always remember to remove it from the engine when deactivated.
+ remove_custom_type("Sync")
+ #remove_custom_type("RaycastSensor2D2")
diff --git a/examples/Platform2D/addons/godot_rl_agents/icon.png b/examples/Platform2D/addons/godot_rl_agents/icon.png
new file mode 100644
index 0000000..fd8190e
Binary files /dev/null and b/examples/Platform2D/addons/godot_rl_agents/icon.png differ
diff --git a/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs b/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs
new file mode 100644
index 0000000..6dcfa18
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs
@@ -0,0 +1,109 @@
+using Godot;
+using Microsoft.ML.OnnxRuntime;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace GodotONNX
+{
+ ///
+ public partial class ONNXInference : GodotObject
+ {
+
+ private InferenceSession session;
+ ///
+ /// Path to the ONNX model. Use Initialize to change it.
+ ///
+ private string modelPath;
+ private int batchSize;
+
+ private SessionOptions SessionOpt;
+
+ ///
+ /// init function
+ ///
+ ///
+ ///
+ /// Returns the output size of the model
+ public int Initialize(string Path, int BatchSize)
+ {
+ modelPath = Path;
+ batchSize = BatchSize;
+ SessionOpt = SessionConfigurator.MakeConfiguredSessionOptions();
+ session = LoadModel(modelPath);
+ return session.OutputMetadata["output"].Dimensions[1];
+ }
+
+
+ ///
+ public Godot.Collections.Dictionary> RunInference(Godot.Collections.Array obs, int state_ins)
+ {
+ //Current model: Any (Godot Rl Agents)
+ //Expects a tensor of shape [batch_size, input_size] type float named obs and a tensor of shape [batch_size] type float named state_ins
+
+ //Fill the input tensors
+ // create span from inputSize
+ var span = new float[obs.Count]; //There's probably a better way to do this
+ for (int i = 0; i < obs.Count; i++)
+ {
+ span[i] = obs[i];
+ }
+
+ IReadOnlyCollection inputs = new List
+ {
+ NamedOnnxValue.CreateFromTensor("obs", new DenseTensor(span, new int[] { batchSize, obs.Count })),
+ NamedOnnxValue.CreateFromTensor("state_ins", new DenseTensor(new float[] { state_ins }, new int[] { batchSize }))
+ };
+ IReadOnlyCollection outputNames = new List { "output", "state_outs" }; //ONNX is sensible to these names, as well as the input names
+
+ IDisposableReadOnlyCollection results;
+ //We do not use "using" here so we get a better exception explaination later
+ try
+ {
+ results = session.Run(inputs, outputNames);
+ }
+ catch (OnnxRuntimeException e)
+ {
+ //This error usually means that the model is not compatible with the input, beacause of the input shape (size)
+ GD.Print("Error at inference: ", e);
+ return null;
+ }
+ //Can't convert IEnumerable to Variant, so we have to convert it to an array or something
+ Godot.Collections.Dictionary> output = new Godot.Collections.Dictionary>();
+ DisposableNamedOnnxValue output1 = results.First();
+ DisposableNamedOnnxValue output2 = results.Last();
+ Godot.Collections.Array output1Array = new Godot.Collections.Array();
+ Godot.Collections.Array output2Array = new Godot.Collections.Array();
+
+ foreach (float f in output1.AsEnumerable())
+ {
+ output1Array.Add(f);
+ }
+
+ foreach (float f in output2.AsEnumerable())
+ {
+ output2Array.Add(f);
+ }
+
+ output.Add(output1.Name, output1Array);
+ output.Add(output2.Name, output2Array);
+
+ //Output is a dictionary of arrays, ex: { "output" : [0.1, 0.2, 0.3, 0.4, ...], "state_outs" : [0.5, ...]}
+ results.Dispose();
+ return output;
+ }
+ ///
+ public InferenceSession LoadModel(string Path)
+ {
+ using Godot.FileAccess file = FileAccess.Open(Path, Godot.FileAccess.ModeFlags.Read);
+ byte[] model = file.GetBuffer((int)file.GetLength());
+ //file.Close(); file.Dispose(); //Close the file, then dispose the reference.
+ return new InferenceSession(model, SessionOpt); //Load the model
+ }
+ public void FreeDisposables()
+ {
+ session.Dispose();
+ SessionOpt.Dispose();
+ }
+ }
+}
diff --git a/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/SessionConfigurator.cs b/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/SessionConfigurator.cs
new file mode 100644
index 0000000..ad7a41c
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/SessionConfigurator.cs
@@ -0,0 +1,131 @@
+using Godot;
+using Microsoft.ML.OnnxRuntime;
+
+namespace GodotONNX
+{
+ ///
+
+ public static class SessionConfigurator
+ {
+ public enum ComputeName
+ {
+ CUDA,
+ ROCm,
+ DirectML,
+ CoreML,
+ CPU
+ }
+
+ ///
+ public static SessionOptions MakeConfiguredSessionOptions()
+ {
+ SessionOptions sessionOptions = new();
+ SetOptions(sessionOptions);
+ return sessionOptions;
+ }
+
+ private static void SetOptions(SessionOptions sessionOptions)
+ {
+ sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
+ ApplySystemSpecificOptions(sessionOptions);
+ }
+
+ ///
+ static public void ApplySystemSpecificOptions(SessionOptions sessionOptions)
+ {
+ //Most code for this function is verbose only, the only reason it exists is to track
+ //implementation progress of the different compute APIs.
+
+ //December 2022: CUDA is not working.
+
+ string OSName = OS.GetName(); //Get OS Name
+
+ //ComputeName ComputeAPI = ComputeCheck(); //Get Compute API
+ // //TODO: Get CPU architecture
+
+ //Linux can use OpenVINO (C#) on x64 and ROCm on x86 (GDNative/C++)
+ //Windows can use OpenVINO (C#) on x64
+ //TODO: try TensorRT instead of CUDA
+ //TODO: Use OpenVINO for Intel Graphics
+
+ // Temporarily using CPU on all platforms to avoid errors detected with DML
+ ComputeName ComputeAPI = ComputeName.CPU;
+
+ //match OS and Compute API
+ GD.Print($"OS: {OSName} Compute API: {ComputeAPI}");
+
+ // CPU is set by default without appending necessary
+ // sessionOptions.AppendExecutionProvider_CPU(0);
+
+ /*
+ switch (OSName)
+ {
+ case "Windows": //Can use CUDA, DirectML
+ if (ComputeAPI is ComputeName.CUDA)
+ {
+ //CUDA
+ //sessionOptions.AppendExecutionProvider_CUDA(0);
+ //sessionOptions.AppendExecutionProvider_DML(0);
+ }
+ else if (ComputeAPI is ComputeName.DirectML)
+ {
+ //DirectML
+ //sessionOptions.AppendExecutionProvider_DML(0);
+ }
+ break;
+ case "X11": //Can use CUDA, ROCm
+ if (ComputeAPI is ComputeName.CUDA)
+ {
+ //CUDA
+ //sessionOptions.AppendExecutionProvider_CUDA(0);
+ }
+ if (ComputeAPI is ComputeName.ROCm)
+ {
+ //ROCm, only works on x86
+ //Research indicates that this has to be compiled as a GDNative plugin
+ //GD.Print("ROCm not supported yet, using CPU.");
+ //sessionOptions.AppendExecutionProvider_CPU(0);
+ }
+ break;
+ case "macOS": //Can use CoreML
+ if (ComputeAPI is ComputeName.CoreML)
+ { //CoreML
+ //TODO: Needs testing
+ //sessionOptions.AppendExecutionProvider_CoreML(0);
+ //CoreML on ARM64, out of the box, on x64 needs .tar file from GitHub
+ }
+ break;
+ default:
+ GD.Print("OS not Supported.");
+ break;
+ }
+ */
+ }
+
+
+ ///
+ public static ComputeName ComputeCheck()
+ {
+ string adapterName = Godot.RenderingServer.GetVideoAdapterName();
+ //string adapterVendor = Godot.RenderingServer.GetVideoAdapterVendor();
+ adapterName = adapterName.ToUpper(new System.Globalization.CultureInfo(""));
+ //TODO: GPU vendors for MacOS, what do they even use these days?
+
+ if (adapterName.Contains("INTEL"))
+ {
+ return ComputeName.DirectML;
+ }
+ if (adapterName.Contains("AMD") || adapterName.Contains("RADEON"))
+ {
+ return ComputeName.DirectML;
+ }
+ if (adapterName.Contains("NVIDIA"))
+ {
+ return ComputeName.CUDA;
+ }
+
+ GD.Print("Graphics Card not recognized."); //Should use CPU
+ return ComputeName.CPU;
+ }
+ }
+}
diff --git a/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/docs/ONNXInference.xml b/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/docs/ONNXInference.xml
new file mode 100644
index 0000000..91b07d6
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/docs/ONNXInference.xml
@@ -0,0 +1,31 @@
+
+
+
+
+ The main ONNXInference Class that handles the inference process.
+
+
+
+
+ Starts the inference process.
+
+ Path to the ONNX model, expects a path inside resources.
+ How many observations will the model recieve.
+
+
+
+ Runs the given input through the model and returns the output.
+
+ Dictionary containing all observations.
+ How many different agents are creating these observations.
+ A Dictionary of arrays, containing instructions based on the observations.
+
+
+
+ Loads the given model into the inference process, using the best Execution provider available.
+
+ Path to the ONNX model, expects a path inside resources.
+ InferenceSession ready to run.
+
+
+
\ No newline at end of file
diff --git a/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/docs/SessionConfigurator.xml b/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/docs/SessionConfigurator.xml
new file mode 100644
index 0000000..f160c02
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/onnx/csharp/docs/SessionConfigurator.xml
@@ -0,0 +1,29 @@
+
+
+
+
+ The main SessionConfigurator Class that handles the execution options and providers for the inference process.
+
+
+
+
+ Creates a SessionOptions with all available execution providers.
+
+ SessionOptions with all available execution providers.
+
+
+
+ Appends any execution provider available in the current system.
+
+
+ This function is mainly verbose for tracking implementation progress of different compute APIs.
+
+
+
+
+ Checks for available GPUs.
+
+ An integer identifier for each compute platform.
+
+
+
\ No newline at end of file
diff --git a/examples/Platform2D/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd b/examples/Platform2D/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd
new file mode 100644
index 0000000..e27f2c3
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd
@@ -0,0 +1,51 @@
+extends Resource
+class_name ONNXModel
+var inferencer_script = load("res://addons/godot_rl_agents/onnx/csharp/ONNXInference.cs")
+
+var inferencer = null
+
+## How many action values the model outputs
+var action_output_size: int
+
+## Used to differentiate models
+## that only output continuous action mean (e.g. sb3, cleanrl export)
+## versus models that output mean and logstd (e.g. rllib export)
+var action_means_only: bool
+
+## Whether action_means_value has been set already for this model
+var action_means_only_set: bool
+
+# Must provide the path to the model and the batch size
+func _init(model_path, batch_size):
+ inferencer = inferencer_script.new()
+ action_output_size = inferencer.Initialize(model_path, batch_size)
+
+# This function is the one that will be called from the game,
+# requires the observation as an array and the state_ins as an int
+# returns an Array containing the action the model takes.
+func run_inference(obs: Array, state_ins: int) -> Dictionary:
+ if inferencer == null:
+ printerr("Inferencer not initialized")
+ return {}
+ return inferencer.RunInference(obs, state_ins)
+
+
+func _notification(what):
+ if what == NOTIFICATION_PREDELETE:
+ inferencer.FreeDisposables()
+ inferencer.free()
+
+# Check whether agent uses a continuous actions model with only action means or not
+func set_action_means_only(agent_action_space):
+ action_means_only_set = true
+ var continuous_only: bool = true
+ var continuous_actions: int
+ for action in agent_action_space:
+ if not agent_action_space[action]["action_type"] == "continuous":
+ continuous_only = false
+ break
+ else:
+ continuous_actions += agent_action_space[action]["size"]
+ if continuous_only:
+ if continuous_actions == action_output_size:
+ action_means_only = true
diff --git a/examples/Platform2D/addons/godot_rl_agents/plugin.cfg b/examples/Platform2D/addons/godot_rl_agents/plugin.cfg
new file mode 100644
index 0000000..b1bc988
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/plugin.cfg
@@ -0,0 +1,7 @@
+[plugin]
+
+name="GodotRLAgents"
+description="Custom nodes for the godot rl agents toolkit "
+author="Edward Beeching"
+version="0.1"
+script="godot_rl_agents.gd"
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/ExampleRaycastSensor2D.tscn b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/ExampleRaycastSensor2D.tscn
new file mode 100644
index 0000000..5edb6c7
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/ExampleRaycastSensor2D.tscn
@@ -0,0 +1,48 @@
+[gd_scene load_steps=5 format=3 uid="uid://ddeq7mn1ealyc"]
+
+[ext_resource type="Script" path="res://addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.gd" id="1"]
+
+[sub_resource type="GDScript" id="2"]
+script/source = "extends Node2D
+
+
+
+func _physics_process(delta: float) -> void:
+ print(\"step start\")
+
+"
+
+[sub_resource type="GDScript" id="1"]
+script/source = "extends RayCast2D
+
+var steps = 1
+
+func _physics_process(delta: float) -> void:
+ print(\"processing raycast\")
+ steps += 1
+ if steps % 2:
+ force_raycast_update()
+
+ print(is_colliding())
+"
+
+[sub_resource type="CircleShape2D" id="3"]
+
+[node name="ExampleRaycastSensor2D" type="Node2D"]
+script = SubResource("2")
+
+[node name="ExampleAgent" type="Node2D" parent="."]
+position = Vector2(573, 314)
+rotation = 0.286234
+
+[node name="RaycastSensor2D" type="Node2D" parent="ExampleAgent"]
+script = ExtResource("1")
+
+[node name="TestRayCast2D" type="RayCast2D" parent="."]
+script = SubResource("1")
+
+[node name="StaticBody2D" type="StaticBody2D" parent="."]
+position = Vector2(1, 52)
+
+[node name="CollisionShape2D" type="CollisionShape2D" parent="StaticBody2D"]
+shape = SubResource("3")
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/GridSensor2D.gd b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/GridSensor2D.gd
new file mode 100644
index 0000000..48b132e
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/GridSensor2D.gd
@@ -0,0 +1,235 @@
+@tool
+extends ISensor2D
+class_name GridSensor2D
+
+@export var debug_view := false:
+ get:
+ return debug_view
+ set(value):
+ debug_view = value
+ _update()
+
+@export_flags_2d_physics var detection_mask := 0:
+ get:
+ return detection_mask
+ set(value):
+ detection_mask = value
+ _update()
+
+@export var collide_with_areas := false:
+ get:
+ return collide_with_areas
+ set(value):
+ collide_with_areas = value
+ _update()
+
+@export var collide_with_bodies := true:
+ get:
+ return collide_with_bodies
+ set(value):
+ collide_with_bodies = value
+ _update()
+
+@export_range(1, 200, 0.1) var cell_width := 20.0:
+ get:
+ return cell_width
+ set(value):
+ cell_width = value
+ _update()
+
+@export_range(1, 200, 0.1) var cell_height := 20.0:
+ get:
+ return cell_height
+ set(value):
+ cell_height = value
+ _update()
+
+@export_range(1, 21, 2, "or_greater") var grid_size_x := 3:
+ get:
+ return grid_size_x
+ set(value):
+ grid_size_x = value
+ _update()
+
+@export_range(1, 21, 2, "or_greater") var grid_size_y := 3:
+ get:
+ return grid_size_y
+ set(value):
+ grid_size_y = value
+ _update()
+
+var _obs_buffer: PackedFloat64Array
+var _rectangle_shape: RectangleShape2D
+var _collision_mapping: Dictionary
+var _n_layers_per_cell: int
+
+var _highlighted_cell_color: Color
+var _standard_cell_color: Color
+
+
+func get_observation():
+ return _obs_buffer
+
+
+func _update():
+ if Engine.is_editor_hint():
+ if is_node_ready():
+ _spawn_nodes()
+
+
+func _ready() -> void:
+ _set_colors()
+
+ if Engine.is_editor_hint():
+ if get_child_count() == 0:
+ _spawn_nodes()
+ else:
+ _spawn_nodes()
+
+
+func _set_colors() -> void:
+ _standard_cell_color = Color(100.0 / 255.0, 100.0 / 255.0, 100.0 / 255.0, 100.0 / 255.0)
+ _highlighted_cell_color = Color(255.0 / 255.0, 100.0 / 255.0, 100.0 / 255.0, 100.0 / 255.0)
+
+
+func _get_collision_mapping() -> Dictionary:
+ # defines which layer is mapped to which cell obs index
+ var total_bits = 0
+ var collision_mapping = {}
+ for i in 32:
+ var bit_mask = 2 ** i
+ if (detection_mask & bit_mask) > 0:
+ collision_mapping[i] = total_bits
+ total_bits += 1
+
+ return collision_mapping
+
+
+func _spawn_nodes():
+ for cell in get_children():
+ cell.name = "_%s" % cell.name # Otherwise naming below will fail
+ cell.queue_free()
+
+ _collision_mapping = _get_collision_mapping()
+ #prints("collision_mapping", _collision_mapping, len(_collision_mapping))
+ # allocate memory for the observations
+ _n_layers_per_cell = len(_collision_mapping)
+ _obs_buffer = PackedFloat64Array()
+ _obs_buffer.resize(grid_size_x * grid_size_y * _n_layers_per_cell)
+ _obs_buffer.fill(0)
+ #prints(len(_obs_buffer), _obs_buffer )
+
+ _rectangle_shape = RectangleShape2D.new()
+ _rectangle_shape.set_size(Vector2(cell_width, cell_height))
+
+ var shift := Vector2(
+ -(grid_size_x / 2) * cell_width,
+ -(grid_size_y / 2) * cell_height,
+ )
+
+ for i in grid_size_x:
+ for j in grid_size_y:
+ var cell_position = Vector2(i * cell_width, j * cell_height) + shift
+ _create_cell(i, j, cell_position)
+
+
+func _create_cell(i: int, j: int, position: Vector2):
+ var cell := Area2D.new()
+ cell.position = position
+ cell.name = "GridCell %s %s" % [i, j]
+ cell.modulate = _standard_cell_color
+
+ if collide_with_areas:
+ cell.area_entered.connect(_on_cell_area_entered.bind(i, j))
+ cell.area_exited.connect(_on_cell_area_exited.bind(i, j))
+
+ if collide_with_bodies:
+ cell.body_entered.connect(_on_cell_body_entered.bind(i, j))
+ cell.body_exited.connect(_on_cell_body_exited.bind(i, j))
+
+ cell.collision_layer = 0
+ cell.collision_mask = detection_mask
+ cell.monitorable = true
+ add_child(cell)
+ cell.set_owner(get_tree().edited_scene_root)
+
+ var col_shape := CollisionShape2D.new()
+ col_shape.shape = _rectangle_shape
+ col_shape.name = "CollisionShape2D"
+ cell.add_child(col_shape)
+ col_shape.set_owner(get_tree().edited_scene_root)
+
+ if debug_view:
+ var quad = MeshInstance2D.new()
+ quad.name = "MeshInstance2D"
+ var quad_mesh = QuadMesh.new()
+
+ quad_mesh.set_size(Vector2(cell_width, cell_height))
+
+ quad.mesh = quad_mesh
+ cell.add_child(quad)
+ quad.set_owner(get_tree().edited_scene_root)
+
+
+func _update_obs(cell_i: int, cell_j: int, collision_layer: int, entered: bool):
+ for key in _collision_mapping:
+ var bit_mask = 2 ** key
+ if (collision_layer & bit_mask) > 0:
+ var collison_map_index = _collision_mapping[key]
+
+ var obs_index = (
+ (cell_i * grid_size_y * _n_layers_per_cell)
+ + (cell_j * _n_layers_per_cell)
+ + collison_map_index
+ )
+ #prints(obs_index, cell_i, cell_j)
+ if entered:
+ _obs_buffer[obs_index] += 1
+ else:
+ _obs_buffer[obs_index] -= 1
+
+
+func _toggle_cell(cell_i: int, cell_j: int):
+ var cell = get_node_or_null("GridCell %s %s" % [cell_i, cell_j])
+
+ if cell == null:
+ print("cell not found, returning")
+
+ var n_hits = 0
+ var start_index = (cell_i * grid_size_y * _n_layers_per_cell) + (cell_j * _n_layers_per_cell)
+ for i in _n_layers_per_cell:
+ n_hits += _obs_buffer[start_index + i]
+
+ if n_hits > 0:
+ cell.modulate = _highlighted_cell_color
+ else:
+ cell.modulate = _standard_cell_color
+
+
+func _on_cell_area_entered(area: Area2D, cell_i: int, cell_j: int):
+ #prints("_on_cell_area_entered", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, area.collision_layer, true)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
+ #print(_obs_buffer)
+
+
+func _on_cell_area_exited(area: Area2D, cell_i: int, cell_j: int):
+ #prints("_on_cell_area_exited", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, area.collision_layer, false)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
+
+
+func _on_cell_body_entered(body: Node2D, cell_i: int, cell_j: int):
+ #prints("_on_cell_body_entered", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, body.collision_layer, true)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
+
+
+func _on_cell_body_exited(body: Node2D, cell_i: int, cell_j: int):
+ #prints("_on_cell_body_exited", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, body.collision_layer, false)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/ISensor2D.gd b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/ISensor2D.gd
new file mode 100644
index 0000000..67669a1
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/ISensor2D.gd
@@ -0,0 +1,25 @@
+extends Node2D
+class_name ISensor2D
+
+var _obs: Array = []
+var _active := false
+
+
+func get_observation():
+ pass
+
+
+func activate():
+ _active = true
+
+
+func deactivate():
+ _active = false
+
+
+func _update_observation():
+ pass
+
+
+func reset():
+ pass
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.gd b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.gd
new file mode 100644
index 0000000..9bb54ed
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.gd
@@ -0,0 +1,118 @@
+@tool
+extends ISensor2D
+class_name RaycastSensor2D
+
+@export_flags_2d_physics var collision_mask := 1:
+ get:
+ return collision_mask
+ set(value):
+ collision_mask = value
+ _update()
+
+@export var collide_with_areas := false:
+ get:
+ return collide_with_areas
+ set(value):
+ collide_with_areas = value
+ _update()
+
+@export var collide_with_bodies := true:
+ get:
+ return collide_with_bodies
+ set(value):
+ collide_with_bodies = value
+ _update()
+
+@export var n_rays := 16.0:
+ get:
+ return n_rays
+ set(value):
+ n_rays = value
+ _update()
+
+@export_range(5, 3000, 5.0) var ray_length := 200:
+ get:
+ return ray_length
+ set(value):
+ ray_length = value
+ _update()
+@export_range(5, 360, 5.0) var cone_width := 360.0:
+ get:
+ return cone_width
+ set(value):
+ cone_width = value
+ _update()
+
+@export var debug_draw := true:
+ get:
+ return debug_draw
+ set(value):
+ debug_draw = value
+ _update()
+
+var _angles = []
+var rays := []
+
+
+func _update():
+ if Engine.is_editor_hint():
+ if debug_draw:
+ _spawn_nodes()
+ else:
+ for ray in get_children():
+ if ray is RayCast2D:
+ remove_child(ray)
+
+
+func _ready() -> void:
+ _spawn_nodes()
+
+
+func _spawn_nodes():
+ for ray in rays:
+ ray.queue_free()
+ rays = []
+
+ _angles = []
+ var step = cone_width / (n_rays)
+ var start = step / 2 - cone_width / 2
+
+ for i in n_rays:
+ var angle = start + i * step
+ var ray = RayCast2D.new()
+ ray.set_target_position(
+ Vector2(ray_length * cos(deg_to_rad(angle)), ray_length * sin(deg_to_rad(angle)))
+ )
+ ray.set_name("node_" + str(i))
+ ray.enabled = false
+ ray.collide_with_areas = collide_with_areas
+ ray.collide_with_bodies = collide_with_bodies
+ ray.collision_mask = collision_mask
+ add_child(ray)
+ rays.append(ray)
+
+ _angles.append(start + i * step)
+
+
+func get_observation() -> Array:
+ return self.calculate_raycasts()
+
+
+func calculate_raycasts() -> Array:
+ var result = []
+ for ray in rays:
+ ray.enabled = true
+ ray.force_raycast_update()
+ var distance = _get_raycast_distance(ray)
+ result.append(distance)
+ ray.enabled = false
+ return result
+
+
+func _get_raycast_distance(ray: RayCast2D) -> float:
+ if !ray.is_colliding():
+ return 0.0
+
+ var distance = (global_position - ray.get_collision_point()).length()
+ distance = clamp(distance, 0.0, ray_length)
+ return (ray_length - distance) / ray_length
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.tscn b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.tscn
new file mode 100644
index 0000000..5ca402c
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.tscn
@@ -0,0 +1,7 @@
+[gd_scene load_steps=2 format=3 uid="uid://drvfihk5esgmv"]
+
+[ext_resource type="Script" path="res://addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.gd" id="1"]
+
+[node name="RaycastSensor2D" type="Node2D"]
+script = ExtResource("1")
+n_rays = 17.0
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/ExampleRaycastSensor3D.tscn b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/ExampleRaycastSensor3D.tscn
new file mode 100644
index 0000000..a8057c7
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/ExampleRaycastSensor3D.tscn
@@ -0,0 +1,6 @@
+[gd_scene format=3 uid="uid://biu787qh4woik"]
+
+[node name="ExampleRaycastSensor3D" type="Node3D"]
+
+[node name="Camera3D" type="Camera3D" parent="."]
+transform = Transform3D(1, 0, 0, 0, 1, 0, 0, 0, 1, 0.804183, 0, 2.70146)
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/GridSensor3D.gd b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/GridSensor3D.gd
new file mode 100644
index 0000000..24de9a4
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/GridSensor3D.gd
@@ -0,0 +1,258 @@
+@tool
+extends ISensor3D
+class_name GridSensor3D
+
+@export var debug_view := false:
+ get:
+ return debug_view
+ set(value):
+ debug_view = value
+ _update()
+
+@export_flags_3d_physics var detection_mask := 0:
+ get:
+ return detection_mask
+ set(value):
+ detection_mask = value
+ _update()
+
+@export var collide_with_areas := false:
+ get:
+ return collide_with_areas
+ set(value):
+ collide_with_areas = value
+ _update()
+
+@export var collide_with_bodies := false:
+ # NOTE! The sensor will not detect StaticBody3D, add an area to static bodies to detect them
+ get:
+ return collide_with_bodies
+ set(value):
+ collide_with_bodies = value
+ _update()
+
+@export_range(0.1, 2, 0.1) var cell_width := 1.0:
+ get:
+ return cell_width
+ set(value):
+ cell_width = value
+ _update()
+
+@export_range(0.1, 2, 0.1) var cell_height := 1.0:
+ get:
+ return cell_height
+ set(value):
+ cell_height = value
+ _update()
+
+@export_range(1, 21, 1, "or_greater") var grid_size_x := 3:
+ get:
+ return grid_size_x
+ set(value):
+ grid_size_x = value
+ _update()
+
+@export_range(1, 21, 1, "or_greater") var grid_size_z := 3:
+ get:
+ return grid_size_z
+ set(value):
+ grid_size_z = value
+ _update()
+
+var _obs_buffer: PackedFloat64Array
+var _box_shape: BoxShape3D
+var _collision_mapping: Dictionary
+var _n_layers_per_cell: int
+
+var _highlighted_box_material: StandardMaterial3D
+var _standard_box_material: StandardMaterial3D
+
+
+func get_observation():
+ return _obs_buffer
+
+
+func reset():
+ _obs_buffer.fill(0)
+
+
+func _update():
+ if Engine.is_editor_hint():
+ if is_node_ready():
+ _spawn_nodes()
+
+
+func _ready() -> void:
+ _make_materials()
+
+ if Engine.is_editor_hint():
+ if get_child_count() == 0:
+ _spawn_nodes()
+ else:
+ _spawn_nodes()
+
+
+func _make_materials() -> void:
+ if _highlighted_box_material != null and _standard_box_material != null:
+ return
+
+ _standard_box_material = StandardMaterial3D.new()
+ _standard_box_material.set_transparency(1) # ALPHA
+ _standard_box_material.albedo_color = Color(
+ 100.0 / 255.0, 100.0 / 255.0, 100.0 / 255.0, 100.0 / 255.0
+ )
+
+ _highlighted_box_material = StandardMaterial3D.new()
+ _highlighted_box_material.set_transparency(1) # ALPHA
+ _highlighted_box_material.albedo_color = Color(
+ 255.0 / 255.0, 100.0 / 255.0, 100.0 / 255.0, 100.0 / 255.0
+ )
+
+
+func _get_collision_mapping() -> Dictionary:
+ # defines which layer is mapped to which cell obs index
+ var total_bits = 0
+ var collision_mapping = {}
+ for i in 32:
+ var bit_mask = 2 ** i
+ if (detection_mask & bit_mask) > 0:
+ collision_mapping[i] = total_bits
+ total_bits += 1
+
+ return collision_mapping
+
+
+func _spawn_nodes():
+ for cell in get_children():
+ cell.name = "_%s" % cell.name # Otherwise naming below will fail
+ cell.queue_free()
+
+ _collision_mapping = _get_collision_mapping()
+ #prints("collision_mapping", _collision_mapping, len(_collision_mapping))
+ # allocate memory for the observations
+ _n_layers_per_cell = len(_collision_mapping)
+ _obs_buffer = PackedFloat64Array()
+ _obs_buffer.resize(grid_size_x * grid_size_z * _n_layers_per_cell)
+ _obs_buffer.fill(0)
+ #prints(len(_obs_buffer), _obs_buffer )
+
+ _box_shape = BoxShape3D.new()
+ _box_shape.set_size(Vector3(cell_width, cell_height, cell_width))
+
+ var shift := Vector3(
+ -(grid_size_x / 2) * cell_width,
+ 0,
+ -(grid_size_z / 2) * cell_width,
+ )
+
+ for i in grid_size_x:
+ for j in grid_size_z:
+ var cell_position = Vector3(i * cell_width, 0.0, j * cell_width) + shift
+ _create_cell(i, j, cell_position)
+
+
+func _create_cell(i: int, j: int, position: Vector3):
+ var cell := Area3D.new()
+ cell.position = position
+ cell.name = "GridCell %s %s" % [i, j]
+
+ if collide_with_areas:
+ cell.area_entered.connect(_on_cell_area_entered.bind(i, j))
+ cell.area_exited.connect(_on_cell_area_exited.bind(i, j))
+
+ if collide_with_bodies:
+ cell.body_entered.connect(_on_cell_body_entered.bind(i, j))
+ cell.body_exited.connect(_on_cell_body_exited.bind(i, j))
+
+# cell.body_shape_entered.connect(_on_cell_body_shape_entered.bind(i, j))
+# cell.body_shape_exited.connect(_on_cell_body_shape_exited.bind(i, j))
+
+ cell.collision_layer = 0
+ cell.collision_mask = detection_mask
+ cell.monitorable = true
+ cell.input_ray_pickable = false
+ add_child(cell)
+ cell.set_owner(get_tree().edited_scene_root)
+
+ var col_shape := CollisionShape3D.new()
+ col_shape.shape = _box_shape
+ col_shape.name = "CollisionShape3D"
+ cell.add_child(col_shape)
+ col_shape.set_owner(get_tree().edited_scene_root)
+
+ if debug_view:
+ var box = MeshInstance3D.new()
+ box.name = "MeshInstance3D"
+ var box_mesh = BoxMesh.new()
+
+ box_mesh.set_size(Vector3(cell_width, cell_height, cell_width))
+ box_mesh.material = _standard_box_material
+
+ box.mesh = box_mesh
+ cell.add_child(box)
+ box.set_owner(get_tree().edited_scene_root)
+
+
+func _update_obs(cell_i: int, cell_j: int, collision_layer: int, entered: bool):
+ for key in _collision_mapping:
+ var bit_mask = 2 ** key
+ if (collision_layer & bit_mask) > 0:
+ var collison_map_index = _collision_mapping[key]
+
+ var obs_index = (
+ (cell_i * grid_size_z * _n_layers_per_cell)
+ + (cell_j * _n_layers_per_cell)
+ + collison_map_index
+ )
+ #prints(obs_index, cell_i, cell_j)
+ if entered:
+ _obs_buffer[obs_index] += 1
+ else:
+ _obs_buffer[obs_index] -= 1
+
+
+func _toggle_cell(cell_i: int, cell_j: int):
+ var cell = get_node_or_null("GridCell %s %s" % [cell_i, cell_j])
+
+ if cell == null:
+ print("cell not found, returning")
+
+ var n_hits = 0
+ var start_index = (cell_i * grid_size_z * _n_layers_per_cell) + (cell_j * _n_layers_per_cell)
+ for i in _n_layers_per_cell:
+ n_hits += _obs_buffer[start_index + i]
+
+ var cell_mesh = cell.get_node_or_null("MeshInstance3D")
+ if n_hits > 0:
+ cell_mesh.mesh.material = _highlighted_box_material
+ else:
+ cell_mesh.mesh.material = _standard_box_material
+
+
+func _on_cell_area_entered(area: Area3D, cell_i: int, cell_j: int):
+ #prints("_on_cell_area_entered", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, area.collision_layer, true)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
+ #print(_obs_buffer)
+
+
+func _on_cell_area_exited(area: Area3D, cell_i: int, cell_j: int):
+ #prints("_on_cell_area_exited", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, area.collision_layer, false)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
+
+
+func _on_cell_body_entered(body: Node3D, cell_i: int, cell_j: int):
+ #prints("_on_cell_body_entered", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, body.collision_layer, true)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
+
+
+func _on_cell_body_exited(body: Node3D, cell_i: int, cell_j: int):
+ #prints("_on_cell_body_exited", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, body.collision_layer, false)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/ISensor3D.gd b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/ISensor3D.gd
new file mode 100644
index 0000000..aca3c2d
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/ISensor3D.gd
@@ -0,0 +1,25 @@
+extends Node3D
+class_name ISensor3D
+
+var _obs: Array = []
+var _active := false
+
+
+func get_observation():
+ pass
+
+
+func activate():
+ _active = true
+
+
+func deactivate():
+ _active = false
+
+
+func _update_observation():
+ pass
+
+
+func reset():
+ pass
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RGBCameraSensor3D.gd b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RGBCameraSensor3D.gd
new file mode 100644
index 0000000..96dfb6a
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RGBCameraSensor3D.gd
@@ -0,0 +1,63 @@
+extends Node3D
+class_name RGBCameraSensor3D
+var camera_pixels = null
+
+@onready var camera_texture := $Control/CameraTexture as Sprite2D
+@onready var processed_texture := $Control/ProcessedTexture as Sprite2D
+@onready var sub_viewport := $SubViewport as SubViewport
+@onready var displayed_image: ImageTexture
+
+@export var render_image_resolution := Vector2(36, 36)
+## Display size does not affect rendered or sent image resolution.
+## Scale is relative to either render image or downscale image resolution
+## depending on which mode is set.
+@export var displayed_image_scale_factor := Vector2(8, 8)
+
+@export_group("Downscale image options")
+## Enable to downscale the rendered image before sending the obs.
+@export var downscale_image: bool = false
+## If downscale_image is true, will display the downscaled image instead of rendered image.
+@export var display_downscaled_image: bool = true
+## This is the resolution of the image that will be sent after downscaling
+@export var resized_image_resolution := Vector2(36, 36)
+
+
+func _ready():
+ sub_viewport.size = render_image_resolution
+ camera_texture.scale = displayed_image_scale_factor
+
+ if downscale_image and display_downscaled_image:
+ camera_texture.visible = false
+ processed_texture.scale = displayed_image_scale_factor
+ else:
+ processed_texture.visible = false
+
+
+func get_camera_pixel_encoding():
+ var image := camera_texture.get_texture().get_image() as Image
+
+ if downscale_image:
+ image.resize(
+ resized_image_resolution.x, resized_image_resolution.y, Image.INTERPOLATE_NEAREST
+ )
+ if display_downscaled_image:
+ if not processed_texture.texture:
+ displayed_image = ImageTexture.create_from_image(image)
+ processed_texture.texture = displayed_image
+ else:
+ displayed_image.update(image)
+
+ return image.get_data().hex_encode()
+
+
+func get_camera_shape() -> Array:
+ var size = resized_image_resolution if downscale_image else render_image_resolution
+
+ assert(
+ size.x >= 36 and size.y >= 36,
+ "Camera sensor sent image resolution must be 36x36 or larger."
+ )
+ if sub_viewport.transparent_bg:
+ return [4, size.y, size.x]
+ else:
+ return [3, size.y, size.x]
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RGBCameraSensor3D.tscn b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RGBCameraSensor3D.tscn
new file mode 100644
index 0000000..d58649c
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RGBCameraSensor3D.tscn
@@ -0,0 +1,35 @@
+[gd_scene load_steps=3 format=3 uid="uid://baaywi3arsl2m"]
+
+[ext_resource type="Script" path="res://addons/godot_rl_agents/sensors/sensors_3d/RGBCameraSensor3D.gd" id="1"]
+
+[sub_resource type="ViewportTexture" id="ViewportTexture_y72s3"]
+viewport_path = NodePath("SubViewport")
+
+[node name="RGBCameraSensor3D" type="Node3D"]
+script = ExtResource("1")
+
+[node name="RemoteTransform" type="RemoteTransform3D" parent="."]
+remote_path = NodePath("../SubViewport/Camera")
+
+[node name="SubViewport" type="SubViewport" parent="."]
+size = Vector2i(36, 36)
+render_target_update_mode = 3
+
+[node name="Camera" type="Camera3D" parent="SubViewport"]
+near = 0.5
+
+[node name="Control" type="Control" parent="."]
+layout_mode = 3
+anchors_preset = 15
+anchor_right = 1.0
+anchor_bottom = 1.0
+grow_horizontal = 2
+grow_vertical = 2
+metadata/_edit_use_anchors_ = true
+
+[node name="CameraTexture" type="Sprite2D" parent="Control"]
+texture = SubResource("ViewportTexture_y72s3")
+centered = false
+
+[node name="ProcessedTexture" type="Sprite2D" parent="Control"]
+centered = false
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RaycastSensor3D.gd b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RaycastSensor3D.gd
new file mode 100644
index 0000000..1357529
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RaycastSensor3D.gd
@@ -0,0 +1,185 @@
+@tool
+extends ISensor3D
+class_name RayCastSensor3D
+@export_flags_3d_physics var collision_mask = 1:
+ get:
+ return collision_mask
+ set(value):
+ collision_mask = value
+ _update()
+@export_flags_3d_physics var boolean_class_mask = 1:
+ get:
+ return boolean_class_mask
+ set(value):
+ boolean_class_mask = value
+ _update()
+
+@export var n_rays_width := 6.0:
+ get:
+ return n_rays_width
+ set(value):
+ n_rays_width = value
+ _update()
+
+@export var n_rays_height := 6.0:
+ get:
+ return n_rays_height
+ set(value):
+ n_rays_height = value
+ _update()
+
+@export var ray_length := 10.0:
+ get:
+ return ray_length
+ set(value):
+ ray_length = value
+ _update()
+
+@export var cone_width := 60.0:
+ get:
+ return cone_width
+ set(value):
+ cone_width = value
+ _update()
+
+@export var cone_height := 60.0:
+ get:
+ return cone_height
+ set(value):
+ cone_height = value
+ _update()
+
+@export var collide_with_areas := false:
+ get:
+ return collide_with_areas
+ set(value):
+ collide_with_areas = value
+ _update()
+
+@export var collide_with_bodies := true:
+ get:
+ return collide_with_bodies
+ set(value):
+ collide_with_bodies = value
+ _update()
+
+@export var class_sensor := false
+
+var rays := []
+var geo = null
+
+
+func _update():
+ if Engine.is_editor_hint():
+ if is_node_ready():
+ _spawn_nodes()
+
+
+func _ready() -> void:
+ if Engine.is_editor_hint():
+ if get_child_count() == 0:
+ _spawn_nodes()
+ else:
+ _spawn_nodes()
+
+
+func _spawn_nodes():
+ print("spawning nodes")
+ for ray in get_children():
+ ray.queue_free()
+ if geo:
+ geo.clear()
+ #$Lines.remove_points()
+ rays = []
+
+ var horizontal_step = cone_width / (n_rays_width)
+ var vertical_step = cone_height / (n_rays_height)
+
+ var horizontal_start = horizontal_step / 2 - cone_width / 2
+ var vertical_start = vertical_step / 2 - cone_height / 2
+
+ var points = []
+
+ for i in n_rays_width:
+ for j in n_rays_height:
+ var angle_w = horizontal_start + i * horizontal_step
+ var angle_h = vertical_start + j * vertical_step
+ #angle_h = 0.0
+ var ray = RayCast3D.new()
+ var cast_to = to_spherical_coords(ray_length, angle_w, angle_h)
+ ray.set_target_position(cast_to)
+
+ points.append(cast_to)
+
+ ray.set_name("node_" + str(i) + " " + str(j))
+ ray.enabled = true
+ ray.collide_with_bodies = collide_with_bodies
+ ray.collide_with_areas = collide_with_areas
+ ray.collision_mask = collision_mask
+ add_child(ray)
+ ray.set_owner(get_tree().edited_scene_root)
+ rays.append(ray)
+ ray.force_raycast_update()
+
+
+# if Engine.editor_hint:
+# _create_debug_lines(points)
+
+
+func _create_debug_lines(points):
+ if not geo:
+ geo = ImmediateMesh.new()
+ add_child(geo)
+
+ geo.clear()
+ geo.begin(Mesh.PRIMITIVE_LINES)
+ for point in points:
+ geo.set_color(Color.AQUA)
+ geo.add_vertex(Vector3.ZERO)
+ geo.add_vertex(point)
+ geo.end()
+
+
+func display():
+ if geo:
+ geo.display()
+
+
+func to_spherical_coords(r, inc, azimuth) -> Vector3:
+ return Vector3(
+ r * sin(deg_to_rad(inc)) * cos(deg_to_rad(azimuth)),
+ r * sin(deg_to_rad(azimuth)),
+ r * cos(deg_to_rad(inc)) * cos(deg_to_rad(azimuth))
+ )
+
+
+func get_observation() -> Array:
+ return self.calculate_raycasts()
+
+
+func calculate_raycasts() -> Array:
+ var result = []
+ for ray in rays:
+ ray.set_enabled(true)
+ ray.force_raycast_update()
+ var distance = _get_raycast_distance(ray)
+
+ result.append(distance)
+ if class_sensor:
+ var hit_class: float = 0
+ if ray.get_collider():
+ var hit_collision_layer = ray.get_collider().collision_layer
+ hit_collision_layer = hit_collision_layer & collision_mask
+ hit_class = (hit_collision_layer & boolean_class_mask) > 0
+ result.append(float(hit_class))
+ ray.set_enabled(false)
+ return result
+
+
+func _get_raycast_distance(ray: RayCast3D) -> float:
+ if !ray.is_colliding():
+ return 0.0
+
+ var distance = (global_transform.origin - ray.get_collision_point()).length()
+ distance = clamp(distance, 0.0, ray_length)
+ return (ray_length - distance) / ray_length
diff --git a/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RaycastSensor3D.tscn b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RaycastSensor3D.tscn
new file mode 100644
index 0000000..35f9796
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sensors/sensors_3d/RaycastSensor3D.tscn
@@ -0,0 +1,27 @@
+[gd_scene load_steps=2 format=3 uid="uid://b803cbh1fmy66"]
+
+[ext_resource type="Script" path="res://addons/godot_rl_agents/sensors/sensors_3d/RaycastSensor3D.gd" id="1"]
+
+[node name="RaycastSensor3D" type="Node3D"]
+script = ExtResource("1")
+n_rays_width = 4.0
+n_rays_height = 2.0
+ray_length = 11.0
+
+[node name="node_1 0" type="RayCast3D" parent="."]
+target_position = Vector3(-1.38686, -2.84701, 10.5343)
+
+[node name="node_1 1" type="RayCast3D" parent="."]
+target_position = Vector3(-1.38686, 2.84701, 10.5343)
+
+[node name="node_2 0" type="RayCast3D" parent="."]
+target_position = Vector3(1.38686, -2.84701, 10.5343)
+
+[node name="node_2 1" type="RayCast3D" parent="."]
+target_position = Vector3(1.38686, 2.84701, 10.5343)
+
+[node name="node_3 0" type="RayCast3D" parent="."]
+target_position = Vector3(4.06608, -2.84701, 9.81639)
+
+[node name="node_3 1" type="RayCast3D" parent="."]
+target_position = Vector3(4.06608, 2.84701, 9.81639)
diff --git a/examples/Platform2D/addons/godot_rl_agents/sync.gd b/examples/Platform2D/addons/godot_rl_agents/sync.gd
new file mode 100644
index 0000000..f47decb
--- /dev/null
+++ b/examples/Platform2D/addons/godot_rl_agents/sync.gd
@@ -0,0 +1,598 @@
+extends Node
+class_name Sync
+
+# --fixed-fps 2000 --disable-render-loop
+
+enum ControlModes {
+ HUMAN, ## Test the environment manually
+ TRAINING, ## Train a model
+ ONNX_INFERENCE ## Load a pretrained model using an .onnx file
+}
+@export var control_mode: ControlModes = ControlModes.TRAINING
+## Action will be repeated for n frames (Godot physics steps).
+@export_range(1, 10, 1, "or_greater") var action_repeat := 8
+## Speeds up the physics in the environment to enable faster training.
+@export_range(0, 10, 0.1, "or_greater") var speed_up := 1.0
+## The path to a trained .onnx model file to use for inference (only needed for the 'Onnx Inference' control mode).
+@export var onnx_model_path := ""
+
+# Onnx model stored for each requested path
+var onnx_models: Dictionary
+
+@onready var start_time = Time.get_ticks_msec()
+
+const MAJOR_VERSION := "0"
+const MINOR_VERSION := "7"
+const DEFAULT_PORT := "11008"
+const DEFAULT_SEED := "1"
+var stream: StreamPeerTCP = null
+var connected = false
+var message_center
+var should_connect = true
+
+var all_agents: Array
+var agents_training: Array
+## Policy name of each agent, for use with multi-policy multi-agent RL cases
+var agents_training_policy_names: Array[String] = ["shared_policy"]
+var agents_inference: Array
+var agents_heuristic: Array
+
+## For recording expert demos
+var agent_demo_record: Node
+## File path for writing recorded trajectories
+var expert_demo_save_path: String
+## Stores recorded trajectories
+var demo_trajectories: Array
+## A trajectory includes obs: Array, acts: Array, terminal (set in Python env instead)
+var current_demo_trajectory: Array
+
+var need_to_send_obs = false
+var args = null
+var initialized = false
+var just_reset = false
+var onnx_model = null
+var n_action_steps = 0
+
+var _action_space_training: Array[Dictionary] = []
+var _action_space_inference: Array[Dictionary] = []
+var _obs_space_training: Array[Dictionary] = []
+
+
+# Called when the node enters the scene tree for the first time.
+func _ready():
+ await get_parent().ready
+ get_tree().set_pause(true)
+ _initialize()
+ await get_tree().create_timer(1.0).timeout
+ get_tree().set_pause(false)
+
+
+func _initialize():
+ _get_agents()
+ args = _get_args()
+ Engine.physics_ticks_per_second = _get_speedup() * 60 # Replace with function body.
+ Engine.time_scale = _get_speedup() * 1.0
+ prints(
+ "physics ticks",
+ Engine.physics_ticks_per_second,
+ Engine.time_scale,
+ _get_speedup(),
+ speed_up
+ )
+
+ _set_heuristic("human", all_agents)
+
+ _initialize_training_agents()
+ _initialize_inference_agents()
+ _initialize_demo_recording()
+
+ _set_seed()
+ _set_action_repeat()
+ initialized = true
+
+
+func _initialize_training_agents():
+ if agents_training.size() > 0:
+ _obs_space_training.resize(agents_training.size())
+ _action_space_training.resize(agents_training.size())
+ for agent_idx in range(0, agents_training.size()):
+ _obs_space_training[agent_idx] = agents_training[agent_idx].get_obs_space()
+ _action_space_training[agent_idx] = agents_training[agent_idx].get_action_space()
+ connected = connect_to_server()
+ if connected:
+ _set_heuristic("model", agents_training)
+ _handshake()
+ _send_env_info()
+ else:
+ push_warning(
+ "Couldn't connect to Python server, using human controls instead. ",
+ "Did you start the training server using e.g. `gdrl` from the console?"
+ )
+
+
+func _initialize_inference_agents():
+ if agents_inference.size() > 0:
+ if control_mode == ControlModes.ONNX_INFERENCE:
+ assert(
+ FileAccess.file_exists(onnx_model_path),
+ "Onnx Model Path set on Sync node does not exist: %s" % onnx_model_path
+ )
+ onnx_models[onnx_model_path] = ONNXModel.new(onnx_model_path, 1)
+
+ for agent in agents_inference:
+ var action_space = agent.get_action_space()
+ _action_space_inference.append(action_space)
+
+ var agent_onnx_model: ONNXModel
+ if agent.onnx_model_path.is_empty():
+ assert(
+ onnx_models.has(onnx_model_path),
+ (
+ "Node %s has no onnx model path set " % agent.get_path()
+ + "and sync node's control mode is not set to OnnxInference. "
+ + "Either add the path to the AIController, "
+ + "or if you want to use the path set on sync node instead, "
+ + "set control mode to OnnxInference."
+ )
+ )
+ prints(
+ "Info: AIController %s" % agent.get_path(),
+ "has no onnx model path set.",
+ "Using path set on the sync node instead."
+ )
+ agent_onnx_model = onnx_models[onnx_model_path]
+ else:
+ if not onnx_models.has(agent.onnx_model_path):
+ assert(
+ FileAccess.file_exists(agent.onnx_model_path),
+ (
+ "Onnx Model Path set on %s node does not exist: %s"
+ % [agent.get_path(), agent.onnx_model_path]
+ )
+ )
+ onnx_models[agent.onnx_model_path] = ONNXModel.new(agent.onnx_model_path, 1)
+ agent_onnx_model = onnx_models[agent.onnx_model_path]
+
+ agent.onnx_model = agent_onnx_model
+ if not agent_onnx_model.action_means_only_set:
+ agent_onnx_model.set_action_means_only(action_space)
+
+ _set_heuristic("model", agents_inference)
+
+
+func _initialize_demo_recording():
+ if agent_demo_record:
+ expert_demo_save_path = agent_demo_record.expert_demo_save_path
+ assert(
+ not expert_demo_save_path.is_empty(),
+ "Expert demo save path set in %s is empty." % agent_demo_record.get_path()
+ )
+
+ InputMap.add_action("RemoveLastDemoEpisode")
+ InputMap.action_add_event(
+ "RemoveLastDemoEpisode", agent_demo_record.remove_last_episode_key
+ )
+ current_demo_trajectory.resize(2)
+ current_demo_trajectory[0] = []
+ current_demo_trajectory[1] = []
+ agent_demo_record.heuristic = "demo_record"
+
+
+func _physics_process(_delta):
+ # two modes, human control, agent control
+ # pause tree, send obs, get actions, set actions, unpause tree
+
+ _demo_record_process()
+
+ if n_action_steps % action_repeat != 0:
+ n_action_steps += 1
+ return
+
+ n_action_steps += 1
+
+ _training_process()
+ _inference_process()
+ _heuristic_process()
+
+
+func _training_process():
+ if connected:
+ get_tree().set_pause(true)
+
+ var obs = _get_obs_from_agents(agents_training)
+ var info = _get_info_from_agents(agents_training)
+
+ if just_reset:
+ just_reset = false
+
+ var reply = {"type": "reset", "obs": obs, "info": info}
+ _send_dict_as_json_message(reply)
+ # this should go straight to getting the action and setting it checked the agent, no need to perform one phyics tick
+ get_tree().set_pause(false)
+ return
+
+ if need_to_send_obs:
+ need_to_send_obs = false
+ var reward = _get_reward_from_agents()
+ var done = _get_done_from_agents()
+ #_reset_agents_if_done() # this ensures the new observation is from the next env instance : NEEDS REFACTOR
+
+ var reply = {"type": "step", "obs": obs, "reward": reward, "done": done, "info": info}
+ _send_dict_as_json_message(reply)
+
+ var handled = handle_message()
+
+
+func _inference_process():
+ if agents_inference.size() > 0:
+ var obs: Array = _get_obs_from_agents(agents_inference)
+ var actions = []
+
+ for agent_id in range(0, agents_inference.size()):
+ var model: ONNXModel = agents_inference[agent_id].onnx_model
+ var action = model.run_inference(obs[agent_id]["obs"], 1.0)
+ var action_dict = _extract_action_dict(
+ action["output"], _action_space_inference[agent_id], model.action_means_only
+ )
+ actions.append(action_dict)
+
+ _set_agent_actions(actions, agents_inference)
+ _reset_agents_if_done(agents_inference)
+ get_tree().set_pause(false)
+
+
+func _demo_record_process():
+ if not agent_demo_record:
+ return
+
+ if Input.is_action_just_pressed("RemoveLastDemoEpisode"):
+ print("[Sync script][Demo recorder] Removing last recorded episode.")
+ demo_trajectories.remove_at(demo_trajectories.size() - 1)
+ print("Remaining episode count: %d" % demo_trajectories.size())
+
+ if n_action_steps % agent_demo_record.action_repeat != 0:
+ return
+
+ var obs_dict: Dictionary = agent_demo_record.get_obs()
+
+ # Get the current obs from the agent
+ assert(
+ obs_dict.has("obs"),
+ "Demo recorder needs an 'obs' key in get_obs() returned dictionary to record obs from."
+ )
+ current_demo_trajectory[0].append(obs_dict.obs)
+
+ # Get the action applied for the current obs from the agent
+ agent_demo_record.set_action()
+ var acts = agent_demo_record.get_action()
+
+ var terminal = agent_demo_record.get_done()
+ # Record actions only for non-terminal states
+ if terminal:
+ agent_demo_record.set_done_false()
+ else:
+ current_demo_trajectory[1].append(acts)
+
+ if terminal:
+ #current_demo_trajectory[2].append(true)
+ demo_trajectories.append(current_demo_trajectory.duplicate(true))
+ print("[Sync script][Demo recorder] Recorded episode count: %d" % demo_trajectories.size())
+ current_demo_trajectory[0].clear()
+ current_demo_trajectory[1].clear()
+
+
+func _heuristic_process():
+ for agent in agents_heuristic:
+ _reset_agents_if_done(agents_heuristic)
+
+
+func _extract_action_dict(action_array: Array, action_space: Dictionary, action_means_only: bool):
+ var index = 0
+ var result = {}
+ for key in action_space.keys():
+ var size = action_space[key]["size"]
+ var action_type = action_space[key]["action_type"]
+ if action_type == "discrete":
+ var largest_logit: float # Value of the largest logit for this action in the actions array
+ var largest_logit_idx: int # Index of the largest logit for this action in the actions array
+ for logit_idx in range(0, size):
+ var logit_value = action_array[index + logit_idx]
+ if logit_value > largest_logit:
+ largest_logit = logit_value
+ largest_logit_idx = logit_idx
+ result[key] = largest_logit_idx # Index of the largest logit is the discrete action value
+ index += size
+ elif action_type == "continuous":
+ # For continous actions, we only take the action mean values
+ result[key] = clamp_array(action_array.slice(index, index + size), -1.0, 1.0)
+ if action_means_only:
+ index += size # model only outputs action means, so we move index by size
+ else:
+ index += size * 2 # model outputs logstd after action mean, we skip the logstd part
+
+ else:
+ assert(
+ false,
+ (
+ 'Only "discrete" and "continuous" action types supported. Found: %s action type set.'
+ % action_type
+ )
+ )
+
+ return result
+
+
+## For AIControllers that inherit mode from sync, sets the correct mode.
+func _set_agent_mode(agent: Node):
+ var agent_inherits_mode: bool = agent.control_mode == agent.ControlModes.INHERIT_FROM_SYNC
+
+ if agent_inherits_mode:
+ match control_mode:
+ ControlModes.HUMAN:
+ agent.control_mode = agent.ControlModes.HUMAN
+ ControlModes.TRAINING:
+ agent.control_mode = agent.ControlModes.TRAINING
+ ControlModes.ONNX_INFERENCE:
+ agent.control_mode = agent.ControlModes.ONNX_INFERENCE
+
+
+func _get_agents():
+ all_agents = get_tree().get_nodes_in_group("AGENT")
+ for agent in all_agents:
+ _set_agent_mode(agent)
+
+ if agent.control_mode == agent.ControlModes.TRAINING:
+ agents_training.append(agent)
+ elif agent.control_mode == agent.ControlModes.ONNX_INFERENCE:
+ agents_inference.append(agent)
+ elif agent.control_mode == agent.ControlModes.HUMAN:
+ agents_heuristic.append(agent)
+ elif agent.control_mode == agent.ControlModes.RECORD_EXPERT_DEMOS:
+ assert(
+ not agent_demo_record,
+ "Currently only a single AIController can be used for recording expert demos."
+ )
+ agent_demo_record = agent
+
+ var training_agent_count = agents_training.size()
+ agents_training_policy_names.resize(training_agent_count)
+ for i in range(0, training_agent_count):
+ agents_training_policy_names[i] = agents_training[i].policy_name
+
+
+func _set_heuristic(heuristic, agents: Array):
+ for agent in agents:
+ agent.set_heuristic(heuristic)
+
+
+func _handshake():
+ print("performing handshake")
+
+ var json_dict = _get_dict_json_message()
+ assert(json_dict["type"] == "handshake")
+ var major_version = json_dict["major_version"]
+ var minor_version = json_dict["minor_version"]
+ if major_version != MAJOR_VERSION:
+ print("WARNING: major verison mismatch ", major_version, " ", MAJOR_VERSION)
+ if minor_version != MINOR_VERSION:
+ print("WARNING: minor verison mismatch ", minor_version, " ", MINOR_VERSION)
+
+ print("handshake complete")
+
+
+func _get_dict_json_message():
+ # returns a dictionary from of the most recent message
+ # this is not waiting
+ while stream.get_available_bytes() == 0:
+ stream.poll()
+ if stream.get_status() != 2:
+ print("server disconnected status, closing")
+ get_tree().quit()
+ return null
+
+ OS.delay_usec(10)
+
+ var message = stream.get_string()
+ var json_data = JSON.parse_string(message)
+
+ return json_data
+
+
+func _send_dict_as_json_message(dict):
+ stream.put_string(JSON.stringify(dict, "", false))
+
+
+func _send_env_info():
+ var json_dict = _get_dict_json_message()
+ assert(json_dict["type"] == "env_info")
+
+ var message = {
+ "type": "env_info",
+ "observation_space": _obs_space_training,
+ "action_space": _action_space_training,
+ "n_agents": len(agents_training),
+ "agent_policy_names": agents_training_policy_names
+ }
+ _send_dict_as_json_message(message)
+
+
+func connect_to_server():
+ print("Waiting for one second to allow server to start")
+ OS.delay_msec(1000)
+ print("trying to connect to server")
+ stream = StreamPeerTCP.new()
+
+ # "localhost" was not working on windows VM, had to use the IP
+ var ip = "127.0.0.1"
+ var port = _get_port()
+ var connect = stream.connect_to_host(ip, port)
+ stream.set_no_delay(true) # TODO check if this improves performance or not
+ stream.poll()
+ # Fetch the status until it is either connected (2) or failed to connect (3)
+ while stream.get_status() < 2:
+ stream.poll()
+ return stream.get_status() == 2
+
+
+func _get_args():
+ print("getting command line arguments")
+ var arguments = {}
+ for argument in OS.get_cmdline_args():
+ print(argument)
+ if argument.find("=") > -1:
+ var key_value = argument.split("=")
+ arguments[key_value[0].lstrip("--")] = key_value[1]
+ else:
+ # Options without an argument will be present in the dictionary,
+ # with the value set to an empty string.
+ arguments[argument.lstrip("--")] = ""
+
+ return arguments
+
+
+func _get_speedup():
+ print(args)
+ return args.get("speedup", str(speed_up)).to_float()
+
+
+func _get_port():
+ return args.get("port", DEFAULT_PORT).to_int()
+
+
+func _set_seed():
+ var _seed = args.get("env_seed", DEFAULT_SEED).to_int()
+ seed(_seed)
+
+
+func _set_action_repeat():
+ action_repeat = args.get("action_repeat", str(action_repeat)).to_int()
+
+
+func disconnect_from_server():
+ stream.disconnect_from_host()
+
+
+func handle_message() -> bool:
+ # get json message: reset, step, close
+ var message = _get_dict_json_message()
+ if message["type"] == "close":
+ print("received close message, closing game")
+ get_tree().quit()
+ get_tree().set_pause(false)
+ return true
+
+ if message["type"] == "reset":
+ print("resetting all agents")
+ _reset_agents()
+ just_reset = true
+ get_tree().set_pause(false)
+ #print("resetting forcing draw")
+# RenderingServer.force_draw()
+# var obs = _get_obs_from_agents()
+# print("obs ", obs)
+# var reply = {
+# "type": "reset",
+# "obs": obs
+# }
+# _send_dict_as_json_message(reply)
+ return true
+
+ if message["type"] == "call":
+ var method = message["method"]
+ var returns = _call_method_on_agents(method)
+ var reply = {"type": "call", "returns": returns}
+ print("calling method from Python")
+ _send_dict_as_json_message(reply)
+ return handle_message()
+
+ if message["type"] == "action":
+ var action = message["action"]
+ _set_agent_actions(action, agents_training)
+ need_to_send_obs = true
+ get_tree().set_pause(false)
+ return true
+
+ print("message was not handled")
+ return false
+
+
+func _call_method_on_agents(method):
+ var returns = []
+ for agent in all_agents:
+ returns.append(agent.call(method))
+
+ return returns
+
+
+func _reset_agents_if_done(agents = all_agents):
+ for agent in agents:
+ if agent.get_done():
+ agent.set_done_false()
+
+
+func _reset_agents(agents = all_agents):
+ for agent in agents:
+ agent.needs_reset = true
+ #agent.reset()
+
+
+func _get_obs_from_agents(agents: Array = all_agents):
+ var obs = []
+ for agent in agents:
+ obs.append(agent.get_obs())
+ return obs
+
+
+func _get_reward_from_agents(agents: Array = agents_training):
+ var rewards = []
+ for agent in agents:
+ rewards.append(agent.get_reward())
+ agent.zero_reward()
+ return rewards
+
+
+func _get_info_from_agents(agents: Array = all_agents):
+ var info = []
+ for agent in agents:
+ info.append(agent.get_info())
+ return info
+
+
+func _get_done_from_agents(agents: Array = agents_training):
+ var dones = []
+ for agent in agents:
+ var done = agent.get_done()
+ if done:
+ agent.set_done_false()
+ dones.append(done)
+ return dones
+
+
+func _set_agent_actions(actions, agents: Array = all_agents):
+ for i in range(len(actions)):
+ agents[i].set_action(actions[i])
+
+
+func clamp_array(arr: Array, min: float, max: float):
+ var output: Array = []
+ for a in arr:
+ output.append(clamp(a, min, max))
+ return output
+
+
+## Save recorded export demos on window exit (Close game window instead of "Stop" button in Godot Editor)
+func _notification(what):
+ if demo_trajectories.size() == 0 or expert_demo_save_path.is_empty():
+ return
+
+ if what == NOTIFICATION_PREDELETE:
+ var json_string = JSON.stringify(demo_trajectories, "", false)
+ var file = FileAccess.open(expert_demo_save_path, FileAccess.WRITE)
+
+ if not file:
+ var error: Error = FileAccess.get_open_error()
+ assert(not error, "There was an error opening the file: %d" % error)
+
+ file.store_line(json_string)
+ var error = file.get_error()
+ assert(not error, "There was an error after trying to write to the file: %d" % error)
diff --git a/examples/Platform2D/assets/player/jump/Player1Jump1.png b/examples/Platform2D/assets/player/jump/Player1Jump1.png
new file mode 100644
index 0000000..9a1719b
Binary files /dev/null and b/examples/Platform2D/assets/player/jump/Player1Jump1.png differ
diff --git a/examples/Platform2D/assets/player/jump/Player1Jump2.png b/examples/Platform2D/assets/player/jump/Player1Jump2.png
new file mode 100644
index 0000000..d64645e
Binary files /dev/null and b/examples/Platform2D/assets/player/jump/Player1Jump2.png differ
diff --git a/examples/Platform2D/assets/player/jump/Player1Jump3.png b/examples/Platform2D/assets/player/jump/Player1Jump3.png
new file mode 100644
index 0000000..79cecf1
Binary files /dev/null and b/examples/Platform2D/assets/player/jump/Player1Jump3.png differ
diff --git a/examples/Platform2D/assets/player/move/Player-1.png b/examples/Platform2D/assets/player/move/Player-1.png
new file mode 100644
index 0000000..aacd953
Binary files /dev/null and b/examples/Platform2D/assets/player/move/Player-1.png differ
diff --git a/examples/Platform2D/assets/player/move/Player-2.png b/examples/Platform2D/assets/player/move/Player-2.png
new file mode 100644
index 0000000..e750568
Binary files /dev/null and b/examples/Platform2D/assets/player/move/Player-2.png differ
diff --git a/examples/Platform2D/assets/player/move/Player-3.png b/examples/Platform2D/assets/player/move/Player-3.png
new file mode 100644
index 0000000..f9b894b
Binary files /dev/null and b/examples/Platform2D/assets/player/move/Player-3.png differ
diff --git a/examples/Platform2D/assets/tilesheet.png b/examples/Platform2D/assets/tilesheet.png
new file mode 100644
index 0000000..ee1352d
Binary files /dev/null and b/examples/Platform2D/assets/tilesheet.png differ
diff --git a/examples/Platform2D/icon.svg b/examples/Platform2D/icon.svg
new file mode 100644
index 0000000..b370ceb
--- /dev/null
+++ b/examples/Platform2D/icon.svg
@@ -0,0 +1 @@
+
diff --git a/examples/Platform2D/license.md b/examples/Platform2D/license.md
new file mode 100644
index 0000000..fbcf1fa
--- /dev/null
+++ b/examples/Platform2D/license.md
@@ -0,0 +1,5 @@
+Platform2D Environment made by Ivan Dodic (https://github.com/Ivan-267)
+
+The following license is only for the graphical assets in the folder "assets", specifically .png files:
+Author: Ivan Dodic (https://github.com/Ivan-267),
+License: https://creativecommons.org/licenses/by/4.0/
\ No newline at end of file
diff --git a/examples/Platform2D/model.onnx b/examples/Platform2D/model.onnx
new file mode 100644
index 0000000..d5b25f7
Binary files /dev/null and b/examples/Platform2D/model.onnx differ
diff --git a/examples/Platform2D/project.godot b/examples/Platform2D/project.godot
new file mode 100644
index 0000000..3137106
--- /dev/null
+++ b/examples/Platform2D/project.godot
@@ -0,0 +1,54 @@
+; Engine configuration file.
+; It's best edited using the editor UI and not directly,
+; since the parameters that go here are not all obvious.
+;
+; Format:
+; [section] ; section goes between []
+; param=value ; assign values to parameters
+
+config_version=5
+
+[application]
+
+config/name="Platform2D"
+run/main_scene="res://scenes/training_scene/training_scene.tscn"
+config/features=PackedStringArray("4.3", "C#", "Forward Plus")
+config/icon="res://icon.svg"
+
+[dotnet]
+
+project/assembly_name="Platform2D"
+
+[editor_plugins]
+
+enabled=PackedStringArray("res://addons/godot_rl_agents/plugin.cfg")
+
+[input]
+
+move_left={
+"deadzone": 0.5,
+"events": [Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":65,"key_label":0,"unicode":97,"location":0,"echo":false,"script":null)
+, Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":4194319,"key_label":0,"unicode":0,"location":0,"echo":false,"script":null)
+]
+}
+move_right={
+"deadzone": 0.5,
+"events": [Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":68,"key_label":0,"unicode":100,"location":0,"echo":false,"script":null)
+, Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":4194321,"key_label":0,"unicode":0,"location":0,"echo":false,"script":null)
+]
+}
+jump={
+"deadzone": 0.5,
+"events": [Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":32,"key_label":0,"unicode":32,"location":0,"echo":false,"script":null)
+, Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":4194326,"key_label":0,"unicode":0,"location":0,"echo":false,"script":null)
+]
+}
+
+[physics]
+
+common/max_physics_steps_per_frame=64
+2d/solver/solver_iterations=2
+
+[rendering]
+
+environment/defaults/default_clear_color=Color(0, 0, 0, 1)
diff --git a/examples/Platform2D/readme.md b/examples/Platform2D/readme.md
new file mode 100644
index 0000000..9d84afa
--- /dev/null
+++ b/examples/Platform2D/readme.md
@@ -0,0 +1,69 @@
+# Platform2D environment
+https://github.com/user-attachments/assets/9e61e70d-0968-4432-8952-44c38f7e971f
+
+## Goal:
+The player must pick up all of the coins and reach the goal, while avoiding traps and falling outside of the map.
+It's not allowed to move to a lower level without first picking up all of the coins in the current one.
+
+## Observations:
+```gdscript
+func get_obs() -> Dictionary:
+ var obs: Array[float]
+
+ for sensor in raycast_sensors:
+ obs.append_array(sensor.get_observation())
+
+ var player_velocity := player.get_real_velocity()
+ player_velocity /= Vector2(player.speed, player.jump_velocity)
+
+ obs.append_array(
+ [
+ clampf(player_velocity.x, -1.0, 1.0),
+ clampf(player_velocity.y, -1.0, 1.0),
+ float(player.is_on_floor())
+ ]
+ )
+
+ var goal_pos_global := player.map_manager.goal_position
+ var player_to_goal := player.to_local(goal_pos_global)
+ var goal_direction := player_to_goal.normalized()
+ var goal_dist := clampf(player_to_goal.length() / 640.0, 0, 1.0)
+
+ obs.append_array([goal_direction.x, goal_direction.y, goal_dist])
+ return {"obs": obs}
+```
+
+Observations include data from multiple raycast sensors (and a grid sensor with 2 cells to detect any remaining coins left or right from the player in the same row),
+player velocity, whether jumping is allowed or not (`is_on_floor()`), as well as a direction vector and distance scalar toward the goal.
+
+
+## Actions:
+```python
+func get_action_space() -> Dictionary:
+ return {
+ "move": {"size": 3, "action_type": "discrete"},
+ "jump": {"size": 2, "action_type": "discrete"},
+ }
+```
+The player can stand still, move left/right, and jump.
+
+## Running inference:
+
+If you’d just like to test the env using the pre-trained onnx model, open `res://scenes/training_scene/inference_scene.tscn` in Godot, then press `F6`.
+
+## Training:
+
+There’s an included onnx file that was trained with https://github.com/edbeeching/godot_rl_agents/blob/main/examples/stable_baselines3_example.py
+
+CL arguments used (also onnx export and model saving was used, enable as needed, add `env_path` too to set the exported executable of the platform):
+
+```python
+--speedup=32
+--n_parallel=8
+--timesteps=10_000_000
+--linear_lr_schedule
+```
+
+Stats from the training session (success rate only):
+
+
diff --git a/examples/Platform2D/scenes/game_scene/game_scene.tscn b/examples/Platform2D/scenes/game_scene/game_scene.tscn
new file mode 100644
index 0000000..e7a2fac
--- /dev/null
+++ b/examples/Platform2D/scenes/game_scene/game_scene.tscn
@@ -0,0 +1,21 @@
+[gd_scene load_steps=4 format=3 uid="uid://danlnf1x033rf"]
+
+[ext_resource type="TileSet" uid="uid://sdmrwh4xd6qj" path="res://scenes/tileset/tileset.tres" id="1_rsvlm"]
+[ext_resource type="Script" path="res://scenes/tilemap/tile_map_layer.gd" id="2_0030j"]
+[ext_resource type="PackedScene" uid="uid://d2qsl7semlkyv" path="res://scenes/player/player.tscn" id="3_5v4jr"]
+
+[node name="GameScene" type="Node2D"]
+
+[node name="TileMapLayer" type="TileMapLayer" parent="."]
+tile_set = ExtResource("1_rsvlm")
+navigation_enabled = false
+script = ExtResource("2_0030j")
+total_rows = 12
+max_coins_per_row = 3
+
+[node name="Player" parent="." node_paths=PackedStringArray("map_manager") instance=ExtResource("3_5v4jr")]
+position = Vector2(50, -99.46)
+map_manager = NodePath("../TileMapLayer")
+
+[node name="Camera2D" type="Camera2D" parent="Player"]
+zoom = Vector2(0.785, 0.785)
diff --git a/examples/Platform2D/scenes/player/extended_grid_sensor_2d.gd b/examples/Platform2D/scenes/player/extended_grid_sensor_2d.gd
new file mode 100644
index 0000000..6d812e3
--- /dev/null
+++ b/examples/Platform2D/scenes/player/extended_grid_sensor_2d.gd
@@ -0,0 +1,42 @@
+extends GridSensor2D
+
+## Simple modification to enable detecting a single physics layer of tilemap
+## without needing to check the collision layer of the specific tile
+## it also clamps all observation values to 0-1 range, and overrides some of
+## the export variable ranges.
+## Note: Meant to be used with only a single `detection mask` layer per sensor
+
+@export_range(1, 10_000, 0.1) var cell_width_override := 20.0
+@export_range(1, 10_000, 0.1) var cell_height_override := 20.0
+@export_range(1, 21, 1, "or_greater") var grid_size_x_override := 2
+@export_range(1, 21, 1, "or_greater") var grid_size_y_override := 1
+
+
+func _ready() -> void:
+ cell_width = cell_width_override
+ cell_height = cell_height_override
+ grid_size_x = grid_size_x_override
+ grid_size_y = grid_size_y_override
+ super._ready()
+
+func get_observation():
+ # There can be more than one object in a cell at a time
+ # to simplify the obs, we clamp the values to 0-1 range
+ var obs: Array[float]
+ obs.resize(_obs_buffer.size())
+ for obs_idx in _obs_buffer:
+ obs[obs_idx] = clampf(_obs_buffer[obs_idx], 0, 1)
+ return obs
+
+func _on_cell_body_entered(_body: Node2D, cell_i: int, cell_j: int):
+ #prints("_on_cell_body_entered", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, detection_mask, true)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
+
+
+func _on_cell_body_exited(_body: Node2D, cell_i: int, cell_j: int):
+ #prints("_on_cell_body_exited", cell_i, cell_j)
+ _update_obs(cell_i, cell_j, detection_mask, false)
+ if debug_view:
+ _toggle_cell(cell_i, cell_j)
diff --git a/examples/Platform2D/scenes/player/player.gd b/examples/Platform2D/scenes/player/player.gd
new file mode 100644
index 0000000..d5ae260
--- /dev/null
+++ b/examples/Platform2D/scenes/player/player.gd
@@ -0,0 +1,90 @@
+extends CharacterBody2D
+class_name Player
+
+@export var speed := 700.0
+@export var jump_velocity := -1400.0
+@export var map_manager: MapManager
+@export var ai_controller: PlayerAIController
+
+var requested_movement: float
+var requested_jump: bool
+
+@onready var animated_sprite: AnimatedSprite2D = $AnimatedSprite2D
+@onready var initial_transform := transform
+
+
+func _physics_process(delta: float) -> void:
+ handle_movement(delta)
+ move_and_slide()
+ end_episode_on_fell_down()
+
+
+func handle_movement(delta: float):
+ # Gravity
+ if not is_on_floor():
+ velocity += Vector2.DOWN * 4000 * delta
+
+ # Controls (human or AI controlled)
+ var direction: float
+ if ai_controller.control_mode == AIController2D.ControlModes.HUMAN:
+ direction = Input.get_axis("move_left", "move_right")
+ requested_jump = Input.is_action_just_pressed("jump")
+ else:
+ direction = requested_movement
+
+ # Horizontal movement
+ velocity.x = direction * speed
+ if velocity.x:
+ animated_sprite.flip_h = velocity.x < 0
+ if is_on_floor():
+ animated_sprite.animation = "move"
+ animated_sprite.play()
+
+ # Jump
+ if requested_jump and is_on_floor():
+ velocity.y = jump_velocity
+ animated_sprite.animation = "jump"
+ animated_sprite.play()
+
+ # Stop animation if not moving
+ if velocity.length_squared() < 0.01 and is_on_floor():
+ animated_sprite.animation = "move"
+ animated_sprite.stop()
+
+
+func end_episode_on_fell_down() -> void:
+ if (position.y - map_manager.total_rows * map_manager.tile_set.tile_size.y) > 0:
+ end_episode(-1.0)
+
+
+func end_episode(final_reward := 0.0, success := false) -> void:
+ ai_controller.end_episode(final_reward, success)
+ transform = initial_transform
+ map_manager.reset()
+
+
+func _on_area_2d_body_shape_entered(
+ body_rid: RID, body: Node2D, _body_shape_index: int, _local_shape_index: int
+) -> void:
+ if body is MapManager:
+ var coords = body.get_coords_for_body_rid(body_rid)
+ if body.get_cell_atlas_coords(coords) == MapManager.Tiles.COIN:
+ body.remove_coin_from_position(coords)
+ player_picked_up_coin()
+ elif body.get_cell_atlas_coords(coords) == MapManager.Tiles.GOAL:
+ player_reached_goal()
+ elif body.get_cell_atlas_coords(coords) == MapManager.Tiles.SPIKES:
+ player_hit_spikes()
+
+
+func player_picked_up_coin() -> void:
+ ai_controller.reward += 1
+
+
+func player_reached_goal() -> void:
+ if map_manager.remaining_coins == 0:
+ end_episode(+10, true)
+
+
+func player_hit_spikes() -> void:
+ end_episode(-0.1)
diff --git a/examples/Platform2D/scenes/player/player.tscn b/examples/Platform2D/scenes/player/player.tscn
new file mode 100644
index 0000000..43d65de
--- /dev/null
+++ b/examples/Platform2D/scenes/player/player.tscn
@@ -0,0 +1,126 @@
+[gd_scene load_steps=14 format=3 uid="uid://d2qsl7semlkyv"]
+
+[ext_resource type="Script" path="res://scenes/player/player.gd" id="1_uuo8p"]
+[ext_resource type="Texture2D" uid="uid://djgf0w4s12f86" path="res://assets/player/jump/Player1Jump1.png" id="2_6wykd"]
+[ext_resource type="Texture2D" uid="uid://dtrkm6ibrh22k" path="res://assets/player/move/Player-1.png" id="2_t2oqs"]
+[ext_resource type="Texture2D" uid="uid://b5ty3hrl7jtj3" path="res://assets/player/jump/Player1Jump2.png" id="3_0ssda"]
+[ext_resource type="Texture2D" uid="uid://mohyg2vfkunp" path="res://assets/player/move/Player-2.png" id="3_y4ujj"]
+[ext_resource type="Texture2D" uid="uid://dlec4dcwqdi66" path="res://assets/player/move/Player-3.png" id="4_iy4ba"]
+[ext_resource type="Texture2D" uid="uid://f7aey6fwsl0i" path="res://assets/player/jump/Player1Jump3.png" id="4_ngojv"]
+[ext_resource type="Script" path="res://addons/godot_rl_agents/sensors/sensors_2d/RaycastSensor2D.gd" id="6_ybkrn"]
+[ext_resource type="Script" path="res://scenes/player/extended_grid_sensor_2d.gd" id="7_jr4fg"]
+[ext_resource type="Script" path="res://scenes/player/player_ai_controller.gd" id="12_dkdhh"]
+
+[sub_resource type="SpriteFrames" id="SpriteFrames_5tff7"]
+animations = [{
+"frames": [{
+"duration": 1.0,
+"texture": ExtResource("2_6wykd")
+}, {
+"duration": 1.0,
+"texture": ExtResource("3_0ssda")
+}, {
+"duration": 1.0,
+"texture": ExtResource("4_ngojv")
+}],
+"loop": true,
+"name": &"jump",
+"speed": 6.0
+}, {
+"frames": [{
+"duration": 1.0,
+"texture": ExtResource("2_t2oqs")
+}, {
+"duration": 1.0,
+"texture": ExtResource("3_y4ujj")
+}, {
+"duration": 1.0,
+"texture": ExtResource("4_iy4ba")
+}],
+"loop": true,
+"name": &"move",
+"speed": 6.0
+}]
+
+[sub_resource type="CapsuleShape2D" id="CapsuleShape2D_jl23j"]
+radius = 32.0
+height = 80.0
+
+[sub_resource type="RectangleShape2D" id="RectangleShape2D_vqw8d"]
+size = Vector2(68.14, 81.28)
+
+[node name="Player" type="CharacterBody2D" node_paths=PackedStringArray("ai_controller")]
+collision_layer = 2
+script = ExtResource("1_uuo8p")
+ai_controller = NodePath("AIController2D")
+
+[node name="AnimatedSprite2D" type="AnimatedSprite2D" parent="."]
+scale = Vector2(0.64, 0.64)
+sprite_frames = SubResource("SpriteFrames_5tff7")
+animation = &"move"
+autoplay = "move"
+
+[node name="CollisionShape2D" type="CollisionShape2D" parent="."]
+shape = SubResource("CapsuleShape2D_jl23j")
+
+[node name="Area2D" type="Area2D" parent="."]
+collision_mask = 29
+
+[node name="CollisionShape2D" type="CollisionShape2D" parent="Area2D"]
+position = Vector2(0, 0.5)
+shape = SubResource("RectangleShape2D_vqw8d")
+
+[node name="AIController2D" type="Node2D" parent="." node_paths=PackedStringArray("player", "raycast_sensors")]
+script = ExtResource("12_dkdhh")
+player = NodePath("..")
+raycast_sensors = [NodePath("RaycastGround"), NodePath("RaycastSpike"), NodePath("RaycastCoin"), NodePath("RaycastCoin2"), NodePath("GridSensor2DCoin")]
+reset_after = 2500
+
+[node name="RaycastGround" type="Node2D" parent="AIController2D"]
+visible = false
+rotation = 1.5708
+script = ExtResource("6_ybkrn")
+n_rays = 32.0
+ray_length = 3000
+cone_width = 205.0
+debug_draw = false
+
+[node name="RaycastSpike" type="Node2D" parent="AIController2D"]
+visible = false
+rotation = 1.5708
+script = ExtResource("6_ybkrn")
+collision_mask = 8
+n_rays = 32.0
+ray_length = 3000
+cone_width = 205.0
+debug_draw = false
+
+[node name="RaycastCoin" type="Node2D" parent="AIController2D"]
+visible = false
+script = ExtResource("6_ybkrn")
+collision_mask = 4
+n_rays = 9.0
+ray_length = 1280
+cone_width = 100.0
+debug_draw = false
+
+[node name="RaycastCoin2" type="Node2D" parent="AIController2D"]
+visible = false
+rotation = 3.14159
+script = ExtResource("6_ybkrn")
+collision_mask = 4
+n_rays = 9.0
+ray_length = 1280
+cone_width = 100.0
+debug_draw = false
+
+[node name="GridSensor2DCoin" type="Node2D" parent="AIController2D"]
+visible = false
+position = Vector2(1000, 0)
+script = ExtResource("7_jr4fg")
+cell_width_override = 2000.0
+cell_height_override = 480.0
+detection_mask = 4
+grid_size_y = 1
+
+[connection signal="body_shape_entered" from="Area2D" to="." method="_on_area_2d_body_shape_entered"]
diff --git a/examples/Platform2D/scenes/player/player_ai_controller.gd b/examples/Platform2D/scenes/player/player_ai_controller.gd
new file mode 100644
index 0000000..dbe0374
--- /dev/null
+++ b/examples/Platform2D/scenes/player/player_ai_controller.gd
@@ -0,0 +1,82 @@
+extends AIController2D
+class_name PlayerAIController
+
+@export var player: Player
+@export var raycast_sensors: Array[Node2D]
+
+var is_success := false
+
+
+func _physics_process(_delta):
+ n_steps += 1
+ if needs_reset:
+ reset()
+
+ if n_steps > reset_after:
+ player.end_episode(-0.1)
+
+ # To help training, we reset the episode if there are any remaining coins
+ # in the row above the player
+ var previous_row_coins := player.map_manager.count_coins_in_grid_row(
+ player.map_manager.get_grid_pos(player.global_position).y - 1
+ )
+ if previous_row_coins > 0:
+ player.end_episode(-0.1)
+
+
+func end_episode(final_reward := 0.0, success := false) -> void:
+ is_success = success
+ reward += final_reward
+ done = true
+ reset()
+
+
+
+func get_info() -> Dictionary:
+ if done:
+ return {"is_success": is_success}
+ return {}
+
+
+func get_obs() -> Dictionary:
+ var obs: Array[float]
+
+ for sensor in raycast_sensors:
+ obs.append_array(sensor.get_observation())
+
+ var player_velocity := player.get_real_velocity()
+ player_velocity /= Vector2(player.speed, player.jump_velocity)
+
+ obs.append_array(
+ [
+ clampf(player_velocity.x, -1.0, 1.0),
+ clampf(player_velocity.y, -1.0, 1.0),
+ float(player.is_on_floor())
+ ]
+ )
+
+ var goal_pos_global := player.map_manager.goal_position
+ var player_to_goal := player.to_local(goal_pos_global)
+ var goal_direction := player_to_goal.normalized()
+ var goal_dist := clampf(player_to_goal.length() / 640.0, 0, 1.0)
+
+ obs.append_array([goal_direction.x, goal_direction.y, goal_dist])
+ return {"obs": obs}
+
+
+func get_reward() -> float:
+ return reward
+
+
+func get_action_space() -> Dictionary:
+ return {
+ "move": {"size": 3, "action_type": "discrete"},
+ "jump": {"size": 2, "action_type": "discrete"},
+ }
+
+
+func set_action(action) -> void:
+ player.requested_movement = (action.move - 1)
+ player.requested_jump = (action.jump == 1)
+
+ reward -= action.jump * 0.01
diff --git a/examples/Platform2D/scenes/tilemap/tile_map_layer.gd b/examples/Platform2D/scenes/tilemap/tile_map_layer.gd
new file mode 100644
index 0000000..0d1cc09
--- /dev/null
+++ b/examples/Platform2D/scenes/tilemap/tile_map_layer.gd
@@ -0,0 +1,145 @@
+extends TileMapLayer
+class_name MapManager
+
+
+## Maps tile names to tileset atlas coordinates
+class Tiles:
+ const PLATFORM_LEFT_EDGE = Vector2i(0, 0)
+ const PLATFORM_MIDDLE = Vector2i(1, 0)
+ const PLATFORM_RIGHT_EDGE = Vector2i(2, 0)
+ const GROUND = Vector2i(0, 1) # currently not used
+ const GROUND_2 = Vector2i(1, 1) # currently not used
+ const SPIKES = Vector2i(2, 1)
+ const COIN = Vector2i(0, 2)
+ const GOAL = Vector2i(1, 2)
+
+
+@export var rows_between_walkable_platforms: int = 3
+## Must be a multiple of rows_between_walkable_platforms
+@export var total_rows: int = 40
+@export var total_columns: int = 10
+
+## Coin parameters
+@export var max_coins := 200
+@export var max_coins_per_row := 5
+
+## Remaining coin count
+var remaining_coins := 0
+
+## Goal position in global coordinates
+var goal_position: Vector2
+
+
+func _ready() -> void:
+ update_map()
+
+
+func update_map():
+ clear_map()
+ build_map()
+
+
+func clear_map():
+ remaining_coins = 0
+ clear()
+
+
+func build_map():
+ var coins_total = 0
+
+ for y in range(0, total_rows, rows_between_walkable_platforms):
+ var coins_in_row = 0
+ var walkable_tiles_in_row = total_columns
+ # Place a walkable platform
+ for x in range(0, total_columns):
+ set_cell(Vector2i(x, y), 0, Tiles.PLATFORM_MIDDLE)
+
+ # Carve passages on all but the last platform row
+ if y < total_rows - rows_between_walkable_platforms:
+ # Carve out up to 5 passages down at random coords
+ for i in range(total_columns - 1):
+ var rand_x = randi_range(1, total_columns - 2)
+ if y > 0:
+ # Carve only where there is no carved passage at the same column in the previous platform row
+ while (
+ get_cell_atlas_coords(Vector2i(rand_x, y - rows_between_walkable_platforms))
+ == Vector2i(-1, -1)
+ ):
+ rand_x = randi_range(1, total_columns - 2)
+
+ if not (
+ get_cell_atlas_coords(Vector2i(rand_x, y)) == Tiles.PLATFORM_MIDDLE
+ and get_cell_atlas_coords(Vector2i(rand_x - 1, y)) == Tiles.PLATFORM_MIDDLE
+ and get_cell_atlas_coords(Vector2i(rand_x + 1, y)) == Tiles.PLATFORM_MIDDLE
+ ):
+ continue
+
+ erase_cell(Vector2i(rand_x, y))
+ walkable_tiles_in_row -= 1
+ set_cell(Vector2i(rand_x - 1, y), 0, Tiles.PLATFORM_RIGHT_EDGE)
+ set_cell(Vector2i(rand_x + 1, y), 0, Tiles.PLATFORM_LEFT_EDGE)
+
+ # COINS: Add random coin only if there is no passage below the row
+ if y < total_rows - 1 and coins_total < max_coins:
+ while coins_in_row < max_coins_per_row and coins_in_row < walkable_tiles_in_row:
+ var rand_x = randi_range(0, total_columns - 1)
+ var current_cell_coord = get_cell_atlas_coords(Vector2i(rand_x, y))
+ var above_cell_coord = get_cell_atlas_coords(Vector2i(rand_x, y - 1))
+
+ if (
+ (current_cell_coord != Vector2i(-1, -1))
+ and (above_cell_coord == Vector2i(-1, -1))
+ ):
+ var coin_pos := Vector2i(rand_x, y - 1)
+ set_cell(coin_pos, 0, Tiles.COIN)
+ coins_total += 1
+ remaining_coins += 1
+ coins_in_row += 1
+
+ # TRAPS: Add traps (up to 1 per row, depending on coins placed)
+ if y > 0:
+ var rand_x = randi_range(0, total_columns - 1)
+ var current_cell_coord = get_cell_atlas_coords(Vector2i(rand_x, y))
+ var previous_row_cell_coord = get_cell_atlas_coords(
+ Vector2i(rand_x, y - rows_between_walkable_platforms)
+ )
+ var above_cell_coord = get_cell_atlas_coords(Vector2i(rand_x, y - 1))
+
+ if (
+ (previous_row_cell_coord != Vector2i(-1, -1))
+ and (above_cell_coord == Vector2i(-1, -1))
+ and (current_cell_coord == Tiles.PLATFORM_MIDDLE)
+ ):
+ var spike_pos := Vector2i(rand_x, y - 1)
+ remove_coin_from_position(spike_pos, false)
+ set_cell(spike_pos, 0, Tiles.SPIKES)
+
+ # GOAL: Add 1 goal on the last level
+ if y == total_rows - rows_between_walkable_platforms:
+ var rand_x = randi_range(0, total_columns - 1)
+ var goal_pos := Vector2i(rand_x, y - 1)
+ remove_coin_from_position(goal_pos, false)
+ set_cell(goal_pos, 0, Tiles.GOAL)
+ goal_position = to_global(map_to_local(goal_pos))
+
+
+func count_coins_in_grid_row(grid_y: int) -> int:
+ var coins := 0
+ for x in total_columns:
+ if get_cell_atlas_coords(Vector2i(x, grid_y)) == Tiles.COIN:
+ coins += 1
+ return coins
+
+
+func remove_coin_from_position(grid_position: Vector2i, clear_cell := true):
+ if get_cell_atlas_coords(grid_position) == Tiles.COIN:
+ remaining_coins -= 1
+ if clear_cell: set_cell(grid_position, -1)
+
+
+func get_grid_pos(position_global: Vector2) -> Vector2i:
+ return local_to_map(to_local(position_global))
+
+
+func reset():
+ call_deferred("update_map")
diff --git a/examples/Platform2D/scenes/tileset/tileset.tres b/examples/Platform2D/scenes/tileset/tileset.tres
new file mode 100644
index 0000000..bcaf652
--- /dev/null
+++ b/examples/Platform2D/scenes/tileset/tileset.tres
@@ -0,0 +1,31 @@
+[gd_resource type="TileSet" load_steps=3 format=3 uid="uid://sdmrwh4xd6qj"]
+
+[ext_resource type="Texture2D" uid="uid://mow8g6gd34j3" path="res://assets/tilesheet.png" id="1_xi302"]
+
+[sub_resource type="TileSetAtlasSource" id="TileSetAtlasSource_oky61"]
+texture = ExtResource("1_xi302")
+texture_region_size = Vector2i(160, 160)
+0:0/0 = 0
+0:0/0/physics_layer_0/polygon_0/points = PackedVector2Array(-80, -80, 80, -80, 80, 80, -80, 80)
+1:0/0 = 0
+1:0/0/physics_layer_0/polygon_0/points = PackedVector2Array(-80, -80, 80, -80, 80, 80, -80, 80)
+2:0/0 = 0
+2:0/0/physics_layer_0/polygon_0/points = PackedVector2Array(-80, -80, 80, -80, 80, 80, -80, 80)
+0:1/0 = 0
+0:1/0/physics_layer_0/polygon_0/points = PackedVector2Array(-80, -80, 80, -80, 80, 80, -80, 80)
+1:1/0 = 0
+1:1/0/physics_layer_0/polygon_0/points = PackedVector2Array(-80, -80, 80, -80, 80, 80, -80, 80)
+2:1/0 = 0
+2:1/0/physics_layer_2/polygon_0/points = PackedVector2Array(-80, 80, -50.9091, -7.27273, 50.9091, -7.27273, 80, 80)
+0:2/0 = 0
+0:2/0/physics_layer_1/polygon_0/points = PackedVector2Array(-21.8182, -21.8182, 21.8182, -21.8182, 21.8182, 21.8182, -21.8182, 21.8182, -21.8182, 21.8182)
+1:2/0 = 0
+1:2/0/physics_layer_3/polygon_0/points = PackedVector2Array(-50.9091, -50.9091, -7.27273, 80, 7.27273, 80, 50.9091, -50.9091)
+
+[resource]
+tile_size = Vector2i(160, 160)
+physics_layer_0/collision_layer = 1
+physics_layer_1/collision_layer = 4
+physics_layer_2/collision_layer = 8
+physics_layer_3/collision_layer = 16
+sources/0 = SubResource("TileSetAtlasSource_oky61")
diff --git a/examples/Platform2D/scenes/training_scene/inference_scene.tscn b/examples/Platform2D/scenes/training_scene/inference_scene.tscn
new file mode 100644
index 0000000..e29efcd
--- /dev/null
+++ b/examples/Platform2D/scenes/training_scene/inference_scene.tscn
@@ -0,0 +1,14 @@
+[gd_scene load_steps=3 format=3 uid="uid://c2baxuisewykf"]
+
+[ext_resource type="PackedScene" uid="uid://danlnf1x033rf" path="res://scenes/game_scene/game_scene.tscn" id="1_pwlfw"]
+[ext_resource type="Script" path="res://addons/godot_rl_agents/sync.gd" id="3_bqjwy"]
+
+[node name="InferenceScene" type="Node2D"]
+
+[node name="GameScene" parent="." instance=ExtResource("1_pwlfw")]
+
+[node name="Sync" type="Node" parent="."]
+script = ExtResource("3_bqjwy")
+control_mode = 2
+action_repeat = 4
+onnx_model_path = "model.onnx"
diff --git a/examples/Platform2D/scenes/training_scene/training_scene.tscn b/examples/Platform2D/scenes/training_scene/training_scene.tscn
new file mode 100644
index 0000000..a3c6410
--- /dev/null
+++ b/examples/Platform2D/scenes/training_scene/training_scene.tscn
@@ -0,0 +1,34 @@
+[gd_scene load_steps=3 format=3 uid="uid://fp0m16qnoe0r"]
+
+[ext_resource type="PackedScene" uid="uid://danlnf1x033rf" path="res://scenes/game_scene/game_scene.tscn" id="1_bccmi"]
+[ext_resource type="Script" path="res://addons/godot_rl_agents/sync.gd" id="3_sbmsv"]
+
+[node name="TrainingScene" type="Node2D"]
+
+[node name="GameScene" parent="." instance=ExtResource("1_bccmi")]
+
+[node name="GameScene2" parent="." instance=ExtResource("1_bccmi")]
+position = Vector2(10000, 0)
+
+[node name="GameScene3" parent="." instance=ExtResource("1_bccmi")]
+position = Vector2(20000, 0)
+
+[node name="GameScene4" parent="." instance=ExtResource("1_bccmi")]
+position = Vector2(-10000, 0)
+
+[node name="GameScene5" parent="." instance=ExtResource("1_bccmi")]
+position = Vector2(0, 10048)
+
+[node name="GameScene6" parent="." instance=ExtResource("1_bccmi")]
+position = Vector2(10000, 10048)
+
+[node name="GameScene7" parent="." instance=ExtResource("1_bccmi")]
+position = Vector2(20000, 10048)
+
+[node name="GameScene8" parent="." instance=ExtResource("1_bccmi")]
+position = Vector2(-10000, 10048)
+
+[node name="Sync" type="Node" parent="."]
+script = ExtResource("3_sbmsv")
+action_repeat = 4
+onnx_model_path = "C:\\Users\\Computer\\PycharmProjects\\godot_rl_agents\\examples\\model.onnx"