- 
                Notifications
    You must be signed in to change notification settings 
- Fork 4.4k
Moving domain randomization to C# #4065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 38 commits
8d0244b
              eb0c495
              e9d8350
              a1c771e
              e856f7b
              daa5688
              9756a2c
              ec2493f
              aa4ebd9
              460a2ea
              54b6959
              b4469ca
              3d26047
              f18da6a
              7f116cd
              46f6491
              5286675
              9dbcc4b
              d3e0d9c
              4c111e4
              38f48f1
              988452a
              cd06ce7
              e40951f
              1b9f2d5
              aff9c00
              63a24cc
              da3cb2d
              fae0ca3
              681c7ea
              c40c4d0
              87a3cfc
              320233b
              94e1d20
              9d18e7b
              38e1115
              de6ba28
              bf52d2f
              915d102
              af11e36
              61e3165
              49ee082
              cf838d9
              fd5420f
              4afc7b1
              4c6cf57
              9088b4c
              c5da1c3
              1b45bab
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| using System; | ||
| using System.Collections.Generic; | ||
| using Unity.MLAgents.Inference.Utils; | ||
| using UnityEngine; | ||
| using Random=System.Random; | ||
|  | ||
| namespace Unity.MLAgents | ||
| { | ||
|  | ||
| /// <summary> | ||
| /// Takes a list of floats that encode a sampling distribution and returns the sampling function. | ||
| /// </summary> | ||
| internal sealed class SamplerFactory | ||
| { | ||
| /// <summary> | ||
| /// Constructor. | ||
| /// </summary> | ||
| internal SamplerFactory() | ||
|          | ||
| { | ||
| } | ||
|  | ||
| public Func<float> CreateUniformSampler(float min, float max, int seed) | ||
| { | ||
| Random distr = new Random(seed); | ||
| return () => min + (float)distr.NextDouble() * (max - min); | ||
| } | ||
|  | ||
| public Func<float> CreateGaussianSampler(float mean, float stddev, int seed) | ||
| { | ||
| RandomNormal distr = new RandomNormal(seed, mean, stddev); | ||
| return () => (float)distr.NextDouble(); | ||
| } | ||
|  | ||
| public Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed) | ||
| { | ||
| //RNG | ||
| Random distr = new Random(seed); | ||
| // Will be used to normalize intervalFuncs | ||
| float sumIntervalSizes = 0; | ||
| //The number of intervals | ||
| int numIntervals = (int)(intervals.Count/2); | ||
| // List that will store interval lengths | ||
| float[] intervalSizes = new float[numIntervals]; | ||
| // List that will store uniform distributions | ||
| IList<Func<float>> intervalFuncs = new Func<float>[numIntervals]; | ||
| // Collect all intervals and store as uniform distrus | ||
| // Collect all interval sizes | ||
| for(int i = 0; i < numIntervals; i++) | ||
| { | ||
| var min = intervals[2 * i]; | ||
| var max = intervals[2 * i + 1]; | ||
| var intervalSize = max - min; | ||
| sumIntervalSizes += intervalSize; | ||
| intervalSizes[i] = intervalSize; | ||
| intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize; | ||
| } | ||
| // Normalize interval lengths | ||
| for(int i = 0; i < numIntervals; i++) | ||
| { | ||
| intervalSizes[i] = intervalSizes[i] / sumIntervalSizes; | ||
| } | ||
| // Build cmf for intervals | ||
| for(int i = 1; i < numIntervals; i++) | ||
| { | ||
| intervalSizes[i] += intervalSizes[i - 1]; | ||
| } | ||
| Multinomial intervalDistr = new Multinomial(seed + 1); | ||
| float MultiRange() | ||
| { | ||
| int sampledInterval = intervalDistr.Sample(intervalSizes); | ||
| return intervalFuncs[sampledInterval].Invoke(); | ||
| } | ||
| return MultiRange; | ||
| } | ||
| } | ||
| } | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| using System; | ||
| using NUnit.Framework; | ||
| using System.IO; | ||
| using System.Collections.Generic; | ||
| using UnityEngine; | ||
| using Unity.MLAgents.SideChannels; | ||
|  | ||
| namespace Unity.MLAgents.Tests | ||
| { | ||
| public class SamplerTests | ||
| { | ||
| const int k_Seed = 1337; | ||
| const double k_Epsilon = 0.0001; | ||
| EnvironmentParametersChannel m_Channel = new EnvironmentParametersChannel(); | ||
|  | ||
| public SamplerTests() | ||
| { | ||
| SideChannelsManager.RegisterSideChannel(m_Channel); | ||
| } | ||
|  | ||
| [Test] | ||
| public void UniformSamplerTest() | ||
| { | ||
| float min_value = 1.0f; | ||
| float max_value = 2.0f; | ||
| string parameter = "parameter1"; | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 1.0f); | ||
| using (var outgoingMsg = new OutgoingMessage()) | ||
| { | ||
| outgoingMsg.WriteString(parameter); | ||
| // 1 indicates this meessage is a Sampler | ||
| outgoingMsg.WriteInt32(1); | ||
| outgoingMsg.WriteInt32(k_Seed); | ||
| outgoingMsg.WriteInt32((int)SamplerType.Uniform); | ||
| outgoingMsg.WriteFloat32(min_value); | ||
| outgoingMsg.WriteFloat32(max_value); | ||
| byte[] message = GetByteMessage(m_Channel, outgoingMsg); | ||
| SideChannelsManager.ProcessSideChannelData(message); | ||
| } | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 1.208888f, k_Epsilon); | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 1.118017f, k_Epsilon); | ||
| } | ||
|  | ||
| [Test] | ||
| public void GaussianSamplerTest() | ||
| { | ||
| float mean = 3.0f; | ||
| float stddev = 0.2f; | ||
| string parameter = "parameter2"; | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 1.0f); | ||
| using (var outgoingMsg = new OutgoingMessage()) | ||
| { | ||
| outgoingMsg.WriteString(parameter); | ||
| // 1 indicates this meessage is a Sampler | ||
| outgoingMsg.WriteInt32(1); | ||
| outgoingMsg.WriteInt32(k_Seed); | ||
| outgoingMsg.WriteInt32((int)SamplerType.Gaussian); | ||
| outgoingMsg.WriteFloat32(mean); | ||
| outgoingMsg.WriteFloat32(stddev); | ||
| byte[] message = GetByteMessage(m_Channel, outgoingMsg); | ||
| SideChannelsManager.ProcessSideChannelData(message); | ||
| } | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 2.936162f, k_Epsilon); | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 2.951348f, k_Epsilon); | ||
| } | ||
|  | ||
| [Test] | ||
| public void MultiRangeUniformSamplerTest() | ||
| { | ||
| float[] intervals = new float[4]; | ||
| intervals[0] = 1.2f; | ||
| intervals[1] = 2f; | ||
| intervals[2] = 3.2f; | ||
| intervals[3] = 4.1f; | ||
| string parameter = "parameter3"; | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 1.0f); | ||
| using (var outgoingMsg = new OutgoingMessage()) | ||
| { | ||
| outgoingMsg.WriteString(parameter); | ||
| // 1 indicates this meessage is a Sampler | ||
| outgoingMsg.WriteInt32(1); | ||
| outgoingMsg.WriteInt32(k_Seed); | ||
| outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform); | ||
| outgoingMsg.WriteFloatList(intervals); | ||
| byte[] message = GetByteMessage(m_Channel, outgoingMsg); | ||
| SideChannelsManager.ProcessSideChannelData(message); | ||
| } | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 3.387999f, k_Epsilon); | ||
| Assert.AreEqual(m_Channel.GetWithDefault(parameter, 1.0f), 1.294413f, k_Epsilon); | ||
| } | ||
|  | ||
| internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg) | ||
| { | ||
| byte[] message = msg.ToByteArray(); | ||
| using (var memStream = new MemoryStream()) | ||
| { | ||
| using (var binaryWriter = new BinaryWriter(memStream)) | ||
| { | ||
| binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); | ||
| binaryWriter.Write(message.Length); | ||
| binaryWriter.Write(message); | ||
| } | ||
| return memStream.ToArray(); | ||
| } | ||
| } | ||
| } | ||
| } | 
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.