diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs
index c42fd1afef..1b7cdb457a 100644
--- a/com.unity.ml-agents/Runtime/Academy.cs
+++ b/com.unity.ml-agents/Runtime/Academy.cs
@@ -91,9 +91,13 @@ public class Academy : IDisposable
/// 1.3.0
/// Support both continuous and discrete actions.
///
+ /// -
+ /// 1.4.0
+ /// Support training analytics sent from python trainer to the editor.
+ ///
///
///
- const string k_ApiVersion = "1.3.0";
+ const string k_ApiVersion = "1.4.0";
///
/// Unity package version of com.unity.ml-agents.
@@ -406,6 +410,7 @@ void InitializeEnvironment()
EnableAutomaticStepping();
SideChannelManager.RegisterSideChannel(new EngineConfigurationChannel());
+ SideChannelManager.RegisterSideChannel(new TrainingAnalyticsSideChannel());
m_EnvironmentParameters = new EnvironmentParameters();
m_StatsRecorder = new StatsRecorder();
diff --git a/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs
new file mode 100644
index 0000000000..fb480b7a11
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs
@@ -0,0 +1,40 @@
+using System;
+using UnityEngine;
+
+namespace Unity.MLAgents.Analytics
+{
+ internal static class AnalyticsUtils
+ {
+ ///
+ /// Hash a string to remove PII or secret info before sending to analytics
+ ///
+ ///
+ /// A string containing the Hash128 of the input string.
+ public static string Hash(string s)
+ {
+ var behaviorNameHash = Hash128.Compute(s);
+ return behaviorNameHash.ToString();
+ }
+
+ internal static bool s_SendEditorAnalytics = true;
+
+ ///
+ /// Helper class to temporarily disable sending analytics from unit tests.
+ ///
+ internal class DisableAnalyticsSending : IDisposable
+ {
+ private bool m_PreviousSendEditorAnalytics;
+
+ public DisableAnalyticsSending()
+ {
+ m_PreviousSendEditorAnalytics = s_SendEditorAnalytics;
+ s_SendEditorAnalytics = false;
+ }
+
+ public void Dispose()
+ {
+ s_SendEditorAnalytics = m_PreviousSendEditorAnalytics;
+ }
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta
new file mode 100644
index 0000000000..b00fab1c90
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: af1ef3e70f1242938d7b39284b1a892b
+timeCreated: 1610575760
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Analytics/Events.cs b/com.unity.ml-agents/Runtime/Analytics/Events.cs
index 276587e761..ed22d3c67d 100644
--- a/com.unity.ml-agents/Runtime/Analytics/Events.cs
+++ b/com.unity.ml-agents/Runtime/Analytics/Events.cs
@@ -91,4 +91,67 @@ public static EventObservationSpec FromSensor(ISensor sensor)
};
}
}
+
+ internal struct RemotePolicyInitializedEvent
+ {
+ public string TrainingSessionGuid;
+ ///
+ /// Hash of the BehaviorName.
+ ///
+ public string BehaviorName;
+ public List ObservationSpecs;
+ public EventActionSpec ActionSpec;
+
+ ///
+ /// This will be the same as TrainingEnvironmentInitializedEvent if available, but
+ /// TrainingEnvironmentInitializedEvent maybe not always be available with older trainers.
+ ///
+ public string MLAgentsEnvsVersion;
+ public string TrainerCommunicationVersion;
+ }
+
+ internal struct TrainingEnvironmentInitializedEvent
+ {
+ public string TrainingSessionGuid;
+
+ public string TrainerPythonVersion;
+ public string MLAgentsVersion;
+ public string MLAgentsEnvsVersion;
+ public string TorchVersion;
+ public string TorchDeviceType;
+ public int NumEnvironments;
+ public int NumEnvironmentParameters;
+ }
+
+ [Flags]
+ internal enum RewardSignals
+ {
+ Extrinsic = 1 << 0,
+ Gail = 1 << 1,
+ Curiosity = 1 << 2,
+ Rnd = 1 << 3,
+ }
+
+ [Flags]
+ internal enum TrainingFeatures
+ {
+ BehavioralCloning = 1 << 0,
+ Recurrent = 1 << 1,
+ Threaded = 1 << 2,
+ SelfPlay = 1 << 3,
+ Curriculum = 1 << 4,
+ }
+
+ internal struct TrainingBehaviorInitializedEvent
+ {
+ public string TrainingSessionGuid;
+
+ public string BehaviorName;
+ public string TrainerType;
+ public RewardSignals RewardSignalFlags;
+ public TrainingFeatures TrainingFeatureFlags;
+ public string VisualEncoder;
+ public int NumNetworkLayers;
+ public int NumNetworkHiddenUnits;
+ }
}
diff --git a/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs b/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs
index fd8d8bc99a..9dc1f4e535 100644
--- a/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs
+++ b/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs
@@ -116,7 +116,10 @@ ActionSpec actionSpec
// Note - to debug, use JsonUtility.ToJson on the event.
//Debug.Log(JsonUtility.ToJson(data, true));
#if UNITY_EDITOR
- EditorAnalytics.SendEventWithLimit(k_EventName, data, k_EventVersion);
+ if (AnalyticsUtils.s_SendEditorAnalytics)
+ {
+ EditorAnalytics.SendEventWithLimit(k_EventName, data, k_EventVersion);
+ }
#else
return;
#endif
@@ -143,8 +146,7 @@ ActionSpec actionSpec
var inferenceEvent = new InferenceEvent();
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
- var behaviorNameHash = Hash128.Compute(behaviorName);
- inferenceEvent.BehaviorName = behaviorNameHash.ToString();
+ inferenceEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
inferenceEvent.BarracudaModelSource = barracudaModel.IrSource;
inferenceEvent.BarracudaModelVersion = barracudaModel.IrVersion;
diff --git a/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs
new file mode 100644
index 0000000000..6bdea00171
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs
@@ -0,0 +1,246 @@
+using System;
+using System.Collections.Generic;
+using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Sensors;
+using UnityEngine;
+using UnityEngine.Analytics;
+
+#if UNITY_EDITOR
+using UnityEditor;
+using UnityEditor.Analytics;
+#endif
+
+namespace Unity.MLAgents.Analytics
+{
+ internal class TrainingAnalytics
+ {
+ const string k_VendorKey = "unity.ml-agents";
+ const string k_TrainingEnvironmentInitializedEventName = "ml_agents_training_environment_initialized";
+ const string k_TrainingBehaviorInitializedEventName = "ml_agents_training_behavior_initialized";
+ const string k_RemotePolicyInitializedEventName = "ml_agents_remote_policy_initialized";
+
+ private static readonly string[] s_EventNames =
+ {
+ k_TrainingEnvironmentInitializedEventName,
+ k_TrainingBehaviorInitializedEventName,
+ k_RemotePolicyInitializedEventName
+ };
+
+ ///
+ /// Whether or not we've registered this particular event yet
+ ///
+ static bool s_EventsRegistered = false;
+
+ ///
+ /// Hourly limit for this event name
+ ///
+ const int k_MaxEventsPerHour = 1000;
+
+ ///
+ /// Maximum number of items in this event.
+ ///
+ const int k_MaxNumberOfElements = 1000;
+
+ private static bool s_SentEnvironmentInitialized;
+ ///
+ /// Behaviors that we've already sent events for.
+ ///
+ private static HashSet s_SentRemotePolicyInitialized;
+ private static HashSet s_SentTrainingBehaviorInitialized;
+
+ private static Guid s_TrainingSessionGuid;
+
+ // These are set when the RpcCommunicator connects
+ private static string s_TrainerPackageVersion = "";
+ private static string s_TrainerCommunicationVersion = "";
+
+ static bool EnableAnalytics()
+ {
+ if (s_EventsRegistered)
+ {
+ return true;
+ }
+
+ foreach (var eventName in s_EventNames)
+ {
+#if UNITY_EDITOR
+ AnalyticsResult result = EditorAnalytics.RegisterEventWithLimit(eventName, k_MaxEventsPerHour, k_MaxNumberOfElements, k_VendorKey);
+#else
+ AnalyticsResult result = AnalyticsResult.UnsupportedPlatform;
+#endif
+ if (result != AnalyticsResult.Ok)
+ {
+ return false;
+ }
+ }
+ s_EventsRegistered = true;
+
+ if (s_SentRemotePolicyInitialized == null)
+ {
+ s_SentRemotePolicyInitialized = new HashSet();
+ s_SentTrainingBehaviorInitialized = new HashSet();
+ s_TrainingSessionGuid = Guid.NewGuid();
+ }
+
+ return s_EventsRegistered;
+ }
+
+ ///
+ /// Cache information about the trainer when it becomes available in the RpcCommunicator.
+ ///
+ ///
+ ///
+ public static void SetTrainerInformation(string packageVersion, string communicationVersion)
+ {
+ s_TrainerPackageVersion = packageVersion;
+ s_TrainerCommunicationVersion = communicationVersion;
+ }
+
+ public static bool IsAnalyticsEnabled()
+ {
+#if UNITY_EDITOR
+ return EditorAnalytics.enabled;
+#else
+ return false;
+#endif
+ }
+
+ public static void TrainingEnvironmentInitialized(TrainingEnvironmentInitializedEvent tbiEvent)
+ {
+ if (!IsAnalyticsEnabled())
+ return;
+
+ if (!EnableAnalytics())
+ return;
+
+ if (s_SentEnvironmentInitialized)
+ {
+ // We already sent an TrainingEnvironmentInitializedEvent. Exit so we don't resend.
+ return;
+ }
+
+ s_SentEnvironmentInitialized = true;
+ tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
+
+ // Note - to debug, use JsonUtility.ToJson on the event.
+ // Debug.Log(
+ // $"Would send event {k_TrainingEnvironmentInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}"
+ // );
+#if UNITY_EDITOR
+ if (AnalyticsUtils.s_SendEditorAnalytics)
+ {
+ EditorAnalytics.SendEventWithLimit(k_TrainingEnvironmentInitializedEventName, tbiEvent);
+ }
+#else
+ return;
+#endif
+ }
+
+ public static void RemotePolicyInitialized(
+ string fullyQualifiedBehaviorName,
+ IList sensors,
+ ActionSpec actionSpec
+ )
+ {
+ if (!IsAnalyticsEnabled())
+ return;
+
+ if (!EnableAnalytics())
+ return;
+
+ // Extract base behavior name (no team ID)
+ var behaviorName = ParseBehaviorName(fullyQualifiedBehaviorName);
+ var added = s_SentRemotePolicyInitialized.Add(behaviorName);
+
+ if (!added)
+ {
+ // We previously added this model. Exit so we don't resend.
+ return;
+ }
+
+ var data = GetEventForRemotePolicy(behaviorName, sensors, actionSpec);
+ // Note - to debug, use JsonUtility.ToJson on the event.
+ // Debug.Log(
+ // $"Would send event {k_RemotePolicyInitializedEventName} with body {JsonUtility.ToJson(data, true)}"
+ // );
+#if UNITY_EDITOR
+ if (AnalyticsUtils.s_SendEditorAnalytics)
+ {
+ EditorAnalytics.SendEventWithLimit(k_RemotePolicyInitializedEventName, data);
+ }
+#else
+ return;
+#endif
+ }
+
+ internal static string ParseBehaviorName(string fullyQualifiedBehaviorName)
+ {
+ var lastQuestionIndex = fullyQualifiedBehaviorName.LastIndexOf("?");
+ if (lastQuestionIndex < 0)
+ {
+ // Nothing to remove
+ return fullyQualifiedBehaviorName;
+ }
+
+ return fullyQualifiedBehaviorName.Substring(0, lastQuestionIndex);
+ }
+
+ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent tbiEvent)
+ {
+ if (!IsAnalyticsEnabled())
+ return;
+
+ if (!EnableAnalytics())
+ return;
+
+ var behaviorName = tbiEvent.BehaviorName;
+ var added = s_SentTrainingBehaviorInitialized.Add(behaviorName);
+
+ if (!added)
+ {
+ // We previously added this model. Exit so we don't resend.
+ return;
+ }
+
+ // Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
+ tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
+ tbiEvent.BehaviorName = AnalyticsUtils.Hash(tbiEvent.BehaviorName);
+
+ // Note - to debug, use JsonUtility.ToJson on the event.
+ // Debug.Log(
+ // $"Would send event {k_TrainingBehaviorInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}"
+ // );
+#if UNITY_EDITOR
+ if (AnalyticsUtils.s_SendEditorAnalytics)
+ {
+ EditorAnalytics.SendEventWithLimit(k_TrainingBehaviorInitializedEventName, tbiEvent);
+ }
+#else
+ return;
+#endif
+ }
+
+ static RemotePolicyInitializedEvent GetEventForRemotePolicy(
+ string behaviorName,
+ IList sensors,
+ ActionSpec actionSpec)
+ {
+ var remotePolicyEvent = new RemotePolicyInitializedEvent();
+
+ // Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
+ remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
+
+ remotePolicyEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
+ remotePolicyEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec);
+ remotePolicyEvent.ObservationSpecs = new List(sensors.Count);
+ foreach (var sensor in sensors)
+ {
+ remotePolicyEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor));
+ }
+
+ remotePolicyEvent.MLAgentsEnvsVersion = s_TrainerPackageVersion;
+ remotePolicyEvent.TrainerCommunicationVersion = s_TrainerCommunicationVersion;
+ return remotePolicyEvent;
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta
new file mode 100644
index 0000000000..9109c265a2
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 5ad0bc6b45614bb7929d25dd59d5ac38
+timeCreated: 1608168600
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
index 01a6706c98..b9044dd6a9 100644
--- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
@@ -6,6 +6,7 @@
using UnityEngine;
using System.Runtime.CompilerServices;
using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Analytics;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;
@@ -435,6 +436,7 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto
ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
CompressedChannelMapping = proto.CompressedChannelMapping,
HybridActions = proto.HybridActions,
+ TrainingAnalytics = proto.TrainingAnalytics,
};
}
@@ -446,6 +448,7 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
HybridActions = rlCaps.HybridActions,
+ TrainingAnalytics = rlCaps.TrainingAnalytics,
};
}
@@ -476,5 +479,54 @@ internal static bool IsTrivialMapping(ISensor sensor)
}
return true;
}
+
+ #region Analytics
+
+ internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent(
+ this TrainingEnvironmentInitialized inputProto)
+ {
+ return new TrainingEnvironmentInitializedEvent
+ {
+ TrainerPythonVersion = inputProto.PythonVersion,
+ MLAgentsVersion = inputProto.MlagentsVersion,
+ MLAgentsEnvsVersion = inputProto.MlagentsEnvsVersion,
+ TorchVersion = inputProto.TorchVersion,
+ TorchDeviceType = inputProto.TorchDeviceType,
+ NumEnvironments = inputProto.NumEnvs,
+ NumEnvironmentParameters = inputProto.NumEnvironmentParameters,
+ };
+ }
+
+ internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEvent(
+ this TrainingBehaviorInitialized inputProto)
+ {
+ RewardSignals rewardSignals = 0;
+ rewardSignals |= inputProto.ExtrinsicRewardEnabled ? RewardSignals.Extrinsic : 0;
+ rewardSignals |= inputProto.GailRewardEnabled ? RewardSignals.Gail : 0;
+ rewardSignals |= inputProto.CuriosityRewardEnabled ? RewardSignals.Curiosity : 0;
+ rewardSignals |= inputProto.RndRewardEnabled ? RewardSignals.Rnd : 0;
+
+ TrainingFeatures trainingFeatures = 0;
+ trainingFeatures |= inputProto.BehavioralCloningEnabled ? TrainingFeatures.BehavioralCloning : 0;
+ trainingFeatures |= inputProto.RecurrentEnabled ? TrainingFeatures.Recurrent : 0;
+ trainingFeatures |= inputProto.TrainerThreaded ? TrainingFeatures.Threaded : 0;
+ trainingFeatures |= inputProto.SelfPlayEnabled ? TrainingFeatures.SelfPlay : 0;
+ trainingFeatures |= inputProto.CurriculumEnabled ? TrainingFeatures.Curriculum : 0;
+
+
+ return new TrainingBehaviorInitializedEvent
+ {
+ BehaviorName = inputProto.BehaviorName,
+ TrainerType = inputProto.TrainerType,
+ RewardSignalFlags = rewardSignals,
+ TrainingFeatureFlags = trainingFeatures,
+ VisualEncoder = inputProto.VisualEncoder,
+ NumNetworkLayers = inputProto.NumNetworkLayers,
+ NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits,
+ };
+ }
+
+ #endregion
+
}
}
diff --git a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
index 515ff52770..683dfebd17 100644
--- a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
@@ -9,6 +9,7 @@
using System.Linq;
using UnityEngine;
using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Analytics;
using Unity.MLAgents.CommunicatorObjects;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.SideChannels;
@@ -114,10 +115,12 @@ public UnityRLInitParameters Initialize(CommunicatorInitParameters initParameter
},
out input);
- var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion;
var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion;
+ var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion;
var unityCommunicationVersion = initParameters.unityCommunicationVersion;
+ TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion);
+
var communicationIsCompatible = CheckCommunicationVersionsAreCompatible(unityCommunicationVersion,
pythonCommunicationVersion,
pythonPackageVersion);
diff --git a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
index 48db8815b0..79627e48d4 100644
--- a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
@@ -8,6 +8,7 @@ internal class UnityRLCapabilities
public bool ConcatenatedPngObservations;
public bool CompressedChannelMapping;
public bool HybridActions;
+ public bool TrainingAnalytics;
///
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
@@ -17,12 +18,14 @@ public UnityRLCapabilities(
bool baseRlCapabilities = true,
bool concatenatedPngObservations = true,
bool compressedChannelMapping = true,
- bool hybridActions = true)
+ bool hybridActions = true,
+ bool trainingAnalytics = true)
{
BaseRLCapabilities = baseRlCapabilities;
ConcatenatedPngObservations = concatenatedPngObservations;
CompressedChannelMapping = compressedChannelMapping;
HybridActions = hybridActions;
+ TrainingAnalytics = trainingAnalytics;
}
///
diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
index 24137bbc40..3953aea214 100644
--- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
+++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
@@ -25,16 +25,16 @@ static CapabilitiesReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
- "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilAEKGFVuaXR5UkxD",
+ "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMirwEKGFVuaXR5UkxD",
"YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS",
"IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy",
"ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg",
- "ASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZw",
- "cm90bzM="));
+ "ASgIEhkKEXRyYWluaW5nQW5hbHl0aWNzGAUgASgIQiWqAiJVbml0eS5NTEFn",
+ "ZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions" }, null, null, null)
+ new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics" }, null, null, null)
}));
}
#endregion
@@ -75,6 +75,7 @@ public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
concatenatedPngObservations_ = other.concatenatedPngObservations_;
compressedChannelMapping_ = other.compressedChannelMapping_;
hybridActions_ = other.hybridActions_;
+ trainingAnalytics_ = other.trainingAnalytics_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
@@ -139,6 +140,20 @@ public bool HybridActions {
}
}
+ /// Field number for the "trainingAnalytics" field.
+ public const int TrainingAnalyticsFieldNumber = 5;
+ private bool trainingAnalytics_;
+ ///
+ /// support for training analytics
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool TrainingAnalytics {
+ get { return trainingAnalytics_; }
+ set {
+ trainingAnalytics_ = value;
+ }
+ }
+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);
@@ -156,6 +171,7 @@ public bool Equals(UnityRLCapabilitiesProto other) {
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
if (HybridActions != other.HybridActions) return false;
+ if (TrainingAnalytics != other.TrainingAnalytics) return false;
return Equals(_unknownFields, other._unknownFields);
}
@@ -166,6 +182,7 @@ public override int GetHashCode() {
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
+ if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
@@ -195,6 +212,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(32);
output.WriteBool(HybridActions);
}
+ if (TrainingAnalytics != false) {
+ output.WriteRawTag(40);
+ output.WriteBool(TrainingAnalytics);
+ }
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
@@ -215,6 +236,9 @@ public int CalculateSize() {
if (HybridActions != false) {
size += 1 + 1;
}
+ if (TrainingAnalytics != false) {
+ size += 1 + 1;
+ }
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
@@ -238,6 +262,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) {
if (other.HybridActions != false) {
HybridActions = other.HybridActions;
}
+ if (other.TrainingAnalytics != false) {
+ TrainingAnalytics = other.TrainingAnalytics;
+ }
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
@@ -265,6 +292,10 @@ public void MergeFrom(pb::CodedInputStream input) {
HybridActions = input.ReadBool();
break;
}
+ case 40: {
+ TrainingAnalytics = input.ReadBool();
+ break;
+ }
}
}
}
diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs
new file mode 100644
index 0000000000..099563e949
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs
@@ -0,0 +1,850 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: mlagents_envs/communicator_objects/training_analytics.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace Unity.MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from mlagents_envs/communicator_objects/training_analytics.proto
+ internal static partial class TrainingAnalyticsReflection {
+
+ #region Descriptor
+ /// File descriptor for mlagents_envs/communicator_objects/training_analytics.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static TrainingAnalyticsReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3RyYWluaW5n",
+ "X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi2QEKHlRy",
+ "YWluaW5nRW52aXJvbm1lbnRJbml0aWFsaXplZBIYChBtbGFnZW50c192ZXJz",
+ "aW9uGAEgASgJEh0KFW1sYWdlbnRzX2VudnNfdmVyc2lvbhgCIAEoCRIWCg5w",
+ "eXRob25fdmVyc2lvbhgDIAEoCRIVCg10b3JjaF92ZXJzaW9uGAQgASgJEhkK",
+ "EXRvcmNoX2RldmljZV90eXBlGAUgASgJEhAKCG51bV9lbnZzGAYgASgFEiIK",
+ "Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFIq0DChtUcmFpbmlu",
+ "Z0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoNYmVoYXZpb3JfbmFtZRgBIAEoCRIU",
+ "Cgx0cmFpbmVyX3R5cGUYAiABKAkSIAoYZXh0cmluc2ljX3Jld2FyZF9lbmFi",
+ "bGVkGAMgASgIEhsKE2dhaWxfcmV3YXJkX2VuYWJsZWQYBCABKAgSIAoYY3Vy",
+ "aW9zaXR5X3Jld2FyZF9lbmFibGVkGAUgASgIEhoKEnJuZF9yZXdhcmRfZW5h",
+ "YmxlZBgGIAEoCBIiChpiZWhhdmlvcmFsX2Nsb25pbmdfZW5hYmxlZBgHIAEo",
+ "CBIZChFyZWN1cnJlbnRfZW5hYmxlZBgIIAEoCBIWCg52aXN1YWxfZW5jb2Rl",
+ "chgJIAEoCRIaChJudW1fbmV0d29ya19sYXllcnMYCiABKAUSIAoYbnVtX25l",
+ "dHdvcmtfaGlkZGVuX3VuaXRzGAsgASgFEhgKEHRyYWluZXJfdGhyZWFkZWQY",
+ "DCABKAgSGQoRc2VsZl9wbGF5X2VuYWJsZWQYDSABKAgSGgoSY3VycmljdWx1",
+ "bV9lbmFibGVkGA4gASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0",
+ "b3JPYmplY3RzYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters" }, null, null, null),
+ new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ internal sealed partial class TrainingEnvironmentInitialized : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrainingEnvironmentInitialized());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Unity.MLAgents.CommunicatorObjects.TrainingAnalyticsReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public TrainingEnvironmentInitialized() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public TrainingEnvironmentInitialized(TrainingEnvironmentInitialized other) : this() {
+ mlagentsVersion_ = other.mlagentsVersion_;
+ mlagentsEnvsVersion_ = other.mlagentsEnvsVersion_;
+ pythonVersion_ = other.pythonVersion_;
+ torchVersion_ = other.torchVersion_;
+ torchDeviceType_ = other.torchDeviceType_;
+ numEnvs_ = other.numEnvs_;
+ numEnvironmentParameters_ = other.numEnvironmentParameters_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public TrainingEnvironmentInitialized Clone() {
+ return new TrainingEnvironmentInitialized(this);
+ }
+
+ /// Field number for the "mlagents_version" field.
+ public const int MlagentsVersionFieldNumber = 1;
+ private string mlagentsVersion_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string MlagentsVersion {
+ get { return mlagentsVersion_; }
+ set {
+ mlagentsVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "mlagents_envs_version" field.
+ public const int MlagentsEnvsVersionFieldNumber = 2;
+ private string mlagentsEnvsVersion_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string MlagentsEnvsVersion {
+ get { return mlagentsEnvsVersion_; }
+ set {
+ mlagentsEnvsVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "python_version" field.
+ public const int PythonVersionFieldNumber = 3;
+ private string pythonVersion_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string PythonVersion {
+ get { return pythonVersion_; }
+ set {
+ pythonVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "torch_version" field.
+ public const int TorchVersionFieldNumber = 4;
+ private string torchVersion_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string TorchVersion {
+ get { return torchVersion_; }
+ set {
+ torchVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "torch_device_type" field.
+ public const int TorchDeviceTypeFieldNumber = 5;
+ private string torchDeviceType_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string TorchDeviceType {
+ get { return torchDeviceType_; }
+ set {
+ torchDeviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "num_envs" field.
+ public const int NumEnvsFieldNumber = 6;
+ private int numEnvs_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int NumEnvs {
+ get { return numEnvs_; }
+ set {
+ numEnvs_ = value;
+ }
+ }
+
+ /// Field number for the "num_environment_parameters" field.
+ public const int NumEnvironmentParametersFieldNumber = 7;
+ private int numEnvironmentParameters_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int NumEnvironmentParameters {
+ get { return numEnvironmentParameters_; }
+ set {
+ numEnvironmentParameters_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as TrainingEnvironmentInitialized);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(TrainingEnvironmentInitialized other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (MlagentsVersion != other.MlagentsVersion) return false;
+ if (MlagentsEnvsVersion != other.MlagentsEnvsVersion) return false;
+ if (PythonVersion != other.PythonVersion) return false;
+ if (TorchVersion != other.TorchVersion) return false;
+ if (TorchDeviceType != other.TorchDeviceType) return false;
+ if (NumEnvs != other.NumEnvs) return false;
+ if (NumEnvironmentParameters != other.NumEnvironmentParameters) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (MlagentsVersion.Length != 0) hash ^= MlagentsVersion.GetHashCode();
+ if (MlagentsEnvsVersion.Length != 0) hash ^= MlagentsEnvsVersion.GetHashCode();
+ if (PythonVersion.Length != 0) hash ^= PythonVersion.GetHashCode();
+ if (TorchVersion.Length != 0) hash ^= TorchVersion.GetHashCode();
+ if (TorchDeviceType.Length != 0) hash ^= TorchDeviceType.GetHashCode();
+ if (NumEnvs != 0) hash ^= NumEnvs.GetHashCode();
+ if (NumEnvironmentParameters != 0) hash ^= NumEnvironmentParameters.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (MlagentsVersion.Length != 0) {
+ output.WriteRawTag(10);
+ output.WriteString(MlagentsVersion);
+ }
+ if (MlagentsEnvsVersion.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(MlagentsEnvsVersion);
+ }
+ if (PythonVersion.Length != 0) {
+ output.WriteRawTag(26);
+ output.WriteString(PythonVersion);
+ }
+ if (TorchVersion.Length != 0) {
+ output.WriteRawTag(34);
+ output.WriteString(TorchVersion);
+ }
+ if (TorchDeviceType.Length != 0) {
+ output.WriteRawTag(42);
+ output.WriteString(TorchDeviceType);
+ }
+ if (NumEnvs != 0) {
+ output.WriteRawTag(48);
+ output.WriteInt32(NumEnvs);
+ }
+ if (NumEnvironmentParameters != 0) {
+ output.WriteRawTag(56);
+ output.WriteInt32(NumEnvironmentParameters);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (MlagentsVersion.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(MlagentsVersion);
+ }
+ if (MlagentsEnvsVersion.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(MlagentsEnvsVersion);
+ }
+ if (PythonVersion.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(PythonVersion);
+ }
+ if (TorchVersion.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(TorchVersion);
+ }
+ if (TorchDeviceType.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(TorchDeviceType);
+ }
+ if (NumEnvs != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvs);
+ }
+ if (NumEnvironmentParameters != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvironmentParameters);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(TrainingEnvironmentInitialized other) {
+ if (other == null) {
+ return;
+ }
+ if (other.MlagentsVersion.Length != 0) {
+ MlagentsVersion = other.MlagentsVersion;
+ }
+ if (other.MlagentsEnvsVersion.Length != 0) {
+ MlagentsEnvsVersion = other.MlagentsEnvsVersion;
+ }
+ if (other.PythonVersion.Length != 0) {
+ PythonVersion = other.PythonVersion;
+ }
+ if (other.TorchVersion.Length != 0) {
+ TorchVersion = other.TorchVersion;
+ }
+ if (other.TorchDeviceType.Length != 0) {
+ TorchDeviceType = other.TorchDeviceType;
+ }
+ if (other.NumEnvs != 0) {
+ NumEnvs = other.NumEnvs;
+ }
+ if (other.NumEnvironmentParameters != 0) {
+ NumEnvironmentParameters = other.NumEnvironmentParameters;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ MlagentsVersion = input.ReadString();
+ break;
+ }
+ case 18: {
+ MlagentsEnvsVersion = input.ReadString();
+ break;
+ }
+ case 26: {
+ PythonVersion = input.ReadString();
+ break;
+ }
+ case 34: {
+ TorchVersion = input.ReadString();
+ break;
+ }
+ case 42: {
+ TorchDeviceType = input.ReadString();
+ break;
+ }
+ case 48: {
+ NumEnvs = input.ReadInt32();
+ break;
+ }
+ case 56: {
+ NumEnvironmentParameters = input.ReadInt32();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ internal sealed partial class TrainingBehaviorInitialized : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrainingBehaviorInitialized());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Unity.MLAgents.CommunicatorObjects.TrainingAnalyticsReflection.Descriptor.MessageTypes[1]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public TrainingBehaviorInitialized() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public TrainingBehaviorInitialized(TrainingBehaviorInitialized other) : this() {
+ behaviorName_ = other.behaviorName_;
+ trainerType_ = other.trainerType_;
+ extrinsicRewardEnabled_ = other.extrinsicRewardEnabled_;
+ gailRewardEnabled_ = other.gailRewardEnabled_;
+ curiosityRewardEnabled_ = other.curiosityRewardEnabled_;
+ rndRewardEnabled_ = other.rndRewardEnabled_;
+ behavioralCloningEnabled_ = other.behavioralCloningEnabled_;
+ recurrentEnabled_ = other.recurrentEnabled_;
+ visualEncoder_ = other.visualEncoder_;
+ numNetworkLayers_ = other.numNetworkLayers_;
+ numNetworkHiddenUnits_ = other.numNetworkHiddenUnits_;
+ trainerThreaded_ = other.trainerThreaded_;
+ selfPlayEnabled_ = other.selfPlayEnabled_;
+ curriculumEnabled_ = other.curriculumEnabled_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public TrainingBehaviorInitialized Clone() {
+ return new TrainingBehaviorInitialized(this);
+ }
+
+ /// Field number for the "behavior_name" field.
+ public const int BehaviorNameFieldNumber = 1;
+ private string behaviorName_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string BehaviorName {
+ get { return behaviorName_; }
+ set {
+ behaviorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "trainer_type" field.
+ public const int TrainerTypeFieldNumber = 2;
+ private string trainerType_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string TrainerType {
+ get { return trainerType_; }
+ set {
+ trainerType_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "extrinsic_reward_enabled" field.
+ public const int ExtrinsicRewardEnabledFieldNumber = 3;
+ private bool extrinsicRewardEnabled_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool ExtrinsicRewardEnabled {
+ get { return extrinsicRewardEnabled_; }
+ set {
+ extrinsicRewardEnabled_ = value;
+ }
+ }
+
+ /// Field number for the "gail_reward_enabled" field.
+ public const int GailRewardEnabledFieldNumber = 4;
+ private bool gailRewardEnabled_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool GailRewardEnabled {
+ get { return gailRewardEnabled_; }
+ set {
+ gailRewardEnabled_ = value;
+ }
+ }
+
+ /// Field number for the "curiosity_reward_enabled" field.
+ public const int CuriosityRewardEnabledFieldNumber = 5;
+ private bool curiosityRewardEnabled_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool CuriosityRewardEnabled {
+ get { return curiosityRewardEnabled_; }
+ set {
+ curiosityRewardEnabled_ = value;
+ }
+ }
+
+ /// Field number for the "rnd_reward_enabled" field.
+ public const int RndRewardEnabledFieldNumber = 6;
+ private bool rndRewardEnabled_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool RndRewardEnabled {
+ get { return rndRewardEnabled_; }
+ set {
+ rndRewardEnabled_ = value;
+ }
+ }
+
+ /// Field number for the "behavioral_cloning_enabled" field.
+ public const int BehavioralCloningEnabledFieldNumber = 7;
+ private bool behavioralCloningEnabled_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool BehavioralCloningEnabled {
+ get { return behavioralCloningEnabled_; }
+ set {
+ behavioralCloningEnabled_ = value;
+ }
+ }
+
+ /// Field number for the "recurrent_enabled" field.
+ public const int RecurrentEnabledFieldNumber = 8;
+ private bool recurrentEnabled_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool RecurrentEnabled {
+ get { return recurrentEnabled_; }
+ set {
+ recurrentEnabled_ = value;
+ }
+ }
+
+ /// Field number for the "visual_encoder" field.
+ public const int VisualEncoderFieldNumber = 9;
+ private string visualEncoder_ = "";
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string VisualEncoder {
+ get { return visualEncoder_; }
+ set {
+ visualEncoder_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "num_network_layers" field.
+ public const int NumNetworkLayersFieldNumber = 10;
+ private int numNetworkLayers_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int NumNetworkLayers {
+ get { return numNetworkLayers_; }
+ set {
+ numNetworkLayers_ = value;
+ }
+ }
+
+ /// Field number for the "num_network_hidden_units" field.
+ public const int NumNetworkHiddenUnitsFieldNumber = 11;
+ private int numNetworkHiddenUnits_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int NumNetworkHiddenUnits {
+ get { return numNetworkHiddenUnits_; }
+ set {
+ numNetworkHiddenUnits_ = value;
+ }
+ }
+
+ /// Field number for the "trainer_threaded" field.
+ public const int TrainerThreadedFieldNumber = 12;
+ private bool trainerThreaded_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool TrainerThreaded {
+ get { return trainerThreaded_; }
+ set {
+ trainerThreaded_ = value;
+ }
+ }
+
+ /// Field number for the "self_play_enabled" field.
+ public const int SelfPlayEnabledFieldNumber = 13;
+ private bool selfPlayEnabled_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool SelfPlayEnabled {
+ get { return selfPlayEnabled_; }
+ set {
+ selfPlayEnabled_ = value;
+ }
+ }
+
+ /// Field number for the "curriculum_enabled" field.
+ public const int CurriculumEnabledFieldNumber = 14;
+ private bool curriculumEnabled_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool CurriculumEnabled {
+ get { return curriculumEnabled_; }
+ set {
+ curriculumEnabled_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as TrainingBehaviorInitialized);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(TrainingBehaviorInitialized other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (BehaviorName != other.BehaviorName) return false;
+ if (TrainerType != other.TrainerType) return false;
+ if (ExtrinsicRewardEnabled != other.ExtrinsicRewardEnabled) return false;
+ if (GailRewardEnabled != other.GailRewardEnabled) return false;
+ if (CuriosityRewardEnabled != other.CuriosityRewardEnabled) return false;
+ if (RndRewardEnabled != other.RndRewardEnabled) return false;
+ if (BehavioralCloningEnabled != other.BehavioralCloningEnabled) return false;
+ if (RecurrentEnabled != other.RecurrentEnabled) return false;
+ if (VisualEncoder != other.VisualEncoder) return false;
+ if (NumNetworkLayers != other.NumNetworkLayers) return false;
+ if (NumNetworkHiddenUnits != other.NumNetworkHiddenUnits) return false;
+ if (TrainerThreaded != other.TrainerThreaded) return false;
+ if (SelfPlayEnabled != other.SelfPlayEnabled) return false;
+ if (CurriculumEnabled != other.CurriculumEnabled) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (BehaviorName.Length != 0) hash ^= BehaviorName.GetHashCode();
+ if (TrainerType.Length != 0) hash ^= TrainerType.GetHashCode();
+ if (ExtrinsicRewardEnabled != false) hash ^= ExtrinsicRewardEnabled.GetHashCode();
+ if (GailRewardEnabled != false) hash ^= GailRewardEnabled.GetHashCode();
+ if (CuriosityRewardEnabled != false) hash ^= CuriosityRewardEnabled.GetHashCode();
+ if (RndRewardEnabled != false) hash ^= RndRewardEnabled.GetHashCode();
+ if (BehavioralCloningEnabled != false) hash ^= BehavioralCloningEnabled.GetHashCode();
+ if (RecurrentEnabled != false) hash ^= RecurrentEnabled.GetHashCode();
+ if (VisualEncoder.Length != 0) hash ^= VisualEncoder.GetHashCode();
+ if (NumNetworkLayers != 0) hash ^= NumNetworkLayers.GetHashCode();
+ if (NumNetworkHiddenUnits != 0) hash ^= NumNetworkHiddenUnits.GetHashCode();
+ if (TrainerThreaded != false) hash ^= TrainerThreaded.GetHashCode();
+ if (SelfPlayEnabled != false) hash ^= SelfPlayEnabled.GetHashCode();
+ if (CurriculumEnabled != false) hash ^= CurriculumEnabled.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (BehaviorName.Length != 0) {
+ output.WriteRawTag(10);
+ output.WriteString(BehaviorName);
+ }
+ if (TrainerType.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(TrainerType);
+ }
+ if (ExtrinsicRewardEnabled != false) {
+ output.WriteRawTag(24);
+ output.WriteBool(ExtrinsicRewardEnabled);
+ }
+ if (GailRewardEnabled != false) {
+ output.WriteRawTag(32);
+ output.WriteBool(GailRewardEnabled);
+ }
+ if (CuriosityRewardEnabled != false) {
+ output.WriteRawTag(40);
+ output.WriteBool(CuriosityRewardEnabled);
+ }
+ if (RndRewardEnabled != false) {
+ output.WriteRawTag(48);
+ output.WriteBool(RndRewardEnabled);
+ }
+ if (BehavioralCloningEnabled != false) {
+ output.WriteRawTag(56);
+ output.WriteBool(BehavioralCloningEnabled);
+ }
+ if (RecurrentEnabled != false) {
+ output.WriteRawTag(64);
+ output.WriteBool(RecurrentEnabled);
+ }
+ if (VisualEncoder.Length != 0) {
+ output.WriteRawTag(74);
+ output.WriteString(VisualEncoder);
+ }
+ if (NumNetworkLayers != 0) {
+ output.WriteRawTag(80);
+ output.WriteInt32(NumNetworkLayers);
+ }
+ if (NumNetworkHiddenUnits != 0) {
+ output.WriteRawTag(88);
+ output.WriteInt32(NumNetworkHiddenUnits);
+ }
+ if (TrainerThreaded != false) {
+ output.WriteRawTag(96);
+ output.WriteBool(TrainerThreaded);
+ }
+ if (SelfPlayEnabled != false) {
+ output.WriteRawTag(104);
+ output.WriteBool(SelfPlayEnabled);
+ }
+ if (CurriculumEnabled != false) {
+ output.WriteRawTag(112);
+ output.WriteBool(CurriculumEnabled);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (BehaviorName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(BehaviorName);
+ }
+ if (TrainerType.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(TrainerType);
+ }
+ if (ExtrinsicRewardEnabled != false) {
+ size += 1 + 1;
+ }
+ if (GailRewardEnabled != false) {
+ size += 1 + 1;
+ }
+ if (CuriosityRewardEnabled != false) {
+ size += 1 + 1;
+ }
+ if (RndRewardEnabled != false) {
+ size += 1 + 1;
+ }
+ if (BehavioralCloningEnabled != false) {
+ size += 1 + 1;
+ }
+ if (RecurrentEnabled != false) {
+ size += 1 + 1;
+ }
+ if (VisualEncoder.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(VisualEncoder);
+ }
+ if (NumNetworkLayers != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumNetworkLayers);
+ }
+ if (NumNetworkHiddenUnits != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumNetworkHiddenUnits);
+ }
+ if (TrainerThreaded != false) {
+ size += 1 + 1;
+ }
+ if (SelfPlayEnabled != false) {
+ size += 1 + 1;
+ }
+ if (CurriculumEnabled != false) {
+ size += 1 + 1;
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(TrainingBehaviorInitialized other) {
+ if (other == null) {
+ return;
+ }
+ if (other.BehaviorName.Length != 0) {
+ BehaviorName = other.BehaviorName;
+ }
+ if (other.TrainerType.Length != 0) {
+ TrainerType = other.TrainerType;
+ }
+ if (other.ExtrinsicRewardEnabled != false) {
+ ExtrinsicRewardEnabled = other.ExtrinsicRewardEnabled;
+ }
+ if (other.GailRewardEnabled != false) {
+ GailRewardEnabled = other.GailRewardEnabled;
+ }
+ if (other.CuriosityRewardEnabled != false) {
+ CuriosityRewardEnabled = other.CuriosityRewardEnabled;
+ }
+ if (other.RndRewardEnabled != false) {
+ RndRewardEnabled = other.RndRewardEnabled;
+ }
+ if (other.BehavioralCloningEnabled != false) {
+ BehavioralCloningEnabled = other.BehavioralCloningEnabled;
+ }
+ if (other.RecurrentEnabled != false) {
+ RecurrentEnabled = other.RecurrentEnabled;
+ }
+ if (other.VisualEncoder.Length != 0) {
+ VisualEncoder = other.VisualEncoder;
+ }
+ if (other.NumNetworkLayers != 0) {
+ NumNetworkLayers = other.NumNetworkLayers;
+ }
+ if (other.NumNetworkHiddenUnits != 0) {
+ NumNetworkHiddenUnits = other.NumNetworkHiddenUnits;
+ }
+ if (other.TrainerThreaded != false) {
+ TrainerThreaded = other.TrainerThreaded;
+ }
+ if (other.SelfPlayEnabled != false) {
+ SelfPlayEnabled = other.SelfPlayEnabled;
+ }
+ if (other.CurriculumEnabled != false) {
+ CurriculumEnabled = other.CurriculumEnabled;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ BehaviorName = input.ReadString();
+ break;
+ }
+ case 18: {
+ TrainerType = input.ReadString();
+ break;
+ }
+ case 24: {
+ ExtrinsicRewardEnabled = input.ReadBool();
+ break;
+ }
+ case 32: {
+ GailRewardEnabled = input.ReadBool();
+ break;
+ }
+ case 40: {
+ CuriosityRewardEnabled = input.ReadBool();
+ break;
+ }
+ case 48: {
+ RndRewardEnabled = input.ReadBool();
+ break;
+ }
+ case 56: {
+ BehavioralCloningEnabled = input.ReadBool();
+ break;
+ }
+ case 64: {
+ RecurrentEnabled = input.ReadBool();
+ break;
+ }
+ case 74: {
+ VisualEncoder = input.ReadString();
+ break;
+ }
+ case 80: {
+ NumNetworkLayers = input.ReadInt32();
+ break;
+ }
+ case 88: {
+ NumNetworkHiddenUnits = input.ReadInt32();
+ break;
+ }
+ case 96: {
+ TrainerThreaded = input.ReadBool();
+ break;
+ }
+ case 104: {
+ SelfPlayEnabled = input.ReadBool();
+ break;
+ }
+ case 112: {
+ CurriculumEnabled = input.ReadBool();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta
new file mode 100644
index 0000000000..8e9d358feb
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 9e6ac06a3931742d798cf922de6b99f0
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
index 176eac44d7..fbeff3a15e 100644
--- a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
@@ -1,7 +1,9 @@
using System.Collections.Generic;
using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Analytics;
using Unity.MLAgents.Sensors;
+
namespace Unity.MLAgents.Policies
{
///
@@ -14,6 +16,7 @@ internal class RemotePolicy : IPolicy
string m_FullyQualifiedBehaviorName;
ActionSpec m_ActionSpec;
ActionBuffers m_LastActionBuffer;
+ private bool m_AnalyticsSent = false;
internal ICommunicator m_Communicator;
@@ -24,13 +27,22 @@ public RemotePolicy(
{
m_FullyQualifiedBehaviorName = fullyQualifiedBehaviorName;
m_Communicator = Academy.Instance.Communicator;
- m_Communicator.SubscribeBrain(m_FullyQualifiedBehaviorName, actionSpec);
+ m_Communicator?.SubscribeBrain(m_FullyQualifiedBehaviorName, actionSpec);
m_ActionSpec = actionSpec;
}
///
public void RequestDecision(AgentInfo info, List sensors)
{
+ if (!m_AnalyticsSent)
+ {
+ m_AnalyticsSent = true;
+ TrainingAnalytics.RemotePolicyInitialized(
+ m_FullyQualifiedBehaviorName,
+ sensors,
+ m_ActionSpec
+ );
+ }
m_AgentId = info.episodeId;
m_Communicator?.PutObservations(m_FullyQualifiedBehaviorName, info, sensors);
}
diff --git a/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs
new file mode 100644
index 0000000000..6005954f00
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs
@@ -0,0 +1,50 @@
+using System;
+using UnityEngine;
+using Unity.MLAgents.Analytics;
+using Unity.MLAgents.CommunicatorObjects;
+
+namespace Unity.MLAgents.SideChannels
+{
+ public class TrainingAnalyticsSideChannel : SideChannel
+ {
+ const string k_TrainingAnalyticsConfigId = "b664a4a9-d86f-5a5f-95cb-e8353a7e8356";
+
+ ///
+ /// Initializes the side channel. The constructor is internal because only one instance is
+ /// supported at a time, and is created by the Academy.
+ ///
+ internal TrainingAnalyticsSideChannel()
+ {
+ ChannelId = new Guid(k_TrainingAnalyticsConfigId);
+ }
+
+ ///
+ protected override void OnMessageReceived(IncomingMessage msg)
+ {
+ Google.Protobuf.WellKnownTypes.Any anyMessage = null;
+ try
+ {
+ anyMessage = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(msg.GetRawBytes());
+ }
+ catch (Google.Protobuf.InvalidProtocolBufferException)
+ {
+ // Bad message, nothing we can do about it, so just ignore.
+ return;
+ }
+
+ if (anyMessage.Is(TrainingEnvironmentInitialized.Descriptor))
+ {
+ var envInitProto = anyMessage.Unpack();
+ var envInitEvent = envInitProto.ToTrainingEnvironmentInitializedEvent();
+ TrainingAnalytics.TrainingEnvironmentInitialized(envInitEvent);
+ }
+ else if (anyMessage.Is(TrainingBehaviorInitialized.Descriptor))
+ {
+ var behaviorInitProto = anyMessage.Unpack();
+ var behaviorTrainingEvent = behaviorInitProto.ToTrainingBehaviorInitializedEvent();
+ TrainingAnalytics.TrainingBehaviorInitialized(behaviorTrainingEvent);
+ }
+ // Don't do anything for unknown types, since the user probably can't do anything about it.
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta
new file mode 100644
index 0000000000..757d0d0d4f
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 13c87198bbd54b40a0b93308eb37933e
+timeCreated: 1608337471
\ No newline at end of file
diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs b/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs
index 0705796965..489c64f2f0 100644
--- a/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs
@@ -26,6 +26,11 @@ ActionSpec GetContinuous2vis8vec2actionActionSpec()
[SetUp]
public void SetUp()
{
+ if (Academy.IsInitialized)
+ {
+ Academy.Instance.Dispose();
+ }
+
continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel));
var go = new GameObject("SensorA");
sensor_21_20_3 = go.AddComponent();
@@ -65,5 +70,18 @@ public void TestModelEvent()
Assert.IsTrue(jsonString.Contains("SensorName"));
Assert.IsTrue(jsonString.Contains("Flags"));
}
+
+ [Test]
+ public void TestBarracudaPolicy()
+ {
+ // Explicitly request decisions for a policy so we get code coverage on the event sending
+ using (new AnalyticsUtils.DisableAnalyticsSending())
+ {
+ var sensors = new List { sensor_21_20_3.Sensor, sensor_20_22_3.Sensor };
+ var policy = new BarracudaPolicy(GetContinuous2vis8vec2actionActionSpec(), continuousONNXModel, InferenceDevice.CPU, "testBehavior");
+ policy.RequestDecision(new AgentInfo(), sensors);
+ }
+ Academy.Instance.Dispose();
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs
new file mode 100644
index 0000000000..6a5b30e72a
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs
@@ -0,0 +1,42 @@
+using System.Collections.Generic;
+using NUnit.Framework;
+using Unity.MLAgents.Sensors;
+using UnityEngine;
+using Unity.Barracuda;
+using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Analytics;
+using Unity.MLAgents.Policies;
+using UnityEditor;
+
+namespace Unity.MLAgents.Tests.Analytics
+{
+ [TestFixture]
+ public class TrainingAnalyticsTests
+ {
+ [TestCase("foo?team=42", ExpectedResult = "foo")]
+ [TestCase("foo", ExpectedResult = "foo")]
+ [TestCase("foo?bar?team=1337", ExpectedResult = "foo?bar")]
+ public string TestParseBehaviorName(string fullyQualifiedBehaviorName)
+ {
+ return TrainingAnalytics.ParseBehaviorName(fullyQualifiedBehaviorName);
+ }
+
+ [Test]
+ public void TestRemotePolicy()
+ {
+ if (Academy.IsInitialized)
+ {
+ Academy.Instance.Dispose();
+ }
+
+ using (new AnalyticsUtils.DisableAnalyticsSending())
+ {
+ var actionSpec = ActionSpec.MakeContinuous(3);
+ var policy = new RemotePolicy(actionSpec, "TestBehavior?team=42");
+ policy.RequestDecision(new AgentInfo(), new List());
+ }
+
+ Academy.Instance.Dispose();
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta
new file mode 100644
index 0000000000..df394c157a
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 70b8f1544bc34b4e8f1bc1068c64f01c
+timeCreated: 1610419546
\ No newline at end of file
diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
index fed5e97039..2529ecd96f 100644
--- a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
@@ -1,7 +1,9 @@
using NUnit.Framework;
-using Unity.MLAgents.Policies;
-using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Analytics;
+using Unity.MLAgents.CommunicatorObjects;
+using Unity.MLAgents.Demonstrations;
+using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
@@ -169,5 +171,31 @@ public void TestIsTrivialMapping()
sparseChannelSensor.Mapping = new[] { 0, 0, 0, 1, 1, 1 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
}
+
+ [Test]
+ public void TestDefaultTrainingEvents()
+ {
+ var trainingEnvInit = new TrainingEnvironmentInitialized
+ {
+ PythonVersion = "test",
+ };
+ var trainingEnvInitEvent = trainingEnvInit.ToTrainingEnvironmentInitializedEvent();
+ Assert.AreEqual(trainingEnvInit.PythonVersion, trainingEnvInitEvent.TrainerPythonVersion);
+
+ var trainingBehavInit = new TrainingBehaviorInitialized
+ {
+ BehaviorName = "testBehavior",
+ ExtrinsicRewardEnabled = true,
+ CuriosityRewardEnabled = true,
+
+ RecurrentEnabled = true,
+ SelfPlayEnabled = true,
+ };
+ var trainingBehavInitEvent = trainingBehavInit.ToTrainingBehaviorInitializedEvent();
+ Assert.AreEqual(trainingBehavInit.BehaviorName, trainingBehavInitEvent.BehaviorName);
+
+ Assert.AreEqual(RewardSignals.Extrinsic | RewardSignals.Curiosity, trainingBehavInitEvent.RewardSignalFlags);
+ Assert.AreEqual(TrainingFeatures.Recurrent | TrainingFeatures.SelfPlay, trainingBehavInitEvent.TrainingFeatureFlags);
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs
new file mode 100644
index 0000000000..d71abf6555
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs
@@ -0,0 +1,65 @@
+using System;
+using System.Linq;
+using System.Text;
+using NUnit.Framework;
+using Google.Protobuf;
+using Unity.MLAgents.Analytics;
+using Unity.MLAgents.SideChannels;
+using Unity.MLAgents.CommunicatorObjects;
+
+
+namespace Unity.MLAgents.Tests
+{
+ ///
+ /// These tests send messages through the event handling code.
+ /// There's no output to test, so just make sure there are no exceptions
+ /// (and get the code coverage above the minimum).
+ ///
+ public class TrainingAnalyticsSideChannelTests
+ {
+ [Test]
+ public void TestTrainingEnvironmentReceived()
+ {
+ var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingEnvironmentInitialized());
+ var anyMsgBytes = anyMsg.ToByteArray();
+ var sideChannel = new TrainingAnalyticsSideChannel();
+ using (new AnalyticsUtils.DisableAnalyticsSending())
+ {
+ sideChannel.ProcessMessage(anyMsgBytes);
+ }
+ }
+
+ [Test]
+ public void TestTrainingBehaviorReceived()
+ {
+ var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingBehaviorInitialized());
+ var anyMsgBytes = anyMsg.ToByteArray();
+ var sideChannel = new TrainingAnalyticsSideChannel();
+ using (new AnalyticsUtils.DisableAnalyticsSending())
+ {
+ sideChannel.ProcessMessage(anyMsgBytes);
+ }
+ }
+
+ [Test]
+ public void TestInvalidProtobufMessage()
+ {
+ // Test an invalid (non-protobuf) message. This should silently ignore the data.
+ var badBytes = Encoding.ASCII.GetBytes("Lorem ipsum");
+ var sideChannel = new TrainingAnalyticsSideChannel();
+ using (new AnalyticsUtils.DisableAnalyticsSending())
+ {
+ sideChannel.ProcessMessage(badBytes);
+ }
+
+ // Test an almost-valid message. This should silently ignore the data.
+ var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingBehaviorInitialized());
+ var anyMsgBytes = anyMsg.ToByteArray();
+ var truncatedMessage = new ArraySegment(anyMsgBytes, 0, anyMsgBytes.Length - 1).ToArray();
+ using (new AnalyticsUtils.DisableAnalyticsSending())
+ {
+ sideChannel.ProcessMessage(truncatedMessage);
+ }
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta
new file mode 100644
index 0000000000..ebb5915235
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: c2a71036ddec4ba4bf83c5e8ba1b8daa
+timeCreated: 1610574895
\ No newline at end of file
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
index bd87dc3b23..054fec848a 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
@@ -19,7 +19,7 @@
name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
- serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\x94\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
+ serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xaf\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
@@ -60,6 +60,13 @@
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='trainingAnalytics', full_name='communicator_objects.UnityRLCapabilitiesProto.trainingAnalytics', index=4,
+ number=5, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -73,7 +80,7 @@
oneofs=[
],
serialized_start=80,
- serialized_end=228,
+ serialized_end=255,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
index f1799fe6e5..69cf46fd0e 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
@@ -29,6 +29,7 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
concatenatedPngObservations = ... # type: builtin___bool
compressedChannelMapping = ... # type: builtin___bool
hybridActions = ... # type: builtin___bool
+ trainingAnalytics = ... # type: builtin___bool
def __init__(self,
*,
@@ -36,12 +37,13 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
concatenatedPngObservations : typing___Optional[builtin___bool] = None,
compressedChannelMapping : typing___Optional[builtin___bool] = None,
hybridActions : typing___Optional[builtin___bool] = None,
+ trainingAnalytics : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
- def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"trainingAnalytics"]) -> None: ...
else:
- def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"trainingAnalytics",b"trainingAnalytics"]) -> None: ...
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.py
new file mode 100644
index 0000000000..1e775c9710
--- /dev/null
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.py
@@ -0,0 +1,243 @@
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: mlagents_envs/communicator_objects/training_analytics.proto
+
+import sys
+_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf import descriptor_pb2
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+ name='mlagents_envs/communicator_objects/training_analytics.proto',
+ package='communicator_objects',
+ syntax='proto3',
+ serialized_pb=_b('\n;mlagents_envs/communicator_objects/training_analytics.proto\x12\x14\x63ommunicator_objects\"\xd9\x01\n\x1eTrainingEnvironmentInitialized\x12\x18\n\x10mlagents_version\x18\x01 \x01(\t\x12\x1d\n\x15mlagents_envs_version\x18\x02 \x01(\t\x12\x16\n\x0epython_version\x18\x03 \x01(\t\x12\x15\n\rtorch_version\x18\x04 \x01(\t\x12\x19\n\x11torch_device_type\x18\x05 \x01(\t\x12\x10\n\x08num_envs\x18\x06 \x01(\x05\x12\"\n\x1anum_environment_parameters\x18\x07 \x01(\x05\"\xad\x03\n\x1bTrainingBehaviorInitialized\x12\x15\n\rbehavior_name\x18\x01 \x01(\t\x12\x14\n\x0ctrainer_type\x18\x02 \x01(\t\x12 \n\x18\x65xtrinsic_reward_enabled\x18\x03 \x01(\x08\x12\x1b\n\x13gail_reward_enabled\x18\x04 \x01(\x08\x12 \n\x18\x63uriosity_reward_enabled\x18\x05 \x01(\x08\x12\x1a\n\x12rnd_reward_enabled\x18\x06 \x01(\x08\x12\"\n\x1a\x62\x65havioral_cloning_enabled\x18\x07 \x01(\x08\x12\x19\n\x11recurrent_enabled\x18\x08 \x01(\x08\x12\x16\n\x0evisual_encoder\x18\t \x01(\t\x12\x1a\n\x12num_network_layers\x18\n \x01(\x05\x12 \n\x18num_network_hidden_units\x18\x0b \x01(\x05\x12\x18\n\x10trainer_threaded\x18\x0c \x01(\x08\x12\x19\n\x11self_play_enabled\x18\r \x01(\x08\x12\x1a\n\x12\x63urriculum_enabled\x18\x0e \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
+)
+
+
+
+
+_TRAININGENVIRONMENTINITIALIZED = _descriptor.Descriptor(
+ name='TrainingEnvironmentInitialized',
+ full_name='communicator_objects.TrainingEnvironmentInitialized',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='mlagents_version', full_name='communicator_objects.TrainingEnvironmentInitialized.mlagents_version', index=0,
+ number=1, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='mlagents_envs_version', full_name='communicator_objects.TrainingEnvironmentInitialized.mlagents_envs_version', index=1,
+ number=2, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='python_version', full_name='communicator_objects.TrainingEnvironmentInitialized.python_version', index=2,
+ number=3, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='torch_version', full_name='communicator_objects.TrainingEnvironmentInitialized.torch_version', index=3,
+ number=4, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='torch_device_type', full_name='communicator_objects.TrainingEnvironmentInitialized.torch_device_type', index=4,
+ number=5, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='num_envs', full_name='communicator_objects.TrainingEnvironmentInitialized.num_envs', index=5,
+ number=6, type=5, cpp_type=1, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='num_environment_parameters', full_name='communicator_objects.TrainingEnvironmentInitialized.num_environment_parameters', index=6,
+ number=7, type=5, cpp_type=1, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=86,
+ serialized_end=303,
+)
+
+
+_TRAININGBEHAVIORINITIALIZED = _descriptor.Descriptor(
+ name='TrainingBehaviorInitialized',
+ full_name='communicator_objects.TrainingBehaviorInitialized',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='behavior_name', full_name='communicator_objects.TrainingBehaviorInitialized.behavior_name', index=0,
+ number=1, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='trainer_type', full_name='communicator_objects.TrainingBehaviorInitialized.trainer_type', index=1,
+ number=2, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='extrinsic_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.extrinsic_reward_enabled', index=2,
+ number=3, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='gail_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.gail_reward_enabled', index=3,
+ number=4, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='curiosity_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.curiosity_reward_enabled', index=4,
+ number=5, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='rnd_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.rnd_reward_enabled', index=5,
+ number=6, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='behavioral_cloning_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.behavioral_cloning_enabled', index=6,
+ number=7, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='recurrent_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.recurrent_enabled', index=7,
+ number=8, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='visual_encoder', full_name='communicator_objects.TrainingBehaviorInitialized.visual_encoder', index=8,
+ number=9, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='num_network_layers', full_name='communicator_objects.TrainingBehaviorInitialized.num_network_layers', index=9,
+ number=10, type=5, cpp_type=1, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='num_network_hidden_units', full_name='communicator_objects.TrainingBehaviorInitialized.num_network_hidden_units', index=10,
+ number=11, type=5, cpp_type=1, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='trainer_threaded', full_name='communicator_objects.TrainingBehaviorInitialized.trainer_threaded', index=11,
+ number=12, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='self_play_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.self_play_enabled', index=12,
+ number=13, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='curriculum_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.curriculum_enabled', index=13,
+ number=14, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=306,
+ serialized_end=735,
+)
+
+DESCRIPTOR.message_types_by_name['TrainingEnvironmentInitialized'] = _TRAININGENVIRONMENTINITIALIZED
+DESCRIPTOR.message_types_by_name['TrainingBehaviorInitialized'] = _TRAININGBEHAVIORINITIALIZED
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+TrainingEnvironmentInitialized = _reflection.GeneratedProtocolMessageType('TrainingEnvironmentInitialized', (_message.Message,), dict(
+ DESCRIPTOR = _TRAININGENVIRONMENTINITIALIZED,
+ __module__ = 'mlagents_envs.communicator_objects.training_analytics_pb2'
+ # @@protoc_insertion_point(class_scope:communicator_objects.TrainingEnvironmentInitialized)
+ ))
+_sym_db.RegisterMessage(TrainingEnvironmentInitialized)
+
+TrainingBehaviorInitialized = _reflection.GeneratedProtocolMessageType('TrainingBehaviorInitialized', (_message.Message,), dict(
+ DESCRIPTOR = _TRAININGBEHAVIORINITIALIZED,
+ __module__ = 'mlagents_envs.communicator_objects.training_analytics_pb2'
+ # @@protoc_insertion_point(class_scope:communicator_objects.TrainingBehaviorInitialized)
+ ))
+_sym_db.RegisterMessage(TrainingBehaviorInitialized)
+
+
+DESCRIPTOR.has_options = True
+DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\"Unity.MLAgents.CommunicatorObjects'))
+# @@protoc_insertion_point(module_scope)
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.pyi
new file mode 100644
index 0000000000..a347de6874
--- /dev/null
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.pyi
@@ -0,0 +1,97 @@
+# @generated by generate_proto_mypy_stubs.py. Do not edit!
+import sys
+from google.protobuf.descriptor import (
+ Descriptor as google___protobuf___descriptor___Descriptor,
+)
+
+from google.protobuf.message import (
+ Message as google___protobuf___message___Message,
+)
+
+from typing import (
+ Optional as typing___Optional,
+ Text as typing___Text,
+)
+
+from typing_extensions import (
+ Literal as typing_extensions___Literal,
+)
+
+
+builtin___bool = bool
+builtin___bytes = bytes
+builtin___float = float
+builtin___int = int
+
+
+class TrainingEnvironmentInitialized(google___protobuf___message___Message):
+ DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
+ mlagents_version = ... # type: typing___Text
+ mlagents_envs_version = ... # type: typing___Text
+ python_version = ... # type: typing___Text
+ torch_version = ... # type: typing___Text
+ torch_device_type = ... # type: typing___Text
+ num_envs = ... # type: builtin___int
+ num_environment_parameters = ... # type: builtin___int
+
+ def __init__(self,
+ *,
+ mlagents_version : typing___Optional[typing___Text] = None,
+ mlagents_envs_version : typing___Optional[typing___Text] = None,
+ python_version : typing___Optional[typing___Text] = None,
+ torch_version : typing___Optional[typing___Text] = None,
+ torch_device_type : typing___Optional[typing___Text] = None,
+ num_envs : typing___Optional[builtin___int] = None,
+ num_environment_parameters : typing___Optional[builtin___int] = None,
+ ) -> None: ...
+ @classmethod
+ def FromString(cls, s: builtin___bytes) -> TrainingEnvironmentInitialized: ...
+ def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
+ def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
+ if sys.version_info >= (3,):
+ def ClearField(self, field_name: typing_extensions___Literal[u"mlagents_envs_version",u"mlagents_version",u"num_environment_parameters",u"num_envs",u"python_version",u"torch_device_type",u"torch_version"]) -> None: ...
+ else:
+ def ClearField(self, field_name: typing_extensions___Literal[u"mlagents_envs_version",b"mlagents_envs_version",u"mlagents_version",b"mlagents_version",u"num_environment_parameters",b"num_environment_parameters",u"num_envs",b"num_envs",u"python_version",b"python_version",u"torch_device_type",b"torch_device_type",u"torch_version",b"torch_version"]) -> None: ...
+
+class TrainingBehaviorInitialized(google___protobuf___message___Message):
+ DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
+ behavior_name = ... # type: typing___Text
+ trainer_type = ... # type: typing___Text
+ extrinsic_reward_enabled = ... # type: builtin___bool
+ gail_reward_enabled = ... # type: builtin___bool
+ curiosity_reward_enabled = ... # type: builtin___bool
+ rnd_reward_enabled = ... # type: builtin___bool
+ behavioral_cloning_enabled = ... # type: builtin___bool
+ recurrent_enabled = ... # type: builtin___bool
+ visual_encoder = ... # type: typing___Text
+ num_network_layers = ... # type: builtin___int
+ num_network_hidden_units = ... # type: builtin___int
+ trainer_threaded = ... # type: builtin___bool
+ self_play_enabled = ... # type: builtin___bool
+ curriculum_enabled = ... # type: builtin___bool
+
+ def __init__(self,
+ *,
+ behavior_name : typing___Optional[typing___Text] = None,
+ trainer_type : typing___Optional[typing___Text] = None,
+ extrinsic_reward_enabled : typing___Optional[builtin___bool] = None,
+ gail_reward_enabled : typing___Optional[builtin___bool] = None,
+ curiosity_reward_enabled : typing___Optional[builtin___bool] = None,
+ rnd_reward_enabled : typing___Optional[builtin___bool] = None,
+ behavioral_cloning_enabled : typing___Optional[builtin___bool] = None,
+ recurrent_enabled : typing___Optional[builtin___bool] = None,
+ visual_encoder : typing___Optional[typing___Text] = None,
+ num_network_layers : typing___Optional[builtin___int] = None,
+ num_network_hidden_units : typing___Optional[builtin___int] = None,
+ trainer_threaded : typing___Optional[builtin___bool] = None,
+ self_play_enabled : typing___Optional[builtin___bool] = None,
+ curriculum_enabled : typing___Optional[builtin___bool] = None,
+ ) -> None: ...
+ @classmethod
+ def FromString(cls, s: builtin___bytes) -> TrainingBehaviorInitialized: ...
+ def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
+ def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
+ if sys.version_info >= (3,):
+ def ClearField(self, field_name: typing_extensions___Literal[u"behavior_name",u"behavioral_cloning_enabled",u"curiosity_reward_enabled",u"curriculum_enabled",u"extrinsic_reward_enabled",u"gail_reward_enabled",u"num_network_hidden_units",u"num_network_layers",u"recurrent_enabled",u"rnd_reward_enabled",u"self_play_enabled",u"trainer_threaded",u"trainer_type",u"visual_encoder"]) -> None: ...
+ else:
+ def ClearField(self, field_name: typing_extensions___Literal[u"behavior_name",b"behavior_name",u"behavioral_cloning_enabled",b"behavioral_cloning_enabled",u"curiosity_reward_enabled",b"curiosity_reward_enabled",u"curriculum_enabled",b"curriculum_enabled",u"extrinsic_reward_enabled",b"extrinsic_reward_enabled",u"gail_reward_enabled",b"gail_reward_enabled",u"num_network_hidden_units",b"num_network_hidden_units",u"num_network_layers",b"num_network_layers",u"recurrent_enabled",b"recurrent_enabled",u"rnd_reward_enabled",b"rnd_reward_enabled",u"self_play_enabled",b"self_play_enabled",u"trainer_threaded",b"trainer_threaded",u"trainer_type",b"trainer_type",u"visual_encoder",b"visual_encoder"]) -> None: ...
diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py
index 14b4dddb53..5d844749c4 100644
--- a/ml-agents-envs/mlagents_envs/environment.py
+++ b/ml-agents-envs/mlagents_envs/environment.py
@@ -62,7 +62,8 @@ class UnityEnvironment(BaseEnv):
# * 1.1.0 - support concatenated PNGs for compressed observations.
# * 1.2.0 - support compression mapping for stacked compressed observations.
# * 1.3.0 - support action spaces with both continuous and discrete actions.
- API_VERSION = "1.3.0"
+ # * 1.4.0 - support training analytics sent from python trainer to the editor.
+ API_VERSION = "1.4.0"
# Default port that the editor listens on. If an environment executable
# isn't specified, this port will be used.
@@ -120,6 +121,7 @@ def _get_capabilities_proto() -> UnityRLCapabilitiesProto:
capabilities.concatenatedPngObservations = True
capabilities.compressedChannelMapping = True
capabilities.hybridActions = True
+ capabilities.trainingAnalytics = True
return capabilities
@staticmethod
@@ -183,6 +185,7 @@ def __init__(
self._worker_id = worker_id
self._side_channel_manager = SideChannelManager(side_channels)
self._log_folder = log_folder
+ self.academy_capabilities: UnityRLCapabilitiesProto = None # type: ignore
# If the environment name is None, a new environment will not be launched
# and the communicator will directly try to connect to an existing unity environment.
@@ -239,6 +242,7 @@ def __init__(
self._env_actions: Dict[str, ActionTuple] = {}
self._is_first_message = True
self._update_behavior_specs(aca_output)
+ self.academy_capabilities = aca_params.capabilities
@staticmethod
def _get_communicator(worker_id, base_port, timeout_wait):
diff --git a/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py b/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py
index 4e4c15d795..ce1715ba07 100644
--- a/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py
+++ b/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py
@@ -53,7 +53,7 @@ def on_message_received(self, msg: IncomingMessage) -> None:
"""
raise UnityCommunicationException(
"The EngineConfigurationChannel received a message from Unity, "
- + "this should not have happend."
+ + "this should not have happened."
)
def set_configuration_parameters(
diff --git a/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py b/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
index 9629f1b5d5..ff516a5eb5 100644
--- a/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
+++ b/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
@@ -28,7 +28,7 @@ def __init__(self) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
raise UnityCommunicationException(
"The EnvironmentParametersChannel received a message from Unity, "
- + "this should not have happend."
+ + "this should not have happened."
)
def set_float_parameter(self, key: str, value: float) -> None:
diff --git a/ml-agents/mlagents/trainers/env_manager.py b/ml-agents/mlagents/trainers/env_manager.py
index 905badd9f4..a4a90fdc70 100644
--- a/ml-agents/mlagents/trainers/env_manager.py
+++ b/ml-agents/mlagents/trainers/env_manager.py
@@ -12,6 +12,7 @@
from mlagents.trainers.policy import Policy
from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue
from mlagents.trainers.action_info import ActionInfo
+from mlagents.trainers.settings import TrainerSettings
from mlagents_envs.logging_util import get_logger
AllStepResult = Dict[BehaviorName, Tuple[DecisionSteps, TerminalSteps]]
@@ -76,6 +77,17 @@ def set_env_parameters(self, config: Dict = None) -> None:
"""
pass
+ def on_training_started(
+ self, behavior_name: str, trainer_settings: TrainerSettings
+ ) -> None:
+ """
+ Handle traing starting for a new behavior type. Generally nothing is necessary here.
+ :param behavior_name:
+ :param trainer_settings:
+ :return:
+ """
+ pass
+
@property
@abstractmethod
def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]:
diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py
index 82b65f59b4..c5bf8daac0 100644
--- a/ml-agents/mlagents/trainers/learn.py
+++ b/ml-agents/mlagents/trainers/learn.py
@@ -28,7 +28,6 @@
from mlagents_envs.base_env import BaseEnv
from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
from mlagents_envs.side_channel.side_channel import SideChannel
-from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.timers import (
hierarchical_timer,
get_timer_tree,
@@ -109,17 +108,8 @@ def run_training(run_seed: int, options: RunOptions) -> None:
env_settings.env_args,
os.path.abspath(run_logs_dir), # Unity environment requires absolute path
)
- engine_config = EngineConfig(
- width=engine_settings.width,
- height=engine_settings.height,
- quality_level=engine_settings.quality_level,
- time_scale=engine_settings.time_scale,
- target_frame_rate=engine_settings.target_frame_rate,
- capture_frame_rate=engine_settings.capture_frame_rate,
- )
- env_manager = SubprocessEnvManager(
- env_factory, engine_config, env_settings.num_envs
- )
+
+ env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs)
env_parameter_manager = EnvironmentParameterManager(
options.environment_parameters, run_seed, restore=checkpoint_settings.resume
)
diff --git a/ml-agents/mlagents/trainers/subprocess_env_manager.py b/ml-agents/mlagents/trainers/subprocess_env_manager.py
index 1eb5a4aa9f..4bad60c8b4 100644
--- a/ml-agents/mlagents/trainers/subprocess_env_manager.py
+++ b/ml-agents/mlagents/trainers/subprocess_env_manager.py
@@ -16,6 +16,7 @@
from mlagents_envs.base_env import BaseEnv, BehaviorName, BehaviorSpec
from mlagents_envs import logging_util
from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult
+from mlagents.trainers.settings import TrainerSettings
from mlagents_envs.timers import (
TimerNode,
timed,
@@ -23,7 +24,7 @@
reset_timers,
get_timer_root,
)
-from mlagents.trainers.settings import ParameterRandomizationSettings
+from mlagents.trainers.settings import ParameterRandomizationSettings, RunOptions
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.side_channel.environment_parameters_channel import (
EnvironmentParametersChannel,
@@ -33,9 +34,10 @@
EngineConfig,
)
from mlagents_envs.side_channel.stats_side_channel import (
- StatsSideChannel,
EnvironmentStats,
+ StatsSideChannel,
)
+from mlagents.training_analytics_side_channel import TrainingAnalyticsSideChannel
from mlagents_envs.side_channel.side_channel import SideChannel
@@ -51,6 +53,7 @@ class EnvironmentCommand(enum.Enum):
CLOSE = 5
ENV_EXITED = 6
CLOSED = 7
+ TRAINING_STARTED = 8
class EnvironmentRequest(NamedTuple):
@@ -112,17 +115,30 @@ def worker(
step_queue: Queue,
pickled_env_factory: str,
worker_id: int,
- engine_configuration: EngineConfig,
+ run_options: RunOptions,
log_level: int = logging_util.INFO,
) -> None:
env_factory: Callable[
[int, List[SideChannel]], UnityEnvironment
] = cloudpickle.loads(pickled_env_factory)
env_parameters = EnvironmentParametersChannel()
+
+ engine_config = EngineConfig(
+ width=run_options.engine_settings.width,
+ height=run_options.engine_settings.height,
+ quality_level=run_options.engine_settings.quality_level,
+ time_scale=run_options.engine_settings.time_scale,
+ target_frame_rate=run_options.engine_settings.target_frame_rate,
+ capture_frame_rate=run_options.engine_settings.capture_frame_rate,
+ )
engine_configuration_channel = EngineConfigurationChannel()
- engine_configuration_channel.set_configuration(engine_configuration)
+ engine_configuration_channel.set_configuration(engine_config)
+
stats_channel = StatsSideChannel()
- env: BaseEnv = None
+ training_analytics_channel: Optional[TrainingAnalyticsSideChannel] = None
+ if worker_id == 0:
+ training_analytics_channel = TrainingAnalyticsSideChannel()
+ env: UnityEnvironment = None
# Set log level. On some platforms, the logger isn't common with the
# main process, so we need to set it again.
logging_util.set_log_level(log_level)
@@ -137,9 +153,21 @@ def _generate_all_results() -> AllStepResult:
return all_step_result
try:
- env = env_factory(
- worker_id, [env_parameters, engine_configuration_channel, stats_channel]
- )
+ side_channels = [env_parameters, engine_configuration_channel, stats_channel]
+ if training_analytics_channel is not None:
+ side_channels.append(training_analytics_channel)
+
+ env = env_factory(worker_id, side_channels)
+ if (
+ not env.academy_capabilities
+ or not env.academy_capabilities.trainingAnalytics
+ ):
+ # Make sure we don't try to send training analytics if the environment doesn't know how to process
+ # them. This wouldn't be catastrophic, but would result in unknown SideChannel UUIDs being used.
+ training_analytics_channel = None
+ if training_analytics_channel:
+ training_analytics_channel.environment_initialized(run_options)
+
while True:
req: EnvironmentRequest = parent_conn.recv()
if req.cmd == EnvironmentCommand.STEP:
@@ -170,6 +198,12 @@ def _generate_all_results() -> AllStepResult:
for k, v in req.payload.items():
if isinstance(v, ParameterRandomizationSettings):
v.apply(k, env_parameters)
+ elif req.cmd == EnvironmentCommand.TRAINING_STARTED:
+ behavior_name, trainer_config = req.payload
+ if training_analytics_channel:
+ training_analytics_channel.training_started(
+ behavior_name, trainer_config
+ )
elif req.cmd == EnvironmentCommand.RESET:
env.reset()
all_step_result = _generate_all_results()
@@ -210,7 +244,7 @@ class SubprocessEnvManager(EnvManager):
def __init__(
self,
env_factory: Callable[[int, List[SideChannel]], BaseEnv],
- engine_configuration: EngineConfig,
+ run_options: RunOptions,
n_env: int = 1,
):
super().__init__()
@@ -220,7 +254,7 @@ def __init__(
for worker_idx in range(n_env):
self.env_workers.append(
self.create_worker(
- worker_idx, self.step_queue, env_factory, engine_configuration
+ worker_idx, self.step_queue, env_factory, run_options
)
)
self.workers_alive += 1
@@ -230,7 +264,7 @@ def create_worker(
worker_id: int,
step_queue: Queue,
env_factory: Callable[[int, List[SideChannel]], BaseEnv],
- engine_configuration: EngineConfig,
+ run_options: RunOptions,
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()
@@ -244,7 +278,7 @@ def create_worker(
step_queue,
pickled_env_factory,
worker_id,
- engine_configuration,
+ run_options,
logger.level,
),
)
@@ -308,6 +342,20 @@ def set_env_parameters(self, config: Dict = None) -> None:
for ew in self.env_workers:
ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config)
+ def on_training_started(
+ self, behavior_name: str, trainer_settings: TrainerSettings
+ ) -> None:
+ """
+ Handle traing starting for a new behavior type. Generally nothing is necessary here.
+ :param behavior_name:
+ :param trainer_settings:
+ :return:
+ """
+ for ew in self.env_workers:
+ ew.send(
+ EnvironmentCommand.TRAINING_STARTED, (behavior_name, trainer_settings)
+ )
+
@property
def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]:
result: Dict[BehaviorName, BehaviorSpec] = {}
diff --git a/ml-agents/mlagents/trainers/tests/simple_test_envs.py b/ml-agents/mlagents/trainers/tests/simple_test_envs.py
index 01a196c8db..e7f44b0f56 100644
--- a/ml-agents/mlagents/trainers/tests/simple_test_envs.py
+++ b/ml-agents/mlagents/trainers/tests/simple_test_envs.py
@@ -72,6 +72,8 @@ def __init__(
self.step_result: Dict[str, Tuple[DecisionSteps, TerminalSteps]] = {}
self.agent_id: Dict[str, int] = {}
self.step_size = step_size # defines the difficulty of the test
+ # Allow to be used as a UnityEnvironment during tests
+ self.academy_capabilities = None
for name in self.names:
self.agent_id[name] = 0
diff --git a/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py b/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
index 9b99d833c3..4fd9cc96c2 100644
--- a/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
+++ b/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
@@ -4,6 +4,7 @@
import pytest
from queue import Empty as EmptyQueue
+from mlagents.trainers.settings import RunOptions
from mlagents.trainers.subprocess_env_manager import (
SubprocessEnvManager,
EnvironmentResponse,
@@ -12,7 +13,6 @@
)
from mlagents.trainers.env_manager import EnvironmentStep
from mlagents_envs.base_env import BaseEnv
-from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents_envs.exception import UnityEnvironmentException
from mlagents.trainers.tests.simple_test_envs import (
@@ -54,16 +54,13 @@ class SubprocessEnvManagerTest(unittest.TestCase):
)
def test_environments_are_created(self, mock_create_worker):
mock_create_worker.side_effect = create_worker_mock
- env = SubprocessEnvManager(mock_env_factory, EngineConfig.default_config(), 2)
+ run_options = RunOptions()
+ env = SubprocessEnvManager(mock_env_factory, run_options, 2)
# Creates two processes
env.create_worker.assert_has_calls(
[
- mock.call(
- 0, env.step_queue, mock_env_factory, EngineConfig.default_config()
- ),
- mock.call(
- 1, env.step_queue, mock_env_factory, EngineConfig.default_config()
- ),
+ mock.call(0, env.step_queue, mock_env_factory, run_options),
+ mock.call(1, env.step_queue, mock_env_factory, run_options),
]
)
self.assertEqual(len(env.env_workers), 2)
@@ -73,9 +70,7 @@ def test_environments_are_created(self, mock_create_worker):
)
def test_reset_passes_reset_params(self, mock_create_worker):
mock_create_worker.side_effect = create_worker_mock
- manager = SubprocessEnvManager(
- mock_env_factory, EngineConfig.default_config(), 1
- )
+ manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 1)
params = {"test": "params"}
manager._reset_env(params)
manager.env_workers[0].send.assert_called_with(
@@ -87,9 +82,7 @@ def test_reset_passes_reset_params(self, mock_create_worker):
)
def test_reset_collects_results_from_all_envs(self, mock_create_worker):
mock_create_worker.side_effect = create_worker_mock
- manager = SubprocessEnvManager(
- mock_env_factory, EngineConfig.default_config(), 4
- )
+ manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 4)
params = {"test": "params"}
res = manager._reset_env(params)
@@ -117,9 +110,7 @@ def create_worker_mock(worker_id, step_queue, env_factor, engine_c):
)
mock_create_worker.side_effect = create_worker_mock
- manager = SubprocessEnvManager(
- mock_env_factory, EngineConfig.default_config(), 4
- )
+ manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 4)
res = manager.training_behaviors
for env in manager.env_workers:
@@ -134,9 +125,7 @@ def create_worker_mock(worker_id, step_queue, env_factor, engine_c):
)
def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker):
mock_create_worker.side_effect = create_worker_mock
- manager = SubprocessEnvManager(
- mock_env_factory, EngineConfig.default_config(), 3
- )
+ manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 3)
manager.step_queue = Mock()
manager.step_queue.get_nowait.side_effect = [
EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(0, None, {})),
@@ -176,9 +165,7 @@ def test_advance(self, mock_create_worker, training_behaviors_mock, step_mock):
brain_name = "testbrain"
action_info_dict = {brain_name: MagicMock()}
mock_create_worker.side_effect = create_worker_mock
- env_manager = SubprocessEnvManager(
- mock_env_factory, EngineConfig.default_config(), 3
- )
+ env_manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 3)
training_behaviors_mock.return_value = [brain_name]
agent_manager_mock = mock.Mock()
mock_policy = mock.Mock()
@@ -219,9 +206,7 @@ def simple_env_factory(worker_id, config):
env = SimpleEnvironment(["1D"], action_sizes=(0, 1))
return env
- env_manager = SubprocessEnvManager(
- simple_env_factory, EngineConfig.default_config(), num_envs
- )
+ env_manager = SubprocessEnvManager(simple_env_factory, RunOptions(), num_envs)
# Run PPO using env_manager
check_environment_trains(
simple_env_factory(0, []),
@@ -250,9 +235,7 @@ def failing_step_env_factory(_worker_id, _config):
)
return env
- env_manager = SubprocessEnvManager(
- failing_step_env_factory, EngineConfig.default_config()
- )
+ env_manager = SubprocessEnvManager(failing_step_env_factory, RunOptions())
# Expect the exception raised to be routed back up to the top level.
with pytest.raises(CustomTestOnlyException):
check_environment_trains(
@@ -275,9 +258,7 @@ def failing_env_factory(worker_id, config):
time.sleep(0.5)
raise UnityEnvironmentException()
- env_manager = SubprocessEnvManager(
- failing_env_factory, EngineConfig.default_config(), num_envs
- )
+ env_manager = SubprocessEnvManager(failing_env_factory, RunOptions(), num_envs)
with pytest.raises(UnityEnvironmentException):
env_manager.reset()
env_manager.close()
diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py
index 7f9808f5dd..c4a60f8b3a 100644
--- a/ml-agents/mlagents/trainers/trainer_controller.py
+++ b/ml-agents/mlagents/trainers/trainer_controller.py
@@ -130,6 +130,9 @@ def _create_trainer_and_manager(
target=self.trainer_update_func, args=(trainer,), daemon=True
)
self.trainer_threads.append(trainerthread)
+ env_manager.on_training_started(
+ brain_name, self.trainer_factory.trainer_config[brain_name]
+ )
policy = trainer.create_policy(
parsed_behavior_id,
diff --git a/ml-agents/mlagents/training_analytics_side_channel.py b/ml-agents/mlagents/training_analytics_side_channel.py
new file mode 100644
index 0000000000..f964f13fac
--- /dev/null
+++ b/ml-agents/mlagents/training_analytics_side_channel.py
@@ -0,0 +1,99 @@
+import sys
+from typing import Optional
+import uuid
+import mlagents_envs
+import mlagents.trainers
+from mlagents import torch_utils
+from mlagents.trainers.settings import RewardSignalType
+from mlagents_envs.exception import UnityCommunicationException
+from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
+from mlagents_envs.communicator_objects.training_analytics_pb2 import (
+ TrainingEnvironmentInitialized,
+ TrainingBehaviorInitialized,
+)
+from google.protobuf.any_pb2 import Any
+
+from mlagents.trainers.settings import TrainerSettings, RunOptions
+
+
+class TrainingAnalyticsSideChannel(SideChannel):
+ """
+ Side channel that sends information about the training to the Unity environment so it can be logged.
+ """
+
+ def __init__(self) -> None:
+ # >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/TrainingAnalyticsSideChannel")
+ # UUID('b664a4a9-d86f-5a5f-95cb-e8353a7e8356')
+ super().__init__(uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356"))
+ self.run_options: Optional[RunOptions] = None
+
+ def on_message_received(self, msg: IncomingMessage) -> None:
+ raise UnityCommunicationException(
+ "The TrainingAnalyticsSideChannel received a message from Unity, "
+ + "this should not have happened."
+ )
+
+ def environment_initialized(self, run_options: RunOptions) -> None:
+ self.run_options = run_options
+ # Tuple of (major, minor, patch)
+ vi = sys.version_info
+ env_params = run_options.environment_parameters
+
+ msg = TrainingEnvironmentInitialized(
+ python_version=f"{vi[0]}.{vi[1]}.{vi[2]}",
+ mlagents_version=mlagents.trainers.__version__,
+ mlagents_envs_version=mlagents_envs.__version__,
+ torch_version=torch_utils.torch.__version__,
+ torch_device_type=torch_utils.default_device().type,
+ num_envs=run_options.env_settings.num_envs,
+ num_environment_parameters=len(env_params) if env_params else 0,
+ )
+
+ any_message = Any()
+ any_message.Pack(msg)
+
+ env_init_msg = OutgoingMessage()
+ env_init_msg.set_raw_bytes(any_message.SerializeToString())
+ super().queue_message_to_send(env_init_msg)
+
+ def training_started(self, behavior_name: str, config: TrainerSettings) -> None:
+ msg = TrainingBehaviorInitialized(
+ behavior_name=behavior_name,
+ trainer_type=config.trainer_type.value,
+ extrinsic_reward_enabled=(
+ RewardSignalType.EXTRINSIC in config.reward_signals
+ ),
+ gail_reward_enabled=(RewardSignalType.GAIL in config.reward_signals),
+ curiosity_reward_enabled=(
+ RewardSignalType.CURIOSITY in config.reward_signals
+ ),
+ rnd_reward_enabled=(RewardSignalType.RND in config.reward_signals),
+ behavioral_cloning_enabled=config.behavioral_cloning is not None,
+ recurrent_enabled=config.network_settings.memory is not None,
+ visual_encoder=config.network_settings.vis_encode_type.value,
+ num_network_layers=config.network_settings.num_layers,
+ num_network_hidden_units=config.network_settings.hidden_units,
+ trainer_threaded=config.threaded,
+ self_play_enabled=config.self_play is not None,
+ curriculum_enabled=self._behavior_uses_curriculum(behavior_name),
+ )
+
+ any_message = Any()
+ any_message.Pack(msg)
+
+ training_start_msg = OutgoingMessage()
+ training_start_msg.set_raw_bytes(any_message.SerializeToString())
+
+ super().queue_message_to_send(training_start_msg)
+
+ def _behavior_uses_curriculum(self, behavior_name: str) -> bool:
+ if not self.run_options or not self.run_options.environment_parameters:
+ return False
+
+ for param_settings in self.run_options.environment_parameters.values():
+ for lesson in param_settings.curriculum:
+ cc = lesson.completion_criteria
+ if cc and cc.behavior == behavior_name:
+ return True
+
+ return False
diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
index 109b1f0c97..723cc2cd5c 100644
--- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
+++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
@@ -19,4 +19,7 @@ message UnityRLCapabilitiesProto {
// support for hybrid action spaces (discrete + continuous)
bool hybridActions = 4;
+
+ // support for training analytics
+ bool trainingAnalytics = 5;
}
diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/training_analytics.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/training_analytics.proto
new file mode 100644
index 0000000000..52f0ed6109
--- /dev/null
+++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/training_analytics.proto
@@ -0,0 +1,31 @@
+syntax = "proto3";
+
+option csharp_namespace = "Unity.MLAgents.CommunicatorObjects";
+package communicator_objects;
+
+message TrainingEnvironmentInitialized {
+ string mlagents_version = 1;
+ string mlagents_envs_version = 2;
+ string python_version = 3;
+ string torch_version = 4;
+ string torch_device_type = 5;
+ int32 num_envs = 6;
+ int32 num_environment_parameters = 7;
+}
+
+message TrainingBehaviorInitialized {
+ string behavior_name = 1;
+ string trainer_type = 2;
+ bool extrinsic_reward_enabled = 3;
+ bool gail_reward_enabled = 4;
+ bool curiosity_reward_enabled = 5;
+ bool rnd_reward_enabled = 6;
+ bool behavioral_cloning_enabled = 7;
+ bool recurrent_enabled = 8;
+ string visual_encoder = 9;
+ int32 num_network_layers = 10;
+ int32 num_network_hidden_units = 11;
+ bool trainer_threaded = 12;
+ bool self_play_enabled = 13;
+ bool curriculum_enabled = 14;
+}