diff --git a/env/AnimalAI-Environment/.gitignore b/env/AnimalAI-Environment/.gitignore
new file mode 100755
index 00000000..3aed551d
--- /dev/null
+++ b/env/AnimalAI-Environment/.gitignore
@@ -0,0 +1,14 @@
+AnimalAI-environment.sln
+AnimalAIOlympics-PrivDev-UnitySDK.sln
+Assembly-CSharp.csproj
+Assembly-CSharp-Editor.csproj
+iOSBLAS.csproj
+Library
+Logs
+MacBLAS.csproj
+obj
+omnisharp.json
+Packages
+UIElementsSchema
+.vscode
+Temp
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics.meta
new file mode 100755
index 00000000..59a4dc3f
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 11630fa83cc8b4194b94352e3e6cdb9d
+folderAsset: yes
+timeCreated: 1504127524
+licenseType: Pro
+DefaultImporter:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor.meta
new file mode 100755
index 00000000..9ded768b
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 67b4fb0b937cc471eae742addf6bda86
+folderAsset: yes
+timeCreated: 1503177274
+licenseType: Free
+DefaultImporter:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/AgentEditor.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/AgentEditor.cs
new file mode 100755
index 00000000..83eced8f
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/AgentEditor.cs
@@ -0,0 +1,86 @@
+using UnityEngine;
+using UnityEditor;
+
+namespace MLAgents
+{
+/*
+ This code is meant to modify the behavior of the inspector on Brain Components.
+ Depending on the type of brain that is used, the available fields will be modified in the inspector accordingly.
+*/
+ [CustomEditor(typeof(Agent), true)]
+ [CanEditMultipleObjects]
+ public class AgentEditor : Editor
+ {
+
+ public override void OnInspectorGUI()
+ {
+ SerializedObject serializedAgent = serializedObject;
+ serializedAgent.Update();
+
+ SerializedProperty brain = serializedAgent.FindProperty("brain");
+ SerializedProperty actionsPerDecision = serializedAgent.FindProperty(
+ "agentParameters.numberOfActionsBetweenDecisions");
+ SerializedProperty maxSteps = serializedAgent.FindProperty(
+ "agentParameters.maxStep");
+ SerializedProperty isResetOnDone = serializedAgent.FindProperty(
+ "agentParameters.resetOnDone");
+ SerializedProperty isODD = serializedAgent.FindProperty(
+ "agentParameters.onDemandDecision");
+ SerializedProperty cameras = serializedAgent.FindProperty(
+ "agentParameters.agentCameras");
+
+ EditorGUILayout.PropertyField(brain);
+
+ EditorGUILayout.LabelField("Agent Cameras");
+ for (int i = 0; i < cameras.arraySize; i++)
+ {
+ EditorGUILayout.PropertyField(
+ cameras.GetArrayElementAtIndex(i),
+ new GUIContent("Camera " + (i + 1).ToString() + ": "));
+ }
+
+ EditorGUILayout.BeginHorizontal();
+ if (GUILayout.Button("Add Camera", EditorStyles.miniButton))
+ {
+ cameras.arraySize++;
+ }
+
+ if (GUILayout.Button("Remove Camera", EditorStyles.miniButton))
+ {
+ cameras.arraySize--;
+ }
+
+ EditorGUILayout.EndHorizontal();
+
+ EditorGUILayout.PropertyField(
+ maxSteps,
+ new GUIContent(
+ "Max Step", "The per-agent maximum number of steps."));
+ EditorGUILayout.PropertyField(
+ isResetOnDone,
+ new GUIContent(
+ "Reset On Done",
+ "If checked, the agent will reset on done. Else, AgentOnDone() will be called."));
+ EditorGUILayout.PropertyField(
+ isODD,
+ new GUIContent(
+ "On Demand Decisions",
+ "If checked, you must manually request decisions."));
+ if (!isODD.boolValue)
+ {
+ EditorGUILayout.PropertyField(
+ actionsPerDecision,
+ new GUIContent(
+ "Decision Interval",
+ "The agent will automatically request a decision every X" +
+ " steps and perform an action at every step."));
+ actionsPerDecision.intValue = Mathf.Max(1, actionsPerDecision.intValue);
+ }
+
+ serializedAgent.ApplyModifiedProperties();
+
+ EditorGUILayout.LabelField("", GUI.skin.horizontalSlider);
+ base.OnInspectorGUI();
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/AgentEditor.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/AgentEditor.cs.meta
new file mode 100755
index 00000000..66bc325f
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/AgentEditor.cs.meta
@@ -0,0 +1,12 @@
+fileFormatVersion: 2
+guid: c3b291e1cd0c64781861652b579d0ac1
+timeCreated: 1503270350
+licenseType: Free
+MonoImporter:
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainEditor.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainEditor.cs
new file mode 100755
index 00000000..1a442052
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainEditor.cs
@@ -0,0 +1,36 @@
+using UnityEngine;
+using UnityEditor;
+
+
+namespace MLAgents
+{
+ ///
+ /// CustomEditor for the Brain base class. Defines the default Inspector view for a Brain.
+ /// Shows the BrainParameters of the Brain and expose a tool to deep copy BrainParameters
+ /// between brains.
+ ///
+ [CustomEditor(typeof(Brain))]
+ public class BrainEditor : Editor
+ {
+ public override void OnInspectorGUI()
+ {
+ var brain = (Brain) target;
+ var brainToCopy = EditorGUILayout.ObjectField(
+ "Copy Brain Parameters from : ", null, typeof(Brain), false) as Brain;
+ if (brainToCopy != null)
+ {
+ brain.brainParameters = brainToCopy.brainParameters.Clone();
+ EditorUtility.SetDirty(brain);
+ AssetDatabase.SaveAssets();
+ return;
+ }
+ var serializedBrain = serializedObject;
+ serializedBrain.Update();
+ EditorGUILayout.PropertyField(serializedBrain.FindProperty("brainParameters"), true);
+ serializedBrain.ApplyModifiedProperties();
+
+ // Draws a horizontal thick line
+ EditorGUILayout.LabelField("", GUI.skin.horizontalSlider);
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainEditor.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainEditor.cs.meta
new file mode 100755
index 00000000..c7eaf3f0
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainEditor.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 7b07bebd03994ed08559c725da882b62
+timeCreated: 1537834304
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainParametersDrawer.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainParametersDrawer.cs
new file mode 100755
index 00000000..9363307b
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainParametersDrawer.cs
@@ -0,0 +1,369 @@
+using UnityEngine;
+using UnityEditor;
+
+namespace MLAgents
+{
+ ///
+ /// PropertyDrawer for BrainParameters. Defines how BrainParameters are displayed in the
+ /// Inspector.
+ ///
+ [CustomPropertyDrawer(typeof(BrainParameters))]
+ public class BrainParametersDrawer : PropertyDrawer
+ {
+ // The height of a line in the Unity Inspectors
+ private const float LineHeight = 17f;
+ private const int VecObsNumLine = 3;
+ private const string CamResPropName = "cameraResolutions";
+ private const string ActionSizePropName = "vectorActionSize";
+ private const string ActionTypePropName = "vectorActionSpaceType";
+ private const string ActionDescriptionPropName = "vectorActionDescriptions";
+ private const string VecObsPropName = "vectorObservationSize";
+ private const string NumVecObsPropName ="numStackedVectorObservations";
+ private const string CamWidthPropName = "width";
+ private const string CamHeightPropName = "height";
+ private const string CamGrayPropName = "blackAndWhite";
+ private const int DefaultCameraWidth = 84;
+ private const int DefaultCameraHeight = 84;
+ private const bool DefaultCameraGray = false;
+
+ ///
+ public override float GetPropertyHeight(SerializedProperty property, GUIContent label)
+ {
+ if (property.isExpanded)
+ {
+ return LineHeight +
+ GetHeightDrawVectorObservation() +
+ GetHeightDrawVisualObservation(property) +
+ GetHeightDrawVectorAction(property) +
+ GetHeightDrawVectorActionDescriptions(property);
+ }
+ return LineHeight;
+ }
+
+ ///
+ public override void OnGUI(Rect position, SerializedProperty property, GUIContent label)
+ {
+ var indent = EditorGUI.indentLevel;
+ EditorGUI.indentLevel = 0;
+ position.height = LineHeight;
+ property.isExpanded = EditorGUI.Foldout(position, property.isExpanded, label);
+ position.y += LineHeight;
+ if (property.isExpanded)
+ {
+ EditorGUI.BeginProperty(position, label, property);
+ EditorGUI.indentLevel++;
+
+ // Vector Observations
+ DrawVectorObservation(position, property);
+ position.y += GetHeightDrawVectorObservation();
+
+ //Visual Observations
+ DrawVisualObservations(position, property);
+ position.y += GetHeightDrawVisualObservation(property);
+
+ // Vector Action
+ DrawVectorAction(position, property);
+ position.y += GetHeightDrawVectorAction(property);
+
+ // Vector Action Descriptions
+ DrawVectorActionDescriptions(position, property);
+ position.y += GetHeightDrawVectorActionDescriptions(property);
+ EditorGUI.EndProperty();
+ }
+ EditorGUI.indentLevel = indent;
+ }
+
+ ///
+ /// Draws the Vector Observations for the Brain Parameters
+ ///
+ /// Rectangle on the screen to use for the property GUI.
+ /// The SerializedProperty of the BrainParameters
+ /// to make the custom GUI for.
+ private static void DrawVectorObservation(Rect position, SerializedProperty property)
+ {
+ EditorGUI.LabelField(position, "Vector Observation");
+ position.y += LineHeight;
+
+ EditorGUI.indentLevel++;
+ EditorGUI.PropertyField(position,
+ property.FindPropertyRelative(VecObsPropName),
+ new GUIContent("Space Size",
+ "Length of state " +
+ "vector for brain (In Continuous state space)." +
+ "Or number of possible values (in Discrete state space)."));
+ position.y += LineHeight;
+
+ EditorGUI.PropertyField(position,
+ property.FindPropertyRelative(NumVecObsPropName),
+ new GUIContent("Stacked Vectors",
+ "Number of states that will be stacked before " +
+ "beeing fed to the neural network."));
+ position.y += LineHeight;
+ EditorGUI.indentLevel--;
+ }
+
+ ///
+ /// The Height required to draw the Vector Observations paramaters
+ ///
+ /// The height of the drawer of the Vector Observations
+ private static float GetHeightDrawVectorObservation()
+ {
+ return VecObsNumLine * LineHeight;
+ }
+
+ ///
+ /// Draws the Visual Observations parameters for the Brain Parameters
+ ///
+ /// Rectangle on the screen to use for the property GUI.
+ /// The SerializedProperty of the BrainParameters
+ /// to make the custom GUI for.
+ private static void DrawVisualObservations(Rect position, SerializedProperty property)
+ {
+ EditorGUI.LabelField(position, "Visual Observations");
+ position.y += LineHeight;
+ var quarter = position.width / 4;
+ var resolutions = property.FindPropertyRelative(CamResPropName);
+ DrawVisualObsButtons(position, resolutions);
+ position.y += LineHeight;
+
+ // Display the labels for the columns : Index, Width, Height and Gray
+ var indexRect = new Rect(position.x, position.y, quarter, position.height);
+ var widthRect = new Rect(position.x + quarter, position.y, quarter, position.height);
+ var heightRect = new Rect(position.x + 2*quarter, position.y, quarter, position.height);
+ var bwRect = new Rect(position.x + 3*quarter, position.y, quarter, position.height);
+ EditorGUI.indentLevel++;
+ if (resolutions.arraySize > 0)
+ {
+ EditorGUI.LabelField(indexRect, "Index");
+ indexRect.y += LineHeight;
+ EditorGUI.LabelField(widthRect, "Width");
+ widthRect.y += LineHeight;
+ EditorGUI.LabelField(heightRect, "Height");
+ heightRect.y += LineHeight;
+ EditorGUI.LabelField(bwRect, "Gray");
+ bwRect.y += LineHeight;
+ }
+
+ // Iterate over the resolutions
+ for (var i = 0; i < resolutions.arraySize; i++)
+ {
+ EditorGUI.LabelField(indexRect, "Obs " + i);
+ indexRect.y += LineHeight;
+ var res = resolutions.GetArrayElementAtIndex(i);
+ var w = res.FindPropertyRelative("width");
+ w.intValue = EditorGUI.IntField(widthRect, w.intValue);
+ widthRect.y += LineHeight;
+ var h = res.FindPropertyRelative("height");
+ h.intValue = EditorGUI.IntField(heightRect, h.intValue);
+ heightRect.y += LineHeight;
+ var bw = res.FindPropertyRelative("blackAndWhite");
+ bw.boolValue = EditorGUI.Toggle(bwRect, bw.boolValue);
+ bwRect.y += LineHeight;
+ }
+ EditorGUI.indentLevel--;
+ }
+
+ ///
+ /// Draws the buttons to add and remove the visual observations parameters
+ ///
+ /// Rectangle on the screen to use for the property GUI.
+ /// The SerializedProperty of the resolution array
+ /// to make the custom GUI for.
+ private static void DrawVisualObsButtons(Rect position, SerializedProperty resolutions)
+ {
+ var widthEighth = position.width / 8;
+ var addButtonRect = new Rect(position.x + widthEighth, position.y,
+ 3 * widthEighth, position.height);
+ var removeButtonRect = new Rect(position.x + 4 * widthEighth, position.y,
+ 3 * widthEighth, position.height);
+ if (resolutions.arraySize == 0)
+ {
+ addButtonRect.width *= 2;
+ }
+ // Display the buttons
+ if (GUI.Button(addButtonRect, "Add New", EditorStyles.miniButton))
+ {
+ resolutions.arraySize += 1;
+ var newRes = resolutions.GetArrayElementAtIndex(resolutions.arraySize - 1);
+ newRes.FindPropertyRelative(CamWidthPropName).intValue = DefaultCameraWidth;
+ newRes.FindPropertyRelative(CamHeightPropName).intValue = DefaultCameraHeight;
+ newRes.FindPropertyRelative(CamGrayPropName).boolValue = DefaultCameraGray;
+
+ }
+ if (resolutions.arraySize > 0)
+ {
+ if (GUI.Button(removeButtonRect, "Remove Last", EditorStyles.miniButton))
+ {
+ resolutions.arraySize -= 1;
+ }
+ }
+ }
+
+ ///
+ /// The Height required to draw the Visual Observations parameters
+ ///
+ /// The height of the drawer of the Visual Observations
+ private static float GetHeightDrawVisualObservation(SerializedProperty property)
+ {
+ var visObsSize = property.FindPropertyRelative(CamResPropName).arraySize + 2;
+ if (property.FindPropertyRelative(CamResPropName).arraySize > 0)
+ {
+ visObsSize += 1;
+ }
+ return LineHeight * visObsSize;
+ }
+
+ ///
+ /// Draws the Vector Actions parameters for the Brain Parameters
+ ///
+ /// Rectangle on the screen to use for the property GUI.
+ /// The SerializedProperty of the BrainParameters
+ /// to make the custom GUI for.
+ private static void DrawVectorAction(Rect position, SerializedProperty property)
+ {
+ EditorGUI.LabelField(position, "Vector Action");
+ position.y += LineHeight;
+ EditorGUI.indentLevel++;
+ var bpVectorActionType = property.FindPropertyRelative(ActionTypePropName);
+ EditorGUI.PropertyField(
+ position,
+ bpVectorActionType,
+ new GUIContent("Space Type",
+ "Corresponds to whether state vector contains a single integer (Discrete) " +
+ "or a series of real-valued floats (Continuous)."));
+ position.y += LineHeight;
+ if (bpVectorActionType.enumValueIndex == 1)
+ {
+ DrawContinuousVectorAction(position, property);
+ }
+ else
+ {
+ DrawDiscreteVectorAction(position, property);
+ }
+ }
+
+ ///
+ /// Draws the Continuous Vector Actions parameters for the Brain Parameters
+ ///
+ /// Rectangle on the screen to use for the property GUI.
+ /// The SerializedProperty of the BrainParameters
+ /// to make the custom GUI for.
+ private static void DrawContinuousVectorAction(Rect position, SerializedProperty property)
+ {
+ var vecActionSize = property.FindPropertyRelative(ActionSizePropName);
+ vecActionSize.arraySize = 1;
+ SerializedProperty continuousActionSize =
+ vecActionSize.GetArrayElementAtIndex(0);
+ EditorGUI.PropertyField(
+ position,
+ continuousActionSize,
+ new GUIContent("Space Size", "Length of continuous action vector."));
+ }
+
+ ///
+ /// Draws the Discrete Vector Actions parameters for the Brain Parameters
+ ///
+ /// Rectangle on the screen to use for the property GUI.
+ /// The SerializedProperty of the BrainParameters
+ /// to make the custom GUI for.
+ private static void DrawDiscreteVectorAction(Rect position, SerializedProperty property)
+ {
+ var vecActionSize = property.FindPropertyRelative(ActionSizePropName);
+ vecActionSize.arraySize = EditorGUI.IntField(
+ position, "Branches Size", vecActionSize.arraySize);
+ position.y += LineHeight;
+ position.x += 20;
+ position.width -= 20;
+ for (var branchIndex = 0;
+ branchIndex < vecActionSize.arraySize;
+ branchIndex++)
+ {
+ SerializedProperty branchActionSize =
+ vecActionSize.GetArrayElementAtIndex(branchIndex);
+
+ EditorGUI.PropertyField(
+ position,
+ branchActionSize,
+ new GUIContent("Branch " + branchIndex + " Size",
+ "Number of possible actions for the branch number " + branchIndex + "."));
+ position.y += LineHeight;
+ }
+ }
+
+ ///
+ /// The Height required to draw the Vector Action parameters
+ ///
+ /// The height of the drawer of the Vector Action
+ private static float GetHeightDrawVectorAction(SerializedProperty property)
+ {
+ var actionSize = 2 + property.FindPropertyRelative(ActionSizePropName).arraySize;
+ if (property.FindPropertyRelative(ActionTypePropName).enumValueIndex == 0)
+ {
+ actionSize += 1;
+ }
+ return actionSize * LineHeight;
+ }
+
+ ///
+ /// Draws the Vector Actions descriptions for the Brain Parameters
+ ///
+ /// Rectangle on the screen to use for the property GUI.
+ /// The SerializedProperty of the BrainParameters
+ /// to make the custom GUI for.
+ private static void DrawVectorActionDescriptions(Rect position, SerializedProperty property)
+ {
+ var bpVectorActionType = property.FindPropertyRelative(ActionTypePropName);
+ var vecActionSize = property.FindPropertyRelative(ActionSizePropName);
+ var numberOfDescriptions = 0;
+ if (bpVectorActionType.enumValueIndex == 1)
+ {
+ numberOfDescriptions = vecActionSize.GetArrayElementAtIndex(0).intValue;
+ }
+ else
+ {
+ numberOfDescriptions = vecActionSize.arraySize;
+ }
+
+ EditorGUI.indentLevel++;
+ var vecActionDescriptions =
+ property.FindPropertyRelative(ActionDescriptionPropName);
+ vecActionDescriptions.arraySize = numberOfDescriptions;
+ if (bpVectorActionType.enumValueIndex == 1)
+ {
+ //Continuous case :
+ EditorGUI.PropertyField(
+ position,
+ vecActionDescriptions,
+ new GUIContent("Action Descriptions",
+ "A list of strings used to name the available actionsm for the Brain."),
+ true);
+ position.y += LineHeight;
+ }
+ else
+ {
+ // Discrete case :
+ EditorGUI.PropertyField(
+ position,
+ vecActionDescriptions,
+ new GUIContent("Branch Descriptions",
+ "A list of strings used to name the available branches for the Brain."),
+ true);
+ position.y += LineHeight;
+ }
+ }
+ ///
+ /// The Height required to draw the Action Descriptions
+ ///
+ /// The height of the drawer of the Action Descriptions
+ private static float GetHeightDrawVectorActionDescriptions(SerializedProperty property)
+ {
+ var descriptionSize = 1;
+ if (property.FindPropertyRelative(ActionDescriptionPropName).isExpanded)
+ {
+ var descriptions = property.FindPropertyRelative(ActionDescriptionPropName);
+ descriptionSize += descriptions.arraySize + 1;
+ }
+ return descriptionSize * LineHeight;
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainParametersDrawer.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainParametersDrawer.cs.meta
new file mode 100755
index 00000000..9379a5f0
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BrainParametersDrawer.cs.meta
@@ -0,0 +1,12 @@
+fileFormatVersion: 2
+guid: b060ae8e687cf49bcae88b24db17bfa6
+timeCreated: 1517291065
+licenseType: Free
+MonoImporter:
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BroadcastHubDrawer.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BroadcastHubDrawer.cs
new file mode 100755
index 00000000..b8b481ba
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BroadcastHubDrawer.cs
@@ -0,0 +1,208 @@
+using UnityEngine;
+using UnityEditor;
+using System;
+using System.Linq;
+using UnityEditor.SceneManagement;
+
+namespace MLAgents
+{
+ ///
+ /// PropertyDrawer for BroadcastHub. Used to display the BroadcastHub in the Inspector.
+ ///
+ [CustomPropertyDrawer(typeof(BroadcastHub))]
+ public class BroadcastHubDrawer : PropertyDrawer
+ {
+ private BroadcastHub _hub;
+ // The height of a line in the Unity Inspectors
+ private const float LineHeight = 17f;
+ // The vertical space left below the BroadcastHub UI.
+ private const float ExtraSpaceBelow = 10f;
+ // The horizontal size of the Control checkbox
+ private const int ControlSize = 80;
+
+ ///
+ /// Computes the height of the Drawer depending on the property it is showing
+ ///
+ /// The property that is being drawn.
+ /// The label of the property being drawn.
+ /// The vertical space needed to draw the property.
+ public override float GetPropertyHeight(SerializedProperty property, GUIContent label)
+ {
+ LazyInitializeHub(property, label);
+ var numLines = _hub.Count + 2 + (_hub.Count > 0 ? 1 : 0);
+ return (numLines) * LineHeight + ExtraSpaceBelow;
+ }
+
+ ///
+ public override void OnGUI(Rect position, SerializedProperty property, GUIContent label)
+ {
+ LazyInitializeHub(property, label);
+ position.height = LineHeight;
+ EditorGUI.LabelField(position, new GUIContent(label.text,
+ "The Broadcast Hub helps you define which Brains you want to expose to " +
+ "the external process"));
+ position.y += LineHeight;
+
+ EditorGUI.BeginProperty(position, label, property);
+
+ EditorGUI.indentLevel++;
+ DrawAddRemoveButtons(position);
+ position.y += LineHeight;
+
+ // This is the labels for each columns
+ var brainWidth = position.width - ControlSize;
+ var brainRect = new Rect(
+ position.x, position.y, brainWidth, position.height);
+ var controlRect = new Rect(
+ position.x + brainWidth, position.y, ControlSize, position.height);
+ if (_hub.Count > 0)
+ {
+ EditorGUI.LabelField(brainRect, "Brains");
+ brainRect.y += LineHeight;
+ EditorGUI.LabelField(controlRect, "Control");
+ controlRect.y += LineHeight;
+ controlRect.x += 15;
+ }
+ DrawBrains(brainRect, controlRect);
+ EditorGUI.indentLevel--;
+ EditorGUI.EndProperty();
+ }
+
+ ///
+ /// Draws the Add and Remove buttons.
+ ///
+ /// The position at which to draw.
+ private void DrawAddRemoveButtons(Rect position)
+ {
+ // This is the rectangle for the Add button
+ var addButtonRect = position;
+ addButtonRect.x += 20;
+ if (_hub.Count > 0)
+ {
+ addButtonRect.width /= 2;
+ addButtonRect.width -= 24;
+ var buttonContent = new GUIContent(
+ "Add New", "Add a new Brain to the Broadcast Hub");
+ if (GUI.Button(addButtonRect, buttonContent, EditorStyles.miniButton))
+ {
+ MarkSceneAsDirty();
+ AddBrain();
+ }
+ // This is the rectangle for the Remove button
+ var removeButtonRect = position;
+ removeButtonRect.x = position.width / 2 + 15;
+ removeButtonRect.width = addButtonRect.width - 18;
+ buttonContent = new GUIContent(
+ "Remove Last", "Remove the last Brain from the Broadcast Hub");
+ if (GUI.Button(removeButtonRect, buttonContent, EditorStyles.miniButton))
+ {
+ MarkSceneAsDirty();
+ RemoveLastBrain();
+ }
+ }
+ else
+ {
+ addButtonRect.width -= 50;
+ var buttonContent = new GUIContent(
+ "Add Brain to Broadcast Hub", "Add a new Brain to the Broadcast Hub");
+ if (GUI.Button(addButtonRect, buttonContent, EditorStyles.miniButton))
+ {
+ MarkSceneAsDirty();
+ AddBrain();
+ }
+ }
+ }
+
+ ///
+ /// Draws the Brain and Control checkbox for the brains contained in the BroadCastHub.
+ ///
+ /// The Rect to draw the Brains.
+ /// The Rect to draw the control checkbox.
+ private void DrawBrains(Rect brainRect, Rect controlRect)
+ {
+ for (var index = 0; index < _hub.Count; index++)
+ {
+ var exposedBrains = _hub.broadcastingBrains;
+ var brain = exposedBrains[index];
+ // This is the rectangle for the brain
+ EditorGUI.BeginChangeCheck();
+ var newBrain = EditorGUI.ObjectField(
+ brainRect, brain, typeof(Brain), true) as Brain;
+ brainRect.y += LineHeight;
+ if (EditorGUI.EndChangeCheck())
+ {
+ MarkSceneAsDirty();
+ _hub.broadcastingBrains.RemoveAt(index);
+ var brainToInsert = exposedBrains.Contains(newBrain) ? null : newBrain;
+ exposedBrains.Insert(index, brainToInsert);
+ break;
+ }
+ // This is the Rectangle for the control checkbox
+ EditorGUI.BeginChangeCheck();
+ if (brain is LearningBrain)
+ {
+ var isTraining = _hub.IsControlled(brain);
+ isTraining = EditorGUI.Toggle(controlRect, isTraining);
+ _hub.SetControlled(brain, isTraining);
+ }
+ controlRect.y += LineHeight;
+ if (EditorGUI.EndChangeCheck())
+ {
+ MarkSceneAsDirty();
+ }
+ }
+ }
+
+ ///
+ /// Lazy initializes the Drawer with the property to be drawn.
+ ///
+ /// The SerializedProperty of the BroadcastHub
+ /// to make the custom GUI for.
+ /// The label of this property.
+ private void LazyInitializeHub(SerializedProperty property, GUIContent label)
+ {
+ if (_hub != null)
+ {
+ return;
+ }
+ var target = property.serializedObject.targetObject;
+ _hub = fieldInfo.GetValue(target) as BroadcastHub;
+ if (_hub == null)
+ {
+ _hub = new BroadcastHub();
+ fieldInfo.SetValue(target, _hub);
+ }
+ }
+
+ ///
+ /// Signals that the property has been modified and requires the scene to be saved for
+ /// the changes to persist. Only works when the Editor is not playing.
+ ///
+ private static void MarkSceneAsDirty()
+ {
+ if (!EditorApplication.isPlaying)
+ {
+ EditorSceneManager.MarkSceneDirty(EditorSceneManager.GetActiveScene());
+ }
+ }
+
+ ///
+ /// Removes the last Brain from the BroadcastHub
+ ///
+ private void RemoveLastBrain()
+ {
+ if (_hub.Count > 0)
+ {
+ _hub.broadcastingBrains.RemoveAt(_hub.broadcastingBrains.Count - 1);
+ }
+ }
+
+ ///
+ /// Adds a new Brain to the BroadcastHub. The value of this brain will not be initialized.
+ ///
+ private void AddBrain()
+ {
+ _hub.broadcastingBrains.Add(null);
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BroadcastHubDrawer.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BroadcastHubDrawer.cs.meta
new file mode 100755
index 00000000..7ab682eb
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/BroadcastHubDrawer.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: aa1bef9e5833447ab7251fc6f7a3a609
+timeCreated: 1536852419
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationDrawer.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationDrawer.cs
new file mode 100755
index 00000000..9792cab4
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationDrawer.cs
@@ -0,0 +1,95 @@
+using System.Text;
+using MLAgents;
+using UnityEditor;
+
+///
+/// Renders a custom UI for Demonstration Scriptable Object.
+///
+[CustomEditor(typeof(Demonstration))]
+[CanEditMultipleObjects]
+public class DemonstrationEditor : Editor
+{
+ SerializedProperty brainParameters;
+ SerializedProperty demoMetaData;
+
+ void OnEnable()
+ {
+ brainParameters = serializedObject.FindProperty("brainParameters");
+ demoMetaData = serializedObject.FindProperty("metaData");
+ }
+
+ ///
+ /// Renders Inspector UI for Demonstration metadata.
+ ///
+ void MakeMetaDataProperty(SerializedProperty property)
+ {
+ var nameProp = property.FindPropertyRelative("demonstrationName");
+ var expProp = property.FindPropertyRelative("numberExperiences");
+ var epiProp = property.FindPropertyRelative("numberEpisodes");
+ var rewProp = property.FindPropertyRelative("meanReward");
+
+ var nameLabel = nameProp.displayName + ": " + nameProp.stringValue;
+ var expLabel = expProp.displayName + ": " + expProp.intValue;
+ var epiLabel = epiProp.displayName + ": " + epiProp.intValue;
+ var rewLabel = rewProp.displayName + ": " + rewProp.floatValue;
+
+ EditorGUILayout.LabelField(nameLabel);
+ EditorGUILayout.LabelField(expLabel);
+ EditorGUILayout.LabelField(epiLabel);
+ EditorGUILayout.LabelField(rewLabel);
+ }
+
+ ///
+ /// Constructs label for action size array.
+ ///
+ static string BuildActionArrayLabel(SerializedProperty actionSizeProperty)
+ {
+ var actionSize = actionSizeProperty.arraySize;
+ StringBuilder actionLabel = new StringBuilder("[ ");
+ for (int i = 0; i < actionSize; i++)
+ {
+ actionLabel.Append(actionSizeProperty.GetArrayElementAtIndex(i).intValue);
+ if (i < actionSize - 1)
+ {
+ actionLabel.Append(", ");
+ }
+ }
+
+ actionLabel.Append(" ]");
+ return actionLabel.ToString();
+ }
+
+ ///
+ /// Renders Inspector UI for Brain Parameters of Demonstration.
+ ///
+ void MakeBrainParametersProperty(SerializedProperty property)
+ {
+ var vecObsSizeProp = property.FindPropertyRelative("vectorObservationSize");
+ var numStackedProp = property.FindPropertyRelative("numStackedVectorObservations");
+ var actSizeProperty = property.FindPropertyRelative("vectorActionSize");
+ var camResProp = property.FindPropertyRelative("cameraResolutions");
+ var actSpaceTypeProp = property.FindPropertyRelative("vectorActionSpaceType");
+
+ var vecObsSizeLabel = vecObsSizeProp.displayName + ": " + vecObsSizeProp.intValue;
+ var numStackedLabel = numStackedProp.displayName + ": " + numStackedProp.intValue;
+ var vecActSizeLabel = actSizeProperty.displayName + ": " + BuildActionArrayLabel(actSizeProperty);
+ var camResLabel = camResProp.displayName + ": " + camResProp.arraySize;
+ var actSpaceTypeLabel = actSpaceTypeProp.displayName + ": " + (SpaceType) actSpaceTypeProp.enumValueIndex;
+
+ EditorGUILayout.LabelField(vecObsSizeLabel);
+ EditorGUILayout.LabelField(numStackedLabel);
+ EditorGUILayout.LabelField(vecActSizeLabel);
+ EditorGUILayout.LabelField(camResLabel);
+ EditorGUILayout.LabelField(actSpaceTypeLabel);
+ }
+
+ public override void OnInspectorGUI()
+ {
+ serializedObject.Update();
+ EditorGUILayout.LabelField("Meta Data", EditorStyles.boldLabel);
+ MakeMetaDataProperty(demoMetaData);
+ EditorGUILayout.LabelField("Brain Parameters", EditorStyles.boldLabel);
+ MakeBrainParametersProperty(brainParameters);
+ serializedObject.ApplyModifiedProperties();
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationDrawer.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationDrawer.cs.meta
new file mode 100755
index 00000000..57c06813
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationDrawer.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 84f9cd83f56c74790a51444a6cfe4945
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationImporter.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationImporter.cs
new file mode 100755
index 00000000..e8ffa067
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationImporter.cs
@@ -0,0 +1,60 @@
+using System;
+using System.IO;
+using MLAgents.CommunicatorObjects;
+using UnityEditor;
+using UnityEngine;
+using UnityEditor.Experimental.AssetImporters;
+
+namespace MLAgents
+{
+ ///
+ /// Asset Importer used to parse demonstration files.
+ ///
+ [ScriptedImporter(1, new[] {"demo"})]
+ public class DemonstrationImporter : ScriptedImporter
+ {
+ private const string IconPath = "Assets/ML-Agents/Resources/DemoIcon.png";
+
+ public override void OnImportAsset(AssetImportContext ctx)
+ {
+ var inputType = Path.GetExtension(ctx.assetPath);
+ if (inputType == null)
+ {
+ throw new Exception("Demonstration import error.");
+ }
+
+ try
+ {
+ // Read first two proto objects containing metadata and brain parameters.
+ Stream reader = File.OpenRead(ctx.assetPath);
+
+ var metaDataProto = DemonstrationMetaProto.Parser.ParseDelimitedFrom(reader);
+ var metaData = new DemonstrationMetaData(metaDataProto);
+
+ reader.Seek(DemonstrationStore.MetaDataBytes + 1, 0);
+ var brainParamsProto = BrainParametersProto.Parser.ParseDelimitedFrom(reader);
+ var brainParameters = new BrainParameters(brainParamsProto);
+
+ reader.Close();
+
+ var demonstration = ScriptableObject.CreateInstance();
+ demonstration.Initialize(brainParameters, metaData);
+ userData = demonstration.ToString();
+
+ Texture2D texture = (Texture2D)
+ AssetDatabase.LoadAssetAtPath(IconPath, typeof(Texture2D));
+
+#if UNITY_2017_3_OR_NEWER
+ ctx.AddObjectToAsset(ctx.assetPath, demonstration, texture);
+ ctx.SetMainObject(demonstration);
+#else
+ ctx.SetMainAsset(ctx.assetPath, model);
+#endif
+ }
+ catch
+ {
+ return;
+ }
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationImporter.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationImporter.cs.meta
new file mode 100755
index 00000000..bbdca977
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/DemonstrationImporter.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 7bd65ce151aaa4a41a45312543c56be1
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/HeuristicBrainEditor.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/HeuristicBrainEditor.cs
new file mode 100755
index 00000000..84630893
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/HeuristicBrainEditor.cs
@@ -0,0 +1,56 @@
+using UnityEngine;
+using UnityEditor;
+
+namespace MLAgents
+{
+ ///
+ /// CustomEditor for the Heuristic Brain class. Defines the default Inspector view for a
+ /// HeuristicBrain.
+ /// Shows the BrainParameters of the Brain and expose a tool to deep copy BrainParameters
+ /// between brains. Provides a drag box for a Decision Monoscript that will be used by
+ /// the Heuristic Brain.
+ ///
+ [CustomEditor(typeof(HeuristicBrain))]
+ public class HeuristicBrainEditor : BrainEditor
+ {
+ public override void OnInspectorGUI()
+ {
+ EditorGUILayout.LabelField("Heuristic Brain", EditorStyles.boldLabel);
+ var brain = (HeuristicBrain) target;
+ base.OnInspectorGUI();
+
+ // Expose the Heuristic Brain's Monoscript for decision in a drag and drop box.
+ brain.decisionScript = EditorGUILayout.ObjectField(
+ "Decision Script", brain.decisionScript, typeof(MonoScript), true) as MonoScript;
+
+ CheckIsDecision(brain);
+ // Draw an error box if the Decision is not set.
+ if (brain.decisionScript == null)
+ {
+ EditorGUILayout.HelpBox("You need to add a 'Decision' component to this Object",
+ MessageType.Error);
+ }
+ }
+
+ ///
+ /// Ensures tht the Monoscript for the decision of the HeuristicBrain is either null or
+ /// an implementation of Decision. If the Monoscript is not an implementation of
+ /// Decision, it will be set to null.
+ ///
+ /// The HeuristicBrain with the decision script attached
+ private static void CheckIsDecision(HeuristicBrain brain)
+ {
+ if (brain.decisionScript != null)
+ {
+ var decisionInstance = (CreateInstance(brain.decisionScript.name) as Decision);
+ if (decisionInstance == null)
+ {
+ Debug.LogError(
+ "Instance of " + brain.decisionScript.name + " couldn't be created. " +
+ "The the script class needs to derive from Decision.");
+ brain.decisionScript = null;
+ }
+ }
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/HeuristicBrainEditor.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/HeuristicBrainEditor.cs.meta
new file mode 100755
index 00000000..304d51f2
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/HeuristicBrainEditor.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: c3347a9ad704411896dd4898423c6515
+timeCreated: 1536852553
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/LearningBrainEditor.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/LearningBrainEditor.cs
new file mode 100755
index 00000000..cd5bcaad
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/LearningBrainEditor.cs
@@ -0,0 +1,84 @@
+using UnityEngine;
+using UnityEditor;
+
+namespace MLAgents
+{
+ ///
+ /// CustomEditor for the LearningBrain class. Defines the default Inspector view for a
+ /// LearningBrain.
+ /// Shows the BrainParameters of the Brain and expose a tool to deep copy BrainParameters
+ /// between brains. Also exposes a drag box for the Model that will be used by the
+ /// LearningBrain.
+ ///
+ [CustomEditor(typeof(LearningBrain))]
+ public class LearningBrainEditor : BrainEditor
+ {
+ private const string ModelPropName = "model";
+ private const string InferenceDevicePropName = "inferenceDevice";
+ private const float TimeBetweenModelReloads = 2f;
+ // Time since the last reload of the model
+ private float _timeSinceModelReload;
+ // Whether or not the model needs to be reloaded
+ private bool _requireReload;
+
+ ///
+ /// Called when the user opens the Inspector for the LearningBrain
+ ///
+ public void OnEnable()
+ {
+ _requireReload = true;
+ EditorApplication.update += IncreaseTimeSinceLastModelReload;
+ }
+
+ ///
+ /// Called when the user leaves the Inspector for the LearningBrain
+ ///
+ public void OnDisable()
+ {
+ EditorApplication.update -= IncreaseTimeSinceLastModelReload;
+ }
+
+ public override void OnInspectorGUI()
+ {
+ EditorGUILayout.LabelField("Learning Brain", EditorStyles.boldLabel);
+ var brain = (LearningBrain) target;
+ var serializedBrain = serializedObject;
+ EditorGUI.BeginChangeCheck();
+ base.OnInspectorGUI();
+ serializedBrain.Update();
+ var tfGraphModel = serializedBrain.FindProperty(ModelPropName);
+ EditorGUILayout.ObjectField(tfGraphModel);
+ var inferenceDevice = serializedBrain.FindProperty(InferenceDevicePropName);
+ EditorGUILayout.PropertyField(inferenceDevice);
+ serializedBrain.ApplyModifiedProperties();
+ if (EditorGUI.EndChangeCheck())
+ {
+ _requireReload = true;
+ }
+ if (_requireReload && _timeSinceModelReload > TimeBetweenModelReloads)
+ {
+ brain.ReloadModel();
+ _requireReload = false;
+ _timeSinceModelReload = 0;
+ }
+ // Display all failed checks
+ var failedChecks = brain.GetModelFailedChecks();
+ foreach (var check in failedChecks)
+ {
+ if (check != null)
+ {
+ EditorGUILayout.HelpBox(check, MessageType.Warning);
+ }
+ }
+ }
+
+ ///
+ /// Increases the time since last model reload by the deltaTime since the last Update call
+ /// from the UnityEditor
+ ///
+ private void IncreaseTimeSinceLastModelReload()
+ {
+ _timeSinceModelReload += Time.deltaTime;
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/LearningBrainEditor.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/LearningBrainEditor.cs.meta
new file mode 100755
index 00000000..ce3229c2
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/LearningBrainEditor.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: b538d92cc78b4a62a596822eca31423e
+timeCreated: 1536970736
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/NNModelImporter.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/NNModelImporter.cs
new file mode 100755
index 00000000..9abf9899
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/NNModelImporter.cs
@@ -0,0 +1,29 @@
+using System.IO;
+using UnityEditor;
+using UnityEngine;
+using UnityEditor.Experimental.AssetImporters;
+using MLAgents.InferenceBrain;
+
+namespace MLAgents
+{
+ ///
+ /// Asset Importer of barracuda models.
+ ///
+ [ScriptedImporter(1, new[] {"nn"})]
+ public class NNModelImporter : ScriptedImporter {
+ private const string IconPath = "Assets/ML-Agents/Resources/NNModelIcon.png";
+
+ public override void OnImportAsset(AssetImportContext ctx)
+ {
+ var model = File.ReadAllBytes(ctx.assetPath);
+ var asset = ScriptableObject.CreateInstance();
+ asset.Value = model;
+
+ Texture2D texture = (Texture2D)
+ AssetDatabase.LoadAssetAtPath(IconPath, typeof(Texture2D));
+
+ ctx.AddObjectToAsset(ctx.assetPath, asset, texture);
+ ctx.SetMainObject(asset);
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/NNModelImporter.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/NNModelImporter.cs.meta
new file mode 100755
index 00000000..cfb75680
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/NNModelImporter.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 83221ad3db87f4b3b91b041047cb2bc5
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/PlayerBrainEditor.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/PlayerBrainEditor.cs
new file mode 100755
index 00000000..2efc0d7b
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/PlayerBrainEditor.cs
@@ -0,0 +1,106 @@
+using System.Collections;
+using System.Collections.Generic;
+using UnityEngine;
+
+using UnityEditor;
+using System.Linq;
+
+namespace MLAgents
+{
+ ///
+ /// CustomEditor for the PlayerBrain class. Defines the default Inspector view for a
+ /// PlayerBrain.
+ /// Shows the BrainParameters of the Brain and expose a tool to deep copy BrainParameters
+ /// between brains. Also exposes the key mappings for either continuous or discrete control
+ /// depending on the Vector Action Space Type of the Brain Parameter. These mappings are the
+ /// ones that will be used by the PlayerBrain.
+ ///
+ [CustomEditor(typeof(PlayerBrain))]
+ public class PlayerBrainEditor : BrainEditor
+ {
+ private const string KeyContinuousPropName = "keyContinuousPlayerActions";
+ private const string KeyDiscretePropName = "discretePlayerActions";
+ private const string AxisContinuousPropName = "axisContinuousPlayerActions";
+
+ public override void OnInspectorGUI()
+ {
+ EditorGUILayout.LabelField("Player Brain", EditorStyles.boldLabel);
+ var brain = (PlayerBrain) target;
+ var serializedBrain = serializedObject;
+ base.OnInspectorGUI();
+
+ serializedBrain.Update();
+ if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous)
+ {
+ DrawContinuousKeyMapping(serializedBrain, brain);
+ }
+ else
+ {
+ DrawDiscreteKeyMapping(serializedBrain);
+ }
+ serializedBrain.ApplyModifiedProperties();
+ }
+
+ ///
+ /// Draws the UI for continuous control key mapping to actions.
+ ///
+ /// The SerializedObject corresponding to the brain.
+ /// The Brain of which properties are displayed.
+ private static void DrawContinuousKeyMapping(
+ SerializedObject serializedBrain, PlayerBrain brain)
+ {
+ GUILayout.Label("Edit the continuous inputs for your actions", EditorStyles.boldLabel);
+ var keyActionsProp = serializedBrain.FindProperty(KeyContinuousPropName);
+ var axisActionsProp = serializedBrain.FindProperty(AxisContinuousPropName);
+ EditorGUILayout.PropertyField(keyActionsProp , true);
+ EditorGUILayout.PropertyField(axisActionsProp, true);
+ var keyContinuous = brain.keyContinuousPlayerActions;
+ var axisContinuous = brain.axisContinuousPlayerActions;
+ var brainParams = brain.brainParameters;
+ if (keyContinuous == null)
+ {
+ keyContinuous = new PlayerBrain.KeyContinuousPlayerAction[0];
+ }
+ if (axisContinuous == null)
+ {
+ axisContinuous = new PlayerBrain.AxisContinuousPlayerAction[0];
+ }
+ foreach (var action in keyContinuous)
+ {
+ if (action.index >= brainParams.vectorActionSize[0])
+ {
+ EditorGUILayout.HelpBox(
+ $"Key {action.key.ToString()} is assigned to index " +
+ $"{action.index.ToString()} but the action size is only of size " +
+ $"{brainParams.vectorActionSize.ToString()}",
+ MessageType.Error);
+ }
+ }
+ foreach (var action in axisContinuous)
+ {
+ if (action.index >= brainParams.vectorActionSize[0])
+ {
+ EditorGUILayout.HelpBox(
+ $"Axis {action.axis} is assigned to index {action.index.ToString()} " +
+ $"but the action size is only of size {brainParams.vectorActionSize}",
+ MessageType.Error);
+ }
+ }
+ GUILayout.Label("You can change axis settings from Edit->Project Settings->Input",
+ EditorStyles.helpBox );
+ }
+
+ ///
+ /// Draws the UI for discrete control key mapping to actions.
+ ///
+ /// The SerializedObject corresponding to the brain.
+ private static void DrawDiscreteKeyMapping(SerializedObject serializedBrain)
+ {
+ GUILayout.Label("Edit the discrete inputs for your actions",
+ EditorStyles.boldLabel);
+ var dhas = serializedBrain.FindProperty(KeyDiscretePropName);
+ serializedBrain.Update();
+ EditorGUILayout.PropertyField(dhas, true);
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/PlayerBrainEditor.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/PlayerBrainEditor.cs.meta
new file mode 100755
index 00000000..82acf52b
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/PlayerBrainEditor.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 0d99e43f78e54b4f96a346219e2ca2d2
+timeCreated: 1536851993
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/ResetParameterDrawer.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/ResetParameterDrawer.cs
new file mode 100755
index 00000000..abc5aee8
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/ResetParameterDrawer.cs
@@ -0,0 +1,179 @@
+using UnityEngine;
+using UnityEditor;
+using System;
+using System.Linq;
+using UnityEditor.SceneManagement;
+
+namespace MLAgents
+{
+ ///
+ /// PropertyDrawer for ResetParameters. Defines how ResetParameters are displayed in the
+ /// Inspector.
+ ///
+ [CustomPropertyDrawer(typeof(ResetParameters))]
+ public class ResetParameterDrawer : PropertyDrawer
+ {
+ private ResetParameters _parameters;
+ // The height of a line in the Unity Inspectors
+ private const float LineHeight = 17f;
+ // This is the prefix for the key when you add a reset parameter
+ private const string NewKeyPrefix = "Param-";
+
+ ///
+ /// Computes the height of the Drawer depending on the property it is showing
+ ///
+ /// The property that is being drawn.
+ /// The label of the property being drawn.
+ /// The vertical space needed to draw the property.
+ public override float GetPropertyHeight(SerializedProperty property, GUIContent label)
+ {
+ LazyInitializeParameters(property, label);
+ return (_parameters.Count + 2) * LineHeight;
+ }
+
+ ///
+ public override void OnGUI(Rect position, SerializedProperty property, GUIContent label)
+ {
+ LazyInitializeParameters(property, label);
+ position.height = LineHeight;
+ EditorGUI.LabelField(position, label);
+ position.y += LineHeight;
+ var width = position.width / 2 - 24;
+ var keyRect = new Rect(position.x + 20, position.y, width, position.height);
+ var valueRect = new Rect(position.x + width + 30, position.y, width, position.height);
+ DrawAddRemoveButtons(keyRect, valueRect);
+ EditorGUI.BeginProperty(position, label, property);
+ foreach (var parameter in _parameters)
+ {
+ var key = parameter.Key;
+ var value = parameter.Value;
+ keyRect.y += LineHeight;
+ valueRect.y += LineHeight;
+ EditorGUI.BeginChangeCheck();
+ var newKey = EditorGUI.TextField(keyRect, key);
+ if (EditorGUI.EndChangeCheck())
+ {
+ MarkSceneAsDirty();
+ try
+ {
+ _parameters.Remove(key);
+ _parameters.Add(newKey, value);
+ }
+ catch (Exception e)
+ {
+ Debug.Log(e.Message);
+ }
+ break;
+ }
+
+ EditorGUI.BeginChangeCheck();
+ value = EditorGUI.FloatField(valueRect, value);
+ if (EditorGUI.EndChangeCheck())
+ {
+ MarkSceneAsDirty();
+ _parameters[key] = value;
+ break;
+ }
+ }
+ EditorGUI.EndProperty();
+ }
+
+ ///
+ /// Draws the Add and Remove buttons.
+ ///
+ /// The rectangle for the Add New button.
+ /// The rectangle for the Remove Last button.
+ private void DrawAddRemoveButtons(Rect addRect, Rect removeRect)
+ {
+ // This is the Add button
+ if (_parameters.Count == 0)
+ {
+ addRect.width *= 2;
+ }
+ if (GUI.Button(addRect,
+ new GUIContent("Add New", "Add a new item to the default reset parameters"),
+ EditorStyles.miniButton))
+ {
+ MarkSceneAsDirty();
+ AddParameter();
+ }
+
+ // If there are no items in the ResetParameters, Hide the Remove button
+ if (_parameters.Count == 0)
+ {
+ return;
+ }
+ // This is the Remove button
+ if (GUI.Button(removeRect,
+ new GUIContent(
+ "Remove Last", "Remove the last item from the default reset parameters"),
+ EditorStyles.miniButton))
+ {
+ MarkSceneAsDirty();
+ RemoveLastParameter();
+ }
+ }
+
+ ///
+ /// Signals that the property has been modified and requires the scene to be saved for
+ /// the changes to persist. Only works when the Editor is not playing.
+ ///
+ private static void MarkSceneAsDirty()
+ {
+ if (!EditorApplication.isPlaying)
+ {
+ EditorSceneManager.MarkSceneDirty(EditorSceneManager.GetActiveScene());
+ }
+ }
+
+ ///
+ /// Ensures that the state of the Drawer is synchronized with the property.
+ ///
+ /// The SerializedProperty of the ResetParameters
+ /// to make the custom GUI for.
+ /// The label of this property.
+ private void LazyInitializeParameters(SerializedProperty property, GUIContent label)
+ {
+ if (_parameters != null)
+ {
+ return;
+ }
+ var target = property.serializedObject.targetObject;
+ _parameters = fieldInfo.GetValue(target) as ResetParameters;
+ if (_parameters == null)
+ {
+ _parameters = new ResetParameters();
+ fieldInfo.SetValue(target, _parameters);
+ }
+ }
+
+ ///
+ /// Removes the last ResetParameter from the ResetParameters
+ ///
+ private void RemoveLastParameter()
+ {
+ if (_parameters.Count > 0)
+ {
+ string key = _parameters.Keys.ToList()[_parameters.Count - 1];
+ _parameters.Remove(key);
+ }
+ }
+
+ ///
+ /// Adds a new ResetParameter to the ResetParameters with a default name.
+ ///
+ private void AddParameter()
+ {
+ string key = NewKeyPrefix + _parameters.Count;
+ var value = default(float);
+ try
+ {
+ _parameters.Add(key, value);
+ }
+ catch (Exception e)
+ {
+ Debug.Log(e.Message);
+ }
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/ResetParameterDrawer.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/ResetParameterDrawer.cs.meta
new file mode 100755
index 00000000..2a082721
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Editor/ResetParameterDrawer.cs.meta
@@ -0,0 +1,12 @@
+fileFormatVersion: 2
+guid: 740b9a60fe38f476ab020dcf91f3f94a
+timeCreated: 1517291065
+licenseType: Free
+MonoImporter:
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins.meta
new file mode 100755
index 00000000..2e066a06
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: d6d56028f4c564724878c82cfa3c9e14
+folderAsset: yes
+timeCreated: 1502996258
+licenseType: Pro
+DefaultImporter:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core.meta
new file mode 100755
index 00000000..42930051
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 13df47c141a644f57bdb0a667879ef0b
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.md b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.md
new file mode 100755
index 00000000..61b0acc7
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.md
@@ -0,0 +1,166 @@
+
+
+# Barracuda
+
+**Barracuda** is a lightweight and **cross-platform** Neural Net **inference library for Unity**. Barracuda can execute both on GPU and CPU. Currently Barracuda is in the early development stage, so adventures are expected.
+
+## Using Barracuda
+Typically the following steps are needed to use Barracuda in application:
+1. load model,
+2. create inference engine (the worker),
+3. execute model and
+4. fetch results.
+
+But first you have to convert your TensorFlow (or ONNX) model to Barracuda format with python scripts. Example usage:
+```bash
+python onnx_to_barracuda.py Models/mnist/model.onnx Destination/mnist.bytes
+```
+See _Converting models to Barracuda_ paragraph below for more information.
+
+### Load Model into Barracuda
+Once you have your TensorFlow (or ONNX) model converted, you can load resulting Barracuda file via `ModelLoader`:
+```C#
+var model = ModelLoader.LoadFromStreamingAssets(modelName + ".bytes");
+```
+
+### Create inference engine (Worker)
+Inference engine in Barracuda is called Worker. Worker is responsible for converting model into executable tasks and scheduling them on GPU or CPU.
+```C#
+var worker = BarracudaWorkerFactory.CreateWorker(BarracudaWorkerFactory.Type.ComputeFast, model)
+```
+
+### Execute the model
+Inputs can be provided both as sole `Tensor` object (assuming Model has only one input) or as a dictionary of name and `Tensor` pairs.
+
+```C#
+var inputs = new Dictionary();
+inputs[name1] = new Tensor(...);
+inputs[name2] = new Tensor(...);
+worker.Execute(inputs);
+```
+Execution is asynchronous for GPU backends. Currently implementation is synchronous for CPU backends, however it is good to assume that execution will be async for all backends in the future.
+
+### Fetch outputs
+If model has only single output, then simple `worker.Fetch()` can be used, otherwise output names should be provided.
+```C#
+var O = worker.Fetch(outputName);
+```
+
+### Cleanup
+As a Barracuda client you are responsible to `Dispose` _worker_, _inputs_ and _outputs_ you fetched. This is necessary to properly free GPU resources.
+```C#
+O.Dispose();
+worker.Dispose();
+```
+
+## Working with data
+
+### Tensor
+Barracuda stores data in `batch`,`height`,`width`,`channels` also known as _NHWC_ or _channels-last_ format. You can interact with `Tensor` data via multi-dimensional array operators:
+```C#
+var tensor = new Tensor(batchCount, height, width, channelCount);
+tensor[n, y, x, c] = 1.0f; // as N batches of 3 dimensional data: N x {X, Y, C}
+tensor[n, c] = 2.0f; // as N batches of 1 dimensional data: N x {C}
+tensor[ i] = 3.0f; // as flat array
+```
+
+There are number of `Tensor` constructors that cover variety of scenarios. By default tensors are initialized with `0` upon construction, unless intialization `Array` is provided.
+```C#
+tensor = new Tensor(batchCount, height, width, channelCount); // batch of 3 dimensional data, 0 initialized: batchCount x {height, width, channelCount}
+tensor = new Tensor(batchCount, elementCount); // batch of 1 dimensional data, 0 initialized: batchCount x {elementCount}
+
+var stridedArray = new float[batchCount * elementCount] { ... };
+tensor = new Tensor(batchCount, elementCount, stridedArray); // batch of 1 dimensional data, initialized from strided array
+
+var jaggedArray = new float[batchCount][elementCount] { ... };
+tensor = new Tensor(batchCount, elementCount, jaggedArray); // batch of 1 dimensional data, initialized from jagged array
+
+Texture2D texture = ...;
+tensor = new Tensor(texture); // tensor initialized with texture data: 1 x { texture.width, texture.height, 3}
+```
+
+You can query shape of the `Tensor` object, but you can not change it. Shape of the `Tensor` is immutable. If you want to have different shape of `Tensor`, you have to construct the new instance of `Tensor` object.
+```C#
+var shape = tensor.shape;
+Debug.Log(shape + " or " + shape.batch + shape.height + shape.width + shape.channels);
+```
+
+### Texture as input
+You can directly pass `Texture2D`, `Texture2DArray`, `Texture3D` or `RenderTexture` to Barracuda without accessing individual pixels on CPU:
+```C#
+var channelCount = 3; // you can treat input pixels as 1 (grayscale), 3 (color) or 4 (color with alpha) channels
+var tensor = new Tensor(texture, channelCount);
+```
+You can batch multiple textures into the single `Tensor` object:
+```C#
+var textures = new [] { texture0, texture1, texture2, texture3 }; // these textures will form a batch
+var tensor = new Tensor(textures, channelCount);
+```
+Note that to form a batch all textures must have the same width and height dimensions.
+
+### Texture as output
+If you want to use Barracuda execution results further in the graphics pipeline, you can copy data from `Tensor` into `RenderTexture` without stalling CPU or GPU:
+```C#
+ var tensor = worker.Fetch();
+ var texture = BarracudaTextureUtils.TensorToRenderTexture(tensor);
+```
+If you wish, you can reuse the same `RenderTexture` multiple times:
+```C#
+ var texture = new RenderTexture(width, height, 0);
+ // ...
+ var tensor = worker.Fetch();
+ BarracudaTextureUtils.TensorToRenderTexture(tensor, texture);
+```
+
+## Introspecting Barracuda models
+Barracuda model has very simple memory representation. Once model is loaded you can query for inputs and outputs:
+```C#
+string[] inputNames = model.inputs; // query model inputs
+string[] outputNames = model.outputs; // query model outputs
+```
+Or you can directly iterate through the layers and investigate what model is going to do:
+```C#
+foreach (var layer in model.layers)
+ Debug.Log(layer.name + " does " + layer.type);
+```
+
+## Verbose mode
+You can turn on verbose mode for different parts of Barracuda:
+```C#
+bool verbose = true;
+var model = ModelLoader.LoadFromStreamingAssets(modelName + ".bytes", verbose); // verbose loader
+var worker = BarracudaWorkerFactory.CreateWorker(BarracudaWorkerFactory.Type.ComputeFast, model, verbose); // verbose execution
+```
+
+## Converting TensorFlow and ONNX models to Barracuda format
+Barracuda comes with dedicated python scripts to convert pre-trained TensorFlow and ONNX models to Barracuda format.
+
+Convert from TensorFlow:
+```bash
+python tensorflow_to_barracuda.py Models/3DBall-tf-model.pb Destination/3DBall-bc.bytes
+```
+
+Convert from ONNX:
+```bash
+python onnx_to_barracuda.py Models/mnist/model.onnx Destination/mnist-bc.bytes
+```
+
+If network has multiple outputs, but you need only particular ones during the inference, there is an optional `-trim` flag to remove unused outputs and calculations.
+For example:
+```bash
+python tensorflow_to_barracuda.py Models/3DBall-tf-model.pb Destination/3DBall-bc.bytes -trim action$
+```
+Trim will first remove outputs that do not match regular expression from the graph. In this case only output that ends with `action` will be left.
+Next trim will strip all nodes that do not participate in the evaluation of the output.
+
+
+P.S. Python 3.5 or 3.6 is recommended
+
+P.P.S. We plan to migrate Tensorflow and ONNX converters from Python to C# in the future.
+
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.md.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.md.meta
new file mode 100755
index 00000000..4a967c38
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.md.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: 3cf2bcd7dcfe144bebf6cf271e7dfbe0
+TextScriptImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.meta
new file mode 100755
index 00000000..c142006d
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 4d59cec597ba94288831c0cade38b14e
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Barracuda.dll b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Barracuda.dll
new file mode 100755
index 00000000..0ad29477
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Barracuda.dll differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Barracuda.dll.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Barracuda.dll.meta
new file mode 100755
index 00000000..3e4f56da
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Barracuda.dll.meta
@@ -0,0 +1,30 @@
+fileFormatVersion: 2
+guid: de59cc66e5e394f93b2a692e50bce97f
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ Any:
+ second:
+ enabled: 1
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 0
+ settings:
+ DefaultValueInitialized: true
+ - first:
+ Windows Store Apps: WindowsStoreApps
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins.meta
new file mode 100755
index 00000000..d253192d
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: a7bba248e968b476a875260a8127a595
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX.meta
new file mode 100755
index 00000000..ecc28271
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 5087a463bec2b4b76808e7307a94887f
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef
new file mode 100755
index 00000000..9d6f291a
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef
@@ -0,0 +1,11 @@
+{
+ "name": "MacBLAS",
+ "references": [],
+ "optionalUnityReferences": [],
+ "includePlatforms": [
+ "Editor",
+ "macOSStandalone"
+ ],
+ "excludePlatforms": [],
+ "allowUnsafeCode": true
+}
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef.meta
new file mode 100755
index 00000000..4a3cefc8
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: 53fc9961397934ed38a573ce1392c80c
+AssemblyDefinitionImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs
new file mode 100755
index 00000000..6e39c3fa
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs
@@ -0,0 +1,29 @@
+#if UNITY_STANDALONE_OSX || UNITY_EDITOR_OSX
+using System.Runtime.InteropServices;
+using Barracuda;
+using UnityEngine;
+using UnityEngine.Scripting;
+
+
+[Preserve]
+public class MacBLAS : BLASPlugin
+{
+ [DllImport("macblas")]
+ static extern unsafe void macsgemm(float* Ap, int AN, int AM,
+ float* Bp, int BN, int BM,
+ float* Cp, int CN, int CM,
+ int bs, bool transposeA, bool transposeB);
+
+ public bool IsCurrentPlatformSupported()
+ {
+ return Application.platform == RuntimePlatform.OSXEditor ||
+ Application.platform == RuntimePlatform.OSXPlayer;
+ }
+
+ public unsafe void SGEMM(float* Ap, int AN, int AM, float* Bp, int BN, int BM, float* Cp, int CN, int CM, int bs,
+ bool transposeA = false, bool transposeB = false)
+ {
+ macsgemm(Ap, AN, AM, Bp, BN, BM, Cp, CN, CM, bs, transposeA, transposeB);
+ }
+}
+#endif // UNITY_OSX
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs.meta
new file mode 100755
index 00000000..b90d4acb
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 680f04373f71f48a89408105d3f58a08
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle.meta
new file mode 100755
index 00000000..c73e2100
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle.meta
@@ -0,0 +1,40 @@
+fileFormatVersion: 2
+guid: 6633afded85ec4f00a4cc653053461bb
+folderAsset: yes
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ '': OSXIntel
+ second:
+ enabled: 1
+ settings: {}
+ - first:
+ '': OSXIntel64
+ second:
+ enabled: 1
+ settings: {}
+ - first:
+ Any:
+ second:
+ enabled: 0
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 1
+ settings:
+ DefaultValueInitialized: true
+ - first:
+ Standalone: OSXUniversal
+ second:
+ enabled: 1
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents.meta
new file mode 100755
index 00000000..a0a3fd80
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 5de42c62131964fc999e1dc3d292cc31
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist
new file mode 100755
index 00000000..22d6943a
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist
@@ -0,0 +1,40 @@
+
+
+
+
+ BuildMachineOSBuild
+ 14F27
+ CFBundleDevelopmentRegion
+ en
+ CFBundleExecutable
+ macblas
+ CFBundleIdentifier
+ com.unity3d.macblas
+ CFBundleInfoDictionaryVersion
+ 6.0
+ CFBundleName
+ macblas
+ CFBundlePackageType
+ BNDL
+ CFBundleShortVersionString
+ 0.1.4
+ CFBundleVersion
+ 1
+ DTCompiler
+ com.apple.compilers.llvm.clang.1_0
+ DTPlatformBuild
+ 6A1052d
+ DTPlatformVersion
+ GM
+ DTSDKBuild
+ 14A382
+ DTSDKName
+ macosx10.10
+ DTXcode
+ 0610
+ DTXcodeBuild
+ 6A1052d
+ NSHumanReadableCopyright
+ Copyright © 2018 Unity Technologies. All rights reserved.
+
+
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist.meta
new file mode 100755
index 00000000..2a9aa9e4
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: 844f003f25d444aafad9fb1fcea17bbc
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS.meta
new file mode 100755
index 00000000..dc277cfa
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 0620b207d80004fe595413acf79f2f66
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas
new file mode 100755
index 00000000..e3f52632
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas.meta
new file mode 100755
index 00000000..7077e866
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: e9ef2c9e25cad478aa1220d6cf68a2ed
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature.meta
new file mode 100755
index 00000000..2a52881c
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 93038b433855548879a151644d2354c1
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources
new file mode 100755
index 00000000..0710b400
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources
@@ -0,0 +1,105 @@
+
+
+
+
+ files
+
+ files2
+
+ rules
+
+ ^Resources/
+
+ ^Resources/.*\.lproj/
+
+ optional
+
+ weight
+ 1000
+
+ ^Resources/.*\.lproj/locversion.plist$
+
+ omit
+
+ weight
+ 1100
+
+ ^version.plist$
+
+
+ rules2
+
+ .*\.dSYM($|/)
+
+ weight
+ 11
+
+ ^(.*/)?\.DS_Store$
+
+ omit
+
+ weight
+ 2000
+
+ ^(Frameworks|SharedFrameworks|PlugIns|Plug-ins|XPCServices|Helpers|MacOS|Library/(Automator|Spotlight|LoginItems))/
+
+ nested
+
+ weight
+ 10
+
+ ^.*
+
+ ^Info\.plist$
+
+ omit
+
+ weight
+ 20
+
+ ^PkgInfo$
+
+ omit
+
+ weight
+ 20
+
+ ^Resources/
+
+ weight
+ 20
+
+ ^Resources/.*\.lproj/
+
+ optional
+
+ weight
+ 1000
+
+ ^Resources/.*\.lproj/locversion.plist$
+
+ omit
+
+ weight
+ 1100
+
+ ^[^/]+$
+
+ nested
+
+ weight
+ 10
+
+ ^embedded\.provisionprofile$
+
+ weight
+ 20
+
+ ^version\.plist$
+
+ weight
+ 20
+
+
+
+
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources.meta
new file mode 100755
index 00000000..87c151ef
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: 523ab7e7760c743a9977ecfedabe1691
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS.meta
new file mode 100755
index 00000000..0d588e91
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 256085e1b062345239f3d7d88741f96c
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef
new file mode 100755
index 00000000..ba581665
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef
@@ -0,0 +1,11 @@
+{
+ "name": "iOSBLAS",
+ "references": [],
+ "optionalUnityReferences": [],
+ "includePlatforms": [
+ "Editor",
+ "iOS"
+ ],
+ "excludePlatforms": [],
+ "allowUnsafeCode": true
+}
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef.meta
new file mode 100755
index 00000000..5b93d769
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: 005937e819cd540429ad05eabcfb642f
+AssemblyDefinitionImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs
new file mode 100755
index 00000000..03e3c8b8
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs
@@ -0,0 +1,27 @@
+#if UNITY_IOS
+using System.Runtime.InteropServices;
+using Barracuda;
+using UnityEngine;
+using UnityEngine.Scripting;
+
+[Preserve]
+public class iOSBLAS : BLASPlugin
+{
+ [DllImport("__Internal")]
+ static extern unsafe void iossgemm(float* Ap, int AN, int AM,
+ float* Bp, int BN, int BM,
+ float* Cp, int CN, int CM,
+ int bs, bool transposeA, bool transposeB);
+
+ public bool IsCurrentPlatformSupported()
+ {
+ return Application.platform == RuntimePlatform.IPhonePlayer;
+ }
+
+ public unsafe void SGEMM(float* Ap, int AN, int AM, float* Bp, int BN, int BM, float* Cp, int CN, int CM, int bs,
+ bool transposeA = false, bool transposeB = false)
+ {
+ iossgemm(Ap, AN, AM, Bp, BN, BM, Cp, CN, CM, bs, transposeA, transposeB);
+ }
+}
+#endif // UNITY_IOS
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs.meta
new file mode 100755
index 00000000..9304817b
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 75424b0c6afc14ea7a1debef68240d9e
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm
new file mode 100755
index 00000000..15cbe6c7
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm
@@ -0,0 +1,15 @@
+#import
+
+extern "C"
+{
+void iossgemm(float* Ap, int AN, int AM,
+ float* Bp, int BN, int BM,
+ float* Cp, int CN, int CM,
+ int bs, bool transposeA, bool transposeB)
+ {
+ cblas_sgemm(CblasRowMajor, transposeA ? CblasTrans : CblasNoTrans,
+ transposeB ? CblasTrans : CblasNoTrans,
+ AN, BM, BN, 1.0f, Ap, AM, Bp, BM, 1.0f, Cp, CM);
+ }
+
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm.meta
new file mode 100755
index 00000000..2fa3f6de
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm.meta
@@ -0,0 +1,102 @@
+fileFormatVersion: 2
+guid: 100b08f95d9f349118f287b0170140d4
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ '': Any
+ second:
+ enabled: 0
+ settings:
+ Exclude Android: 1
+ Exclude Editor: 1
+ Exclude Linux: 1
+ Exclude Linux64: 1
+ Exclude LinuxUniversal: 1
+ Exclude OSXUniversal: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude iOS: 0
+ - first:
+ Android: Android
+ second:
+ enabled: 0
+ settings:
+ CPU: ARMv7
+ - first:
+ Any:
+ second:
+ enabled: 0
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ DefaultValueInitialized: true
+ OS: AnyOS
+ - first:
+ Facebook: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Facebook: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Linux
+ second:
+ enabled: 0
+ settings:
+ CPU: x86
+ - first:
+ Standalone: Linux64
+ second:
+ enabled: 0
+ settings:
+ CPU: x86_64
+ - first:
+ Standalone: OSXUniversal
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ iPhone: iOS
+ second:
+ enabled: 1
+ settings:
+ AddToEmbeddedBinaries: false
+ CompileFlags:
+ FrameworkDependencies: Accelerate;
+ - first:
+ tvOS: tvOS
+ second:
+ enabled: 1
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources.meta
new file mode 100755
index 00000000..da72593c
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 264a957219ea041c58af860601fe1881
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute
new file mode 100755
index 00000000..d43c11cf
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute
@@ -0,0 +1,679 @@
+#pragma kernel Relu
+#pragma kernel Relu_CNyx
+#pragma kernel Relu_Nyxc
+#pragma kernel Relu6
+#pragma kernel Relu6_CNyx
+#pragma kernel Relu6_Nyxc
+#pragma kernel Tanh
+#pragma kernel Tanh_CNyx
+#pragma kernel Tanh_Nyxc
+#pragma kernel Swish
+#pragma kernel Swish_CNyx
+#pragma kernel Swish_Nyxc
+#pragma kernel Sigmoid
+#pragma kernel Sigmoid_CNyx
+#pragma kernel Sigmoid_Nyxc
+#pragma kernel Elu
+#pragma kernel Elu_CNyx
+#pragma kernel Elu_Nyxc
+#pragma kernel LeakyRelu
+#pragma kernel LeakyRelu_CNyx
+#pragma kernel LeakyRelu_Nyxc
+#pragma kernel Exp
+#pragma kernel Exp_CNyx
+#pragma kernel Exp_Nyxc
+#pragma kernel Pow
+#pragma kernel Pow_CNyx
+#pragma kernel Pow_Nyxc
+#pragma kernel Softmax
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL_RW(O)
+
+float _Alpha;
+
+float relu(float v)
+{
+ return 0.5f * (v + abs(v));
+}
+
+float relu6(float v)
+{
+ return min(max(0, v), 6);
+}
+
+float swish(float v)
+{
+ return v / (1.f + exp(-v));
+}
+
+float sigmoid(float v)
+{
+ return 1.f / (1.f + exp(-v));
+}
+
+float elu(float v)
+{
+ if (v <= 0)
+ v = _Alpha * (exp(v) - 1);
+ return v;
+}
+
+float lrelu(float v)
+{
+ return max(v, _Alpha * v);
+}
+
+float signed_pow(float f, float e)
+{
+ // handle negative f
+ float v = pow(abs(f), e);
+ float s = (e % 2 == 1) ?
+ sign(f): // exponent is odd => sign(f) * pow(abs(f), e)
+ 1; // exponent is even => pow(abs(f), e)
+ return v * s;
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Relu(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = relu(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Relu6(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = relu6(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Tanh(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = tanh(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+ void Sigmoid(uint3 dispatchThreadID : SV_DispatchThreadID)
+ {
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = sigmoid(v);
+ O.Set(n, y, x, c, v);
+ }
+ }
+
+ NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Swish(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = swish(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Elu(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = elu(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void LeakyRelu(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = lrelu(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Exp(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = exp(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Pow(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = signed_pow(v, _Alpha);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void Relu_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = relu(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void Relu_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = relu(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void Relu6_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = relu6(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void Relu6_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = relu6(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void Tanh_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = tanh(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void Tanh_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = tanh(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void Sigmoid_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = sigmoid(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void Sigmoid_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = sigmoid(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void Swish_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = swish(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void Swish_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = swish(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void Elu_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = elu(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void Elu_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = elu(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void LeakyRelu_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = lrelu(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void LeakyRelu_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = lrelu(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void Exp_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = exp(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void Exp_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = exp(v);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void Pow_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = signed_pow(v, _Alpha);
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
+void Pow_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
+ TENSOR_ARGS2(X, O);
+
+ uint nyxc = dispatchThreadID.x;
+
+ uint c = nyxc % X.channels;
+ uint nyx = nyxc / X.channels;
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (n >= X.batch) return;
+
+ float v = X.Get(n, y, x, c);
+ v = signed_pow(v, _Alpha);
+ O.Set(n, y, x, c, v);
+}
+
+
+NUMTHREADS((64,4,1), (64,2,1), (64,1,1))
+void Softmax(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint x = dispatchThreadID.x;
+ uint y = dispatchThreadID.y;
+
+ if (x >= O.GetFlatWidth()) return;
+ if (y >= O.GetFlatHeight()) return;
+
+ float maxV = -FLT_MAX;
+ for (uint i = 0; i < X.GetFlatWidth(); ++i)
+ {
+ float v = X.Get(y, i);
+ if (v > maxV)
+ maxV = v;
+ }
+
+ float acc = 0.0f;
+ for (i = 0; i < X.GetFlatWidth(); ++i)
+ {
+ float v = X.Get(y, i);
+ acc += exp(v - maxV);
+ }
+
+ float v = X.Get(y, x);
+ v = exp(v - maxV) / acc;
+ O.Set(y, x, v);
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute.meta
new file mode 100755
index 00000000..1c31e435
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: fdc94044b2f234c0fa80ada3771a2ae7
+timeCreated: 1495527718
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute
new file mode 100755
index 00000000..14e4e327
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute
@@ -0,0 +1,885 @@
+#pragma kernel Dense
+#pragma kernel Conv2D
+#pragma kernel DepthwiseConv2D
+#pragma kernel Conv2DTrans
+#pragma kernel Upsample2D
+#pragma kernel Unstride2D
+#pragma kernel MaxPool2D
+#pragma kernel AvgPool2D
+#pragma kernel GlobalMaxPool2D
+#pragma kernel GlobalAvgPool2D
+#pragma kernel ScaleBias
+#pragma kernel InstanceNorm
+#pragma kernel Dropout
+#pragma kernel Relu
+#pragma kernel Swish
+#pragma kernel Softmax
+#pragma kernel Tanh
+#pragma kernel Sigmoid
+#pragma kernel Relu6
+#pragma kernel Elu
+#pragma kernel LeakyRelu
+#pragma kernel Exp
+#pragma kernel Pow
+#pragma kernel Copy
+#pragma kernel BroadcastAdd
+#pragma kernel BroadcastSub
+#pragma kernel BroadcastMul
+#pragma kernel BroadcastDiv
+#pragma kernel BroadcastPow
+#pragma kernel BroadcastMin
+#pragma kernel BroadcastMax
+#pragma kernel TextureToTensor
+#pragma kernel TensorToTexture
+
+#include "Tensor.cginc"
+#include "Random.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(W)
+TENSOR_DECL(K)
+TENSOR_DECL(B)
+TENSOR_DECL_RW(O)
+
+uint4 _Pad;
+uint4 _Pool;
+uint4 _Stride;
+float _Alpha;
+float _Seed;
+
+[numthreads(8,8,1)]
+void Dense(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
+ TENSOR_ARGS4(X, W, B, O);
+
+ uint x = dispatchThreadID.x;
+ uint y = dispatchThreadID.y;
+
+ if (x >= O.GetFlatWidth()) return;
+ if (y >= O.GetFlatHeight()) return;
+
+ float acc = B.Get(x);
+ for (uint i = 0; i < X.GetFlatWidth(); ++i)
+ acc += X.Get(y, i) * W.Get(i, x);
+
+ O.Set(y, x, acc);
+}
+
+[numthreads(4,4,4)]
+void Relu(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = 0.5f * (v + abs(v));
+
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Swish(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = v / (1 + exp(-v));
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Tanh(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = tanh(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Sigmoid(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = 1 / (1 + exp(-v));
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Relu6(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = min(max(0, v), 6);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Elu(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ if (v <= 0)
+ v = _Alpha * (exp(v) - 1);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void LeakyRelu(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = max(v, _Alpha * v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Exp(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = exp(v);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+float signed_pow(float f, float e)
+{
+ // handle negative f
+ float v = pow(abs(f), e);
+ float s = (e % 2 == 1) ?
+ sign(f): // exponent is odd => sign(f) * pow(abs(f), e)
+ 1; // exponent is even => pow(abs(f), e)
+ return v * s;
+}
+
+[numthreads(4,4,4)]
+void Pow(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = signed_pow(v, _Alpha);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void BroadcastAdd(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v =
+ X.BroadcastGet(n, y, x, c) +
+ B.BroadcastGet(n, y, x, c);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void BroadcastSub(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v =
+ X.BroadcastGet(n, y, x, c) -
+ B.BroadcastGet(n, y, x, c);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void BroadcastMul(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float v =
+ X.BroadcastGet(n, y, x, c) *
+ B.BroadcastGet(n, y, x, c);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void BroadcastDiv(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v =
+ X.BroadcastGet(n, y, x, c) /
+ B.BroadcastGet(n, y, x, c);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void BroadcastPow(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = signed_pow(
+ X.BroadcastGet(n, y, x, c),
+ B.BroadcastGet(n, y, x, c));
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void BroadcastMin(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = min(
+ X.BroadcastGet(n, y, x, c),
+ B.BroadcastGet(n, y, x, c));
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void BroadcastMax(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = max(
+ X.BroadcastGet(n, y, x, c),
+ B.BroadcastGet(n, y, x, c));
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Copy(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // NOTE: dispatched over X (not O)
+ DISPATCH_ARGS(X.channels, X.width, X.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= X.channels) return; if (x >= X.width) return; if (y >= X.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ O.Set(n + _Pad[0], y + _Pad[1], x + _Pad[2], c + _Pad[3], v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Dropout(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float4 seed = float4(n / O.batch, y / O.height, x / O.width, c / O.channels);
+ seed = frac(seed + _Seed);
+
+ float v = X.Get(n, y, x, c);
+ v *= Bernoulli(seed, 1 - _Alpha) / (1 - _Alpha);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void ScaleBias(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS4(X, W, B, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ float scale = W.Get(0, 0, 0, c);
+ float bias = B.Get(0, 0, 0, c);
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = v * scale + bias;
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(16,4,1)]
+void Softmax(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint x = dispatchThreadID.x;
+ uint y = dispatchThreadID.y;
+
+ if (x >= O.GetFlatWidth()) return;
+ if (y >= O.GetFlatHeight()) return;
+
+ float maxV = -FLT_MAX;
+ for (uint i = 0; i < X.GetFlatWidth(); ++i)
+ {
+ float v = X.Get(y, i);
+ if (v > maxV)
+ maxV = v;
+ }
+
+ float acc = 0.0f;
+ for (i = 0; i < X.GetFlatWidth(); ++i)
+ {
+ float v = X.Get(y, i);
+ acc += exp(v - maxV);
+ }
+
+ float v = X.Get(y, x);
+ v = exp(v - maxV) / acc;
+ O.Set(y, x, v);
+}
+
+[numthreads(4,4,4)]
+void Upsample2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // NOTE: dispatched over X (not O)
+ DISPATCH_ARGS(X.channels, X.width, X.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= X.channels) return;
+ if (x >= X.width) return;
+ if (y >= X.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+
+ for (uint dy = 0; dy < _Pool.y; ++dy)
+ for (uint dx = 0; dx < _Pool.x; ++dx)
+ {
+ uint oy = y * _Pool.y + dy;
+ uint ox = x * _Pool.x + dx;
+ O.Set(n, oy, ox, c, v);
+ }
+ }
+}
+
+[numthreads(4,4,4)]
+void MaxPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float maxV = -FLT_MAX;
+ for (uint dy = 0; dy < _Pool.y; ++dy)
+ for (uint dx = 0; dx < _Pool.x; ++dx)
+ {
+ uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
+ float v = X.SafeGet(n, pos, c, _Pad.xy);
+ maxV = max(v, maxV);
+ }
+
+ O.Set(n, y, x, c, maxV);
+ }
+}
+
+[numthreads(4,4,4)]
+void AvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ uint2 leftCorner = _Pad.xy;
+ uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float acc = 0;
+ float counter = 0;
+ for (uint dy = 0; dy < _Pool.y; ++dy)
+ for (uint dx = 0; dx < _Pool.x; ++dx)
+ {
+ uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
+
+ bool mask = all(pos >= leftCorner) && all(pos < rightCorner);
+ acc += (mask)? X.Get(n, pos.y - leftCorner.y, pos.x - leftCorner.x, c): 0;
+ counter += (mask)? 1: 0;
+ }
+
+ acc /= counter;
+ O.Set(n, y, x, c, acc);
+ }
+}
+
+[numthreads(32,1,1)]
+void GlobalMaxPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, 1, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ if (c >= O.channels) return;
+ //ASSERT(X.batch == O.batch)
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float maxV = -FLT_MAX;
+ for (uint y = 0; y < X.height; ++y)
+ for (uint x = 0; x < X.width; ++x)
+ {
+ float v = X.Get(n, y, x, c);
+ maxV = max(v, maxV);
+ }
+
+ O.Set(n, 0, 0, c, maxV);
+ }
+}
+
+[numthreads(32,1,1)]
+void GlobalAvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, 1, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ if (c >= O.channels) return;
+ //ASSERT(X.batch == O.batch)
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = 0;
+ for (uint y = 0; y < X.height; ++y)
+ for (uint x = 0; x < X.width; ++x)
+ v += X.Get(n, y, x, c);
+
+ v /= (X.height * X.width);
+ O.Set(n, 0, 0, c, v);
+ }
+}
+
+[numthreads(32,1,1)]
+void InstanceNorm(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, 1, 1);
+ TENSOR_ARGS4(X, W, B, O);
+
+ uint c = dispatchThreadID.x;
+ if (c >= O.channels) return;
+ //ASSERT(X.shape == O.shape)
+
+ float gamma = W.Get(0, 0, 0, c);
+ float beta = B.Get(0, 0, 0, c);
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ uint x, y;
+ // calc mean
+ float acc = 0;
+ for (y = 0; y < O.height; ++y)
+ for (x = 0; x < O.width; ++x)
+ acc += X.Get(n, y, x, c);
+ float mean = acc / (O.width * O.height);
+
+ // calc variance
+ acc = 0;
+ for (y = 0; y < O.height; ++y)
+ for (x = 0; x < O.width; ++x)
+ {
+ float delta = X.Get(n, y, x, c) - mean;
+ acc += delta * delta;
+ }
+ float var = acc / (O.width * O.height);
+
+ // normalization factor
+ float invNormFactor = 1 / sqrt(var + FLT_EPSILON);
+
+ float scale = gamma * invNormFactor;
+ float bias = beta - gamma * mean * invNormFactor;
+
+ // apply normalization
+ for (y = 0; y < O.height; ++y)
+ for (x = 0; x < O.width; ++x)
+ {
+ float v = X.Get(n, y, x, c);
+ v = v * scale + bias;
+ O.Set(n, y, x, c, v);
+ }
+ }
+}
+
+[numthreads(4,4,4)]
+void Conv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_ARGS4(X, K, B, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ float v = X.SafeGet(n, pos, c, _Pad.xy);
+ acc += v * K.Get(dy, dx, c, k);
+ }
+ }
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+}
+
+NUMTHREADS((16,4,4), (8,4,4), (4,4,4))
+void DepthwiseConv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_ARGS4(X, K, B, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
+ float v = X.SafeGet(n, pos, k, _Pad.xy);
+ acc += v * K.Get(dy, dx, 0, k);
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+}
+
+[numthreads(4,4,4)]
+void Unstride2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ int xx = (int)x - (int)_Pad.x;
+ int yy = (int)y - (int)_Pad.y;
+
+ int my = yy % _Stride.y;
+ int mx = xx % _Stride.x;
+
+ int oy = yy / _Stride.y;
+ int ox = xx / _Stride.x;
+
+ bool mask = ox >= 0 && oy >= 0 && ox < (int)X.width && oy < (int)X.height &&
+ my == 0 && mx == 0;
+
+ float v = mask ? X.Get(n, (uint)oy, (uint)ox, c) : 0;
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(4,4,4)]
+void Conv2DTrans(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_ARGS4(X, K, B, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ uint2 strideMask = _Stride.xy - 1;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = y & strideMask.y; dy < K.GetKernelHeight(); dy += _Stride.y)
+ {
+ for (uint dx = x & strideMask.x; dx < K.GetKernelWidth(); dx += _Stride.x)
+ {
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ uint xx = x + dx;
+ uint yy = y + dy;
+
+ uint oy = (yy - _Pad.y) / _Stride.y;
+ uint ox = (xx - _Pad.x) / _Stride.x;
+
+ bool mask = xx >= _Pad.x && yy >= _Pad.y && ox < X.width && oy < X.height;
+
+ float v = (mask)? X.Get(n, oy, ox, c): 0;
+ acc += v * K.Get(K.GetKernelHeight() - 1 - dy, K.GetKernelWidth() - 1 - dx, c, k);
+ }
+ }
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+}
+
+
+Texture2D Xtex2D;
+Texture3D Xtex3D;
+Texture2DArray Xtex2DArray;
+SamplerState samplerXtex2D { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
+SamplerState samplerXtex3D { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; AddressW = Clamp; };
+SamplerState samplerXtex2DArray { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
+
+RWTexture2D Otex2D;
+RWTexture3D Otex3D;
+RWTexture2DArray Otex2DArray;
+
+bool _FlipY;
+
+// TODO: call TextureToTensor(v, dispatchThreadID) from Tex2DToTensor() { v = Xtex2D.SampleLevel }
+[numthreads(8,8,1)]
+void TextureToTensor(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ TENSOR_ARG_RW(O);
+
+ uint b = _Pad.x;
+ uint x = dispatchThreadID.x + _Pad.y;
+ uint y = dispatchThreadID.y + _Pad.z;
+ uint c = dispatchThreadID.z + _Pad.w;
+
+ // calculate texture coordinates:
+ // offset by 0.5 to get texel centers
+ // divide by texture resolution (_Pool)
+ float3 uvw = (float3)dispatchThreadID + float3(0.5f, 0.5f, 0);
+ uvw /= (float3)_Pool.xyz;
+ if (_FlipY)
+ uvw.y = 1 - uvw.y;
+
+ float4 v = Xtex2D.SampleLevel(samplerXtex2D, uvw.xy, 0);
+ //texArray.SampleLevel(smpArray, loc, 0);
+
+ if (_Stride.w == 1)
+ {
+ // TODO: interpret color as
+ O.Set(b, y, x, c+0, (v.r + v.g + v.b) / 3.0f);
+ }
+ else if (_Stride.w == 3)
+ {
+ O.Set(b, y, x, c+0, v.r);
+ O.Set(b, y, x, c+1, v.g);
+ O.Set(b, y, x, c+2, v.b);
+ }
+ else if (_Stride.w == 4)
+ {
+ O.Set(b, y, x, c+0, v.r);
+ O.Set(b, y, x, c+1, v.g);
+ O.Set(b, y, x, c+2, v.b);
+ O.Set(b, y, x, c+3, v.a);
+ }
+}
+
+[numthreads(8,8,1)]
+void TensorToTexture(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ TENSOR_ARG(X);
+
+ uint b = _Pad.x;
+ uint x = dispatchThreadID.x + _Pad.y;
+ uint y = dispatchThreadID.y + _Pad.z;
+ uint c = dispatchThreadID.z + _Pad.w;
+
+ if (_FlipY)
+ y = X.height - 1 - y;
+
+ float4 v = 0;
+
+ if (X.channels - c == 1)
+ {
+ // broadcast to all channels
+ v = X.Get(b, y, x, c);
+ }
+ else if (X.channels - c == 3)
+ {
+ v.r = X.Get(b, y, x, c+0);
+ v.g = X.Get(b, y, x, c+1);
+ v.b = X.Get(b, y, x, c+2);
+ v.a = 1;
+ }
+ else if (X.channels - c >= 4)
+ {
+ v.r = X.Get(b, y, x, c+0);
+ v.g = X.Get(b, y, x, c+1);
+ v.b = X.Get(b, y, x, c+2);
+ v.a = X.Get(b, y, x, c+3);
+ }
+
+
+ Otex2D[dispatchThreadID.xy] = v;
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute.meta
new file mode 100755
index 00000000..e8147972
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: b4b1b304aae6c404cb0cdab46b8fa084
+timeCreated: 1495527718
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute
new file mode 100755
index 00000000..240b7e86
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute
@@ -0,0 +1,149 @@
+#pragma kernel BroadcastAdd
+#pragma kernel BroadcastSub
+#pragma kernel BroadcastMul
+#pragma kernel BroadcastDiv
+#pragma kernel BroadcastPow
+#pragma kernel BroadcastMin
+#pragma kernel BroadcastMax
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(B)
+TENSOR_DECL_RW(O)
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void BroadcastAdd(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v =
+ X.BroadcastGet(n, y, x, c) +
+ B.BroadcastGet(n, y, x, c);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void BroadcastSub(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v =
+ X.BroadcastGet(n, y, x, c) -
+ B.BroadcastGet(n, y, x, c);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void BroadcastMul(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float v =
+ X.BroadcastGet(n, y, x, c) *
+ B.BroadcastGet(n, y, x, c);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void BroadcastDiv(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v =
+ X.BroadcastGet(n, y, x, c) /
+ B.BroadcastGet(n, y, x, c);
+ O.Set(n, y, x, c, v);
+ }
+}
+
+float signed_pow(float f, float e)
+{
+ // handle negative f
+ float v = pow(abs(f), e);
+ float s = (e % 2 == 1) ?
+ sign(f): // exponent is odd => sign(f) * pow(abs(f), e)
+ 1; // exponent is even => pow(abs(f), e)
+ return v * s;
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void BroadcastPow(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = signed_pow(
+ X.BroadcastGet(n, y, x, c),
+ B.BroadcastGet(n, y, x, c));
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void BroadcastMin(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = min(
+ X.BroadcastGet(n, y, x, c),
+ B.BroadcastGet(n, y, x, c));
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void BroadcastMax(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS3(X, B, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = max(
+ X.BroadcastGet(n, y, x, c),
+ B.BroadcastGet(n, y, x, c));
+ O.Set(n, y, x, c, v);
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute.meta
new file mode 100755
index 00000000..70f38084
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 72dd00e416ab94bd79e7264a1fadef9d
+ComputeShaderImporter:
+ externalObjects: {}
+ currentAPIMask: 65536
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute
new file mode 100755
index 00000000..ce0ba1a4
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute
@@ -0,0 +1,396 @@
+#pragma kernel Conv2D
+#pragma kernel Conv2D_RegisterBlock4x2
+//#pragma kernel Conv2D_L1Cached64_RegisterBlock4x4
+
+#pragma kernel DepthwiseConv2D
+
+#pragma kernel Conv2DTrans
+#pragma kernel Conv2DTrans_L1Cached64_RegisterBlock2x2
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(K)
+TENSOR_DECL(B)
+TENSOR_DECL(WBK)
+TENSOR_DECL_RW(O)
+
+uint4 _Pad;
+uint4 _Stride;
+
+NUMTHREADS((16,4,4), (8,4,4), (4,4,4))
+void Conv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ uint2 leftCorner = _Pad.xy;
+ uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
+ // @TODO: investigate
+ // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
+ if (any(pos < leftCorner)) continue;
+ if (any(pos >= rightCorner)) continue;
+
+ for (uint c = 0; c < X.channels; ++c)
+ acc = fastfma(X.Get(n, pos.y - leftCorner.y, pos.x - leftCorner.x, c), K.Get(dy, dx, c, k), acc);
+ }
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+}
+
+
+#define SIZE_W 4
+#define SIZE_H 2
+NUMTHREADS((64, 2, 2), (32, 2, 2), (16, 2, 2))
+void Conv2D_RegisterBlock4x2(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x*SIZE_W >= O.width) return;
+ if (y*SIZE_H >= O.height) return;
+
+ uint2 leftCorner = _Pad.xy;
+ uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE_H*SIZE_W];
+ [unroll]
+ for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
+ acc[q] = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos[SIZE_H*SIZE_W];
+ [unroll]
+ for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
+ pos[q] = uint2(x*SIZE_W+(q%SIZE_W), y*SIZE_H+(q/SIZE_W)) * _Stride.xy + uint2(dx, dy);
+
+ for (uint c = 0; c < X.channels; ++c)
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ if (all(pos[q] >= leftCorner) && all(pos[q] < rightCorner))
+ acc[q] = fastfma(X.Get(n, pos[q] - leftCorner, c), K.Get(dy, dx, c, k), acc[q]);
+ }
+ }
+
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ O.Set(n, y*SIZE_H+(q/SIZE_W), x*SIZE_W+(q%SIZE_W), k, acc[q]);
+ }
+}
+#undef SIZE_W
+#undef SIZE_H
+
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+#undef SIZE
+#define SIZE 4
+groupshared float Conv2D_L1Cached64_Reg_Loop_safe_X[SIZE*SIZE][L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2D_L1Cached64_RegisterBlock4x4(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2D_L1Cached64_Reg_Loop_safe_X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ // need all threads to load channels, thus will do late check against kernel count
+ if (x*SIZE >= O.width) return;
+ if (y*SIZE >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = B.SafeGet(k);
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ uint dc = groupThreadID.x;
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ X_[q][dc] = X.SafeGet(n, pos[q], c + dc, _Pad.xy);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ {
+ uint kIndex = K.Index(dy, dx, c, k);
+ for (dc = 0; dc < L1CACHESIZE; ++dc)
+ {
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = fastfma(X_[q][dc], K.data[kIndex], acc[q]);
+ kIndex += K.channels;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ uint remainderW = (O.width - x*SIZE);
+ uint remainderH = (O.height - y*SIZE);
+
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ if (q/SIZE < remainderH && q%SIZE < remainderW)
+ O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
+ }
+
+ #undef X_
+}
+
+
+NUMTHREADS((16,4,4), (8,4,4), (4,4,4))
+void DepthwiseConv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ uint2 leftCorner = _Pad.xy;
+ uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
+
+ uint2 leftKernelCorner = uint2(x, y) * _Stride.xy;
+ uint2 rightKernelCorner = leftKernelCorner + uint2(K.GetKernelWidth(), K.GetKernelHeight());
+
+ if (any(leftKernelCorner < leftCorner) || any(rightKernelCorner >= rightCorner))
+ {
+ // path with edge-cases checks
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos = leftKernelCorner + uint2(dx, dy);
+ if (any(pos < leftCorner)) continue;
+ if (any(pos >= rightCorner)) continue;
+
+ acc = fastfma(
+ X.Get(n, pos.y - leftCorner.y, pos.x - leftCorner.x, k),
+ K.Get(dy, dx, 0, k),
+ acc);
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+ }
+ else
+ {
+ // kernel is guaranteed to be within X,
+ // no need to check against edge-cases
+ leftKernelCorner -= leftCorner;
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos = leftKernelCorner + uint2(dx, dy);
+
+ acc = fastfma(
+ X.Get(n, pos, k),
+ K.Get(dy, dx, 0, k),
+ acc);
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+ }
+}
+
+
+// Significantly faster than Conv2DTrans
+[numthreads(16,2,2)]
+void Conv2DTrans(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // NOTE: dispatched over X (not O)
+ DISPATCH_ARGS(K.kernelCount, X.width, X.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= X.width) return;
+ if (y >= X.height) return;
+
+ uint2 pad = _Pad.xy / _Stride.xy;
+ uint2 leftCorner = pad;
+ uint2 rightCorner = uint2(X.width, X.height) + pad;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ for (uint sy = 0; sy < _Stride.y; ++sy)
+ {
+ for (uint sx = 0; sx < _Stride.x; ++sx)
+ {
+ float acc = B.Get(k);
+ for (uint dy = sy; dy < K.GetKernelHeight(); dy += _Stride.y)
+ {
+ for (uint dx = sx; dx < K.GetKernelWidth(); dx += _Stride.x)
+ {
+ uint2 pos = uint2(x, y) + uint2(sx + dx, sy + dy) / _Stride.xy;
+
+ if (any(pos < leftCorner)) continue;
+ if (any(pos >= rightCorner)) continue;
+
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ acc = fastfma( X.Get(n, pos - leftCorner, c),
+ K.Get( K.GetKernelHeight() - 1 - dy,
+ K.GetKernelWidth() - 1 - dx, c, k),
+ acc);
+ }
+ }
+ }
+
+ uint oy = y * _Stride.y + sy;
+ uint ox = x * _Stride.x + sx;
+ if (oy < O.height && ox < O.width)
+ O.Set(n, oy, ox, k, acc);
+ }
+ }
+ }
+}
+
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+#undef SIZE
+#define SIZE 2
+groupshared float Conv2DTrans_L1Cached64_Reg_Loop_safe_X[SIZE*SIZE][L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2DTrans_L1Cached64_RegisterBlock2x2(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ // NOTE: dispatched over X (not O)
+ DISPATCH_ARGS(K.kernelCount, X.width / SIZE, X.height / SIZE);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2DTrans_L1Cached64_Reg_Loop_safe_X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ // need all threads to load channels, thus will do late check against kernel count
+ if (x*SIZE >= X.width) return;
+ if (y*SIZE >= X.height) return;
+
+ uint2 pad = _Pad.xy / _Stride.xy;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ for (uint sy = 0; sy < _Stride.y; ++sy)
+ {
+ for (uint sx = 0; sx < _Stride.x; ++sx)
+ {
+ float acc[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = B.SafeGet(k);
+
+ for (uint dy = sy; dy < K.GetKernelHeight(); dy += _Stride.y)
+ {
+ for (uint dx = sx; dx < K.GetKernelWidth(); dx += _Stride.x)
+ {
+ uint2 pos[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) + uint2(dx+sx, dy+sy) / _Stride.xy;
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ uint dc = groupThreadID.x;
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ X_[q][dc] = X.SafeGet(n, pos[q], c + dc, pad);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ {
+ //uint kIndex = K.Index(dy, dx, c, k);
+ for (dc = 0; dc < L1CACHESIZE; ++dc)
+ {
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = fastfma( X_[q][dc],
+ K.Get( K.GetKernelHeight() - 1 - dy,
+ K.GetKernelWidth() - 1 - dx, c + dc, k),
+ acc[q]);
+ //kIndex += K.channels;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ {
+ uint ox = (x*SIZE+(q%SIZE)) * _Stride.x + sx;
+ uint oy = (y*SIZE+(q/SIZE)) * _Stride.y + sy;
+ if (ox < O.width && oy < O.height)
+ O.Set(n, oy, ox, k, acc[q]);
+ }
+ }
+ }
+ }
+
+ #undef X_
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute.meta
new file mode 100755
index 00000000..bc66c8b6
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 7f508b82f984146e8bf0ad8520c316c7
+timeCreated: 1507457340
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute
new file mode 100755
index 00000000..1e957493
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute
@@ -0,0 +1,418 @@
+//#pragma kernel Conv2D_Kmod16_Nmod8_KNY
+//#pragma kernel Conv2D_Cache_KCmod32_KNyx
+//#pragma kernel Conv2D_Cache_KCmod32_KNyxDiv2
+// NOTE: DISABLED 64 version because as it is slower than 32 version on AMD GPU
+//#pragma kernel Conv2D_Cache_KCmod64_KNyx
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(K)
+TENSOR_DECL(B)
+TENSOR_DECL(WBK)
+TENSOR_DECL_RW(O)
+
+uint4 _Pad;
+uint4 _Stride;
+
+NUMTHREADS((16,8,1), (16,8,1), (16,4,1))
+void Conv2D_Kmod16_Nmod8_KNY(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.channels, O.batch, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint n = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ for (uint x = 0; x < O.width; ++x)
+ {
+ float v = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint oy = y * _Stride.y + dy;
+ uint ox = x * _Stride.x + dx;
+ // @TODO: investigate
+ // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
+ if (oy < _Pad.y) continue;
+ if (oy - _Pad.w >= X.height) continue;
+ if (ox < _Pad.x) continue;
+ if (ox - _Pad.z >= X.width) continue;
+
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ v += X.Get(n, oy-_Pad.y, ox-_Pad.x, c) * K.Get(dy, dx, c, k);
+ }
+ }
+ }
+ O.Set(n, y, x, k, v);
+ }
+}
+
+#undef CTILE
+#define CTILE NUMTHREAD(16, 8, 8)
+groupshared float Conv_Xcache[4][CTILE][CTILE];
+groupshared float Conv_Kcache[4][CTILE][CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Cache_KCmod32_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount / 2, O.batch * O.height * O.width / 2, 1);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv_Xcache
+ #define K_ Conv_Kcache
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = O.width;
+ uint height = O.height;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float b0 = B.Get(k*2+0);
+ float b1 = B.Get(k*2+1);
+ float4 v = float4(b0, b1,
+ b0, b1);
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ bool mask = true;
+ uint oy = y * _Stride.y + dy;
+ uint ox = x * _Stride.x + dx;
+ // @TODO: investigate
+ // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
+ if (oy < _Pad.y) mask = false;
+ if (oy - _Pad.w >= X.height) mask = false;
+ if (ox < _Pad.x) mask = false;
+ if (ox - _Pad.z >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*2); ++m)
+ {
+ float x0 = 0;
+ float x1 = 0;
+ float x2 = 0;
+ float x3 = 0;
+
+ if (mask)
+ {
+ x0 = X.Get(n*2+0, oy-_Pad.y, ox-_Pad.x, (m*CTILE + gx)*2+0);
+ x1 = X.Get(n*2+0, oy-_Pad.y, ox-_Pad.x, (m*CTILE + gx)*2+1);
+ x2 = X.Get(n*2+1, oy-_Pad.y, ox-_Pad.x, (m*CTILE + gx)*2+0);
+ x3 = X.Get(n*2+1, oy-_Pad.y, ox-_Pad.x, (m*CTILE + gx)*2+1);
+ }
+
+ float k0 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+0);
+ float k1 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+1);
+ float k2 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+0);
+ float k3 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+1);
+
+ //X_[gy][gx] = float4(x0, x1,
+ // x2, x3);
+ //K_[gy][gx] = float4(k0, k1,
+ // k2, k3);
+ X_[0][gy][gx] = x0;
+ X_[1][gy][gx] = x1;
+ X_[2][gy][gx] = x2;
+ X_[3][gy][gx] = x3;
+
+ K_[0][gy][gx] = k0;
+ K_[1][gy][gx] = k1;
+ K_[2][gy][gx] = k2;
+ K_[3][gy][gx] = k3;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < CTILE; ++i)
+ {
+ float4 x = //X_[gy][i];
+ float4( X_[0][gy][i],
+ X_[1][gy][i],
+ X_[2][gy][i],
+ X_[3][gy][i]);
+ float4 k = //K_[i][gx];
+ float4( K_[0][i][gx],
+ K_[1][i][gx],
+ K_[2][i][gx],
+ K_[3][i][gx]);
+
+ v.x = mad(k.x, x.x, v.x);
+ v.x = mad(k.z, x.y, v.x);
+
+ v.y = mad(k.y, x.x, v.y);
+ v.y = mad(k.w, x.y, v.y);
+
+ v.z = mad(k.x, x.z, v.z);
+ v.z = mad(k.z, x.w, v.z);
+
+ v.w = mad(k.y, x.z, v.w);
+ v.w = mad(k.w, x.w, v.w);
+
+ //v.x += k.x*x.x + k.z*x.y;
+ //v.y += k.y*x.x + k.w*x.y;
+ //v.z += k.x*x.z + k.z*x.w;
+ //v.w += k.y*x.z + k.w*x.w;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ O.Set(n*2+0, y, x, k*2+0, v.x);
+ O.Set(n*2+0, y, x, k*2+1, v.y);
+ O.Set(n*2+1, y, x, k*2+0, v.z);
+ O.Set(n*2+1, y, x, k*2+1, v.w);
+
+ #undef X_
+ #undef K_
+}
+
+#undef CTILE
+//#define CTILE NUMTHREAD(16, 8, 8)
+#define CTILE 16
+groupshared float Conv_Xcache2[4][CTILE][CTILE];
+groupshared float Conv_Kcache2[4][CTILE][CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Cache_KCmod32_KNyxDiv2(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount / 2, O.batch * O.height * O.width / 2, 1);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv_Xcache2
+ #define K_ Conv_Kcache2
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = O.width / 2;
+ uint height = O.height;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float b0 = B.Get(k*2+0);
+ float b1 = B.Get(k*2+1);
+ float4 v = float4(b0, b1,
+ b0, b1);
+
+ bool mask = n < O.batch;
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ // @TODO: investigate
+ // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
+ bool maskY = mask;
+ uint oy = y * _Stride.y + dy;
+ if (oy < _Pad.y) maskY = false;
+ if (oy - _Pad.w >= X.height) maskY = false;
+
+ bool maskL = maskY;
+ uint oxL = (x*2+0) * _Stride.x + dx;
+ if (oxL < _Pad.x) maskL = false;
+ if (oxL - _Pad.z >= X.width) maskL = false;
+
+ bool maskR = maskY;
+ uint oxR = (x*2+1) * _Stride.x + dx;
+ if (oxR < _Pad.x) maskR = false;
+ if (oxR - _Pad.z >= X.width) maskR = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*2); ++m)
+ {
+ if (maskL)
+ {
+ X_[0][gy][gx] = X.Get(n, oy-_Pad.y, oxL-_Pad.x, (m*CTILE + gx)*2+0);
+ X_[1][gy][gx] = X.Get(n, oy-_Pad.y, oxL-_Pad.x, (m*CTILE + gx)*2+1);
+ }
+ else
+ {
+ X_[0][gy][gx] = X_[1][gy][gx] = 0;
+ }
+
+ if (maskR)
+ {
+ X_[2][gy][gx] = X.Get(n, oy-_Pad.y, oxR-_Pad.x, (m*CTILE + gx)*2+0);
+ X_[3][gy][gx] = X.Get(n, oy-_Pad.y, oxR-_Pad.x, (m*CTILE + gx)*2+1);
+ }
+ else
+ {
+ X_[2][gy][gx] = X_[3][gy][gx] = 0;
+ }
+
+
+ K_[0][gy][gx] = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+0);
+ K_[1][gy][gx] = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+1);
+ K_[2][gy][gx] = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+0);
+ K_[3][gy][gx] = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+1);
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < CTILE; ++i)
+ {
+ float4 x =
+ float4( X_[0][gy][i],
+ X_[1][gy][i],
+ X_[2][gy][i],
+ X_[3][gy][i]);
+ float4 k =
+ float4( K_[0][i][gx],
+ K_[1][i][gx],
+ K_[2][i][gx],
+ K_[3][i][gx]);
+
+ v.x = mad(k.x, x.x, v.x);
+ v.x = mad(k.z, x.y, v.x);
+
+ v.y = mad(k.y, x.x, v.y);
+ v.y = mad(k.w, x.y, v.y);
+
+ v.z = mad(k.x, x.z, v.z);
+ v.z = mad(k.z, x.w, v.z);
+
+ v.w = mad(k.y, x.z, v.w);
+ v.w = mad(k.w, x.w, v.w);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ O.Set(n, y, x*2+0, k*2+0, v.x);
+ O.Set(n, y, x*2+0, k*2+1, v.y);
+ if (mask && x*2+1 < O.width)
+ {
+ O.Set(n, y, x*2+1, k*2+0, v.z);
+ O.Set(n, y, x*2+1, k*2+1, v.w);
+ }
+
+ #undef X_
+ #undef K_
+}
+
+
+#undef CTILE
+//#define CTILE NUMTHREAD(16, 8, 8)
+#define CTILE 16
+#define RTILE 4
+groupshared float Conv_XcacheR[RTILE*RTILE][CTILE*CTILE];
+groupshared float Conv_KcacheR[RTILE*RTILE][CTILE*CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount / 4, O.batch * O.height * O.width / 4, 1);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv_XcacheR
+ #define K_ Conv_KcacheR
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint x = nyx % O.width;
+ uint ny = nyx / O.width;
+ uint y = ny % O.height;
+ uint n = ny / O.height;
+
+ float v[RTILE][RTILE];
+ for (uint xxxx = 0; xxxx < RTILE; ++xxxx)
+ {
+ float b = B.Get(k*RTILE+xxxx);
+ for (uint yyyy = 0; yyyy < RTILE; ++yyyy)
+ v[yyyy][xxxx] = b;
+ }
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ bool mask = true;
+ uint oy = y * _Stride.y + dy;
+ uint ox = x * _Stride.x + dx;
+ // @TODO: investigate
+ // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
+ if (oy < _Pad.y) mask = false;
+ if (oy - _Pad.w >= X.height) mask = false;
+ if (ox < _Pad.x) mask = false;
+ if (ox - _Pad.z >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*RTILE); ++m)
+ {
+ for (uint yy = 0; yy < RTILE; ++yy)
+ for (uint xx = 0; xx < RTILE; ++xx)
+ {
+ if (mask)
+ X_[yy*RTILE+xx][gy*CTILE+gx] = X.Get(n*RTILE+yy, oy - _Pad.y, ox - _Pad.x, (m*CTILE + gx)*RTILE+xx);
+ else
+ X_[yy*RTILE+xx][gy*CTILE+gx] = 0;
+ K_[yy*RTILE+xx][gy*CTILE+gx] = K.Get(dy, dx, (m*CTILE + gy)*RTILE+yy, k*RTILE+xx);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint ii = 0; ii < CTILE; ++ii)
+ {
+ float x[RTILE][RTILE];
+ float k[RTILE][RTILE];
+
+ [unroll]
+ for (uint yy = 0; yy < RTILE; ++yy)
+ {
+ [unroll]
+ for (uint xx = 0; xx < RTILE; ++xx)
+ {
+ x[yy][xx] = X_[yy*RTILE+xx][gy*CTILE+ii];
+ k[yy][xx] = K_[yy*RTILE+xx][ii*CTILE+gx];
+ }
+ }
+
+
+ [unroll]
+ for (uint yyy = 0; yyy < RTILE; ++yyy)
+ {
+ [unroll]
+ for (uint xxx = 0; xxx < RTILE; ++xxx)
+ {
+ [unroll]
+ for (uint i = 0; i < RTILE; ++i)
+ {
+ v[yyy][xxx] = mad(x[yyy][i], k[i][xxx], v[yyy][xxx]);
+ }
+ }
+ }
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ for (uint yy = 0; yy < RTILE; ++yy)
+ for (uint xx = 0; xx < RTILE; ++xx)
+ O.Set(n*RTILE+yy, y, x, k*RTILE+xx, v[yy][xx]);
+
+ #undef X_
+ #undef K_
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute.meta
new file mode 100755
index 00000000..dae45fc7
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: a89bb2d7cde74429c8475f7cd8bcdb01
+ComputeShaderImporter:
+ externalObjects: {}
+ currentAPIMask: 0
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute
new file mode 100755
index 00000000..bd2c76cd
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute
@@ -0,0 +1,305 @@
+#pragma kernel Dense_L1Cached64
+#pragma kernel DenseTiled16x16
+//#pragma kernel DenseTiled32x32
+//#pragma kernel DenseTiled64x64
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(W)
+TENSOR_DECL(B)
+TENSOR_DECL(WBK)
+TENSOR_DECL_RW(O)
+
+// NOTE: usually this path is used for <16 batches
+#undef CACHESIZE
+#define CACHESIZE 64
+groupshared float Dense_L1Cached64_X[CACHESIZE];
+[numthreads(CACHESIZE, 1, 1)]
+void Dense_L1Cached64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ #define X_ Dense_L1Cached64_X
+
+ uint x = CACHESIZE * groupID.x + groupThreadID.x;
+ uint y = groupID.y;
+
+ uint wIndex = W.Index(0, x);
+
+ float acc = B.Get(x);
+ // loop over X columns (flatWidth) and W rows (height) in CACHESIZE steps
+ for (uint i = 0; i < X.GetFlatWidth(); i += CACHESIZE)
+ {
+ // Cache X
+ // coalescent reads
+ X_[groupThreadID.x] = X.SafeGet(y, i + groupThreadID.x);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * W
+ if (i + CACHESIZE <= X.GetFlatWidth())
+ {
+ [unroll]
+ for (uint di = 0; di < CACHESIZE; ++di)
+ {
+ acc = fastfma(X_[di], W.data[wIndex], acc);
+ wIndex += W.GetFlatWidth();
+ }
+ }
+ else
+ {
+ // handle remainder of the line < CACHESIZE
+ for (uint di = 0; i + di < X.GetFlatWidth(); ++di)
+ {
+ acc = fastfma(X_[di], W.data[wIndex], acc);
+ wIndex += W.GetFlatWidth();
+ }
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ // needed all threads to load matrix line, x might be out of the bounds for writing
+ if (x < O.GetFlatWidth())
+ O.Set(y, x, acc);
+
+ #undef X_
+}
+
+
+#undef TILE_WIDTH
+#define TILE_WIDTH NUMTHREAD(16,8,8)
+groupshared float DenseTiled_Xcache[TILE_WIDTH][TILE_WIDTH];
+groupshared float DenseTiled_Wcache[TILE_WIDTH][TILE_WIDTH];
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiled16x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ #define X_ DenseTiled_Xcache
+ #define W_ DenseTiled_Wcache
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint y = groupID.y*TILE_WIDTH + ty;
+
+ bool mask = (x < O.GetFlatWidth() && y < O.GetFlatHeight());
+
+ float v = B.Get(x);
+ for (uint m = 0; m < X.GetFlatWidth()/TILE_WIDTH; ++m)
+ {
+ if (mask)
+ {
+ X_[ty][tx] = X.Get(y, m*TILE_WIDTH + tx);
+ W_[ty][tx] = W.Get(m*TILE_WIDTH + ty, x);
+ }
+ else
+ {
+ X_[ty][tx] = 0;
+ W_[ty][tx] = 0;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < TILE_WIDTH; ++i)
+ {
+ v = fastfma(X_[ty][i], W_[i][tx], v);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ if (mask)
+ O.Set(y, x, v);
+
+ #undef X_
+ #undef W_
+}
+
+#undef TILE_WIDTH
+#define TILE_WIDTH NUMTHREAD(16,8,8) // 32 crashes on MacBookPro/AMD
+groupshared float DenseTiled_Xcache32[2*2][TILE_WIDTH][TILE_WIDTH];
+groupshared float DenseTiled_Wcache32[2*2][TILE_WIDTH][TILE_WIDTH];
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiled32x32(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(O.flatWidth / 2, O.flatHeight / 2, 1);
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ #define X_ DenseTiled_Xcache32
+ #define W_ DenseTiled_Wcache32
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint y = groupID.y*TILE_WIDTH + ty;
+
+ float b0 = B.Get(x*2+0);
+ float b1 = B.Get(x*2+1);
+ float4 v = float4(b0, b1,
+ b0, b1);
+
+ for (uint m = 0; m < X.GetFlatWidth()/(TILE_WIDTH*2);)
+ {
+ float x0 = X.Get(y*2+0, m*TILE_WIDTH*2 + tx*2+0);
+ float x1 = X.Get(y*2+0, m*TILE_WIDTH*2 + tx*2+1);
+ float x2 = X.Get(y*2+1, m*TILE_WIDTH*2 + tx*2+0);
+ float x3 = X.Get(y*2+1, m*TILE_WIDTH*2 + tx*2+1);
+
+ float w0 = W.Get(m*TILE_WIDTH*2 + ty*2+0, x*2+0);
+ float w1 = W.Get(m*TILE_WIDTH*2 + ty*2+0, x*2+1);
+ float w2 = W.Get(m*TILE_WIDTH*2 + ty*2+1, x*2+0);
+ float w3 = W.Get(m*TILE_WIDTH*2 + ty*2+1, x*2+1);
+
+ ++m;
+
+ X_[0][ty][tx] = x0;
+ X_[1][ty][tx] = x1;
+ X_[2][ty][tx] = x2;
+ X_[3][ty][tx] = x3;
+
+ W_[0][ty][tx] = w0;
+ W_[1][ty][tx] = w1;
+ W_[2][ty][tx] = w2;
+ W_[3][ty][tx] = w3;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < TILE_WIDTH; ++i)
+ {
+ float4 x =
+ float4( X_[0][ty][i],
+ X_[1][ty][i],
+ X_[2][ty][i],
+ X_[3][ty][i]);
+ float4 w =
+ float4( W_[0][i][tx],
+ W_[1][i][tx],
+ W_[2][i][tx],
+ W_[3][i][tx]);
+
+ v.x = fastfma(w.x, x.x, v.x);
+ v.y = fastfma(w.y, x.x, v.y);
+ v.z = fastfma(w.x, x.z, v.z);
+ v.w = fastfma(w.y, x.z, v.w);
+
+ v.x = fastfma(w.z, x.y, v.x);
+ v.y = fastfma(w.w, x.y, v.y);
+ v.z = fastfma(w.z, x.w, v.z);
+ v.w = fastfma(w.w, x.w, v.w);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ O.Set(y*2+0, x*2+0, v.x);
+ O.Set(y*2+0, x*2+1, v.y);
+ O.Set(y*2+1, x*2+0, v.z);
+ O.Set(y*2+1, x*2+1, v.w);
+
+ #undef X_
+ #undef W_
+}
+
+#undef TILE_WIDTH
+#define TILE_WIDTH NUMTHREAD(16,8,8)
+groupshared float DenseTiled_Xcache64[4*4][TILE_WIDTH*TILE_WIDTH];
+groupshared float DenseTiled_Wcache64[4*4][TILE_WIDTH*TILE_WIDTH];
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiled64x64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(O.flatWidth / 4, O.flatHeight / 4, 1);
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ #define X_ DenseTiled_Xcache64
+ #define W_ DenseTiled_Wcache64
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint y = groupID.y*TILE_WIDTH + ty;
+
+ float b0 = B.Get(x*4+0);
+ float b1 = B.Get(x*4+1);
+ float b2 = B.Get(x*4+2);
+ float b3 = B.Get(x*4+3);
+
+ float4 v0, v1, v2, v3;
+ v0 = v1 = v2 = v3 = float4(b0, b1, b2, b3);
+
+ for (uint m = 0; m < X.GetFlatWidth()/(TILE_WIDTH*4); ++m)
+ {
+ for (uint yy = 0; yy < 4; ++yy)
+ for (uint xx = 0; xx < 4; ++xx)
+ {
+ X_[yy*4+xx][ty*TILE_WIDTH+tx] = X.Get(y*4+yy, (m*TILE_WIDTH + tx)*4+xx);
+ W_[yy*4+xx][ty*TILE_WIDTH+tx] = W.Get((m*TILE_WIDTH + ty)*4+yy, x*4+xx);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint i = 0; i < TILE_WIDTH; ++i)
+ {
+ [unroll]
+ for (uint q = 0; q < 4; ++q)
+ {
+ float x0 = X_[0*4+q][ty*TILE_WIDTH+i];
+ float x1 = X_[1*4+q][ty*TILE_WIDTH+i];
+ float x2 = X_[2*4+q][ty*TILE_WIDTH+i];
+ float x3 = X_[3*4+q][ty*TILE_WIDTH+i];
+
+ float w0 = W_[q*4+0][i*TILE_WIDTH+tx];
+ float w1 = W_[q*4+1][i*TILE_WIDTH+tx];
+ float w2 = W_[q*4+2][i*TILE_WIDTH+tx];
+ float w3 = W_[q*4+3][i*TILE_WIDTH+tx];
+
+ v0.x = fastfma(x0, w0, v0.x); //--
+ v1.x = fastfma(x1, w0, v1.x);
+ v2.x = fastfma(x2, w0, v2.x);
+ v3.x = fastfma(x3, w0, v3.x);
+ v0.y = fastfma(x0, w1, v0.y); //--
+ v1.y = fastfma(x1, w1, v1.y);
+ v2.y = fastfma(x2, w1, v2.y);
+ v3.y = fastfma(x3, w1, v3.y);
+ v0.z = fastfma(x0, w2, v0.z); //--
+ v1.z = fastfma(x1, w2, v1.z);
+ v2.z = fastfma(x2, w2, v2.z);
+ v3.z = fastfma(x3, w2, v3.z);
+ v0.w = fastfma(x0, w3, v0.w); //--
+ v1.w = fastfma(x1, w3, v1.w);
+ v2.w = fastfma(x2, w3, v2.w);
+ v3.w = fastfma(x3, w3, v3.w);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+
+ O.Set(y*4+0, x*4+0, v0.x);
+ O.Set(y*4+0, x*4+1, v0.y);
+ O.Set(y*4+0, x*4+2, v0.z);
+ O.Set(y*4+0, x*4+3, v0.w);
+
+ O.Set(y*4+1, x*4+0, v1.x);
+ O.Set(y*4+1, x*4+1, v1.y);
+ O.Set(y*4+1, x*4+2, v1.z);
+ O.Set(y*4+1, x*4+3, v1.w);
+
+ O.Set(y*4+2, x*4+0, v2.x);
+ O.Set(y*4+2, x*4+1, v2.y);
+ O.Set(y*4+2, x*4+2, v2.z);
+ O.Set(y*4+2, x*4+3, v2.w);
+
+ O.Set(y*4+3, x*4+0, v3.x);
+ O.Set(y*4+3, x*4+1, v3.y);
+ O.Set(y*4+3, x*4+2, v3.z);
+ O.Set(y*4+3, x*4+3, v3.w);
+
+ #undef X_
+ #undef W_
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute.meta
new file mode 100755
index 00000000..33ad83ca
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 6b08c0ac202ad41deb8881132b21894c
+timeCreated: 1507457322
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute
new file mode 100755
index 00000000..759bd7b0
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute
@@ -0,0 +1,72 @@
+#pragma kernel DenseFP16Div2
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(W)
+TENSOR_DECL(B)
+TENSOR_DECL(WBK)
+TENSOR_DECL_RW(O)
+
+float f16tof32_(uint src)
+{
+ // Based on Fabian Giesen's public domain half_to_float_fast3
+ const uint magic = 113 << 23;
+ const uint shiftedExp = 0x7c00 << 13; // exponent mask after shift
+
+ // Mask out sign bit
+ uint o = src & 0x7fff;
+ if (o)
+ {
+ // Move exponent + mantissa to correct bits
+ o <<= 13;
+ uint exponent = o & shiftedExp;
+ if (exponent == 0)
+ {
+ // Handle denormal
+ o = asuint(asfloat(o + magic) - asfloat(magic));
+ }
+ else if (exponent == shiftedExp) // Inf/NaN
+ o += (255 - 31) << 23;
+ else
+ o += (127 - 15) << 23;
+ }
+
+ // Copy sign bit
+ o |= (src & 0x8000) << 16;
+
+ return asfloat(o);
+}
+
+float2 Unpack(SharedTensor t, uint y, uint x)
+{
+ uint v = asuint(t.data[t.Index(y, x) >> 1]);
+ // TEMPORARY: f16tof32 is broken in GLSL/Metal compiler
+ // using custom conversion function for now
+ //return float2(f16tof32(v), f16tof32(v>>16));
+ return float2(f16tof32_(v), f16tof32_(v>>16));
+}
+
+// NOTE: usually this path is used for <16 batches
+NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
+void DenseFP16Div2(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.flatWidth/2, O.flatHeight, 1);
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ uint x = dispatchThreadID.x;
+ uint y = dispatchThreadID.y;
+
+ if (x*2 >= O.GetFlatWidth()) return;
+ if (y >= O.GetFlatHeight()) return;
+
+ float2 acc = Unpack(B, 0, x*2);
+ for (uint i = 0; i < X.width; ++i)
+ {
+ float2 w = Unpack(W, i, x*2);
+ acc += X.Get(y, i) * w;
+ }
+
+ O.Set(y, x*2+0, acc[0]);
+ O.Set(y, x*2+1, acc[1]);
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute.meta
new file mode 100755
index 00000000..f0111a62
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: cff3cb66e54744fa4888ef91a11ec90c
+timeCreated: 1508334838
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute
new file mode 100755
index 00000000..db2e7ddc
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute
@@ -0,0 +1,4284 @@
+#if EXPERIMENTAL_KERNELS_ENABLED
+/*
+#pragma kernel Dense
+#pragma kernel DenseTiled
+#pragma kernel Dense10x16
+#pragma kernel DenseTiled32x32
+#pragma kernel DenseTiled64x64
+#pragma kernel Dense64
+#pragma kernel Relu
+#pragma kernel Relu256xV
+#pragma kernel Relu16x16
+#pragma kernel ReluChannelsFirst16x2x16
+#pragma kernel Relu_Cmod16_CNyx
+#pragma kernel Relu_Nyxc
+#pragma kernel Softmax
+#pragma kernel Softmax256x2
+#pragma kernel MaxPooling2D
+#pragma kernel MaxPooling2D16x4x4
+*/
+/*
+#pragma kernel Conv2D_Kernel3x3_32Channel
+#pragma kernel Conv2D_Kernel3x3_1Channel
+#pragma kernel Conv2D
+//#pragma kernel Conv2DTiled16x16_Kernel3x3
+#pragma kernel Conv2DTiled14x14_Kernel3x3
+#pragma kernel Conv2DTiled13x13_Kernel3x3
+//#pragma kernel Conv2DTiled12x12_Kernel3x3
+#pragma kernel Fill
+
+#pragma kernel Conv2D_Kernel3x3_Kmod16_Cmod4_KN
+#pragma kernel Conv2D_Kernel3x3_Kmod16_Cmod4_KNyx
+//#pragma kernel Conv2D_Kernel3x3_Cache_KCmod32_KNyx
+//#pragma kernel Conv2D_Kernel3x3_Cache_KCmod64_KNyx
+*/
+
+
+// @TODO: BIAS and WEIGHTS have changed format
+// BIAS (0,0,x,0) -> (0,0,0,x) --> (x)
+// WEIGHTS (y,0,x,0) -> (y,0,0,x) --> (y,x)
+// DENSE_OUT (y,0,x,0) -> (y,0,0,x) --> (y,x)
+
+
+//#pragma kernel Conv2D_Kmod16_Nmod8_KNY
+//#pragma kernel Conv2D_Kernel3x3_64
+
+#define BOUNDS_CHECKS 0
+
+RWStructuredBuffer Edata;
+
+struct Tensor
+{
+ uint batch, height, width, channels;
+ uint offset;
+ uint dataLength;
+
+ uint Index(uint b, uint h, uint w, uint ch)
+ {
+ uint index =
+ b * height * width * channels +
+ h * width * channels +
+ w * channels +
+ ch;
+ return index + offset;
+ }
+ void Set(uint b, uint h, uint w, uint ch, float v, RWStructuredBuffer data)
+ {
+ data[Index(b,h,w,ch)] = v;
+ }
+ void Set(int b, uint h, uint w, uint ch, float v, RWStructuredBuffer data, int dataLength)
+ {
+ uint index = Index(b,h,w,ch);
+ #if BOUNDS_CHECKS
+ if (index < 0 || index >= dataLength)
+ {
+ InterlockedAdd(Edata[1], 1);
+ return;
+ }
+ #endif
+
+ data[Index(b,h,w,ch)] = v;
+ }
+
+ float Get(uint b, uint h, uint w, uint ch, StructuredBuffer data)
+ {
+ return data[Index(b,h,w,ch)];
+ }
+ float Get(uint b, uint h, uint w, uint ch, StructuredBuffer data, int dataLength)
+ {
+ int index = Index(b,h,w,ch);
+ #if BOUNDS_CHECKS
+ if (index < 0 || index >= dataLength)
+ {
+ InterlockedAdd(Edata[0], 1);
+ return 0.0f;
+ }
+ #endif
+
+ return data[Index(b,h,w,ch)];
+ }
+};
+
+#define X ((Tensor)Xdecl)
+int4 Xdecl[2];
+StructuredBuffer Xdata;
+
+#define O ((Tensor)Odecl)
+int4 Odecl[2];
+RWStructuredBuffer Odata;
+
+#define W ((Tensor)Wdecl)
+int4 Wdecl[2];
+
+#define B ((Tensor)Bdecl)
+int4 Bdecl[2];
+
+#define K ((Tensor)Kdecl)
+int4 Kdecl[2];
+
+#define WBK ((Tensor)WBKdecl)
+int4 WBKdecl[2];
+StructuredBuffer WBKdata;
+
+uint _FilterSize;
+uint _Border;
+uint _Offset;
+
+[numthreads(1,1,1)]
+void Dense(uint3 groupID : SV_GroupID)
+{
+ uint b = groupID.y;
+ uint x = groupID.x;
+ float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
+ for (uint i = 0; i < X.width; ++i)
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength);
+
+ O.Set(b, 0, x, 0, v, Odata, O.dataLength);
+}
+
+[numthreads(10,16,1)]
+void Dense10x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint x = 10*groupID.x + groupThreadID.x;
+ uint b = 16*groupID.y + groupThreadID.y;
+ float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
+
+ for (uint i = 0; i < X.width;)
+ {
+ // can unroll up to 16 because numthreads.y=16
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
+
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
+ v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
+ }
+ O.Set(b, 0, x, 0, v, Odata);
+}
+
+
+#undef THREAD_COUNT
+#define THREAD_COUNT 64 // ATM support only 8x8
+
+#undef BLOCK_WIDTH
+#define BLOCK_WIDTH 8
+
+#undef LOAD_WIDTH
+#define LOAD_WIDTH THREAD_COUNT
+
+#undef LOAD_DEPTH
+#define LOAD_DEPTH BLOCK_WIDTH
+
+groupshared float Conv_KcacheR[LOAD_DEPTH][LOAD_WIDTH];
+groupshared float Conv_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
+[numthreads(THREAD_COUNT, 1, 1)]
+void Conv2D_Kernel3x3_64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_XcacheR
+ #define K_ Conv_KcacheR
+
+ uint id = groupThreadID.x;
+ uint bx = groupID.x;
+ uint by = groupID.y;
+
+ uint bbx = id % BLOCK_WIDTH;
+ uint bby = id / BLOCK_WIDTH;
+
+ uint width = O.width;
+ uint height = O.height;
+
+ // ASSERT(LOAD_WIDTH == THREAD_COUNT)
+ uint loadNYX = by*LOAD_WIDTH + id; // only works for 8x8
+ uint loadX = loadNYX % width;
+ uint loadNY = loadNYX / width;
+ uint loadY = loadNY % height;
+ uint loadN = loadNY / height;
+
+ float v[BLOCK_WIDTH][BLOCK_WIDTH];
+ for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
+ for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
+ {
+ float bias = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx, 0, WBKdata, WBK.dataLength);
+ v[yy][xx] = bias;
+ }
+
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (loadY+dy < _Offset) mask = false;
+ if (loadY+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (loadX+dx < _Offset) mask = false;
+ if (loadX+dx-_Offset >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/LOAD_DEPTH; ++m)
+ {
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ if (mask)
+ X_[q][id] = X.Get(loadN, loadY+dy-_Offset, loadX+dx-_Offset, m*LOAD_DEPTH + q, Xdata);
+ else
+ X_[q][id] = 0;
+ K_[q][id] = K.Get(dy, dx, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * K_[i][bbx*BLOCK_WIDTH + xxx];
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ {
+ //O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, y, x, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, v[yyy][xxx], Odata);
+ uint saveNYX = by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy;
+ //uint saveNYX = by*LOAD_WIDTH + ((id>>3)<<3) + yyy;
+ uint saveX = saveNYX % width;
+ uint saveNY = saveNYX / width;
+ uint saveY = saveNY % height;
+ uint saveN = saveNY / height;
+
+ uint saveK = bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx;
+ O.Set(saveN, saveY, saveX, saveK, v[yyy][xxx], Odata);
+ }
+
+ #undef X_
+ #undef K_
+}
+
+
+#undef THREAD_COUNT
+#define THREAD_COUNT 64 // ATM support only 8x8
+
+#undef BLOCK_WIDTH
+#define BLOCK_WIDTH 8
+
+#undef LOAD_WIDTH
+#define LOAD_WIDTH THREAD_COUNT
+
+#undef LOAD_DEPTH
+#define LOAD_DEPTH BLOCK_WIDTH
+
+#if 1
+
+groupshared float DenseTiled_XcacheR[32][LOAD_WIDTH];
+groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
+
+[numthreads(THREAD_COUNT, 1, 1)]
+void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_XcacheR
+ #define W_ DenseTiled_WcacheR
+
+ uint id = groupThreadID.x;
+ uint bx = groupID.x;
+ uint by = groupID.y;
+
+ uint bbx = id % BLOCK_WIDTH;
+ uint bby = id / BLOCK_WIDTH;
+
+ float v[BLOCK_WIDTH][BLOCK_WIDTH];
+ for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
+ for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
+ {
+ float bias = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx, 0, WBKdata, WBK.dataLength);
+ v[yy][xx] = bias;
+ }
+
+ for (uint m = 0; m < X.width/LOAD_DEPTH; ++m)
+ {
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ X_[q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
+ W_[q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ {
+ X_[yyy][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + yyy, 0, Xdata);
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * W_[i][bbx*BLOCK_WIDTH + xxx];
+ }
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, 0, v[yyy][xxx], Odata);
+
+ #undef X_
+ #undef W_
+}
+
+#elif 1
+groupshared float DenseTiled_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
+groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
+
+[numthreads(THREAD_COUNT, 1, 1)]
+void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_XcacheR
+ #define W_ DenseTiled_WcacheR
+
+ uint id = groupThreadID.x;
+ uint bx = groupID.x;
+ uint by = groupID.y;
+
+ uint bbx = id % BLOCK_WIDTH;
+ uint bby = id / BLOCK_WIDTH;
+
+ float v[BLOCK_WIDTH][BLOCK_WIDTH];
+ for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
+ for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
+ {
+ float bias = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx, 0, WBKdata, WBK.dataLength);
+ v[yy][xx] = bias;
+ }
+
+ for (uint m = 0; m < X.width/LOAD_DEPTH; ++m)
+ {
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ X_[q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
+ W_[q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ //v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * W_[i][bbx*BLOCK_WIDTH + xxx];
+ v[yyy][xxx] = mad(X_[i][bby*BLOCK_WIDTH + yyy], W_[i][bbx*BLOCK_WIDTH + xxx], v[yyy][xxx]);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, 0, v[yyy][xxx], Odata);
+
+ #undef X_
+ #undef W_
+}
+
+#elif 1
+
+// unroll array to help some "naive" compilers to map to regs
+// could be easier to lay out zigzagging patterns
+groupshared float DenseTiled_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
+groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
+
+[numthreads(THREAD_COUNT, 1, 1)]
+void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_XcacheR
+ #define W_ DenseTiled_WcacheR
+
+ uint id = groupThreadID.x;
+ uint bx = groupID.x;
+ uint by = groupID.y;
+
+ uint bbx = id % BLOCK_WIDTH;
+ uint bby = id / BLOCK_WIDTH;
+
+ //float v[BLOCK_WIDTH][BLOCK_WIDTH];
+ float
+ v00, v01, v02, v03, v04, v05, v06, v07,
+ v10, v11, v12, v13, v14, v15, v16, v17,
+ v20, v21, v22, v23, v24, v25, v26, v27,
+ v30, v31, v32, v33, v34, v35, v36, v37,
+ v40, v41, v42, v43, v44, v45, v46, v47,
+ v50, v51, v52, v53, v54, v55, v56, v57,
+ v60, v61, v62, v63, v64, v65, v66, v67,
+ v70, v71, v72, v73, v74, v75, v76, v77;
+
+ float b0 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 0, 0, WBKdata, WBK.dataLength);
+ float b1 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 1, 0, WBKdata, WBK.dataLength);
+ float b2 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 2, 0, WBKdata, WBK.dataLength);
+ float b3 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 3, 0, WBKdata, WBK.dataLength);
+ float b4 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 4, 0, WBKdata, WBK.dataLength);
+ float b5 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 5, 0, WBKdata, WBK.dataLength);
+ float b6 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 6, 0, WBKdata, WBK.dataLength);
+ float b7 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 7, 0, WBKdata, WBK.dataLength);
+
+ #define L_(y, x) v##y##x = b##x
+ L_(0,0); L_(0,1); L_(0,2); L_(0,3); L_(0,4); L_(0,5); L_(0,6); L_(0,7);
+ L_(1,0); L_(1,1); L_(1,2); L_(1,3); L_(1,4); L_(1,5); L_(1,6); L_(1,7);
+ L_(2,0); L_(2,1); L_(2,2); L_(2,3); L_(2,4); L_(2,5); L_(2,6); L_(2,7);
+ L_(3,0); L_(3,1); L_(3,2); L_(3,3); L_(3,4); L_(3,5); L_(3,6); L_(3,7);
+ L_(4,0); L_(4,1); L_(4,2); L_(4,3); L_(4,4); L_(4,5); L_(4,6); L_(4,7);
+ L_(5,0); L_(5,1); L_(5,2); L_(5,3); L_(5,4); L_(5,5); L_(5,6); L_(5,7);
+ L_(6,0); L_(6,1); L_(6,2); L_(6,3); L_(6,4); L_(6,5); L_(6,6); L_(6,7);
+ L_(7,0); L_(7,1); L_(7,2); L_(7,3); L_(7,4); L_(7,5); L_(7,6); L_(7,7);
+ #undef L_
+
+ for (uint m = 0; m < X.width/LOAD_DEPTH; ++m)
+ {
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ X_[q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
+ W_[q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ //v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * W_[i][bbx*BLOCK_WIDTH + xxx];
+ #define XW_(y, x) v##y##x += X_[i][bby*BLOCK_WIDTH + ##y] * W_[i][bbx*BLOCK_WIDTH + ##x]
+ XW_(0,0); XW_(0,1); XW_(0,2); XW_(0,3); XW_(0,4); XW_(0,5); XW_(0,6); XW_(0,7);
+ XW_(1,0); XW_(1,1); XW_(1,2); XW_(1,3); XW_(1,4); XW_(1,5); XW_(1,6); XW_(1,7);
+ XW_(2,0); XW_(2,1); XW_(2,2); XW_(2,3); XW_(2,4); XW_(2,5); XW_(2,6); XW_(2,7);
+ XW_(3,0); XW_(3,1); XW_(3,2); XW_(3,3); XW_(3,4); XW_(3,5); XW_(3,6); XW_(3,7);
+ XW_(4,0); XW_(4,1); XW_(4,2); XW_(4,3); XW_(4,4); XW_(4,5); XW_(4,6); XW_(4,7);
+ XW_(5,0); XW_(5,1); XW_(5,2); XW_(5,3); XW_(5,4); XW_(5,5); XW_(5,6); XW_(5,7);
+ XW_(6,0); XW_(6,1); XW_(6,2); XW_(6,3); XW_(6,4); XW_(6,5); XW_(6,6); XW_(6,7);
+ XW_(7,0); XW_(7,1); XW_(7,2); XW_(7,3); XW_(7,4); XW_(7,5); XW_(7,6); XW_(7,7);
+ #undef XW_
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ #define S_(a, b) O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + ##a, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + ##b, 0, v##a##b, Odata)
+ S_(0,0); S_(0,1); S_(0,2); S_(0,3); S_(0,4); S_(0,5); S_(0,6); S_(0,7);
+ S_(1,0); S_(1,1); S_(1,2); S_(1,3); S_(1,4); S_(1,5); S_(1,6); S_(1,7);
+ S_(2,0); S_(2,1); S_(2,2); S_(2,3); S_(2,4); S_(2,5); S_(2,6); S_(2,7);
+ S_(3,0); S_(3,1); S_(3,2); S_(3,3); S_(3,4); S_(3,5); S_(3,6); S_(3,7);
+ S_(4,0); S_(4,1); S_(4,2); S_(4,3); S_(4,4); S_(4,5); S_(4,6); S_(4,7);
+ S_(5,0); S_(5,1); S_(5,2); S_(5,3); S_(5,4); S_(5,5); S_(5,6); S_(5,7);
+ S_(6,0); S_(6,1); S_(6,2); S_(6,3); S_(6,4); S_(6,5); S_(6,6); S_(6,7);
+ S_(7,0); S_(7,1); S_(7,2); S_(7,3); S_(7,4); S_(7,5); S_(7,6); S_(7,7);
+ #undef S_
+
+ #undef X_
+ #undef W_
+}
+
+#elif 1
+
+groupshared float DenseTiled_XcacheR[2][LOAD_DEPTH][LOAD_WIDTH];
+groupshared float DenseTiled_WcacheR[2][LOAD_DEPTH][LOAD_WIDTH];
+
+[numthreads(THREAD_COUNT, 1, 1)]
+void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_XcacheR
+ #define W_ DenseTiled_WcacheR
+
+ uint id = groupThreadID.x;
+ uint bx = groupID.x;
+ uint by = groupID.y;
+
+ uint bbx = id % BLOCK_WIDTH;
+ uint bby = id / BLOCK_WIDTH;
+
+ float v[BLOCK_WIDTH][BLOCK_WIDTH];
+ for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
+ [unroll] for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
+ {
+ float bias = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx, 0, WBKdata, WBK.dataLength);
+ v[yy][xx] = bias;
+ }
+
+ uint m = 0;
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ X_[0][q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
+ W_[0][q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ ++m;
+
+ for (; m < X.width/LOAD_DEPTH; ++m)
+ {
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ X_[1][q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
+ W_[1][q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
+ }
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ [unroll]
+ for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ v[yyy][xxx] += X_[0][i][bby*BLOCK_WIDTH + yyy] * W_[0][i][bbx*BLOCK_WIDTH + xxx];
+ }
+
+ ++m;
+ GroupMemoryBarrierWithGroupSync();
+
+ if (m < X.width/LOAD_DEPTH)
+ {
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ X_[0][q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
+ W_[0][q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
+ }
+ }
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ [unroll]
+ for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ v[yyy][xxx] += X_[1][i][bby*BLOCK_WIDTH + yyy] * W_[1][i][bbx*BLOCK_WIDTH + xxx];
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, 0, v[yyy][xxx], Odata);
+
+ #undef X_
+ #undef W_
+}
+
+#else
+
+groupshared float DenseTiled_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
+groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
+
+[numthreads(THREAD_COUNT, 1, 1)]
+void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_XcacheR
+ #define W_ DenseTiled_WcacheR
+
+ uint id = groupThreadID.x;
+ uint bx = groupID.x;
+ uint by = groupID.y;
+
+ uint n = by * LOAD_WIDTH + id;
+ uint x = bx * LOAD_WIDTH + id;
+
+ float v[LOAD_WIDTH];
+ float bias = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
+ [unroll] for (uint xx = 0; xx < LOAD_WIDTH; ++xx)
+ v[xx] = bias;
+
+ for (uint m = 0; m < X.width/LOAD_DEPTH; ++m)
+ {
+ float ww[LOAD_DEPTH];
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ X_[q][id] = X.Get(n, 0, m*LOAD_DEPTH + q, 0, Xdata);
+ //W_[q][id] = W.Get(0, m*LOAD_DEPTH + q, x, 0, WBKdata, WBK.dataLength);
+ ww[q] = W.Get(0, m*LOAD_DEPTH + q, x, 0, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint w = 0; w < LOAD_WIDTH; ++w)
+ {
+ [unroll]
+ for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ //v[w] += X_[i][w] * W_[i][id];
+ v[w] += X_[i][w] * ww[i];
+ }
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ [unroll] for ( xx = 0; xx < LOAD_WIDTH; ++xx)
+ O.Set(by * LOAD_WIDTH + xx, 0, x, 0, v[xx], Odata);
+
+ #undef X_
+ #undef W_
+}
+#endif
+
+#if 1
+#undef TILE_WIDTH
+#define TILE_WIDTH 16
+groupshared float DenseTiled_Xcache64[16][TILE_WIDTH*TILE_WIDTH];
+groupshared float DenseTiled_Wcache64[16][TILE_WIDTH*TILE_WIDTH];
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiled64x64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_Xcache64
+ #define W_ DenseTiled_Wcache64
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint n = groupID.y*TILE_WIDTH + ty;
+
+ float b0 = B.Get(0, 0, x*4+0, 0, WBKdata, WBK.dataLength);
+ float b1 = B.Get(0, 0, x*4+1, 0, WBKdata, WBK.dataLength);
+ float b2 = B.Get(0, 0, x*4+2, 0, WBKdata, WBK.dataLength);
+ float b3 = B.Get(0, 0, x*4+3, 0, WBKdata, WBK.dataLength);
+
+ float4 v0, v1, v2, v3;
+ v0 = v1 = v2 = v3 = float4(b0, b1, b2, b3);
+
+ for (uint m = 0; m < X.width/(TILE_WIDTH*4); ++m)
+ {
+ for (uint yy = 0; yy < 4; ++yy)
+ for (uint xx = 0; xx < 4; ++xx)
+ {
+ X_[yy*4+xx][ty*TILE_WIDTH+tx] = X.Get(n*4+yy, 0, (m*TILE_WIDTH + tx)*4+xx, 0, Xdata);
+ W_[yy*4+xx][ty*TILE_WIDTH+tx] = W.Get(0, (m*TILE_WIDTH + ty)*4+yy, x*4+xx, 0, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ //[unroll]
+ for (uint i = 0; i < TILE_WIDTH; ++i)
+ {
+ [unroll]
+ for (uint q = 0; q < 4; ++q)
+ {
+ float x0 = X_[0*4+q][ty*TILE_WIDTH+i];
+ float x1 = X_[1*4+q][ty*TILE_WIDTH+i];
+ float x2 = X_[2*4+q][ty*TILE_WIDTH+i];
+ float x3 = X_[3*4+q][ty*TILE_WIDTH+i];
+
+ float w0 = W_[q*4+0][i*TILE_WIDTH+tx];
+ float w1 = W_[q*4+1][i*TILE_WIDTH+tx];
+ float w2 = W_[q*4+2][i*TILE_WIDTH+tx];
+ float w3 = W_[q*4+3][i*TILE_WIDTH+tx];
+
+ v0.x = mad(x0, w0, v0.x); //--
+ v1.x = mad(x1, w0, v1.x);
+ v2.x = mad(x2, w0, v2.x);
+ v3.x = mad(x3, w0, v3.x);
+ v0.y = mad(x0, w1, v0.y); //--
+ v1.y = mad(x1, w1, v1.y);
+ v2.y = mad(x2, w1, v2.y);
+ v3.y = mad(x3, w1, v3.y);
+ v0.z = mad(x0, w2, v0.z); //--
+ v1.z = mad(x1, w2, v1.z);
+ v2.z = mad(x2, w2, v2.z);
+ v3.z = mad(x3, w2, v3.z);
+ v0.w = mad(x0, w3, v0.w); //--
+ v1.w = mad(x1, w3, v1.w);
+ v2.w = mad(x2, w3, v2.w);
+ v3.w = mad(x3, w3, v3.w);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+
+ O.Set(n*4+0, 0, x*4+0, 0, v0.x, Odata);
+ O.Set(n*4+0, 0, x*4+1, 0, v0.y, Odata);
+ O.Set(n*4+0, 0, x*4+2, 0, v0.z, Odata);
+ O.Set(n*4+0, 0, x*4+3, 0, v0.w, Odata);
+
+ O.Set(n*4+1, 0, x*4+0, 0, v1.x, Odata);
+ O.Set(n*4+1, 0, x*4+1, 0, v1.y, Odata);
+ O.Set(n*4+1, 0, x*4+2, 0, v1.z, Odata);
+ O.Set(n*4+1, 0, x*4+3, 0, v1.w, Odata);
+
+ O.Set(n*4+2, 0, x*4+0, 0, v2.x, Odata);
+ O.Set(n*4+2, 0, x*4+1, 0, v2.y, Odata);
+ O.Set(n*4+2, 0, x*4+2, 0, v2.z, Odata);
+ O.Set(n*4+2, 0, x*4+3, 0, v2.w, Odata);
+
+ O.Set(n*4+3, 0, x*4+0, 0, v3.x, Odata);
+ O.Set(n*4+3, 0, x*4+1, 0, v3.y, Odata);
+ O.Set(n*4+3, 0, x*4+2, 0, v3.z, Odata);
+ O.Set(n*4+3, 0, x*4+3, 0, v3.w, Odata);
+
+ #undef X_
+ #undef W_
+}
+
+#else
+
+#define TILE_WIDTH 16
+#define RTILE 4
+groupshared float DenseTiled_Xcache64[RTILE*RTILE][TILE_WIDTH*TILE_WIDTH];
+groupshared float DenseTiled_Wcache64[RTILE*RTILE][TILE_WIDTH*TILE_WIDTH];
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiled64x64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_Xcache64
+ #define W_ DenseTiled_Wcache64
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint n = groupID.y*TILE_WIDTH + ty;
+
+ float v[RTILE*RTILE];
+ [unroll] for (uint xxxx = 0; xxxx < RTILE; ++xxxx)
+ {
+ float b = B.Get(0, 0, x*RTILE+xxxx, 0, WBKdata, WBK.dataLength);
+ [unroll] for (uint yyyy = 0; yyyy < RTILE; ++yyyy)
+ v[yyyy*RTILE+xxxx] = b;
+ }
+
+ for (uint m = 0; m < X.width/(TILE_WIDTH*RTILE); ++m)
+ {
+ for (uint yy = 0; yy < RTILE; ++yy)
+ [unroll] for (uint xx = 0; xx < RTILE; ++xx)
+ {
+ X_[yy*RTILE+xx][ty*TILE_WIDTH+tx] = X.Get(n*RTILE+yy, 0, (m*TILE_WIDTH + tx)*RTILE+xx, 0, Xdata);
+ W_[yy*RTILE+xx][ty*TILE_WIDTH+tx] = W.Get(0, (m*TILE_WIDTH + ty)*RTILE+yy, x*RTILE+xx, 0, WBKdata, WBK.dataLength);
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint ii = 0; ii < TILE_WIDTH; ++ii)
+ {
+ [unroll] for (uint yy = 0; yy < RTILE; ++yy)
+ [unroll] for (uint xx = 0; xx < RTILE; ++xx)
+ [unroll] for (uint i = 0; i < RTILE; ++i)
+ {
+ float x = X_[yy*RTILE+i][ty*TILE_WIDTH+ii];
+ float w = W_[i*RTILE+xx][ii*TILE_WIDTH+tx];
+ v[yy*RTILE+xx] = mad(x, w, v[yy*RTILE+xx]);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+
+ [unroll] for (uint yy = 0; yy < RTILE; ++yy)
+ [unroll] for (uint xx = 0; xx < RTILE; ++xx)
+ O.Set(n*RTILE+yy, 0, x*RTILE+xx, 0, v[yy*RTILE+xx], Odata);
+
+ #undef X_
+ #undef W_
+}
+
+#endif
+
+#undef TILE_WIDTH
+#define TILE_WIDTH 16 // 32 crashes on MacBookPro/AMD
+groupshared float DenseTiled_Xcache32[4][TILE_WIDTH][TILE_WIDTH];
+groupshared float DenseTiled_Wcache32[4][TILE_WIDTH][TILE_WIDTH];
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiled32x32(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_Xcache32
+ #define W_ DenseTiled_Wcache32
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint n = groupID.y*TILE_WIDTH + ty;
+
+ float b0 = B.Get(0, 0, x*2+0, 0, WBKdata, WBK.dataLength);
+ float b1 = B.Get(0, 0, x*2+1, 0, WBKdata, WBK.dataLength);
+ float4 v = float4(b0, b1,
+ b0, b1);
+
+ for (uint m = 0; m < X.width/(TILE_WIDTH*2);)
+ {
+ // @TODO: read in float2s
+ float x0 = X.Get(n*2+0, 0, m*TILE_WIDTH*2 + tx*2+0, 0, Xdata);
+ float x1 = X.Get(n*2+0, 0, m*TILE_WIDTH*2 + tx*2+1, 0, Xdata);
+ float x2 = X.Get(n*2+1, 0, m*TILE_WIDTH*2 + tx*2+0, 0, Xdata);
+ float x3 = X.Get(n*2+1, 0, m*TILE_WIDTH*2 + tx*2+1, 0, Xdata);
+
+ float w0 = W.Get(0, m*TILE_WIDTH*2 + ty*2+0, x*2+0, 0, WBKdata, WBK.dataLength);
+ float w1 = W.Get(0, m*TILE_WIDTH*2 + ty*2+0, x*2+1, 0, WBKdata, WBK.dataLength);
+ float w2 = W.Get(0, m*TILE_WIDTH*2 + ty*2+1, x*2+0, 0, WBKdata, WBK.dataLength);
+ float w3 = W.Get(0, m*TILE_WIDTH*2 + ty*2+1, x*2+1, 0, WBKdata, WBK.dataLength);
+
+ ++m;
+
+ X_[0][ty][tx] = x0;
+ X_[1][ty][tx] = x1;
+ X_[2][ty][tx] = x2;
+ X_[3][ty][tx] = x3;
+
+ W_[0][ty][tx] = w0;
+ W_[1][ty][tx] = w1;
+ W_[2][ty][tx] = w2;
+ W_[3][ty][tx] = w3;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < TILE_WIDTH; ++i)
+ {
+ float4 x = //X_[ty][i];
+ float4( X_[0][ty][i],
+ X_[1][ty][i],
+ X_[2][ty][i],
+ X_[3][ty][i]);
+ float4 w = //W_[i][tx];
+ float4( W_[0][i][tx],
+ W_[1][i][tx],
+ W_[2][i][tx],
+ W_[3][i][tx]);
+
+ v.x = mad(w.x, x.x, v.x);
+ v.y = mad(w.y, x.x, v.y);
+ v.z = mad(w.x, x.z, v.z);
+ v.w = mad(w.y, x.z, v.w);
+
+ v.x = mad(w.z, x.y, v.x);
+ v.y = mad(w.w, x.y, v.y);
+ v.z = mad(w.z, x.w, v.z);
+ v.w = mad(w.w, x.w, v.w);
+
+ //v.x += k.x*x.x + k.z*x.y;
+ //v.y += k.y*x.x + k.w*x.y;
+ //v.z += k.x*x.z + k.z*x.w;
+ //v.w += k.y*x.z + k.w*x.w;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ O.Set(n*2+0, 0, x*2+0, 0, v.x, Odata);
+ O.Set(n*2+0, 0, x*2+1, 0, v.y, Odata);
+ O.Set(n*2+1, 0, x*2+0, 0, v.z, Odata);
+ O.Set(n*2+1, 0, x*2+1, 0, v.w, Odata);
+
+ #undef X_
+ #undef W_
+}
+
+// sligtly faster on AMD (56ms vs 62ms)
+#undef TILE_WIDTH
+#define TILE_WIDTH 16
+//#define CACHE_ONLY_X
+//#define TRANSPOSE_W
+//#define TRANSPOSE_X
+groupshared float DenseTiled_XcacheF[TILE_WIDTH][TILE_WIDTH];
+#if !defined(CACHE_ONLY_X)
+groupshared float DenseTiled_WcacheF[TILE_WIDTH][TILE_WIDTH];
+#endif
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiled16x16_amd(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_XcacheF
+ #define W_ DenseTiled_WcacheF
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint b = groupID.y*TILE_WIDTH + ty;
+
+ float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
+
+ for (uint m = 0; m < X.width/TILE_WIDTH; ++m)
+ {
+ #if defined(TRANSPOSE_X)
+ X_[tx][ty] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
+ #else
+ X_[ty][tx] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
+ #endif
+
+ #if defined(CACHE_ONLY_X)
+ float ww = WBKdata[wi];
+ #else
+ #if defined(TRANSPOSE_W)
+ W_[tx][ty] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
+ #else
+ W_[ty][tx] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
+ #endif
+ #endif
+ GroupMemoryBarrierWithGroupSync();
+
+ //[unroll(groupthreads)]
+ [unroll]
+ for (uint i = 0; i < TILE_WIDTH; ++i)
+ {
+ #if defined(TRANSPOSE_X)
+ float x = X_[i][ty];
+ #else
+ float x = X_[ty][i];
+ #endif
+
+ #if defined(CACHE_ONLY_X)
+ //float w = ww;
+ //if (i != TILE_WIDTH-1) { wi += W.width; ww = WBKdata[wi]; }
+ float w = W.Get(0, m*TILE_WIDTH + i, x, 0, WBKdata, WBK.dataLength);
+ #else
+ #if defined(TRANSPOSE_W)
+ float w = W_[tx][i];
+ #else
+ float w = W_[i][tx];
+ #endif
+ #endif
+
+ v += x * w;
+ }
+ }
+
+ O.Set(b, 0, x, 0, v, Odata);
+
+ #undef X_
+ #undef W_
+}
+
+#undef TILE_WIDTH
+#define TILE_WIDTH 16
+groupshared float DenseTiled_Xcache[TILE_WIDTH][TILE_WIDTH];
+groupshared float DenseTiled_Wcache[TILE_WIDTH][TILE_WIDTH];
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiled(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_Xcache
+ #define W_ DenseTiled_Wcache
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint b = groupID.y*TILE_WIDTH + ty;
+
+ bool mask = (x < O.width && b < O.batch);
+
+ float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
+
+ for (uint m = 0; m < X.width/TILE_WIDTH; ++m)
+ {
+ if (mask)
+ {
+ X_[ty][tx] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
+ W_[ty][tx] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
+ }
+ else
+ {
+ X_[ty][tx] = 0;
+ W_[ty][tx] = 0;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < TILE_WIDTH; ++i)
+ {
+ v += X_[ty][i] * W_[i][tx];
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ if (mask)
+ O.Set(b, 0, x, 0, v, Odata);
+
+ #undef X_
+ #undef W_
+}
+
+
+groupshared float DenseTiled_XcacheP[TILE_WIDTH][TILE_WIDTH];
+groupshared float DenseTiled_WcacheP[TILE_WIDTH][TILE_WIDTH];
+// Prefetch - seems to be the same performance as DenseTiled16x16 without prefetch, has higher register pressure
+[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
+void DenseTiledPrefetch16x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ DenseTiled_XcacheP
+ #define W_ DenseTiled_WcacheP
+
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint x = groupID.x*TILE_WIDTH + tx;
+ uint b = groupID.y*TILE_WIDTH + ty;
+
+ float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
+
+ float Xregs[TILE_WIDTH][TILE_WIDTH];
+ float Wregs[TILE_WIDTH][TILE_WIDTH];
+ for (uint m = 0; m < X.width/TILE_WIDTH; ++m)
+ {
+ Xregs[ty][tx] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
+ Wregs[ty][tx] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ for (m = 0; m < X.width/TILE_WIDTH; ++m)
+ {
+ X_[ty][tx] = Xregs[ty][tx];
+ W_[ty][tx] = Wregs[ty][tx];
+
+ Xregs[ty][tx] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
+ Wregs[ty][tx] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
+
+ for (uint i = 0; i < TILE_WIDTH;)
+ {
+ // can unroll up to 16 because TILE_WIDTH=16
+ v += X_[ty][i] * W_[i][tx]; ++i;
+ v += X_[ty][i] * W_[i][tx]; ++i;
+ v += X_[ty][i] * W_[i][tx]; ++i;
+ v += X_[ty][i] * W_[i][tx]; ++i;
+
+ v += X_[ty][i] * W_[i][tx]; ++i;
+ v += X_[ty][i] * W_[i][tx]; ++i;
+ v += X_[ty][i] * W_[i][tx]; ++i;
+ v += X_[ty][i] * W_[i][tx]; ++i;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ O.Set(b, 0, x, 0, v, Odata);
+ #undef X_
+ #undef W_
+}
+
+[numthreads(1,1,1)]
+void Relu(uint3 groupID : SV_GroupID)
+{
+ uint x = groupID.x;
+ uint b = groupID.y;
+ uint c = groupID.z;
+ for (uint y = 0; y < X.height; ++y)
+ {
+ float v = X.Get(b, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ }
+}
+
+[numthreads(16,16,1)]
+void Relu_Cmod16_CNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint c = 16*groupID.x + groupThreadID.x;
+ uint nyx = 16*groupID.y + groupThreadID.y;
+
+ uint width = X.width;
+ uint height = X.height;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float v = X.Get(n, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(n, y, x, c, v, Odata, O.dataLength);
+}
+
+[numthreads(512,1,1)]
+void Relu_Nyxc(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint nyxc = 512*groupID.x + groupThreadID.x;
+
+ uint width = X.width;
+ uint height = X.height;
+ uint channels = X.channels;
+
+ uint c = nyxc % channels;
+ uint nyx = nyxc / channels;
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float v = X.Get(n, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(n, y, x, c, v, Odata, O.dataLength);
+}
+
+[numthreads(16,16,1)]
+void Relu16x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint x = 16*groupID.x + groupThreadID.x;
+ uint b = 16*groupID.y + groupThreadID.y;
+ uint c = groupID.z;
+
+ for (uint y = 0; y < X.height; ++y)
+ {
+ float v = X.Get(b, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ }
+}
+
+[numthreads(16,16,1)]
+void Relu16x16_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint x = 16*groupID.x + groupThreadID.x;
+ uint b = 16*groupID.y + groupThreadID.y;
+
+ for (uint y = 0; y < X.height; ++y)
+ {
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ float v = X.Get(b, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ }
+ }
+}
+
+
+// channels, width, batch
+[numthreads(16,2,16)]
+void ReluChannelsFirst16x2x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint c = 16*groupID.x + groupThreadID.x;
+ uint x = 2*groupID.y + groupThreadID.y;
+ uint b = 16*groupID.z + groupThreadID.z;
+
+ for (uint y = 0; y < X.height; ++y)
+ {
+ float v = X.Get(b, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ }
+}
+
+[numthreads(256,1,1)]
+void Relu256xV(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint x = 256*groupID.x + groupThreadID.x;
+ uint b = groupID.y;
+ uint c = groupID.z;
+
+ for (uint y = 0; y < X.height; ++y)
+ {
+ float v = 0;
+ for (uint b = 0; b < X.batch; )
+ {
+ v = X.Get(b, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ ++b;
+
+ v = X.Get(b, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ ++b;
+
+ v = X.Get(b, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ ++b;
+
+ v = X.Get(b, y, x, c, Xdata, X.dataLength);
+ v = 0.5f * (v + abs(v));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ ++b;
+ }
+ }
+}
+
+
+#define FLT_MAX 3.402823466e+38F
+
+[numthreads(1,1,1)]
+void Softmax(uint3 groupID : SV_GroupID)
+{
+ uint b = groupID.x;
+ uint x = groupID.y;
+
+ float maxV = -FLT_MAX;
+ for (uint i = 0; i < X.width; ++i)
+ {
+ float v = X.Get(b, 0, i, 0, Xdata, X.dataLength);
+ if (v > maxV)
+ maxV = v;
+ }
+
+ float sum = 0.0f;
+ for (i = 0; i < X.width; ++i)
+ {
+ float v = X.Get(b, 0, i, 0, Xdata, X.dataLength);
+ sum += exp(v - maxV);
+ }
+
+ float v = X.Get(b, 0, x, 0, Xdata, X.dataLength);
+ v = exp(v - maxV) / sum;
+ O.Set(b, 0, x, 0, v, Odata, O.dataLength);
+}
+
+[numthreads(256,2,1)]
+void Softmax256x2(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint b = 256*groupID.x + groupThreadID.x;
+ uint x = 2*groupID.y + groupThreadID.y;
+
+ float maxV = -FLT_MAX;
+ for (uint i = 0; i < X.width; ++i)
+ {
+ float v = X.Get(b, 0, i, 0, Xdata, X.dataLength);
+ if (v > maxV)
+ maxV = v;
+ }
+
+ float sum = 0.0f;
+ for (i = 0; i < X.width; ++i)
+ {
+ float v = X.Get(b, 0, i, 0, Xdata, X.dataLength);
+ sum += exp(v - maxV);
+ }
+
+ float v = X.Get(b, 0, x, 0, Xdata, X.dataLength);
+ v = exp(v - maxV) / sum;
+ O.Set(b, 0, x, 0, v, Odata, O.dataLength);
+}
+
+[numthreads(1,1,1)]
+void MaxPooling2D(uint3 groupID : SV_GroupID)
+{
+ uint c = groupID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ for (uint b = 0; b < O.batch; ++b)
+ {
+ float v0 = X.Get(b, y*2, x*2, c, Xdata, X.dataLength);
+ float v1 = X.Get(b, y*2+1, x*2, c, Xdata, X.dataLength);
+ float v2 = X.Get(b, y*2, x*2+1, c, Xdata, X.dataLength);
+ float v3 = X.Get(b, y*2+1, x*2+1, c, Xdata, X.dataLength);
+ float v = max(v0, max(v1, max(v2, v3)));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ }
+}
+
+[numthreads(16,4,4)]
+void MaxPooling2D16x4x4(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint c = 16*groupID.x + groupThreadID.x;
+ uint x = 4*groupID.y + groupThreadID.y;
+ uint y = 4*groupID.z + groupThreadID.z;
+
+ for (uint b = 0; b < O.batch; ++b)
+ {
+ float v0 = X.Get(b, y*2, x*2, c, Xdata, X.dataLength);
+ float v1 = X.Get(b, y*2+1, x*2, c, Xdata, X.dataLength);
+ float v2 = X.Get(b, y*2, x*2+1, c, Xdata, X.dataLength);
+ float v3 = X.Get(b, y*2+1, x*2+1, c, Xdata, X.dataLength);
+ float v = max(v0, max(v1, max(v2, v3)));
+ O.Set(b, y, x, c, v, Odata, O.dataLength);
+ }
+}
+
+[numthreads(16,16,2)]
+void Conv2D_Valid(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint k = 16*groupID.x + groupThreadID.x;
+ uint n = 16*groupID.y + groupThreadID.y;
+ uint y = 2*groupID.z + groupThreadID.z + _FilterSize;
+
+ //for (int y = _FilterSize; y < X.height - _FilterSize; ++y)
+ {
+ for (uint x = _FilterSize; x < X.width - _FilterSize; ++x)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (int i = -(int)_FilterSize; i < (int)_FilterSize + 1; ++i)
+ {
+ for (int j = -(int)_FilterSize; j < (int)_FilterSize + 1; ++j)
+ {
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ v += X.Get(n, y+j, x+i, c, Xdata, X.dataLength) * K.Get(_FilterSize+j, _FilterSize+i, c, k, WBKdata, WBK.dataLength);
+ }
+ }
+ }
+ O.Set(n, y-_FilterSize, x-_FilterSize, k, v, Odata, O.dataLength);
+ }
+ }
+}
+
+[numthreads(16,8,1)]
+void Conv2D_Kmod16_Nmod8_KNY(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint k = 16*groupID.x + groupThreadID.x;
+ uint n = 8*groupID.y + groupThreadID.y;
+ uint y = 1*groupID.z + groupThreadID.z;
+
+ //for (int y = _FilterSize; y < X.height - _FilterSize; ++y)
+ {
+ for (uint x = 0; x < X.width - _Border; ++x)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint j = 0; j < 2*_FilterSize+1; ++j)
+ {
+ if (y+j < _Offset) continue;
+ if (y+j-_Offset >= X.height) continue;
+
+ for (uint i = 0; i < 2*_FilterSize+1; ++i)
+ {
+ if (x+i < _Offset) continue;
+ if (x+i-_Offset >= X.width) continue;
+
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ v += X.Get(n, y+j-_Offset, x+i-_Offset, c, Xdata, X.dataLength) * K.Get(j, i, c, k, WBKdata, WBK.dataLength);
+ }
+ }
+ }
+ O.Set(n, y, x, k, v, Odata, O.dataLength);
+ }
+ }
+}
+
+[numthreads(1,1,1)]
+void Conv2D(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint k = 1*groupID.x + groupThreadID.x;
+ uint n = 1*groupID.y + groupThreadID.y;
+ uint y = 1*groupID.z + groupThreadID.z;
+
+ //for (int y = _FilterSize; y < X.height - _FilterSize; ++y)
+ {
+ for (uint x = 0; x < X.width - _Border; ++x)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint j = 0; j < 2*_FilterSize+1; ++j)
+ {
+ if (y+j < _Offset) continue;
+ if (y+j-_Offset >= X.height) continue;
+
+ for (uint i = 0; i < 2*_FilterSize+1; ++i)
+ {
+ if (x+i < _Offset) continue;
+ if (x+i-_Offset >= X.width) continue;
+
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ v += X.Get(n, y+j-_Offset, x+i-_Offset, c, Xdata, X.dataLength) * K.Get(j, i, c, k, WBKdata, WBK.dataLength);
+ }
+ }
+ }
+ O.Set(n, y, x, k, v, Odata, O.dataLength);
+ }
+ }
+}
+
+#if 0
+
+#define MAX_TILE_WIDTH 16
+#define KERNEL_COUNT 4
+#define KERNEL_SIZE 3
+#define KERNEL_RADIUS 1 //(KERNEL_SIZE-1)/2
+groupshared float XCcache[MAX_TILE_WIDTH+KERNEL_SIZE-1][MAX_TILE_WIDTH+KERNEL_SIZE-1];
+groupshared float Kcache[KERNEL_SIZE][KERNEL_SIZE][KERNEL_COUNT];
+
+#undef TILE_WIDTH
+#define TILE_WIDTH 13
+[numthreads(TILE_WIDTH,TILE_WIDTH,KERNEL_COUNT)]
+void Conv2DTiled14x14_Kernel3x3(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint tk = groupThreadID.z;
+ uint gx = groupID.x;
+ uint gy = groupID.y;
+ uint gk = groupID.z;
+ uint tileCornerX = gx*TILE_WIDTH;
+ uint tileCornerY = gy*TILE_WIDTH;
+ uint x = tileCornerX + tx;
+ uint y = tileCornerY + ty;
+ uint k = gk*KERNEL_COUNT + tk;
+ uint idx = ty*TILE_WIDTH + tx;
+
+ for (uint b = 0; b < X.batch; ++b)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ if (tk == 0)
+ XCcache[ty][tx] = X.Get(b, y, x, c, Xdata);
+ else if (tk == 1 && idx < TILE_WIDTH * 2)
+ {
+ uint yy = idx / 2;
+ uint xx = idx % 2 + TILE_WIDTH;
+ XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
+ }
+ else if (tk == 2 && idx < (TILE_WIDTH + 2) * 2)
+ {
+ uint yy = idx / (TILE_WIDTH + 2) + TILE_WIDTH;
+ uint xx = idx % (TILE_WIDTH + 2);
+ XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
+ }
+ if (tk == 3)
+ {
+ uint kk = idx / (KERNEL_SIZE * KERNEL_SIZE);
+ uint kyx = idx % (KERNEL_SIZE * KERNEL_SIZE);
+ if (kk < KERNEL_COUNT)
+ {
+ uint yy = kyx / KERNEL_SIZE;
+ uint xx = kyx % KERNEL_SIZE;
+ Kcache[yy][xx][kk] = K.Get(yy, xx, c, gk*KERNEL_COUNT+kk, WBKdata, WBK.dataLength);
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ for (int i = 0; i < KERNEL_SIZE; ++i)
+ {
+ for (int j = 0; j < KERNEL_SIZE; ++j)
+ {
+ v += XCcache[ty+j][tx+i] * Kcache[j][i][tk];
+ }
+ }
+ }
+ O.Set(b, y, x, k, v, Odata, O.dataLength);
+ GroupMemoryBarrierWithGroupSync();
+ }
+}
+
+#undef TILE_WIDTH
+#define TILE_WIDTH 12
+[numthreads(TILE_WIDTH,TILE_WIDTH,KERNEL_COUNT)]
+void Conv2DTiled13x13_Kernel3x3(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint tk = groupThreadID.z;
+ uint gx = groupID.x;
+ uint gy = groupID.y;
+ uint gk = groupID.z;
+ uint tileCornerX = gx*TILE_WIDTH;
+ uint tileCornerY = gy*TILE_WIDTH;
+ uint x = tileCornerX + tx;
+ uint y = tileCornerY + ty;
+ uint k = gk*KERNEL_COUNT + tk;
+ uint idx = ty*TILE_WIDTH + tx;
+
+ for (uint b = 0; b < X.batch; ++b)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ if (tk == 0)
+ XCcache[ty][tx] = X.Get(b, y, x, c, Xdata);
+ else if (tk == 1 && idx < TILE_WIDTH * 2)
+ {
+ uint yy = idx / 2;
+ uint xx = idx % 2 + TILE_WIDTH;
+ XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
+ }
+ else if (tk == 2 && idx < (TILE_WIDTH + 2) * 2)
+ {
+ uint yy = idx / (TILE_WIDTH + 2) + TILE_WIDTH;
+ uint xx = idx % (TILE_WIDTH + 2);
+ XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
+ }
+ if (tk == 3)
+ {
+ uint kk = idx / (KERNEL_SIZE * KERNEL_SIZE);
+ uint kyx = idx % (KERNEL_SIZE * KERNEL_SIZE);
+ if (kk < KERNEL_COUNT)
+ {
+ uint yy = kyx / KERNEL_SIZE;
+ uint xx = kyx % KERNEL_SIZE;
+ Kcache[yy][xx][kk] = K.Get(yy, xx, c, gk*KERNEL_COUNT+kk, WBKdata, WBK.dataLength);
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ for (int i = 0; i < KERNEL_SIZE; ++i)
+ {
+ for (int j = 0; j < KERNEL_SIZE; ++j)
+ {
+ v += XCcache[ty+j][tx+i] * Kcache[j][i][tk];
+ }
+ }
+ }
+ O.Set(b, y, x, k, v, Odata, O.dataLength);
+ GroupMemoryBarrierWithGroupSync();
+ }
+}
+
+/*
+#undef TILE_WIDTH
+#define TILE_WIDTH 12
+[numthreads(TILE_WIDTH,TILE_WIDTH,KERNEL_COUNT)]
+void Conv2DTiled12x12_Kernel3x3(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint tx = groupThreadID.x;
+ uint ty = groupThreadID.y;
+ uint tk = groupThreadID.z;
+ uint gx = groupID.x;
+ uint gy = groupID.y;
+ uint gk = groupID.z;
+ uint tileCornerX = gx*TILE_WIDTH;
+ uint tileCornerY = gy*TILE_WIDTH;
+ uint x = tileCornerX + tx;
+ uint y = tileCornerY + ty;
+ uint k = gk*KERNEL_COUNT + tk;
+ uint idx = ty*TILE_WIDTH + tx;
+
+ for (uint b = 0; b < X.batch; ++b)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ if (gk == 0)
+ XCcache[ty][tx] = X.Get(b, y, x, c, Xdata);
+ else if (gk == 1 && idx < TILE_WIDTH * 2)
+ {
+ uint yy = idx / 2;
+ uint xx = idx % 2 + TILE_WIDTH;
+ XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
+ }
+ else if (gk == 2 && idx < (TILE_WIDTH + 2) * 2)
+ {
+ uint yy = idx / (TILE_WIDTH + 2) + TILE_WIDTH;
+ uint xx = idx % (TILE_WIDTH + 2);
+ XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
+ }
+ else if (gk == 3 && ty < KERNEL_SIZE && tx < KERNEL_SIZE)
+ Kcache[ty][tx][tk] = K.Get(ty, tx, c, k, WBKdata, WBK.dataLength);
+ GroupMemoryBarrierWithGroupSync();
+
+ for (int i = 0; i < KERNEL_SIZE; ++i)
+ {
+ for (int j = 0; j < KERNEL_SIZE; ++j)
+ {
+ v += XCcache[ty+j][tx+i] * Kcache[j][i][tk];
+ }
+ }
+ }
+ O.Set(b, y-KERNEL_RADIUS, x-KERNEL_RADIUS, k, v, Odata, O.dataLength);
+ GroupMemoryBarrierWithGroupSync();
+ }
+}
+*/
+
+// %TODO: only supports up to 32 channels now
+#undef KERNEL_COUNT
+#undef CHANNEL_COUNT
+#define KERNEL_COUNT 16
+#define CHANNEL_COUNT 32
+groupshared float K2cache[CHANNEL_COUNT][KERNEL_COUNT][9];
+[numthreads(KERNEL_COUNT,CHANNEL_COUNT,1)]
+void Conv2D_Kernel3x3_32Channel_Valid(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint tk = groupThreadID.x;
+ uint k = KERNEL_COUNT*groupID.x + tk;
+ uint n = CHANNEL_COUNT*groupID.y + groupThreadID.y;
+
+ for (uint q = 0; q < 9; ++q)
+ {
+ uint tc = n % CHANNEL_COUNT;
+ K2cache[tc][tk][q] = K.Get(q/3, q%3, tc, k, WBKdata, WBK.dataLength);
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint y = 0; y < X.height - _FilterSize*2; ++y)
+ {
+ for (uint x = 0; x < X.width - _FilterSize*2; ++x)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint q = 0; q < 9; ++q)
+ for (uint c = 0; c < CHANNEL_COUNT; c += 4)
+ {
+ //K.Get(q/3, q%3, c, k, WBKdata, WBK.dataLength);
+ v += X.Get(n, y+q/3, x+q%3, c+0, Xdata, X.dataLength) * K2cache[c+0][tk][q];
+ v += X.Get(n, y+q/3, x+q%3, c+1, Xdata, X.dataLength) * K2cache[c+1][tk][q];
+ v += X.Get(n, y+q/3, x+q%3, c+2, Xdata, X.dataLength) * K2cache[c+2][tk][q];
+ v += X.Get(n, y+q/3, x+q%3, c+3, Xdata, X.dataLength) * K2cache[c+3][tk][q];
+ }
+ O.Set(n, y, x, k, v, Odata, O.dataLength);
+ }
+ }
+}
+
+[numthreads(KERNEL_COUNT,CHANNEL_COUNT,1)]
+void Conv2D_Kernel3x3_32Channel(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint tk = groupThreadID.x;
+ uint k = KERNEL_COUNT*groupID.x + tk;
+ uint n = CHANNEL_COUNT*groupID.y + groupThreadID.y;
+
+ for (uint q = 0; q < 9; ++q)
+ {
+ uint tc = n % CHANNEL_COUNT;
+ K2cache[tc][tk][q] = K.Get(q/3, q%3, tc, k, WBKdata, WBK.dataLength);
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint y = 0; y < X.height - _Border; ++y)
+ {
+ for (uint x = 0; x < X.width - _Border; ++x)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ if (y+dy < _Offset) continue;
+ if (y+dy-_Offset >= X.height) continue;
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) continue;
+ if (x+dx-_Offset >= X.width) continue;
+
+ uint q = dy*3+dx;
+ for (uint c = 0; c < CHANNEL_COUNT; c += 4)
+ {
+ //K.Get(q/3, q%3, c, k, WBKdata, WBK.dataLength);
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+0, Xdata, X.dataLength) * K2cache[c+0][tk][q];
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+1, Xdata, X.dataLength) * K2cache[c+1][tk][q];
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+2, Xdata, X.dataLength) * K2cache[c+2][tk][q];
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+3, Xdata, X.dataLength) * K2cache[c+3][tk][q];
+ }
+ }
+ }
+ O.Set(n, y, x, k, v, Odata, O.dataLength);
+ }
+ }
+}
+
+groupshared float X2cache[2][CHANNEL_COUNT][KERNEL_COUNT];
+[numthreads(KERNEL_COUNT,CHANNEL_COUNT,1)]
+void Conv2D_Kernel3x3_32Channel_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint tk = groupThreadID.x;
+ uint tn = groupThreadID.y;
+ uint k = KERNEL_COUNT*groupID.x + tk;
+ uint n = CHANNEL_COUNT*groupID.y + tn;
+
+ for (uint q = 0; q < 9; ++q)
+ {
+ uint tc = n % CHANNEL_COUNT;
+ K2cache[q][tc][tk] = K.Get(q/3, q%3, tc, k, WBKdata, WBK.dataLength);
+ }
+ //GroupMemoryBarrierWithGroupSync(); <-- unnecessary, we have one inside the loop
+
+ for (uint y = 0; y < X.height - _FilterSize*2; ++y)
+ {
+ for (uint x = 0; x < X.width - _FilterSize*2; ++x)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint cBlock = 0; cBlock < CHANNEL_COUNT; cBlock += KERNEL_COUNT)
+ {
+ for (uint q = 0; q < 9; ++q)
+ {
+ uint tc = k % KERNEL_COUNT;
+ X2cache[q%2][tn][tc] = X.Get(n, y+q/3, x+q%3, cBlock+tc, Xdata, X.dataLength);
+ GroupMemoryBarrierWithGroupSync();
+
+ for (tc = 0; tc < KERNEL_COUNT; ++tc)
+ v += X2cache[q%2][tn][tc] * K2cache[q][cBlock+tc][tk];
+ }
+ }
+ O.Set(n, y, x, k, v, Odata, O.dataLength);
+ }
+ }
+}
+
+// 16x8 => 0.101
+// 32x4 => 0.114
+// 8x8 => 0.131
+
+#define PARAM_X 16
+#define PARAM_Y 8
+[numthreads(PARAM_X, PARAM_Y, 1)]
+void Conv2D_Kernel3x3_Kmod16_Cmod4_KN(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint k = PARAM_X * groupID.x + groupThreadID.x;
+ uint n = PARAM_Y * groupID.y + groupThreadID.y;
+
+ for (uint y = 0; y < X.height - _Border; ++y)
+ {
+ for (uint x = 0; x < X.width - _Border; ++x)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ if (y+dy < _Offset) continue;
+ if (y+dy-_Offset >= X.height) continue;
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) continue;
+ if (x+dx-_Offset >= X.width) continue;
+
+ for (uint c = 0; c < X.channels; c += 4)
+ {
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+0, Xdata, X.dataLength) * K.Get(dy, dx, c+0, k, WBKdata, WBK.dataLength);
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+1, Xdata, X.dataLength) * K.Get(dy, dx, c+1, k, WBKdata, WBK.dataLength);
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+2, Xdata, X.dataLength) * K.Get(dy, dx, c+2, k, WBKdata, WBK.dataLength);
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+3, Xdata, X.dataLength) * K.Get(dy, dx, c+3, k, WBKdata, WBK.dataLength);
+ }
+ }
+ }
+ O.Set(n, y, x, k, v, Odata, O.dataLength);
+ }
+ }
+}
+#undef PARAM_X
+#undef PARAM_Y
+#define PARAM_X 16
+#define PARAM_Y 8
+
+// 16x8 => 0.096
+// 8x8 => 0.117
+[numthreads(PARAM_X, PARAM_Y, 1)]
+void Conv2D_Kernel3x3_Kmod16_Cmod4_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint k = PARAM_X * groupID.x + groupThreadID.x;
+ uint nyx = PARAM_Y * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ //for (uint y = 0; y < X.height - _Border; ++y)
+ //{
+ // for (uint x = 0; x < X.width - _Border; ++x)
+ // {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ if (y+dy < _Offset) continue;
+ if (y+dy-_Offset >= X.height) continue;
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) continue;
+ if (x+dx-_Offset >= X.width) continue;
+
+ for (uint c = 0; c < X.channels; c += 4)
+ {
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+0, Xdata, X.dataLength) * K.Get(dy, dx, c+0, k, WBKdata, WBK.dataLength);
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+1, Xdata, X.dataLength) * K.Get(dy, dx, c+1, k, WBKdata, WBK.dataLength);
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+2, Xdata, X.dataLength) * K.Get(dy, dx, c+2, k, WBKdata, WBK.dataLength);
+ v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+3, Xdata, X.dataLength) * K.Get(dy, dx, c+3, k, WBKdata, WBK.dataLength);
+ }
+ }
+ }
+ O.Set(n, y, x, k, v, Odata, O.dataLength);
+ // }
+ //}
+}
+
+#undef CTILE
+#define CTILE 16
+
+#undef PARAM_X
+#undef PARAM_Y
+#define PARAM_X CTILE
+#define PARAM_Y CTILE
+
+#define TYPE float
+
+groupshared TYPE Conv_XcacheT[CTILE][CTILE];
+groupshared TYPE Conv_KcacheT[CTILE][CTILE];
+
+[numthreads(PARAM_X, PARAM_Y, 1)]
+void Conv2D_Kernel3x3_Cache_KCmod16_KNyx_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_XcacheT
+ #define K_ Conv_KcacheT
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = PARAM_X * groupID.x + groupThreadID.x;
+ uint nyx = PARAM_Y * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ //half v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ TYPE v = WBKdata[k + B.offset];
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (y+dy < _Offset) mask = false;
+ if (y+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) mask = false;
+ if (x+dx-_Offset >= X.width) mask = false;
+
+ int Xi = (( n * X.height +
+ y+dy-_Offset ) * X.width +
+ x+dx-_Offset ) * X.channels +
+ gx;
+
+ int Ki = (( dy * K.height +
+ dx ) * K.width +
+ /*m*CTILE +*/ gy ) * K.channels +
+ k + K.offset;
+
+ for (uint m = 0; m < X.channels/CTILE; ++m)
+ {
+ if (mask)
+ {
+ //X_[gy][gx] = X.Get(n, y+dy-_Offset, x+dx-_Offset, m*CTILE + gx, Xdata);
+ X_[gy][gx] = Xdata[Xi + m*CTILE];
+ }
+ else
+ {
+ X_[gy][gx] = 0;
+ }
+ //K_[gy][gx] = K.Get(dy, dx, m*CTILE + gy, k, WBKdata, WBK.dataLength);
+ //K_[gy][gx] = WBKdata[((
+ // dy * K.height +
+ // dx ) * K.width +
+ // m*CTILE + gy ) * K.channels +
+ // k + K.offset];
+ //K_[gy][gx] = WBKdata[Ki + m*CTILE * K.channels];
+ K_[gy][gx] = WBKdata[Ki + m*CTILE * K.channels];
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint i = 0; i < CTILE;)
+ {
+ /*
+ // can unroll up to CTILE
+ half4 x4 = ((half4[CTILE][CTILE/4])(X_))[gy][i];
+ half4 k4 = ((half4[CTILE][CTILE/4])(K_))[gx][i];
+
+ v += dot(x4, k4); ++i;
+ v += dot(x4, k4); ++i;
+ */
+
+ v += X_[gy][i] * K_[i][gx]; ++i;
+ v += X_[gy][i] * K_[i][gx]; ++i;
+ v += X_[gy][i] * K_[i][gx]; ++i;
+ v += X_[gy][i] * K_[i][gx]; ++i;
+ v += X_[gy][i] * K_[i][gx]; ++i;
+ v += X_[gy][i] * K_[i][gx]; ++i;
+ v += X_[gy][i] * K_[i][gx]; ++i;
+ v += X_[gy][i] * K_[i][gx]; ++i;
+
+ }
+ }
+ }
+ }
+ //O.Set(n, y, x, k, v, Odata, O.dataLength);
+ Odata[((
+ n * O.height +
+ y ) * O.width +
+ x ) * O.channels +
+ k] = v;
+
+ #undef X_
+ #undef K_
+}
+
+#undef CTILE
+#define CTILE 16
+groupshared float Conv_XcacheA[4][CTILE][CTILE];
+groupshared float Conv_Kcache0[CTILE][CTILE];
+groupshared float Conv_Kcache1[CTILE][CTILE];
+groupshared float Conv_Kcache2[CTILE][CTILE];
+groupshared float Conv_Kcache3[CTILE][CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Kernel3x3_Cache_KCmod32_KNyx____(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_XcacheA
+ #define K_0 Conv_Kcache0
+ #define K_1 Conv_Kcache1
+ #define K_2 Conv_Kcache2
+ #define K_3 Conv_Kcache3
+
+
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float b0 = B.Get(0, 0, k*2+0, 0, WBKdata, WBK.dataLength);
+ float b1 = B.Get(0, 0, k*2+1, 0, WBKdata, WBK.dataLength);
+ float4 v = float4(b0, b1,
+ b0, b1);
+
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (y+dy < _Offset) mask = false;
+ if (y+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) mask = false;
+ if (x+dx-_Offset >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*2); ++m)
+ {
+ float x0 = 0;
+ float x1 = 0;
+ float x2 = 0;
+ float x3 = 0;
+
+ if (mask)
+ {
+ x0 = X.Get(n*2+0, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+0, Xdata);
+ x1 = X.Get(n*2+0, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+1, Xdata);
+ x2 = X.Get(n*2+1, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+0, Xdata);
+ x3 = X.Get(n*2+1, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+1, Xdata);
+ }
+
+ float k0 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+0, WBKdata, WBK.dataLength);
+ float k1 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+1, WBKdata, WBK.dataLength);
+ float k2 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+0, WBKdata, WBK.dataLength);
+ float k3 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+1, WBKdata, WBK.dataLength);
+
+ //X_[gy][gx] = float4(x0, x1,
+ // x2, x3);
+ //K_[gy][gx] = float4(k0, k1,
+ // k2, k3);
+ X_[0][gy][gx] = x0;
+ X_[1][gy][gx] = x1;
+ X_[2][gy][gx] = x2;
+ X_[3][gy][gx] = x3;
+
+ K_0[gy][gx] = k0;
+ K_1[gy][gx] = k1;
+ K_2[gy][gx] = k2;
+ K_3[gy][gx] = k3;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < CTILE; ++i)
+ {
+ float4 x = //X_[gy][i];
+ float4( X_[0][gy][i],
+ X_[1][gy][i],
+ X_[2][gy][i],
+ X_[3][gy][i]);
+ //float4 k = //K_[i][gx];
+ // float4( K_0[i][gx],
+ // K_1[i][gx],
+ // K_2[i][gx],
+ // K_3[i][gx]);
+ k0 = K_0[i][gx];
+ k1 = K_1[i][gx];
+ k2 = K_2[i][gx];
+ k3 = K_3[i][gx];
+
+ v.x = mad(k0, x.x, v.x);
+ v.x = mad(k2, x.y, v.x);
+
+ v.y = mad(k1, x.x, v.y);
+ v.y = mad(k2, x.y, v.y);
+
+ v.z = mad(k0, x.z, v.z);
+ v.z = mad(k2, x.w, v.z);
+
+ v.w = mad(k1, x.z, v.w);
+ v.w = mad(k3, x.w, v.w);
+
+ //v.x += k.x*x.x + k.z*x.y;
+ //v.y += k.y*x.x + k.w*x.y;
+ //v.z += k.x*x.z + k.z*x.w;
+ //v.w += k.y*x.z + k.w*x.w;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ //Odata[nyx * O.channels + k] = v;
+
+ /*Odata[((
+ n * O.height +
+ y ) * O.width +
+ x ) * O.channels +
+ k] = v;
+ */
+
+ O.Set(n*2+0, y, x, k*2+0, v.x, Odata);
+ O.Set(n*2+0, y, x, k*2+1, v.y, Odata);
+ O.Set(n*2+1, y, x, k*2+0, v.z, Odata);
+ O.Set(n*2+1, y, x, k*2+1, v.w, Odata);
+
+ #undef X_
+ #undef K_
+}
+
+
+#undef CTILE
+#define CTILE 16
+groupshared float Conv_Xcache[4][CTILE][CTILE];
+groupshared float Conv_Kcache[4][CTILE][CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Kernel3x3_Cache_KCmod32_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_Xcache
+ #define K_ Conv_Kcache
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float b0 = B.Get(0, 0, k*2+0, 0, WBKdata, WBK.dataLength);
+ float b1 = B.Get(0, 0, k*2+1, 0, WBKdata, WBK.dataLength);
+ float4 v = float4(b0, b1,
+ b0, b1);
+
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (y+dy < _Offset) mask = false;
+ if (y+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) mask = false;
+ if (x+dx-_Offset >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*2); ++m)
+ {
+ float x0 = 0;
+ float x1 = 0;
+ float x2 = 0;
+ float x3 = 0;
+
+ if (mask)
+ {
+ x0 = X.Get(n*2+0, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+0, Xdata);
+ x1 = X.Get(n*2+0, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+1, Xdata);
+ x2 = X.Get(n*2+1, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+0, Xdata);
+ x3 = X.Get(n*2+1, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+1, Xdata);
+ }
+
+ float k0 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+0, WBKdata, WBK.dataLength);
+ float k1 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+1, WBKdata, WBK.dataLength);
+ float k2 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+0, WBKdata, WBK.dataLength);
+ float k3 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+1, WBKdata, WBK.dataLength);
+
+ //X_[gy][gx] = float4(x0, x1,
+ // x2, x3);
+ //K_[gy][gx] = float4(k0, k1,
+ // k2, k3);
+ X_[0][gy][gx] = x0;
+ X_[1][gy][gx] = x1;
+ X_[2][gy][gx] = x2;
+ X_[3][gy][gx] = x3;
+
+ K_[0][gy][gx] = k0;
+ K_[1][gy][gx] = k1;
+ K_[2][gy][gx] = k2;
+ K_[3][gy][gx] = k3;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < CTILE; ++i)
+ {
+ float4 x = //X_[gy][i];
+ float4( X_[0][gy][i],
+ X_[1][gy][i],
+ X_[2][gy][i],
+ X_[3][gy][i]);
+ float4 k = //K_[i][gx];
+ float4( K_[0][i][gx],
+ K_[1][i][gx],
+ K_[2][i][gx],
+ K_[3][i][gx]);
+
+ v.x = mad(k.x, x.x, v.x);
+ v.x = mad(k.z, x.y, v.x);
+
+ v.y = mad(k.y, x.x, v.y);
+ v.y = mad(k.w, x.y, v.y);
+
+ v.z = mad(k.x, x.z, v.z);
+ v.z = mad(k.z, x.w, v.z);
+
+ v.w = mad(k.y, x.z, v.w);
+ v.w = mad(k.w, x.w, v.w);
+
+ //v.x += k.x*x.x + k.z*x.y;
+ //v.y += k.y*x.x + k.w*x.y;
+ //v.z += k.x*x.z + k.z*x.w;
+ //v.w += k.y*x.z + k.w*x.w;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ //Odata[nyx * O.channels + k] = v;
+
+ /*Odata[((
+ n * O.height +
+ y ) * O.width +
+ x ) * O.channels +
+ k] = v;
+ */
+
+ O.Set(n*2+0, y, x, k*2+0, v.x, Odata);
+ O.Set(n*2+0, y, x, k*2+1, v.y, Odata);
+ O.Set(n*2+1, y, x, k*2+0, v.z, Odata);
+ O.Set(n*2+1, y, x, k*2+1, v.w, Odata);
+
+ #undef X_
+ #undef K_
+}
+
+#if 0 // =====================================================================================================
+
+#undef CTILE
+#define CTILE 16
+#define RTILE 4
+groupshared float Conv_XcacheR[RTILE*RTILE][CTILE*CTILE];
+groupshared float Conv_KcacheR[RTILE*RTILE][CTILE*CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Kernel3x3_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_XcacheR
+ #define K_ Conv_KcacheR
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float v[RTILE][RTILE];
+ for (uint xxxx = 0; xxxx < RTILE; ++xxxx)
+ {
+ float b = B.Get(0, 0, k*RTILE+xxxx, 0, WBKdata, WBK.dataLength);
+ for (uint yyyy = 0; yyyy < RTILE; ++yyyy)
+ v[yyyy][xxxx] = b;
+ }
+
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (y+dy < _Offset) mask = false;
+ if (y+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) mask = false;
+ if (x+dx-_Offset >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*RTILE); ++m)
+ {
+ for (uint yy = 0; yy < RTILE; ++yy)
+ for (uint xx = 0; xx < RTILE; ++xx)
+ {
+ if (mask)
+ X_[yy*RTILE+xx][gy*CTILE+gx] = X.Get(n*RTILE+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*RTILE+xx, Xdata);
+ else
+ X_[yy*RTILE+xx][gy*CTILE+gx] = 0;
+ K_[yy*RTILE+xx][gy*CTILE+gx] = K.Get(dy, dx, (m*CTILE + gy)*RTILE+yy, k*RTILE+xx, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint ii = 0; ii < CTILE; ++ii)
+ {
+ float x[RTILE][RTILE];
+ float k[RTILE][RTILE];
+
+ [unroll]
+ for (uint yy = 0; yy < RTILE; ++yy)
+ {
+ [unroll]
+ for (uint xx = 0; xx < RTILE; ++xx)
+ {
+ x[yy][xx] = X_[yy*RTILE+xx][gy*CTILE+ii];
+ k[yy][xx] = K_[yy*RTILE+xx][ii*CTILE+gx];
+ }
+ }
+
+
+ [unroll]
+ for (uint yyy = 0; yyy < RTILE; ++yyy)
+ {
+ [unroll]
+ for (uint xxx = 0; xxx < RTILE; ++xxx)
+ {
+ [unroll]
+ for (uint i = 0; i < RTILE; ++i)
+ {
+ v[yyy][xxx] = mad(x[yyy][i], k[i][xxx], v[yyy][xxx]);
+ }
+ }
+ }
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ for (uint yy = 0; yy < RTILE; ++yy)
+ for (uint xx = 0; xx < RTILE; ++xx)
+ O.Set(n*RTILE+yy, y, x, k*RTILE+xx, v[yy][xx], Odata);
+
+ #undef X_
+ #undef K_
+}
+
+#elif 1 // =====================================================================================================
+
+#undef CTILE
+#define CTILE 16
+groupshared float2 Conv_KcacheR[8][CTILE*CTILE];
+groupshared float2 Conv_XcacheR[8][CTILE*CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Kernel3x3_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_XcacheR
+ #define K_ Conv_KcacheR
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float b0 = B.Get(0, 0, k*4+0, 0, WBKdata, WBK.dataLength);
+ float b1 = B.Get(0, 0, k*4+1, 0, WBKdata, WBK.dataLength);
+ float b2 = B.Get(0, 0, k*4+2, 0, WBKdata, WBK.dataLength);
+ float b3 = B.Get(0, 0, k*4+3, 0, WBKdata, WBK.dataLength);
+
+ float4 v0, v1, v2, v3;
+ v0 = v1 = v2 = v3 = float4(b0, b1, b2, b3);
+
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (y+dy < _Offset) mask = false;
+ if (y+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) mask = false;
+ if (x+dx-_Offset >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*4); ++m)
+ {
+ for (uint yy = 0; yy < 4; ++yy)
+ for (uint xx = 0; xx < 2; ++xx)
+ {
+ // 111ms
+ if (mask)
+ {
+ X_[yy*2+xx][gy*CTILE+gx].x = X.Get(n*4+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*4+xx*2+0, Xdata);
+ X_[yy*2+xx][gy*CTILE+gx].y = X.Get(n*4+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*4+xx*2+1, Xdata);
+ }
+ else
+ {
+ X_[yy*2+xx][gy*CTILE+gx].x = 0;
+ X_[yy*2+xx][gy*CTILE+gx].y = 0;
+ }
+
+ K_[yy*2+xx][gy*CTILE+gx].x = K.Get(dy, dx, (m*CTILE + gy)*4+yy, k*4+xx*2+0, WBKdata, WBK.dataLength);
+ K_[yy*2+xx][gy*CTILE+gx].y = K.Get(dy, dx, (m*CTILE + gy)*4+yy, k*4+xx*2+1, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint i = 0; i < CTILE; ++i)
+ {
+ #if 1 // ----------------------------------------------------------
+
+ float2 x[8];
+ float2 k[8];
+
+ // 109ms
+ // dcl_temps 29
+ for (uint regs = 0; regs < 8; ++regs)
+ {
+ x[regs] = X_[regs][gy*CTILE+i];
+ k[regs] = K_[regs][i*CTILE+gx];
+ }
+
+ for (uint q = 0; q < 4; ++q)
+ {
+ float
+ k0 = k[q*2+0].x,
+ k1 = k[q*2+0].y,
+ k2 = k[q*2+1].x,
+ k3 = k[q*2+1].y;
+ float
+ x0 = x[0+q/2].x,
+ x1 = x[2+q/2].x,
+ x2 = x[4+q/2].x,
+ x3 = x[6+q/2].x;
+
+ v0.x = mad(x0, k0, v0.x); //--
+ v1.x = mad(x1, k0, v1.x);
+ v2.x = mad(x2, k0, v2.x);
+ v3.x = mad(x3, k0, v3.x);
+ v0.y = mad(x0, k1, v0.y); //--
+ v1.y = mad(x1, k1, v1.y);
+ v2.y = mad(x2, k1, v2.y);
+ v3.y = mad(x3, k1, v3.y);
+ v0.z = mad(x0, k2, v0.z); //--
+ v1.z = mad(x1, k2, v1.z);
+ v2.z = mad(x2, k2, v2.z);
+ v3.z = mad(x3, k2, v3.z);
+ v0.w = mad(x0, k3, v0.w); //--
+ v1.w = mad(x1, k3, v1.w);
+ v2.w = mad(x2, k3, v2.w);
+ v3.w = mad(x3, k3, v3.w);
+
+ ++q;
+
+ k0 = k[q*2+0].x;
+ k1 = k[q*2+0].y;
+ k2 = k[q*2+1].x;
+ k3 = k[q*2+1].y;
+
+ x0 = x[0+q/2].y;
+ x1 = x[2+q/2].y;
+ x2 = x[4+q/2].y;
+ x3 = x[6+q/2].y;
+
+ v0.x = mad(x0, k0, v0.x); //--
+ v1.x = mad(x1, k0, v1.x);
+ v2.x = mad(x2, k0, v2.x);
+ v3.x = mad(x3, k0, v3.x);
+ v0.y = mad(x0, k1, v0.y); //--
+ v1.y = mad(x1, k1, v1.y);
+ v2.y = mad(x2, k1, v2.y);
+ v3.y = mad(x3, k1, v3.y);
+ v0.z = mad(x0, k2, v0.z); //--
+ v1.z = mad(x1, k2, v1.z);
+ v2.z = mad(x2, k2, v2.z);
+ v3.z = mad(x3, k2, v3.z);
+ v0.w = mad(x0, k3, v0.w); //--
+ v1.w = mad(x1, k3, v1.w);
+ v2.w = mad(x2, k3, v2.w);
+ v3.w = mad(x3, k3, v3.w);
+ }
+
+ #endif // ----------------------------------------------------------
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ #if 1 // ----------------------------------------------------------
+
+ // 117ms
+ O.Set(n*4+0, y, x, k*4+0, v0.x, Odata);
+ O.Set(n*4+0, y, x, k*4+1, v0.y, Odata);
+ O.Set(n*4+0, y, x, k*4+2, v0.z, Odata);
+ O.Set(n*4+0, y, x, k*4+3, v0.w, Odata);
+
+ O.Set(n*4+1, y, x, k*4+0, v1.x, Odata);
+ O.Set(n*4+1, y, x, k*4+1, v1.y, Odata);
+ O.Set(n*4+1, y, x, k*4+2, v1.z, Odata);
+ O.Set(n*4+1, y, x, k*4+3, v1.w, Odata);
+
+ O.Set(n*4+2, y, x, k*4+0, v2.x, Odata);
+ O.Set(n*4+2, y, x, k*4+1, v2.y, Odata);
+ O.Set(n*4+2, y, x, k*4+2, v2.z, Odata);
+ O.Set(n*4+2, y, x, k*4+3, v2.w, Odata);
+
+ O.Set(n*4+3, y, x, k*4+0, v3.x, Odata);
+ O.Set(n*4+3, y, x, k*4+1, v3.y, Odata);
+ O.Set(n*4+3, y, x, k*4+2, v3.z, Odata);
+ O.Set(n*4+3, y, x, k*4+3, v3.w, Odata);
+
+ #else // ----------------------------------------------------------
+
+ // 118ms
+ O.Set(n*4+0, y, x, k*4+0, v0.x, Odata);
+ O.Set(n*4+1, y, x, k*4+0, v1.x, Odata);
+ O.Set(n*4+2, y, x, k*4+0, v2.x, Odata);
+ O.Set(n*4+3, y, x, k*4+0, v3.x, Odata);
+
+ O.Set(n*4+0, y, x, k*4+1, v0.y, Odata);
+ O.Set(n*4+1, y, x, k*4+1, v1.y, Odata);
+ O.Set(n*4+2, y, x, k*4+1, v2.y, Odata);
+ O.Set(n*4+3, y, x, k*4+1, v3.y, Odata);
+
+ O.Set(n*4+0, y, x, k*4+2, v0.z, Odata);
+ O.Set(n*4+1, y, x, k*4+2, v1.z, Odata);
+ O.Set(n*4+2, y, x, k*4+2, v2.z, Odata);
+ O.Set(n*4+3, y, x, k*4+2, v3.z, Odata);
+
+ O.Set(n*4+0, y, x, k*4+3, v0.w, Odata);
+ O.Set(n*4+1, y, x, k*4+3, v1.w, Odata);
+ O.Set(n*4+2, y, x, k*4+3, v2.w, Odata);
+ O.Set(n*4+3, y, x, k*4+3, v3.w, Odata);
+
+ #endif // ----------------------------------------------------------
+
+
+ #undef X_
+ #undef K_
+}
+
+#elif 1 // =====================================================================================================
+
+#undef CTILE
+#define CTILE 16
+groupshared float Conv_KcacheR[16][CTILE*CTILE];
+groupshared float Conv_XcacheR[16][CTILE*CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Kernel3x3_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_XcacheR
+ #define K_ Conv_KcacheR
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float b0 = B.Get(0, 0, k*4+0, 0, WBKdata, WBK.dataLength);
+ float b1 = B.Get(0, 0, k*4+1, 0, WBKdata, WBK.dataLength);
+ float b2 = B.Get(0, 0, k*4+2, 0, WBKdata, WBK.dataLength);
+ float b3 = B.Get(0, 0, k*4+3, 0, WBKdata, WBK.dataLength);
+
+ float4 v0, v1, v2, v3;
+ v0 = v1 = v2 = v3 = float4(b0, b1, b2, b3);
+
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (y+dy < _Offset) mask = false;
+ if (y+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) mask = false;
+ if (x+dx-_Offset >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*4); ++m)
+ {
+ for (uint yy = 0; yy < 4; ++yy)
+ for (uint xx = 0; xx < 4; ++xx)
+ {
+ #if 1 // ----------------------------------------------------------
+
+ // 111ms
+ if (mask)
+ X_[yy*4+xx][gy*CTILE+gx] = X.Get(n*4+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*4+xx, Xdata);
+ else
+ X_[yy*4+xx][gy*CTILE+gx] = 0;
+ K_[yy*4+xx][gy*CTILE+gx] = K.Get(dy, dx, (m*CTILE + gy)*4+yy, k*4+xx, WBKdata, WBK.dataLength);
+
+ #else // ----------------------------------------------------------
+
+ // 122ms
+ if (mask)
+ X_[yy*4+(gx%4)][gy*CTILE+xx*4+(gx/4)] = X.Get(n*4+yy, y+dy-_Offset, x+dx-_Offset, m*CTILE*4 + xx*CTILE + gx, Xdata);
+ else
+ X_[yy*4+(gx%4)][gy*CTILE+xx*4+(gx/4)] = 0;
+ K_[yy*4+(k%4)][gy*CTILE+xx*4+(gx/4)] = K.Get(dy, dx, (m*CTILE + gy)*4+yy, CTILE*groupID.x*4 + xx*CTILE + gx, WBKdata, WBK.dataLength);
+
+ #endif // ----------------------------------------------------------
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint i = 0; i < CTILE; ++i)
+ {
+
+ #if 0 // ----------------------------------------------------------
+
+ float x[16];
+ float k[16];
+
+ k[0] = K_[0][i*CTILE+gx];
+ x[0] = X_[0][gy*CTILE+i];
+ x[4] = X_[4][gy*CTILE+i];
+ x[8] = X_[8][gy*CTILE+i];
+ x[12] = X_[12][gy*CTILE+i];
+
+ for (uint q = 0; q < 3; ++q)
+ {
+ k[q*4+1] = K_[q*4+1][i*CTILE+gx];
+ v0.x = mad(x[0*4+q], k[q*4+0], v0.x); //--
+ v1.x = mad(x[1*4+q], k[q*4+0], v1.x);
+ x[0*4+q+1] = X_[0*4+q+1][gy*CTILE+i];
+ v2.x = mad(x[2*4+q], k[q*4+0], v2.x);
+ v3.x = mad(x[3*4+q], k[q*4+0], v3.x);
+ k[q*4+2] = K_[q*4+2][i*CTILE+gx];
+ v0.y = mad(x[0*4+q], k[q*4+1], v0.y); //--
+ v1.y = mad(x[1*4+q], k[q*4+1], v1.y);
+ x[1*4+q+1] = X_[1*4+q+1][gy*CTILE+i];
+ v2.y = mad(x[2*4+q], k[q*4+1], v2.y);
+ v3.y = mad(x[3*4+q], k[q*4+1], v3.y);
+ k[q*4+3] = K_[q*4+3][i*CTILE+gx];
+ v0.z = mad(x[0*4+q], k[q*4+2], v0.z); //--
+ v1.z = mad(x[1*4+q], k[q*4+2], v1.z);
+ x[2*4+q+1] = X_[2*4+q+1][gy*CTILE+i];
+ v2.z = mad(x[2*4+q], k[q*4+2], v2.z);
+ v3.z = mad(x[3*4+q], k[q*4+2], v3.z);
+ k[q*4+4] = K_[q*4+4][i*CTILE+gx];
+ v0.w = mad(x[0*4+q], k[q*4+3], v0.w); //--
+ v1.w = mad(x[1*4+q], k[q*4+3], v1.w);
+ x[3*4+q+1] = X_[3*4+q+1][gy*CTILE+i];
+ v2.w = mad(x[2*4+q], k[q*4+3], v2.w);
+ v3.w = mad(x[3*4+q], k[q*4+3], v3.w);
+ }
+ {
+ k[q*4+1] = K_[q*4+1][i*CTILE+gx];
+ v0.x = mad(x[0*4+q], k[q*4+0], v0.x); //--
+ v1.x = mad(x[1*4+q], k[q*4+0], v1.x);
+ v2.x = mad(x[2*4+q], k[q*4+0], v2.x);
+ v3.x = mad(x[3*4+q], k[q*4+0], v3.x);
+ k[q*4+2] = K_[q*4+2][i*CTILE+gx];
+ v0.y = mad(x[0*4+q], k[q*4+1], v0.y); //--
+ v1.y = mad(x[1*4+q], k[q*4+1], v1.y);
+ v2.y = mad(x[2*4+q], k[q*4+1], v2.y);
+ v3.y = mad(x[3*4+q], k[q*4+1], v3.y);
+ k[q*4+3] = K_[q*4+3][i*CTILE+gx];
+ v0.z = mad(x[0*4+q], k[q*4+2], v0.z); //--
+ v1.z = mad(x[1*4+q], k[q*4+2], v1.z);
+ v2.z = mad(x[2*4+q], k[q*4+2], v2.z);
+ v3.z = mad(x[3*4+q], k[q*4+2], v3.z);
+ v0.w = mad(x[0*4+q], k[q*4+3], v0.w); //--
+ v1.w = mad(x[1*4+q], k[q*4+3], v1.w);
+ v2.w = mad(x[2*4+q], k[q*4+3], v2.w);
+ v3.w = mad(x[3*4+q], k[q*4+3], v3.w);
+ }
+
+ #elif 0 // ----------------------------------------------------------
+
+ //float x[4];
+ //float k[4];
+
+ float k0 = K_[0*4+0][i*CTILE+gx];
+ float x0 = X_[0*4+0][gy*CTILE+i];
+ float x1 = X_[1*4+0][gy*CTILE+i];
+ float x2 = X_[2*4+0][gy*CTILE+i];
+ float x3 = X_[3*4+0][gy*CTILE+i];
+
+ float k1, k2, k3;
+ float x0p, x1p, x2p, x3p;
+
+ uint q = 0;
+ //for (uint q = 0; q < 4;)
+ {
+ //x[regs] = X_[regs][gy*CTILE+i];
+
+ k1 = K_[q*4+1][i*CTILE+gx];
+ v0.x = mad(x0, k0, v0.x); //--
+ v1.x = mad(x1, k0, v1.x);
+ x0p = X_[0*4+q+1][gy*CTILE+i];
+ v2.x = mad(x2, k0, v2.x);
+ v3.x = mad(x3, k0, v3.x);
+
+ k2 = K_[q*4+2][i*CTILE+gx];
+ v0.y = mad(x0, k1, v0.y); //--
+ v1.y = mad(x1, k1, v1.y);
+ x1p = X_[1*4+q+1][gy*CTILE+i];
+ v2.y = mad(x2, k1, v2.y);
+ v3.y = mad(x3, k1, v3.y);
+
+ k3 = K_[q*4+3][i*CTILE+gx];
+ v0.z = mad(x0, k2, v0.z); //--
+ v1.z = mad(x1, k2, v1.z);
+ x2p = X_[2*4+q+1][gy*CTILE+i];
+ v2.z = mad(x2, k2, v2.z);
+ v3.z = mad(x3, k2, v3.z);
+
+ k0 = K_[q*4+4][i*CTILE+gx];
+ v0.w = mad(x0, k3, v0.w); //--
+ v1.w = mad(x1, k3, v1.w);
+ x3p = X_[3*4+q+1][gy*CTILE+i];
+ v2.w = mad(x2, k3, v2.w);
+ v3.w = mad(x3, k3, v3.w);
+
+ ++q;
+
+ k1 = K_[q*4+1][i*CTILE+gx];
+ v0.x = mad(x0p, k0, v0.x); //--
+ v1.x = mad(x1p, k0, v1.x);
+ x0 = X_[0*4+q+1][gy*CTILE+i];
+ v2.x = mad(x2p, k0, v2.x);
+ v3.x = mad(x3p, k0, v3.x);
+
+ k2 = K_[q*4+2][i*CTILE+gx];
+ v0.y = mad(x0p, k1, v0.y); //--
+ v1.y = mad(x1p, k1, v1.y);
+ x1 = X_[1*4+q+1][gy*CTILE+i];
+ v2.y = mad(x2p, k1, v2.y);
+ v3.y = mad(x3p, k1, v3.y);
+
+ k3 = K_[q*4+3][i*CTILE+gx];
+ v0.z = mad(x0p, k2, v0.z); //--
+ v1.z = mad(x1p, k2, v1.z);
+ x2 = X_[2*4+q+1][gy*CTILE+i];
+ v2.z = mad(x2p, k2, v2.z);
+ v3.z = mad(x3p, k2, v3.z);
+
+ k0 = K_[q*4+4][i*CTILE+gx];
+ v0.w = mad(x0p, k3, v0.w); //--
+ v1.w = mad(x1p, k3, v1.w);
+ x3 = X_[3*4+q+1][gy*CTILE+i];
+ v2.w = mad(x2p, k3, v2.w);
+ v3.w = mad(x3p, k3, v3.w);
+
+ ++q;
+
+ k1 = K_[q*4+1][i*CTILE+gx];
+ v0.x = mad(x0, k0, v0.x); //--
+ v1.x = mad(x1, k0, v1.x);
+ x0p = X_[0*4+q+1][gy*CTILE+i];
+ v2.x = mad(x2, k0, v2.x);
+ v3.x = mad(x3, k0, v3.x);
+
+ k2 = K_[q*4+2][i*CTILE+gx];
+ v0.y = mad(x0, k1, v0.y); //--
+ v1.y = mad(x1, k1, v1.y);
+ x1p = X_[1*4+q+1][gy*CTILE+i];
+ v2.y = mad(x2, k1, v2.y);
+ v3.y = mad(x3, k1, v3.y);
+
+ k3 = K_[q*4+3][i*CTILE+gx];
+ v0.z = mad(x0, k2, v0.z); //--
+ v1.z = mad(x1, k2, v1.z);
+ x2p = X_[2*4+q+1][gy*CTILE+i];
+ v2.z = mad(x2, k2, v2.z);
+ v3.z = mad(x3, k2, v3.z);
+
+ k0 = K_[q*4+4][i*CTILE+gx];
+ v0.w = mad(x0, k3, v0.w); //--
+ v1.w = mad(x1, k3, v1.w);
+ x3p = X_[3*4+q+1][gy*CTILE+i];
+ v2.w = mad(x2, k3, v2.w);
+ v3.w = mad(x3, k3, v3.w);
+
+ ++q;
+
+ k1 = K_[q*4+1][i*CTILE+gx];
+ v0.x = mad(x0p, k0, v0.x); //--
+ v1.x = mad(x1p, k0, v1.x);
+ //x0p = X_[0*4+q][gy*CTILE+i];
+ v2.x = mad(x2p, k0, v2.x);
+ v3.x = mad(x3p, k0, v3.x);
+
+ k2 = K_[q*4+2][i*CTILE+gx];
+ v0.y = mad(x0p, k1, v0.y); //--
+ v1.y = mad(x1p, k1, v1.y);
+ //x1p = X_[1*4+q][gy*CTILE+i];
+ v2.y = mad(x2p, k1, v2.y);
+ v3.y = mad(x3p, k1, v3.y);
+
+ k3 = K_[q*4+3][i*CTILE+gx];
+ v0.z = mad(x0p, k2, v0.z); //--
+ v1.z = mad(x1p, k2, v1.z);
+ //x2p = X_[2*4+q][gy*CTILE+i];
+ v2.z = mad(x2p, k2, v2.z);
+ v3.z = mad(x3p, k2, v3.z);
+
+ //k0 = K_[(q+1)*4][i*CTILE+gx];
+ v0.w = mad(x0p, k3, v0.w); //--
+ v1.w = mad(x1p, k3, v1.w);
+ //x3p = X_[3*4+q][gy*CTILE+i];
+ v2.w = mad(x2p, k3, v2.w);
+ v3.w = mad(x3p, k3, v3.w);
+
+ ++q;
+ }
+
+
+ #elif 1 // ----------------------------------------------------------
+
+ float x[16];
+ float k[16];
+
+ // 109ms
+ // dcl_temps 29
+ for (uint regs = 0; regs < 16; ++regs)
+ {
+ x[regs] = X_[regs][gy*CTILE+i];
+ k[regs] = K_[regs][i*CTILE+gx];
+ }
+
+ for (uint q = 0; q < 4; ++q)
+ {
+ v0.x = mad(x[0*4+q], k[q*4+0], v0.x); //--
+ v1.x = mad(x[1*4+q], k[q*4+0], v1.x);
+ v2.x = mad(x[2*4+q], k[q*4+0], v2.x);
+ v3.x = mad(x[3*4+q], k[q*4+0], v3.x);
+ v0.y = mad(x[0*4+q], k[q*4+1], v0.y); //--
+ v1.y = mad(x[1*4+q], k[q*4+1], v1.y);
+ v2.y = mad(x[2*4+q], k[q*4+1], v2.y);
+ v3.y = mad(x[3*4+q], k[q*4+1], v3.y);
+ v0.z = mad(x[0*4+q], k[q*4+2], v0.z); //--
+ v1.z = mad(x[1*4+q], k[q*4+2], v1.z);
+ v2.z = mad(x[2*4+q], k[q*4+2], v2.z);
+ v3.z = mad(x[3*4+q], k[q*4+2], v3.z);
+ v0.w = mad(x[0*4+q], k[q*4+3], v0.w); //--
+ v1.w = mad(x[1*4+q], k[q*4+3], v1.w);
+ v2.w = mad(x[2*4+q], k[q*4+3], v2.w);
+ v3.w = mad(x[3*4+q], k[q*4+3], v3.w);
+ }
+
+ #elif 1 // ----------------------------------------------------------
+
+ // 111ms
+ // dcl_temps 34
+ [unroll]
+ for (uint regs = 0; regs < 16; ++regs)
+ {
+ x[regs] = X_[regs][gy*CTILE+i];
+ k[regs] = K_[regs][i*CTILE+gx];
+ }
+ v0.x = mad(x[0*4+0], k[0*4+0], v0.x); //--
+ v1.x = mad(x[1*4+0], k[0*4+0], v1.x);
+ v2.x = mad(x[2*4+0], k[0*4+0], v2.x);
+ v3.x = mad(x[3*4+0], k[0*4+0], v3.x);
+ v0.y = mad(x[0*4+0], k[0*4+1], v0.y); //--
+ v1.y = mad(x[1*4+0], k[0*4+1], v1.y);
+ v2.y = mad(x[2*4+0], k[0*4+1], v2.y);
+ v3.y = mad(x[3*4+0], k[0*4+1], v3.y);
+ v0.z = mad(x[0*4+0], k[0*4+2], v0.z); //--
+ v1.z = mad(x[1*4+0], k[0*4+2], v1.z);
+ v2.z = mad(x[2*4+0], k[0*4+2], v2.z);
+ v3.z = mad(x[3*4+0], k[0*4+2], v3.z);
+ v0.w = mad(x[0*4+0], k[0*4+3], v0.w); //--
+ v1.w = mad(x[1*4+0], k[0*4+3], v1.w);
+ v2.w = mad(x[2*4+0], k[0*4+3], v2.w);
+ v3.w = mad(x[3*4+0], k[0*4+3], v3.w);
+
+ v0.x = mad(x[0*4+1], k[1*4+0], v0.x); //--
+ v1.x = mad(x[1*4+1], k[1*4+0], v1.x);
+ v2.x = mad(x[2*4+1], k[1*4+0], v2.x);
+ v3.x = mad(x[3*4+1], k[1*4+0], v3.x);
+ v0.y = mad(x[0*4+1], k[1*4+1], v0.y); //--
+ v1.y = mad(x[1*4+1], k[1*4+1], v1.y);
+ v2.y = mad(x[2*4+1], k[1*4+1], v2.y);
+ v3.y = mad(x[3*4+1], k[1*4+1], v3.y);
+ v0.z = mad(x[0*4+1], k[1*4+2], v0.z); //--
+ v1.z = mad(x[1*4+1], k[1*4+2], v1.z);
+ v2.z = mad(x[2*4+1], k[1*4+2], v2.z);
+ v3.z = mad(x[3*4+1], k[1*4+2], v3.z);
+ v0.w = mad(x[0*4+1], k[1*4+3], v0.w); //--
+ v1.w = mad(x[1*4+1], k[1*4+3], v1.w);
+ v2.w = mad(x[2*4+1], k[1*4+3], v2.w);
+ v3.w = mad(x[3*4+1], k[1*4+3], v3.w);
+
+ v0.x = mad(x[0*4+2], k[2*4+0], v0.x); //--
+ v1.x = mad(x[1*4+2], k[2*4+0], v1.x);
+ v2.x = mad(x[2*4+2], k[2*4+0], v2.x);
+ v3.x = mad(x[3*4+2], k[2*4+0], v3.x);
+ v0.y = mad(x[0*4+2], k[2*4+1], v0.y); //--
+ v1.y = mad(x[1*4+2], k[2*4+1], v1.y);
+ v2.y = mad(x[2*4+2], k[2*4+1], v2.y);
+ v3.y = mad(x[3*4+2], k[2*4+1], v3.y);
+ v0.z = mad(x[0*4+2], k[2*4+2], v0.z); //--
+ v1.z = mad(x[1*4+2], k[2*4+2], v1.z);
+ v2.z = mad(x[2*4+2], k[2*4+2], v2.z);
+ v3.z = mad(x[3*4+2], k[2*4+2], v3.z);
+ v0.w = mad(x[0*4+2], k[2*4+3], v0.w); //--
+ v1.w = mad(x[1*4+2], k[2*4+3], v1.w);
+ v2.w = mad(x[2*4+2], k[2*4+3], v2.w);
+ v3.w = mad(x[3*4+2], k[2*4+3], v3.w);
+
+ v0.x = mad(x[0*4+3], k[3*4+0], v0.x); //--
+ v1.x = mad(x[1*4+3], k[3*4+0], v1.x);
+ v2.x = mad(x[2*4+3], k[3*4+0], v2.x);
+ v3.x = mad(x[3*4+3], k[3*4+0], v3.x);
+ v0.y = mad(x[0*4+3], k[3*4+1], v0.y); //--
+ v1.y = mad(x[1*4+3], k[3*4+1], v1.y);
+ v2.y = mad(x[2*4+3], k[3*4+1], v2.y);
+ v3.y = mad(x[3*4+3], k[3*4+1], v3.y);
+ v0.z = mad(x[0*4+3], k[3*4+2], v0.z); //--
+ v1.z = mad(x[1*4+3], k[3*4+2], v1.z);
+ v2.z = mad(x[2*4+3], k[3*4+2], v2.z);
+ v3.z = mad(x[3*4+3], k[3*4+2], v3.z);
+ v0.w = mad(x[0*4+3], k[3*4+3], v0.w); //--
+ v1.w = mad(x[1*4+3], k[3*4+3], v1.w);
+ v2.w = mad(x[2*4+3], k[3*4+3], v2.w);
+ v3.w = mad(x[3*4+3], k[3*4+3], v3.w);
+
+ #else // ----------------------------------------------------------
+
+ // 115 ms, reg dependencies
+ // dcl_temps 32
+ [unroll]
+ for (uint regs = 0; regs < 16; ++regs)
+ {
+ x[regs] = X_[regs][gy*CTILE+i];
+ k[regs] = K_[regs][i*CTILE+gx];
+ }
+
+ v0.x = mad(x[0*4+0], k[0*4+0], v0.x); //--
+ v0.x = mad(x[0*4+1], k[1*4+0], v0.x);
+ v0.x = mad(x[0*4+2], k[2*4+0], v0.x);
+ v0.x = mad(x[0*4+3], k[3*4+0], v0.x);
+ v0.y = mad(x[0*4+0], k[0*4+1], v0.y); //--
+ v0.y = mad(x[0*4+1], k[1*4+1], v0.y);
+ v0.y = mad(x[0*4+2], k[2*4+1], v0.y);
+ v0.y = mad(x[0*4+3], k[3*4+1], v0.y);
+ v0.z = mad(x[0*4+0], k[0*4+2], v0.z); //--
+ v0.z = mad(x[0*4+1], k[1*4+2], v0.z);
+ v0.z = mad(x[0*4+2], k[2*4+2], v0.z);
+ v0.z = mad(x[0*4+3], k[3*4+2], v0.z);
+ v0.w = mad(x[0*4+0], k[0*4+3], v0.w); //--
+ v0.w = mad(x[0*4+1], k[1*4+3], v0.w);
+ v0.w = mad(x[0*4+2], k[2*4+3], v0.w);
+ v0.w = mad(x[0*4+3], k[3*4+3], v0.w);
+
+ v1.x = mad(x[1*4+0], k[0*4+0], v1.x); //--
+ v1.x = mad(x[1*4+1], k[1*4+0], v1.x);
+ v1.x = mad(x[1*4+2], k[2*4+0], v1.x);
+ v1.x = mad(x[1*4+3], k[3*4+0], v1.x);
+ v1.y = mad(x[1*4+0], k[0*4+1], v1.y); //--
+ v1.y = mad(x[1*4+1], k[1*4+1], v1.y);
+ v1.y = mad(x[1*4+2], k[2*4+1], v1.y);
+ v1.y = mad(x[1*4+3], k[3*4+1], v1.y);
+ v1.z = mad(x[1*4+0], k[0*4+2], v1.z); //--
+ v1.z = mad(x[1*4+1], k[1*4+2], v1.z);
+ v1.z = mad(x[1*4+2], k[2*4+2], v1.z);
+ v1.z = mad(x[1*4+3], k[3*4+2], v1.z);
+ v1.w = mad(x[1*4+0], k[0*4+3], v1.w); //--
+ v1.w = mad(x[1*4+1], k[1*4+3], v1.w);
+ v1.w = mad(x[1*4+2], k[2*4+3], v1.w);
+ v1.w = mad(x[1*4+3], k[3*4+3], v1.w);
+
+ v2.x = mad(x[2*4+0], k[0*4+0], v2.x); //--
+ v2.x = mad(x[2*4+1], k[1*4+0], v2.x);
+ v2.x = mad(x[2*4+2], k[2*4+0], v2.x);
+ v2.x = mad(x[2*4+3], k[3*4+0], v2.x);
+ v2.y = mad(x[2*4+0], k[0*4+1], v2.y); //--
+ v2.y = mad(x[2*4+1], k[1*4+1], v2.y);
+ v2.y = mad(x[2*4+2], k[2*4+1], v2.y);
+ v2.y = mad(x[2*4+3], k[3*4+1], v2.y);
+ v2.z = mad(x[2*4+0], k[0*4+2], v2.z); //--
+ v2.z = mad(x[2*4+1], k[1*4+2], v2.z);
+ v2.z = mad(x[2*4+2], k[2*4+2], v2.z);
+ v2.z = mad(x[2*4+3], k[3*4+2], v2.z);
+ v2.w = mad(x[2*4+0], k[0*4+3], v2.w); //--
+ v2.w = mad(x[2*4+1], k[1*4+3], v2.w);
+ v2.w = mad(x[2*4+2], k[2*4+3], v2.w);
+ v2.w = mad(x[2*4+3], k[3*4+3], v2.w);
+
+ v3.x = mad(x[3*4+0], k[0*4+0], v3.x); //--
+ v3.x = mad(x[3*4+1], k[1*4+0], v3.x);
+ v3.x = mad(x[3*4+2], k[2*4+0], v3.x);
+ v3.x = mad(x[3*4+3], k[3*4+0], v3.x);
+ v3.y = mad(x[3*4+0], k[0*4+1], v3.y); //--
+ v3.y = mad(x[3*4+1], k[1*4+1], v3.y);
+ v3.y = mad(x[3*4+2], k[2*4+1], v3.y);
+ v3.y = mad(x[3*4+3], k[3*4+1], v3.y);
+ v3.z = mad(x[3*4+0], k[0*4+2], v3.z); //--
+ v3.z = mad(x[3*4+1], k[1*4+2], v3.z);
+ v3.z = mad(x[3*4+2], k[2*4+2], v3.z);
+ v3.z = mad(x[3*4+3], k[3*4+2], v3.z);
+ v3.w = mad(x[3*4+0], k[0*4+3], v3.w); //--
+ v3.w = mad(x[3*4+1], k[1*4+3], v3.w);
+ v3.w = mad(x[3*4+2], k[2*4+3], v3.w);
+ v3.w = mad(x[3*4+3], k[3*4+3], v3.w);
+
+ #endif // ----------------------------------------------------------
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ #if 1 // ----------------------------------------------------------
+
+ // 117ms
+ O.Set(n*4+0, y, x, k*4+0, v0.x, Odata);
+ O.Set(n*4+0, y, x, k*4+1, v0.y, Odata);
+ O.Set(n*4+0, y, x, k*4+2, v0.z, Odata);
+ O.Set(n*4+0, y, x, k*4+3, v0.w, Odata);
+
+ O.Set(n*4+1, y, x, k*4+0, v1.x, Odata);
+ O.Set(n*4+1, y, x, k*4+1, v1.y, Odata);
+ O.Set(n*4+1, y, x, k*4+2, v1.z, Odata);
+ O.Set(n*4+1, y, x, k*4+3, v1.w, Odata);
+
+ O.Set(n*4+2, y, x, k*4+0, v2.x, Odata);
+ O.Set(n*4+2, y, x, k*4+1, v2.y, Odata);
+ O.Set(n*4+2, y, x, k*4+2, v2.z, Odata);
+ O.Set(n*4+2, y, x, k*4+3, v2.w, Odata);
+
+ O.Set(n*4+3, y, x, k*4+0, v3.x, Odata);
+ O.Set(n*4+3, y, x, k*4+1, v3.y, Odata);
+ O.Set(n*4+3, y, x, k*4+2, v3.z, Odata);
+ O.Set(n*4+3, y, x, k*4+3, v3.w, Odata);
+
+ #else // ----------------------------------------------------------
+
+ // 118ms
+ O.Set(n*4+0, y, x, k*4+0, v0.x, Odata);
+ O.Set(n*4+1, y, x, k*4+0, v1.x, Odata);
+ O.Set(n*4+2, y, x, k*4+0, v2.x, Odata);
+ O.Set(n*4+3, y, x, k*4+0, v3.x, Odata);
+
+ O.Set(n*4+0, y, x, k*4+1, v0.y, Odata);
+ O.Set(n*4+1, y, x, k*4+1, v1.y, Odata);
+ O.Set(n*4+2, y, x, k*4+1, v2.y, Odata);
+ O.Set(n*4+3, y, x, k*4+1, v3.y, Odata);
+
+ O.Set(n*4+0, y, x, k*4+2, v0.z, Odata);
+ O.Set(n*4+1, y, x, k*4+2, v1.z, Odata);
+ O.Set(n*4+2, y, x, k*4+2, v2.z, Odata);
+ O.Set(n*4+3, y, x, k*4+2, v3.z, Odata);
+
+ O.Set(n*4+0, y, x, k*4+3, v0.w, Odata);
+ O.Set(n*4+1, y, x, k*4+3, v1.w, Odata);
+ O.Set(n*4+2, y, x, k*4+3, v2.w, Odata);
+ O.Set(n*4+3, y, x, k*4+3, v3.w, Odata);
+
+ #endif // ----------------------------------------------------------
+
+
+ #undef X_
+ #undef K_
+}
+
+#else // =====================================================================================================
+
+#undef CTILE
+#define CTILE 16
+#define RTILE 4
+groupshared float Conv_XcacheR[RTILE*RTILE][CTILE*CTILE];
+groupshared float Conv_KcacheR[RTILE*RTILE][CTILE*CTILE];
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Kernel3x3_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_XcacheR
+ #define K_ Conv_KcacheR
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float v[RTILE*RTILE];
+ for (uint xxxx = 0; xxxx < RTILE; ++xxxx)
+ {
+ float b = B.Get(0, 0, k*RTILE+xxxx, 0, WBKdata, WBK.dataLength);
+ for (uint yyyy = 0; yyyy < RTILE; ++yyyy)
+ v[yyyy*RTILE+xxxx] = b;
+ }
+
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (y+dy < _Offset) mask = false;
+ if (y+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) mask = false;
+ if (x+dx-_Offset >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/(CTILE*RTILE); ++m)
+ {
+
+ for (uint yy = 0; yy < RTILE; ++yy)
+ for (uint xx = 0; xx < RTILE; ++xx)
+ {
+ if (mask)
+ X_[yy*RTILE+xx][gy*CTILE+gx] = X.Get(n*RTILE+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*RTILE+xx, Xdata);
+ else
+ X_[yy*RTILE+xx][gy*CTILE+gx] = 0;
+ K_[yy*RTILE+xx][gy*CTILE+gx] = K.Get(dy, dx, (m*CTILE + gy)*RTILE+yy, k*RTILE+xx, WBKdata, WBK.dataLength);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint ii = 0; ii < CTILE; ++ii)
+ {
+ float x[RTILE*RTILE];
+ float k[RTILE*RTILE];
+
+ [unroll]
+ for (uint iii = 0; iii < RTILE*RTILE; ++iii)
+ {
+ x[iii] = X_[iii][gy*CTILE+ii];
+ k[iii] = K_[iii][ii*CTILE+gx];
+ }
+
+ [unroll]
+ for (uint r = 0; r < RTILE*RTILE; ++r)
+ {
+ [unroll]
+ for (uint i = 0; i < RTILE; ++i)
+ {
+ uint xxx = r % RTILE;
+ v[r] = mad(x[r], k[i*RTILE+xxx], v[r]);
+
+ //v[yyy][xxx] += x[yyy][i] * k[i][xxx];
+ }
+ }
+
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ for (uint yy = 0; yy < RTILE; ++yy)
+ for (uint xx = 0; xx < RTILE; ++xx)
+ O.Set(n*RTILE+yy, y, x, k*RTILE+xx, v[yy*RTILE+xx], Odata);
+
+ #undef X_
+ #undef K_
+}
+#endif
+
+[numthreads(CTILE, CTILE, 1)]
+void Conv2D_Kernel3x3_Cache_KCmod16_KNyx_TEMPLATE(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ #define X_ Conv_XcacheT
+ #define K_ Conv_KcacheT
+
+ uint gx = groupThreadID.x;
+ uint gy = groupThreadID.y;
+
+ uint k = CTILE * groupID.x + groupThreadID.x;
+ uint nyx = CTILE * groupID.y + groupThreadID.y;
+
+ uint width = X.width - _Border;
+ uint height = X.height - _Border;
+
+ uint x = nyx % width;
+ uint ny = nyx / width;
+ uint y = ny % height;
+ uint n = ny / height;
+
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (y+dy < _Offset) mask = false;
+ if (y+dy-_Offset >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (x+dx < _Offset) mask = false;
+ if (x+dx-_Offset >= X.width) mask = false;
+
+ //for (uint m = 0; m < (9*128)/CTILE; ++m)
+ for (uint m = 0; m < X.channels/CTILE; ++m)
+ {
+ if (mask)
+ X_[gy][gx] = X.Get(n, y+dy-_Offset, x+dx-_Offset, m*CTILE + gx, Xdata);
+ else
+ X_[gy][gx] = 0;
+ K_[gy][gx] = K.Get(dy, dx, m*CTILE + gy, k, WBKdata, WBK.dataLength);
+ GroupMemoryBarrierWithGroupSync();
+
+ [unroll]
+ for (uint i = 0; i < CTILE; ++i)
+ {
+ float x = X_[gy][i];
+ float k =.25;// K_[i][gx];
+ v += x * k;
+ }
+ }
+ }
+ }
+
+ //Odata[nyx * O.channels + k] = v;
+
+ Odata[((
+ n * O.height +
+ y ) * O.width +
+ x ) * O.channels +
+ k] = v;
+
+ #undef X_
+ #undef K_
+}
+// %TODO: only supports up to 51 kernels (51 = 16*16*2/(9kernel+1bias)) for now. Add a loop to handle more!
+/*
+groupshared float K1cache[KERNEL_SIZE][KERNEL_SIZE][32];
+groupshared float B1cache[32];
+[numthreads(16,16,2)]
+void Conv2D_Kernel3x3_1Channel(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint k = 16*groupID.x + groupThreadID.x;
+ uint n = 16*groupID.y + groupThreadID.y;
+ uint y = 2*groupID.z + groupThreadID.z + _FilterSize;
+
+ uint idx = 16*16*groupThreadID.z + 16*groupThreadID.y + groupThreadID.x;
+ if (idx < 9 * K.channels)
+ {
+ uint kx = idx / K.channels;
+ uint kk = idx % K.channels;
+ K1cache[kx/3][kx%3][kk] = K.Get(kx/3, kx%3, 0, kk, WBKdata, WBK.dataLength);
+ }
+ else if (idx < 10 * K.channels)
+ {
+ uint kk = idx % K.channels;
+ B1cache[kk] = B.Get(0, 0, kk, 0, WBKdata, WBK.dataLength);
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint x = _FilterSize; x < X.width - _FilterSize; ++x)
+ {
+ float v = B1cache[k];//B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ for (int i = -_FilterSize; i < _FilterSize + 1; ++i)
+ {
+ for (int j = -_FilterSize; j < _FilterSize + 1; ++j)
+ {
+ v += X.Get(n, y+j, x+i, 0, Xdata, X.dataLength) * K1cache[_FilterSize+j][_FilterSize+i][k];
+ }
+ }
+ O.Set(n, y-_FilterSize, x-_FilterSize, k, v, Odata, O.dataLength);
+ }
+}
+*/
+
+groupshared float K1cache[32][9];
+[numthreads(32,16,1)]
+void Conv2D_Kernel3x3_1Channel(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ uint tk = groupThreadID.x;
+ uint k = 32*groupID.x + tk;
+ uint n = 16*groupID.y + groupThreadID.y;
+
+ //for (uint q = 0; q < 9; ++q)
+ {
+ uint q = n % 9;
+ K1cache[tk][q] = K.Get(q/3, q%3, 0, k, WBKdata, WBK.dataLength);
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint y = 0; y < X.height - _FilterSize*2; ++y)
+ {
+ for (uint x = 0; x < X.width - _FilterSize*2; ++x)
+ {
+ float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
+ //for (uint q = 0; q < 9; ++q)
+ // v += X.Get(n, y+q/3, x+q%3, 0, Xdata, X.dataLength) * K1cache[tk][q];
+ v += X.Get(n, y+0, x+0, 0, Xdata, X.dataLength) * K1cache[tk][0];
+ v += X.Get(n, y+0, x+1, 0, Xdata, X.dataLength) * K1cache[tk][1];
+ v += X.Get(n, y+0, x+2, 0, Xdata, X.dataLength) * K1cache[tk][2];
+
+ v += X.Get(n, y+1, x+0, 0, Xdata, X.dataLength) * K1cache[tk][3];
+ v += X.Get(n, y+1, x+1, 0, Xdata, X.dataLength) * K1cache[tk][4];
+ v += X.Get(n, y+1, x+2, 0, Xdata, X.dataLength) * K1cache[tk][5];
+
+ v += X.Get(n, y+2, x+0, 0, Xdata, X.dataLength) * K1cache[tk][6];
+ v += X.Get(n, y+2, x+1, 0, Xdata, X.dataLength) * K1cache[tk][7];
+ v += X.Get(n, y+2, x+2, 0, Xdata, X.dataLength) * K1cache[tk][8];
+
+ O.Set(n, y, x, k, v, Odata, O.dataLength);
+ }
+ }
+}
+
+float fillValue;
+
+[numthreads(1,1,1)]
+void Fill(uint3 groupID : SV_GroupID)
+{
+ uint b = groupID.x;
+ uint h = groupID.y;
+ uint w = groupID.z;
+ for (uint ch = 0; ch < O.channels; ++ch)
+ O.Set(b, h, w, ch+1, fillValue, Odata, O.dataLength);
+}
+#endif
+
+
+/*
+Cbufferconsts{
+ uint n;
+ uint dispatchDim_x;};
+#define groupDim_x 512
+groupshared float Accumulate_sharedMem[groupDim_x * channels];
+[numthreads(groupDim_x, 1, 1)]
+void Accumulate(uint tid: SV_GroupIndex, uint3 groupIdx: groupID)
+{
+ #define sharedMem Reduce_sharedMem
+ unsigned int i = groupIdx.x * (groupDim_x * 2) + tid;
+ unsigned int dispatchSize = (groupDim_x * 2) * dispatchDim_x;
+ sharedMem[tid] = 0;
+ do {
+ sharedMem[tid] += g_idata[i] + g_idata[i+groupDim_x];
+ i += dispatchSize;
+ } while (i < n);
+ GroupMemoryBarrierWithGroupSync();
+
+ if (groupDim_x >= 256)
+ {
+ if (tid < 128) { sharedMem[tid] += sharedMem[tid + 128 * channels]; }
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ if (groupDim_x >= 128)
+ {
+ if (tid < 64) { sharedMem[tid] += sharedMem[tid + 64]; }
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ if (tid < 32)
+ {
+ if (groupDim_x >= 64) sharedMem[tid] += sharedMem[tid + 32* channels];
+ if (groupDim_x >= 32) sharedMem[tid] += sharedMem[tid + 16* channels];
+ if (groupDim_x >= 16) sharedMem[tid] += sharedMem[tid + 8* channels];
+ if (groupDim_x >= 8) sharedMem[tid] += sharedMem[tid + 4* channels];
+ if (groupDim_x >= 4) sharedMem[tid] += sharedMem[tid + 2* channels];
+ if (groupDim_x >= 2) sharedMem[tid] += sharedMem[tid + 1* channels];
+ }
+
+ if (tid == 0) g_odata[groupIdx.x] = sharedMem[0];
+
+ #undef sharedMem
+}
+*/
+ /*
+// Could do to reduce across NxN patch fitting within a group, HW <= HW / N
+// Repeat, until HW == 1
+
+// Alternatively reduce across Y axis, then X
+
+#undef MAX_CHANNELS
+#define MAX_CHANNELS 2048
+groupshared float GlobalAvgPool2D_AccumulatorPerChannel[MAX_CHANNELS];
+[numthreads(4,8,8)]
+void GlobalAvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID, uint threadID : SV_ThreadID)
+{
+ // NOTE: dispatched over X (not O)
+ DISPATCH_ARGS(X.channels, X.width, X.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= X.channels || c >= MAX_CHANNELS) return;
+ if (x >= X.width) return;
+ if (y >= X.height) return;
+
+ // Accumulate
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ // Clear accumulator
+ // @TODO: ThreadID
+ //uint threadID = groupThreadID.x * 4 + groupThreadID.y * 8 + groupThreadID.z * 8;
+ if (threadID < MAX_CHANNELS)
+ GlobalAvgPool2D_AccumulatorPerChannel[threadID] = 0;
+ GroupMemoryBarrierWithGroupSync();
+
+ GlobalAvgPool2D_AccumulatorPerChannel[c] += X.Get(n, y, x, c);
+ // @TODO: atomicAdd?
+
+ GroupMemoryBarrierWithGroupSync();
+ if (threadID < MAX_CHANNELS)
+ {
+ float v = GlobalAvgPool2D_AccumulatorPerChannel[threadID];
+ O.Set(n, 0, 0, c, v / (X.width * X.height));
+ }
+ }
+}*/
+
+
+[numthreads(64,2,2)]
+void Conv2D_Reg2x2(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x*2 >= O.width) return;
+ if (y*2 >= O.height) return;
+
+ uint2 leftCorner = _Pad.xy;
+ uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float4 acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos0 = uint2(x*2+0, y*2+0) * _Stride.xy + uint2(dx, dy);
+ uint2 pos1 = uint2(x*2+1, y*2+0) * _Stride.xy + uint2(dx, dy);
+ uint2 pos2 = uint2(x*2+0, y*2+1) * _Stride.xy + uint2(dx, dy);
+ uint2 pos3 = uint2(x*2+1, y*2+1) * _Stride.xy + uint2(dx, dy);
+
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ if (all(pos0 >= leftCorner) && all(pos0 < rightCorner))
+ acc.x = fastfma(X.Get(n, pos0 - leftCorner, c), K.Get(dy, dx, c, k), acc.x);
+ if (all(pos1 >= leftCorner) && all(pos1 < rightCorner))
+ acc.y = fastfma(X.Get(n, pos1 - leftCorner, c), K.Get(dy, dx, c, k), acc.y);
+ if (all(pos2 >= leftCorner) && all(pos2 < rightCorner))
+ acc.z = fastfma(X.Get(n, pos2 - leftCorner, c), K.Get(dy, dx, c, k), acc.z);
+ if (all(pos3 >= leftCorner) && all(pos3 < rightCorner))
+ acc.w = fastfma(X.Get(n, pos3 - leftCorner, c), K.Get(dy, dx, c, k), acc.w);
+ }
+ }
+ }
+
+ O.Set(n, y*2+0, x*2+0, k, acc.x);
+ O.Set(n, y*2+0, x*2+1, k, acc.y);
+ O.Set(n, y*2+1, x*2+0, k, acc.z);
+ O.Set(n, y*2+1, x*2+1, k, acc.w);
+ }
+}
+
+#define SIZE 2
+[numthreads(64, 2, 2)]
+void Conv2D_Reg_Loop(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x*SIZE >= O.width) return;
+ if (y*SIZE >= O.height) return;
+
+ uint2 leftCorner = _Pad.xy;
+ uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
+
+ // @TODO: investigate
+ // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
+
+ for (uint c = 0; c < X.channels; ++c)
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ if (all(pos[q] >= leftCorner) && all(pos[q] < rightCorner))
+ acc[q] = fastfma(X.Get(n, pos[q] - leftCorner, c), K.Get(dy, dx, c, k), acc[q]);
+ }
+ }
+
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
+ }
+}
+
+NUMTHREADS((16,4,4), (8,4,4), (16,2,2))
+//[numthreads(64, 1, 1)]
+void Conv2D_safe(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
+
+ for (uint c = 0; c < X.channels; ++c)
+ acc = fastfma(X.SafeGet(n, pos, c, _Pad.xy), K.Get(dy, dx, c, k), acc);
+ }
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+}
+
+
+#undef L1CACHESIZE
+#define L1CACHESIZE 32
+groupshared float Conv2D_L1Cached32_X[L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2D_L1Cached32(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2D_L1Cached32_X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.SafeGet(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos = uint2(x,y) * _Stride.xy + uint2(dx,dy);
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ X_[groupThreadID.x] = X.SafeGet(n, pos, c + groupThreadID.x, _Pad.xy);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels)
+ {
+ for (uint dc = 0; dc < L1CACHESIZE; ++dc)
+ acc = fastfma(X_[dc], K.Get(dy, dx, c + dc, k), acc);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+
+ #undef X_
+}
+
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+groupshared float Conv2D_L1Cached64_X[L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2D_L1Cached64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2D_L1Cached64_X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.SafeGet(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos = uint2(x,y) * _Stride.xy + uint2(dx,dy);
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ X_[groupThreadID.x] = X.SafeGet(n, pos, c + groupThreadID.x, _Pad.xy);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels)
+ {
+ for (uint dc = 0; dc < L1CACHESIZE; ++dc)
+ acc = fastfma(X_[dc], K.Get(dy, dx, c + dc, k), acc);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+
+ #undef X_
+}
+
+
+#undef SIZE
+#define SIZE 2
+[numthreads(64, 2, 2)]
+void Conv2D_Reg_Loop_safe(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x*SIZE >= O.width) return;
+ if (y*SIZE >= O.height) return;
+
+ uint2 leftCorner = _Pad.xy;
+ uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
+
+ // @TODO: investigate
+ // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
+
+ for (uint c = 0; c < X.channels; ++c)
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = fastfma(X.SafeGet(n, pos[q], c, _Pad.xy), K.Get(dy, dx, c, k), acc[q]);
+ }
+ }
+
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
+ }
+}
+
+
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+#undef SIZE
+#define SIZE 2
+groupshared float Conv2D_L1Cached64_Reg_Loop2x2_X[SIZE*SIZE][L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2D_L1Cached64_Reg_Loop2x2(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2D_L1Cached64_Reg_Loop2x2_X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ // need all threads to load channels, thus will do late check against kernel count
+ if (x*SIZE >= O.width) return;
+ if (y*SIZE >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = B.SafeGet(k);
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ uint dc = groupThreadID.x;
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ X_[q][dc] = X.SafeGet(n, pos[q], c + dc, _Pad.xy);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ {
+ uint kIndex = K.Index(dy, dx, c, k);
+ for (dc = 0; dc < L1CACHESIZE; ++dc)
+ {
+ for (q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = fastfma(X_[q][dc], K.data[kIndex], acc[q]); //K.Get(dy, dx, c + dc, k);
+ kIndex += K.channels;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
+ }
+
+ #undef X_
+}
+
+
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+#undef SIZE
+#define SIZE 4
+groupshared float Conv2D_L1Cached64_Reg_Loop_X[SIZE*SIZE][L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2D_L1Cached64_Reg_Loop(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2D_L1Cached64_Reg_Loop_X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ // need all threads to load channels, thus will do late check against kernel count
+ if (x*SIZE >= O.width) return;
+ if (y*SIZE >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = B.SafeGet(k);
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ uint dc = groupThreadID.x;
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ X_[q][dc] = X.SafeGet(n, pos[q], c + dc, _Pad.xy);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ {
+ uint kIndex = K.Index(dy, dx, c, k);
+ for (dc = 0; dc < L1CACHESIZE; ++dc)
+ {
+ for (q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = fastfma(X_[q][dc], K.data[kIndex], acc[q]);//K.Get(dy, dx, c + dc, k);
+ kIndex += K.channels;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
+ }
+
+ #undef X_
+}
+
+
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+#define SIZE_W 4
+#define SIZE_H 2
+groupshared float Conv2D_L1Cached64_Reg_Loop_safe__X[SIZE_H*SIZE_W][L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2D_L1Cached64_Reg_Loop_safe_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2D_L1Cached64_Reg_Loop_safe__X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ // need all threads to load channels, thus will do late check against kernel count
+ if (x*SIZE_W >= O.width) return;
+ if (y*SIZE_H >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE_H*SIZE_W];
+ [unroll]
+ for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
+ acc[q] = B.SafeGet(k);
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos[SIZE_H*SIZE_W];
+ [unroll]
+ for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
+ pos[q] = uint2(x*SIZE_W+(q%SIZE_W), y*SIZE_H+(q/SIZE_W)) * _Stride.xy + uint2(dx, dy);
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ uint dc = groupThreadID.x;
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ X_[q][dc] = X.SafeGet(n, pos[q], c + dc, _Pad.xy);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ {
+ uint kIndex = K.Index(dy, dx, c, k);
+ for (dc = 0; dc < L1CACHESIZE; ++dc)
+ {
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ acc[q] = fastfma(X_[q][dc], K.data[kIndex], acc[q]);
+ kIndex += K.channels;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ {
+ uint ox = x*SIZE_W+(q%SIZE_W);
+ uint oy = y*SIZE_H+(q/SIZE_W);
+ if (ox < O.width && oy < O.height)
+ O.Set(n, oy, ox, k, acc[q]);
+ }
+ }
+
+ #undef X_
+}
+#undef SIZE_H
+#undef SIZE_W
+
+
+/*
+#undef L1CACHESIZE
+#define L1CACHESIZE 32
+#define SIZE_W 4
+#define SIZE_H 2
+groupshared float Conv2D_L1Cached64_Reg_Loop_safe__X[SIZE_H*SIZE_W][L1CACHESIZE];
+[numthreads(L1CACHESIZE, SIZE_W, SIZE_H)]
+void Conv2D_L1Cached64_Reg_Loop_safe_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2D_L1Cached64_Reg_Loop_safe__X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = SIZE_W * groupID.y + groupThreadID.y;
+ uint y = SIZE_H * groupID.z + groupThreadID.z;
+
+ // need all threads to load channels, thus will do late check against kernel count
+ //if (x*SIZE_W >= O.width) return;
+ //if (y*SIZE_H >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE_H*SIZE_W];
+ [unroll]
+ for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
+ acc[q] = B.SafeGet(k);
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ //uint2 pos[SIZE_H*SIZE_W];
+ //[unroll]
+ //for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
+ // pos[q] = uint2(x*SIZE_W+(q%SIZE_W), y*SIZE_H+(q/SIZE_W)) * _Stride.xy + uint2(dx, dy);
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ uint dc = groupThreadID.x;
+ uint gx = groupThreadID.y;
+ uint gy = groupThreadID.z;
+ //[unroll]
+ //for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ //{
+ uint2 pos = uint2(x*SIZE_W+gx, y*SIZE_H+gy) * _Stride.xy + uint2(dx, dy);
+ X_[SIZE_W*gy+gx][dc] = X.SafeGet(n, pos, c + dc, _Pad.xy);
+ //}
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels &&
+ x*SIZE_W < O.width &&
+ y*SIZE_H < O.height) // need all threads to load channels, thus late check against kernel count
+ {
+ uint kIndex = K.Index(dy, dx, c, k);
+ for (dc = 0; dc < L1CACHESIZE; ++dc)
+ {
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ acc[q] += X_[q][dc] * K.data[kIndex];//K.Get(dy, dx, c + dc, k);
+ kIndex += K.channels;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ {
+ uint ox = x*SIZE_W+(q%SIZE_W);
+ uint oy = y*SIZE_H+(q/SIZE_W);
+ if (ox < O.width && oy < O.height)
+ O.Set(n, oy, ox, k, acc[q]);
+ }
+ }
+
+ #undef X_
+}
+#undef SIZE_H
+#undef SIZE_W
+*/
+
+/*
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+groupshared float Conv2D_RegCached_X[4][L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2D_RegCached(uint3 dispatchThreadID : SV_DispatchThreadID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2D_RegCached_X
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (x*2 >= O.width) return;
+ if (y*2 >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float4 acc = B.SafeGet(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint2 pos0 = uint2(x*2+0,y*2+0) * _Stride + uint2(dx,dy);
+ uint2 pos1 = uint2(x*2+1,y*2+0) * _Stride + uint2(dx,dy);
+ uint2 pos2 = uint2(x*2+0,y*2+1) * _Stride + uint2(dx,dy);
+ uint2 pos3 = uint2(x*2+1,y*2+1) * _Stride + uint2(dx,dy);
+
+ // Cache X
+ uint c_ = groupThreadID.x;
+ if (c_ < X.channels)
+ {
+ X_[0][c_] = X.SafeGet(n, pos0, c_, _Pad.xy);
+ X_[1][c_] = X.SafeGet(n, pos1, c_, _Pad.xy);
+ X_[2][c_] = X.SafeGet(n, pos2, c_, _Pad.xy);
+ X_[3][c_] = X.SafeGet(n, pos3, c_, _Pad.xy);
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels)
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ acc.x += X_[0][c] * K.Get(dy, dx, c, k);
+ acc.y += X_[1][c] * K.Get(dy, dx, c, k);
+ acc.z += X_[2][c] * K.Get(dy, dx, c, k);
+ acc.w += X_[3][c] * K.Get(dy, dx, c, k);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+
+ O.Set(n, y*2+0, x*2+0, k, acc.x);
+ O.Set(n, y*2+0, x*2+1, k, acc.y);
+ O.Set(n, y*2+1, x*2+0, k, acc.z);
+ O.Set(n, y*2+1, x*2+1, k, acc.w);
+ }
+}
+*/
+
+/*
+[numthreads(16,4,4)]
+void Conv2DTrans(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ uint2 strideMask = _Stride.xy - 1;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); dy += _Stride.y)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); dx += _Stride.x)
+ {
+ uint dxShifted = dx + (x&strideMask.x);
+ uint dyShifted = dy + (y&strideMask.y);
+
+ uint xx = x + dxShifted;
+ uint yy = y + dyShifted;
+
+ uint oy = (yy - _Pad.y) / _Stride.y;
+ uint ox = (xx - _Pad.x) / _Stride.x;
+
+ bool mask = xx >= _Pad.x && yy >= _Pad.y && ox < X.width && oy < X.height;
+ if (!mask) continue;
+
+ // [unroll] - crashes metal compiler
+ for (uint c = 0; c < X.channels; ++c)
+ {
+ acc += X.Get(n, oy, ox, c) * K.Get( K.GetKernelHeight() - 1 - dyShifted,
+ K.GetKernelWidth() - 1 - dxShifted, c, k);
+ }
+ }
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+}
+*/
+
+
+
+#undef SIZE
+#define SIZE 4
+[numthreads(16, 4, 4)]
+void Conv2DTrans_Reg_Loop_safe(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x*SIZE >= O.width) return;
+ if (y*SIZE >= O.height) return;
+
+ uint2 strideMask = _Stride.xy - 1;
+
+ uint2 pad = _Pad.xy / _Stride.xy;
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE*SIZE];
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = B.Get(k);
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); dy += _Stride.y)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); dx += _Stride.x)
+ {
+ uint2 kernelPos[SIZE*SIZE];
+ uint2 pos[SIZE*SIZE];
+
+ [unroll]
+ for (uint q = 0; q < SIZE*SIZE; ++q)
+ {
+ uint2 xy = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE));
+ kernelPos[q] = uint2(dx, dy) + (xy & strideMask);
+ pos[q] = (xy + kernelPos[q]) / _Stride.xy;
+
+ // transpose
+ kernelPos[q] = uint2(K.GetKernelWidth(), K.GetKernelHeight()) - 1 - kernelPos[q];
+ }
+
+ for (uint c = 0; c < X.channels; ++c)
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ acc[q] = fastfma(X.SafeGet(n, pos[q], c, pad.xy), K.Get(kernelPos[q].y, kernelPos[q].x, c, k), acc[q]);
+ //acc[q] += X.SafeGet(n, pos[q], c, pad.xy) * K.Get(kernelPos[q].y, kernelPos[q].x, c, k);
+ }
+ }
+
+ [unroll]
+ for (q = 0; q < SIZE*SIZE; ++q)
+ O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
+ }
+}
+
+
+
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+#define SIZE_W 4
+#define SIZE_H 2
+groupshared float Conv2DTrans_L1Cached64_Reg_Loop_safe__X[SIZE_H*SIZE_W][L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2DTrans_L1Cached64_Reg_Loop_safe_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2DTrans_L1Cached64_Reg_Loop_safe__X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ // need all threads to load channels, thus will do late check against kernel count
+ if (x*SIZE_W >= O.width) return;
+ if (y*SIZE_H >= O.height) return;
+
+ uint2 strideMask = _Stride.xy - 1;
+ uint2 pad = _Pad.xy / _Stride.xy;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc[SIZE_H*SIZE_W];
+ [unroll]
+ for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
+ acc[q] = B.SafeGet(k);
+
+ for (uint dy = 0; dy < K.GetKernelHeight(); dy += _Stride.y)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); dx += _Stride.x)
+ {
+ uint2 kernelPos[SIZE_H*SIZE_W];
+ uint2 pos[SIZE_H*SIZE_W];
+
+ [unroll]
+ for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
+ {
+ uint2 xy = uint2(x*SIZE_W+(q%SIZE_W), y*SIZE_H+(q/SIZE_W));
+ kernelPos[q] = uint2(dx, dy) + (xy & strideMask);
+ pos[q] = (xy + kernelPos[q]) / _Stride.xy;
+
+ // transpose
+ kernelPos[q] = uint2(K.GetKernelWidth(), K.GetKernelHeight()) - 1 - kernelPos[q];
+ }
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ uint dc = groupThreadID.x;
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ X_[q][dc] = X.SafeGet(n, pos[q], c + dc, pad.xy);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ {
+ for (dc = 0; dc < L1CACHESIZE; ++dc)
+ {
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ acc[q] = fastfma(X_[q][dc], K.Get(kernelPos[q].y, kernelPos[q].x, c + dc, k), acc[q]);
+ //acc[q] += X_[q][dc] * K.Get(kernelPos[q].y, kernelPos[q].x, c + dc, k);
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ [unroll]
+ for (q = 0; q < SIZE_H*SIZE_W; ++q)
+ {
+ uint ox = x*SIZE_W+(q%SIZE_W);
+ uint oy = y*SIZE_H+(q/SIZE_W);
+ if (ox < O.width && oy < O.height)
+ O.Set(n, oy, ox, k, acc[q]);
+ }
+ }
+
+ #undef X_
+}
+#undef SIZE_H
+#undef SIZE_W
+
+
+/*
+#undef L1CACHESIZE
+#define L1CACHESIZE 64
+groupshared float Conv2DTrans_L1Cached64_Reg_Loop_safe_X[L1CACHESIZE];
+[numthreads(L1CACHESIZE, 1, 1)]
+void Conv2DTrans_L1Cached64_Reg_Loop_safe(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ DISPATCH_ARGS(K.kernelCount, X.width, X.height);
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv2DTrans_L1Cached64_Reg_Loop_safe_X
+
+ uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
+ uint x = groupID.y;
+ uint y = groupID.z;
+
+ // need all threads to load channels, thus will do late check against kernel count
+ if (x >= X.width) return;
+ if (y >= X.height) return;
+
+ uint2 pad = _Pad.xy / _Stride.xy;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ for (uint sy = 0; sy < _Stride.y; ++sy)
+ {
+ for (uint sx = 0; sx < _Stride.x; ++sx)
+ {
+ float acc = B.SafeGet(k);
+
+ for (uint dy = sy; dy < K.GetKernelHeight(); dy += _Stride.y)
+ {
+ for (uint dx = sx; dx < K.GetKernelWidth(); dx += _Stride.x)
+ {
+ uint2 pos = uint2(x, y) + uint2(sx + dx, sy + dy) / _Stride.xy;
+
+ for (uint c = 0; c < X.channels; c += L1CACHESIZE)
+ {
+ // Cache X
+ uint dc = groupThreadID.x;
+ X_[dc] = X.SafeGet(n, pos, c + dc, pad);
+ GroupMemoryBarrierWithGroupSync();
+
+ // X * K
+ if (k < K.channels) // need all threads to load channels, thus late check against kernel count
+ {
+ for (dc = 0; dc < L1CACHESIZE; ++dc)
+ {
+ acc = fastfma( X_[dc],
+ K.Get( K.GetKernelHeight() - 1 - dy,
+ K.GetKernelWidth() - 1 - dx, c + dc, k),
+ acc);
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ uint oy = y * _Stride.y + sy;
+ uint ox = x * _Stride.x + sx;
+ if (oy < O.height && ox < O.width && k < K.channels)
+ O.Set(n, oy, ox, k, acc);
+ }
+ }
+ }
+
+ #undef X_
+}
+*/
+#endif
+
+
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute.meta
new file mode 100755
index 00000000..49e7b42d
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 299ca130202014274b506123e830c52d
+timeCreated: 1506672486
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute
new file mode 100755
index 00000000..afb83544
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute
@@ -0,0 +1,188 @@
+//#pragma kernel Dense64
+//#pragma kernel Conv2D_Kernel3x3_64
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(W)
+TENSOR_DECL(K)
+TENSOR_DECL(B)
+TENSOR_DECL(WBK)
+TENSOR_DECL_RW(O)
+
+uint4 _Pad;
+uint4 _Stride;
+
+#undef THREAD_COUNT
+#define THREAD_COUNT 64 // ATM support only 8x8
+
+#undef BLOCK_WIDTH
+#define BLOCK_WIDTH 8
+
+#undef LOAD_WIDTH
+#define LOAD_WIDTH THREAD_COUNT
+
+#undef LOAD_DEPTH
+#define LOAD_DEPTH BLOCK_WIDTH
+
+groupshared float DenseTiled_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
+groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
+
+[numthreads(THREAD_COUNT, 1, 1)]
+void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ // @TODO: DISPATCH_ARGS(...)
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ #define X_ DenseTiled_XcacheR
+ #define W_ DenseTiled_WcacheR
+
+ uint id = groupThreadID.x;
+ uint bx = groupID.x;
+ uint by = groupID.y;
+
+ uint bbx = id % BLOCK_WIDTH;
+ uint bby = id / BLOCK_WIDTH;
+
+ float v[BLOCK_WIDTH][BLOCK_WIDTH];
+ for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
+ for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
+ {
+ float bias = B.Get(bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx);
+ v[yy][xx] = bias;
+ }
+
+ for (uint m = 0; m < X.GetFlatWidth()/LOAD_DEPTH; ++m)
+ {
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ X_[q][id] = X.Get(by*LOAD_WIDTH + id, m*LOAD_DEPTH + q);
+ W_[q][id] = W.Get(m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ v[yyy][xxx] = mad(X_[i][bby*BLOCK_WIDTH + yyy], W_[i][bbx*BLOCK_WIDTH + xxx], v[yyy][xxx]);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, v[yyy][xxx]);
+
+ #undef X_
+ #undef W_
+}
+
+
+#undef THREAD_COUNT
+#define THREAD_COUNT 64 // ATM support only 8x8
+
+#undef BLOCK_WIDTH
+#define BLOCK_WIDTH 8
+
+#undef LOAD_WIDTH
+#define LOAD_WIDTH THREAD_COUNT
+
+#undef LOAD_DEPTH
+#define LOAD_DEPTH BLOCK_WIDTH
+
+groupshared float Conv_KcacheR[LOAD_DEPTH][LOAD_WIDTH];
+groupshared float Conv_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
+[numthreads(THREAD_COUNT, 1, 1)]
+void Conv2D_Kernel3x3_64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+ // @TODO: DISPATCH_ARGS(...)
+ TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
+
+ #define X_ Conv_XcacheR
+ #define K_ Conv_KcacheR
+
+ uint id = groupThreadID.x;
+ uint bx = groupID.x;
+ uint by = groupID.y;
+
+ uint bbx = id % BLOCK_WIDTH;
+ uint bby = id / BLOCK_WIDTH;
+
+ uint width = O.width;
+ uint height = O.height;
+
+ // ASSERT(LOAD_WIDTH == THREAD_COUNT)
+ uint loadNYX = by*LOAD_WIDTH + id; // only works for 8x8
+ uint loadX = loadNYX % width;
+ uint loadNY = loadNYX / width;
+ uint loadY = loadNY % height;
+ uint loadN = loadNY / height;
+
+ // @TODO: validate that _Stride works, added the following 2 lines without testing
+ loadX *= _Stride.x;
+ loadY *= _Stride.y;
+
+ float v[BLOCK_WIDTH][BLOCK_WIDTH];
+ [unroll] for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
+ [unroll] for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
+ {
+ float bias = B.Get(bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx);
+ v[yy][xx] = bias;
+ }
+
+ for (uint dy = 0; dy < 3; ++dy)
+ {
+ bool mask = true;
+
+ if (loadY+dy < _Pad.y) mask = false;
+ if (loadY+dy - _Pad.w >= X.height) mask = false;
+
+ for (uint dx = 0; dx < 3; ++dx)
+ {
+ if (loadX+dx < _Pad.x) mask = false;
+ if (loadX+dx - _Pad.z >= X.width) mask = false;
+
+ for (uint m = 0; m < X.channels/LOAD_DEPTH; ++m)
+ {
+ for (uint q = 0; q < LOAD_DEPTH; ++q)
+ {
+ if (mask)
+ X_[q][id] = X.Get(loadN, loadY+dy-_Pad.y, loadX+dx-_Pad.x, m*LOAD_DEPTH + q);
+ else
+ X_[q][id] = 0;
+ K_[q][id] = K.Get(dy, dx, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
+ {
+ v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * K_[i][bbx*BLOCK_WIDTH + xxx];
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+ }
+ }
+
+ [unroll] for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
+ [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
+ {
+ uint saveNYX = by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy;
+ uint saveX = saveNYX % width;
+ uint saveNY = saveNYX / width;
+ uint saveY = saveNY % height;
+ uint saveN = saveNY / height;
+
+ uint saveK = bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx;
+ O.Set(saveN, saveY, saveX, saveK, v[yyy][xxx]);
+ }
+
+ #undef X_
+ #undef K_
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute.meta
new file mode 100755
index 00000000..91a84252
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: c7c673db45e6845d5abaed4ed5ef42e1
+timeCreated: 1507294253
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute
new file mode 100755
index 00000000..4ca7769c
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute
@@ -0,0 +1,339 @@
+#pragma kernel ScaleBias
+#pragma kernel ScaleBias_CNyx
+#pragma kernel Upsample2D
+#pragma kernel AvgPool2D
+#pragma kernel MaxPool2D
+#pragma kernel AvgPool2D_NoPads
+#pragma kernel MaxPool2D_NoPads
+//#pragma kernel MaxPool2D_Pool2x2_NoPads
+#pragma kernel GlobalAvgPool2D
+#pragma kernel InstanceNorm
+#pragma kernel Copy
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(W)
+TENSOR_DECL(B)
+TENSOR_DECL(WBK)
+TENSOR_DECL_RW(O)
+
+uint4 _Pool;
+uint4 _Stride;
+uint4 _Pad;
+float _Alpha;
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void ScaleBias(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ float bias = B.Get(0, 0, 0, c);
+ float scale = W.Get(0, 0, 0, c);
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ v = v * scale + bias;
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
+void ScaleBias_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ uint c = dispatchThreadID.x;
+ uint nyx = dispatchThreadID.y;
+
+ uint x = nyx % X.width;
+ uint ny = nyx / X.width;
+ uint y = ny % X.height;
+ uint n = ny / X.height;
+
+ if (c >= X.channels) return;
+ if (n >= X.batch) return;
+
+ float bias = B.Get(0, 0, 0, c);
+ float scale = W.Get(0, 0, 0, c);
+
+ float v = X.Get(n, y, x, c);
+ v = v * scale + bias;
+ O.Set(n, y, x, c, v);
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Upsample2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // NOTE: dispatched over X (not O)
+ DISPATCH_ARGS(X.channels, X.width, X.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= X.channels) return;
+ if (x >= X.width) return;
+ if (y >= X.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+
+ for (uint dy = 0; dy < _Pool.y; ++dy)
+ for (uint dx = 0; dx < _Pool.x; ++dx)
+ {
+ uint oy = y * _Pool.y + dy;
+ uint ox = x * _Pool.x + dx;
+ O.Set(n, oy, ox, c, v);
+ }
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void MaxPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float maxV = -FLT_MAX;
+ for (uint dy = 0; dy < _Pool.y; ++dy)
+ for (uint dx = 0; dx < _Pool.x; ++dx)
+ {
+ uint oy = y * _Stride.y + dy;
+ uint ox = x * _Stride.x + dx;
+
+ bool mask = (oy >= _Pad.y) && (ox >= _Pad.x) && (oy - _Pad.w < X.height) && (ox - _Pad.z < X.width);
+ float v = (mask)? X.Get(n, oy - _Pad.y, ox - _Pad.x, c): 0;
+ maxV = max(v, maxV);
+ }
+
+ O.Set(n, y, x, c, maxV);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void AvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float acc = 0;
+ float counter = 0;
+ for (uint dy = 0; dy < _Pool.y; ++dy)
+ for (uint dx = 0; dx < _Pool.x; ++dx)
+ {
+ uint oy = y * _Stride.y + dy;
+ uint ox = x * _Stride.x + dx;
+
+ bool mask = (oy >= _Pad.y) && (ox >= _Pad.x) && (oy - _Pad.w < X.height) && (ox - _Pad.z < X.width);
+ acc += (mask)? X.Get(n, oy - _Pad.y, ox - _Pad.x, c): 0;
+ counter += (mask)? 1: 0;
+ }
+
+ acc /= counter;
+ O.Set(n, y, x, c, acc);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void MaxPool2D_NoPads(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float maxV = -FLT_MAX;
+ for (uint dy = 0; dy < _Pool[1]; ++dy)
+ for (uint dx = 0; dx < _Pool[0]; ++dx)
+ {
+ float v = X.Get(n, y * _Stride[1] + dy, x * _Stride[0] + dx, c);
+ maxV = max(v, maxV);
+ }
+
+ O.Set(n, y, x, c, maxV);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void AvgPool2D_NoPads(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ float invPoolSize = 1.0f / (_Pool[0] * _Pool[1]);
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = 0;
+ for (uint dy = 0; dy < _Pool[1]; ++dy)
+ for (uint dx = 0; dx < _Pool[0]; ++dx)
+ v += X.Get(n, y * _Stride[1] + dy, x * _Stride[0] + dx, c) * invPoolSize;
+
+ O.Set(n, y, x, c, v);
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+//NUMTHREADS((16,4,4), (16,4,2), (16,2,2))
+void MaxPool2D_Pool2x2_NoPads(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, O.width, O.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (c >= O.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v0 = X.Get(n, y*2, x*2, c);
+ float v1 = X.Get(n, y*2+1, x*2, c);
+ float v2 = X.Get(n, y*2, x*2+1, c);
+ float v3 = X.Get(n, y*2+1, x*2+1, c);
+ float v = max(v0, max(v1, max(v2, v3)));
+
+ O.Set(n, y, x, c, v);
+ }
+}
+
+[numthreads(32,1,1)]
+void GlobalAvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, 1, 1);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x;
+ if (c >= O.channels) return;
+ //ASSERT(X.batch == O.batch)
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = 0;
+ for (uint y = 0; y < X.height; ++y)
+ for (uint x = 0; x < X.width; ++x)
+ v += X.Get(n, y, x, c);
+
+ v /= (X.height * X.width);
+ O.Set(n, 0, 0, c, v);
+ }
+}
+
+[numthreads(64,1,1)]
+void InstanceNorm(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ DISPATCH_ARGS(O.channels, 1, 1);
+ TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
+
+ uint c = dispatchThreadID.x;
+ if (c >= O.channels) return;
+ //ASSERT(X.shape == O.shape)
+
+ float gamma = W.Get(0, 0, 0, c);
+ float beta = B.Get(0, 0, 0, c);
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ uint x, y;
+ // calc mean
+ float acc = 0;
+ for (y = 0; y < O.height; ++y)
+ for (x = 0; x < O.width; ++x)
+ acc += X.Get(n, y, x, c);
+ float mean = acc / (O.width * O.height);
+
+ // calc variance
+ acc = 0;
+ for (y = 0; y < O.height; ++y)
+ for (x = 0; x < O.width; ++x)
+ {
+ float delta = X.Get(n, y, x, c) - mean;
+ acc += delta * delta;
+ }
+ float var = acc / (O.width * O.height);
+
+ // normalization factor
+ float invNormFactor = 1 / sqrt(var + FLT_EPSILON);
+
+ float scale = gamma * invNormFactor;
+ float bias = beta - gamma * mean * invNormFactor;
+
+ // apply normalization
+ for (y = 0; y < O.height; ++y)
+ for (x = 0; x < O.width; ++x)
+ {
+ float v = X.Get(n, y, x, c);
+ v = v * scale + bias;
+ O.Set(n, y, x, c, v);
+ }
+ }
+}
+
+NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
+void Copy(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // NOTE: dispatched over X (not O)
+ DISPATCH_ARGS(X.channels, X.width, X.height);
+ TENSOR_ARGS2(X, O);
+
+ uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
+ if (c >= X.channels) return; if (x >= X.width) return; if (y >= X.height) return;
+
+ for (uint n = 0; n < X.batch; ++n)
+ {
+ float v = X.Get(n, y, x, c);
+ O.Set(n + _Pad[0], y + _Pad[1], x + _Pad[2], c + _Pad[3], v);
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute.meta
new file mode 100755
index 00000000..47cf3515
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 62f5efacd43b24dd38ead3ce0d80cc34
+timeCreated: 1495527718
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc
new file mode 100755
index 00000000..0c416189
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc
@@ -0,0 +1,70 @@
+
+// Based on: https://stackoverflow.com/questions/5149544/can-i-generate-a-random-number-inside-a-pixel-shader
+// Output: Random number: [0,1), that is between 0.0 and 0.999999... inclusive.
+// Author: Michael Pohoreski
+// Copyright: Copyleft 2012 :-)
+float RandomUsingCos(float4 seed)
+{
+ float4 K1 = float4( // Transcendental numbers:
+ 0.64341054629, // (Cahen's constant)
+ 23.14069263277926, // e^pi (Gelfond's constant)
+ 2.665144142690225, // 2^sqrt(2) (Gelfond-Schneider constant)
+ 3.14159265359 // pi
+ );
+ return frac(cos(dot(seed, K1)) * 12345.6789);
+}
+
+// Based on: https://stackoverflow.com/questions/4200224/random-noise-functions-for-glsl
+// Author: Spatial
+// 05 July 2013
+
+// A single iteration of Bob Jenkins' One-At-A-Time hashing algorithm.
+uint hash(uint x)
+{
+ x += ( x << 10u );
+ x ^= ( x >> 6u );
+ x += ( x << 3u );
+ x ^= ( x >> 11u );
+ x += ( x << 15u );
+ return x;
+}
+uint hash( uint2 v ) { return hash( v.x ^ hash(v.y) ); }
+uint hash( uint3 v ) { return hash( v.x ^ hash(v.y) ^ hash(v.z) ); }
+uint hash( uint4 v ) { return hash( v.x ^ hash(v.y) ^ hash(v.z) ^ hash(v.w) ); }
+
+// Construct a float with half-open range [0:1] using low 23 bits.
+// All zeroes yields 0.0, all ones yields the next smallest representable value below 1.0.
+float floatConstruct(uint m)
+{
+ const uint ieeeMantissa = 0x007FFFFFu; // binary32 mantissa bitmask
+ const uint ieeeOne = 0x3F800000u; // 1.0 in IEEE binary32
+
+ m &= ieeeMantissa; // Keep only mantissa bits (fractional part)
+ m |= ieeeOne; // Add fractional part to 1.0
+
+ float f = asfloat(m); // Range [1:2]
+ return f - 1.0; // Range [0:1]
+}
+
+// Pseudo-random value in half-open range [0:1].
+float RandomUsingHash(float4 seed)
+{
+ return floatConstruct(hash(asuint(seed)));
+}
+
+
+// More alternatives:
+// https://github.com/ashima/webgl-noise
+// https://www.shadertoy.com/view/4djSRW
+
+// ------------------------------------------------------------------------------------------
+
+float Random(float4 seed)
+{
+ return RandomUsingCos(seed);
+}
+
+float Bernoulli(float4 seed, float p)
+{
+ return Random(seed) <= p ? 1: 0;
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc.meta
new file mode 100755
index 00000000..572d47b4
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: 5a17e0b3943a74564a02a8ed0a41228b
+timeCreated: 1520855309
+licenseType: Pro
+ShaderImporter:
+ externalObjects: {}
+ defaultTextures: []
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc
new file mode 100755
index 00000000..2ca608f5
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc
@@ -0,0 +1,311 @@
+#define BARRACUDA_MAX_THREAD_COUNT 64
+#if (BARRACUDA_MAX_THREAD_COUNT>=256)
+#define NUMTHREADS(t256,t128,t64) [numthreads t256]
+#define NUMTHREAD(t256, t128, t64) t256
+#elif (BARRACUDA_MAX_THREAD_COUNT>=128)
+#define NUMTHREADS(t256,t128,t64) [numthreads t128]
+#define NUMTHREAD(t256,t128,t64) t128
+#elif (BARRACUDA_MAX_THREAD_COUNT>=64)
+#define NUMTHREADS(t256,t128,t64) [numthreads t64]
+#define NUMTHREAD(t256,t128,t64) t64
+#endif
+
+struct Tensor
+{
+ // @TODO: actually uint seems not like a good idea anymore, consider going to int
+ uint batch, height, width, channels;
+
+ void Init(uint4 nhwc)
+ {
+ batch = nhwc.x;
+ height = nhwc.y;
+ width = nhwc.z;
+ channels = nhwc.w;
+ }
+
+ uint4 Dims()
+ {
+ return uint4(batch, height, width, channels);
+ }
+ uint GetFlatHeight()
+ {
+ return batch;
+ }
+ uint GetFlatWidth()
+ {
+ return height * width * channels;
+ }
+ uint GetKernelHeight()
+ {
+ // kernels storage: {kernel_width * kernel_height * kernel_channels * kernel_count}
+ uint kernelHeight = batch;
+ return kernelHeight;
+ }
+ uint GetKernelWidth()
+ {
+ // kernels storage: {kernel_width * kernel_height * kernel_channels * kernel_count}
+ uint kernelWidth = height;
+ return kernelWidth;
+ }
+
+ uint Index(uint b, uint h, uint w, uint ch)
+ {
+ uint index =
+ b * height * width * channels +
+ h * width * channels +
+ w * channels +
+ ch;
+ return index;
+ }
+
+ uint Index(uint b, uint i)
+ {
+ uint index =
+ b * height * width * channels +
+ i;
+ return index;
+ }
+};
+
+struct ReadonlyTensor : Tensor
+{
+ StructuredBuffer data;
+
+ void Init(uint4 nhwc, StructuredBuffer data_)
+ {
+ Tensor::Init(nhwc);
+ data = data_;
+ }
+
+ float Get(uint b, uint h, uint w, uint ch)
+ {
+ return data[Index(b,h,w,ch)];
+ }
+ float Get(uint b, uint2 pos, uint ch)
+ {
+ return data[Index(b, pos.y, pos.x, ch)];
+ }
+ float Get(uint b, uint i)
+ {
+ return data[Index(b,i)];
+ }
+ float Get(uint i)
+ {
+ return data[i];
+ }
+
+ float BroadcastGet(uint b, uint h, uint w, uint ch)
+ {
+ return Get(b % batch, h % height, w % width, ch % channels);
+ }
+ float BroadcastGet(uint b, uint2 pos, uint ch)
+ {
+ return BroadcastGet(b, pos.y, pos.x, ch);
+ }
+ float BroadcastGet(uint b, uint i)
+ {
+ return Get(b % GetFlatHeight(), i % GetFlatWidth());
+ }
+
+ float SafeGet(uint b, uint2 pos, uint ch, uint2 pad)
+ {
+ if (b >= batch || ch >= channels) return 0;
+
+ if (any(pos < pad)) return 0;
+ if (any(pos >= uint2(width, height) + pad)) return 0;
+ pos -= pad;
+
+ return data[Index(b, pos.y, pos.x, ch)];
+ }
+ float SafeGet(uint b, uint h, uint w, uint ch, uint2 pad)
+ {
+ return SafeGet(b, uint2(w, h), ch, pad);
+ }
+ float SafeGet(uint b, uint i)
+ {
+ if (b >= batch || i >= height * width * channels) return 0;
+ return Get(b,i);
+ }
+ float SafeGet(uint i)
+ {
+ if (i >= batch * height * width * channels) return 0;
+ return Get(i);
+ }
+};
+
+struct ReadWriteTensor : Tensor
+{
+ RWStructuredBuffer data;
+
+ void Init(int4 nhwc, RWStructuredBuffer data_)
+ {
+ Tensor::Init(nhwc);
+ data = data_;
+ }
+
+ float Get(uint b, uint h, uint w, uint ch)
+ {
+ return data[Index(b,h,w,ch)];
+ }
+ float Get(uint b, uint2 pos, uint ch)
+ {
+ return data[Index(b, pos.y, pos.x, ch)];
+ }
+ float Get(uint b, uint i)
+ {
+ return data[Index(b,i)];
+ }
+ float Get(uint i)
+ {
+ return data[i];
+ }
+
+ float BroadcastGet(uint b, uint h, uint w, uint ch)
+ {
+ return Get(b % batch, h % height, w % width, ch % channels);
+ }
+ float BroadcastGet(uint b, uint2 pos, uint ch)
+ {
+ return BroadcastGet(b, pos.y, pos.x, ch);
+ }
+ float BroadcastGet(uint b, uint i)
+ {
+ return Get(b % GetFlatHeight(), i % GetFlatWidth());
+ }
+
+ float SafeGet(uint b, uint2 pos, uint ch, uint2 pad)
+ {
+ if (b >= batch || ch >= channels) return 0;
+
+ if (any(pos < pad)) return 0;
+ if (any(pos >= uint2(width, height) + pad)) return 0;
+ pos -= pad;
+
+ return Get(b, pos.y, pos.x, ch);
+ }
+ float SafeGet(uint b, uint h, uint w, uint ch, uint2 pad)
+ {
+ return SafeGet(b, uint2(w, h), ch, pad);
+ }
+ float SafeGet(uint b, uint i)
+ {
+ if (b >= batch || i >= height * width * channels) return 0;
+ return Get(b,i);
+ }
+ float SafeGet(uint i)
+ {
+ if (i >= batch * height * width * channels) return 0;
+ return Get(i);
+ }
+
+
+ void Set(uint b, uint h, uint w, uint ch, float v)
+ {
+ data[Index(b,h,w,ch)] = v;
+ }
+ void Set(uint y, uint x, float v)
+ {
+ data[Index(y,x)] = v;
+ }
+ void Set(uint i, float v)
+ {
+ data[i] = v;
+ }
+};
+
+struct SharedTensor : Tensor
+{
+ StructuredBuffer data;
+ uint offset;
+
+ void Init(uint4 nhwc, uint4 info, StructuredBuffer data_)
+ {
+ Tensor::Init(nhwc);
+ data = data_;
+ offset = info.x;
+ }
+
+ float Get(uint b, uint h, uint w, uint ch)
+ {
+ return data[Index(b,h,w,ch) + offset];
+ }
+ float Get(uint b, uint2 pos, uint ch)
+ {
+ return Get(b, pos.y, pos.x, ch);
+ }
+ float Get(uint b, uint i)
+ {
+ return data[Index(b,i) + offset];
+ }
+ float Get(uint i)
+ {
+ return data[i + offset];
+ }
+
+ float BroadcastGet(uint b, uint h, uint w, uint ch)
+ {
+ return Get(b % batch, h % height, w % width, ch % channels);
+ }
+ float BroadcastGet(uint b, uint2 pos, uint ch)
+ {
+ return BroadcastGet(b, pos.y, pos.x, ch);
+ }
+ float BroadcastGet(uint b, uint i)
+ {
+ return Get(b % GetFlatHeight(), i % GetFlatWidth());
+ }
+
+ float SafeGet(uint b, uint2 pos, uint ch, uint2 pad)
+ {
+ if (b >= batch || ch >= channels) return 0;
+
+ if (any(pos < pad)) return 0;
+ if (any(pos >= uint2(width, height) + pad)) return 0;
+ pos -= pad;
+
+ return Get(b, pos, ch);
+ }
+ float SafeGet(uint b, uint h, uint w, uint ch, uint2 pad)
+ {
+ return SafeGet(b, uint2(w, h), ch, pad);
+ }
+ float SafeGet(uint b, uint i)
+ {
+ if (b >= batch || i >= height * width * channels) return 0;
+ return Get(b,i);
+ }
+ float SafeGet(uint i)
+ {
+ if (i >= batch * height * width * channels) return 0;
+ return Get(i);
+ }
+};
+
+#define TENSOR_DECL(X) uint4 X##decl[2]; StructuredBuffer X##data;
+#define TENSOR_DECL_RW(X) uint4 X ## decl[2]; RWStructuredBuffer X ## data;
+
+#define TENSOR_ARG(X) ReadonlyTensor X; X##.Init(X##decl[0], X##data); // readonly
+#define TENSOR_MODEL(X) SharedTensor X; X##.Init(X##decl[0], X##decl[1], X##data); // RO w offset
+#define TENSOR_ARG_RW(X) ReadWriteTensor X; X##.Init(X##decl[0], X##data);
+
+#define TENSOR_ARGS2(X, O) TENSOR_ARG(X); TENSOR_ARG_RW(O);
+#define TENSOR_ARGS3(X, A, O) TENSOR_ARG(X); TENSOR_MODEL(A); TENSOR_ARG_RW(O);
+#define TENSOR_ARGS4(X, A, B, O) TENSOR_ARG(X); TENSOR_MODEL(A); TENSOR_MODEL(B); TENSOR_ARG_RW(O);
+
+// shared model tensors
+#define TENSOR_SHARED_MODEL(X, S) SharedTensor X; X##.Init(X##decl[0], X##decl[1], S##data);
+#define TENSOR_SHARED2_ARGS4(X, A, B, S, O) TENSOR_ARG(X); TENSOR_SHARED_MODEL(A, S); TENSOR_SHARED_MODEL(B, S); TENSOR_ARG_RW(O);
+
+
+// purely informational - declares contract between caller of Dispatch() and kernel
+#define DISPATCH_ARGS(threadGroupsX, threadGroupsY, threadGroupsZ)
+
+
+// @TODO: move into more appropriate file
+#define FLT_MAX 3.402823466e+38F
+#define FLT_EPSILON 1e-6
+
+float fastfma(float a, float b, float c)
+{
+ return dot(float2(a,c), float2(b, 1));
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc.meta
new file mode 100755
index 00000000..c611dd01
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 5761abd87a16940b2a81aaa755787fc9
+timeCreated: 1506540305
+licenseType: Pro
+ShaderImporter:
+ defaultTextures: []
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute
new file mode 100755
index 00000000..a93817d1
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute
@@ -0,0 +1,99 @@
+#pragma kernel TexConv2D
+
+#include "Tensor.cginc"
+
+TENSOR_DECL(X)
+TENSOR_DECL(K)
+TENSOR_DECL(B)
+TENSOR_DECL(WBK)
+TENSOR_DECL_RW(O)
+
+uint4 _Pad;
+uint4 _Stride;
+
+struct TextureAsTensor : Tensor
+{
+ Texture2D tex;
+ SamplerState smp;
+
+ Texture2DArray texArray;
+ SamplerState smpArray;
+
+ void Init(uint4 nhwc, Texture2D tex_, SamplerState sampler_, Texture2DArray texArray_, SamplerState samplerArray_)
+ {
+ Tensor::Init(nhwc);
+ tex = tex_;
+ smp = sampler_;
+ texArray = texArray_;
+ smpArray = samplerArray_;
+ }
+
+ float4 Get(uint b, uint y, uint x)
+ {
+ float3 loc = float3((float)x / (float)width, (float)y / (float)height, b);
+ if (batch > 1)
+ return texArray.SampleLevel(smpArray, loc, 0);
+ else
+ return tex.SampleLevel(smp, loc.xy, 0);
+ }
+};
+
+#define TENSOR_SHARED2_ARGS3(A, B, S, O) TENSOR_SHARED_ARG(A, S); TENSOR_SHARED_ARG(B, S); TENSOR_ARG_RW(O);
+Texture2DArray Xtex2DArray;
+Texture2D Xtex2D;
+SamplerState samplerXtex2D { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
+SamplerState samplerXtex2DArray { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
+
+#define MAX_CHANNELS 4
+
+NUMTHREADS((16,4,4), (16,4,2), (16,2,2))
+void TexConv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+// @TODO: currently it fails to compile, needs to be investigated
+#if 0
+ DISPATCH_ARGS(K.kernelCount, O.width, O.height);
+ TextureAsTensor X; X.Init(Xdecl[0], Xtex2D, samplerXtex2D, Xtex2DArray, samplerXtex2DArray);
+
+ TENSOR_SHARED_ARG(K, WBK);
+ TENSOR_SHARED_ARG(B, WBK);
+ TENSOR_ARG_RW(O);
+
+ // ASSERT(X.channels <= MAX_CHANNELS)
+
+ uint k = dispatchThreadID.x;
+ uint x = dispatchThreadID.y;
+ uint y = dispatchThreadID.z;
+
+ if (k >= K.channels) return;
+ if (x >= O.width) return;
+ if (y >= O.height) return;
+
+ for (uint n = 0; n < O.batch; ++n)
+ {
+ float acc = B.Get(k);
+ for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
+ {
+ for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
+ {
+ uint oy = y * _Stride.y + dy;
+ uint ox = x * _Stride.x + dx;
+
+ // @TODO: investigate
+ // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
+ if (oy < _Pad.y) continue;
+ if (oy - _Pad.w >= X.height) continue;
+ if (ox < _Pad.x) continue;
+ if (ox - _Pad.z >= X.width) continue;
+
+ float4 in4channels = X.Get(n, oy - _Pad.y, ox - _Pad.x);
+ for (uint c = 0; c < X.channels && c < MAX_CHANNELS; ++c)
+ {
+ acc += in4channels[c] * K.Get(dy, dx, c, k);
+ }
+ }
+ }
+
+ O.Set(n, y, x, k, acc);
+ }
+#endif
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute.meta
new file mode 100755
index 00000000..38baaf96
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 85d38d76f835143f797bca1481285596
+timeCreated: 1507637303
+licenseType: Pro
+ComputeShaderImporter:
+ currentAPIMask: 196608
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/LICENSE.md b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/LICENSE.md
new file mode 100755
index 00000000..855b4276
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/LICENSE.md
@@ -0,0 +1,6 @@
+Barracuda cross-platform Neural Net engine copyright © 2018 Unity Technologies ApS
+
+Licensed under the Unity Companion License for Unity-dependent projects--see [Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).
+
+Unless expressly provided otherwise, the Software under this license is made available strictly on an “AS IS” BASIS WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. Please review the license for details on these and other terms and conditions.
+
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/LICENSE.md.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/LICENSE.md.meta
new file mode 100755
index 00000000..a68e6e46
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/LICENSE.md.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: dcc5ce8caa7664f8090ef0103a208c6e
+TextScriptImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/ReleaseNotes.md b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/ReleaseNotes.md
new file mode 100755
index 00000000..8707d740
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/ReleaseNotes.md
@@ -0,0 +1,82 @@
+# Release notes
+
+## 0.1.6
+- Added activation type print in verbose mode
+- Added fast and parallel CPU implementation for Swish, Relu, Add, Sub, Div, Min, Max, Tanh, Exp
+- Removed duplicate profiler blocks for ops
+- Improved scheduling on CPU for small batches of data
+- Fixed compatibility with Unity 2019.2.x
+
+## 0.1.5
+- Added Transpose, MatMul and Indentity layer support for models exported from ONNX.
+- Added BasicLSTM layer support for models exported from TF. Limited set of LSTM networks should work now.
+- Added DepthwiseConv2D layer support. Most of the networks based on the MobileNet should work now.
+- Added OneHot layer support for models exported from TF.
+- Added optimized path for Conv2D, Dense and Transpose layers with single batch executions. Performance gain up to 100%.
+- Fixed FMA performance issue on Metal GFX platforms.
+- Added fast optimized path for Sigmoid and Mul layers on CPU.
+- Fixed issue when worker is executed with different batch sizes.
+- Added ``pip`` requirements file for Python dependencies, check ``Tools/requirements.txt```.
+- Added proof of concept Docker wrappers for running model conversion inside of Docker container. Check ``Tools/docker-tensorflow-to-barracuda.sh`` and ``Tools/docker-onnx-to-barracuda.sh``. Currently it was tested only on Mac host.
+- Refactored model importers for easier integration with ML Agents.
+- Fixed input shape determination for Keras sequential model.
+- Added metadata about input shapes to model. Look for ``Model.GetShapeByName()``.
+- Added API to query constant Tensors embedded into network, look for ``Model.GetTensorByName()``.
+- Added reference implementations for Selu, Abs, Neg, Ceil, Floor, Clip, Rcp, Log layers.
+- Added support for Mean, Square, StridedSlice and Border2D layers.
+- Added support for Swish activation, now it is automatically detected in models.
+- Fixed Tanh NaN issue when large argument is passed.
+- RandomNormal and RandomUniform now supports either embedded shape constant OR previous tensor shape for input.
+- Fixed Keras/TF/ONNX FusedBatchNorm/BatchNorm import and now it takes ``epsilon`` into account.
+- Now Barracuda will fallback to CSharpFast if compute shaders are not supported on the current platform.
+- Improved compute kernel interop on Android.
+- Implemented Pix2Pix model (.pict) importer.
+
+## 0.1.4
+- Implemented fast Conv2DTrans. Useful for GAN type networks.
+- Fixed few ComputeBuffer handling issues.
+- Simplified way to pass texture via ``Tensor`` constructor.
+- Documentation improvements.
+- Added Unity Companion License as part of distribution.
+- Fixed boundary checks for Compute Copy/Concat operations.
+- Improved profiling experience, now each layer will be reported separately in Unity Profiler.
+- Fixed Broadcast layer support in ``ModelAnalyzer``.
+- Exp, Pow and other layers are now also implemented in Compute. Improves RL model inference performance on GPU.
+- Added platform specific BLAS plugin support. Out of the box Barracuda ships with Apple Accelerate framework support for iOS and macOS.
+- Added Burst BLAS plugin, greatly improves performance in Unity Editor where native OS BLAS is not available. It's packaged as separate package and requires to have Burst enabled.
+- Improved memory handling, now less GC allocations should be made per inference execution.
+
+## 0.1.3
+- Improved Barracuda support for Unity Profiler.
+- Cleaned up Barracuda APIs.
+- Added direct ``Texture`` input support. Look for ``TextureAsTensorData``. The following types of texture supported as input: ``Texture2D``, ``Texture2DArray``, ``Texture3D``, ``RenderTexture``.
+- Added ``Tensor`` to ``RenderTexture`` conversion. Look for ``TensorToRenderTexture``.
+- Autoencoder type networks can run completely on GPU now. Data roundtrip via CPU is not necessary anymore.
+- Vertical flip is applied when converting between ``Texture`` and ``Tensor`` to match conventionts. To override this behavior look for ``TextureAsTensorData.Flip`` enum.
+- Removed direct reference to WebCamTexture, now Barracuda compiles for Console targets.
+- Fixed _Conv2DTranspose_ layer support. Now GANs using _Conv2DTranspose_ work properly.
+- Added automated test for pix2pix GAN.
+
+## 0.1.2
+- Barracuda now is also available as preview package. Look for ``com.unity.barracuda`` in https://staging-packages.unity.com registry.
+- Conv2D layers are now *up to 30x faster* with ``CSharpFast`` backend (``ComputeFast`` remains best backend for convolutional networks).
+- Added profiler sample for ``Fetch()``.
+- Fixed compilation issues on Xbox One.
+- TexConv2D support was temporary disabled.
+- Barracuda logging now can be configured via static fields of ``Barracuda.D`` class, it allows both disable specific logging levels or just disable stack trace collection (helps with performance when profiling).
+- Compute Concat implementation now will fall back to C# implementation instead of throwing exception when unsupported configuration is encountered.
+- Fixed several ``ComputeBuffer`` release issues.
+- Added constructor for ``Tensor`` that allows to pass in data array.
+- Improved Flatten handling in TensorFlow models.
+- Added helper func ``ModelLoader.LoadFromStreamingAssets``.
+- Fixed .meta file packaging.
+- Small docs improvements.
+- Fixed unnecessary patching of Activation layers in ``ModelLoader``.
+- Added output trimming at run-time. See for extra parameters Worker factory.
+
+## 0.1.1
+- First internal realease as drop-in package
+- Compatibility with ML Agents models: 3DBall, PushBlock, GridWorld, Soccer.
+
+## 0.1.0
+- First internal build. Due some bugs encountered wasn't published.
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/ReleaseNotes.md.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/ReleaseNotes.md.meta
new file mode 100755
index 00000000..2d0ff280
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/ReleaseNotes.md.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: a129912fffc9d4ab3b5ae110be67a669
+TextScriptImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/package.json b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/package.json
new file mode 100755
index 00000000..e9e24727
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/package.json
@@ -0,0 +1,8 @@
+{
+ "name": "com.unity.barracuda",
+ "displayName": "Barracuda",
+ "version": "0.1.6-preview",
+ "unity": "2017.4",
+ "description": "Barracuda is lightweight and cross-platform Neural Net inference library. Barracuda supports inference both on GPU and CPU.",
+ "dependencies": {}
+}
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/package.json.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/package.json.meta
new file mode 100755
index 00000000..e4c32c93
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/Barracuda.Core/package.json.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: 73ae2d877fd444b04b5b6ef591d3fa0e
+TextScriptImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer.meta
new file mode 100755
index 00000000..af0fdcb1
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: e44343d7e31b04d47bd5f7329c918ffe
+folderAsset: yes
+timeCreated: 1521839636
+licenseType: Free
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Google.Protobuf.dll b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Google.Protobuf.dll
new file mode 100755
index 00000000..6ea720de
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Google.Protobuf.dll differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Google.Protobuf.dll.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Google.Protobuf.dll.meta
new file mode 100755
index 00000000..e0850422
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Google.Protobuf.dll.meta
@@ -0,0 +1,30 @@
+fileFormatVersion: 2
+guid: 0836ffd04a4924861a2d58aa4b111937
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ Any:
+ second:
+ enabled: 1
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 0
+ settings:
+ DefaultValueInitialized: true
+ - first:
+ Windows Store Apps: WindowsStoreApps
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Grpc.Core.dll b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Grpc.Core.dll
new file mode 100755
index 00000000..601f87c2
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Grpc.Core.dll differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Grpc.Core.dll.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Grpc.Core.dll.meta
new file mode 100755
index 00000000..3163a0cd
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/Grpc.Core.dll.meta
@@ -0,0 +1,106 @@
+fileFormatVersion: 2
+guid: cbf24ddeec4054edc9ad4c8295556878
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ '': Any
+ second:
+ enabled: 0
+ settings:
+ Exclude Android: 1
+ Exclude Editor: 0
+ Exclude Linux: 0
+ Exclude Linux64: 0
+ Exclude LinuxUniversal: 0
+ Exclude OSXUniversal: 0
+ Exclude Win: 0
+ Exclude Win64: 0
+ Exclude iOS: 1
+ - first:
+ Android: Android
+ second:
+ enabled: 0
+ settings:
+ CPU: ARMv7
+ - first:
+ Any:
+ second:
+ enabled: 0
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ DefaultValueInitialized: true
+ OS: AnyOS
+ - first:
+ Facebook: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Facebook: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Linux
+ second:
+ enabled: 1
+ settings:
+ CPU: x86
+ - first:
+ Standalone: Linux64
+ second:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ - first:
+ Standalone: LinuxUniversal
+ second:
+ enabled: 1
+ settings: {}
+ - first:
+ Standalone: OSXUniversal
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win64
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Windows Store Apps: WindowsStoreApps
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ iPhone: iOS
+ second:
+ enabled: 0
+ settings:
+ CompileFlags:
+ FrameworkDependencies:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/System.Interactive.Async.dll b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/System.Interactive.Async.dll
new file mode 100755
index 00000000..364a99c3
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/System.Interactive.Async.dll differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta
new file mode 100755
index 00000000..1ee8b2e1
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/System.Interactive.Async.dll.meta
@@ -0,0 +1,32 @@
+fileFormatVersion: 2
+guid: 9502ce7e38c5947dba996570732b6e9f
+timeCreated: 1521661784
+licenseType: Free
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ Any:
+ second:
+ enabled: 1
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 0
+ settings:
+ DefaultValueInitialized: true
+ - first:
+ Windows Store Apps: WindowsStoreApps
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes.meta
new file mode 100755
index 00000000..6995400a
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: b8022add2e5264884a117894eeaf9809
+folderAsset: yes
+timeCreated: 1521595360
+licenseType: Free
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux.meta
new file mode 100755
index 00000000..97848b12
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: 50c3602c6f6244621861928757e31463
+folderAsset: yes
+timeCreated: 1521595360
+licenseType: Free
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native.meta
new file mode 100755
index 00000000..a8b33def
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: ba192b1e561564e1583e0a87334f8682
+folderAsset: yes
+timeCreated: 1521595360
+licenseType: Free
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so
new file mode 100755
index 00000000..9bf86dc2
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta
new file mode 100755
index 00000000..62496d62
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x64.so.meta
@@ -0,0 +1,102 @@
+fileFormatVersion: 2
+guid: c9d901caf522f4dc5815786fa764a5da
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ '': Any
+ second:
+ enabled: 0
+ settings:
+ Exclude Android: 1
+ Exclude Editor: 0
+ Exclude Linux: 1
+ Exclude Linux64: 0
+ Exclude LinuxUniversal: 0
+ Exclude OSXUniversal: 1
+ Exclude Win: 0
+ Exclude Win64: 0
+ Exclude iOS: 1
+ - first:
+ Android: Android
+ second:
+ enabled: 0
+ settings:
+ CPU: ARMv7
+ - first:
+ Any:
+ second:
+ enabled: 0
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ - first:
+ Facebook: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Facebook: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Linux
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ - first:
+ Standalone: Linux64
+ second:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ - first:
+ Standalone: LinuxUniversal
+ second:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ - first:
+ Standalone: OSXUniversal
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ - first:
+ Standalone: Win
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win64
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ iPhone: iOS
+ second:
+ enabled: 0
+ settings:
+ AddToEmbeddedBinaries: false
+ CompileFlags:
+ FrameworkDependencies:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so
new file mode 100755
index 00000000..fce30416
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta
new file mode 100755
index 00000000..f612ded0
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/linux/native/libgrpc_csharp_ext.x86.so.meta
@@ -0,0 +1,102 @@
+fileFormatVersion: 2
+guid: 7dfb52431a6d941c89758cf0a217e3ab
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ '': Any
+ second:
+ enabled: 0
+ settings:
+ Exclude Android: 1
+ Exclude Editor: 0
+ Exclude Linux: 0
+ Exclude Linux64: 1
+ Exclude LinuxUniversal: 0
+ Exclude OSXUniversal: 1
+ Exclude Win: 0
+ Exclude Win64: 0
+ Exclude iOS: 1
+ - first:
+ Android: Android
+ second:
+ enabled: 0
+ settings:
+ CPU: ARMv7
+ - first:
+ Any:
+ second:
+ enabled: 0
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 1
+ settings:
+ CPU: x86
+ DefaultValueInitialized: true
+ OS: Linux
+ - first:
+ Facebook: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Facebook: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Linux
+ second:
+ enabled: 1
+ settings:
+ CPU: x86
+ - first:
+ Standalone: Linux64
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ - first:
+ Standalone: LinuxUniversal
+ second:
+ enabled: 1
+ settings:
+ CPU: x86
+ - first:
+ Standalone: OSXUniversal
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ - first:
+ Standalone: Win
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win64
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ iPhone: iOS
+ second:
+ enabled: 0
+ settings:
+ AddToEmbeddedBinaries: false
+ CompileFlags:
+ FrameworkDependencies:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx.meta
new file mode 100755
index 00000000..69cbe8ef
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: f43fa6e62fb4c4105b270be1ae7bbbfd
+folderAsset: yes
+timeCreated: 1521595360
+licenseType: Free
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native.meta
new file mode 100755
index 00000000..24fab959
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: 55aee008fb6a3411aa96f2f9911f9207
+folderAsset: yes
+timeCreated: 1521595360
+licenseType: Free
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle
new file mode 100755
index 00000000..58390e6c
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta
new file mode 100755
index 00000000..87f002f6
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/osx/native/libgrpc_csharp_ext.x64.bundle.meta
@@ -0,0 +1,142 @@
+fileFormatVersion: 2
+guid: 7eeb863bd08ba4388829c23da03a714f
+PluginImporter:
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ data:
+ first:
+ '': Any
+ second:
+ enabled: 0
+ settings:
+ Exclude Android: 1
+ Exclude Editor: 0
+ Exclude Linux: 1
+ Exclude Linux64: 1
+ Exclude LinuxUniversal: 1
+ Exclude OSXIntel: 0
+ Exclude OSXIntel64: 0
+ Exclude OSXUniversal: 0
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude iOS: 1
+ data:
+ first:
+ '': OSXIntel
+ second:
+ enabled: 1
+ settings: {}
+ data:
+ first:
+ '': OSXIntel64
+ second:
+ enabled: 1
+ settings: {}
+ data:
+ first:
+ Android: Android
+ second:
+ enabled: 0
+ settings:
+ CPU: ARMv7
+ data:
+ first:
+ Any:
+ second:
+ enabled: 0
+ settings: {}
+ data:
+ first:
+ Editor: Editor
+ second:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: OSX
+ data:
+ first:
+ Facebook: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ data:
+ first:
+ Facebook: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ data:
+ first:
+ Standalone: Linux
+ second:
+ enabled: 0
+ settings:
+ CPU: x86
+ data:
+ first:
+ Standalone: Linux64
+ second:
+ enabled: 0
+ settings:
+ CPU: x86_64
+ data:
+ first:
+ Standalone: LinuxUniversal
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ data:
+ first:
+ Standalone: OSXIntel
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ data:
+ first:
+ Standalone: OSXIntel64
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ data:
+ first:
+ Standalone: OSXUniversal
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ data:
+ first:
+ Standalone: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ data:
+ first:
+ Standalone: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ data:
+ first:
+ iPhone: iOS
+ second:
+ enabled: 0
+ settings:
+ AddToEmbeddedBinaries: false
+ CompileFlags:
+ FrameworkDependencies:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win.meta
new file mode 100755
index 00000000..b1e54c9a
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: a961485c3484a4002ac4961a8481f6cc
+folderAsset: yes
+timeCreated: 1521595360
+licenseType: Free
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native.meta
new file mode 100755
index 00000000..42e4968a
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native.meta
@@ -0,0 +1,10 @@
+fileFormatVersion: 2
+guid: af9f9f367bbc543b8ba41e58dcdd6e66
+folderAsset: yes
+timeCreated: 1521595360
+licenseType: Free
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll
new file mode 100755
index 00000000..b2e48711
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta
new file mode 100755
index 00000000..888979c7
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x64.dll.meta
@@ -0,0 +1,102 @@
+fileFormatVersion: 2
+guid: f4d9429fe43154fbd9d158c129e0ff33
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ '': Any
+ second:
+ enabled: 0
+ settings:
+ Exclude Android: 1
+ Exclude Editor: 0
+ Exclude Linux: 0
+ Exclude Linux64: 0
+ Exclude LinuxUniversal: 0
+ Exclude OSXUniversal: 0
+ Exclude Win: 1
+ Exclude Win64: 0
+ Exclude iOS: 1
+ - first:
+ Android: Android
+ second:
+ enabled: 0
+ settings:
+ CPU: ARMv7
+ - first:
+ Any:
+ second:
+ enabled: 0
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Windows
+ - first:
+ Facebook: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ - first:
+ Facebook: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Linux
+ second:
+ enabled: 1
+ settings:
+ CPU: x86
+ - first:
+ Standalone: Linux64
+ second:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ - first:
+ Standalone: LinuxUniversal
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: OSXUniversal
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ - first:
+ Standalone: Win64
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ iPhone: iOS
+ second:
+ enabled: 0
+ settings:
+ AddToEmbeddedBinaries: false
+ CompileFlags:
+ FrameworkDependencies:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll
new file mode 100755
index 00000000..45d5c324
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta
new file mode 100755
index 00000000..9c7036f3
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/ProtoBuffer/runtimes/win/native/grpc_csharp_ext.x86.dll.meta
@@ -0,0 +1,102 @@
+fileFormatVersion: 2
+guid: d74134114def74fb4ae781c015deaa95
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ '': Any
+ second:
+ enabled: 0
+ settings:
+ Exclude Android: 1
+ Exclude Editor: 0
+ Exclude Linux: 0
+ Exclude Linux64: 0
+ Exclude LinuxUniversal: 0
+ Exclude OSXUniversal: 0
+ Exclude Win: 0
+ Exclude Win64: 1
+ Exclude iOS: 1
+ - first:
+ Android: Android
+ second:
+ enabled: 0
+ settings:
+ CPU: ARMv7
+ - first:
+ Any:
+ second:
+ enabled: 0
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 1
+ settings:
+ CPU: x86
+ DefaultValueInitialized: true
+ OS: Windows
+ - first:
+ Facebook: Win
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ - first:
+ Facebook: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ - first:
+ Standalone: Linux
+ second:
+ enabled: 1
+ settings:
+ CPU: x86
+ - first:
+ Standalone: Linux64
+ second:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ - first:
+ Standalone: LinuxUniversal
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: OSXUniversal
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win
+ second:
+ enabled: 1
+ settings:
+ CPU: AnyCPU
+ - first:
+ Standalone: Win64
+ second:
+ enabled: 0
+ settings:
+ CPU: None
+ - first:
+ iPhone: iOS
+ second:
+ enabled: 0
+ settings:
+ AddToEmbeddedBinaries: false
+ CompileFlags:
+ FrameworkDependencies:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.TestingHelpers.dll b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.TestingHelpers.dll
new file mode 100755
index 00000000..0d2b68f2
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.TestingHelpers.dll differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.TestingHelpers.dll.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.TestingHelpers.dll.meta
new file mode 100755
index 00000000..7a9871ad
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.TestingHelpers.dll.meta
@@ -0,0 +1,30 @@
+fileFormatVersion: 2
+guid: 2d7ba4e1037b64de5b860bcbe15755b3
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ Any:
+ second:
+ enabled: 1
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 0
+ settings:
+ DefaultValueInitialized: true
+ - first:
+ Windows Store Apps: WindowsStoreApps
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.dll b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.dll
new file mode 100755
index 00000000..4fe6ccbf
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.dll differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.dll.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.dll.meta
new file mode 100755
index 00000000..d3d9b5d0
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Plugins/System.IO.Abstractions.dll.meta
@@ -0,0 +1,30 @@
+fileFormatVersion: 2
+guid: b01205587773841ad95e8ceda347e8bd
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ iconMap: {}
+ executionOrder: {}
+ isPreloaded: 0
+ isOverridable: 0
+ platformData:
+ - first:
+ Any:
+ second:
+ enabled: 1
+ settings: {}
+ - first:
+ Editor: Editor
+ second:
+ enabled: 0
+ settings:
+ DefaultValueInitialized: true
+ - first:
+ Windows Store Apps: WindowsStoreApps
+ second:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources.meta
new file mode 100755
index 00000000..da5233ae
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 1b3ab22264a5447df9e52684598ac3b0
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/DemoIcon.png b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/DemoIcon.png
new file mode 100755
index 00000000..ddc91181
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/DemoIcon.png differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/DemoIcon.png.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/DemoIcon.png.meta
new file mode 100755
index 00000000..37831fb2
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/DemoIcon.png.meta
@@ -0,0 +1,86 @@
+fileFormatVersion: 2
+guid: 3352a0e8d253b4a4ea3782a6d7e09d9b
+TextureImporter:
+ fileIDToRecycleName: {}
+ externalObjects: {}
+ serializedVersion: 4
+ mipmaps:
+ mipMapMode: 0
+ enableMipMap: 1
+ sRGBTexture: 1
+ linearTexture: 0
+ fadeOut: 0
+ borderMipMap: 0
+ mipMapsPreserveCoverage: 0
+ alphaTestReferenceValue: 0.5
+ mipMapFadeDistanceStart: 1
+ mipMapFadeDistanceEnd: 3
+ bumpmap:
+ convertToNormalMap: 0
+ externalNormalMap: 0
+ heightScale: 0.25
+ normalMapFilter: 0
+ isReadable: 0
+ grayScaleToAlpha: 0
+ generateCubemap: 6
+ cubemapConvolution: 0
+ seamlessCubemap: 0
+ textureFormat: 1
+ maxTextureSize: 2048
+ textureSettings:
+ serializedVersion: 2
+ filterMode: -1
+ aniso: -1
+ mipBias: -1
+ wrapU: -1
+ wrapV: -1
+ wrapW: -1
+ nPOTScale: 1
+ lightmap: 0
+ compressionQuality: 50
+ spriteMode: 0
+ spriteExtrude: 1
+ spriteMeshType: 1
+ alignment: 0
+ spritePivot: {x: 0.5, y: 0.5}
+ spritePixelsToUnits: 100
+ spriteBorder: {x: 0, y: 0, z: 0, w: 0}
+ spriteGenerateFallbackPhysicsShape: 1
+ alphaUsage: 1
+ alphaIsTransparency: 1
+ spriteTessellationDetail: -1
+ textureType: 0
+ textureShape: 1
+ maxTextureSizeSet: 0
+ compressionQualitySet: 0
+ textureFormatSet: 0
+ platformSettings:
+ - buildTarget: DefaultTexturePlatform
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ androidETC2FallbackOverride: 0
+ - buildTarget: Standalone
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ androidETC2FallbackOverride: 0
+ spriteSheet:
+ serializedVersion: 2
+ sprites: []
+ outline: []
+ physicsShape: []
+ spritePackingTag:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/NNModelIcon.png b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/NNModelIcon.png
new file mode 100755
index 00000000..10434c27
Binary files /dev/null and b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/NNModelIcon.png differ
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/NNModelIcon.png.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/NNModelIcon.png.meta
new file mode 100755
index 00000000..9a88c6d1
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Resources/NNModelIcon.png.meta
@@ -0,0 +1,106 @@
+fileFormatVersion: 2
+guid: 8682ff569c4c7457a8a8e3a527aad537
+TextureImporter:
+ fileIDToRecycleName: {}
+ externalObjects: {}
+ serializedVersion: 4
+ mipmaps:
+ mipMapMode: 0
+ enableMipMap: 0
+ sRGBTexture: 0
+ linearTexture: 0
+ fadeOut: 0
+ borderMipMap: 0
+ mipMapsPreserveCoverage: 0
+ alphaTestReferenceValue: 0.5
+ mipMapFadeDistanceStart: 1
+ mipMapFadeDistanceEnd: 3
+ bumpmap:
+ convertToNormalMap: 0
+ externalNormalMap: 0
+ heightScale: 0.25
+ normalMapFilter: 0
+ isReadable: 0
+ grayScaleToAlpha: 0
+ generateCubemap: 6
+ cubemapConvolution: 0
+ seamlessCubemap: 0
+ textureFormat: 1
+ maxTextureSize: 2048
+ textureSettings:
+ serializedVersion: 2
+ filterMode: -1
+ aniso: 1
+ mipBias: -1
+ wrapU: 1
+ wrapV: 1
+ wrapW: -1
+ nPOTScale: 0
+ lightmap: 0
+ compressionQuality: 50
+ spriteMode: 0
+ spriteExtrude: 1
+ spriteMeshType: 1
+ alignment: 0
+ spritePivot: {x: 0.5, y: 0.5}
+ spritePixelsToUnits: 100
+ spriteBorder: {x: 0, y: 0, z: 0, w: 0}
+ spriteGenerateFallbackPhysicsShape: 1
+ alphaUsage: 1
+ alphaIsTransparency: 1
+ spriteTessellationDetail: -1
+ textureType: 2
+ textureShape: 1
+ maxTextureSizeSet: 0
+ compressionQualitySet: 0
+ textureFormatSet: 0
+ platformSettings:
+ - buildTarget: DefaultTexturePlatform
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ androidETC2FallbackOverride: 0
+ - buildTarget: Standalone
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ androidETC2FallbackOverride: 0
+ - buildTarget: iPhone
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ androidETC2FallbackOverride: 0
+ - buildTarget: Android
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ androidETC2FallbackOverride: 0
+ spriteSheet:
+ serializedVersion: 2
+ sprites: []
+ outline: []
+ physicsShape: []
+ spritePackingTag:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts.meta
new file mode 100755
index 00000000..779df48e
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts.meta
@@ -0,0 +1,9 @@
+fileFormatVersion: 2
+guid: 9a3740bf890474fc9857a8ec39739a35
+folderAsset: yes
+timeCreated: 1502223516
+licenseType: Free
+DefaultImporter:
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Academy.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Academy.cs
new file mode 100755
index 00000000..aba87ce0
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Academy.cs
@@ -0,0 +1,658 @@
+using System.Collections.Generic;
+using UnityEngine;
+using System.IO;
+using System.Linq;
+using ArenasParameters;
+#if UNITY_EDITOR
+using UnityEditor;
+using Google.Protobuf.Collections;
+
+#endif
+
+/**
+ * Welcome to Unity Machine Learning Agents (ML-Agents).
+ *
+ * The ML-Agents toolkit contains five entities: Academy, Brain, Agent, Communicator and
+ * Python API. The academy, and all its brains and connected agents live within
+ * a learning environment (herin called Environment), while the communicator
+ * manages the communication between the learning environment and the Python
+ * API. For more information on each of these entities, in addition to how to
+ * set-up a learning environment and train the behavior of characters in a
+ * Unity scene, please browse our documentation pages on GitHub:
+ * https://github.com/Unity-Technologies/ml-agents/blob/master/docs/
+ */
+
+namespace MLAgents
+{
+ ///
+ /// Wraps the environment-level parameters that are provided within the
+ /// Editor. These parameters can be provided for training and inference
+ /// modes separately and represent screen resolution, rendering quality and
+ /// frame rate.
+ ///
+ [System.Serializable]
+ public class EnvironmentConfiguration
+ {
+ [Tooltip("Width of the environment window in pixels.")]
+ public int width;
+
+ [Tooltip("Height of the environment window in pixels.")]
+ public int height;
+
+ [Tooltip("Rendering quality of environment. (Higher is better quality.)")] [Range(0, 5)]
+ public int qualityLevel;
+
+ [Tooltip("Speed at which environment is run. (Higher is faster.)")] [Range(1f, 100f)]
+ public float timeScale;
+
+ [Tooltip("Frames per second (FPS) engine attempts to maintain.")]
+ public int targetFrameRate;
+
+ /// Initializes a new instance of the
+ /// class.
+ /// Width of environment window (pixels).
+ /// Height of environment window (pixels).
+ ///
+ /// Rendering quality of environment. Ranges from 0 to 5, with higher.
+ ///
+ ///
+ /// Speed at which environment is run. Ranges from 1 to 100, with higher
+ /// values representing faster speed.
+ ///
+ ///
+ /// Target frame rate (per second) that the engine tries to maintain.
+ ///
+ public EnvironmentConfiguration(
+ int width, int height, int qualityLevel,
+ float timeScale, int targetFrameRate)
+ {
+ this.width = width;
+ this.height = height;
+ this.qualityLevel = qualityLevel;
+ this.timeScale = timeScale;
+ this.targetFrameRate = targetFrameRate;
+ }
+ }
+
+ ///
+ /// An Academy is where Agent objects go to train their behaviors. More
+ /// specifically, an academy is a collection of Brain objects and each agent
+ /// in a scene is attached to one brain (a single brain may be attached to
+ /// multiple agents). Currently, this class is expected to be extended to
+ /// implement the desired academy behavior.
+ ///
+ ///
+ /// When an academy is run, it can either be in inference or training mode.
+ /// The mode is determined by the presence or absence of a Communicator. In
+ /// the presence of a communicator, the academy is run in training mode where
+ /// the states and observations of each agent are sent through the
+ /// communicator. In the absence of a communciator, the academy is run in
+ /// inference mode where the agent behavior is determined by the brain
+ /// attached to it (which may be internal, heuristic or player).
+ ///
+ [HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/master/" +
+ "docs/Learning-Environment-Design-Academy.md")]
+ public abstract class Academy : MonoBehaviour
+ {
+ [SerializeField]
+ public BroadcastHub broadcastHub = new BroadcastHub();
+
+ // Flag to swith to play mode to allow particpants to look at their configuration
+ [HideInInspector]
+ public bool playerMode = false;
+
+ // Flag to swith to inference mode to allow particpants to watch their agent in action
+ [HideInInspector]
+ public bool externalInferenceMode = false;
+
+ private const string kApiVersion = "1.0";
+
+ /// Temporary storage for global gravity value
+ /// Used to restore oringal value when deriving Academy modifies it
+ private Vector3 originalGravity;
+
+ /// Temporary storage for global fixedDeltaTime value
+ /// Used to restore oringal value when deriving Academy modifies it
+ private float originalFixedDeltaTime;
+
+ /// Temporary storage for global maximumDeltaTime value
+ /// Used to restore oringal value when deriving Academy modifies it
+ private float originalMaximumDeltaTime;
+
+ // Fields provided in the Inspector
+
+ [SerializeField]
+ [Tooltip("Total number of steps per global episode.\nNon-positive " +
+ "values correspond to episodes without a maximum number of \n" +
+ "steps. Once the step counter reaches this maximum value, the " +
+ "environment will reset.")]
+ int maxSteps;
+
+ [SerializeField]
+ [Tooltip("The engine-level settings which correspond to rendering " +
+ "quality and engine speed during Training.")]
+ EnvironmentConfiguration trainingConfiguration =
+ new EnvironmentConfiguration(80, 80, 1, 100.0f, -1);
+
+ [SerializeField]
+ [Tooltip("The engine-level settings which correspond to rendering " +
+ "quality and engine speed during Inference.")]
+ EnvironmentConfiguration inferenceConfiguration =
+ new EnvironmentConfiguration(1280, 720, 5, 1.0f, 60);
+
+ ///
+ /// Contains a mapping from parameter names to float values. They are
+ /// used in and
+ /// to modify elements in the environment at reset time.
+ ///
+ ///
+ /// Default reset parameters are specified in the academy Editor, and can
+ /// be modified when training with an external Brain by passinga config
+ /// dictionary at reset.
+ ///
+ // [SerializeField]
+ // [Tooltip("List of custom parameters that can be changed in the " +
+ // "environment when it resets.")]
+ // public ResetParameters resetParameters;
+
+ /// Configurations for each arena in the environment
+ public ArenasConfigurations arenasConfigurations = new ArenasConfigurations();
+
+ // Fields not provided in the Inspector.
+
+ /// Boolean flag indicating whether a communicator is accessible by the
+ /// environment. This also specifies whether the environment is in
+ /// Training or Inference mode.
+ bool isCommunicatorOn;
+
+ /// Keeps track of the id of the last communicator message received.
+ /// Remains 0 if there are no communicators. Is used to ensure that
+ /// the same message is not used multiple times.
+ private ulong lastCommunicatorMessageNumber;
+
+ /// If true, the Academy will use inference settings. This field is
+ /// initialized in depending on the presence
+ /// or absence of a communicator. Furthermore, it can be modified by an
+ /// external Brain during reset via .
+ bool isInference = true;
+
+ /// The done flag of the academy. When set to true, the academy will
+ /// call instead of
+ /// at step time. If true, all agents done flags will be set to true.
+ bool done;
+
+ /// Whether the academy has reached the maximum number of steps for the
+ /// current episode.
+ bool maxStepReached;
+
+ /// The number of episodes completed by the environment. Incremented
+ /// each time the environment is reset.
+ int episodeCount;
+
+ /// The number of steps completed within the current episide. Incremented
+ /// each time a step is taken in the environment. Is reset to 0 during
+ /// .
+ int stepCount;
+
+ /// Flag that indicates whether the inference/training mode of the
+ /// environment was switched by the external Brain. This impacts the
+ /// engine settings at the next environment step.
+ bool modeSwitched;
+
+ /// Pointer to the batcher currently in use by the Academy.
+ MLAgents.Batcher brainBatcher;
+
+ /// Used to write error messages.
+ StreamWriter logWriter;
+
+ /// The path to where the log should be written.
+ string logPath;
+
+
+ // Flag used to keep track of the first time the Academy is reset.
+ bool firstAcademyReset;
+
+ // The Academy uses a series of events to communicate with agents and
+ // brains to facilitate synchronization. More specifically, it ensure
+ // that all the agents performs their steps in a consistent order (i.e. no
+ // agent can act based on a decision before another agent has had a chance
+ // to request a decision).
+
+ // Signals to all the Brains at each environment step so they can decide
+ // actions for their agents.
+ public event System.Action BrainDecideAction;
+
+ // Signals to all the agents at each environment step along with the
+ // Academy's maxStepReached, done and stepCount values. The agents rely
+ // on this event to update their own values of max step reached and done
+ // in addition to aligning on the step count of the global episode.
+ public event System.Action AgentSetStatus;
+
+ // Signals to all the agents at each environment step so they can reset
+ // if their flag has been set to done (assuming the agent has requested a
+ // decision).
+ public event System.Action AgentResetIfDone;
+
+ // Signals to all the agents at each environment step so they can send
+ // their state to their Brain if they have requested a decision.
+ public event System.Action AgentSendState;
+
+ // Signals to all the agents at each environment step so they can act if
+ // they have requested a decision.
+ public event System.Action AgentAct;
+
+ // Sigals to all the agents each time the Academy force resets.
+ public event System.Action AgentForceReset;
+
+ ///
+ /// Monobehavior function called at the very beginning of environment
+ /// creation. Academy uses this time to initialize internal data
+ /// structures, initialize the environment and check for the existence
+ /// of a communicator.
+ ///
+ void Awake()
+ {
+ InitializeEnvironment();
+ }
+
+ // Used to read Python-provided environment parameters
+ private int ReadArgs()
+ {
+ var args = System.Environment.GetCommandLineArgs();
+ var inputPort = "";
+ for (var i = 0; i < args.Length; i++)
+ {
+ if (args[i] == "--port")
+ {
+ inputPort = args[i + 1];
+ }
+ }
+
+ return int.Parse(inputPort);
+ }
+
+ ///
+ /// Initializes the environment, configures it and initialized the Academy.
+ ///
+ private void InitializeEnvironment()
+ {
+ originalGravity = Physics.gravity;
+ originalFixedDeltaTime = Time.fixedDeltaTime;
+ originalMaximumDeltaTime = Time.maximumDeltaTime;
+
+ InitializeAcademy();
+ Communicator communicator = null;
+
+
+ var exposedBrains = broadcastHub.broadcastingBrains.Where(x => x != null).ToList();
+ var controlledBrains = broadcastHub.broadcastingBrains.Where(
+ x => x != null && x is LearningBrain && broadcastHub.IsControlled(x));
+ foreach (LearningBrain brain in controlledBrains)
+ {
+ brain.SetToControlledExternally();
+ }
+
+ // Try to launch the communicator by usig the arguments passed at launch
+ try
+ {
+ communicator = new RPCCommunicator(
+ new CommunicatorParameters
+ {
+ port = ReadArgs()
+ });
+
+ }
+ // If it fails, we check if there are any external brains in the scene
+ // If there are : Launch the communicator on the default port
+ // If there arn't, there is no need for a communicator and it is set
+ // to null
+ catch
+ {
+ communicator = null;
+ if (controlledBrains.ToList().Count > 0)
+ {
+ communicator = new RPCCommunicator(
+ new CommunicatorParameters
+ {
+ port = 5005
+ });
+ }
+ }
+
+ brainBatcher = new Batcher(communicator);
+
+ foreach (var trainingBrain in exposedBrains)
+ {
+ trainingBrain.SetBatcher(brainBatcher);
+ }
+
+ if (communicator != null)
+ {
+ isCommunicatorOn = true;
+
+ var academyParameters =
+ new CommunicatorObjects.UnityRLInitializationOutput();
+ academyParameters.Name = gameObject.name;
+ academyParameters.Version = kApiVersion;
+ foreach (var brain in exposedBrains)
+ {
+ var bp = brain.brainParameters;
+ academyParameters.BrainParameters.Add(
+ bp.ToProto(brain.name, broadcastHub.IsControlled(brain)));
+ }
+
+ var pythonParameters = brainBatcher.SendAcademyParameters(academyParameters);
+ Random.InitState(pythonParameters.Seed);
+ Application.logMessageReceived += HandleLog;
+ logPath = Path.GetFullPath(".") + "/UnitySDK.log";
+ logWriter = new StreamWriter(logPath, false);
+ logWriter.WriteLine(System.DateTime.Now.ToString());
+ logWriter.WriteLine(" ");
+ logWriter.Close();
+
+ UpdateResetParameters();
+ if (playerMode && !externalInferenceMode)
+ {
+ broadcastHub.Clear();
+ isCommunicatorOn = false;
+ }
+
+ }
+
+ // If a communicator is enabled/provided, then we assume we are in
+ // training mode. In the absence of a communicator, we assume we are
+ // in inference mode.
+ isInference = !isCommunicatorOn;
+
+ BrainDecideAction += () => { };
+ AgentSetStatus += (m, d, i) => { };
+ AgentResetIfDone += () => { };
+ AgentSendState += () => { };
+ AgentAct += () => { };
+ AgentForceReset += () => { };
+
+ // Configure the environment using the configurations provided by
+ // the developer in the Editor.
+ SetIsInference(!brainBatcher.GetIsTraining());
+ ConfigureEnvironment();
+ AcademyStep();
+ }
+
+ private void UpdateResetParameters()
+ {
+ var newResetParameters = brainBatcher.GetArenasParameters();
+ if (newResetParameters != null)
+ {
+ foreach (KeyValuePair
+ kvp in newResetParameters.Arenas)
+ {
+ arenasConfigurations.Add(kvp.Key,kvp.Value);
+ }
+ }
+ }
+
+ void HandleLog(string logString, string stackTrace, LogType type)
+ {
+ logWriter = new StreamWriter(logPath, true);
+ logWriter.WriteLine(type.ToString());
+ logWriter.WriteLine(logString);
+ logWriter.WriteLine(stackTrace);
+ logWriter.Close();
+ }
+
+ ///
+ /// Configures the environment settings depending on the training/inference
+ /// mode and the corresponding parameters passed in the Editor.
+ ///
+ void ConfigureEnvironment()
+ {
+ if (isInference || playerMode || externalInferenceMode)
+ {
+ ConfigureEnvironmentHelper(inferenceConfiguration);
+ Monitor.SetActive(true);
+ }
+ else
+ {
+ ConfigureEnvironmentHelper(trainingConfiguration);
+ Monitor.SetActive(false);
+ }
+ }
+
+ ///
+ /// Helper method for initializing the environment based on the provided
+ /// configuration.
+ ///
+ ///
+ /// Environment configuration (specified in the Editor).
+ ///
+ static void ConfigureEnvironmentHelper(EnvironmentConfiguration config)
+ {
+ Screen.SetResolution(config.width, config.height, false);
+ QualitySettings.SetQualityLevel(config.qualityLevel, true);
+ Time.timeScale = config.timeScale;
+ Time.captureFramerate = 60;
+ Application.targetFrameRate = config.targetFrameRate;
+ }
+
+ ///
+ /// Initializes the academy and environment. Called during the waking-up
+ /// phase of the environment before any of the scene objects/agents have
+ /// been initialized.
+ ///
+ public virtual void InitializeAcademy()
+ {
+ }
+
+ ///
+ /// Specifies the academy behavior at every step of the environment.
+ ///
+ public virtual void AcademyStep()
+ {
+ }
+
+ ///
+ /// Specifies the academy behavior when being reset (i.e. at the completion
+ /// of a global episode).
+ ///
+ public virtual void AcademyReset()
+ {
+ }
+
+ ///
+ /// Returns the flag.
+ ///
+ ///
+ /// true, if current mode is inference, false if training.
+ ///
+ public bool GetIsInference()
+ {
+ return isInference;
+ }
+
+ ///
+ /// Sets the flag to the provided value. If
+ /// the new flag differs from the current flag value, this signals that
+ /// the environment configuration needs to be updated.
+ ///
+ ///
+ /// Environment mode, if true then inference, otherwise training.
+ ///
+ public void SetIsInference(bool isInference)
+ {
+ if (this.isInference != isInference)
+ {
+ this.isInference = isInference;
+
+ // This signals to the academy that at the next environment step
+ // the engine configurations need updating to the respective mode
+ // (i.e. training vs inference) configuraiton.
+ modeSwitched = true;
+ }
+ }
+
+ ///
+ /// Returns the current episode counter.
+ ///
+ ///
+ /// Current episode number.
+ ///
+ public int GetEpisodeCount()
+ {
+ return episodeCount;
+ }
+
+ ///
+ /// Returns the current step counter (within the current epside).
+ ///
+ ///
+ /// Current episode number.
+ ///
+ public int GetStepCount()
+ {
+ return stepCount;
+ }
+
+ ///
+ /// Sets the done flag to true.
+ ///
+ public void Done()
+ {
+ done = true;
+ }
+
+ ///
+ /// Returns whether or not the academy is done.
+ ///
+ ///
+ /// true, if academy is done, false otherwise.
+ ///
+ public bool IsDone()
+ {
+ return done;
+ }
+
+ ///
+ /// Returns whether or not the communicator is on.
+ ///
+ ///
+ /// true, if communicator is on, false otherwise.
+ ///
+ public bool IsCommunicatorOn()
+ {
+ return isCommunicatorOn;
+ }
+
+ ///
+ /// Forces the full reset. The done flags are not affected. Is either
+ /// called the first reset at inference and every external reset
+ /// at training.
+ ///
+ void ForcedFullReset()
+ {
+ EnvironmentReset();
+ AgentForceReset();
+ firstAcademyReset = true;
+ }
+
+ ///
+ /// Performs a single environment update to the Academy, Brain and Agent
+ /// objects within the environment.
+ ///
+ void EnvironmentStep()
+ {
+ if (modeSwitched)
+ {
+ ConfigureEnvironment();
+ modeSwitched = false;
+ }
+
+ if ((isCommunicatorOn) &&
+ (lastCommunicatorMessageNumber != brainBatcher.GetNumberMessageReceived()))
+ {
+ lastCommunicatorMessageNumber = brainBatcher.GetNumberMessageReceived();
+
+ UpdateResetParameters();
+
+ if (brainBatcher.GetCommand() ==
+ CommunicatorObjects.CommandProto.Reset)
+ {
+
+ SetIsInference(!brainBatcher.GetIsTraining());
+
+ ForcedFullReset();
+ }
+
+ if (brainBatcher.GetCommand() ==
+ CommunicatorObjects.CommandProto.Quit)
+ {
+#if UNITY_EDITOR
+ EditorApplication.isPlaying = false;
+#endif
+ Application.Quit();
+ return;
+ }
+ }
+ else if (!firstAcademyReset)
+ {
+ ForcedFullReset();
+ }
+
+ if ((stepCount >= maxSteps) && maxSteps > 0)
+ {
+ maxStepReached = true;
+ Done();
+ }
+
+ AgentSetStatus(maxStepReached, done, stepCount);
+
+ brainBatcher.RegisterAcademyDoneFlag(done);
+
+ if (done)
+ {
+ EnvironmentReset();
+ }
+
+ AgentResetIfDone();
+
+ AgentSendState();
+
+ BrainDecideAction();
+
+ AcademyStep();
+
+ AgentAct();
+
+ stepCount += 1;
+ }
+
+ ///
+ /// Resets the environment, including the Academy.
+ ///
+ void EnvironmentReset()
+ {
+ stepCount = 0;
+ episodeCount++;
+ done = false;
+ maxStepReached = false;
+ AcademyReset();
+ }
+
+ ///
+ /// Monobehavior function that dictates each environment step.
+ ///
+ void FixedUpdate()
+ {
+ EnvironmentStep();
+ }
+
+ ///
+ /// Cleanup function
+ ///
+ protected virtual void OnDestroy()
+ {
+ Physics.gravity = originalGravity;
+ Time.fixedDeltaTime = originalFixedDeltaTime;
+ Time.maximumDeltaTime = originalMaximumDeltaTime;
+ }
+ }
+}
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Academy.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Academy.cs.meta
new file mode 100755
index 00000000..e3f2f56f
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Academy.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: b1fc0029fee784d9cb9854f8912bfd07
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/ActionMasker.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/ActionMasker.cs
new file mode 100755
index 00000000..2d9c3e30
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/ActionMasker.cs
@@ -0,0 +1,138 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace MLAgents
+{
+ public class ActionMasker
+ {
+ /// When using discrete control, is the starting indices of the actions
+ /// when all the branches are concatenated with each other.
+ private int[] _startingActionIndices;
+
+ private bool[] _currentMask;
+
+ private readonly BrainParameters _brainParameters;
+
+ public ActionMasker(BrainParameters brainParameters)
+ {
+ this._brainParameters = brainParameters;
+ }
+
+ ///
+ /// Modifies an action mask for discrete control agents. When used, the agent will not be
+ /// able to perform the action passed as argument at the next decision. If no branch is
+ /// specified, the default branch will be 0. The actionIndex or actionIndices correspond
+ /// to the action the agent will be unable to perform.
+ ///
+ /// The branch for which the actions will be masked
+ /// The indices of the masked actions
+ public void SetActionMask(int branch, IEnumerable actionIndices)
+ {
+ // If the branch does not exist, raise an error
+ if (branch >= _brainParameters.vectorActionSize.Length )
+ throw new UnityAgentsException(
+ "Invalid Action Masking : Branch "+branch+" does not exist.");
+
+ int totalNumberActions = _brainParameters.vectorActionSize.Sum();
+
+ // By default, the masks are null. If we want to specify a new mask, we initialize
+ // the actionMasks with trues.
+ if (_currentMask == null)
+ {
+ _currentMask = new bool[totalNumberActions];
+ }
+
+ // If this is the first time the masked actions are used, we generate the starting
+ // indices for each branch.
+ if (_startingActionIndices == null)
+ {
+ _startingActionIndices = Utilities.CumSum(_brainParameters.vectorActionSize);
+ }
+
+ // Perform the masking
+ foreach (var actionIndex in actionIndices)
+ {
+ if (actionIndex >= _brainParameters.vectorActionSize[branch])
+ {
+ throw new UnityAgentsException(
+ "Invalid Action Masking: Action Mask is too large for specified branch.");
+ }
+ _currentMask[actionIndex + _startingActionIndices[branch]] = true;
+ }
+ }
+
+ ///
+ /// Get the current mask for an agent
+ ///
+ /// A mask for the agent. A boolean array of length equal to the total number of
+ /// actions.
+ public bool[] GetMask()
+ {
+ if (_currentMask != null)
+ {
+ AssertMask();
+ }
+ return _currentMask;
+ }
+
+ ///
+ /// Makes sure that the current mask is usable.
+ ///
+ private void AssertMask()
+ {
+ // Action Masks can only be used in Discrete Control.
+ if (_brainParameters.vectorActionSpaceType != SpaceType.discrete)
+ {
+ throw new UnityAgentsException(
+ "Invalid Action Masking : Can only set action mask for Discrete Control.");
+ }
+
+ var numBranches = _brainParameters.vectorActionSize.Length;
+ for (var branchIndex = 0 ; branchIndex < numBranches; branchIndex++ )
+ {
+ if (AreAllActionsMasked(branchIndex))
+ {
+ throw new UnityAgentsException(
+ "Invalid Action Masking : All the actions of branch " + branchIndex +
+ " are masked.");
+ }
+ }
+ }
+
+ ///
+ /// Resets the current mask for an agent
+ ///
+ public void ResetMask()
+ {
+ if (_currentMask != null)
+ {
+ Array.Clear(_currentMask, 0, _currentMask.Length);
+ }
+ }
+
+ ///
+ /// Checks if all the actions in the input branch are masked
+ ///
+ /// The index of the branch to check
+ /// True if all the actions of the branch are masked
+ private bool AreAllActionsMasked(int branch)
+ {
+ if (_currentMask == null)
+ {
+ return false;
+ }
+ var start = _startingActionIndices[branch];
+ var end = _startingActionIndices[branch + 1];
+ for (var i = start; i < end; i++)
+ {
+ if (!_currentMask[i])
+ {
+ return false;
+ }
+ }
+ return true;
+
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/ActionMasker.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/ActionMasker.cs.meta
new file mode 100755
index 00000000..a7ab396c
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/ActionMasker.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 8a0ec4ccf4ee450da7766f65228d5460
+timeCreated: 1534530911
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Agent.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Agent.cs
new file mode 100755
index 00000000..efdc9f19
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Agent.cs
@@ -0,0 +1,1114 @@
+using System.Collections.Generic;
+using System.Linq;
+using Google.Protobuf;
+using MLAgents.CommunicatorObjects;
+using UnityEngine;
+
+
+namespace MLAgents
+{
+ ///
+ /// Struct that contains all the information for an Agent, including its
+ /// observations, actions and current status, that is sent to the Brain.
+ ///
+ public struct AgentInfo
+ {
+ ///
+ /// Most recent agent vector (i.e. numeric) observation.
+ ///
+ public List vectorObservation;
+
+ ///
+ /// The previous agent vector observations, stacked. The length of the
+ /// history (i.e. number of vector observations to stack) is specified
+ /// in the Brain parameters.
+ ///
+ public List stackedVectorObservation;
+
+ ///
+ /// Most recent agent camera (i.e. texture) observation.
+ ///
+ public List visualObservations;
+
+ ///
+ /// Most recent text observation.
+ ///
+ public string textObservation;
+
+ ///
+ /// Keeps track of the last vector action taken by the Brain.
+ ///
+ public float[] storedVectorActions;
+
+ ///
+ /// Keeps track of the last text action taken by the Brain.
+ ///
+ public string storedTextActions;
+
+ ///
+ /// For discrete control, specifies the actions that the agent cannot take. Is true if
+ /// the action is masked.
+ ///
+ public bool[] actionMasks;
+
+ ///
+ /// Used by the Trainer to store information about the agent. This data
+ /// structure is not consumed or modified by the agent directly, they are
+ /// just the owners of their trainier's memory. Currently, however, the
+ /// size of the memory is in the Brain properties.
+ ///
+ public List memories;
+
+ ///
+ /// Current agent reward.
+ ///
+ public float reward;
+
+ ///
+ /// Whether the agent is done or not.
+ ///
+ public bool done;
+
+ ///
+ /// Whether the agent has reached its max step count for this episode.
+ ///
+ public bool maxStepReached;
+
+ ///
+ /// Unique identifier each agent receives at initialization. It is used
+ /// to separate between different agents in the environment.
+ ///
+ public int id;
+
+ ///
+ /// Converts a AgentInfo to a protobuffer generated AgentInfoProto
+ ///
+ /// The protobuf verison of the AgentInfo.
+ /// The AgentInfo to convert.
+ public CommunicatorObjects.AgentInfoProto ToProto()
+ {
+ var agentInfoProto = new CommunicatorObjects.AgentInfoProto
+ {
+ StackedVectorObservation = { stackedVectorObservation },
+ StoredVectorActions = { storedVectorActions },
+ StoredTextActions = storedTextActions,
+ TextObservation = textObservation,
+ Reward = reward,
+ MaxStepReached = maxStepReached,
+ Done = done,
+ Id = id,
+ };
+ if (memories != null)
+ {
+ agentInfoProto.Memories.Add(memories);
+ }
+
+ if (actionMasks != null)
+ {
+ agentInfoProto.ActionMask.AddRange(actionMasks);
+ }
+
+ foreach (Texture2D obs in visualObservations)
+ {
+ agentInfoProto.VisualObservations.Add(
+ ByteString.CopyFrom(obs.EncodeToPNG())
+ );
+ }
+
+ return agentInfoProto;
+ }
+ }
+
+ ///
+ /// Struct that contains the action information sent from the Brain to the
+ /// Agent.
+ ///
+ public struct AgentAction
+ {
+ public float[] vectorActions;
+ public string textActions;
+ public List memories;
+ public float value;
+ }
+
+ ///
+ /// Struct that contains all the Agent-specific parameters provided in the
+ /// Editor. This excludes the Brain linked to the Agent since it can be
+ /// modified programmatically.
+ ///
+ [System.Serializable]
+ public class AgentParameters
+ {
+ ///
+ /// The list of the Camera GameObjects the agent uses for visual
+ /// observations.
+ ///
+ public List agentCameras = new List();
+
+ ///
+ /// The maximum number of steps the agent takes before being done.
+ ///
+ ///
+ /// If set to 0, the agent can only be set to done programmatically (or
+ /// when the Academy is done).
+ /// If set to any positive integer, the agent will be set to done after
+ /// that many steps. Note that setting the max step to a value greater
+ /// than the academy max step value renders it useless.
+ ///
+ public int maxStep;
+
+ ///
+ /// Determines the behaviour of the agent when done.
+ ///
+ ///
+ /// If true, the agent will reset when done and start a new episode.
+ /// Otherwise, the agent will remain done and its behavior will be
+ /// dictated by the AgentOnDone method.
+ ///
+ public bool resetOnDone = true;
+
+ ///
+ /// Whether to enable On Demand Decisions or make a decision at
+ /// every step.
+ ///
+ public bool onDemandDecision;
+
+ ///
+ /// Number of actions between decisions (used when On Demand Decisions
+ /// is turned off).
+ ///
+ public int numberOfActionsBetweenDecisions;
+ }
+
+
+ ///
+ /// Agent Monobehavior class that is attached to a Unity GameObject, making it
+ /// an Agent. An agent produces observations and takes actions in the
+ /// environment. Observations are determined by the cameras attached
+ /// to the agent in addition to the vector observations implemented by the
+ /// user in . On the other hand, actions
+ /// are determined by decisions produced by a linked Brain. Currently, this
+ /// class is expected to be extended to implement the desired agent behavior.
+ ///
+ ///
+ /// Simply speaking, an agent roams through an environment and at each step
+ /// of the environment extracts its current observation, sends them to its
+ /// linked brain and in return receives an action from its brain. In practice,
+ /// however, an agent need not send its observation at every step since very
+ /// little may have changed between sucessive steps. Currently, how often an
+ /// agent updates its brain with a fresh observation is determined by the
+ /// Academy.
+ ///
+ /// At any step, an agent may be considered .
+ /// This could occur due to a variety of reasons:
+ /// - The agent reached an end state within its environment.
+ /// - The agent reached the maximum # of steps (i.e. timed out).
+ /// - The academy reached the maximum # of steps (forced agent to be done).
+ ///
+ /// Here, an agent reaches an end state if it completes its task successfully
+ /// or somehow fails along the way. In the case where an agent is done before
+ /// the academy, it either resets and restarts, or just lingers until the
+ /// academy is done.
+ ///
+ /// An important note regarding steps and episodes is due. Here, an agent step
+ /// corresponds to an academy step, which also corresponds to Unity
+ /// environment step (i.e. each FixedUpdate call). This is not the case for
+ /// episodes. The academy controls the global episode count and each agent
+ /// controls its own local episode count and can reset and start a new local
+ /// episode independently (based on its own experience). Thus an academy
+ /// (global) episode can be viewed as the upper-bound on an agents episode
+ /// length and that within a single global episode, an agent may have completed
+ /// multiple local episodes. Consequently, if an agent max step is
+ /// set to a value larger than the academy max steps value, then the academy
+ /// value takes precedence (since the agent max step will never be reached).
+ ///
+ /// Lastly, note that at any step the brain linked to the agent is allowed to
+ /// change programmatically with .
+ ///
+ /// Implementation-wise, it is required that this class is extended and the
+ /// virtual methods overridden. For sample implementations of agent behavior,
+ /// see the Examples/ directory within this Unity project.
+ ///
+ [HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/master/" +
+ "docs/Learning-Environment-Design-Agents.md")]
+ [System.Serializable]
+ public abstract class Agent : MonoBehaviour
+ {
+ ///
+ /// The Brain attached to this agent. A brain can be attached either
+ /// directly from the Editor through AgentEditor or
+ /// programmatically through . It is OK for an agent
+ /// to not have a brain, as long as no decision is requested.
+ ///
+ [HideInInspector] public Brain brain;
+
+ ///
+ /// Agent parameters specified within the Editor via AgentEditor.
+ ///
+ [HideInInspector] public AgentParameters agentParameters;
+
+ /// Current Agent information (message sent to Brain).
+ AgentInfo info;
+
+ /// Current Agent action (message sent from Brain).
+ AgentAction action;
+
+ /// Represents the reward the agent accumulated during the current step.
+ /// It is reset to 0 at the beginning of every step.
+ /// Should be set to a positive value when the agent performs a "good"
+ /// action that we wish to reinforce/reward, and set to a negative value
+ /// when the agent performs a "bad" action that we wish to punish/deter.
+ /// Additionally, the magnitude of the reward should not exceed 1.0
+ float reward;
+
+ /// Keeps track of the cumulative reward in this episode.
+ float cumulativeReward;
+
+ /// Whether or not the agent requests an action.
+ bool requestAction;
+
+ /// Whether or not the agent requests a decision.
+ bool requestDecision;
+
+ /// Whether or not the agent has completed the episode. This may be due
+ /// to either reaching a success or fail state, or reaching the maximum
+ /// number of steps (i.e. timing out).
+ bool done;
+
+ /// Whether or not the agent reached the maximum number of steps.
+ bool maxStepReached;
+
+ /// Keeps track of the number of steps taken by the agent in this episode.
+ /// Note that this value is different for each agent, and may not overlap
+ /// with the step counter in the Academy, since agents reset based on
+ /// their own experience.
+ int stepCount;
+
+ /// Flag to signify that an agent has been reset but the fact that it is
+ /// done has not been communicated (required for On Demand Decisions).
+ bool hasAlreadyReset;
+
+ /// Flag to signify that an agent is done and should not reset until
+ /// the fact that it is done has been communicated.
+ bool terminate;
+
+ /// Unique identifier each agent receives at initialization. It is used
+ /// to separate between different agents in the environment.
+ int id;
+
+ /// Keeps track of the actions that are masked at each step.
+ private ActionMasker actionMasker;
+
+ /// Array of Texture2D used to render to from render buffer before
+ /// transforming into float tensor.
+ Texture2D[] textureArray;
+
+ ///
+ /// Demonstration recorder.
+ ///
+ private DemonstrationRecorder recorder;
+
+ private Texture2D[] blackTextures;
+
+ /// Monobehavior function that is called when the attached GameObject
+ /// becomes enabled or active.
+ void OnEnable()
+ {
+ textureArray = new Texture2D[agentParameters.agentCameras.Count];
+ for (int i = 0; i < agentParameters.agentCameras.Count; i++)
+ {
+ textureArray[i] = new Texture2D(1, 1, TextureFormat.RGB24, false);
+ }
+
+ blackTextures = new Texture2D[agentParameters.agentCameras.Count];
+ for (int i = 0; i < agentParameters.agentCameras.Count; i++)
+ {
+ blackTextures[i] = new Texture2D(brain.brainParameters.cameraResolutions[i].width,
+ brain.brainParameters.cameraResolutions[i].height,
+ TextureFormat.RGB24, false);
+ Color[] blackArray = blackTextures[i].GetPixels();
+ int nBlackArray = blackArray.Length;
+ for (int j = 0; j < nBlackArray; j++)
+ {
+ blackArray[j] = Color.black;
+ }
+ blackTextures[i].SetPixels(blackArray);
+ blackTextures[i].Apply();
+ }
+ id = gameObject.GetInstanceID();
+ Academy academy = Object.FindObjectOfType() as Academy;
+ OnEnableHelper(academy);
+
+ recorder = GetComponent();
+ }
+
+ /// Helper method for the event, created to
+ /// facilitate testing.
+ void OnEnableHelper(Academy academy)
+ {
+ info = new AgentInfo();
+ action = new AgentAction();
+
+ if (academy == null)
+ {
+ throw new UnityAgentsException(
+ "No Academy Component could be found in the scene.");
+ }
+
+ academy.AgentSetStatus += SetStatus;
+ academy.AgentResetIfDone += ResetIfDone;
+ academy.AgentSendState += SendInfo;
+ academy.AgentAct += AgentStep;
+ academy.AgentForceReset += _AgentReset;
+
+ if (brain != null)
+ {
+ ResetData();
+ }
+ else
+ {
+ Debug.Log(
+ string.Format(
+ "The Agent component attached to the " +
+ "GameObject {0} was initialized without a brain.",
+ gameObject.name));
+ }
+
+ InitializeAgent();
+ }
+
+ /// Monobehavior function that is called when the attached GameObject
+ /// becomes disabled or inactive.
+ void OnDisable()
+ {
+ Academy academy = Object.FindObjectOfType() as Academy;
+ if (academy != null)
+ {
+ academy.AgentSetStatus -= SetStatus;
+ academy.AgentResetIfDone -= ResetIfDone;
+ academy.AgentSendState -= SendInfo;
+ academy.AgentAct -= AgentStep;
+ academy.AgentForceReset -= _AgentReset;
+ }
+ }
+
+ ///
+ /// Updates the Brain for the agent. Any brain currently assigned to the
+ /// agent will be replaced with the provided one.
+ ///
+ ///
+ /// The agent unsubscribes from its current brain (if it has one) and
+ /// subscribes to the provided brain. This enables contextual brains, that
+ /// is, updating the behaviour (hence brain) of the agent depending on
+ /// the context of the game. For example, we may utilize one (wandering)
+ /// brain when an agent is randomly exploring an open world, but switch
+ /// to another (fighting) brain when it comes into contact with an enemy.
+ ///
+ /// New brain to subscribe this agent to
+ public void GiveBrain(Brain brain)
+ {
+ this.brain = brain;
+ ResetData();
+ }
+
+ ///
+ /// Returns the current step counter (within the current epside).
+ ///
+ ///
+ /// Current episode number.
+ ///
+ public int GetStepCount()
+ {
+ return stepCount;
+ }
+
+ ///
+ /// Resets the step reward and possibly the episode reward for the agent.
+ ///
+ public void ResetReward()
+ {
+ reward = 0f;
+ if (done)
+ {
+ cumulativeReward = 0f;
+ }
+ }
+
+ ///
+ /// Overrides the current step reward of the agent and updates the episode
+ /// reward accordingly.
+ ///
+ /// The new value of the reward.
+ public void SetReward(float reward)
+ {
+ cumulativeReward += (reward - this.reward);
+ this.reward = reward;
+ }
+
+ ///
+ /// Increments the step and episode rewards by the provided value.
+ ///
+ /// Incremental reward value.
+ public void AddReward(float increment)
+ {
+ reward += increment;
+ cumulativeReward += increment;
+ }
+
+ ///
+ /// Retrieves the step reward for the Agent.
+ ///
+ /// The step reward.
+ public float GetReward()
+ {
+ return reward;
+ }
+
+ ///
+ /// Retrieves the episode reward for the Agent.
+ ///
+ /// The episode reward.
+ public float GetCumulativeReward()
+ {
+ return cumulativeReward;
+ }
+
+ ///
+ /// Sets the done flag to true.
+ ///
+ public void Done()
+ {
+ done = true;
+ }
+
+ ///
+ /// Is called when the agent must request the brain for a new decision.
+ ///
+ public void RequestDecision()
+ {
+ requestDecision = true;
+ RequestAction();
+ }
+
+ ///
+ /// Is called then the agent must perform a new action.
+ ///
+ public void RequestAction()
+ {
+ requestAction = true;
+ }
+
+ ///
+ /// Indicates if the agent has reached his maximum number of steps.
+ ///
+ ///
+ /// true, if max step reached was reached, false otherwise.
+ ///
+ public bool IsMaxStepReached()
+ {
+ return maxStepReached;
+ }
+
+ ///
+ /// Indicates if the agent is done
+ ///
+ ///
+ /// true, if the agent is done, false otherwise.
+ ///
+ public bool IsDone()
+ {
+ return done;
+ }
+
+ /// Helper function that resets all the data structures associated with
+ /// the agent. Typically used when the agent is being initialized or reset
+ /// at the end of an episode.
+ void ResetData()
+ {
+ if (brain == null)
+ {
+ return;
+ }
+
+ BrainParameters param = brain.brainParameters;
+ actionMasker = new ActionMasker(param);
+ if (param.vectorActionSpaceType == SpaceType.continuous)
+ {
+ action.vectorActions = new float[param.vectorActionSize[0]];
+ info.storedVectorActions = new float[param.vectorActionSize[0]];
+ }
+ else
+ {
+ action.vectorActions = new float[param.vectorActionSize.Length];
+ info.storedVectorActions = new float[param.vectorActionSize.Length];
+ }
+
+ if (info.textObservation == null)
+ info.textObservation = "";
+ action.textActions = "";
+ info.memories = new List();
+ action.memories = new List();
+ info.vectorObservation =
+ new List(param.vectorObservationSize);
+ info.stackedVectorObservation =
+ new List(param.vectorObservationSize
+ * brain.brainParameters.numStackedVectorObservations);
+ info.stackedVectorObservation.AddRange(
+ new float[param.vectorObservationSize
+ * param.numStackedVectorObservations]);
+
+ info.visualObservations = new List();
+ }
+
+ ///
+ /// Initializes the agent, called once when the agent is enabled. Can be
+ /// left empty if there is no special, unique set-up behavior for the
+ /// agent.
+ ///
+ ///
+ /// One sample use is to store local references to other objects in the
+ /// scene which would facilitate computing this agents observation.
+ ///
+ public virtual void InitializeAgent()
+ {
+ }
+
+ ///
+ /// Sends the Agent info to the linked Brain.
+ ///
+ void SendInfoToBrain()
+ {
+ if (brain == null)
+ {
+ return;
+ }
+
+ info.memories = action.memories;
+ info.storedVectorActions = action.vectorActions;
+ info.storedTextActions = action.textActions;
+ info.vectorObservation.Clear();
+ actionMasker.ResetMask();
+ CollectObservations();
+ info.actionMasks = actionMasker.GetMask();
+
+ BrainParameters param = brain.brainParameters;
+ if (info.vectorObservation.Count != param.vectorObservationSize)
+ {
+ throw new UnityAgentsException(string.Format(
+ "Vector Observation size mismatch between continuous " +
+ "agent {0} and brain {1}. " +
+ "Was Expecting {2} but received {3}. ",
+ gameObject.name, brain.name,
+ brain.brainParameters.vectorObservationSize,
+ info.vectorObservation.Count));
+ }
+
+ info.stackedVectorObservation.RemoveRange(
+ 0, param.vectorObservationSize);
+ info.stackedVectorObservation.AddRange(info.vectorObservation);
+
+ info.visualObservations.Clear();
+ if (param.cameraResolutions.Length > agentParameters.agentCameras.Count)
+ {
+ throw new UnityAgentsException(string.Format(
+ "Not enough cameras for agent {0} : Bain {1} expecting at " +
+ "least {2} cameras but only {3} were present.",
+ gameObject.name, brain.name,
+ brain.brainParameters.cameraResolutions.Length,
+ agentParameters.agentCameras.Count));
+ }
+
+ for (int i = 0; i < brain.brainParameters.cameraResolutions.Length; i++)
+ {
+ int cameraWidth = param.cameraResolutions[i].width;
+ int cameraHeight = param.cameraResolutions[i].height;
+ if (LightStatus())
+ {
+ ObservationToTexture(
+ agentParameters.agentCameras[i],
+ cameraWidth,
+ cameraHeight,
+ ref textureArray[i]);
+ info.visualObservations.Add(textureArray[i]);
+ }
+ else
+ {
+ info.visualObservations.Add(blackTextures[i]);
+ }
+
+ }
+
+ info.reward = reward;
+ info.done = done;
+ info.maxStepReached = maxStepReached;
+ info.id = id;
+
+ brain.SendState(this, info);
+
+ if (recorder != null && recorder.record && Application.isEditor)
+ {
+ recorder.WriteExperience(info);
+ }
+
+ info.textObservation = "";
+ }
+
+
+ public virtual bool LightStatus()
+ {
+ return true;
+ }
+
+ ///
+ /// Collects the (vector, visual, text) observations of the agent.
+ /// The agent observation describes the current environment from the
+ /// perspective of the agent.
+ ///
+ ///
+ /// Simply, an agents observation is any environment information that helps
+ /// the Agent acheive its goal. For example, for a fighting Agent, its
+ /// observation could include distances to friends or enemies, or the
+ /// current level of ammunition at its disposal.
+ /// Recall that an Agent may attach vector, visual or textual observations.
+ /// Vector observations are added by calling the provided helper methods:
+ /// -
+ /// -
+ /// -
+ /// -
+ /// -
+ /// -
+ /// -
+ /// -
+ /// -
+ /// Depending on your environment, any combination of these helpers can
+ /// be used. They just need to be used in the exact same order each time
+ /// this method is called and the resulting size of the vector observation
+ /// needs to match the vectorObservationSize attribute of the linked Brain.
+ /// Visual observations are implicitly added from the cameras attached to
+ /// the Agent.
+ /// Lastly, textual observations are added using
+ /// .
+ ///
+ public virtual void CollectObservations()
+ {
+ }
+
+ ///
+ /// Sets an action mask for discrete control agents. When used, the agent will not be
+ /// able to perform the action passed as argument at the next decision. If no branch is
+ /// specified, the default branch will be 0. The actionIndex or actionIndices correspond
+ /// to the action the agent will be unable to perform.
+ ///
+ /// The indices of the masked actions on branch 0
+ protected void SetActionMask(IEnumerable actionIndices)
+ {
+ actionMasker.SetActionMask(0, actionIndices);
+ }
+
+ ///
+ /// Sets an action mask for discrete control agents. When used, the agent will not be
+ /// able to perform the action passed as argument at the next decision. If no branch is
+ /// specified, the default branch will be 0. The actionIndex or actionIndices correspond
+ /// to the action the agent will be unable to perform.
+ ///
+ /// The index of the masked action on branch 0
+ protected void SetActionMask(int actionIndex)
+ {
+ actionMasker.SetActionMask(0, new int[1] { actionIndex });
+ }
+
+ ///
+ /// Sets an action mask for discrete control agents. When used, the agent will not be
+ /// able to perform the action passed as argument at the next decision. If no branch is
+ /// specified, the default branch will be 0. The actionIndex or actionIndices correspond
+ /// to the action the agent will be unable to perform.
+ ///
+ /// The branch for which the actions will be masked
+ /// The index of the masked action
+ protected void SetActionMask(int branch, int actionIndex)
+ {
+ actionMasker.SetActionMask(branch, new int[1] { actionIndex });
+ }
+
+ ///
+ /// Modifies an action mask for discrete control agents. When used, the agent will not be
+ /// able to perform the action passed as argument at the next decision. If no branch is
+ /// specified, the default branch will be 0. The actionIndex or actionIndices correspond
+ /// to the action the agent will be unable to perform.
+ ///
+ /// The branch for which the actions will be masked
+ /// The indices of the masked actions
+ protected void SetActionMask(int branch, IEnumerable actionIndices)
+ {
+ actionMasker.SetActionMask(branch, actionIndices);
+ }
+
+
+ ///
+ /// Adds a float observation to the vector observations of the agent.
+ /// Increases the size of the agents vector observation by 1.
+ ///
+ /// Observation.
+ protected void AddVectorObs(float observation)
+ {
+ info.vectorObservation.Add(observation);
+ }
+
+ ///
+ /// Adds an integer observation to the vector observations of the agent.
+ /// Increases the size of the agents vector observation by 1.
+ ///
+ /// Observation.
+ protected void AddVectorObs(int observation)
+ {
+ info.vectorObservation.Add(observation);
+ }
+
+ ///
+ /// Adds an Vector3 observation to the vector observations of the agent.
+ /// Increases the size of the agents vector observation by 3.
+ ///
+ /// Observation.
+ protected void AddVectorObs(Vector3 observation)
+ {
+ info.vectorObservation.Add(observation.x);
+ info.vectorObservation.Add(observation.y);
+ info.vectorObservation.Add(observation.z);
+ }
+
+ ///
+ /// Adds an Vector2 observation to the vector observations of the agent.
+ /// Increases the size of the agents vector observation by 2.
+ ///
+ /// Observation.
+ protected void AddVectorObs(Vector2 observation)
+ {
+ info.vectorObservation.Add(observation.x);
+ info.vectorObservation.Add(observation.y);
+ }
+
+ ///
+ /// Adds a collection of float observations to the vector observations of the agent.
+ /// Increases the size of the agents vector observation by size of the collection.
+ ///
+ /// Observation.
+ protected void AddVectorObs(IEnumerable observation)
+ {
+ info.vectorObservation.AddRange(observation);
+ }
+
+ ///
+ /// Adds a quaternion observation to the vector observations of the agent.
+ /// Increases the size of the agents vector observation by 4.
+ ///
+ /// Observation.
+ protected void AddVectorObs(Quaternion observation)
+ {
+ info.vectorObservation.Add(observation.x);
+ info.vectorObservation.Add(observation.y);
+ info.vectorObservation.Add(observation.z);
+ info.vectorObservation.Add(observation.w);
+ }
+
+ ///
+ /// Adds a boolean observation to the vector observation of the agent.
+ /// Increases the size of the agent's vector observation by 1.
+ ///
+ ///
+ protected void AddVectorObs(bool observation)
+ {
+ info.vectorObservation.Add(observation ? 1f : 0f);
+ }
+
+ protected void AddVectorObs(int observation, int range)
+ {
+ float[] oneHotVector = new float[range];
+ oneHotVector[observation] = 1;
+ info.vectorObservation.AddRange(oneHotVector);
+ }
+
+ ///
+ /// Sets the text observation.
+ ///
+ /// The text observation.
+ public void SetTextObs(string textObservation)
+ {
+ info.textObservation = textObservation;
+ }
+
+ ///
+ /// Specifies the agent behavior at every step based on the provided
+ /// action.
+ ///
+ ///
+ /// Vector action. Note that for discrete actions, the provided array
+ /// will be of length 1.
+ ///
+ /// Text action.
+ public virtual void AgentAction(float[] vectorAction, string textAction)
+ {
+ }
+
+ ///
+ /// Specifies the agent behavior when done and
+ /// is false. This method can be
+ /// used to remove the agent from the scene.
+ ///
+ public virtual void AgentOnDone()
+ {
+ }
+
+ ///
+ /// Specifies the agent behavior when being reset, which can be due to
+ /// the agent or Academy being done (i.e. completion of local or global
+ /// episode).
+ ///
+ public virtual void AgentReset()
+ {
+ }
+
+ ///
+ /// An internal reset method that updates internal data structures in
+ /// addition to calling .
+ ///
+ void _AgentReset()
+ {
+ ResetData();
+ stepCount = 0;
+ AgentReset();
+ }
+
+ ///
+ /// Updates the vector action.
+ ///
+ /// Vector actions.
+ public void UpdateVectorAction(float[] vectorActions)
+ {
+ action.vectorActions = vectorActions;
+ }
+
+ ///
+ /// Updates the memories action.
+ ///
+ /// Memories.
+ public void UpdateMemoriesAction(List memories)
+ {
+ action.memories = memories;
+ }
+
+ public void AppendMemoriesAction(List memories)
+ {
+ action.memories.AddRange(memories);
+ }
+
+ ///
+ /// Updates the text action.
+ ///
+ /// Text actions.
+ public void UpdateTextAction(string textActions)
+ {
+ action.textActions = textActions;
+ }
+
+ ///
+ /// Updates the value of the agent.
+ ///
+ /// Text actions.
+ public void UpdateValueAction(float value)
+ {
+ action.value = value;
+ }
+
+ protected float GetValueEstimate()
+ {
+ return action.value;
+ }
+
+ ///
+ /// Scales continous action from [-1, 1] to arbitrary range.
+ ///
+ ///
+ ///
+ ///
+ ///
+ protected float ScaleAction(float rawAction, float min, float max)
+ {
+ var middle = (min + max) / 2;
+ var range = (max - min) / 2;
+ return rawAction * range + middle;
+ }
+
+ ///
+ /// Sets the status of the agent.
+ ///
+ /// If set to true
+ /// The agent must set maxStepReached.
+ /// If set to true
+ /// The agent must set done.
+ /// Number of current steps in episode
+ void SetStatus(bool academyMaxStep, bool academyDone, int academyStepCounter)
+ {
+ if (academyDone)
+ {
+ academyStepCounter = 0;
+ }
+
+ MakeRequests(academyStepCounter);
+ if (academyMaxStep)
+ {
+ maxStepReached = true;
+ }
+
+ // If the Academy needs to reset, the agent should reset
+ // even if it reseted recently.
+ if (academyDone)
+ {
+ Done();
+ hasAlreadyReset = false;
+ }
+ }
+
+ /// Signals the agent that it must reset if its done flag is set to true.
+ void ResetIfDone()
+ {
+ // If an agent is done, then it will also
+ // request for a decision and an action
+ if (IsDone())
+ {
+ if (agentParameters.resetOnDone)
+ {
+ if (agentParameters.onDemandDecision)
+ {
+ if (!hasAlreadyReset)
+ {
+ // If event based, the agent can reset as soon
+ // as it is done
+ _AgentReset();
+ hasAlreadyReset = true;
+ }
+ }
+ else if (requestDecision)
+ {
+ // If not event based, the agent must wait to request a
+ // decsion before reseting to keep multiple agents in sync.
+ _AgentReset();
+ }
+ }
+ else
+ {
+ terminate = true;
+ RequestDecision();
+ }
+ }
+ }
+
+ ///
+ /// Signals the agent that it must sent its decision to the brain.
+ ///
+ void SendInfo()
+ {
+ if (requestDecision)
+ {
+ SendInfoToBrain();
+ ResetReward();
+ done = false;
+ maxStepReached = false;
+ requestDecision = false;
+
+ hasAlreadyReset = false;
+ }
+ }
+
+ /// Used by the brain to make the agent perform a step.
+ void AgentStep()
+ {
+ if (terminate)
+ {
+ terminate = false;
+ ResetReward();
+ done = false;
+ maxStepReached = false;
+ requestDecision = false;
+ requestAction = false;
+
+ hasAlreadyReset = false;
+ OnDisable();
+ AgentOnDone();
+ }
+
+ if ((requestAction) && (brain != null))
+ {
+ requestAction = false;
+ AgentAction(action.vectorActions, action.textActions);
+ }
+
+ if ((stepCount >= agentParameters.maxStep)
+ && (agentParameters.maxStep > 0))
+ {
+ maxStepReached = true;
+ Done();
+ }
+ stepCount += 1;
+ }
+
+ ///
+ /// Is called after every step, contains the logic to decide if the agent
+ /// will request a decision at the next step.
+ ///
+ void MakeRequests(int academyStepCounter)
+ {
+ agentParameters.numberOfActionsBetweenDecisions =
+ Mathf.Max(agentParameters.numberOfActionsBetweenDecisions, 1);
+ if (!agentParameters.onDemandDecision)
+ {
+ RequestAction();
+ if (academyStepCounter %
+ agentParameters.numberOfActionsBetweenDecisions == 0)
+ {
+ RequestDecision();
+ }
+ }
+ }
+
+ ///
+ /// Converts a camera and correspinding resolution to a 2D texture.
+ ///
+ /// The 2D texture.
+ /// Camera.
+ /// Width of resulting 2D texture.
+ /// Height of resulting 2D texture.
+ /// Texture2D to render to.
+ public static void ObservationToTexture(Camera obsCamera, int width, int height, ref Texture2D texture2D)
+ {
+ Rect oldRec = obsCamera.rect;
+ obsCamera.rect = new Rect(0f, 0f, 1f, 1f);
+ var depth = 24;
+ var format = RenderTextureFormat.Default;
+ var readWrite = RenderTextureReadWrite.Default;
+
+ var tempRT =
+ RenderTexture.GetTemporary(width, height, depth, format, readWrite);
+
+ if (width != texture2D.width || height != texture2D.height)
+ {
+ texture2D.Resize(width, height);
+ }
+
+ var prevActiveRT = RenderTexture.active;
+ var prevCameraRT = obsCamera.targetTexture;
+
+ // render to offscreen texture (readonly from CPU side)
+ RenderTexture.active = tempRT;
+ obsCamera.targetTexture = tempRT;
+
+ obsCamera.Render();
+
+ texture2D.ReadPixels(new Rect(0, 0, texture2D.width, texture2D.height), 0, 0);
+ texture2D.Apply();
+ obsCamera.targetTexture = prevCameraRT;
+ obsCamera.rect = oldRec;
+ RenderTexture.active = prevActiveRT;
+ RenderTexture.ReleaseTemporary(tempRT);
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Agent.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Agent.cs.meta
new file mode 100755
index 00000000..e0d370b8
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Agent.cs.meta
@@ -0,0 +1,12 @@
+fileFormatVersion: 2
+guid: 88b6042bc9a5d4aa58d931eae49442e5
+timeCreated: 1501802662
+licenseType: Free
+MonoImporter:
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BCTeacherHelper.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BCTeacherHelper.cs
new file mode 100755
index 00000000..6c583144
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BCTeacherHelper.cs
@@ -0,0 +1,63 @@
+using System.Collections;
+using System.Collections.Generic;
+using UnityEngine;
+
+namespace MLAgents
+{
+
+ ///
+ /// Behavioral Cloning Helper script. Attach to teacher agent to enable
+ /// resetting the experience buffer, as well as toggling session recording.
+ ///
+ public class BCTeacherHelper : MonoBehaviour
+ {
+
+ bool recordExperiences;
+ bool resetBuffer;
+ Agent myAgent;
+ float bufferResetTime;
+
+ public KeyCode recordKey = KeyCode.R;
+ public KeyCode resetKey = KeyCode.C;
+
+ // Use this for initialization
+ void Start()
+ {
+ recordExperiences = true;
+ resetBuffer = false;
+ myAgent = GetComponent();
+ bufferResetTime = Time.time;
+ }
+
+ // Update is called once per frame
+ void Update()
+ {
+ if (Input.GetKeyDown(recordKey))
+ {
+ recordExperiences = !recordExperiences;
+ }
+
+ if (Input.GetKeyDown(resetKey))
+ {
+ resetBuffer = true;
+ bufferResetTime = Time.time;
+ }
+ else
+ {
+ resetBuffer = false;
+ }
+
+ Monitor.Log("Recording experiences " + recordKey, recordExperiences.ToString());
+ float timeSinceBufferReset = Time.time - bufferResetTime;
+ Monitor.Log("Seconds since buffer reset " + resetKey,
+ Mathf.FloorToInt(timeSinceBufferReset).ToString());
+ }
+
+ void FixedUpdate()
+ {
+ // Convert both bools into single comma separated string. Python makes
+ // assumption that this structure is preserved.
+ myAgent.SetTextObs(recordExperiences + "," + resetBuffer);
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BCTeacherHelper.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BCTeacherHelper.cs.meta
new file mode 100755
index 00000000..909b7060
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BCTeacherHelper.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: d1cf16abc39fb4d6ca81222fc73d1bb5
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Batcher.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Batcher.cs
new file mode 100755
index 00000000..d8eb0b9f
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Batcher.cs
@@ -0,0 +1,295 @@
+using System.Collections.Generic;
+using System.Linq;
+using UnityEngine;
+using Google.Protobuf;
+
+namespace MLAgents
+{
+ ///
+ /// The batcher is an RL specific class that makes sure that the information each object in
+ /// Unity (Academy and Brains) wants to send to External is appropriately batched together
+ /// and sent only when necessary.
+ ///
+ /// The Batcher will only send a Message to the Communicator when either :
+ /// 1 - The academy is done
+ /// 2 - At least one brain has data to send
+ ///
+ /// At each step, the batcher will keep track of the brains that queried the batcher for that
+ /// step. The batcher can only send the batched data when all the Brains have queried the
+ /// Batcher.
+ ///
+ public class Batcher
+ {
+ /// The default number of agents in the scene
+ private const int NumAgents = 32;
+
+ /// Keeps track of which brains have data to send on the current step
+ Dictionary m_hasData =
+ new Dictionary();
+
+ /// Keeps track of which brains queried the batcher on the current step
+ Dictionary m_hasQueried =
+ new Dictionary();
+
+ /// Keeps track of the agents of each brain on the current step
+ Dictionary> m_currentAgents =
+ new Dictionary>();
+
+ /// The Communicator of the batcher, sends a message at most once per step
+ Communicator m_communicator;
+
+ /// The current UnityRLOutput to be sent when all the brains queried the batcher
+ CommunicatorObjects.UnityRLOutput m_currentUnityRLOutput =
+ new CommunicatorObjects.UnityRLOutput();
+
+ /// Keeps track of the done flag of the Academy
+ bool m_academyDone;
+
+ /// Keeps track of last CommandProto sent by External
+ CommunicatorObjects.CommandProto m_command;
+
+ /// Keeps track of last EnvironmentParametersProto sent by External
+ // CommunicatorObjects.EnvironmentParametersProto m_environmentParameters;
+
+ /// Keeps track of the Arena Parameters sent by external (added by Ben)
+ CommunicatorObjects.UnityRLResetInput m_arenasParameters;
+
+ /// Keeps track of last training mode sent by External
+ bool m_isTraining;
+
+ /// Keeps track of the number of messages received
+ private ulong m_messagesReceived;
+
+ ///
+ /// Initializes a new instance of the Batcher class.
+ ///
+ /// The communicator to be used by the batcher.
+ public Batcher(Communicator communicator)
+ {
+ this.m_communicator = communicator;
+ }
+
+ ///
+ /// Sends the academy parameters through the Communicator.
+ /// Is used by the academy to send the AcademyParameters to the communicator.
+ ///
+ /// The External Initialization Parameters received.
+ /// The Unity Initialization Parameters to be sent.
+ public CommunicatorObjects.UnityRLInitializationInput SendAcademyParameters(
+ CommunicatorObjects.UnityRLInitializationOutput academyParameters)
+ {
+ CommunicatorObjects.UnityInput input;
+ var initializationInput = new CommunicatorObjects.UnityInput();
+ try
+ {
+ initializationInput = m_communicator.Initialize(
+ new CommunicatorObjects.UnityOutput
+ {
+ RlInitializationOutput = academyParameters
+ },
+ out input);
+ }
+ catch
+ {
+ throw new UnityAgentsException(
+ "The Communicator was unable to connect. Please make sure the External " +
+ "process is ready to accept communication with Unity.");
+ }
+
+ var firstRlInput = input.RlInput;
+ m_command = firstRlInput.Command;
+ m_arenasParameters = input.RlResetInput;
+ m_isTraining = firstRlInput.IsTraining;
+ return initializationInput.RlInitializationInput;
+ }
+
+ ///
+ /// Registers the done flag of the academy to the next output to be sent
+ /// to the communicator.
+ ///
+ /// If set to true
+ /// The academy done state will be sent to External at the next Exchange.
+ public void RegisterAcademyDoneFlag(bool done)
+ {
+ m_academyDone = done;
+ }
+
+ ///
+ /// Gets the command. Is used by the academy to get reset or quit signals.
+ ///
+ /// The current command.
+ public CommunicatorObjects.CommandProto GetCommand()
+ {
+ return m_command;
+ }
+
+ ///
+ /// Gets the number of messages received so far. Can be used to check for new messages.
+ ///
+ /// The number of messages received since start of the simulation
+ public ulong GetNumberMessageReceived()
+ {
+ return m_messagesReceived;
+ }
+
+ ///
+ /// Gets the arena parameters. Is used by the academy to update
+ /// the environment parameters.
+ ///
+ /// The environment parameters.
+ public CommunicatorObjects.UnityRLResetInput GetArenasParameters()
+ {
+ return m_arenasParameters;
+ }
+
+
+ ///
+ /// Gets the last training_mode flag External sent
+ ///
+ /// true, if training mode is requested, false otherwise.
+ public bool GetIsTraining()
+ {
+ return m_isTraining;
+ }
+
+ ///
+ /// Adds the brain to the list of brains which will be sending information to External.
+ ///
+ /// Brain key.
+ public void SubscribeBrain(string brainKey)
+ {
+ m_hasQueried[brainKey] = false;
+ m_hasData[brainKey] = false;
+ m_currentAgents[brainKey] = new List(NumAgents);
+ m_currentUnityRLOutput.AgentInfos.Add(
+ brainKey,
+ new CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto());
+ }
+
+ ///
+ /// Sends the brain info. If at least one brain has an agent in need of
+ /// a decision or if the academy is done, the data is sent via
+ /// Communicator. Else, a new step is realized. The data can only be
+ /// sent once all the brains that subscribed to the batcher have tried
+ /// to send information.
+ ///
+ /// Brain key.
+ /// Agent info.
+ public void SendBrainInfo(
+ string brainKey, Dictionary agentInfo)
+ {
+ // If no communicator is initialized, the Batcher will not transmit
+ // BrainInfo
+ if (m_communicator == null)
+ {
+ return;
+ }
+
+ // The brain tried called GiveBrainInfo, update m_hasQueried
+ m_hasQueried[brainKey] = true;
+ // Populate the currentAgents dictionary
+ m_currentAgents[brainKey].Clear();
+ foreach (Agent agent in agentInfo.Keys)
+ {
+ m_currentAgents[brainKey].Add(agent);
+ }
+
+ // If at least one agent has data to send, then append data to
+ // the message and update hasSentState
+ if (m_currentAgents[brainKey].Count > 0)
+ {
+ foreach (Agent agent in m_currentAgents[brainKey])
+ {
+ CommunicatorObjects.AgentInfoProto agentInfoProto = agentInfo[agent].ToProto();
+ m_currentUnityRLOutput.AgentInfos[brainKey].Value.Add(agentInfoProto);
+ }
+
+ m_hasData[brainKey] = true;
+ }
+
+ // If any agent needs to send data, then the whole message
+ // must be sent
+ if (m_hasQueried.Values.All(x => x))
+ {
+ if (m_hasData.Values.Any(x => x) || m_academyDone)
+ {
+ m_currentUnityRLOutput.GlobalDone = m_academyDone;
+ SendBatchedMessageHelper();
+ }
+
+ // The message was just sent so we must reset hasSentState and
+ // triedSendState
+ foreach (string k in m_currentAgents.Keys)
+ {
+ m_hasData[k] = false;
+ m_hasQueried[k] = false;
+ }
+ }
+ }
+
+ ///
+ /// Helper method that sends the curent UnityRLOutput, receives the next UnityInput and
+ /// Applies the appropriate AgentAction to the agents.
+ ///
+ void SendBatchedMessageHelper()
+ {
+ var input = m_communicator.Exchange(
+ new CommunicatorObjects.UnityOutput
+ {
+ RlOutput = m_currentUnityRLOutput
+ });
+ m_messagesReceived += 1;
+
+ foreach (string k in m_currentUnityRLOutput.AgentInfos.Keys)
+ {
+ m_currentUnityRLOutput.AgentInfos[k].Value.Clear();
+ }
+
+ if (input == null)
+ {
+ m_command = CommunicatorObjects.CommandProto.Quit;
+ return;
+ }
+
+ CommunicatorObjects.UnityRLInput rlInput = input.RlInput;
+
+ if (rlInput == null)
+ {
+ m_command = CommunicatorObjects.CommandProto.Quit;
+ return;
+ }
+
+ m_command = rlInput.Command;
+ m_arenasParameters = input.RlResetInput;
+ m_isTraining = rlInput.IsTraining;
+
+ if (rlInput.AgentActions == null)
+ {
+ return;
+ }
+
+ foreach (var brainName in rlInput.AgentActions.Keys)
+ {
+ if (!m_currentAgents[brainName].Any())
+ {
+ continue;
+ }
+
+ if (!rlInput.AgentActions[brainName].Value.Any())
+ {
+ continue;
+ }
+
+ for (var i = 0; i < m_currentAgents[brainName].Count; i++)
+ {
+ var agent = m_currentAgents[brainName][i];
+ var action = rlInput.AgentActions[brainName].Value[i];
+ agent.UpdateVectorAction(action.VectorActions.ToArray());
+ agent.UpdateMemoriesAction(action.Memories.ToList());
+ agent.UpdateTextAction(action.TextActions);
+ agent.UpdateValueAction(action.Value);
+ }
+ }
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Batcher.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Batcher.cs.meta
new file mode 100755
index 00000000..07b7fed3
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Batcher.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 4243d5dc0ad5746cba578575182f8c17
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Brain.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Brain.cs
new file mode 100755
index 00000000..b8085e14
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Brain.cs
@@ -0,0 +1,93 @@
+using System.Collections.Generic;
+using UnityEngine;
+
+namespace MLAgents
+{
+ ///
+ /// Brain receive data from Agents through calls to SendState. The brain then updates the
+ /// actions of the agents at each FixedUpdate.
+ /// The Brain encapsulates the decision making process. Every Agent must be assigned a Brain,
+ /// but you can use the same Brain with more than one Agent. You can also create several
+ /// Brains, attach each of the Brain to one or more than one Agent.
+ /// Brain assets has several important properties that you can set using the Inspector window.
+ /// These properties must be appropriate for the Agents using the Brain. For example, the
+ /// Vector Observation Space Size property must match the length of the feature
+ /// vector created by an Agent exactly.
+ ///
+ public abstract class Brain : ScriptableObject
+ {
+ [SerializeField] public BrainParameters brainParameters;
+
+ protected Dictionary agentInfos =
+ new Dictionary(1024);
+
+ protected Batcher brainBatcher;
+
+ [System.NonSerialized]
+ private bool _isInitialized;
+
+ ///
+ /// Sets the Batcher of the Brain. The brain will call the batcher at every step and give
+ /// it the agent's data using SendBrainInfo at each DecideAction call.
+ ///
+ /// The Batcher the brain will use for the current session
+ public void SetBatcher(Batcher batcher)
+ {
+ if (batcher == null)
+ {
+ brainBatcher = null;
+ }
+ else
+ {
+ brainBatcher = batcher;
+ brainBatcher.SubscribeBrain(name);
+ }
+ LazyInitialize();
+ }
+
+ ///
+ /// Adds the data of an agent to the current batch so it will be processed in DecideAction.
+ ///
+ ///
+ ///
+ public void SendState(Agent agent, AgentInfo info)
+ {
+ LazyInitialize();
+ agentInfos.Add(agent, info);
+
+ }
+
+ ///
+ /// If the Brain is not initialized, it subscribes to the Academy's DecideAction Event and
+ /// calls the Initialize method to be implemented by child classes.
+ ///
+ private void LazyInitialize()
+ {
+ if (!_isInitialized)
+ {
+ FindObjectOfType().BrainDecideAction += BrainDecideAction;
+ Initialize();
+ _isInitialized = true;
+ }
+ }
+
+ ///
+ /// Calls the DecideAction method that the concrete brain implements.
+ ///
+ private void BrainDecideAction()
+ {
+ brainBatcher?.SendBrainInfo(name, agentInfos);
+ DecideAction();
+ }
+
+ ///
+ /// Is called only once at the begening of the training or inference session.
+ ///
+ protected abstract void Initialize();
+
+ ///
+ /// Is called once per Environment Step after the Brain has been initialized.
+ ///
+ protected abstract void DecideAction();
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Brain.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Brain.cs.meta
new file mode 100755
index 00000000..eaf6f070
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Brain.cs.meta
@@ -0,0 +1,12 @@
+fileFormatVersion: 2
+guid: c676a8ddf5a5f4f64b35e9ed5028679d
+timeCreated: 1503211687
+licenseType: Free
+MonoImporter:
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BrainParameters.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BrainParameters.cs
new file mode 100755
index 00000000..5252284f
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BrainParameters.cs
@@ -0,0 +1,124 @@
+using UnityEngine;
+using System.Linq;
+
+namespace MLAgents
+{
+ public enum SpaceType
+ {
+ discrete,
+ continuous
+ };
+
+ ///
+ /// The resolution of a camera used by an agent.
+ /// The width defines the number of pixels on the horizontal axis.
+ /// The height defines the number of pixels on the verical axis.
+ /// blackAndWhite defines whether or not the image is grayscale.
+ ///
+ [System.Serializable]
+ public struct Resolution
+ {
+ public int width;
+
+ /**< \brief The width of the observation in pixels */
+ public int height;
+
+ /**< \brief The height of the observation in pixels */
+ public bool blackAndWhite;
+ /**< \brief If true, the image will be in black and white.
+ * If false, it will be in colors RGB */
+ }
+
+ ///
+ /// Holds information about the Brain. It defines what are the inputs and outputs of the
+ /// decision process.
+ ///
+ [System.Serializable]
+ public class BrainParameters
+ {
+ public int vectorObservationSize = 1;
+ /**< \brief If continuous : The length of the float vector that represents
+ * the state
+ *
If discrete : The number of possible values the state can take*/
+
+ [Range(1, 50)] public int numStackedVectorObservations = 1;
+
+ public int[] vectorActionSize = new int[1]{1};
+ /**< \brief If continuous : The length of the float vector that represents
+ * the action
+ *
If discrete : The number of possible values the action can take*/
+
+ public Resolution[] cameraResolutions;
+ /**<\brief The list of observation resolutions for the brain */
+
+ public string[] vectorActionDescriptions;
+ /**< \brief The list of strings describing what the actions correpond to */
+
+ public SpaceType vectorActionSpaceType = SpaceType.discrete;
+ /**< \brief Defines if the action is discrete or continuous */
+
+ ///
+ /// Converts a Brain into to a Protobuff BrainInfoProto so it can be sent
+ ///
+ /// The BrainInfoProto generated.
+ /// The name of the brain.
+ /// Whether or not the Brain is training.
+ public CommunicatorObjects.BrainParametersProto
+ ToProto(string name, bool isTraining)
+ {
+ var brainParametersProto = new CommunicatorObjects.BrainParametersProto
+ {
+ VectorObservationSize = vectorObservationSize,
+ NumStackedVectorObservations = numStackedVectorObservations,
+ VectorActionSize = {vectorActionSize},
+ VectorActionSpaceType =
+ (CommunicatorObjects.SpaceTypeProto)vectorActionSpaceType,
+ BrainName = name,
+ IsTraining = isTraining
+ };
+ brainParametersProto.VectorActionDescriptions.AddRange(vectorActionDescriptions);
+ foreach (Resolution res in cameraResolutions)
+ {
+ brainParametersProto.CameraResolutions.Add(
+ new CommunicatorObjects.ResolutionProto
+ {
+ Width = res.width,
+ Height = res.height,
+ GrayScale = res.blackAndWhite
+ });
+ }
+ return brainParametersProto;
+ }
+
+ public BrainParameters()
+ {
+
+ }
+
+ public BrainParameters(CommunicatorObjects.BrainParametersProto brainParametersProto)
+ {
+ vectorObservationSize = brainParametersProto.VectorObservationSize;
+ numStackedVectorObservations = brainParametersProto.NumStackedVectorObservations;
+ vectorActionSize = brainParametersProto.VectorActionSize.ToArray();
+ vectorActionDescriptions = brainParametersProto.VectorActionDescriptions.ToArray();
+ vectorActionSpaceType = (SpaceType)brainParametersProto.VectorActionSpaceType;
+ }
+
+ ///
+ /// Deep clones the BrainParameter object
+ ///
+ /// A new BrainParameter object with the same values as the original.
+ public BrainParameters Clone()
+ {
+ return new BrainParameters()
+ {
+ vectorObservationSize = this.vectorObservationSize,
+ numStackedVectorObservations = this.numStackedVectorObservations,
+ vectorActionSize = (int[]) this.vectorActionSize.Clone(),
+ cameraResolutions = (Resolution[]) this.cameraResolutions.Clone(),
+ vectorActionDescriptions = (string[]) this.vectorActionDescriptions.Clone(),
+ vectorActionSpaceType = this.vectorActionSpaceType
+ };
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BrainParameters.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BrainParameters.cs.meta
new file mode 100755
index 00000000..248b4d0f
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BrainParameters.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 6108a41e9be04c238d7babaed4476134
+timeCreated: 1538758934
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BroadcastHub.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BroadcastHub.cs
new file mode 100755
index 00000000..569518be
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BroadcastHub.cs
@@ -0,0 +1,68 @@
+using System.Collections.Generic;
+using UnityEngine;
+
+namespace MLAgents
+{
+ ///
+ /// BroadcastHub holds reference to brains and keeps track wether or not the brain be
+ /// remotely controlled.
+ ///
+ [System.Serializable]
+ public class BroadcastHub
+ {
+ [SerializeField]
+ public List broadcastingBrains = new List();
+ [SerializeField]
+ private List _brainsToControl = new List();
+
+ ///
+ /// The number of Brains inside the BroadcastingHub.
+ ///
+ public int Count
+ {
+ get { return broadcastingBrains.Count; }
+ }
+
+ ///
+ /// Checks that a given Brain is set to be remote controlled.
+ ///
+ /// The Brain that is beeing checked
+ /// true if the Brain is set to Controlled and false otherwise. Will return
+ /// false if the Brain is not present in the Hub.
+ public bool IsControlled(Brain brain)
+ {
+ return _brainsToControl.Contains(brain);
+ }
+
+ ///
+ /// Sets a brain to controlled.
+ ///
+ /// The Brain that is being set to controlled
+ /// if true, the Brain will be set to remote controlled. Otherwise
+ /// the brain will be set to broadcast only.
+ public void SetControlled(Brain brain, bool controlled)
+ {
+ if (broadcastingBrains.Contains(brain))
+ {
+ if (controlled && !_brainsToControl.Contains(brain))
+ {
+ _brainsToControl.Add(brain);
+ }
+
+ if (!controlled && _brainsToControl.Contains(brain))
+ {
+ _brainsToControl.Remove(brain);
+ }
+ }
+ }
+
+ ///
+ /// Removes all the Brains of the BroadcastHub
+ ///
+ public void Clear()
+ {
+ broadcastingBrains.Clear();
+ _brainsToControl.Clear();
+ }
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BroadcastHub.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BroadcastHub.cs.meta
new file mode 100755
index 00000000..70bcf9b6
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/BroadcastHub.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: e43fd511c9f147e487d80e0bab3f6c6b
+timeCreated: 1536851538
\ No newline at end of file
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Communicator.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Communicator.cs
new file mode 100755
index 00000000..61f2f50c
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Communicator.cs
@@ -0,0 +1,74 @@
+using System.Collections;
+using System.Collections.Generic;
+using UnityEngine;
+using MLAgents.CommunicatorObjects;
+
+namespace MLAgents
+{
+ public struct CommunicatorParameters
+ {
+ public int port;
+ }
+
+ /**
+ This is the interface of the Communicators.
+ This does not need to be modified nor implemented to create a Unity environment.
+
+ When the Unity Communicator is initialized, it will wait for the External Communicator
+ to be initialized as well. The two communicators will then exchange their first messages
+ that will usually contain information for initialization (information that does not need
+ to be resent at each new exchange).
+
+ By convention a Unity input is from External to Unity and a Unity output is from Unity to
+ External. Inputs and outputs are relative to Unity.
+
+ By convention, when the Unity Communicator and External Communicator call exchange, the
+ exchange is NOT simultaneous but sequential. This means that when a side of the
+ communication calls exchange, the other will receive the result of its previous
+ exchange call.
+ This is what happens when A calls exchange a single time:
+ A sends data_1 to B -> B receives data_1 -> B generates and sends data_2 -> A receives data_2
+ When A calls exchange, it sends data_1 and receives data_2
+
+ Since the messages are sent back and forth with exchange and simultaneously when calling
+ initialize, External sends two messages at initialization.
+
+ The structure of the messages is as follows:
+ UnityMessage
+ ...Header
+ ...UnityOutput
+ ......UnityRLOutput
+ ......UnityRLInitializationOutput
+ ...UnityInput
+ ......UnityRLIntput
+ ......UnityRLInitializationIntput
+
+ UnityOutput and UnityInput can be extended to provide functionalities beyond RL
+ UnityRLOutput and UnityRLInput can be extended to provide new RL functionalities
+ */
+ public interface Communicator
+ {
+ ///
+ /// Initialize the communicator by sending the first UnityOutput and receiving the
+ /// first UnityInput. The second UnityInput is stored in the unityInput argument.
+ ///
+ /// The first Unity Input.
+ /// The first Unity Output.
+ /// The second Unity input.
+ UnityInput Initialize(UnityOutput unityOutput,
+ out UnityInput unityInput);
+
+ ///
+ /// Send a UnityOutput and receives a UnityInput.
+ ///
+ /// The next UnityInput.
+ /// The UnityOutput to be sent.
+ UnityInput Exchange(UnityOutput unityOutput);
+
+ ///
+ /// Close the communicator gracefully on both sides of the communication.
+ ///
+ void Close();
+
+ }
+}
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Communicator.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Communicator.cs.meta
new file mode 100755
index 00000000..165be840
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/Communicator.cs.meta
@@ -0,0 +1,12 @@
+fileFormatVersion: 2
+guid: 18600657fd7d241a199e6caf2ba7cceb
+timeCreated: 1504820023
+licenseType: Free
+MonoImporter:
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects.meta
new file mode 100755
index 00000000..cef92044
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 7ebeef5df83b74a048b7f99681672f3b
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentActionProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentActionProto.cs
new file mode 100755
index 00000000..c2aa9158
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentActionProto.cs
@@ -0,0 +1,245 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/agent_action_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/agent_action_proto.proto
+ public static partial class AgentActionProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/agent_action_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static AgentActionProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjZhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9hY3Rpb25f",
+ "cHJvdG8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzImEKEEFnZW50QWN0",
+ "aW9uUHJvdG8SFgoOdmVjdG9yX2FjdGlvbnMYASADKAISFAoMdGV4dF9hY3Rp",
+ "b25zGAIgASgJEhAKCG1lbW9yaWVzGAMgAygCEg0KBXZhbHVlGAQgASgCQh+q",
+ "AhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Memories", "Value" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class AgentActionProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AgentActionProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.AgentActionProtoReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentActionProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentActionProto(AgentActionProto other) : this() {
+ vectorActions_ = other.vectorActions_.Clone();
+ textActions_ = other.textActions_;
+ memories_ = other.memories_.Clone();
+ value_ = other.value_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentActionProto Clone() {
+ return new AgentActionProto(this);
+ }
+
+ /// Field number for the "vector_actions" field.
+ public const int VectorActionsFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_vectorActions_codec
+ = pb::FieldCodec.ForFloat(10);
+ private readonly pbc::RepeatedField vectorActions_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField VectorActions {
+ get { return vectorActions_; }
+ }
+
+ /// Field number for the "text_actions" field.
+ public const int TextActionsFieldNumber = 2;
+ private string textActions_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string TextActions {
+ get { return textActions_; }
+ set {
+ textActions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "memories" field.
+ public const int MemoriesFieldNumber = 3;
+ private static readonly pb::FieldCodec _repeated_memories_codec
+ = pb::FieldCodec.ForFloat(26);
+ private readonly pbc::RepeatedField memories_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Memories {
+ get { return memories_; }
+ }
+
+ /// Field number for the "value" field.
+ public const int ValueFieldNumber = 4;
+ private float value_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public float Value {
+ get { return value_; }
+ set {
+ value_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as AgentActionProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(AgentActionProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!vectorActions_.Equals(other.vectorActions_)) return false;
+ if (TextActions != other.TextActions) return false;
+ if(!memories_.Equals(other.memories_)) return false;
+ if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Value, other.Value)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= vectorActions_.GetHashCode();
+ if (TextActions.Length != 0) hash ^= TextActions.GetHashCode();
+ hash ^= memories_.GetHashCode();
+ if (Value != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Value);
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ vectorActions_.WriteTo(output, _repeated_vectorActions_codec);
+ if (TextActions.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(TextActions);
+ }
+ memories_.WriteTo(output, _repeated_memories_codec);
+ if (Value != 0F) {
+ output.WriteRawTag(37);
+ output.WriteFloat(Value);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += vectorActions_.CalculateSize(_repeated_vectorActions_codec);
+ if (TextActions.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(TextActions);
+ }
+ size += memories_.CalculateSize(_repeated_memories_codec);
+ if (Value != 0F) {
+ size += 1 + 4;
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(AgentActionProto other) {
+ if (other == null) {
+ return;
+ }
+ vectorActions_.Add(other.vectorActions_);
+ if (other.TextActions.Length != 0) {
+ TextActions = other.TextActions;
+ }
+ memories_.Add(other.memories_);
+ if (other.Value != 0F) {
+ Value = other.Value;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10:
+ case 13: {
+ vectorActions_.AddEntriesFrom(input, _repeated_vectorActions_codec);
+ break;
+ }
+ case 18: {
+ TextActions = input.ReadString();
+ break;
+ }
+ case 26:
+ case 29: {
+ memories_.AddEntriesFrom(input, _repeated_memories_codec);
+ break;
+ }
+ case 37: {
+ Value = input.ReadFloat();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentActionProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentActionProto.cs.meta
new file mode 100755
index 00000000..3f09aabf
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentActionProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 4482f127d4a874cf8a11da2b2cc27dc9
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentInfoProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentInfoProto.cs
new file mode 100755
index 00000000..a620a193
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentInfoProto.cs
@@ -0,0 +1,423 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/agent_info_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/agent_info_proto.proto
+ public static partial class AgentInfoProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/agent_info_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static AgentInfoProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjRhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9pbmZvX3By",
+ "b3RvLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKSAgoOQWdlbnRJbmZv",
+ "UHJvdG8SIgoac3RhY2tlZF92ZWN0b3Jfb2JzZXJ2YXRpb24YASADKAISGwoT",
+ "dmlzdWFsX29ic2VydmF0aW9ucxgCIAMoDBIYChB0ZXh0X29ic2VydmF0aW9u",
+ "GAMgASgJEh0KFXN0b3JlZF92ZWN0b3JfYWN0aW9ucxgEIAMoAhIbChNzdG9y",
+ "ZWRfdGV4dF9hY3Rpb25zGAUgASgJEhAKCG1lbW9yaWVzGAYgAygCEg4KBnJl",
+ "d2FyZBgHIAEoAhIMCgRkb25lGAggASgIEhgKEG1heF9zdGVwX3JlYWNoZWQY",
+ "CSABKAgSCgoCaWQYCiABKAUSEwoLYWN0aW9uX21hc2sYCyADKAhCH6oCHE1M",
+ "QWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "VisualObservations", "TextObservation", "StoredVectorActions", "StoredTextActions", "Memories", "Reward", "Done", "MaxStepReached", "Id", "ActionMask" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class AgentInfoProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AgentInfoProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.AgentInfoProtoReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentInfoProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentInfoProto(AgentInfoProto other) : this() {
+ stackedVectorObservation_ = other.stackedVectorObservation_.Clone();
+ visualObservations_ = other.visualObservations_.Clone();
+ textObservation_ = other.textObservation_;
+ storedVectorActions_ = other.storedVectorActions_.Clone();
+ storedTextActions_ = other.storedTextActions_;
+ memories_ = other.memories_.Clone();
+ reward_ = other.reward_;
+ done_ = other.done_;
+ maxStepReached_ = other.maxStepReached_;
+ id_ = other.id_;
+ actionMask_ = other.actionMask_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentInfoProto Clone() {
+ return new AgentInfoProto(this);
+ }
+
+ /// Field number for the "stacked_vector_observation" field.
+ public const int StackedVectorObservationFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_stackedVectorObservation_codec
+ = pb::FieldCodec.ForFloat(10);
+ private readonly pbc::RepeatedField stackedVectorObservation_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField StackedVectorObservation {
+ get { return stackedVectorObservation_; }
+ }
+
+ /// Field number for the "visual_observations" field.
+ public const int VisualObservationsFieldNumber = 2;
+ private static readonly pb::FieldCodec _repeated_visualObservations_codec
+ = pb::FieldCodec.ForBytes(18);
+ private readonly pbc::RepeatedField visualObservations_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField VisualObservations {
+ get { return visualObservations_; }
+ }
+
+ /// Field number for the "text_observation" field.
+ public const int TextObservationFieldNumber = 3;
+ private string textObservation_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string TextObservation {
+ get { return textObservation_; }
+ set {
+ textObservation_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "stored_vector_actions" field.
+ public const int StoredVectorActionsFieldNumber = 4;
+ private static readonly pb::FieldCodec _repeated_storedVectorActions_codec
+ = pb::FieldCodec.ForFloat(34);
+ private readonly pbc::RepeatedField storedVectorActions_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField StoredVectorActions {
+ get { return storedVectorActions_; }
+ }
+
+ /// Field number for the "stored_text_actions" field.
+ public const int StoredTextActionsFieldNumber = 5;
+ private string storedTextActions_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string StoredTextActions {
+ get { return storedTextActions_; }
+ set {
+ storedTextActions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "memories" field.
+ public const int MemoriesFieldNumber = 6;
+ private static readonly pb::FieldCodec _repeated_memories_codec
+ = pb::FieldCodec.ForFloat(50);
+ private readonly pbc::RepeatedField memories_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Memories {
+ get { return memories_; }
+ }
+
+ /// Field number for the "reward" field.
+ public const int RewardFieldNumber = 7;
+ private float reward_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public float Reward {
+ get { return reward_; }
+ set {
+ reward_ = value;
+ }
+ }
+
+ /// Field number for the "done" field.
+ public const int DoneFieldNumber = 8;
+ private bool done_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Done {
+ get { return done_; }
+ set {
+ done_ = value;
+ }
+ }
+
+ /// Field number for the "max_step_reached" field.
+ public const int MaxStepReachedFieldNumber = 9;
+ private bool maxStepReached_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool MaxStepReached {
+ get { return maxStepReached_; }
+ set {
+ maxStepReached_ = value;
+ }
+ }
+
+ /// Field number for the "id" field.
+ public const int IdFieldNumber = 10;
+ private int id_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Id {
+ get { return id_; }
+ set {
+ id_ = value;
+ }
+ }
+
+ /// Field number for the "action_mask" field.
+ public const int ActionMaskFieldNumber = 11;
+ private static readonly pb::FieldCodec _repeated_actionMask_codec
+ = pb::FieldCodec.ForBool(90);
+ private readonly pbc::RepeatedField actionMask_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField ActionMask {
+ get { return actionMask_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as AgentInfoProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(AgentInfoProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!stackedVectorObservation_.Equals(other.stackedVectorObservation_)) return false;
+ if(!visualObservations_.Equals(other.visualObservations_)) return false;
+ if (TextObservation != other.TextObservation) return false;
+ if(!storedVectorActions_.Equals(other.storedVectorActions_)) return false;
+ if (StoredTextActions != other.StoredTextActions) return false;
+ if(!memories_.Equals(other.memories_)) return false;
+ if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Reward, other.Reward)) return false;
+ if (Done != other.Done) return false;
+ if (MaxStepReached != other.MaxStepReached) return false;
+ if (Id != other.Id) return false;
+ if(!actionMask_.Equals(other.actionMask_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= stackedVectorObservation_.GetHashCode();
+ hash ^= visualObservations_.GetHashCode();
+ if (TextObservation.Length != 0) hash ^= TextObservation.GetHashCode();
+ hash ^= storedVectorActions_.GetHashCode();
+ if (StoredTextActions.Length != 0) hash ^= StoredTextActions.GetHashCode();
+ hash ^= memories_.GetHashCode();
+ if (Reward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Reward);
+ if (Done != false) hash ^= Done.GetHashCode();
+ if (MaxStepReached != false) hash ^= MaxStepReached.GetHashCode();
+ if (Id != 0) hash ^= Id.GetHashCode();
+ hash ^= actionMask_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ stackedVectorObservation_.WriteTo(output, _repeated_stackedVectorObservation_codec);
+ visualObservations_.WriteTo(output, _repeated_visualObservations_codec);
+ if (TextObservation.Length != 0) {
+ output.WriteRawTag(26);
+ output.WriteString(TextObservation);
+ }
+ storedVectorActions_.WriteTo(output, _repeated_storedVectorActions_codec);
+ if (StoredTextActions.Length != 0) {
+ output.WriteRawTag(42);
+ output.WriteString(StoredTextActions);
+ }
+ memories_.WriteTo(output, _repeated_memories_codec);
+ if (Reward != 0F) {
+ output.WriteRawTag(61);
+ output.WriteFloat(Reward);
+ }
+ if (Done != false) {
+ output.WriteRawTag(64);
+ output.WriteBool(Done);
+ }
+ if (MaxStepReached != false) {
+ output.WriteRawTag(72);
+ output.WriteBool(MaxStepReached);
+ }
+ if (Id != 0) {
+ output.WriteRawTag(80);
+ output.WriteInt32(Id);
+ }
+ actionMask_.WriteTo(output, _repeated_actionMask_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += stackedVectorObservation_.CalculateSize(_repeated_stackedVectorObservation_codec);
+ size += visualObservations_.CalculateSize(_repeated_visualObservations_codec);
+ if (TextObservation.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(TextObservation);
+ }
+ size += storedVectorActions_.CalculateSize(_repeated_storedVectorActions_codec);
+ if (StoredTextActions.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(StoredTextActions);
+ }
+ size += memories_.CalculateSize(_repeated_memories_codec);
+ if (Reward != 0F) {
+ size += 1 + 4;
+ }
+ if (Done != false) {
+ size += 1 + 1;
+ }
+ if (MaxStepReached != false) {
+ size += 1 + 1;
+ }
+ if (Id != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
+ }
+ size += actionMask_.CalculateSize(_repeated_actionMask_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(AgentInfoProto other) {
+ if (other == null) {
+ return;
+ }
+ stackedVectorObservation_.Add(other.stackedVectorObservation_);
+ visualObservations_.Add(other.visualObservations_);
+ if (other.TextObservation.Length != 0) {
+ TextObservation = other.TextObservation;
+ }
+ storedVectorActions_.Add(other.storedVectorActions_);
+ if (other.StoredTextActions.Length != 0) {
+ StoredTextActions = other.StoredTextActions;
+ }
+ memories_.Add(other.memories_);
+ if (other.Reward != 0F) {
+ Reward = other.Reward;
+ }
+ if (other.Done != false) {
+ Done = other.Done;
+ }
+ if (other.MaxStepReached != false) {
+ MaxStepReached = other.MaxStepReached;
+ }
+ if (other.Id != 0) {
+ Id = other.Id;
+ }
+ actionMask_.Add(other.actionMask_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10:
+ case 13: {
+ stackedVectorObservation_.AddEntriesFrom(input, _repeated_stackedVectorObservation_codec);
+ break;
+ }
+ case 18: {
+ visualObservations_.AddEntriesFrom(input, _repeated_visualObservations_codec);
+ break;
+ }
+ case 26: {
+ TextObservation = input.ReadString();
+ break;
+ }
+ case 34:
+ case 37: {
+ storedVectorActions_.AddEntriesFrom(input, _repeated_storedVectorActions_codec);
+ break;
+ }
+ case 42: {
+ StoredTextActions = input.ReadString();
+ break;
+ }
+ case 50:
+ case 53: {
+ memories_.AddEntriesFrom(input, _repeated_memories_codec);
+ break;
+ }
+ case 61: {
+ Reward = input.ReadFloat();
+ break;
+ }
+ case 64: {
+ Done = input.ReadBool();
+ break;
+ }
+ case 72: {
+ MaxStepReached = input.ReadBool();
+ break;
+ }
+ case 80: {
+ Id = input.ReadInt32();
+ break;
+ }
+ case 90:
+ case 88: {
+ actionMask_.AddEntriesFrom(input, _repeated_actionMask_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentInfoProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentInfoProto.cs.meta
new file mode 100755
index 00000000..a6632288
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/AgentInfoProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 791522439b8324bff85f84309db90ecc
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ArenaParametersProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ArenaParametersProto.cs
new file mode 100755
index 00000000..7ddab857
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ArenaParametersProto.cs
@@ -0,0 +1,634 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/arena_parameters_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/arena_parameters_proto.proto
+ public static partial class ArenaParametersProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/arena_parameters_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static ArenaParametersProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjphbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9hcmVuYV9wYXJhbWV0",
+ "ZXJzX3Byb3RvLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyLPAwoUQXJl",
+ "bmFQYXJhbWV0ZXJzUHJvdG8SCQoBdBgBIAEoBRJGCgVpdGVtcxgCIAMoCzI3",
+ "LmNvbW11bmljYXRvcl9vYmplY3RzLkFyZW5hUGFyYW1ldGVyc1Byb3RvLkl0",
+ "ZW1zVG9TcGF3bhIRCglibGFja291dHMYAyADKAUa0AIKDEl0ZW1zVG9TcGF3",
+ "bhIMCgRuYW1lGAEgASgJElIKCXBvc2l0aW9ucxgDIAMoCzI/LmNvbW11bmlj",
+ "YXRvcl9vYmplY3RzLkFyZW5hUGFyYW1ldGVyc1Byb3RvLkl0ZW1zVG9TcGF3",
+ "bi5WZWN0b3IzEhEKCXJvdGF0aW9ucxgEIAMoAhJOCgVzaXplcxgFIAMoCzI/",
+ "LmNvbW11bmljYXRvcl9vYmplY3RzLkFyZW5hUGFyYW1ldGVyc1Byb3RvLkl0",
+ "ZW1zVG9TcGF3bi5WZWN0b3IzEk8KBmNvbG9ycxgGIAMoCzI/LmNvbW11bmlj",
+ "YXRvcl9vYmplY3RzLkFyZW5hUGFyYW1ldGVyc1Byb3RvLkl0ZW1zVG9TcGF3",
+ "bi5WZWN0b3IzGioKB1ZlY3RvcjMSCQoBeBgBIAEoAhIJCgF5GAIgASgCEgkK",
+ "AXoYAyABKAJCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
+ "b3RvMw=="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.ArenaParametersProto), global::MLAgents.CommunicatorObjects.ArenaParametersProto.Parser, new[]{ "T", "Items", "Blackouts" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn), global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn.Parser, new[]{ "Name", "Positions", "Rotations", "Sizes", "Colors" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn.Types.Vector3), global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn.Types.Vector3.Parser, new[]{ "X", "Y", "Z" }, null, null, null)})})
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class ArenaParametersProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ArenaParametersProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.ArenaParametersProtoReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ArenaParametersProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ArenaParametersProto(ArenaParametersProto other) : this() {
+ t_ = other.t_;
+ items_ = other.items_.Clone();
+ blackouts_ = other.blackouts_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ArenaParametersProto Clone() {
+ return new ArenaParametersProto(this);
+ }
+
+ /// Field number for the "t" field.
+ public const int TFieldNumber = 1;
+ private int t_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int T {
+ get { return t_; }
+ set {
+ t_ = value;
+ }
+ }
+
+ /// Field number for the "items" field.
+ public const int ItemsFieldNumber = 2;
+ private static readonly pb::FieldCodec _repeated_items_codec
+ = pb::FieldCodec.ForMessage(18, global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn.Parser);
+ private readonly pbc::RepeatedField items_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Items {
+ get { return items_; }
+ }
+
+ /// Field number for the "blackouts" field.
+ public const int BlackoutsFieldNumber = 3;
+ private static readonly pb::FieldCodec _repeated_blackouts_codec
+ = pb::FieldCodec.ForInt32(26);
+ private readonly pbc::RepeatedField blackouts_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Blackouts {
+ get { return blackouts_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as ArenaParametersProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(ArenaParametersProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (T != other.T) return false;
+ if(!items_.Equals(other.items_)) return false;
+ if(!blackouts_.Equals(other.blackouts_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (T != 0) hash ^= T.GetHashCode();
+ hash ^= items_.GetHashCode();
+ hash ^= blackouts_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (T != 0) {
+ output.WriteRawTag(8);
+ output.WriteInt32(T);
+ }
+ items_.WriteTo(output, _repeated_items_codec);
+ blackouts_.WriteTo(output, _repeated_blackouts_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (T != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(T);
+ }
+ size += items_.CalculateSize(_repeated_items_codec);
+ size += blackouts_.CalculateSize(_repeated_blackouts_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(ArenaParametersProto other) {
+ if (other == null) {
+ return;
+ }
+ if (other.T != 0) {
+ T = other.T;
+ }
+ items_.Add(other.items_);
+ blackouts_.Add(other.blackouts_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ T = input.ReadInt32();
+ break;
+ }
+ case 18: {
+ items_.AddEntriesFrom(input, _repeated_items_codec);
+ break;
+ }
+ case 26:
+ case 24: {
+ blackouts_.AddEntriesFrom(input, _repeated_blackouts_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ #region Nested types
+ /// Container for nested types declared in the ArenaParametersProto message type.
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static partial class Types {
+ public sealed partial class ItemsToSpawn : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ItemsToSpawn());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.ArenaParametersProto.Descriptor.NestedTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ItemsToSpawn() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ItemsToSpawn(ItemsToSpawn other) : this() {
+ name_ = other.name_;
+ positions_ = other.positions_.Clone();
+ rotations_ = other.rotations_.Clone();
+ sizes_ = other.sizes_.Clone();
+ colors_ = other.colors_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ItemsToSpawn Clone() {
+ return new ItemsToSpawn(this);
+ }
+
+ /// Field number for the "name" field.
+ public const int NameFieldNumber = 1;
+ private string name_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string Name {
+ get { return name_; }
+ set {
+ name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "positions" field.
+ public const int PositionsFieldNumber = 3;
+ private static readonly pb::FieldCodec _repeated_positions_codec
+ = pb::FieldCodec.ForMessage(26, global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn.Types.Vector3.Parser);
+ private readonly pbc::RepeatedField positions_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Positions {
+ get { return positions_; }
+ }
+
+ /// Field number for the "rotations" field.
+ public const int RotationsFieldNumber = 4;
+ private static readonly pb::FieldCodec _repeated_rotations_codec
+ = pb::FieldCodec.ForFloat(34);
+ private readonly pbc::RepeatedField rotations_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Rotations {
+ get { return rotations_; }
+ }
+
+ /// Field number for the "sizes" field.
+ public const int SizesFieldNumber = 5;
+ private static readonly pb::FieldCodec _repeated_sizes_codec
+ = pb::FieldCodec.ForMessage(42, global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn.Types.Vector3.Parser);
+ private readonly pbc::RepeatedField sizes_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Sizes {
+ get { return sizes_; }
+ }
+
+ /// Field number for the "colors" field.
+ public const int ColorsFieldNumber = 6;
+ private static readonly pb::FieldCodec _repeated_colors_codec
+ = pb::FieldCodec.ForMessage(50, global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn.Types.Vector3.Parser);
+ private readonly pbc::RepeatedField colors_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Colors {
+ get { return colors_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as ItemsToSpawn);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(ItemsToSpawn other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (Name != other.Name) return false;
+ if(!positions_.Equals(other.positions_)) return false;
+ if(!rotations_.Equals(other.rotations_)) return false;
+ if(!sizes_.Equals(other.sizes_)) return false;
+ if(!colors_.Equals(other.colors_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (Name.Length != 0) hash ^= Name.GetHashCode();
+ hash ^= positions_.GetHashCode();
+ hash ^= rotations_.GetHashCode();
+ hash ^= sizes_.GetHashCode();
+ hash ^= colors_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (Name.Length != 0) {
+ output.WriteRawTag(10);
+ output.WriteString(Name);
+ }
+ positions_.WriteTo(output, _repeated_positions_codec);
+ rotations_.WriteTo(output, _repeated_rotations_codec);
+ sizes_.WriteTo(output, _repeated_sizes_codec);
+ colors_.WriteTo(output, _repeated_colors_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (Name.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
+ }
+ size += positions_.CalculateSize(_repeated_positions_codec);
+ size += rotations_.CalculateSize(_repeated_rotations_codec);
+ size += sizes_.CalculateSize(_repeated_sizes_codec);
+ size += colors_.CalculateSize(_repeated_colors_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(ItemsToSpawn other) {
+ if (other == null) {
+ return;
+ }
+ if (other.Name.Length != 0) {
+ Name = other.Name;
+ }
+ positions_.Add(other.positions_);
+ rotations_.Add(other.rotations_);
+ sizes_.Add(other.sizes_);
+ colors_.Add(other.colors_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ Name = input.ReadString();
+ break;
+ }
+ case 26: {
+ positions_.AddEntriesFrom(input, _repeated_positions_codec);
+ break;
+ }
+ case 34:
+ case 37: {
+ rotations_.AddEntriesFrom(input, _repeated_rotations_codec);
+ break;
+ }
+ case 42: {
+ sizes_.AddEntriesFrom(input, _repeated_sizes_codec);
+ break;
+ }
+ case 50: {
+ colors_.AddEntriesFrom(input, _repeated_colors_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ #region Nested types
+ /// Container for nested types declared in the ItemsToSpawn message type.
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static partial class Types {
+ public sealed partial class Vector3 : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Vector3());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.ArenaParametersProto.Types.ItemsToSpawn.Descriptor.NestedTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public Vector3() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public Vector3(Vector3 other) : this() {
+ x_ = other.x_;
+ y_ = other.y_;
+ z_ = other.z_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public Vector3 Clone() {
+ return new Vector3(this);
+ }
+
+ /// Field number for the "x" field.
+ public const int XFieldNumber = 1;
+ private float x_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public float X {
+ get { return x_; }
+ set {
+ x_ = value;
+ }
+ }
+
+ /// Field number for the "y" field.
+ public const int YFieldNumber = 2;
+ private float y_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public float Y {
+ get { return y_; }
+ set {
+ y_ = value;
+ }
+ }
+
+ /// Field number for the "z" field.
+ public const int ZFieldNumber = 3;
+ private float z_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public float Z {
+ get { return z_; }
+ set {
+ z_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as Vector3);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(Vector3 other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(X, other.X)) return false;
+ if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Y, other.Y)) return false;
+ if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Z, other.Z)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (X != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(X);
+ if (Y != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Y);
+ if (Z != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Z);
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (X != 0F) {
+ output.WriteRawTag(13);
+ output.WriteFloat(X);
+ }
+ if (Y != 0F) {
+ output.WriteRawTag(21);
+ output.WriteFloat(Y);
+ }
+ if (Z != 0F) {
+ output.WriteRawTag(29);
+ output.WriteFloat(Z);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (X != 0F) {
+ size += 1 + 4;
+ }
+ if (Y != 0F) {
+ size += 1 + 4;
+ }
+ if (Z != 0F) {
+ size += 1 + 4;
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(Vector3 other) {
+ if (other == null) {
+ return;
+ }
+ if (other.X != 0F) {
+ X = other.X;
+ }
+ if (other.Y != 0F) {
+ Y = other.Y;
+ }
+ if (other.Z != 0F) {
+ Z = other.Z;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 13: {
+ X = input.ReadFloat();
+ break;
+ }
+ case 21: {
+ Y = input.ReadFloat();
+ break;
+ }
+ case 29: {
+ Z = input.ReadFloat();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ }
+ #endregion
+
+ }
+
+ }
+ #endregion
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ArenaParametersProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ArenaParametersProto.cs.meta
new file mode 100755
index 00000000..35719069
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ArenaParametersProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 94c2a102066cf2453bdada5a9fae589c
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/BrainParametersProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/BrainParametersProto.cs
new file mode 100755
index 00000000..a00cda19
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/BrainParametersProto.cs
@@ -0,0 +1,356 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/brain_parameters_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/brain_parameters_proto.proto
+ public static partial class BrainParametersProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/brain_parameters_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static BrainParametersProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjphbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9icmFpbl9wYXJhbWV0",
+ "ZXJzX3Byb3RvLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cxo0YW5pbWFs",
+ "YWkvY29tbXVuaWNhdG9yX29iamVjdHMvcmVzb2x1dGlvbl9wcm90by5wcm90",
+ "bxo0YW5pbWFsYWkvY29tbXVuaWNhdG9yX29iamVjdHMvc3BhY2VfdHlwZV9w",
+ "cm90by5wcm90byLUAgoUQnJhaW5QYXJhbWV0ZXJzUHJvdG8SHwoXdmVjdG9y",
+ "X29ic2VydmF0aW9uX3NpemUYASABKAUSJwofbnVtX3N0YWNrZWRfdmVjdG9y",
+ "X29ic2VydmF0aW9ucxgCIAEoBRIaChJ2ZWN0b3JfYWN0aW9uX3NpemUYAyAD",
+ "KAUSQQoSY2FtZXJhX3Jlc29sdXRpb25zGAQgAygLMiUuY29tbXVuaWNhdG9y",
+ "X29iamVjdHMuUmVzb2x1dGlvblByb3RvEiIKGnZlY3Rvcl9hY3Rpb25fZGVz",
+ "Y3JpcHRpb25zGAUgAygJEkYKGHZlY3Rvcl9hY3Rpb25fc3BhY2VfdHlwZRgG",
+ "IAEoDjIkLmNvbW11bmljYXRvcl9vYmplY3RzLlNwYWNlVHlwZVByb3RvEhIK",
+ "CmJyYWluX25hbWUYByABKAkSEwoLaXNfdHJhaW5pbmcYCCABKAhCH6oCHE1M",
+ "QWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ResolutionProtoReflection.Descriptor, global::MLAgents.CommunicatorObjects.SpaceTypeProtoReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.BrainParametersProto), global::MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorObservationSize", "NumStackedVectorObservations", "VectorActionSize", "CameraResolutions", "VectorActionDescriptions", "VectorActionSpaceType", "BrainName", "IsTraining" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class BrainParametersProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BrainParametersProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.BrainParametersProtoReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public BrainParametersProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public BrainParametersProto(BrainParametersProto other) : this() {
+ vectorObservationSize_ = other.vectorObservationSize_;
+ numStackedVectorObservations_ = other.numStackedVectorObservations_;
+ vectorActionSize_ = other.vectorActionSize_.Clone();
+ cameraResolutions_ = other.cameraResolutions_.Clone();
+ vectorActionDescriptions_ = other.vectorActionDescriptions_.Clone();
+ vectorActionSpaceType_ = other.vectorActionSpaceType_;
+ brainName_ = other.brainName_;
+ isTraining_ = other.isTraining_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public BrainParametersProto Clone() {
+ return new BrainParametersProto(this);
+ }
+
+ /// Field number for the "vector_observation_size" field.
+ public const int VectorObservationSizeFieldNumber = 1;
+ private int vectorObservationSize_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int VectorObservationSize {
+ get { return vectorObservationSize_; }
+ set {
+ vectorObservationSize_ = value;
+ }
+ }
+
+ /// Field number for the "num_stacked_vector_observations" field.
+ public const int NumStackedVectorObservationsFieldNumber = 2;
+ private int numStackedVectorObservations_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int NumStackedVectorObservations {
+ get { return numStackedVectorObservations_; }
+ set {
+ numStackedVectorObservations_ = value;
+ }
+ }
+
+ /// Field number for the "vector_action_size" field.
+ public const int VectorActionSizeFieldNumber = 3;
+ private static readonly pb::FieldCodec _repeated_vectorActionSize_codec
+ = pb::FieldCodec.ForInt32(26);
+ private readonly pbc::RepeatedField vectorActionSize_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField VectorActionSize {
+ get { return vectorActionSize_; }
+ }
+
+ /// Field number for the "camera_resolutions" field.
+ public const int CameraResolutionsFieldNumber = 4;
+ private static readonly pb::FieldCodec _repeated_cameraResolutions_codec
+ = pb::FieldCodec.ForMessage(34, global::MLAgents.CommunicatorObjects.ResolutionProto.Parser);
+ private readonly pbc::RepeatedField cameraResolutions_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField CameraResolutions {
+ get { return cameraResolutions_; }
+ }
+
+ /// Field number for the "vector_action_descriptions" field.
+ public const int VectorActionDescriptionsFieldNumber = 5;
+ private static readonly pb::FieldCodec _repeated_vectorActionDescriptions_codec
+ = pb::FieldCodec.ForString(42);
+ private readonly pbc::RepeatedField vectorActionDescriptions_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField VectorActionDescriptions {
+ get { return vectorActionDescriptions_; }
+ }
+
+ /// Field number for the "vector_action_space_type" field.
+ public const int VectorActionSpaceTypeFieldNumber = 6;
+ private global::MLAgents.CommunicatorObjects.SpaceTypeProto vectorActionSpaceType_ = 0;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.SpaceTypeProto VectorActionSpaceType {
+ get { return vectorActionSpaceType_; }
+ set {
+ vectorActionSpaceType_ = value;
+ }
+ }
+
+ /// Field number for the "brain_name" field.
+ public const int BrainNameFieldNumber = 7;
+ private string brainName_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string BrainName {
+ get { return brainName_; }
+ set {
+ brainName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "is_training" field.
+ public const int IsTrainingFieldNumber = 8;
+ private bool isTraining_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool IsTraining {
+ get { return isTraining_; }
+ set {
+ isTraining_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as BrainParametersProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(BrainParametersProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (VectorObservationSize != other.VectorObservationSize) return false;
+ if (NumStackedVectorObservations != other.NumStackedVectorObservations) return false;
+ if(!vectorActionSize_.Equals(other.vectorActionSize_)) return false;
+ if(!cameraResolutions_.Equals(other.cameraResolutions_)) return false;
+ if(!vectorActionDescriptions_.Equals(other.vectorActionDescriptions_)) return false;
+ if (VectorActionSpaceType != other.VectorActionSpaceType) return false;
+ if (BrainName != other.BrainName) return false;
+ if (IsTraining != other.IsTraining) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (VectorObservationSize != 0) hash ^= VectorObservationSize.GetHashCode();
+ if (NumStackedVectorObservations != 0) hash ^= NumStackedVectorObservations.GetHashCode();
+ hash ^= vectorActionSize_.GetHashCode();
+ hash ^= cameraResolutions_.GetHashCode();
+ hash ^= vectorActionDescriptions_.GetHashCode();
+ if (VectorActionSpaceType != 0) hash ^= VectorActionSpaceType.GetHashCode();
+ if (BrainName.Length != 0) hash ^= BrainName.GetHashCode();
+ if (IsTraining != false) hash ^= IsTraining.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (VectorObservationSize != 0) {
+ output.WriteRawTag(8);
+ output.WriteInt32(VectorObservationSize);
+ }
+ if (NumStackedVectorObservations != 0) {
+ output.WriteRawTag(16);
+ output.WriteInt32(NumStackedVectorObservations);
+ }
+ vectorActionSize_.WriteTo(output, _repeated_vectorActionSize_codec);
+ cameraResolutions_.WriteTo(output, _repeated_cameraResolutions_codec);
+ vectorActionDescriptions_.WriteTo(output, _repeated_vectorActionDescriptions_codec);
+ if (VectorActionSpaceType != 0) {
+ output.WriteRawTag(48);
+ output.WriteEnum((int) VectorActionSpaceType);
+ }
+ if (BrainName.Length != 0) {
+ output.WriteRawTag(58);
+ output.WriteString(BrainName);
+ }
+ if (IsTraining != false) {
+ output.WriteRawTag(64);
+ output.WriteBool(IsTraining);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (VectorObservationSize != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(VectorObservationSize);
+ }
+ if (NumStackedVectorObservations != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumStackedVectorObservations);
+ }
+ size += vectorActionSize_.CalculateSize(_repeated_vectorActionSize_codec);
+ size += cameraResolutions_.CalculateSize(_repeated_cameraResolutions_codec);
+ size += vectorActionDescriptions_.CalculateSize(_repeated_vectorActionDescriptions_codec);
+ if (VectorActionSpaceType != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) VectorActionSpaceType);
+ }
+ if (BrainName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(BrainName);
+ }
+ if (IsTraining != false) {
+ size += 1 + 1;
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(BrainParametersProto other) {
+ if (other == null) {
+ return;
+ }
+ if (other.VectorObservationSize != 0) {
+ VectorObservationSize = other.VectorObservationSize;
+ }
+ if (other.NumStackedVectorObservations != 0) {
+ NumStackedVectorObservations = other.NumStackedVectorObservations;
+ }
+ vectorActionSize_.Add(other.vectorActionSize_);
+ cameraResolutions_.Add(other.cameraResolutions_);
+ vectorActionDescriptions_.Add(other.vectorActionDescriptions_);
+ if (other.VectorActionSpaceType != 0) {
+ VectorActionSpaceType = other.VectorActionSpaceType;
+ }
+ if (other.BrainName.Length != 0) {
+ BrainName = other.BrainName;
+ }
+ if (other.IsTraining != false) {
+ IsTraining = other.IsTraining;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ VectorObservationSize = input.ReadInt32();
+ break;
+ }
+ case 16: {
+ NumStackedVectorObservations = input.ReadInt32();
+ break;
+ }
+ case 26:
+ case 24: {
+ vectorActionSize_.AddEntriesFrom(input, _repeated_vectorActionSize_codec);
+ break;
+ }
+ case 34: {
+ cameraResolutions_.AddEntriesFrom(input, _repeated_cameraResolutions_codec);
+ break;
+ }
+ case 42: {
+ vectorActionDescriptions_.AddEntriesFrom(input, _repeated_vectorActionDescriptions_codec);
+ break;
+ }
+ case 48: {
+ VectorActionSpaceType = (global::MLAgents.CommunicatorObjects.SpaceTypeProto) input.ReadEnum();
+ break;
+ }
+ case 58: {
+ BrainName = input.ReadString();
+ break;
+ }
+ case 64: {
+ IsTraining = input.ReadBool();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/BrainParametersProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/BrainParametersProto.cs.meta
new file mode 100755
index 00000000..e3ee7bca
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/BrainParametersProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 7b41acc4d406e4a3c94df3399b2a6471
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/CommandProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/CommandProto.cs
new file mode 100755
index 00000000..7d8029bf
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/CommandProto.cs
@@ -0,0 +1,49 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/command_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/command_proto.proto
+ public static partial class CommandProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/command_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static CommandProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjFhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9jb21tYW5kX3Byb3Rv",
+ "LnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyotCgxDb21tYW5kUHJvdG8S",
+ "CAoEU1RFUBAAEgkKBVJFU0VUEAESCAoEUVVJVBACQh+qAhxNTEFnZW50cy5D",
+ "b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(new[] {typeof(global::MLAgents.CommunicatorObjects.CommandProto), }, null));
+ }
+ #endregion
+
+ }
+ #region Enums
+ public enum CommandProto {
+ [pbr::OriginalName("STEP")] Step = 0,
+ [pbr::OriginalName("RESET")] Reset = 1,
+ [pbr::OriginalName("QUIT")] Quit = 2,
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/CommandProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/CommandProto.cs.meta
new file mode 100755
index 00000000..6443336b
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/CommandProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 6b2ff9fe2c38b4e79aba78908cc5492c
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/DemonstrationMetaProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/DemonstrationMetaProto.cs
new file mode 100755
index 00000000..85a7c3b5
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/DemonstrationMetaProto.cs
@@ -0,0 +1,288 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/demonstration_meta_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/demonstration_meta_proto.proto
+ public static partial class DemonstrationMetaProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/demonstration_meta_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static DemonstrationMetaProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjxhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9kZW1vbnN0cmF0aW9u",
+ "X21ldGFfcHJvdG8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzIo0BChZE",
+ "ZW1vbnN0cmF0aW9uTWV0YVByb3RvEhMKC2FwaV92ZXJzaW9uGAEgASgFEhoK",
+ "EmRlbW9uc3RyYXRpb25fbmFtZRgCIAEoCRIUCgxudW1iZXJfc3RlcHMYAyAB",
+ "KAUSFwoPbnVtYmVyX2VwaXNvZGVzGAQgASgFEhMKC21lYW5fcmV3YXJkGAUg",
+ "ASgCQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.DemonstrationMetaProto), global::MLAgents.CommunicatorObjects.DemonstrationMetaProto.Parser, new[]{ "ApiVersion", "DemonstrationName", "NumberSteps", "NumberEpisodes", "MeanReward" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class DemonstrationMetaProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DemonstrationMetaProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.DemonstrationMetaProtoReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public DemonstrationMetaProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public DemonstrationMetaProto(DemonstrationMetaProto other) : this() {
+ apiVersion_ = other.apiVersion_;
+ demonstrationName_ = other.demonstrationName_;
+ numberSteps_ = other.numberSteps_;
+ numberEpisodes_ = other.numberEpisodes_;
+ meanReward_ = other.meanReward_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public DemonstrationMetaProto Clone() {
+ return new DemonstrationMetaProto(this);
+ }
+
+ /// Field number for the "api_version" field.
+ public const int ApiVersionFieldNumber = 1;
+ private int apiVersion_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int ApiVersion {
+ get { return apiVersion_; }
+ set {
+ apiVersion_ = value;
+ }
+ }
+
+ /// Field number for the "demonstration_name" field.
+ public const int DemonstrationNameFieldNumber = 2;
+ private string demonstrationName_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string DemonstrationName {
+ get { return demonstrationName_; }
+ set {
+ demonstrationName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "number_steps" field.
+ public const int NumberStepsFieldNumber = 3;
+ private int numberSteps_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int NumberSteps {
+ get { return numberSteps_; }
+ set {
+ numberSteps_ = value;
+ }
+ }
+
+ /// Field number for the "number_episodes" field.
+ public const int NumberEpisodesFieldNumber = 4;
+ private int numberEpisodes_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int NumberEpisodes {
+ get { return numberEpisodes_; }
+ set {
+ numberEpisodes_ = value;
+ }
+ }
+
+ /// Field number for the "mean_reward" field.
+ public const int MeanRewardFieldNumber = 5;
+ private float meanReward_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public float MeanReward {
+ get { return meanReward_; }
+ set {
+ meanReward_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as DemonstrationMetaProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(DemonstrationMetaProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (ApiVersion != other.ApiVersion) return false;
+ if (DemonstrationName != other.DemonstrationName) return false;
+ if (NumberSteps != other.NumberSteps) return false;
+ if (NumberEpisodes != other.NumberEpisodes) return false;
+ if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MeanReward, other.MeanReward)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (ApiVersion != 0) hash ^= ApiVersion.GetHashCode();
+ if (DemonstrationName.Length != 0) hash ^= DemonstrationName.GetHashCode();
+ if (NumberSteps != 0) hash ^= NumberSteps.GetHashCode();
+ if (NumberEpisodes != 0) hash ^= NumberEpisodes.GetHashCode();
+ if (MeanReward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MeanReward);
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (ApiVersion != 0) {
+ output.WriteRawTag(8);
+ output.WriteInt32(ApiVersion);
+ }
+ if (DemonstrationName.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(DemonstrationName);
+ }
+ if (NumberSteps != 0) {
+ output.WriteRawTag(24);
+ output.WriteInt32(NumberSteps);
+ }
+ if (NumberEpisodes != 0) {
+ output.WriteRawTag(32);
+ output.WriteInt32(NumberEpisodes);
+ }
+ if (MeanReward != 0F) {
+ output.WriteRawTag(45);
+ output.WriteFloat(MeanReward);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (ApiVersion != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(ApiVersion);
+ }
+ if (DemonstrationName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(DemonstrationName);
+ }
+ if (NumberSteps != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumberSteps);
+ }
+ if (NumberEpisodes != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumberEpisodes);
+ }
+ if (MeanReward != 0F) {
+ size += 1 + 4;
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(DemonstrationMetaProto other) {
+ if (other == null) {
+ return;
+ }
+ if (other.ApiVersion != 0) {
+ ApiVersion = other.ApiVersion;
+ }
+ if (other.DemonstrationName.Length != 0) {
+ DemonstrationName = other.DemonstrationName;
+ }
+ if (other.NumberSteps != 0) {
+ NumberSteps = other.NumberSteps;
+ }
+ if (other.NumberEpisodes != 0) {
+ NumberEpisodes = other.NumberEpisodes;
+ }
+ if (other.MeanReward != 0F) {
+ MeanReward = other.MeanReward;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ ApiVersion = input.ReadInt32();
+ break;
+ }
+ case 18: {
+ DemonstrationName = input.ReadString();
+ break;
+ }
+ case 24: {
+ NumberSteps = input.ReadInt32();
+ break;
+ }
+ case 32: {
+ NumberEpisodes = input.ReadInt32();
+ break;
+ }
+ case 45: {
+ MeanReward = input.ReadFloat();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/DemonstrationMetaProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/DemonstrationMetaProto.cs.meta
new file mode 100755
index 00000000..f62ed064
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/DemonstrationMetaProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: f7abfeda342414e059423ef90ede4c30
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/EngineConfigurationProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/EngineConfigurationProto.cs
new file mode 100755
index 00000000..c08782d5
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/EngineConfigurationProto.cs
@@ -0,0 +1,317 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/engine_configuration_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/engine_configuration_proto.proto
+ public static partial class EngineConfigurationProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/engine_configuration_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static EngineConfigurationProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "Cj5hbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9lbmdpbmVfY29uZmln",
+ "dXJhdGlvbl9wcm90by5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilQEK",
+ "GEVuZ2luZUNvbmZpZ3VyYXRpb25Qcm90bxINCgV3aWR0aBgBIAEoBRIOCgZo",
+ "ZWlnaHQYAiABKAUSFQoNcXVhbGl0eV9sZXZlbBgDIAEoBRISCgp0aW1lX3Nj",
+ "YWxlGAQgASgCEhkKEXRhcmdldF9mcmFtZV9yYXRlGAUgASgFEhQKDHNob3df",
+ "bW9uaXRvchgGIAEoCEIfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0",
+ "c2IGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.EngineConfigurationProto), global::MLAgents.CommunicatorObjects.EngineConfigurationProto.Parser, new[]{ "Width", "Height", "QualityLevel", "TimeScale", "TargetFrameRate", "ShowMonitor" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class EngineConfigurationProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EngineConfigurationProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.EngineConfigurationProtoReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public EngineConfigurationProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public EngineConfigurationProto(EngineConfigurationProto other) : this() {
+ width_ = other.width_;
+ height_ = other.height_;
+ qualityLevel_ = other.qualityLevel_;
+ timeScale_ = other.timeScale_;
+ targetFrameRate_ = other.targetFrameRate_;
+ showMonitor_ = other.showMonitor_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public EngineConfigurationProto Clone() {
+ return new EngineConfigurationProto(this);
+ }
+
+ /// Field number for the "width" field.
+ public const int WidthFieldNumber = 1;
+ private int width_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Width {
+ get { return width_; }
+ set {
+ width_ = value;
+ }
+ }
+
+ /// Field number for the "height" field.
+ public const int HeightFieldNumber = 2;
+ private int height_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Height {
+ get { return height_; }
+ set {
+ height_ = value;
+ }
+ }
+
+ /// Field number for the "quality_level" field.
+ public const int QualityLevelFieldNumber = 3;
+ private int qualityLevel_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int QualityLevel {
+ get { return qualityLevel_; }
+ set {
+ qualityLevel_ = value;
+ }
+ }
+
+ /// Field number for the "time_scale" field.
+ public const int TimeScaleFieldNumber = 4;
+ private float timeScale_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public float TimeScale {
+ get { return timeScale_; }
+ set {
+ timeScale_ = value;
+ }
+ }
+
+ /// Field number for the "target_frame_rate" field.
+ public const int TargetFrameRateFieldNumber = 5;
+ private int targetFrameRate_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int TargetFrameRate {
+ get { return targetFrameRate_; }
+ set {
+ targetFrameRate_ = value;
+ }
+ }
+
+ /// Field number for the "show_monitor" field.
+ public const int ShowMonitorFieldNumber = 6;
+ private bool showMonitor_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool ShowMonitor {
+ get { return showMonitor_; }
+ set {
+ showMonitor_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as EngineConfigurationProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(EngineConfigurationProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (Width != other.Width) return false;
+ if (Height != other.Height) return false;
+ if (QualityLevel != other.QualityLevel) return false;
+ if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TimeScale, other.TimeScale)) return false;
+ if (TargetFrameRate != other.TargetFrameRate) return false;
+ if (ShowMonitor != other.ShowMonitor) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (Width != 0) hash ^= Width.GetHashCode();
+ if (Height != 0) hash ^= Height.GetHashCode();
+ if (QualityLevel != 0) hash ^= QualityLevel.GetHashCode();
+ if (TimeScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TimeScale);
+ if (TargetFrameRate != 0) hash ^= TargetFrameRate.GetHashCode();
+ if (ShowMonitor != false) hash ^= ShowMonitor.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (Width != 0) {
+ output.WriteRawTag(8);
+ output.WriteInt32(Width);
+ }
+ if (Height != 0) {
+ output.WriteRawTag(16);
+ output.WriteInt32(Height);
+ }
+ if (QualityLevel != 0) {
+ output.WriteRawTag(24);
+ output.WriteInt32(QualityLevel);
+ }
+ if (TimeScale != 0F) {
+ output.WriteRawTag(37);
+ output.WriteFloat(TimeScale);
+ }
+ if (TargetFrameRate != 0) {
+ output.WriteRawTag(40);
+ output.WriteInt32(TargetFrameRate);
+ }
+ if (ShowMonitor != false) {
+ output.WriteRawTag(48);
+ output.WriteBool(ShowMonitor);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (Width != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Width);
+ }
+ if (Height != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Height);
+ }
+ if (QualityLevel != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(QualityLevel);
+ }
+ if (TimeScale != 0F) {
+ size += 1 + 4;
+ }
+ if (TargetFrameRate != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(TargetFrameRate);
+ }
+ if (ShowMonitor != false) {
+ size += 1 + 1;
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(EngineConfigurationProto other) {
+ if (other == null) {
+ return;
+ }
+ if (other.Width != 0) {
+ Width = other.Width;
+ }
+ if (other.Height != 0) {
+ Height = other.Height;
+ }
+ if (other.QualityLevel != 0) {
+ QualityLevel = other.QualityLevel;
+ }
+ if (other.TimeScale != 0F) {
+ TimeScale = other.TimeScale;
+ }
+ if (other.TargetFrameRate != 0) {
+ TargetFrameRate = other.TargetFrameRate;
+ }
+ if (other.ShowMonitor != false) {
+ ShowMonitor = other.ShowMonitor;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ Width = input.ReadInt32();
+ break;
+ }
+ case 16: {
+ Height = input.ReadInt32();
+ break;
+ }
+ case 24: {
+ QualityLevel = input.ReadInt32();
+ break;
+ }
+ case 37: {
+ TimeScale = input.ReadFloat();
+ break;
+ }
+ case 40: {
+ TargetFrameRate = input.ReadInt32();
+ break;
+ }
+ case 48: {
+ ShowMonitor = input.ReadBool();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/EngineConfigurationProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/EngineConfigurationProto.cs.meta
new file mode 100755
index 00000000..613b83d3
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/EngineConfigurationProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 2cebeb1263d7846b4b3c7c6e5d5e193f
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/Header.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/Header.cs
new file mode 100755
index 00000000..353e5838
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/Header.cs
@@ -0,0 +1,202 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/header.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/header.proto
+ public static partial class HeaderReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/header.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static HeaderReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CiphbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9oZWFkZXIucHJvdG8S",
+ "FGNvbW11bmljYXRvcl9vYmplY3RzIikKBkhlYWRlchIOCgZzdGF0dXMYASAB",
+ "KAUSDwoHbWVzc2FnZRgCIAEoCUIfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9y",
+ "T2JqZWN0c2IGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.Header), global::MLAgents.CommunicatorObjects.Header.Parser, new[]{ "Status", "Message" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class Header : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Header());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.HeaderReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public Header() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public Header(Header other) : this() {
+ status_ = other.status_;
+ message_ = other.message_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public Header Clone() {
+ return new Header(this);
+ }
+
+ /// Field number for the "status" field.
+ public const int StatusFieldNumber = 1;
+ private int status_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Status {
+ get { return status_; }
+ set {
+ status_ = value;
+ }
+ }
+
+ /// Field number for the "message" field.
+ public const int MessageFieldNumber = 2;
+ private string message_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string Message {
+ get { return message_; }
+ set {
+ message_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as Header);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(Header other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (Status != other.Status) return false;
+ if (Message != other.Message) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (Status != 0) hash ^= Status.GetHashCode();
+ if (Message.Length != 0) hash ^= Message.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (Status != 0) {
+ output.WriteRawTag(8);
+ output.WriteInt32(Status);
+ }
+ if (Message.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(Message);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (Status != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Status);
+ }
+ if (Message.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(Message);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(Header other) {
+ if (other == null) {
+ return;
+ }
+ if (other.Status != 0) {
+ Status = other.Status;
+ }
+ if (other.Message.Length != 0) {
+ Message = other.Message;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ Status = input.ReadInt32();
+ break;
+ }
+ case 18: {
+ Message = input.ReadString();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/Header.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/Header.cs.meta
new file mode 100755
index 00000000..63d91fcc
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/Header.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 8bb8aabfab48b408381733bccccd5af9
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ResolutionProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ResolutionProto.cs
new file mode 100755
index 00000000..6c005c38
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ResolutionProto.cs
@@ -0,0 +1,231 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/resolution_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/resolution_proto.proto
+ public static partial class ResolutionProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/resolution_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static ResolutionProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjRhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9yZXNvbHV0aW9uX3By",
+ "b3RvLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyJECg9SZXNvbHV0aW9u",
+ "UHJvdG8SDQoFd2lkdGgYASABKAUSDgoGaGVpZ2h0GAIgASgFEhIKCmdyYXlf",
+ "c2NhbGUYAyABKAhCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNi",
+ "BnByb3RvMw=="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.ResolutionProto), global::MLAgents.CommunicatorObjects.ResolutionProto.Parser, new[]{ "Width", "Height", "GrayScale" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class ResolutionProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResolutionProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.ResolutionProtoReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ResolutionProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ResolutionProto(ResolutionProto other) : this() {
+ width_ = other.width_;
+ height_ = other.height_;
+ grayScale_ = other.grayScale_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ResolutionProto Clone() {
+ return new ResolutionProto(this);
+ }
+
+ /// Field number for the "width" field.
+ public const int WidthFieldNumber = 1;
+ private int width_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Width {
+ get { return width_; }
+ set {
+ width_ = value;
+ }
+ }
+
+ /// Field number for the "height" field.
+ public const int HeightFieldNumber = 2;
+ private int height_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Height {
+ get { return height_; }
+ set {
+ height_ = value;
+ }
+ }
+
+ /// Field number for the "gray_scale" field.
+ public const int GrayScaleFieldNumber = 3;
+ private bool grayScale_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool GrayScale {
+ get { return grayScale_; }
+ set {
+ grayScale_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as ResolutionProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(ResolutionProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (Width != other.Width) return false;
+ if (Height != other.Height) return false;
+ if (GrayScale != other.GrayScale) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (Width != 0) hash ^= Width.GetHashCode();
+ if (Height != 0) hash ^= Height.GetHashCode();
+ if (GrayScale != false) hash ^= GrayScale.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (Width != 0) {
+ output.WriteRawTag(8);
+ output.WriteInt32(Width);
+ }
+ if (Height != 0) {
+ output.WriteRawTag(16);
+ output.WriteInt32(Height);
+ }
+ if (GrayScale != false) {
+ output.WriteRawTag(24);
+ output.WriteBool(GrayScale);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (Width != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Width);
+ }
+ if (Height != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Height);
+ }
+ if (GrayScale != false) {
+ size += 1 + 1;
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(ResolutionProto other) {
+ if (other == null) {
+ return;
+ }
+ if (other.Width != 0) {
+ Width = other.Width;
+ }
+ if (other.Height != 0) {
+ Height = other.Height;
+ }
+ if (other.GrayScale != false) {
+ GrayScale = other.GrayScale;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ Width = input.ReadInt32();
+ break;
+ }
+ case 16: {
+ Height = input.ReadInt32();
+ break;
+ }
+ case 24: {
+ GrayScale = input.ReadBool();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ResolutionProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ResolutionProto.cs.meta
new file mode 100755
index 00000000..b019d860
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/ResolutionProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: eae234f817240444a9d18b3d7366f260
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/SpaceTypeProto.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/SpaceTypeProto.cs
new file mode 100755
index 00000000..5897493b
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/SpaceTypeProto.cs
@@ -0,0 +1,49 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/space_type_proto.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/space_type_proto.proto
+ public static partial class SpaceTypeProtoReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/space_type_proto.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static SpaceTypeProtoReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjRhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9zcGFjZV90eXBlX3By",
+ "b3RvLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cxo0YW5pbWFsYWkvY29t",
+ "bXVuaWNhdG9yX29iamVjdHMvcmVzb2x1dGlvbl9wcm90by5wcm90byouCg5T",
+ "cGFjZVR5cGVQcm90bxIMCghkaXNjcmV0ZRAAEg4KCmNvbnRpbnVvdXMQAUIf",
+ "qgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ResolutionProtoReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(new[] {typeof(global::MLAgents.CommunicatorObjects.SpaceTypeProto), }, null));
+ }
+ #endregion
+
+ }
+ #region Enums
+ public enum SpaceTypeProto {
+ [pbr::OriginalName("discrete")] Discrete = 0,
+ [pbr::OriginalName("continuous")] Continuous = 1,
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/SpaceTypeProto.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/SpaceTypeProto.cs.meta
new file mode 100755
index 00000000..dcf59542
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/SpaceTypeProto.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 3e61637749b07412284363ff304da763
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityInput.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityInput.cs
new file mode 100755
index 00000000..d398127e
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityInput.cs
@@ -0,0 +1,256 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_input.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_input.proto
+ public static partial class UnityInputReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_input.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityInputReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "Ci9hbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9pbnB1dC5w",
+ "cm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaMmFuaW1hbGFpL2NvbW11bmlj",
+ "YXRvcl9vYmplY3RzL3VuaXR5X3JsX2lucHV0LnByb3RvGkFhbmltYWxhaS9j",
+ "b21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9pbml0aWFsaXphdGlvbl9p",
+ "bnB1dC5wcm90bxo4YW5pbWFsYWkvY29tbXVuaWNhdG9yX29iamVjdHMvdW5p",
+ "dHlfcmxfcmVzZXRfaW5wdXQucHJvdG8i1gEKClVuaXR5SW5wdXQSNAoIcmxf",
+ "aW5wdXQYASABKAsyIi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMSW5w",
+ "dXQSUQoXcmxfaW5pdGlhbGl6YXRpb25faW5wdXQYAiABKAsyMC5jb21tdW5p",
+ "Y2F0b3Jfb2JqZWN0cy5Vbml0eVJMSW5pdGlhbGl6YXRpb25JbnB1dBI/Cg5y",
+ "bF9yZXNldF9pbnB1dBgDIAEoCzInLmNvbW11bmljYXRvcl9vYmplY3RzLlVu",
+ "aXR5UkxSZXNldElucHV0Qh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmpl",
+ "Y3RzYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.UnityRlInputReflection.Descriptor, global::MLAgents.CommunicatorObjects.UnityRlInitializationInputReflection.Descriptor, global::MLAgents.CommunicatorObjects.UnityRlResetInputReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityInput), global::MLAgents.CommunicatorObjects.UnityInput.Parser, new[]{ "RlInput", "RlInitializationInput", "RlResetInput" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class UnityInput : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityInput());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityInputReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityInput() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityInput(UnityInput other) : this() {
+ rlInput_ = other.rlInput_ != null ? other.rlInput_.Clone() : null;
+ rlInitializationInput_ = other.rlInitializationInput_ != null ? other.rlInitializationInput_.Clone() : null;
+ rlResetInput_ = other.rlResetInput_ != null ? other.rlResetInput_.Clone() : null;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityInput Clone() {
+ return new UnityInput(this);
+ }
+
+ /// Field number for the "rl_input" field.
+ public const int RlInputFieldNumber = 1;
+ private global::MLAgents.CommunicatorObjects.UnityRLInput rlInput_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.UnityRLInput RlInput {
+ get { return rlInput_; }
+ set {
+ rlInput_ = value;
+ }
+ }
+
+ /// Field number for the "rl_initialization_input" field.
+ public const int RlInitializationInputFieldNumber = 2;
+ private global::MLAgents.CommunicatorObjects.UnityRLInitializationInput rlInitializationInput_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.UnityRLInitializationInput RlInitializationInput {
+ get { return rlInitializationInput_; }
+ set {
+ rlInitializationInput_ = value;
+ }
+ }
+
+ /// Field number for the "rl_reset_input" field.
+ public const int RlResetInputFieldNumber = 3;
+ private global::MLAgents.CommunicatorObjects.UnityRLResetInput rlResetInput_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.UnityRLResetInput RlResetInput {
+ get { return rlResetInput_; }
+ set {
+ rlResetInput_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityInput);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityInput other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!object.Equals(RlInput, other.RlInput)) return false;
+ if (!object.Equals(RlInitializationInput, other.RlInitializationInput)) return false;
+ if (!object.Equals(RlResetInput, other.RlResetInput)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (rlInput_ != null) hash ^= RlInput.GetHashCode();
+ if (rlInitializationInput_ != null) hash ^= RlInitializationInput.GetHashCode();
+ if (rlResetInput_ != null) hash ^= RlResetInput.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (rlInput_ != null) {
+ output.WriteRawTag(10);
+ output.WriteMessage(RlInput);
+ }
+ if (rlInitializationInput_ != null) {
+ output.WriteRawTag(18);
+ output.WriteMessage(RlInitializationInput);
+ }
+ if (rlResetInput_ != null) {
+ output.WriteRawTag(26);
+ output.WriteMessage(RlResetInput);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (rlInput_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInput);
+ }
+ if (rlInitializationInput_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInitializationInput);
+ }
+ if (rlResetInput_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlResetInput);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityInput other) {
+ if (other == null) {
+ return;
+ }
+ if (other.rlInput_ != null) {
+ if (rlInput_ == null) {
+ RlInput = new global::MLAgents.CommunicatorObjects.UnityRLInput();
+ }
+ RlInput.MergeFrom(other.RlInput);
+ }
+ if (other.rlInitializationInput_ != null) {
+ if (rlInitializationInput_ == null) {
+ RlInitializationInput = new global::MLAgents.CommunicatorObjects.UnityRLInitializationInput();
+ }
+ RlInitializationInput.MergeFrom(other.RlInitializationInput);
+ }
+ if (other.rlResetInput_ != null) {
+ if (rlResetInput_ == null) {
+ RlResetInput = new global::MLAgents.CommunicatorObjects.UnityRLResetInput();
+ }
+ RlResetInput.MergeFrom(other.RlResetInput);
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ if (rlInput_ == null) {
+ RlInput = new global::MLAgents.CommunicatorObjects.UnityRLInput();
+ }
+ input.ReadMessage(RlInput);
+ break;
+ }
+ case 18: {
+ if (rlInitializationInput_ == null) {
+ RlInitializationInput = new global::MLAgents.CommunicatorObjects.UnityRLInitializationInput();
+ }
+ input.ReadMessage(RlInitializationInput);
+ break;
+ }
+ case 26: {
+ if (rlResetInput_ == null) {
+ RlResetInput = new global::MLAgents.CommunicatorObjects.UnityRLResetInput();
+ }
+ input.ReadMessage(RlResetInput);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityInput.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityInput.cs.meta
new file mode 100755
index 00000000..846a8eb5
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityInput.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 25e46cd9eca204e19a08fa938802ef9d
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityMessage.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityMessage.cs
new file mode 100755
index 00000000..ae85911e
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityMessage.cs
@@ -0,0 +1,254 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_message.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_message.proto
+ public static partial class UnityMessageReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_message.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityMessageReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjFhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9tZXNzYWdl",
+ "LnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cxowYW5pbWFsYWkvY29tbXVu",
+ "aWNhdG9yX29iamVjdHMvdW5pdHlfb3V0cHV0LnByb3RvGi9hbmltYWxhaS9j",
+ "b21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9pbnB1dC5wcm90bxoqYW5pbWFs",
+ "YWkvY29tbXVuaWNhdG9yX29iamVjdHMvaGVhZGVyLnByb3RvIqwBCgxVbml0",
+ "eU1lc3NhZ2USLAoGaGVhZGVyGAEgASgLMhwuY29tbXVuaWNhdG9yX29iamVj",
+ "dHMuSGVhZGVyEjcKDHVuaXR5X291dHB1dBgCIAEoCzIhLmNvbW11bmljYXRv",
+ "cl9vYmplY3RzLlVuaXR5T3V0cHV0EjUKC3VuaXR5X2lucHV0GAMgASgLMiAu",
+ "Y29tbXVuaWNhdG9yX29iamVjdHMuVW5pdHlJbnB1dEIfqgIcTUxBZ2VudHMu",
+ "Q29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.UnityOutputReflection.Descriptor, global::MLAgents.CommunicatorObjects.UnityInputReflection.Descriptor, global::MLAgents.CommunicatorObjects.HeaderReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityMessage), global::MLAgents.CommunicatorObjects.UnityMessage.Parser, new[]{ "Header", "UnityOutput", "UnityInput" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class UnityMessage : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityMessage());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityMessageReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityMessage() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityMessage(UnityMessage other) : this() {
+ header_ = other.header_ != null ? other.header_.Clone() : null;
+ unityOutput_ = other.unityOutput_ != null ? other.unityOutput_.Clone() : null;
+ unityInput_ = other.unityInput_ != null ? other.unityInput_.Clone() : null;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityMessage Clone() {
+ return new UnityMessage(this);
+ }
+
+ /// Field number for the "header" field.
+ public const int HeaderFieldNumber = 1;
+ private global::MLAgents.CommunicatorObjects.Header header_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.Header Header {
+ get { return header_; }
+ set {
+ header_ = value;
+ }
+ }
+
+ /// Field number for the "unity_output" field.
+ public const int UnityOutputFieldNumber = 2;
+ private global::MLAgents.CommunicatorObjects.UnityOutput unityOutput_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.UnityOutput UnityOutput {
+ get { return unityOutput_; }
+ set {
+ unityOutput_ = value;
+ }
+ }
+
+ /// Field number for the "unity_input" field.
+ public const int UnityInputFieldNumber = 3;
+ private global::MLAgents.CommunicatorObjects.UnityInput unityInput_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.UnityInput UnityInput {
+ get { return unityInput_; }
+ set {
+ unityInput_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityMessage);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityMessage other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!object.Equals(Header, other.Header)) return false;
+ if (!object.Equals(UnityOutput, other.UnityOutput)) return false;
+ if (!object.Equals(UnityInput, other.UnityInput)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (header_ != null) hash ^= Header.GetHashCode();
+ if (unityOutput_ != null) hash ^= UnityOutput.GetHashCode();
+ if (unityInput_ != null) hash ^= UnityInput.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (header_ != null) {
+ output.WriteRawTag(10);
+ output.WriteMessage(Header);
+ }
+ if (unityOutput_ != null) {
+ output.WriteRawTag(18);
+ output.WriteMessage(UnityOutput);
+ }
+ if (unityInput_ != null) {
+ output.WriteRawTag(26);
+ output.WriteMessage(UnityInput);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (header_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(Header);
+ }
+ if (unityOutput_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(UnityOutput);
+ }
+ if (unityInput_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(UnityInput);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityMessage other) {
+ if (other == null) {
+ return;
+ }
+ if (other.header_ != null) {
+ if (header_ == null) {
+ Header = new global::MLAgents.CommunicatorObjects.Header();
+ }
+ Header.MergeFrom(other.Header);
+ }
+ if (other.unityOutput_ != null) {
+ if (unityOutput_ == null) {
+ UnityOutput = new global::MLAgents.CommunicatorObjects.UnityOutput();
+ }
+ UnityOutput.MergeFrom(other.UnityOutput);
+ }
+ if (other.unityInput_ != null) {
+ if (unityInput_ == null) {
+ UnityInput = new global::MLAgents.CommunicatorObjects.UnityInput();
+ }
+ UnityInput.MergeFrom(other.UnityInput);
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ if (header_ == null) {
+ Header = new global::MLAgents.CommunicatorObjects.Header();
+ }
+ input.ReadMessage(Header);
+ break;
+ }
+ case 18: {
+ if (unityOutput_ == null) {
+ UnityOutput = new global::MLAgents.CommunicatorObjects.UnityOutput();
+ }
+ input.ReadMessage(UnityOutput);
+ break;
+ }
+ case 26: {
+ if (unityInput_ == null) {
+ UnityInput = new global::MLAgents.CommunicatorObjects.UnityInput();
+ }
+ input.ReadMessage(UnityInput);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityMessage.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityMessage.cs.meta
new file mode 100755
index 00000000..6df8c397
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityMessage.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: d270bf9ce3d564bb48b2095802c15ff9
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityOutput.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityOutput.cs
new file mode 100755
index 00000000..e38c26ff
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityOutput.cs
@@ -0,0 +1,219 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_output.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_output.proto
+ public static partial class UnityOutputReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_output.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityOutputReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjBhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9vdXRwdXQu",
+ "cHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNhbmltYWxhaS9jb21tdW5p",
+ "Y2F0b3Jfb2JqZWN0cy91bml0eV9ybF9vdXRwdXQucHJvdG8aQmFuaW1hbGFp",
+ "L2NvbW11bmljYXRvcl9vYmplY3RzL3VuaXR5X3JsX2luaXRpYWxpemF0aW9u",
+ "X291dHB1dC5wcm90byKaAQoLVW5pdHlPdXRwdXQSNgoJcmxfb3V0cHV0GAEg",
+ "ASgLMiMuY29tbXVuaWNhdG9yX29iamVjdHMuVW5pdHlSTE91dHB1dBJTChhy",
+ "bF9pbml0aWFsaXphdGlvbl9vdXRwdXQYAiABKAsyMS5jb21tdW5pY2F0b3Jf",
+ "b2JqZWN0cy5Vbml0eVJMSW5pdGlhbGl6YXRpb25PdXRwdXRCH6oCHE1MQWdl",
+ "bnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.UnityRlOutputReflection.Descriptor, global::MLAgents.CommunicatorObjects.UnityRlInitializationOutputReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityOutput), global::MLAgents.CommunicatorObjects.UnityOutput.Parser, new[]{ "RlOutput", "RlInitializationOutput" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class UnityOutput : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityOutput());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityOutputReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityOutput() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityOutput(UnityOutput other) : this() {
+ rlOutput_ = other.rlOutput_ != null ? other.rlOutput_.Clone() : null;
+ rlInitializationOutput_ = other.rlInitializationOutput_ != null ? other.rlInitializationOutput_.Clone() : null;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityOutput Clone() {
+ return new UnityOutput(this);
+ }
+
+ /// Field number for the "rl_output" field.
+ public const int RlOutputFieldNumber = 1;
+ private global::MLAgents.CommunicatorObjects.UnityRLOutput rlOutput_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.UnityRLOutput RlOutput {
+ get { return rlOutput_; }
+ set {
+ rlOutput_ = value;
+ }
+ }
+
+ /// Field number for the "rl_initialization_output" field.
+ public const int RlInitializationOutputFieldNumber = 2;
+ private global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput rlInitializationOutput_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput RlInitializationOutput {
+ get { return rlInitializationOutput_; }
+ set {
+ rlInitializationOutput_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityOutput);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityOutput other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!object.Equals(RlOutput, other.RlOutput)) return false;
+ if (!object.Equals(RlInitializationOutput, other.RlInitializationOutput)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (rlOutput_ != null) hash ^= RlOutput.GetHashCode();
+ if (rlInitializationOutput_ != null) hash ^= RlInitializationOutput.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (rlOutput_ != null) {
+ output.WriteRawTag(10);
+ output.WriteMessage(RlOutput);
+ }
+ if (rlInitializationOutput_ != null) {
+ output.WriteRawTag(18);
+ output.WriteMessage(RlInitializationOutput);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (rlOutput_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlOutput);
+ }
+ if (rlInitializationOutput_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(RlInitializationOutput);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityOutput other) {
+ if (other == null) {
+ return;
+ }
+ if (other.rlOutput_ != null) {
+ if (rlOutput_ == null) {
+ RlOutput = new global::MLAgents.CommunicatorObjects.UnityRLOutput();
+ }
+ RlOutput.MergeFrom(other.RlOutput);
+ }
+ if (other.rlInitializationOutput_ != null) {
+ if (rlInitializationOutput_ == null) {
+ RlInitializationOutput = new global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput();
+ }
+ RlInitializationOutput.MergeFrom(other.RlInitializationOutput);
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ if (rlOutput_ == null) {
+ RlOutput = new global::MLAgents.CommunicatorObjects.UnityRLOutput();
+ }
+ input.ReadMessage(RlOutput);
+ break;
+ }
+ case 18: {
+ if (rlInitializationOutput_ == null) {
+ RlInitializationOutput = new global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput();
+ }
+ input.ReadMessage(RlInitializationOutput);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityOutput.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityOutput.cs.meta
new file mode 100755
index 00000000..256098d3
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityOutput.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 5b7166f97831f45ef86df5eed0042240
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs
new file mode 100755
index 00000000..f4488d45
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs
@@ -0,0 +1,174 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_rl_initialization_input.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_rl_initialization_input.proto
+ public static partial class UnityRlInitializationInputReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_rl_initialization_input.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityRlInitializationInputReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CkFhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9pbml0",
+ "aWFsaXphdGlvbl9pbnB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi",
+ "KgoaVW5pdHlSTEluaXRpYWxpemF0aW9uSW5wdXQSDAoEc2VlZBgBIAEoBUIf",
+ "qgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInitializationInput), global::MLAgents.CommunicatorObjects.UnityRLInitializationInput.Parser, new[]{ "Seed" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class UnityRLInitializationInput : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInitializationInput());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityRlInitializationInputReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInitializationInput() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInitializationInput(UnityRLInitializationInput other) : this() {
+ seed_ = other.seed_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInitializationInput Clone() {
+ return new UnityRLInitializationInput(this);
+ }
+
+ /// Field number for the "seed" field.
+ public const int SeedFieldNumber = 1;
+ private int seed_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Seed {
+ get { return seed_; }
+ set {
+ seed_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityRLInitializationInput);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityRLInitializationInput other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (Seed != other.Seed) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (Seed != 0) hash ^= Seed.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (Seed != 0) {
+ output.WriteRawTag(8);
+ output.WriteInt32(Seed);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (Seed != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Seed);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityRLInitializationInput other) {
+ if (other == null) {
+ return;
+ }
+ if (other.Seed != 0) {
+ Seed = other.Seed;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ Seed = input.ReadInt32();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs.meta
new file mode 100755
index 00000000..eb0f1e1c
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 6c81750abd5a9432babe1834534122c0
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs
new file mode 100755
index 00000000..5ec5e9e4
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs
@@ -0,0 +1,260 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_rl_initialization_output.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_rl_initialization_output.proto
+ public static partial class UnityRlInitializationOutputReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_rl_initialization_output.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityRlInitializationOutputReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CkJhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9pbml0",
+ "aWFsaXphdGlvbl9vdXRwdXQucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3Rz",
+ "GjphbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy9icmFpbl9wYXJhbWV0",
+ "ZXJzX3Byb3RvLnByb3RvIpQBChtVbml0eVJMSW5pdGlhbGl6YXRpb25PdXRw",
+ "dXQSDAoEbmFtZRgBIAEoCRIPCgd2ZXJzaW9uGAIgASgJEhAKCGxvZ19wYXRo",
+ "GAMgASgJEkQKEGJyYWluX3BhcmFtZXRlcnMYBSADKAsyKi5jb21tdW5pY2F0",
+ "b3Jfb2JqZWN0cy5CcmFpblBhcmFtZXRlcnNQcm90b0IfqgIcTUxBZ2VudHMu",
+ "Q29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.BrainParametersProtoReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput), global::MLAgents.CommunicatorObjects.UnityRLInitializationOutput.Parser, new[]{ "Name", "Version", "LogPath", "BrainParameters" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ ///
+ /// The request message containing the academy's parameters.
+ ///
+ public sealed partial class UnityRLInitializationOutput : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInitializationOutput());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityRlInitializationOutputReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInitializationOutput() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInitializationOutput(UnityRLInitializationOutput other) : this() {
+ name_ = other.name_;
+ version_ = other.version_;
+ logPath_ = other.logPath_;
+ brainParameters_ = other.brainParameters_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInitializationOutput Clone() {
+ return new UnityRLInitializationOutput(this);
+ }
+
+ /// Field number for the "name" field.
+ public const int NameFieldNumber = 1;
+ private string name_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string Name {
+ get { return name_; }
+ set {
+ name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "version" field.
+ public const int VersionFieldNumber = 2;
+ private string version_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string Version {
+ get { return version_; }
+ set {
+ version_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "log_path" field.
+ public const int LogPathFieldNumber = 3;
+ private string logPath_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string LogPath {
+ get { return logPath_; }
+ set {
+ logPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "brain_parameters" field.
+ public const int BrainParametersFieldNumber = 5;
+ private static readonly pb::FieldCodec _repeated_brainParameters_codec
+ = pb::FieldCodec.ForMessage(42, global::MLAgents.CommunicatorObjects.BrainParametersProto.Parser);
+ private readonly pbc::RepeatedField brainParameters_ = new pbc::RepeatedField();
+ ///
+ /// EnvironmentParametersProto environment_parameters = 6;
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField BrainParameters {
+ get { return brainParameters_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityRLInitializationOutput);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityRLInitializationOutput other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (Name != other.Name) return false;
+ if (Version != other.Version) return false;
+ if (LogPath != other.LogPath) return false;
+ if(!brainParameters_.Equals(other.brainParameters_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (Name.Length != 0) hash ^= Name.GetHashCode();
+ if (Version.Length != 0) hash ^= Version.GetHashCode();
+ if (LogPath.Length != 0) hash ^= LogPath.GetHashCode();
+ hash ^= brainParameters_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (Name.Length != 0) {
+ output.WriteRawTag(10);
+ output.WriteString(Name);
+ }
+ if (Version.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(Version);
+ }
+ if (LogPath.Length != 0) {
+ output.WriteRawTag(26);
+ output.WriteString(LogPath);
+ }
+ brainParameters_.WriteTo(output, _repeated_brainParameters_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (Name.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
+ }
+ if (Version.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(Version);
+ }
+ if (LogPath.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(LogPath);
+ }
+ size += brainParameters_.CalculateSize(_repeated_brainParameters_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityRLInitializationOutput other) {
+ if (other == null) {
+ return;
+ }
+ if (other.Name.Length != 0) {
+ Name = other.Name;
+ }
+ if (other.Version.Length != 0) {
+ Version = other.Version;
+ }
+ if (other.LogPath.Length != 0) {
+ LogPath = other.LogPath;
+ }
+ brainParameters_.Add(other.brainParameters_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ Name = input.ReadString();
+ break;
+ }
+ case 18: {
+ Version = input.ReadString();
+ break;
+ }
+ case 26: {
+ LogPath = input.ReadString();
+ break;
+ }
+ case 42: {
+ brainParameters_.AddEntriesFrom(input, _repeated_brainParameters_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs.meta
new file mode 100755
index 00000000..1afe3779
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: f7ac9dd525a2246688054b2442eda28a
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInput.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInput.cs
new file mode 100755
index 00000000..be338c0b
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInput.cs
@@ -0,0 +1,363 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_rl_input.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_rl_input.proto
+ public static partial class UnityRlInputReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_rl_input.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityRlInputReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjJhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9pbnB1",
+ "dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaNmFuaW1hbGFpL2NvbW11",
+ "bmljYXRvcl9vYmplY3RzL2FnZW50X2FjdGlvbl9wcm90by5wcm90bxoxYW5p",
+ "bWFsYWkvY29tbXVuaWNhdG9yX29iamVjdHMvY29tbWFuZF9wcm90by5wcm90",
+ "byLiAgoMVW5pdHlSTElucHV0EksKDWFnZW50X2FjdGlvbnMYASADKAsyNC5j",
+ "b21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMSW5wdXQuQWdlbnRBY3Rpb25z",
+ "RW50cnkSEwoLaXNfdHJhaW5pbmcYAiABKAgSMwoHY29tbWFuZBgDIAEoDjIi",
+ "LmNvbW11bmljYXRvcl9vYmplY3RzLkNvbW1hbmRQcm90bxpNChRMaXN0QWdl",
+ "bnRBY3Rpb25Qcm90bxI1CgV2YWx1ZRgBIAMoCzImLmNvbW11bmljYXRvcl9v",
+ "YmplY3RzLkFnZW50QWN0aW9uUHJvdG8abAoRQWdlbnRBY3Rpb25zRW50cnkS",
+ "CwoDa2V5GAEgASgJEkYKBXZhbHVlGAIgASgLMjcuY29tbXVuaWNhdG9yX29i",
+ "amVjdHMuVW5pdHlSTElucHV0Lkxpc3RBZ2VudEFjdGlvblByb3RvOgI4AUIf",
+ "qgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentActionProtoReflection.Descriptor, global::MLAgents.CommunicatorObjects.CommandProtoReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInput), global::MLAgents.CommunicatorObjects.UnityRLInput.Parser, new[]{ "AgentActions", "IsTraining", "Command" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLInput.Types.ListAgentActionProto), global::MLAgents.CommunicatorObjects.UnityRLInput.Types.ListAgentActionProto.Parser, new[]{ "Value" }, null, null, null),
+ null, })
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class UnityRLInput : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLInput());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityRlInputReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInput() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInput(UnityRLInput other) : this() {
+ agentActions_ = other.agentActions_.Clone();
+ isTraining_ = other.isTraining_;
+ command_ = other.command_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLInput Clone() {
+ return new UnityRLInput(this);
+ }
+
+ /// Field number for the "agent_actions" field.
+ public const int AgentActionsFieldNumber = 1;
+ private static readonly pbc::MapField.Codec _map_agentActions_codec
+ = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::MLAgents.CommunicatorObjects.UnityRLInput.Types.ListAgentActionProto.Parser), 10);
+ private readonly pbc::MapField agentActions_ = new pbc::MapField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::MapField AgentActions {
+ get { return agentActions_; }
+ }
+
+ /// Field number for the "is_training" field.
+ public const int IsTrainingFieldNumber = 2;
+ private bool isTraining_;
+ ///
+ /// EnvironmentParametersProto environment_parameters = 2;
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool IsTraining {
+ get { return isTraining_; }
+ set {
+ isTraining_ = value;
+ }
+ }
+
+ /// Field number for the "command" field.
+ public const int CommandFieldNumber = 3;
+ private global::MLAgents.CommunicatorObjects.CommandProto command_ = 0;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.CommandProto Command {
+ get { return command_; }
+ set {
+ command_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityRLInput);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityRLInput other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!AgentActions.Equals(other.AgentActions)) return false;
+ if (IsTraining != other.IsTraining) return false;
+ if (Command != other.Command) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= AgentActions.GetHashCode();
+ if (IsTraining != false) hash ^= IsTraining.GetHashCode();
+ if (Command != 0) hash ^= Command.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ agentActions_.WriteTo(output, _map_agentActions_codec);
+ if (IsTraining != false) {
+ output.WriteRawTag(16);
+ output.WriteBool(IsTraining);
+ }
+ if (Command != 0) {
+ output.WriteRawTag(24);
+ output.WriteEnum((int) Command);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += agentActions_.CalculateSize(_map_agentActions_codec);
+ if (IsTraining != false) {
+ size += 1 + 1;
+ }
+ if (Command != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Command);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityRLInput other) {
+ if (other == null) {
+ return;
+ }
+ agentActions_.Add(other.agentActions_);
+ if (other.IsTraining != false) {
+ IsTraining = other.IsTraining;
+ }
+ if (other.Command != 0) {
+ Command = other.Command;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ agentActions_.AddEntriesFrom(input, _map_agentActions_codec);
+ break;
+ }
+ case 16: {
+ IsTraining = input.ReadBool();
+ break;
+ }
+ case 24: {
+ Command = (global::MLAgents.CommunicatorObjects.CommandProto) input.ReadEnum();
+ break;
+ }
+ }
+ }
+ }
+
+ #region Nested types
+ /// Container for nested types declared in the UnityRLInput message type.
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static partial class Types {
+ public sealed partial class ListAgentActionProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListAgentActionProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityRLInput.Descriptor.NestedTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ListAgentActionProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ListAgentActionProto(ListAgentActionProto other) : this() {
+ value_ = other.value_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ListAgentActionProto Clone() {
+ return new ListAgentActionProto(this);
+ }
+
+ /// Field number for the "value" field.
+ public const int ValueFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_value_codec
+ = pb::FieldCodec.ForMessage(10, global::MLAgents.CommunicatorObjects.AgentActionProto.Parser);
+ private readonly pbc::RepeatedField value_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Value {
+ get { return value_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as ListAgentActionProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(ListAgentActionProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!value_.Equals(other.value_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= value_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ value_.WriteTo(output, _repeated_value_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += value_.CalculateSize(_repeated_value_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(ListAgentActionProto other) {
+ if (other == null) {
+ return;
+ }
+ value_.Add(other.value_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ value_.AddEntriesFrom(input, _repeated_value_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ }
+ #endregion
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInput.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInput.cs.meta
new file mode 100755
index 00000000..f381783a
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlInput.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 24680ffa432734c09b4660d82303cbd2
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlOutput.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlOutput.cs
new file mode 100755
index 00000000..eead0c16
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlOutput.cs
@@ -0,0 +1,330 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_rl_output.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_rl_output.proto
+ public static partial class UnityRlOutputReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_rl_output.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityRlOutputReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjNhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9vdXRw",
+ "dXQucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRhbmltYWxhaS9jb21t",
+ "dW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9pbmZvX3Byb3RvLnByb3RvIqMCCg1V",
+ "bml0eVJMT3V0cHV0EhMKC2dsb2JhbF9kb25lGAEgASgIEkcKCmFnZW50SW5m",
+ "b3MYAiADKAsyMy5jb21tdW5pY2F0b3Jfb2JqZWN0cy5Vbml0eVJMT3V0cHV0",
+ "LkFnZW50SW5mb3NFbnRyeRpJChJMaXN0QWdlbnRJbmZvUHJvdG8SMwoFdmFs",
+ "dWUYASADKAsyJC5jb21tdW5pY2F0b3Jfb2JqZWN0cy5BZ2VudEluZm9Qcm90",
+ "bxppCg9BZ2VudEluZm9zRW50cnkSCwoDa2V5GAEgASgJEkUKBXZhbHVlGAIg",
+ "ASgLMjYuY29tbXVuaWNhdG9yX29iamVjdHMuVW5pdHlSTE91dHB1dC5MaXN0",
+ "QWdlbnRJbmZvUHJvdG86AjgBQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JP",
+ "YmplY3RzYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentInfoProtoReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutput), global::MLAgents.CommunicatorObjects.UnityRLOutput.Parser, new[]{ "GlobalDone", "AgentInfos" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto), global::MLAgents.CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto.Parser, new[]{ "Value" }, null, null, null),
+ null, })
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class UnityRLOutput : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLOutput());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityRlOutputReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLOutput() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLOutput(UnityRLOutput other) : this() {
+ globalDone_ = other.globalDone_;
+ agentInfos_ = other.agentInfos_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLOutput Clone() {
+ return new UnityRLOutput(this);
+ }
+
+ /// Field number for the "global_done" field.
+ public const int GlobalDoneFieldNumber = 1;
+ private bool globalDone_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool GlobalDone {
+ get { return globalDone_; }
+ set {
+ globalDone_ = value;
+ }
+ }
+
+ /// Field number for the "agentInfos" field.
+ public const int AgentInfosFieldNumber = 2;
+ private static readonly pbc::MapField.Codec _map_agentInfos_codec
+ = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::MLAgents.CommunicatorObjects.UnityRLOutput.Types.ListAgentInfoProto.Parser), 18);
+ private readonly pbc::MapField agentInfos_ = new pbc::MapField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::MapField AgentInfos {
+ get { return agentInfos_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityRLOutput);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityRLOutput other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (GlobalDone != other.GlobalDone) return false;
+ if (!AgentInfos.Equals(other.AgentInfos)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (GlobalDone != false) hash ^= GlobalDone.GetHashCode();
+ hash ^= AgentInfos.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (GlobalDone != false) {
+ output.WriteRawTag(8);
+ output.WriteBool(GlobalDone);
+ }
+ agentInfos_.WriteTo(output, _map_agentInfos_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (GlobalDone != false) {
+ size += 1 + 1;
+ }
+ size += agentInfos_.CalculateSize(_map_agentInfos_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityRLOutput other) {
+ if (other == null) {
+ return;
+ }
+ if (other.GlobalDone != false) {
+ GlobalDone = other.GlobalDone;
+ }
+ agentInfos_.Add(other.agentInfos_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ GlobalDone = input.ReadBool();
+ break;
+ }
+ case 18: {
+ agentInfos_.AddEntriesFrom(input, _map_agentInfos_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ #region Nested types
+ /// Container for nested types declared in the UnityRLOutput message type.
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static partial class Types {
+ public sealed partial class ListAgentInfoProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListAgentInfoProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityRLOutput.Descriptor.NestedTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ListAgentInfoProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ListAgentInfoProto(ListAgentInfoProto other) : this() {
+ value_ = other.value_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ListAgentInfoProto Clone() {
+ return new ListAgentInfoProto(this);
+ }
+
+ /// Field number for the "value" field.
+ public const int ValueFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_value_codec
+ = pb::FieldCodec.ForMessage(10, global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser);
+ private readonly pbc::RepeatedField value_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Value {
+ get { return value_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as ListAgentInfoProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(ListAgentInfoProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!value_.Equals(other.value_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= value_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ value_.WriteTo(output, _repeated_value_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += value_.CalculateSize(_repeated_value_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(ListAgentInfoProto other) {
+ if (other == null) {
+ return;
+ }
+ value_.Add(other.value_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ value_.AddEntriesFrom(input, _repeated_value_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ }
+ #endregion
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlOutput.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlOutput.cs.meta
new file mode 100755
index 00000000..6d7405ef
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlOutput.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: af13b8fefefa74a948934dd273f94c4a
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetInput.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetInput.cs
new file mode 100755
index 00000000..a3cd2cd3
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetInput.cs
@@ -0,0 +1,171 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_rl_reset_input.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_rl_reset_input.proto
+ public static partial class UnityRlResetInputReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_rl_reset_input.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityRlResetInputReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjhhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9yZXNl",
+ "dF9pbnB1dC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaOmFuaW1hbGFp",
+ "L2NvbW11bmljYXRvcl9vYmplY3RzL2FyZW5hX3BhcmFtZXRlcnNfcHJvdG8u",
+ "cHJvdG8iswEKEVVuaXR5UkxSZXNldElucHV0EkMKBmFyZW5hcxgBIAMoCzIz",
+ "LmNvbW11bmljYXRvcl9vYmplY3RzLlVuaXR5UkxSZXNldElucHV0LkFyZW5h",
+ "c0VudHJ5GlkKC0FyZW5hc0VudHJ5EgsKA2tleRgBIAEoBRI5CgV2YWx1ZRgC",
+ "IAEoCzIqLmNvbW11bmljYXRvcl9vYmplY3RzLkFyZW5hUGFyYW1ldGVyc1By",
+ "b3RvOgI4AUIfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJv",
+ "dG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ArenaParametersProtoReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLResetInput), global::MLAgents.CommunicatorObjects.UnityRLResetInput.Parser, new[]{ "Arenas" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, })
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class UnityRLResetInput : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLResetInput());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityRlResetInputReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLResetInput() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLResetInput(UnityRLResetInput other) : this() {
+ arenas_ = other.arenas_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLResetInput Clone() {
+ return new UnityRLResetInput(this);
+ }
+
+ /// Field number for the "arenas" field.
+ public const int ArenasFieldNumber = 1;
+ private static readonly pbc::MapField.Codec _map_arenas_codec
+ = new pbc::MapField.Codec(pb::FieldCodec.ForInt32(8), pb::FieldCodec.ForMessage(18, global::MLAgents.CommunicatorObjects.ArenaParametersProto.Parser), 10);
+ private readonly pbc::MapField arenas_ = new pbc::MapField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::MapField Arenas {
+ get { return arenas_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityRLResetInput);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityRLResetInput other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!Arenas.Equals(other.Arenas)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= Arenas.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ arenas_.WriteTo(output, _map_arenas_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += arenas_.CalculateSize(_map_arenas_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityRLResetInput other) {
+ if (other == null) {
+ return;
+ }
+ arenas_.Add(other.arenas_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ arenas_.AddEntriesFrom(input, _map_arenas_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetInput.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetInput.cs.meta
new file mode 100755
index 00000000..cb26499e
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetInput.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: f652e016ed2db2c26941505b205c02bb
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetOutput.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetOutput.cs
new file mode 100755
index 00000000..185363fb
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetOutput.cs
@@ -0,0 +1,167 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_rl_reset_output.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_rl_reset_output.proto
+ public static partial class UnityRlResetOutputReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_rl_reset_output.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityRlResetOutputReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjlhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV9ybF9yZXNl",
+ "dF9vdXRwdXQucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzIjEKElVuaXR5",
+ "UkxSZXNldE91dHB1dBIbChNhcmVuYXNfaW5zdGFuY2lhdGVkGAEgAygIQh+q",
+ "AhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLResetOutput), global::MLAgents.CommunicatorObjects.UnityRLResetOutput.Parser, new[]{ "ArenasInstanciated" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class UnityRLResetOutput : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnityRLResetOutput());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.UnityRlResetOutputReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLResetOutput() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLResetOutput(UnityRLResetOutput other) : this() {
+ arenasInstanciated_ = other.arenasInstanciated_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public UnityRLResetOutput Clone() {
+ return new UnityRLResetOutput(this);
+ }
+
+ /// Field number for the "arenas_instanciated" field.
+ public const int ArenasInstanciatedFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_arenasInstanciated_codec
+ = pb::FieldCodec.ForBool(10);
+ private readonly pbc::RepeatedField arenasInstanciated_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField ArenasInstanciated {
+ get { return arenasInstanciated_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as UnityRLResetOutput);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(UnityRLResetOutput other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!arenasInstanciated_.Equals(other.arenasInstanciated_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= arenasInstanciated_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ arenasInstanciated_.WriteTo(output, _repeated_arenasInstanciated_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += arenasInstanciated_.CalculateSize(_repeated_arenasInstanciated_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(UnityRLResetOutput other) {
+ if (other == null) {
+ return;
+ }
+ arenasInstanciated_.Add(other.arenasInstanciated_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10:
+ case 8: {
+ arenasInstanciated_.AddEntriesFrom(input, _repeated_arenasInstanciated_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetOutput.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetOutput.cs.meta
new file mode 100755
index 00000000..459573ba
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityRlResetOutput.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 45aa51baac02ad9d38830e4484de559f
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternal.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternal.cs
new file mode 100755
index 00000000..a5a3b7ea
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternal.cs
@@ -0,0 +1,43 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_to_external.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from animalai/communicator_objects/unity_to_external.proto
+ public static partial class UnityToExternalReflection {
+
+ #region Descriptor
+ /// File descriptor for animalai/communicator_objects/unity_to_external.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static UnityToExternalReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjVhbmltYWxhaS9jb21tdW5pY2F0b3Jfb2JqZWN0cy91bml0eV90b19leHRl",
+ "cm5hbC5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaMWFuaW1hbGFpL2Nv",
+ "bW11bmljYXRvcl9vYmplY3RzL3VuaXR5X21lc3NhZ2UucHJvdG8yZwoPVW5p",
+ "dHlUb0V4dGVybmFsElQKCEV4Y2hhbmdlEiIuY29tbXVuaWNhdG9yX29iamVj",
+ "dHMuVW5pdHlNZXNzYWdlGiIuY29tbXVuaWNhdG9yX29iamVjdHMuVW5pdHlN",
+ "ZXNzYWdlIgBCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
+ "b3RvMw=="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.UnityMessageReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, null));
+ }
+ #endregion
+
+ }
+}
+
+#endregion Designer generated code
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternal.cs.meta b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternal.cs.meta
new file mode 100755
index 00000000..93547265
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternal.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 199e76fc828bc4561abad51402438e07
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs
new file mode 100755
index 00000000..9756068c
--- /dev/null
+++ b/env/AnimalAI-Environment/Assets/AnimalAIOlympics/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs
@@ -0,0 +1,130 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: animalai/communicator_objects/unity_to_external.proto
+//
+#pragma warning disable 0414, 1591
+#region Designer generated code
+
+using grpc = global::Grpc.Core;
+
+namespace MLAgents.CommunicatorObjects {
+ public static partial class UnityToExternal
+ {
+ static readonly string __ServiceName = "communicator_objects.UnityToExternal";
+
+ static readonly grpc::Marshaller __Marshaller_communicator_objects_UnityMessage = grpc::Marshallers.Create((arg) => global::Google.Protobuf.MessageExtensions.ToByteArray(arg), global::MLAgents.CommunicatorObjects.UnityMessage.Parser.ParseFrom);
+
+ static readonly grpc::Method __Method_Exchange = new grpc::Method(
+ grpc::MethodType.Unary,
+ __ServiceName,
+ "Exchange",
+ __Marshaller_communicator_objects_UnityMessage,
+ __Marshaller_communicator_objects_UnityMessage);
+
+ ///