diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index c42fd1afef..1b7cdb457a 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -91,9 +91,13 @@ public class Academy : IDisposable /// 1.3.0 /// Support both continuous and discrete actions. /// + /// + /// 1.4.0 + /// Support training analytics sent from python trainer to the editor. + /// /// /// - const string k_ApiVersion = "1.3.0"; + const string k_ApiVersion = "1.4.0"; /// /// Unity package version of com.unity.ml-agents. @@ -406,6 +410,7 @@ void InitializeEnvironment() EnableAutomaticStepping(); SideChannelManager.RegisterSideChannel(new EngineConfigurationChannel()); + SideChannelManager.RegisterSideChannel(new TrainingAnalyticsSideChannel()); m_EnvironmentParameters = new EnvironmentParameters(); m_StatsRecorder = new StatsRecorder(); diff --git a/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs new file mode 100644 index 0000000000..fb480b7a11 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs @@ -0,0 +1,40 @@ +using System; +using UnityEngine; + +namespace Unity.MLAgents.Analytics +{ + internal static class AnalyticsUtils + { + /// + /// Hash a string to remove PII or secret info before sending to analytics + /// + /// + /// A string containing the Hash128 of the input string. + public static string Hash(string s) + { + var behaviorNameHash = Hash128.Compute(s); + return behaviorNameHash.ToString(); + } + + internal static bool s_SendEditorAnalytics = true; + + /// + /// Helper class to temporarily disable sending analytics from unit tests. + /// + internal class DisableAnalyticsSending : IDisposable + { + private bool m_PreviousSendEditorAnalytics; + + public DisableAnalyticsSending() + { + m_PreviousSendEditorAnalytics = s_SendEditorAnalytics; + s_SendEditorAnalytics = false; + } + + public void Dispose() + { + s_SendEditorAnalytics = m_PreviousSendEditorAnalytics; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta new file mode 100644 index 0000000000..b00fab1c90 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: af1ef3e70f1242938d7b39284b1a892b +timeCreated: 1610575760 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Analytics/Events.cs b/com.unity.ml-agents/Runtime/Analytics/Events.cs index 276587e761..ed22d3c67d 100644 --- a/com.unity.ml-agents/Runtime/Analytics/Events.cs +++ b/com.unity.ml-agents/Runtime/Analytics/Events.cs @@ -91,4 +91,67 @@ public static EventObservationSpec FromSensor(ISensor sensor) }; } } + + internal struct RemotePolicyInitializedEvent + { + public string TrainingSessionGuid; + /// + /// Hash of the BehaviorName. + /// + public string BehaviorName; + public List ObservationSpecs; + public EventActionSpec ActionSpec; + + /// + /// This will be the same as TrainingEnvironmentInitializedEvent if available, but + /// TrainingEnvironmentInitializedEvent maybe not always be available with older trainers. + /// + public string MLAgentsEnvsVersion; + public string TrainerCommunicationVersion; + } + + internal struct TrainingEnvironmentInitializedEvent + { + public string TrainingSessionGuid; + + public string TrainerPythonVersion; + public string MLAgentsVersion; + public string MLAgentsEnvsVersion; + public string TorchVersion; + public string TorchDeviceType; + public int NumEnvironments; + public int NumEnvironmentParameters; + } + + [Flags] + internal enum RewardSignals + { + Extrinsic = 1 << 0, + Gail = 1 << 1, + Curiosity = 1 << 2, + Rnd = 1 << 3, + } + + [Flags] + internal enum TrainingFeatures + { + BehavioralCloning = 1 << 0, + Recurrent = 1 << 1, + Threaded = 1 << 2, + SelfPlay = 1 << 3, + Curriculum = 1 << 4, + } + + internal struct TrainingBehaviorInitializedEvent + { + public string TrainingSessionGuid; + + public string BehaviorName; + public string TrainerType; + public RewardSignals RewardSignalFlags; + public TrainingFeatures TrainingFeatureFlags; + public string VisualEncoder; + public int NumNetworkLayers; + public int NumNetworkHiddenUnits; + } } diff --git a/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs b/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs index fd8d8bc99a..9dc1f4e535 100644 --- a/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs +++ b/com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs @@ -116,7 +116,10 @@ ActionSpec actionSpec // Note - to debug, use JsonUtility.ToJson on the event. //Debug.Log(JsonUtility.ToJson(data, true)); #if UNITY_EDITOR - EditorAnalytics.SendEventWithLimit(k_EventName, data, k_EventVersion); + if (AnalyticsUtils.s_SendEditorAnalytics) + { + EditorAnalytics.SendEventWithLimit(k_EventName, data, k_EventVersion); + } #else return; #endif @@ -143,8 +146,7 @@ ActionSpec actionSpec var inferenceEvent = new InferenceEvent(); // Hash the behavior name so that there's no concern about PII or "secret" data being leaked. - var behaviorNameHash = Hash128.Compute(behaviorName); - inferenceEvent.BehaviorName = behaviorNameHash.ToString(); + inferenceEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName); inferenceEvent.BarracudaModelSource = barracudaModel.IrSource; inferenceEvent.BarracudaModelVersion = barracudaModel.IrVersion; diff --git a/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs new file mode 100644 index 0000000000..6bdea00171 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs @@ -0,0 +1,246 @@ +using System; +using System.Collections.Generic; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Sensors; +using UnityEngine; +using UnityEngine.Analytics; + +#if UNITY_EDITOR +using UnityEditor; +using UnityEditor.Analytics; +#endif + +namespace Unity.MLAgents.Analytics +{ + internal class TrainingAnalytics + { + const string k_VendorKey = "unity.ml-agents"; + const string k_TrainingEnvironmentInitializedEventName = "ml_agents_training_environment_initialized"; + const string k_TrainingBehaviorInitializedEventName = "ml_agents_training_behavior_initialized"; + const string k_RemotePolicyInitializedEventName = "ml_agents_remote_policy_initialized"; + + private static readonly string[] s_EventNames = + { + k_TrainingEnvironmentInitializedEventName, + k_TrainingBehaviorInitializedEventName, + k_RemotePolicyInitializedEventName + }; + + /// + /// Whether or not we've registered this particular event yet + /// + static bool s_EventsRegistered = false; + + /// + /// Hourly limit for this event name + /// + const int k_MaxEventsPerHour = 1000; + + /// + /// Maximum number of items in this event. + /// + const int k_MaxNumberOfElements = 1000; + + private static bool s_SentEnvironmentInitialized; + /// + /// Behaviors that we've already sent events for. + /// + private static HashSet s_SentRemotePolicyInitialized; + private static HashSet s_SentTrainingBehaviorInitialized; + + private static Guid s_TrainingSessionGuid; + + // These are set when the RpcCommunicator connects + private static string s_TrainerPackageVersion = ""; + private static string s_TrainerCommunicationVersion = ""; + + static bool EnableAnalytics() + { + if (s_EventsRegistered) + { + return true; + } + + foreach (var eventName in s_EventNames) + { +#if UNITY_EDITOR + AnalyticsResult result = EditorAnalytics.RegisterEventWithLimit(eventName, k_MaxEventsPerHour, k_MaxNumberOfElements, k_VendorKey); +#else + AnalyticsResult result = AnalyticsResult.UnsupportedPlatform; +#endif + if (result != AnalyticsResult.Ok) + { + return false; + } + } + s_EventsRegistered = true; + + if (s_SentRemotePolicyInitialized == null) + { + s_SentRemotePolicyInitialized = new HashSet(); + s_SentTrainingBehaviorInitialized = new HashSet(); + s_TrainingSessionGuid = Guid.NewGuid(); + } + + return s_EventsRegistered; + } + + /// + /// Cache information about the trainer when it becomes available in the RpcCommunicator. + /// + /// + /// + public static void SetTrainerInformation(string packageVersion, string communicationVersion) + { + s_TrainerPackageVersion = packageVersion; + s_TrainerCommunicationVersion = communicationVersion; + } + + public static bool IsAnalyticsEnabled() + { +#if UNITY_EDITOR + return EditorAnalytics.enabled; +#else + return false; +#endif + } + + public static void TrainingEnvironmentInitialized(TrainingEnvironmentInitializedEvent tbiEvent) + { + if (!IsAnalyticsEnabled()) + return; + + if (!EnableAnalytics()) + return; + + if (s_SentEnvironmentInitialized) + { + // We already sent an TrainingEnvironmentInitializedEvent. Exit so we don't resend. + return; + } + + s_SentEnvironmentInitialized = true; + tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); + + // Note - to debug, use JsonUtility.ToJson on the event. + // Debug.Log( + // $"Would send event {k_TrainingEnvironmentInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}" + // ); +#if UNITY_EDITOR + if (AnalyticsUtils.s_SendEditorAnalytics) + { + EditorAnalytics.SendEventWithLimit(k_TrainingEnvironmentInitializedEventName, tbiEvent); + } +#else + return; +#endif + } + + public static void RemotePolicyInitialized( + string fullyQualifiedBehaviorName, + IList sensors, + ActionSpec actionSpec + ) + { + if (!IsAnalyticsEnabled()) + return; + + if (!EnableAnalytics()) + return; + + // Extract base behavior name (no team ID) + var behaviorName = ParseBehaviorName(fullyQualifiedBehaviorName); + var added = s_SentRemotePolicyInitialized.Add(behaviorName); + + if (!added) + { + // We previously added this model. Exit so we don't resend. + return; + } + + var data = GetEventForRemotePolicy(behaviorName, sensors, actionSpec); + // Note - to debug, use JsonUtility.ToJson on the event. + // Debug.Log( + // $"Would send event {k_RemotePolicyInitializedEventName} with body {JsonUtility.ToJson(data, true)}" + // ); +#if UNITY_EDITOR + if (AnalyticsUtils.s_SendEditorAnalytics) + { + EditorAnalytics.SendEventWithLimit(k_RemotePolicyInitializedEventName, data); + } +#else + return; +#endif + } + + internal static string ParseBehaviorName(string fullyQualifiedBehaviorName) + { + var lastQuestionIndex = fullyQualifiedBehaviorName.LastIndexOf("?"); + if (lastQuestionIndex < 0) + { + // Nothing to remove + return fullyQualifiedBehaviorName; + } + + return fullyQualifiedBehaviorName.Substring(0, lastQuestionIndex); + } + + public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent tbiEvent) + { + if (!IsAnalyticsEnabled()) + return; + + if (!EnableAnalytics()) + return; + + var behaviorName = tbiEvent.BehaviorName; + var added = s_SentTrainingBehaviorInitialized.Add(behaviorName); + + if (!added) + { + // We previously added this model. Exit so we don't resend. + return; + } + + // Hash the behavior name so that there's no concern about PII or "secret" data being leaked. + tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); + tbiEvent.BehaviorName = AnalyticsUtils.Hash(tbiEvent.BehaviorName); + + // Note - to debug, use JsonUtility.ToJson on the event. + // Debug.Log( + // $"Would send event {k_TrainingBehaviorInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}" + // ); +#if UNITY_EDITOR + if (AnalyticsUtils.s_SendEditorAnalytics) + { + EditorAnalytics.SendEventWithLimit(k_TrainingBehaviorInitializedEventName, tbiEvent); + } +#else + return; +#endif + } + + static RemotePolicyInitializedEvent GetEventForRemotePolicy( + string behaviorName, + IList sensors, + ActionSpec actionSpec) + { + var remotePolicyEvent = new RemotePolicyInitializedEvent(); + + // Hash the behavior name so that there's no concern about PII or "secret" data being leaked. + remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName); + + remotePolicyEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString(); + remotePolicyEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec); + remotePolicyEvent.ObservationSpecs = new List(sensors.Count); + foreach (var sensor in sensors) + { + remotePolicyEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor)); + } + + remotePolicyEvent.MLAgentsEnvsVersion = s_TrainerPackageVersion; + remotePolicyEvent.TrainerCommunicationVersion = s_TrainerCommunicationVersion; + return remotePolicyEvent; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta new file mode 100644 index 0000000000..9109c265a2 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 5ad0bc6b45614bb7929d25dd59d5ac38 +timeCreated: 1608168600 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 01a6706c98..b9044dd6a9 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -6,6 +6,7 @@ using UnityEngine; using System.Runtime.CompilerServices; using Unity.MLAgents.Actuators; +using Unity.MLAgents.Analytics; using Unity.MLAgents.Sensors; using Unity.MLAgents.Demonstrations; using Unity.MLAgents.Policies; @@ -435,6 +436,7 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto ConcatenatedPngObservations = proto.ConcatenatedPngObservations, CompressedChannelMapping = proto.CompressedChannelMapping, HybridActions = proto.HybridActions, + TrainingAnalytics = proto.TrainingAnalytics, }; } @@ -446,6 +448,7 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps) ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations, CompressedChannelMapping = rlCaps.CompressedChannelMapping, HybridActions = rlCaps.HybridActions, + TrainingAnalytics = rlCaps.TrainingAnalytics, }; } @@ -476,5 +479,54 @@ internal static bool IsTrivialMapping(ISensor sensor) } return true; } + + #region Analytics + + internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent( + this TrainingEnvironmentInitialized inputProto) + { + return new TrainingEnvironmentInitializedEvent + { + TrainerPythonVersion = inputProto.PythonVersion, + MLAgentsVersion = inputProto.MlagentsVersion, + MLAgentsEnvsVersion = inputProto.MlagentsEnvsVersion, + TorchVersion = inputProto.TorchVersion, + TorchDeviceType = inputProto.TorchDeviceType, + NumEnvironments = inputProto.NumEnvs, + NumEnvironmentParameters = inputProto.NumEnvironmentParameters, + }; + } + + internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEvent( + this TrainingBehaviorInitialized inputProto) + { + RewardSignals rewardSignals = 0; + rewardSignals |= inputProto.ExtrinsicRewardEnabled ? RewardSignals.Extrinsic : 0; + rewardSignals |= inputProto.GailRewardEnabled ? RewardSignals.Gail : 0; + rewardSignals |= inputProto.CuriosityRewardEnabled ? RewardSignals.Curiosity : 0; + rewardSignals |= inputProto.RndRewardEnabled ? RewardSignals.Rnd : 0; + + TrainingFeatures trainingFeatures = 0; + trainingFeatures |= inputProto.BehavioralCloningEnabled ? TrainingFeatures.BehavioralCloning : 0; + trainingFeatures |= inputProto.RecurrentEnabled ? TrainingFeatures.Recurrent : 0; + trainingFeatures |= inputProto.TrainerThreaded ? TrainingFeatures.Threaded : 0; + trainingFeatures |= inputProto.SelfPlayEnabled ? TrainingFeatures.SelfPlay : 0; + trainingFeatures |= inputProto.CurriculumEnabled ? TrainingFeatures.Curriculum : 0; + + + return new TrainingBehaviorInitializedEvent + { + BehaviorName = inputProto.BehaviorName, + TrainerType = inputProto.TrainerType, + RewardSignalFlags = rewardSignals, + TrainingFeatureFlags = trainingFeatures, + VisualEncoder = inputProto.VisualEncoder, + NumNetworkLayers = inputProto.NumNetworkLayers, + NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits, + }; + } + + #endregion + } } diff --git a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs index 515ff52770..683dfebd17 100644 --- a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs +++ b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs @@ -9,6 +9,7 @@ using System.Linq; using UnityEngine; using Unity.MLAgents.Actuators; +using Unity.MLAgents.Analytics; using Unity.MLAgents.CommunicatorObjects; using Unity.MLAgents.Sensors; using Unity.MLAgents.SideChannels; @@ -114,10 +115,12 @@ public UnityRLInitParameters Initialize(CommunicatorInitParameters initParameter }, out input); - var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion; var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion; + var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion; var unityCommunicationVersion = initParameters.unityCommunicationVersion; + TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion); + var communicationIsCompatible = CheckCommunicationVersionsAreCompatible(unityCommunicationVersion, pythonCommunicationVersion, pythonPackageVersion); diff --git a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs index 48db8815b0..79627e48d4 100644 --- a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs +++ b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs @@ -8,6 +8,7 @@ internal class UnityRLCapabilities public bool ConcatenatedPngObservations; public bool CompressedChannelMapping; public bool HybridActions; + public bool TrainingAnalytics; /// /// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This @@ -17,12 +18,14 @@ public UnityRLCapabilities( bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true, - bool hybridActions = true) + bool hybridActions = true, + bool trainingAnalytics = true) { BaseRLCapabilities = baseRlCapabilities; ConcatenatedPngObservations = concatenatedPngObservations; CompressedChannelMapping = compressedChannelMapping; HybridActions = hybridActions; + TrainingAnalytics = trainingAnalytics; } /// diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs index 24137bbc40..3953aea214 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs @@ -25,16 +25,16 @@ static CapabilitiesReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( "CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp", - "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilAEKGFVuaXR5UkxD", + "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMirwEKGFVuaXR5UkxD", "YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS", "IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy", "ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg", - "ASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZw", - "cm90bzM=")); + "ASgIEhkKEXRyYWluaW5nQW5hbHl0aWNzGAUgASgIQiWqAiJVbml0eS5NTEFn", + "ZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions" }, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics" }, null, null, null) })); } #endregion @@ -75,6 +75,7 @@ public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() { concatenatedPngObservations_ = other.concatenatedPngObservations_; compressedChannelMapping_ = other.compressedChannelMapping_; hybridActions_ = other.hybridActions_; + trainingAnalytics_ = other.trainingAnalytics_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -139,6 +140,20 @@ public bool HybridActions { } } + /// Field number for the "trainingAnalytics" field. + public const int TrainingAnalyticsFieldNumber = 5; + private bool trainingAnalytics_; + /// + /// support for training analytics + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool TrainingAnalytics { + get { return trainingAnalytics_; } + set { + trainingAnalytics_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as UnityRLCapabilitiesProto); @@ -156,6 +171,7 @@ public bool Equals(UnityRLCapabilitiesProto other) { if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false; if (CompressedChannelMapping != other.CompressedChannelMapping) return false; if (HybridActions != other.HybridActions) return false; + if (TrainingAnalytics != other.TrainingAnalytics) return false; return Equals(_unknownFields, other._unknownFields); } @@ -166,6 +182,7 @@ public override int GetHashCode() { if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode(); if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode(); if (HybridActions != false) hash ^= HybridActions.GetHashCode(); + if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -195,6 +212,10 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(32); output.WriteBool(HybridActions); } + if (TrainingAnalytics != false) { + output.WriteRawTag(40); + output.WriteBool(TrainingAnalytics); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -215,6 +236,9 @@ public int CalculateSize() { if (HybridActions != false) { size += 1 + 1; } + if (TrainingAnalytics != false) { + size += 1 + 1; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -238,6 +262,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) { if (other.HybridActions != false) { HybridActions = other.HybridActions; } + if (other.TrainingAnalytics != false) { + TrainingAnalytics = other.TrainingAnalytics; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -265,6 +292,10 @@ public void MergeFrom(pb::CodedInputStream input) { HybridActions = input.ReadBool(); break; } + case 40: { + TrainingAnalytics = input.ReadBool(); + break; + } } } } diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs new file mode 100644 index 0000000000..099563e949 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs @@ -0,0 +1,850 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mlagents_envs/communicator_objects/training_analytics.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Unity.MLAgents.CommunicatorObjects { + + /// Holder for reflection information generated from mlagents_envs/communicator_objects/training_analytics.proto + internal static partial class TrainingAnalyticsReflection { + + #region Descriptor + /// File descriptor for mlagents_envs/communicator_objects/training_analytics.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TrainingAnalyticsReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3RyYWluaW5n", + "X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi2QEKHlRy", + "YWluaW5nRW52aXJvbm1lbnRJbml0aWFsaXplZBIYChBtbGFnZW50c192ZXJz", + "aW9uGAEgASgJEh0KFW1sYWdlbnRzX2VudnNfdmVyc2lvbhgCIAEoCRIWCg5w", + "eXRob25fdmVyc2lvbhgDIAEoCRIVCg10b3JjaF92ZXJzaW9uGAQgASgJEhkK", + "EXRvcmNoX2RldmljZV90eXBlGAUgASgJEhAKCG51bV9lbnZzGAYgASgFEiIK", + "Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFIq0DChtUcmFpbmlu", + "Z0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoNYmVoYXZpb3JfbmFtZRgBIAEoCRIU", + "Cgx0cmFpbmVyX3R5cGUYAiABKAkSIAoYZXh0cmluc2ljX3Jld2FyZF9lbmFi", + "bGVkGAMgASgIEhsKE2dhaWxfcmV3YXJkX2VuYWJsZWQYBCABKAgSIAoYY3Vy", + "aW9zaXR5X3Jld2FyZF9lbmFibGVkGAUgASgIEhoKEnJuZF9yZXdhcmRfZW5h", + "YmxlZBgGIAEoCBIiChpiZWhhdmlvcmFsX2Nsb25pbmdfZW5hYmxlZBgHIAEo", + "CBIZChFyZWN1cnJlbnRfZW5hYmxlZBgIIAEoCBIWCg52aXN1YWxfZW5jb2Rl", + "chgJIAEoCRIaChJudW1fbmV0d29ya19sYXllcnMYCiABKAUSIAoYbnVtX25l", + "dHdvcmtfaGlkZGVuX3VuaXRzGAsgASgFEhgKEHRyYWluZXJfdGhyZWFkZWQY", + "DCABKAgSGQoRc2VsZl9wbGF5X2VuYWJsZWQYDSABKAgSGgoSY3VycmljdWx1", + "bV9lbmFibGVkGA4gASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0", + "b3JPYmplY3RzYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled" }, null, null, null) + })); + } + #endregion + + } + #region Messages + internal sealed partial class TrainingEnvironmentInitialized : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrainingEnvironmentInitialized()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.TrainingAnalyticsReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingEnvironmentInitialized() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingEnvironmentInitialized(TrainingEnvironmentInitialized other) : this() { + mlagentsVersion_ = other.mlagentsVersion_; + mlagentsEnvsVersion_ = other.mlagentsEnvsVersion_; + pythonVersion_ = other.pythonVersion_; + torchVersion_ = other.torchVersion_; + torchDeviceType_ = other.torchDeviceType_; + numEnvs_ = other.numEnvs_; + numEnvironmentParameters_ = other.numEnvironmentParameters_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingEnvironmentInitialized Clone() { + return new TrainingEnvironmentInitialized(this); + } + + /// Field number for the "mlagents_version" field. + public const int MlagentsVersionFieldNumber = 1; + private string mlagentsVersion_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MlagentsVersion { + get { return mlagentsVersion_; } + set { + mlagentsVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "mlagents_envs_version" field. + public const int MlagentsEnvsVersionFieldNumber = 2; + private string mlagentsEnvsVersion_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MlagentsEnvsVersion { + get { return mlagentsEnvsVersion_; } + set { + mlagentsEnvsVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "python_version" field. + public const int PythonVersionFieldNumber = 3; + private string pythonVersion_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PythonVersion { + get { return pythonVersion_; } + set { + pythonVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "torch_version" field. + public const int TorchVersionFieldNumber = 4; + private string torchVersion_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TorchVersion { + get { return torchVersion_; } + set { + torchVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "torch_device_type" field. + public const int TorchDeviceTypeFieldNumber = 5; + private string torchDeviceType_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TorchDeviceType { + get { return torchDeviceType_; } + set { + torchDeviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "num_envs" field. + public const int NumEnvsFieldNumber = 6; + private int numEnvs_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumEnvs { + get { return numEnvs_; } + set { + numEnvs_ = value; + } + } + + /// Field number for the "num_environment_parameters" field. + public const int NumEnvironmentParametersFieldNumber = 7; + private int numEnvironmentParameters_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumEnvironmentParameters { + get { return numEnvironmentParameters_; } + set { + numEnvironmentParameters_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TrainingEnvironmentInitialized); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TrainingEnvironmentInitialized other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MlagentsVersion != other.MlagentsVersion) return false; + if (MlagentsEnvsVersion != other.MlagentsEnvsVersion) return false; + if (PythonVersion != other.PythonVersion) return false; + if (TorchVersion != other.TorchVersion) return false; + if (TorchDeviceType != other.TorchDeviceType) return false; + if (NumEnvs != other.NumEnvs) return false; + if (NumEnvironmentParameters != other.NumEnvironmentParameters) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MlagentsVersion.Length != 0) hash ^= MlagentsVersion.GetHashCode(); + if (MlagentsEnvsVersion.Length != 0) hash ^= MlagentsEnvsVersion.GetHashCode(); + if (PythonVersion.Length != 0) hash ^= PythonVersion.GetHashCode(); + if (TorchVersion.Length != 0) hash ^= TorchVersion.GetHashCode(); + if (TorchDeviceType.Length != 0) hash ^= TorchDeviceType.GetHashCode(); + if (NumEnvs != 0) hash ^= NumEnvs.GetHashCode(); + if (NumEnvironmentParameters != 0) hash ^= NumEnvironmentParameters.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MlagentsVersion.Length != 0) { + output.WriteRawTag(10); + output.WriteString(MlagentsVersion); + } + if (MlagentsEnvsVersion.Length != 0) { + output.WriteRawTag(18); + output.WriteString(MlagentsEnvsVersion); + } + if (PythonVersion.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PythonVersion); + } + if (TorchVersion.Length != 0) { + output.WriteRawTag(34); + output.WriteString(TorchVersion); + } + if (TorchDeviceType.Length != 0) { + output.WriteRawTag(42); + output.WriteString(TorchDeviceType); + } + if (NumEnvs != 0) { + output.WriteRawTag(48); + output.WriteInt32(NumEnvs); + } + if (NumEnvironmentParameters != 0) { + output.WriteRawTag(56); + output.WriteInt32(NumEnvironmentParameters); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MlagentsVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MlagentsVersion); + } + if (MlagentsEnvsVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MlagentsEnvsVersion); + } + if (PythonVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PythonVersion); + } + if (TorchVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TorchVersion); + } + if (TorchDeviceType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TorchDeviceType); + } + if (NumEnvs != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvs); + } + if (NumEnvironmentParameters != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvironmentParameters); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TrainingEnvironmentInitialized other) { + if (other == null) { + return; + } + if (other.MlagentsVersion.Length != 0) { + MlagentsVersion = other.MlagentsVersion; + } + if (other.MlagentsEnvsVersion.Length != 0) { + MlagentsEnvsVersion = other.MlagentsEnvsVersion; + } + if (other.PythonVersion.Length != 0) { + PythonVersion = other.PythonVersion; + } + if (other.TorchVersion.Length != 0) { + TorchVersion = other.TorchVersion; + } + if (other.TorchDeviceType.Length != 0) { + TorchDeviceType = other.TorchDeviceType; + } + if (other.NumEnvs != 0) { + NumEnvs = other.NumEnvs; + } + if (other.NumEnvironmentParameters != 0) { + NumEnvironmentParameters = other.NumEnvironmentParameters; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + MlagentsVersion = input.ReadString(); + break; + } + case 18: { + MlagentsEnvsVersion = input.ReadString(); + break; + } + case 26: { + PythonVersion = input.ReadString(); + break; + } + case 34: { + TorchVersion = input.ReadString(); + break; + } + case 42: { + TorchDeviceType = input.ReadString(); + break; + } + case 48: { + NumEnvs = input.ReadInt32(); + break; + } + case 56: { + NumEnvironmentParameters = input.ReadInt32(); + break; + } + } + } + } + + } + + internal sealed partial class TrainingBehaviorInitialized : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrainingBehaviorInitialized()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Unity.MLAgents.CommunicatorObjects.TrainingAnalyticsReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingBehaviorInitialized() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingBehaviorInitialized(TrainingBehaviorInitialized other) : this() { + behaviorName_ = other.behaviorName_; + trainerType_ = other.trainerType_; + extrinsicRewardEnabled_ = other.extrinsicRewardEnabled_; + gailRewardEnabled_ = other.gailRewardEnabled_; + curiosityRewardEnabled_ = other.curiosityRewardEnabled_; + rndRewardEnabled_ = other.rndRewardEnabled_; + behavioralCloningEnabled_ = other.behavioralCloningEnabled_; + recurrentEnabled_ = other.recurrentEnabled_; + visualEncoder_ = other.visualEncoder_; + numNetworkLayers_ = other.numNetworkLayers_; + numNetworkHiddenUnits_ = other.numNetworkHiddenUnits_; + trainerThreaded_ = other.trainerThreaded_; + selfPlayEnabled_ = other.selfPlayEnabled_; + curriculumEnabled_ = other.curriculumEnabled_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainingBehaviorInitialized Clone() { + return new TrainingBehaviorInitialized(this); + } + + /// Field number for the "behavior_name" field. + public const int BehaviorNameFieldNumber = 1; + private string behaviorName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string BehaviorName { + get { return behaviorName_; } + set { + behaviorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "trainer_type" field. + public const int TrainerTypeFieldNumber = 2; + private string trainerType_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TrainerType { + get { return trainerType_; } + set { + trainerType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "extrinsic_reward_enabled" field. + public const int ExtrinsicRewardEnabledFieldNumber = 3; + private bool extrinsicRewardEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ExtrinsicRewardEnabled { + get { return extrinsicRewardEnabled_; } + set { + extrinsicRewardEnabled_ = value; + } + } + + /// Field number for the "gail_reward_enabled" field. + public const int GailRewardEnabledFieldNumber = 4; + private bool gailRewardEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool GailRewardEnabled { + get { return gailRewardEnabled_; } + set { + gailRewardEnabled_ = value; + } + } + + /// Field number for the "curiosity_reward_enabled" field. + public const int CuriosityRewardEnabledFieldNumber = 5; + private bool curiosityRewardEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool CuriosityRewardEnabled { + get { return curiosityRewardEnabled_; } + set { + curiosityRewardEnabled_ = value; + } + } + + /// Field number for the "rnd_reward_enabled" field. + public const int RndRewardEnabledFieldNumber = 6; + private bool rndRewardEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool RndRewardEnabled { + get { return rndRewardEnabled_; } + set { + rndRewardEnabled_ = value; + } + } + + /// Field number for the "behavioral_cloning_enabled" field. + public const int BehavioralCloningEnabledFieldNumber = 7; + private bool behavioralCloningEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool BehavioralCloningEnabled { + get { return behavioralCloningEnabled_; } + set { + behavioralCloningEnabled_ = value; + } + } + + /// Field number for the "recurrent_enabled" field. + public const int RecurrentEnabledFieldNumber = 8; + private bool recurrentEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool RecurrentEnabled { + get { return recurrentEnabled_; } + set { + recurrentEnabled_ = value; + } + } + + /// Field number for the "visual_encoder" field. + public const int VisualEncoderFieldNumber = 9; + private string visualEncoder_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string VisualEncoder { + get { return visualEncoder_; } + set { + visualEncoder_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "num_network_layers" field. + public const int NumNetworkLayersFieldNumber = 10; + private int numNetworkLayers_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumNetworkLayers { + get { return numNetworkLayers_; } + set { + numNetworkLayers_ = value; + } + } + + /// Field number for the "num_network_hidden_units" field. + public const int NumNetworkHiddenUnitsFieldNumber = 11; + private int numNetworkHiddenUnits_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumNetworkHiddenUnits { + get { return numNetworkHiddenUnits_; } + set { + numNetworkHiddenUnits_ = value; + } + } + + /// Field number for the "trainer_threaded" field. + public const int TrainerThreadedFieldNumber = 12; + private bool trainerThreaded_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool TrainerThreaded { + get { return trainerThreaded_; } + set { + trainerThreaded_ = value; + } + } + + /// Field number for the "self_play_enabled" field. + public const int SelfPlayEnabledFieldNumber = 13; + private bool selfPlayEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SelfPlayEnabled { + get { return selfPlayEnabled_; } + set { + selfPlayEnabled_ = value; + } + } + + /// Field number for the "curriculum_enabled" field. + public const int CurriculumEnabledFieldNumber = 14; + private bool curriculumEnabled_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool CurriculumEnabled { + get { return curriculumEnabled_; } + set { + curriculumEnabled_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TrainingBehaviorInitialized); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TrainingBehaviorInitialized other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (BehaviorName != other.BehaviorName) return false; + if (TrainerType != other.TrainerType) return false; + if (ExtrinsicRewardEnabled != other.ExtrinsicRewardEnabled) return false; + if (GailRewardEnabled != other.GailRewardEnabled) return false; + if (CuriosityRewardEnabled != other.CuriosityRewardEnabled) return false; + if (RndRewardEnabled != other.RndRewardEnabled) return false; + if (BehavioralCloningEnabled != other.BehavioralCloningEnabled) return false; + if (RecurrentEnabled != other.RecurrentEnabled) return false; + if (VisualEncoder != other.VisualEncoder) return false; + if (NumNetworkLayers != other.NumNetworkLayers) return false; + if (NumNetworkHiddenUnits != other.NumNetworkHiddenUnits) return false; + if (TrainerThreaded != other.TrainerThreaded) return false; + if (SelfPlayEnabled != other.SelfPlayEnabled) return false; + if (CurriculumEnabled != other.CurriculumEnabled) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (BehaviorName.Length != 0) hash ^= BehaviorName.GetHashCode(); + if (TrainerType.Length != 0) hash ^= TrainerType.GetHashCode(); + if (ExtrinsicRewardEnabled != false) hash ^= ExtrinsicRewardEnabled.GetHashCode(); + if (GailRewardEnabled != false) hash ^= GailRewardEnabled.GetHashCode(); + if (CuriosityRewardEnabled != false) hash ^= CuriosityRewardEnabled.GetHashCode(); + if (RndRewardEnabled != false) hash ^= RndRewardEnabled.GetHashCode(); + if (BehavioralCloningEnabled != false) hash ^= BehavioralCloningEnabled.GetHashCode(); + if (RecurrentEnabled != false) hash ^= RecurrentEnabled.GetHashCode(); + if (VisualEncoder.Length != 0) hash ^= VisualEncoder.GetHashCode(); + if (NumNetworkLayers != 0) hash ^= NumNetworkLayers.GetHashCode(); + if (NumNetworkHiddenUnits != 0) hash ^= NumNetworkHiddenUnits.GetHashCode(); + if (TrainerThreaded != false) hash ^= TrainerThreaded.GetHashCode(); + if (SelfPlayEnabled != false) hash ^= SelfPlayEnabled.GetHashCode(); + if (CurriculumEnabled != false) hash ^= CurriculumEnabled.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (BehaviorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(BehaviorName); + } + if (TrainerType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(TrainerType); + } + if (ExtrinsicRewardEnabled != false) { + output.WriteRawTag(24); + output.WriteBool(ExtrinsicRewardEnabled); + } + if (GailRewardEnabled != false) { + output.WriteRawTag(32); + output.WriteBool(GailRewardEnabled); + } + if (CuriosityRewardEnabled != false) { + output.WriteRawTag(40); + output.WriteBool(CuriosityRewardEnabled); + } + if (RndRewardEnabled != false) { + output.WriteRawTag(48); + output.WriteBool(RndRewardEnabled); + } + if (BehavioralCloningEnabled != false) { + output.WriteRawTag(56); + output.WriteBool(BehavioralCloningEnabled); + } + if (RecurrentEnabled != false) { + output.WriteRawTag(64); + output.WriteBool(RecurrentEnabled); + } + if (VisualEncoder.Length != 0) { + output.WriteRawTag(74); + output.WriteString(VisualEncoder); + } + if (NumNetworkLayers != 0) { + output.WriteRawTag(80); + output.WriteInt32(NumNetworkLayers); + } + if (NumNetworkHiddenUnits != 0) { + output.WriteRawTag(88); + output.WriteInt32(NumNetworkHiddenUnits); + } + if (TrainerThreaded != false) { + output.WriteRawTag(96); + output.WriteBool(TrainerThreaded); + } + if (SelfPlayEnabled != false) { + output.WriteRawTag(104); + output.WriteBool(SelfPlayEnabled); + } + if (CurriculumEnabled != false) { + output.WriteRawTag(112); + output.WriteBool(CurriculumEnabled); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (BehaviorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(BehaviorName); + } + if (TrainerType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TrainerType); + } + if (ExtrinsicRewardEnabled != false) { + size += 1 + 1; + } + if (GailRewardEnabled != false) { + size += 1 + 1; + } + if (CuriosityRewardEnabled != false) { + size += 1 + 1; + } + if (RndRewardEnabled != false) { + size += 1 + 1; + } + if (BehavioralCloningEnabled != false) { + size += 1 + 1; + } + if (RecurrentEnabled != false) { + size += 1 + 1; + } + if (VisualEncoder.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(VisualEncoder); + } + if (NumNetworkLayers != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumNetworkLayers); + } + if (NumNetworkHiddenUnits != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumNetworkHiddenUnits); + } + if (TrainerThreaded != false) { + size += 1 + 1; + } + if (SelfPlayEnabled != false) { + size += 1 + 1; + } + if (CurriculumEnabled != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TrainingBehaviorInitialized other) { + if (other == null) { + return; + } + if (other.BehaviorName.Length != 0) { + BehaviorName = other.BehaviorName; + } + if (other.TrainerType.Length != 0) { + TrainerType = other.TrainerType; + } + if (other.ExtrinsicRewardEnabled != false) { + ExtrinsicRewardEnabled = other.ExtrinsicRewardEnabled; + } + if (other.GailRewardEnabled != false) { + GailRewardEnabled = other.GailRewardEnabled; + } + if (other.CuriosityRewardEnabled != false) { + CuriosityRewardEnabled = other.CuriosityRewardEnabled; + } + if (other.RndRewardEnabled != false) { + RndRewardEnabled = other.RndRewardEnabled; + } + if (other.BehavioralCloningEnabled != false) { + BehavioralCloningEnabled = other.BehavioralCloningEnabled; + } + if (other.RecurrentEnabled != false) { + RecurrentEnabled = other.RecurrentEnabled; + } + if (other.VisualEncoder.Length != 0) { + VisualEncoder = other.VisualEncoder; + } + if (other.NumNetworkLayers != 0) { + NumNetworkLayers = other.NumNetworkLayers; + } + if (other.NumNetworkHiddenUnits != 0) { + NumNetworkHiddenUnits = other.NumNetworkHiddenUnits; + } + if (other.TrainerThreaded != false) { + TrainerThreaded = other.TrainerThreaded; + } + if (other.SelfPlayEnabled != false) { + SelfPlayEnabled = other.SelfPlayEnabled; + } + if (other.CurriculumEnabled != false) { + CurriculumEnabled = other.CurriculumEnabled; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + BehaviorName = input.ReadString(); + break; + } + case 18: { + TrainerType = input.ReadString(); + break; + } + case 24: { + ExtrinsicRewardEnabled = input.ReadBool(); + break; + } + case 32: { + GailRewardEnabled = input.ReadBool(); + break; + } + case 40: { + CuriosityRewardEnabled = input.ReadBool(); + break; + } + case 48: { + RndRewardEnabled = input.ReadBool(); + break; + } + case 56: { + BehavioralCloningEnabled = input.ReadBool(); + break; + } + case 64: { + RecurrentEnabled = input.ReadBool(); + break; + } + case 74: { + VisualEncoder = input.ReadString(); + break; + } + case 80: { + NumNetworkLayers = input.ReadInt32(); + break; + } + case 88: { + NumNetworkHiddenUnits = input.ReadInt32(); + break; + } + case 96: { + TrainerThreaded = input.ReadBool(); + break; + } + case 104: { + SelfPlayEnabled = input.ReadBool(); + break; + } + case 112: { + CurriculumEnabled = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta new file mode 100644 index 0000000000..8e9d358feb --- /dev/null +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 9e6ac06a3931742d798cf922de6b99f0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs index 176eac44d7..fbeff3a15e 100644 --- a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs @@ -1,7 +1,9 @@ using System.Collections.Generic; using Unity.MLAgents.Actuators; +using Unity.MLAgents.Analytics; using Unity.MLAgents.Sensors; + namespace Unity.MLAgents.Policies { /// @@ -14,6 +16,7 @@ internal class RemotePolicy : IPolicy string m_FullyQualifiedBehaviorName; ActionSpec m_ActionSpec; ActionBuffers m_LastActionBuffer; + private bool m_AnalyticsSent = false; internal ICommunicator m_Communicator; @@ -24,13 +27,22 @@ public RemotePolicy( { m_FullyQualifiedBehaviorName = fullyQualifiedBehaviorName; m_Communicator = Academy.Instance.Communicator; - m_Communicator.SubscribeBrain(m_FullyQualifiedBehaviorName, actionSpec); + m_Communicator?.SubscribeBrain(m_FullyQualifiedBehaviorName, actionSpec); m_ActionSpec = actionSpec; } /// public void RequestDecision(AgentInfo info, List sensors) { + if (!m_AnalyticsSent) + { + m_AnalyticsSent = true; + TrainingAnalytics.RemotePolicyInitialized( + m_FullyQualifiedBehaviorName, + sensors, + m_ActionSpec + ); + } m_AgentId = info.episodeId; m_Communicator?.PutObservations(m_FullyQualifiedBehaviorName, info, sensors); } diff --git a/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs new file mode 100644 index 0000000000..6005954f00 --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs @@ -0,0 +1,50 @@ +using System; +using UnityEngine; +using Unity.MLAgents.Analytics; +using Unity.MLAgents.CommunicatorObjects; + +namespace Unity.MLAgents.SideChannels +{ + public class TrainingAnalyticsSideChannel : SideChannel + { + const string k_TrainingAnalyticsConfigId = "b664a4a9-d86f-5a5f-95cb-e8353a7e8356"; + + /// + /// Initializes the side channel. The constructor is internal because only one instance is + /// supported at a time, and is created by the Academy. + /// + internal TrainingAnalyticsSideChannel() + { + ChannelId = new Guid(k_TrainingAnalyticsConfigId); + } + + /// + protected override void OnMessageReceived(IncomingMessage msg) + { + Google.Protobuf.WellKnownTypes.Any anyMessage = null; + try + { + anyMessage = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(msg.GetRawBytes()); + } + catch (Google.Protobuf.InvalidProtocolBufferException) + { + // Bad message, nothing we can do about it, so just ignore. + return; + } + + if (anyMessage.Is(TrainingEnvironmentInitialized.Descriptor)) + { + var envInitProto = anyMessage.Unpack(); + var envInitEvent = envInitProto.ToTrainingEnvironmentInitializedEvent(); + TrainingAnalytics.TrainingEnvironmentInitialized(envInitEvent); + } + else if (anyMessage.Is(TrainingBehaviorInitialized.Descriptor)) + { + var behaviorInitProto = anyMessage.Unpack(); + var behaviorTrainingEvent = behaviorInitProto.ToTrainingBehaviorInitializedEvent(); + TrainingAnalytics.TrainingBehaviorInitialized(behaviorTrainingEvent); + } + // Don't do anything for unknown types, since the user probably can't do anything about it. + } + } +} diff --git a/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta new file mode 100644 index 0000000000..757d0d0d4f --- /dev/null +++ b/com.unity.ml-agents/Runtime/SideChannels/TrainingAnalyticsSideChannel.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 13c87198bbd54b40a0b93308eb37933e +timeCreated: 1608337471 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs b/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs index 0705796965..489c64f2f0 100644 --- a/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs @@ -26,6 +26,11 @@ ActionSpec GetContinuous2vis8vec2actionActionSpec() [SetUp] public void SetUp() { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel)); var go = new GameObject("SensorA"); sensor_21_20_3 = go.AddComponent(); @@ -65,5 +70,18 @@ public void TestModelEvent() Assert.IsTrue(jsonString.Contains("SensorName")); Assert.IsTrue(jsonString.Contains("Flags")); } + + [Test] + public void TestBarracudaPolicy() + { + // Explicitly request decisions for a policy so we get code coverage on the event sending + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + var sensors = new List { sensor_21_20_3.Sensor, sensor_20_22_3.Sensor }; + var policy = new BarracudaPolicy(GetContinuous2vis8vec2actionActionSpec(), continuousONNXModel, InferenceDevice.CPU, "testBehavior"); + policy.RequestDecision(new AgentInfo(), sensors); + } + Academy.Instance.Dispose(); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs new file mode 100644 index 0000000000..6a5b30e72a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs @@ -0,0 +1,42 @@ +using System.Collections.Generic; +using NUnit.Framework; +using Unity.MLAgents.Sensors; +using UnityEngine; +using Unity.Barracuda; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Analytics; +using Unity.MLAgents.Policies; +using UnityEditor; + +namespace Unity.MLAgents.Tests.Analytics +{ + [TestFixture] + public class TrainingAnalyticsTests + { + [TestCase("foo?team=42", ExpectedResult = "foo")] + [TestCase("foo", ExpectedResult = "foo")] + [TestCase("foo?bar?team=1337", ExpectedResult = "foo?bar")] + public string TestParseBehaviorName(string fullyQualifiedBehaviorName) + { + return TrainingAnalytics.ParseBehaviorName(fullyQualifiedBehaviorName); + } + + [Test] + public void TestRemotePolicy() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + var actionSpec = ActionSpec.MakeContinuous(3); + var policy = new RemotePolicy(actionSpec, "TestBehavior?team=42"); + policy.RequestDecision(new AgentInfo(), new List()); + } + + Academy.Instance.Dispose(); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta new file mode 100644 index 0000000000..df394c157a --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 70b8f1544bc34b4e8f1bc1068c64f01c +timeCreated: 1610419546 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs index fed5e97039..2529ecd96f 100644 --- a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs @@ -1,7 +1,9 @@ using NUnit.Framework; -using Unity.MLAgents.Policies; -using Unity.MLAgents.Demonstrations; using Unity.MLAgents.Actuators; +using Unity.MLAgents.Analytics; +using Unity.MLAgents.CommunicatorObjects; +using Unity.MLAgents.Demonstrations; +using Unity.MLAgents.Policies; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests @@ -169,5 +171,31 @@ public void TestIsTrivialMapping() sparseChannelSensor.Mapping = new[] { 0, 0, 0, 1, 1, 1 }; Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false); } + + [Test] + public void TestDefaultTrainingEvents() + { + var trainingEnvInit = new TrainingEnvironmentInitialized + { + PythonVersion = "test", + }; + var trainingEnvInitEvent = trainingEnvInit.ToTrainingEnvironmentInitializedEvent(); + Assert.AreEqual(trainingEnvInit.PythonVersion, trainingEnvInitEvent.TrainerPythonVersion); + + var trainingBehavInit = new TrainingBehaviorInitialized + { + BehaviorName = "testBehavior", + ExtrinsicRewardEnabled = true, + CuriosityRewardEnabled = true, + + RecurrentEnabled = true, + SelfPlayEnabled = true, + }; + var trainingBehavInitEvent = trainingBehavInit.ToTrainingBehaviorInitializedEvent(); + Assert.AreEqual(trainingBehavInit.BehaviorName, trainingBehavInitEvent.BehaviorName); + + Assert.AreEqual(RewardSignals.Extrinsic | RewardSignals.Curiosity, trainingBehavInitEvent.RewardSignalFlags); + Assert.AreEqual(TrainingFeatures.Recurrent | TrainingFeatures.SelfPlay, trainingBehavInitEvent.TrainingFeatureFlags); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs new file mode 100644 index 0000000000..d71abf6555 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs @@ -0,0 +1,65 @@ +using System; +using System.Linq; +using System.Text; +using NUnit.Framework; +using Google.Protobuf; +using Unity.MLAgents.Analytics; +using Unity.MLAgents.SideChannels; +using Unity.MLAgents.CommunicatorObjects; + + +namespace Unity.MLAgents.Tests +{ + /// + /// These tests send messages through the event handling code. + /// There's no output to test, so just make sure there are no exceptions + /// (and get the code coverage above the minimum). + /// + public class TrainingAnalyticsSideChannelTests + { + [Test] + public void TestTrainingEnvironmentReceived() + { + var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingEnvironmentInitialized()); + var anyMsgBytes = anyMsg.ToByteArray(); + var sideChannel = new TrainingAnalyticsSideChannel(); + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + sideChannel.ProcessMessage(anyMsgBytes); + } + } + + [Test] + public void TestTrainingBehaviorReceived() + { + var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingBehaviorInitialized()); + var anyMsgBytes = anyMsg.ToByteArray(); + var sideChannel = new TrainingAnalyticsSideChannel(); + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + sideChannel.ProcessMessage(anyMsgBytes); + } + } + + [Test] + public void TestInvalidProtobufMessage() + { + // Test an invalid (non-protobuf) message. This should silently ignore the data. + var badBytes = Encoding.ASCII.GetBytes("Lorem ipsum"); + var sideChannel = new TrainingAnalyticsSideChannel(); + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + sideChannel.ProcessMessage(badBytes); + } + + // Test an almost-valid message. This should silently ignore the data. + var anyMsg = Google.Protobuf.WellKnownTypes.Any.Pack(new TrainingBehaviorInitialized()); + var anyMsgBytes = anyMsg.ToByteArray(); + var truncatedMessage = new ArraySegment(anyMsgBytes, 0, anyMsgBytes.Length - 1).ToArray(); + using (new AnalyticsUtils.DisableAnalyticsSending()) + { + sideChannel.ProcessMessage(truncatedMessage); + } + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta new file mode 100644 index 0000000000..ebb5915235 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TrainingAnalyticsSideChannelTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: c2a71036ddec4ba4bf83c5e8ba1b8daa +timeCreated: 1610574895 \ No newline at end of file diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py index bd87dc3b23..054fec848a 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py @@ -19,7 +19,7 @@ name='mlagents_envs/communicator_objects/capabilities.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\x94\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xaf\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') ) @@ -60,6 +60,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='trainingAnalytics', full_name='communicator_objects.UnityRLCapabilitiesProto.trainingAnalytics', index=4, + number=5, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -73,7 +80,7 @@ oneofs=[ ], serialized_start=80, - serialized_end=228, + serialized_end=255, ) DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi index f1799fe6e5..69cf46fd0e 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi @@ -29,6 +29,7 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message): concatenatedPngObservations = ... # type: builtin___bool compressedChannelMapping = ... # type: builtin___bool hybridActions = ... # type: builtin___bool + trainingAnalytics = ... # type: builtin___bool def __init__(self, *, @@ -36,12 +37,13 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message): concatenatedPngObservations : typing___Optional[builtin___bool] = None, compressedChannelMapping : typing___Optional[builtin___bool] = None, hybridActions : typing___Optional[builtin___bool] = None, + trainingAnalytics : typing___Optional[builtin___bool] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"trainingAnalytics"]) -> None: ... else: - def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"trainingAnalytics",b"trainingAnalytics"]) -> None: ... diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.py new file mode 100644 index 0000000000..1e775c9710 --- /dev/null +++ b/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.py @@ -0,0 +1,243 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: mlagents_envs/communicator_objects/training_analytics.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='mlagents_envs/communicator_objects/training_analytics.proto', + package='communicator_objects', + syntax='proto3', + serialized_pb=_b('\n;mlagents_envs/communicator_objects/training_analytics.proto\x12\x14\x63ommunicator_objects\"\xd9\x01\n\x1eTrainingEnvironmentInitialized\x12\x18\n\x10mlagents_version\x18\x01 \x01(\t\x12\x1d\n\x15mlagents_envs_version\x18\x02 \x01(\t\x12\x16\n\x0epython_version\x18\x03 \x01(\t\x12\x15\n\rtorch_version\x18\x04 \x01(\t\x12\x19\n\x11torch_device_type\x18\x05 \x01(\t\x12\x10\n\x08num_envs\x18\x06 \x01(\x05\x12\"\n\x1anum_environment_parameters\x18\x07 \x01(\x05\"\xad\x03\n\x1bTrainingBehaviorInitialized\x12\x15\n\rbehavior_name\x18\x01 \x01(\t\x12\x14\n\x0ctrainer_type\x18\x02 \x01(\t\x12 \n\x18\x65xtrinsic_reward_enabled\x18\x03 \x01(\x08\x12\x1b\n\x13gail_reward_enabled\x18\x04 \x01(\x08\x12 \n\x18\x63uriosity_reward_enabled\x18\x05 \x01(\x08\x12\x1a\n\x12rnd_reward_enabled\x18\x06 \x01(\x08\x12\"\n\x1a\x62\x65havioral_cloning_enabled\x18\x07 \x01(\x08\x12\x19\n\x11recurrent_enabled\x18\x08 \x01(\x08\x12\x16\n\x0evisual_encoder\x18\t \x01(\t\x12\x1a\n\x12num_network_layers\x18\n \x01(\x05\x12 \n\x18num_network_hidden_units\x18\x0b \x01(\x05\x12\x18\n\x10trainer_threaded\x18\x0c \x01(\x08\x12\x19\n\x11self_play_enabled\x18\r \x01(\x08\x12\x1a\n\x12\x63urriculum_enabled\x18\x0e \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') +) + + + + +_TRAININGENVIRONMENTINITIALIZED = _descriptor.Descriptor( + name='TrainingEnvironmentInitialized', + full_name='communicator_objects.TrainingEnvironmentInitialized', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='mlagents_version', full_name='communicator_objects.TrainingEnvironmentInitialized.mlagents_version', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='mlagents_envs_version', full_name='communicator_objects.TrainingEnvironmentInitialized.mlagents_envs_version', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='python_version', full_name='communicator_objects.TrainingEnvironmentInitialized.python_version', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='torch_version', full_name='communicator_objects.TrainingEnvironmentInitialized.torch_version', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='torch_device_type', full_name='communicator_objects.TrainingEnvironmentInitialized.torch_device_type', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='num_envs', full_name='communicator_objects.TrainingEnvironmentInitialized.num_envs', index=5, + number=6, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='num_environment_parameters', full_name='communicator_objects.TrainingEnvironmentInitialized.num_environment_parameters', index=6, + number=7, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=86, + serialized_end=303, +) + + +_TRAININGBEHAVIORINITIALIZED = _descriptor.Descriptor( + name='TrainingBehaviorInitialized', + full_name='communicator_objects.TrainingBehaviorInitialized', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='behavior_name', full_name='communicator_objects.TrainingBehaviorInitialized.behavior_name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='trainer_type', full_name='communicator_objects.TrainingBehaviorInitialized.trainer_type', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='extrinsic_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.extrinsic_reward_enabled', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='gail_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.gail_reward_enabled', index=3, + number=4, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='curiosity_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.curiosity_reward_enabled', index=4, + number=5, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='rnd_reward_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.rnd_reward_enabled', index=5, + number=6, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='behavioral_cloning_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.behavioral_cloning_enabled', index=6, + number=7, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='recurrent_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.recurrent_enabled', index=7, + number=8, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='visual_encoder', full_name='communicator_objects.TrainingBehaviorInitialized.visual_encoder', index=8, + number=9, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='num_network_layers', full_name='communicator_objects.TrainingBehaviorInitialized.num_network_layers', index=9, + number=10, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='num_network_hidden_units', full_name='communicator_objects.TrainingBehaviorInitialized.num_network_hidden_units', index=10, + number=11, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='trainer_threaded', full_name='communicator_objects.TrainingBehaviorInitialized.trainer_threaded', index=11, + number=12, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='self_play_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.self_play_enabled', index=12, + number=13, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='curriculum_enabled', full_name='communicator_objects.TrainingBehaviorInitialized.curriculum_enabled', index=13, + number=14, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=306, + serialized_end=735, +) + +DESCRIPTOR.message_types_by_name['TrainingEnvironmentInitialized'] = _TRAININGENVIRONMENTINITIALIZED +DESCRIPTOR.message_types_by_name['TrainingBehaviorInitialized'] = _TRAININGBEHAVIORINITIALIZED +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +TrainingEnvironmentInitialized = _reflection.GeneratedProtocolMessageType('TrainingEnvironmentInitialized', (_message.Message,), dict( + DESCRIPTOR = _TRAININGENVIRONMENTINITIALIZED, + __module__ = 'mlagents_envs.communicator_objects.training_analytics_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.TrainingEnvironmentInitialized) + )) +_sym_db.RegisterMessage(TrainingEnvironmentInitialized) + +TrainingBehaviorInitialized = _reflection.GeneratedProtocolMessageType('TrainingBehaviorInitialized', (_message.Message,), dict( + DESCRIPTOR = _TRAININGBEHAVIORINITIALIZED, + __module__ = 'mlagents_envs.communicator_objects.training_analytics_pb2' + # @@protoc_insertion_point(class_scope:communicator_objects.TrainingBehaviorInitialized) + )) +_sym_db.RegisterMessage(TrainingBehaviorInitialized) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\"Unity.MLAgents.CommunicatorObjects')) +# @@protoc_insertion_point(module_scope) diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.pyi new file mode 100644 index 0000000000..a347de6874 --- /dev/null +++ b/ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.pyi @@ -0,0 +1,97 @@ +# @generated by generate_proto_mypy_stubs.py. Do not edit! +import sys +from google.protobuf.descriptor import ( + Descriptor as google___protobuf___descriptor___Descriptor, +) + +from google.protobuf.message import ( + Message as google___protobuf___message___Message, +) + +from typing import ( + Optional as typing___Optional, + Text as typing___Text, +) + +from typing_extensions import ( + Literal as typing_extensions___Literal, +) + + +builtin___bool = bool +builtin___bytes = bytes +builtin___float = float +builtin___int = int + + +class TrainingEnvironmentInitialized(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + mlagents_version = ... # type: typing___Text + mlagents_envs_version = ... # type: typing___Text + python_version = ... # type: typing___Text + torch_version = ... # type: typing___Text + torch_device_type = ... # type: typing___Text + num_envs = ... # type: builtin___int + num_environment_parameters = ... # type: builtin___int + + def __init__(self, + *, + mlagents_version : typing___Optional[typing___Text] = None, + mlagents_envs_version : typing___Optional[typing___Text] = None, + python_version : typing___Optional[typing___Text] = None, + torch_version : typing___Optional[typing___Text] = None, + torch_device_type : typing___Optional[typing___Text] = None, + num_envs : typing___Optional[builtin___int] = None, + num_environment_parameters : typing___Optional[builtin___int] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: builtin___bytes) -> TrainingEnvironmentInitialized: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def ClearField(self, field_name: typing_extensions___Literal[u"mlagents_envs_version",u"mlagents_version",u"num_environment_parameters",u"num_envs",u"python_version",u"torch_device_type",u"torch_version"]) -> None: ... + else: + def ClearField(self, field_name: typing_extensions___Literal[u"mlagents_envs_version",b"mlagents_envs_version",u"mlagents_version",b"mlagents_version",u"num_environment_parameters",b"num_environment_parameters",u"num_envs",b"num_envs",u"python_version",b"python_version",u"torch_device_type",b"torch_device_type",u"torch_version",b"torch_version"]) -> None: ... + +class TrainingBehaviorInitialized(google___protobuf___message___Message): + DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... + behavior_name = ... # type: typing___Text + trainer_type = ... # type: typing___Text + extrinsic_reward_enabled = ... # type: builtin___bool + gail_reward_enabled = ... # type: builtin___bool + curiosity_reward_enabled = ... # type: builtin___bool + rnd_reward_enabled = ... # type: builtin___bool + behavioral_cloning_enabled = ... # type: builtin___bool + recurrent_enabled = ... # type: builtin___bool + visual_encoder = ... # type: typing___Text + num_network_layers = ... # type: builtin___int + num_network_hidden_units = ... # type: builtin___int + trainer_threaded = ... # type: builtin___bool + self_play_enabled = ... # type: builtin___bool + curriculum_enabled = ... # type: builtin___bool + + def __init__(self, + *, + behavior_name : typing___Optional[typing___Text] = None, + trainer_type : typing___Optional[typing___Text] = None, + extrinsic_reward_enabled : typing___Optional[builtin___bool] = None, + gail_reward_enabled : typing___Optional[builtin___bool] = None, + curiosity_reward_enabled : typing___Optional[builtin___bool] = None, + rnd_reward_enabled : typing___Optional[builtin___bool] = None, + behavioral_cloning_enabled : typing___Optional[builtin___bool] = None, + recurrent_enabled : typing___Optional[builtin___bool] = None, + visual_encoder : typing___Optional[typing___Text] = None, + num_network_layers : typing___Optional[builtin___int] = None, + num_network_hidden_units : typing___Optional[builtin___int] = None, + trainer_threaded : typing___Optional[builtin___bool] = None, + self_play_enabled : typing___Optional[builtin___bool] = None, + curriculum_enabled : typing___Optional[builtin___bool] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: builtin___bytes) -> TrainingBehaviorInitialized: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def ClearField(self, field_name: typing_extensions___Literal[u"behavior_name",u"behavioral_cloning_enabled",u"curiosity_reward_enabled",u"curriculum_enabled",u"extrinsic_reward_enabled",u"gail_reward_enabled",u"num_network_hidden_units",u"num_network_layers",u"recurrent_enabled",u"rnd_reward_enabled",u"self_play_enabled",u"trainer_threaded",u"trainer_type",u"visual_encoder"]) -> None: ... + else: + def ClearField(self, field_name: typing_extensions___Literal[u"behavior_name",b"behavior_name",u"behavioral_cloning_enabled",b"behavioral_cloning_enabled",u"curiosity_reward_enabled",b"curiosity_reward_enabled",u"curriculum_enabled",b"curriculum_enabled",u"extrinsic_reward_enabled",b"extrinsic_reward_enabled",u"gail_reward_enabled",b"gail_reward_enabled",u"num_network_hidden_units",b"num_network_hidden_units",u"num_network_layers",b"num_network_layers",u"recurrent_enabled",b"recurrent_enabled",u"rnd_reward_enabled",b"rnd_reward_enabled",u"self_play_enabled",b"self_play_enabled",u"trainer_threaded",b"trainer_threaded",u"trainer_type",b"trainer_type",u"visual_encoder",b"visual_encoder"]) -> None: ... diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 14b4dddb53..5d844749c4 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -62,7 +62,8 @@ class UnityEnvironment(BaseEnv): # * 1.1.0 - support concatenated PNGs for compressed observations. # * 1.2.0 - support compression mapping for stacked compressed observations. # * 1.3.0 - support action spaces with both continuous and discrete actions. - API_VERSION = "1.3.0" + # * 1.4.0 - support training analytics sent from python trainer to the editor. + API_VERSION = "1.4.0" # Default port that the editor listens on. If an environment executable # isn't specified, this port will be used. @@ -120,6 +121,7 @@ def _get_capabilities_proto() -> UnityRLCapabilitiesProto: capabilities.concatenatedPngObservations = True capabilities.compressedChannelMapping = True capabilities.hybridActions = True + capabilities.trainingAnalytics = True return capabilities @staticmethod @@ -183,6 +185,7 @@ def __init__( self._worker_id = worker_id self._side_channel_manager = SideChannelManager(side_channels) self._log_folder = log_folder + self.academy_capabilities: UnityRLCapabilitiesProto = None # type: ignore # If the environment name is None, a new environment will not be launched # and the communicator will directly try to connect to an existing unity environment. @@ -239,6 +242,7 @@ def __init__( self._env_actions: Dict[str, ActionTuple] = {} self._is_first_message = True self._update_behavior_specs(aca_output) + self.academy_capabilities = aca_params.capabilities @staticmethod def _get_communicator(worker_id, base_port, timeout_wait): diff --git a/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py b/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py index 4e4c15d795..ce1715ba07 100644 --- a/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py +++ b/ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py @@ -53,7 +53,7 @@ def on_message_received(self, msg: IncomingMessage) -> None: """ raise UnityCommunicationException( "The EngineConfigurationChannel received a message from Unity, " - + "this should not have happend." + + "this should not have happened." ) def set_configuration_parameters( diff --git a/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py b/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py index 9629f1b5d5..ff516a5eb5 100644 --- a/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py +++ b/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py @@ -28,7 +28,7 @@ def __init__(self) -> None: def on_message_received(self, msg: IncomingMessage) -> None: raise UnityCommunicationException( "The EnvironmentParametersChannel received a message from Unity, " - + "this should not have happend." + + "this should not have happened." ) def set_float_parameter(self, key: str, value: float) -> None: diff --git a/ml-agents/mlagents/trainers/env_manager.py b/ml-agents/mlagents/trainers/env_manager.py index 905badd9f4..a4a90fdc70 100644 --- a/ml-agents/mlagents/trainers/env_manager.py +++ b/ml-agents/mlagents/trainers/env_manager.py @@ -12,6 +12,7 @@ from mlagents.trainers.policy import Policy from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue from mlagents.trainers.action_info import ActionInfo +from mlagents.trainers.settings import TrainerSettings from mlagents_envs.logging_util import get_logger AllStepResult = Dict[BehaviorName, Tuple[DecisionSteps, TerminalSteps]] @@ -76,6 +77,17 @@ def set_env_parameters(self, config: Dict = None) -> None: """ pass + def on_training_started( + self, behavior_name: str, trainer_settings: TrainerSettings + ) -> None: + """ + Handle traing starting for a new behavior type. Generally nothing is necessary here. + :param behavior_name: + :param trainer_settings: + :return: + """ + pass + @property @abstractmethod def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]: diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index 82b65f59b4..c5bf8daac0 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -28,7 +28,6 @@ from mlagents_envs.base_env import BaseEnv from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager from mlagents_envs.side_channel.side_channel import SideChannel -from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig from mlagents_envs.timers import ( hierarchical_timer, get_timer_tree, @@ -109,17 +108,8 @@ def run_training(run_seed: int, options: RunOptions) -> None: env_settings.env_args, os.path.abspath(run_logs_dir), # Unity environment requires absolute path ) - engine_config = EngineConfig( - width=engine_settings.width, - height=engine_settings.height, - quality_level=engine_settings.quality_level, - time_scale=engine_settings.time_scale, - target_frame_rate=engine_settings.target_frame_rate, - capture_frame_rate=engine_settings.capture_frame_rate, - ) - env_manager = SubprocessEnvManager( - env_factory, engine_config, env_settings.num_envs - ) + + env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs) env_parameter_manager = EnvironmentParameterManager( options.environment_parameters, run_seed, restore=checkpoint_settings.resume ) diff --git a/ml-agents/mlagents/trainers/subprocess_env_manager.py b/ml-agents/mlagents/trainers/subprocess_env_manager.py index 1eb5a4aa9f..4bad60c8b4 100644 --- a/ml-agents/mlagents/trainers/subprocess_env_manager.py +++ b/ml-agents/mlagents/trainers/subprocess_env_manager.py @@ -16,6 +16,7 @@ from mlagents_envs.base_env import BaseEnv, BehaviorName, BehaviorSpec from mlagents_envs import logging_util from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult +from mlagents.trainers.settings import TrainerSettings from mlagents_envs.timers import ( TimerNode, timed, @@ -23,7 +24,7 @@ reset_timers, get_timer_root, ) -from mlagents.trainers.settings import ParameterRandomizationSettings +from mlagents.trainers.settings import ParameterRandomizationSettings, RunOptions from mlagents.trainers.action_info import ActionInfo from mlagents_envs.side_channel.environment_parameters_channel import ( EnvironmentParametersChannel, @@ -33,9 +34,10 @@ EngineConfig, ) from mlagents_envs.side_channel.stats_side_channel import ( - StatsSideChannel, EnvironmentStats, + StatsSideChannel, ) +from mlagents.training_analytics_side_channel import TrainingAnalyticsSideChannel from mlagents_envs.side_channel.side_channel import SideChannel @@ -51,6 +53,7 @@ class EnvironmentCommand(enum.Enum): CLOSE = 5 ENV_EXITED = 6 CLOSED = 7 + TRAINING_STARTED = 8 class EnvironmentRequest(NamedTuple): @@ -112,17 +115,30 @@ def worker( step_queue: Queue, pickled_env_factory: str, worker_id: int, - engine_configuration: EngineConfig, + run_options: RunOptions, log_level: int = logging_util.INFO, ) -> None: env_factory: Callable[ [int, List[SideChannel]], UnityEnvironment ] = cloudpickle.loads(pickled_env_factory) env_parameters = EnvironmentParametersChannel() + + engine_config = EngineConfig( + width=run_options.engine_settings.width, + height=run_options.engine_settings.height, + quality_level=run_options.engine_settings.quality_level, + time_scale=run_options.engine_settings.time_scale, + target_frame_rate=run_options.engine_settings.target_frame_rate, + capture_frame_rate=run_options.engine_settings.capture_frame_rate, + ) engine_configuration_channel = EngineConfigurationChannel() - engine_configuration_channel.set_configuration(engine_configuration) + engine_configuration_channel.set_configuration(engine_config) + stats_channel = StatsSideChannel() - env: BaseEnv = None + training_analytics_channel: Optional[TrainingAnalyticsSideChannel] = None + if worker_id == 0: + training_analytics_channel = TrainingAnalyticsSideChannel() + env: UnityEnvironment = None # Set log level. On some platforms, the logger isn't common with the # main process, so we need to set it again. logging_util.set_log_level(log_level) @@ -137,9 +153,21 @@ def _generate_all_results() -> AllStepResult: return all_step_result try: - env = env_factory( - worker_id, [env_parameters, engine_configuration_channel, stats_channel] - ) + side_channels = [env_parameters, engine_configuration_channel, stats_channel] + if training_analytics_channel is not None: + side_channels.append(training_analytics_channel) + + env = env_factory(worker_id, side_channels) + if ( + not env.academy_capabilities + or not env.academy_capabilities.trainingAnalytics + ): + # Make sure we don't try to send training analytics if the environment doesn't know how to process + # them. This wouldn't be catastrophic, but would result in unknown SideChannel UUIDs being used. + training_analytics_channel = None + if training_analytics_channel: + training_analytics_channel.environment_initialized(run_options) + while True: req: EnvironmentRequest = parent_conn.recv() if req.cmd == EnvironmentCommand.STEP: @@ -170,6 +198,12 @@ def _generate_all_results() -> AllStepResult: for k, v in req.payload.items(): if isinstance(v, ParameterRandomizationSettings): v.apply(k, env_parameters) + elif req.cmd == EnvironmentCommand.TRAINING_STARTED: + behavior_name, trainer_config = req.payload + if training_analytics_channel: + training_analytics_channel.training_started( + behavior_name, trainer_config + ) elif req.cmd == EnvironmentCommand.RESET: env.reset() all_step_result = _generate_all_results() @@ -210,7 +244,7 @@ class SubprocessEnvManager(EnvManager): def __init__( self, env_factory: Callable[[int, List[SideChannel]], BaseEnv], - engine_configuration: EngineConfig, + run_options: RunOptions, n_env: int = 1, ): super().__init__() @@ -220,7 +254,7 @@ def __init__( for worker_idx in range(n_env): self.env_workers.append( self.create_worker( - worker_idx, self.step_queue, env_factory, engine_configuration + worker_idx, self.step_queue, env_factory, run_options ) ) self.workers_alive += 1 @@ -230,7 +264,7 @@ def create_worker( worker_id: int, step_queue: Queue, env_factory: Callable[[int, List[SideChannel]], BaseEnv], - engine_configuration: EngineConfig, + run_options: RunOptions, ) -> UnityEnvWorker: parent_conn, child_conn = Pipe() @@ -244,7 +278,7 @@ def create_worker( step_queue, pickled_env_factory, worker_id, - engine_configuration, + run_options, logger.level, ), ) @@ -308,6 +342,20 @@ def set_env_parameters(self, config: Dict = None) -> None: for ew in self.env_workers: ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config) + def on_training_started( + self, behavior_name: str, trainer_settings: TrainerSettings + ) -> None: + """ + Handle traing starting for a new behavior type. Generally nothing is necessary here. + :param behavior_name: + :param trainer_settings: + :return: + """ + for ew in self.env_workers: + ew.send( + EnvironmentCommand.TRAINING_STARTED, (behavior_name, trainer_settings) + ) + @property def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]: result: Dict[BehaviorName, BehaviorSpec] = {} diff --git a/ml-agents/mlagents/trainers/tests/simple_test_envs.py b/ml-agents/mlagents/trainers/tests/simple_test_envs.py index 01a196c8db..e7f44b0f56 100644 --- a/ml-agents/mlagents/trainers/tests/simple_test_envs.py +++ b/ml-agents/mlagents/trainers/tests/simple_test_envs.py @@ -72,6 +72,8 @@ def __init__( self.step_result: Dict[str, Tuple[DecisionSteps, TerminalSteps]] = {} self.agent_id: Dict[str, int] = {} self.step_size = step_size # defines the difficulty of the test + # Allow to be used as a UnityEnvironment during tests + self.academy_capabilities = None for name in self.names: self.agent_id[name] = 0 diff --git a/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py b/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py index 9b99d833c3..4fd9cc96c2 100644 --- a/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py +++ b/ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py @@ -4,6 +4,7 @@ import pytest from queue import Empty as EmptyQueue +from mlagents.trainers.settings import RunOptions from mlagents.trainers.subprocess_env_manager import ( SubprocessEnvManager, EnvironmentResponse, @@ -12,7 +13,6 @@ ) from mlagents.trainers.env_manager import EnvironmentStep from mlagents_envs.base_env import BaseEnv -from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod from mlagents_envs.exception import UnityEnvironmentException from mlagents.trainers.tests.simple_test_envs import ( @@ -54,16 +54,13 @@ class SubprocessEnvManagerTest(unittest.TestCase): ) def test_environments_are_created(self, mock_create_worker): mock_create_worker.side_effect = create_worker_mock - env = SubprocessEnvManager(mock_env_factory, EngineConfig.default_config(), 2) + run_options = RunOptions() + env = SubprocessEnvManager(mock_env_factory, run_options, 2) # Creates two processes env.create_worker.assert_has_calls( [ - mock.call( - 0, env.step_queue, mock_env_factory, EngineConfig.default_config() - ), - mock.call( - 1, env.step_queue, mock_env_factory, EngineConfig.default_config() - ), + mock.call(0, env.step_queue, mock_env_factory, run_options), + mock.call(1, env.step_queue, mock_env_factory, run_options), ] ) self.assertEqual(len(env.env_workers), 2) @@ -73,9 +70,7 @@ def test_environments_are_created(self, mock_create_worker): ) def test_reset_passes_reset_params(self, mock_create_worker): mock_create_worker.side_effect = create_worker_mock - manager = SubprocessEnvManager( - mock_env_factory, EngineConfig.default_config(), 1 - ) + manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 1) params = {"test": "params"} manager._reset_env(params) manager.env_workers[0].send.assert_called_with( @@ -87,9 +82,7 @@ def test_reset_passes_reset_params(self, mock_create_worker): ) def test_reset_collects_results_from_all_envs(self, mock_create_worker): mock_create_worker.side_effect = create_worker_mock - manager = SubprocessEnvManager( - mock_env_factory, EngineConfig.default_config(), 4 - ) + manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 4) params = {"test": "params"} res = manager._reset_env(params) @@ -117,9 +110,7 @@ def create_worker_mock(worker_id, step_queue, env_factor, engine_c): ) mock_create_worker.side_effect = create_worker_mock - manager = SubprocessEnvManager( - mock_env_factory, EngineConfig.default_config(), 4 - ) + manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 4) res = manager.training_behaviors for env in manager.env_workers: @@ -134,9 +125,7 @@ def create_worker_mock(worker_id, step_queue, env_factor, engine_c): ) def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker): mock_create_worker.side_effect = create_worker_mock - manager = SubprocessEnvManager( - mock_env_factory, EngineConfig.default_config(), 3 - ) + manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 3) manager.step_queue = Mock() manager.step_queue.get_nowait.side_effect = [ EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(0, None, {})), @@ -176,9 +165,7 @@ def test_advance(self, mock_create_worker, training_behaviors_mock, step_mock): brain_name = "testbrain" action_info_dict = {brain_name: MagicMock()} mock_create_worker.side_effect = create_worker_mock - env_manager = SubprocessEnvManager( - mock_env_factory, EngineConfig.default_config(), 3 - ) + env_manager = SubprocessEnvManager(mock_env_factory, RunOptions(), 3) training_behaviors_mock.return_value = [brain_name] agent_manager_mock = mock.Mock() mock_policy = mock.Mock() @@ -219,9 +206,7 @@ def simple_env_factory(worker_id, config): env = SimpleEnvironment(["1D"], action_sizes=(0, 1)) return env - env_manager = SubprocessEnvManager( - simple_env_factory, EngineConfig.default_config(), num_envs - ) + env_manager = SubprocessEnvManager(simple_env_factory, RunOptions(), num_envs) # Run PPO using env_manager check_environment_trains( simple_env_factory(0, []), @@ -250,9 +235,7 @@ def failing_step_env_factory(_worker_id, _config): ) return env - env_manager = SubprocessEnvManager( - failing_step_env_factory, EngineConfig.default_config() - ) + env_manager = SubprocessEnvManager(failing_step_env_factory, RunOptions()) # Expect the exception raised to be routed back up to the top level. with pytest.raises(CustomTestOnlyException): check_environment_trains( @@ -275,9 +258,7 @@ def failing_env_factory(worker_id, config): time.sleep(0.5) raise UnityEnvironmentException() - env_manager = SubprocessEnvManager( - failing_env_factory, EngineConfig.default_config(), num_envs - ) + env_manager = SubprocessEnvManager(failing_env_factory, RunOptions(), num_envs) with pytest.raises(UnityEnvironmentException): env_manager.reset() env_manager.close() diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index 7f9808f5dd..c4a60f8b3a 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -130,6 +130,9 @@ def _create_trainer_and_manager( target=self.trainer_update_func, args=(trainer,), daemon=True ) self.trainer_threads.append(trainerthread) + env_manager.on_training_started( + brain_name, self.trainer_factory.trainer_config[brain_name] + ) policy = trainer.create_policy( parsed_behavior_id, diff --git a/ml-agents/mlagents/training_analytics_side_channel.py b/ml-agents/mlagents/training_analytics_side_channel.py new file mode 100644 index 0000000000..f964f13fac --- /dev/null +++ b/ml-agents/mlagents/training_analytics_side_channel.py @@ -0,0 +1,99 @@ +import sys +from typing import Optional +import uuid +import mlagents_envs +import mlagents.trainers +from mlagents import torch_utils +from mlagents.trainers.settings import RewardSignalType +from mlagents_envs.exception import UnityCommunicationException +from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage +from mlagents_envs.communicator_objects.training_analytics_pb2 import ( + TrainingEnvironmentInitialized, + TrainingBehaviorInitialized, +) +from google.protobuf.any_pb2 import Any + +from mlagents.trainers.settings import TrainerSettings, RunOptions + + +class TrainingAnalyticsSideChannel(SideChannel): + """ + Side channel that sends information about the training to the Unity environment so it can be logged. + """ + + def __init__(self) -> None: + # >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/TrainingAnalyticsSideChannel") + # UUID('b664a4a9-d86f-5a5f-95cb-e8353a7e8356') + super().__init__(uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356")) + self.run_options: Optional[RunOptions] = None + + def on_message_received(self, msg: IncomingMessage) -> None: + raise UnityCommunicationException( + "The TrainingAnalyticsSideChannel received a message from Unity, " + + "this should not have happened." + ) + + def environment_initialized(self, run_options: RunOptions) -> None: + self.run_options = run_options + # Tuple of (major, minor, patch) + vi = sys.version_info + env_params = run_options.environment_parameters + + msg = TrainingEnvironmentInitialized( + python_version=f"{vi[0]}.{vi[1]}.{vi[2]}", + mlagents_version=mlagents.trainers.__version__, + mlagents_envs_version=mlagents_envs.__version__, + torch_version=torch_utils.torch.__version__, + torch_device_type=torch_utils.default_device().type, + num_envs=run_options.env_settings.num_envs, + num_environment_parameters=len(env_params) if env_params else 0, + ) + + any_message = Any() + any_message.Pack(msg) + + env_init_msg = OutgoingMessage() + env_init_msg.set_raw_bytes(any_message.SerializeToString()) + super().queue_message_to_send(env_init_msg) + + def training_started(self, behavior_name: str, config: TrainerSettings) -> None: + msg = TrainingBehaviorInitialized( + behavior_name=behavior_name, + trainer_type=config.trainer_type.value, + extrinsic_reward_enabled=( + RewardSignalType.EXTRINSIC in config.reward_signals + ), + gail_reward_enabled=(RewardSignalType.GAIL in config.reward_signals), + curiosity_reward_enabled=( + RewardSignalType.CURIOSITY in config.reward_signals + ), + rnd_reward_enabled=(RewardSignalType.RND in config.reward_signals), + behavioral_cloning_enabled=config.behavioral_cloning is not None, + recurrent_enabled=config.network_settings.memory is not None, + visual_encoder=config.network_settings.vis_encode_type.value, + num_network_layers=config.network_settings.num_layers, + num_network_hidden_units=config.network_settings.hidden_units, + trainer_threaded=config.threaded, + self_play_enabled=config.self_play is not None, + curriculum_enabled=self._behavior_uses_curriculum(behavior_name), + ) + + any_message = Any() + any_message.Pack(msg) + + training_start_msg = OutgoingMessage() + training_start_msg.set_raw_bytes(any_message.SerializeToString()) + + super().queue_message_to_send(training_start_msg) + + def _behavior_uses_curriculum(self, behavior_name: str) -> bool: + if not self.run_options or not self.run_options.environment_parameters: + return False + + for param_settings in self.run_options.environment_parameters.values(): + for lesson in param_settings.curriculum: + cc = lesson.completion_criteria + if cc and cc.behavior == behavior_name: + return True + + return False diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto index 109b1f0c97..723cc2cd5c 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto @@ -19,4 +19,7 @@ message UnityRLCapabilitiesProto { // support for hybrid action spaces (discrete + continuous) bool hybridActions = 4; + + // support for training analytics + bool trainingAnalytics = 5; } diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/training_analytics.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/training_analytics.proto new file mode 100644 index 0000000000..52f0ed6109 --- /dev/null +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/training_analytics.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +option csharp_namespace = "Unity.MLAgents.CommunicatorObjects"; +package communicator_objects; + +message TrainingEnvironmentInitialized { + string mlagents_version = 1; + string mlagents_envs_version = 2; + string python_version = 3; + string torch_version = 4; + string torch_device_type = 5; + int32 num_envs = 6; + int32 num_environment_parameters = 7; +} + +message TrainingBehaviorInitialized { + string behavior_name = 1; + string trainer_type = 2; + bool extrinsic_reward_enabled = 3; + bool gail_reward_enabled = 4; + bool curiosity_reward_enabled = 5; + bool rnd_reward_enabled = 6; + bool behavioral_cloning_enabled = 7; + bool recurrent_enabled = 8; + string visual_encoder = 9; + int32 num_network_layers = 10; + int32 num_network_hidden_units = 11; + bool trainer_threaded = 12; + bool self_play_enabled = 13; + bool curriculum_enabled = 14; +}