diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index ab6ef40cca..fb08f6c330 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -11,6 +11,8 @@ and this project adheres to
### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
+- The Parameter Randomization feature has been refactored to enable sampling of new parameters per episode to improve robustness. The
+ `resampling-interval` parameter has been removed and the config structure updated. More information [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-ML-Agents.md). (#4065)
### Minor Changes
#### com.unity.ml-agents (C#)
diff --git a/com.unity.ml-agents/Runtime/Sampler.cs b/com.unity.ml-agents/Runtime/Sampler.cs
new file mode 100644
index 0000000000..fc48f3c271
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Sampler.cs
@@ -0,0 +1,70 @@
+using System;
+using System.Collections.Generic;
+using Unity.MLAgents.Inference.Utils;
+using UnityEngine;
+using Random=System.Random;
+
+namespace Unity.MLAgents
+{
+
+ ///
+ /// Takes a list of floats that encode a sampling distribution and returns the sampling function.
+ ///
+ internal static class SamplerFactory
+ {
+
+ public static Func CreateUniformSampler(float min, float max, int seed)
+ {
+ Random distr = new Random(seed);
+ return () => min + (float)distr.NextDouble() * (max - min);
+ }
+
+ public static Func CreateGaussianSampler(float mean, float stddev, int seed)
+ {
+ RandomNormal distr = new RandomNormal(seed, mean, stddev);
+ return () => (float)distr.NextDouble();
+ }
+
+ public static Func CreateMultiRangeUniformSampler(IList intervals, int seed)
+ {
+ //RNG
+ Random distr = new Random(seed);
+ // Will be used to normalize intervalFuncs
+ float sumIntervalSizes = 0;
+ //The number of intervals
+ int numIntervals = (int)(intervals.Count/2);
+ // List that will store interval lengths
+ float[] intervalSizes = new float[numIntervals];
+ // List that will store uniform distributions
+ IList> intervalFuncs = new Func[numIntervals];
+ // Collect all intervals and store as uniform distrus
+ // Collect all interval sizes
+ for(int i = 0; i < numIntervals; i++)
+ {
+ var min = intervals[2 * i];
+ var max = intervals[2 * i + 1];
+ var intervalSize = max - min;
+ sumIntervalSizes += intervalSize;
+ intervalSizes[i] = intervalSize;
+ intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize;
+ }
+ // Normalize interval lengths
+ for(int i = 0; i < numIntervals; i++)
+ {
+ intervalSizes[i] = intervalSizes[i] / sumIntervalSizes;
+ }
+ // Build cmf for intervals
+ for(int i = 1; i < numIntervals; i++)
+ {
+ intervalSizes[i] += intervalSizes[i - 1];
+ }
+ Multinomial intervalDistr = new Multinomial(seed + 1);
+ float MultiRange()
+ {
+ int sampledInterval = intervalDistr.Sample(intervalSizes);
+ return intervalFuncs[sampledInterval].Invoke();
+ }
+ return MultiRange;
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Sampler.cs.meta b/com.unity.ml-agents/Runtime/Sampler.cs.meta
new file mode 100644
index 0000000000..950e28c5b6
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Sampler.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 39ce0ea5a8b2e47f696f6efc807029f6
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
index 5c9fd509b6..d28f84d2da 100644
--- a/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
+++ b/com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs
@@ -9,7 +9,30 @@ namespace Unity.MLAgents.SideChannels
///
internal enum EnvironmentDataTypes
{
- Float = 0
+ Float = 0,
+ Sampler = 1
+ }
+
+ ///
+ /// The types of distributions from which to sample reset parameters.
+ ///
+ internal enum SamplerType
+ {
+ ///
+ /// Samples a reset parameter from a uniform distribution.
+ ///
+ Uniform = 0,
+
+ ///
+ /// Samples a reset parameter from a Gaussian distribution.
+ ///
+ Gaussian = 1,
+
+ ///
+ /// Samples a reset parameter from a MultiRangeUniform distribution.
+ ///
+ MultiRangeUniform = 2
+
}
///
@@ -18,7 +41,7 @@ internal enum EnvironmentDataTypes
///
internal class EnvironmentParametersChannel : SideChannel
{
- Dictionary m_Parameters = new Dictionary();
+ Dictionary> m_Parameters = new Dictionary>();
Dictionary> m_RegisteredActions =
new Dictionary>();
@@ -42,12 +65,40 @@ protected override void OnMessageReceived(IncomingMessage msg)
{
var value = msg.ReadFloat32();
- m_Parameters[key] = value;
+ m_Parameters[key] = () => value;
Action action;
m_RegisteredActions.TryGetValue(key, out action);
action?.Invoke(value);
}
+ else if ((int)EnvironmentDataTypes.Sampler == type)
+ {
+ int seed = msg.ReadInt32();
+ int samplerType = msg.ReadInt32();
+ Func sampler = () => 0.0f;
+ if ((int)SamplerType.Uniform == samplerType)
+ {
+ float min = msg.ReadFloat32();
+ float max = msg.ReadFloat32();
+ sampler = SamplerFactory.CreateUniformSampler(min, max, seed);
+ }
+ else if ((int)SamplerType.Gaussian == samplerType)
+ {
+ float mean = msg.ReadFloat32();
+ float stddev = msg.ReadFloat32();
+
+ sampler = SamplerFactory.CreateGaussianSampler(mean, stddev, seed);
+ }
+ else if ((int)SamplerType.MultiRangeUniform == samplerType)
+ {
+ IList intervals = msg.ReadFloatList();
+ sampler = SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed);
+ }
+ else{
+ Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
+ }
+ m_Parameters[key] = sampler;
+ }
else
{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
@@ -63,9 +114,9 @@ protected override void OnMessageReceived(IncomingMessage msg)
///
public float GetWithDefault(string key, float defaultValue)
{
- float valueOut;
+ Func valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
- return hasKey ? valueOut : defaultValue;
+ return hasKey ? valueOut.Invoke() : defaultValue;
}
///
diff --git a/com.unity.ml-agents/Tests/Editor/SamplerTests.cs b/com.unity.ml-agents/Tests/Editor/SamplerTests.cs
new file mode 100644
index 0000000000..14307e6733
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/SamplerTests.cs
@@ -0,0 +1,109 @@
+using System;
+using NUnit.Framework;
+using System.IO;
+using System.Collections.Generic;
+using UnityEngine;
+using Unity.MLAgents.SideChannels;
+
+namespace Unity.MLAgents.Tests
+{
+ public class SamplerTests
+ {
+ const int k_Seed = 1337;
+ const double k_Epsilon = 0.0001;
+ EnvironmentParametersChannel m_Channel;
+
+ public SamplerTests()
+ {
+ m_Channel = SideChannelsManager.GetSideChannel();
+ // if running test on its own
+ if (m_Channel == null)
+ {
+ m_Channel = new EnvironmentParametersChannel();
+ SideChannelsManager.RegisterSideChannel(m_Channel);
+ }
+ }
+ [Test]
+ public void UniformSamplerTest()
+ {
+ float min_value = 1.0f;
+ float max_value = 2.0f;
+ string parameter = "parameter1";
+ using (var outgoingMsg = new OutgoingMessage())
+ {
+ outgoingMsg.WriteString(parameter);
+ // 1 indicates this meessage is a Sampler
+ outgoingMsg.WriteInt32(1);
+ outgoingMsg.WriteInt32(k_Seed);
+ outgoingMsg.WriteInt32((int)SamplerType.Uniform);
+ outgoingMsg.WriteFloat32(min_value);
+ outgoingMsg.WriteFloat32(max_value);
+ byte[] message = GetByteMessage(m_Channel, outgoingMsg);
+ SideChannelsManager.ProcessSideChannelData(message);
+ }
+ Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
+ Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
+ }
+
+ [Test]
+ public void GaussianSamplerTest()
+ {
+ float mean = 3.0f;
+ float stddev = 0.2f;
+ string parameter = "parameter2";
+ using (var outgoingMsg = new OutgoingMessage())
+ {
+ outgoingMsg.WriteString(parameter);
+ // 1 indicates this meessage is a Sampler
+ outgoingMsg.WriteInt32(1);
+ outgoingMsg.WriteInt32(k_Seed);
+ outgoingMsg.WriteInt32((int)SamplerType.Gaussian);
+ outgoingMsg.WriteFloat32(mean);
+ outgoingMsg.WriteFloat32(stddev);
+ byte[] message = GetByteMessage(m_Channel, outgoingMsg);
+ SideChannelsManager.ProcessSideChannelData(message);
+ }
+ Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
+ Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
+ }
+
+ [Test]
+ public void MultiRangeUniformSamplerTest()
+ {
+ float[] intervals = new float[4];
+ intervals[0] = 1.2f;
+ intervals[1] = 2f;
+ intervals[2] = 3.2f;
+ intervals[3] = 4.1f;
+ string parameter = "parameter3";
+ using (var outgoingMsg = new OutgoingMessage())
+ {
+ outgoingMsg.WriteString(parameter);
+ // 1 indicates this meessage is a Sampler
+ outgoingMsg.WriteInt32(1);
+ outgoingMsg.WriteInt32(k_Seed);
+ outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform);
+ outgoingMsg.WriteFloatList(intervals);
+ byte[] message = GetByteMessage(m_Channel, outgoingMsg);
+ SideChannelsManager.ProcessSideChannelData(message);
+ }
+ Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
+ Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
+ }
+
+ internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg)
+ {
+ byte[] message = msg.ToByteArray();
+ using (var memStream = new MemoryStream())
+ {
+ using (var binaryWriter = new BinaryWriter(memStream))
+ {
+ binaryWriter.Write(sideChannel.ChannelId.ToByteArray());
+ binaryWriter.Write(message.Length);
+ binaryWriter.Write(message);
+ }
+ return memStream.ToArray();
+ }
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta b/com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta
new file mode 100644
index 0000000000..ef0d54e72a
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 7e6609c51018d4132beda8ddedd46d91
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/config/ppo/3DBall_randomize.yaml b/config/ppo/3DBall_randomize.yaml
index b3c6c13f21..2f3608b880 100644
--- a/config/ppo/3DBall_randomize.yaml
+++ b/config/ppo/3DBall_randomize.yaml
@@ -26,16 +26,13 @@ behaviors:
threaded: true
parameter_randomization:
- resampling-interval: 5000
mass:
- sampler-type: uniform
- min_value: 0.5
- max_value: 10
- gravity:
- sampler-type: uniform
- min_value: 7
- max_value: 12
+ sampler_type: uniform
+ sampler_parameters:
+ min_value: 0.5
+ max_value: 10
scale:
- sampler-type: uniform
- min_value: 0.75
- max_value: 3
+ sampler_type: uniform
+ sampler_parameters:
+ min_value: 0.75
+ max_value: 3
diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md
index 2f93938466..b61fc3b24b 100644
--- a/docs/Training-ML-Agents.md
+++ b/docs/Training-ML-Agents.md
@@ -435,97 +435,57 @@ behaviors:
# < Same as above>
parameter_randomization:
- resampling-interval: 5000
mass:
- sampler-type: "uniform"
- min_value: 0.5
- max_value: 10
+ sampler_type: uniform
+ sampler_parameters:
+ min_value: 0.5
+ max_value: 10
- gravity:
- sampler-type: "multirange_uniform"
- intervals: [[7, 10], [15, 20]]
+ length:
+ sampler_type: multirangeuniform
+ sampler_parameters:
+ intervals: [[7, 10], [15, 20]]
scale:
- sampler-type: "uniform"
- min_value: 0.75
- max_value: 3
+ sampler_type: gaussian
+ sampler_parameters:
+ mean: 2
+ st_dev: .3
```
-Note that `mass`, `gravity` and `scale` are the names of the environment
-parameters that will be sampled. If a parameter specified in the file doesn't
-exist in the environment, then this parameter will be ignored.
+Note that `mass`, `length` and `scale` are the names of the environment
+parameters that will be sampled. These are used as keys by the `EnvironmentParameter`
+class to sample new parameters via the function `GetWithDefault`.
| **Setting** | **Description** |
| :--------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `resampling-interval` | Number of steps for the agent to train under a particular environment configuration before resetting the environment with a new sample of `Environment Parameters`. |
-| `sampler-type` | Type of sampler use for this `Environment Parameter`. This is a string that should exist in the `Sampler Factory` (explained below). |
-| `sampler-type-sub-arguments` | Specify the sub-arguments depending on the `sampler-type`. In the example above, this would correspond to the `intervals` under the `sampler-type` `multirange_uniform` for the `Environment Parameter` called `gravity`. The key name should match the name of the corresponding argument in the sampler definition (explained) below) |
+| `sampler_type` | A string identifier for the type of sampler to use for this `Environment Parameter`. |
+| `sampler_parameters` | The parameters for a given `sampler_type`. Samplers of different types can have different `sampler_parameters` |
-#### Included Sampler Types
+#### Supported Sampler Types
-Below is a list of included `sampler-type` as part of the toolkit.
+Below is a list of the `sampler_type` values supported by the toolkit.
- `uniform` - Uniform sampler
- - Uniformly samples a single float value between defined endpoints. The
- sub-arguments for this sampler to specify the interval endpoints are as
- below. The sampling is done in the range of [`min_value`, `max_value`).
- - **sub-arguments** - `min_value`, `max_value`
+ - Uniformly samples a single float value from a range with a given minimum
+ and maximum value (inclusive).
+ - **parameters** - `min_value`, `max_value`
- `gaussian` - Gaussian sampler
- - Samples a single float value from the distribution characterized by the mean
- and standard deviation. The sub-arguments to specify the Gaussian
- distribution to use are as below.
- - **sub-arguments** - `mean`, `st_dev`
+ - Samples a single float value from a normal distribution with a given mean
+ and standard deviation.
+ - **parameters** - `mean`, `st_dev`
- `multirange_uniform` - Multirange uniform sampler
- - Uniformly samples a single float value between the specified intervals.
- Samples by first performing a weight pick of an interval from the list of
- intervals (weighted based on interval width) and samples uniformly from the
- selected interval (half-closed interval, same as the uniform sampler). This
- sampler can take an arbitrary number of intervals in a list in the following
- format: [[`interval_1_min`, `interval_1_max`], [`interval_2_min`,
+ - First, samples an interval from a set of intervals in proportion to relative
+ length of the intervals. Then, uniformly samples a single float value from the
+ sampled interval (inclusive). This sampler can take an arbitrary number of
+ intervals in a list in the following format:
+ [[`interval_1_min`, `interval_1_max`], [`interval_2_min`,
`interval_2_max`], ...]
- - **sub-arguments** - `intervals`
+ - **parameters** - `intervals`
The implementation of the samplers can be found in the
-[sampler_class.py file](../ml-agents/mlagents/trainers/sampler_class.py).
-
-#### Defining a New Sampler Type
-
-If you want to define your own sampler type, you must first inherit the
-_Sampler_ base class (included in the `sampler_class` file) and preserve the
-interface. Once the class for the required method is specified, it must be
-registered in the Sampler Factory.
-
-This can be done by subscribing to the _register_sampler_ method of the
-`SamplerFactory`. The command is as follows:
-
-`SamplerFactory.register_sampler(*custom_sampler_string_key*, *custom_sampler_object*)`
-
-Once the Sampler Factory reflects the new register, the new sampler type can be
-used for sample any `Environment Parameter`. For example, lets say a new sampler
-type was implemented as below and we register the `CustomSampler` class with the
-string `custom-sampler` in the Sampler Factory.
-
-```python
-class CustomSampler(Sampler):
-
- def __init__(self, argA, argB, argC):
- self.possible_vals = [argA, argB, argC]
-
- def sample_all(self):
- return np.random.choice(self.possible_vals)
-```
-
-Now we need to specify the new sampler type in the sampler YAML file. For
-example, we use this new sampler type for the `Environment Parameter` _mass_.
-
-```yaml
-mass:
- sampler-type: "custom-sampler"
- argB: 1
- argA: 2
- argC: 3
-```
+[Samplers.cs file](../com.unity.ml-agents/Runtime/Sampler.cs).
#### Training with Environment Parameter Randomization
diff --git a/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py b/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
index 958364b675..2d379cbb3f 100644
--- a/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
+++ b/ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
@@ -2,6 +2,7 @@
from mlagents_envs.exception import UnityCommunicationException
import uuid
from enum import IntEnum
+from typing import List, Tuple
class EnvironmentParametersChannel(SideChannel):
@@ -13,6 +14,12 @@ class EnvironmentParametersChannel(SideChannel):
class EnvironmentDataTypes(IntEnum):
FLOAT = 0
+ SAMPLER = 1
+
+ class SamplerTypes(IntEnum):
+ UNIFORM = 0
+ GAUSSIAN = 1
+ MULTIRANGEUNIFORM = 2
def __init__(self) -> None:
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860400"))
@@ -35,3 +42,59 @@ def set_float_parameter(self, key: str, value: float) -> None:
msg.write_int32(self.EnvironmentDataTypes.FLOAT)
msg.write_float32(value)
super().queue_message_to_send(msg)
+
+ def set_uniform_sampler_parameters(
+ self, key: str, min_value: float, max_value: float, seed: int
+ ) -> None:
+ """
+ Sets a uniform environment parameter sampler.
+ :param key: The string identifier of the parameter.
+ :param min_value: The minimum of the sampling distribution.
+ :param max_value: The maximum of the sampling distribution.
+ :param seed: The random seed to initialize the sampler.
+ """
+ msg = OutgoingMessage()
+ msg.write_string(key)
+ msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
+ msg.write_int32(seed)
+ msg.write_int32(self.SamplerTypes.UNIFORM)
+ msg.write_float32(min_value)
+ msg.write_float32(max_value)
+ super().queue_message_to_send(msg)
+
+ def set_gaussian_sampler_parameters(
+ self, key: str, mean: float, st_dev: float, seed: int
+ ) -> None:
+ """
+ Sets a gaussian environment parameter sampler.
+ :param key: The string identifier of the parameter.
+ :param mean: The mean of the sampling distribution.
+ :param st_dev: The standard deviation of the sampling distribution.
+ :param seed: The random seed to initialize the sampler.
+ """
+ msg = OutgoingMessage()
+ msg.write_string(key)
+ msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
+ msg.write_int32(seed)
+ msg.write_int32(self.SamplerTypes.GAUSSIAN)
+ msg.write_float32(mean)
+ msg.write_float32(st_dev)
+ super().queue_message_to_send(msg)
+
+ def set_multirangeuniform_sampler_parameters(
+ self, key: str, intervals: List[Tuple[float, float]], seed: int
+ ) -> None:
+ """
+ Sets a multirangeuniform environment parameter sampler.
+ :param key: The string identifier of the parameter.
+ :param intervals: The lists of min and max that define each uniform distribution.
+ :param seed: The random seed to initialize the sampler.
+ """
+ msg = OutgoingMessage()
+ msg.write_string(key)
+ msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
+ msg.write_int32(seed)
+ msg.write_int32(self.SamplerTypes.MULTIRANGEUNIFORM)
+ flattened_intervals = [value for interval in intervals for value in interval]
+ msg.write_float32_list(flattened_intervals)
+ super().queue_message_to_send(msg)
diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py
index 33c2c72b33..ead4b92cd8 100644
--- a/ml-agents/mlagents/trainers/learn.py
+++ b/ml-agents/mlagents/trainers/learn.py
@@ -22,8 +22,6 @@
)
from mlagents.trainers.cli_utils import parser
from mlagents_envs.environment import UnityEnvironment
-from mlagents.trainers.sampler_class import SamplerManager
-from mlagents.trainers.exception import SamplerException
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.training_status import GlobalTrainingStatus
from mlagents_envs.base_env import BaseEnv
@@ -133,9 +131,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
maybe_meta_curriculum = try_create_meta_curriculum(
options.curriculum, env_manager, restore=checkpoint_settings.resume
)
- sampler_manager, resampling_interval = create_sampler_manager(
- options.parameter_randomization, run_seed
- )
+ maybe_add_samplers(options.parameter_randomization, env_manager, run_seed)
trainer_factory = TrainerFactory(
options.behaviors,
write_path,
@@ -154,8 +150,6 @@ def run_training(run_seed: int, options: RunOptions) -> None:
maybe_meta_curriculum,
not checkpoint_settings.inference,
run_seed,
- sampler_manager,
- resampling_interval,
)
# Begin training
@@ -197,26 +191,21 @@ def write_timing_tree(output_dir: str) -> None:
)
-def create_sampler_manager(sampler_config, run_seed=None):
- resample_interval = None
+def maybe_add_samplers(
+ sampler_config: Optional[Dict], env: SubprocessEnvManager, run_seed: int
+) -> None:
+ """
+ Adds samplers to env if sampler config provided and sets seed if not configured.
+ :param sampler_config: validated dict of sampler configs. None if not included.
+ :param env: env manager to pass samplers via reset
+ :param run_seed: Random seed used for training.
+ """
if sampler_config is not None:
- if "resampling-interval" in sampler_config:
- # Filter arguments that do not exist in the environment
- resample_interval = sampler_config.pop("resampling-interval")
- if (resample_interval <= 0) or (not isinstance(resample_interval, int)):
- raise SamplerException(
- "Specified resampling-interval is not valid. Please provide"
- " a positive integer value for resampling-interval"
- )
-
- else:
- raise SamplerException(
- "Resampling interval was not specified in the sampler file."
- " Please specify it with the 'resampling-interval' key in the sampler config file."
- )
-
- sampler_manager = SamplerManager(sampler_config, run_seed)
- return sampler_manager, resample_interval
+ # If the seed is not specified in yaml, this will grab the run seed
+ for offset, v in enumerate(sampler_config.values()):
+ if v.seed == -1:
+ v.seed = run_seed + offset
+ env.reset(config=sampler_config)
def try_create_meta_curriculum(
diff --git a/ml-agents/mlagents/trainers/sampler_class.py b/ml-agents/mlagents/trainers/sampler_class.py
deleted file mode 100644
index f1a7c20327..0000000000
--- a/ml-agents/mlagents/trainers/sampler_class.py
+++ /dev/null
@@ -1,193 +0,0 @@
-import numpy as np
-from typing import Union, Optional, Type, List, Dict, Any
-from abc import ABC, abstractmethod
-
-from mlagents.trainers.exception import SamplerException
-
-
-class Sampler(ABC):
- @abstractmethod
- def sample_parameter(self) -> float:
- pass
-
-
-class UniformSampler(Sampler):
- """
- Uniformly draws a single sample in the range [min_value, max_value).
- """
-
- def __init__(
- self,
- min_value: Union[int, float],
- max_value: Union[int, float],
- seed: Optional[int] = None,
- ):
- """
- :param min_value: minimum value of the range to be sampled uniformly from
- :param max_value: maximum value of the range to be sampled uniformly from
- :param seed: Random seed used for making draws from the uniform sampler
- """
- self.min_value = min_value
- self.max_value = max_value
- # Draw from random state to allow for consistent reset parameter draw for a seed
- self.random_state = np.random.RandomState(seed)
-
- def sample_parameter(self) -> float:
- """
- Draws and returns a sample from the specified interval
- """
- return self.random_state.uniform(self.min_value, self.max_value)
-
-
-class MultiRangeUniformSampler(Sampler):
- """
- Draws a single sample uniformly from the intervals provided. The sampler
- first picks an interval based on a weighted selection, with the weights
- assigned to an interval based on its range. After picking the range,
- it proceeds to pick a value uniformly in that range.
- """
-
- def __init__(
- self, intervals: List[List[Union[int, float]]], seed: Optional[int] = None
- ):
- """
- :param intervals: List of intervals to draw uniform samples from
- :param seed: Random seed used for making uniform draws from the specified intervals
- """
- self.intervals = intervals
- # Measure the length of the intervals
- interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]
- cum_interval_length = sum(interval_lengths)
- # Assign weights to an interval proportionate to the interval size
- self.interval_weights = [x / cum_interval_length for x in interval_lengths]
- # Draw from random state to allow for consistent reset parameter draw for a seed
- self.random_state = np.random.RandomState(seed)
-
- def sample_parameter(self) -> float:
- """
- Selects an interval to pick and then draws a uniform sample from the picked interval
- """
- cur_min, cur_max = self.intervals[
- self.random_state.choice(len(self.intervals), p=self.interval_weights)
- ]
- return self.random_state.uniform(cur_min, cur_max)
-
-
-class GaussianSampler(Sampler):
- """
- Draw a single sample value from a normal (gaussian) distribution.
- This sampler is characterized by the mean and the standard deviation.
- """
-
- def __init__(
- self,
- mean: Union[float, int],
- st_dev: Union[float, int],
- seed: Optional[int] = None,
- ):
- """
- :param mean: Specifies the mean of the gaussian distribution to draw from
- :param st_dev: Specifies the standard devation of the gaussian distribution to draw from
- :param seed: Random seed used for making gaussian draws from the sample
- """
- self.mean = mean
- self.st_dev = st_dev
- # Draw from random state to allow for consistent reset parameter draw for a seed
- self.random_state = np.random.RandomState(seed)
-
- def sample_parameter(self) -> float:
- """
- Returns a draw from the specified Gaussian distribution
- """
- return self.random_state.normal(self.mean, self.st_dev)
-
-
-class SamplerFactory:
- """
- Maintain a directory of all samplers available.
- Add new samplers using the register_sampler method.
- """
-
- NAME_TO_CLASS = {
- "uniform": UniformSampler,
- "gaussian": GaussianSampler,
- "multirange_uniform": MultiRangeUniformSampler,
- }
-
- @staticmethod
- def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None:
- """
- Registers the sampe in the Sampler Factory to be used later
- :param name: String name to set as key for the sampler_cls in the factory
- :param sampler_cls: Sampler object to associate to the name in the factory
- """
- SamplerFactory.NAME_TO_CLASS[name] = sampler_cls
-
- @staticmethod
- def init_sampler_class(
- name: str, params: Dict[str, Any], seed: Optional[int] = None
- ) -> Sampler:
- """
- Initializes the sampler class associated with the name with the params
- :param name: Name of the sampler in the factory to initialize
- :param params: Parameters associated to the sampler attached to the name
- :param seed: Random seed to be used to set deterministic random draws for the sampler
- """
- if name not in SamplerFactory.NAME_TO_CLASS:
- raise SamplerException(
- name + " sampler is not registered in the SamplerFactory."
- " Use the register_sample method to register the string"
- " associated to your sampler in the SamplerFactory."
- )
- sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
- params["seed"] = seed
- try:
- return sampler_cls(**params)
- except TypeError:
- raise SamplerException(
- "The sampler class associated to the " + name + " key in the factory "
- "was not provided the required arguments. Please ensure that the sampler "
- "config file consists of the appropriate keys for this sampler class."
- )
-
-
-class SamplerManager:
- def __init__(
- self, reset_param_dict: Dict[str, Any], seed: Optional[int] = None
- ) -> None:
- """
- :param reset_param_dict: Arguments needed for initializing the samplers
- :param seed: Random seed to be used for drawing samples from the samplers
- """
- self.reset_param_dict = reset_param_dict if reset_param_dict else {}
- assert isinstance(self.reset_param_dict, dict)
- self.samplers: Dict[str, Sampler] = {}
- for param_name, cur_param_dict in self.reset_param_dict.items():
- if "sampler-type" not in cur_param_dict:
- raise SamplerException(
- "'sampler_type' argument hasn't been supplied for the {0} parameter".format(
- param_name
- )
- )
- sampler_name = cur_param_dict.pop("sampler-type")
- param_sampler = SamplerFactory.init_sampler_class(
- sampler_name, cur_param_dict, seed
- )
-
- self.samplers[param_name] = param_sampler
-
- def is_empty(self) -> bool:
- """
- Check for if sampler_manager is empty.
- """
- return not bool(self.samplers)
-
- def sample_all(self) -> Dict[str, float]:
- """
- Loop over all samplers and draw a sample from each one for generating
- next set of reset parameter values.
- """
- res = {}
- for param_name, param_sampler in list(self.samplers.items()):
- res[param_name] = param_sampler.sample_parameter()
- return res
diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py
index 5c7830e4b8..49a8e8036b 100644
--- a/ml-agents/mlagents/trainers/settings.py
+++ b/ml-agents/mlagents/trainers/settings.py
@@ -1,15 +1,23 @@
import attr
import cattr
-from typing import Dict, Optional, List, Any, DefaultDict, Mapping
+from typing import Dict, Optional, List, Any, DefaultDict, Mapping, Tuple
from enum import Enum
import collections
import argparse
+import abc
from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser
from mlagents.trainers.cli_utils import load_config
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.models import ScheduleType, EncoderType
+from mlagents_envs import logging_util
+from mlagents_envs.side_channel.environment_parameters_channel import (
+ EnvironmentParametersChannel,
+)
+
+logger = logging_util.get_logger(__name__)
+
def check_and_structure(key: str, value: Any, class_type: type) -> Any:
attr_fields_dict = attr.fields_dict(class_type)
@@ -151,6 +159,148 @@ class CuriositySettings(RewardSignalSettings):
learning_rate: float = 3e-4
+class ParameterRandomizationType(Enum):
+ UNIFORM: str = "uniform"
+ GAUSSIAN: str = "gaussian"
+ MULTIRANGEUNIFORM: str = "multirangeuniform"
+
+ def to_settings(self) -> type:
+ _mapping = {
+ ParameterRandomizationType.UNIFORM: UniformSettings,
+ ParameterRandomizationType.GAUSSIAN: GaussianSettings,
+ ParameterRandomizationType.MULTIRANGEUNIFORM: MultiRangeUniformSettings,
+ }
+ return _mapping[self]
+
+
+@attr.s(auto_attribs=True)
+class ParameterRandomizationSettings(abc.ABC):
+ seed: int = parser.get_default("seed")
+
+ @staticmethod
+ def structure(d: Mapping, t: type) -> Any:
+ """
+ Helper method to structure a Dict of ParameterRandomizationSettings class. Meant to be registered with
+ cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle
+ the special Enum selection of ParameterRandomizationSettings classes.
+ """
+ if not isinstance(d, Mapping):
+ raise TrainerConfigError(
+ f"Unsupported parameter randomization configuration {d}."
+ )
+ d_final: Dict[str, List[float]] = {}
+ for environment_parameter, environment_parameter_config in d.items():
+ if environment_parameter == "resampling-interval":
+ logger.warning(
+ "The resampling-interval is no longer necessary for parameter randomization. It is being ignored."
+ )
+ continue
+ if "sampler_type" not in environment_parameter_config:
+ raise TrainerConfigError(
+ f"Sampler configuration for {environment_parameter} does not contain sampler_type."
+ )
+ if "sampler_parameters" not in environment_parameter_config:
+ raise TrainerConfigError(
+ f"Sampler configuration for {environment_parameter} does not contain sampler_parameters."
+ )
+ enum_key = ParameterRandomizationType(
+ environment_parameter_config["sampler_type"]
+ )
+ t = enum_key.to_settings()
+ d_final[environment_parameter] = strict_to_cls(
+ environment_parameter_config["sampler_parameters"], t
+ )
+ return d_final
+
+ @abc.abstractmethod
+ def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
+ """
+ Helper method to send sampler settings over EnvironmentParametersChannel
+ Calls the appropriate sampler type set method.
+ :param key: environment parameter to be sampled
+ :param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment
+ """
+ pass
+
+
+@attr.s(auto_attribs=True)
+class UniformSettings(ParameterRandomizationSettings):
+ min_value: float = attr.ib()
+ max_value: float = 1.0
+
+ @min_value.default
+ def _min_value_default(self):
+ return 0.0
+
+ @min_value.validator
+ def _check_min_value(self, attribute, value):
+ if self.min_value > self.max_value:
+ raise TrainerConfigError(
+ "Minimum value is greater than maximum value in uniform sampler."
+ )
+
+ def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
+ """
+ Helper method to send sampler settings over EnvironmentParametersChannel
+ Calls the uniform sampler type set method.
+ :param key: environment parameter to be sampled
+ :param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment
+ """
+ env_channel.set_uniform_sampler_parameters(
+ key, self.min_value, self.max_value, self.seed
+ )
+
+
+@attr.s(auto_attribs=True)
+class GaussianSettings(ParameterRandomizationSettings):
+ mean: float = 1.0
+ st_dev: float = 1.0
+
+ def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
+ """
+ Helper method to send sampler settings over EnvironmentParametersChannel
+ Calls the gaussian sampler type set method.
+ :param key: environment parameter to be sampled
+ :param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment
+ """
+ env_channel.set_gaussian_sampler_parameters(
+ key, self.mean, self.st_dev, self.seed
+ )
+
+
+@attr.s(auto_attribs=True)
+class MultiRangeUniformSettings(ParameterRandomizationSettings):
+ intervals: List[Tuple[float, float]] = attr.ib()
+
+ @intervals.default
+ def _intervals_default(self):
+ return [[0.0, 1.0]]
+
+ @intervals.validator
+ def _check_intervals(self, attribute, value):
+ for interval in self.intervals:
+ if len(interval) != 2:
+ raise TrainerConfigError(
+ f"The sampling interval {interval} must contain exactly two values."
+ )
+ min_value, max_value = interval
+ if min_value > max_value:
+ raise TrainerConfigError(
+ f"Minimum value is greater than maximum value in interval {interval}."
+ )
+
+ def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
+ """
+ Helper method to send sampler settings over EnvironmentParametersChannel
+ Calls the multirangeuniform sampler type set method.
+ :param key: environment parameter to be sampled
+ :param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment
+ """
+ env_channel.set_multirangeuniform_sampler_parameters(
+ key, self.intervals, self.seed
+ )
+
+
@attr.s(auto_attribs=True)
class SelfPlaySettings:
save_steps: int = 20000
@@ -303,7 +453,7 @@ class RunOptions(ExportableSettings):
)
env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings)
engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
- parameter_randomization: Optional[Dict] = None
+ parameter_randomization: Optional[Dict[str, ParameterRandomizationSettings]] = None
curriculum: Optional[Dict[str, CurriculumSettings]] = None
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)
@@ -314,6 +464,10 @@ class RunOptions(ExportableSettings):
cattr.register_structure_hook(EnvironmentSettings, strict_to_cls)
cattr.register_structure_hook(EngineSettings, strict_to_cls)
cattr.register_structure_hook(CheckpointSettings, strict_to_cls)
+ cattr.register_structure_hook(
+ Dict[str, ParameterRandomizationSettings],
+ ParameterRandomizationSettings.structure,
+ )
cattr.register_structure_hook(CurriculumSettings, strict_to_cls)
cattr.register_structure_hook(TrainerSettings, TrainerSettings.structure)
cattr.register_structure_hook(
diff --git a/ml-agents/mlagents/trainers/simple_env_manager.py b/ml-agents/mlagents/trainers/simple_env_manager.py
index b335efd6fc..98cdfbbe99 100644
--- a/ml-agents/mlagents/trainers/simple_env_manager.py
+++ b/ml-agents/mlagents/trainers/simple_env_manager.py
@@ -5,6 +5,7 @@
from mlagents_envs.timers import timed
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.brain import BrainParameters
+from mlagents.trainers.settings import ParameterRandomizationSettings
from mlagents_envs.side_channel.environment_parameters_channel import (
EnvironmentParametersChannel,
)
@@ -44,7 +45,10 @@ def _reset_env(
) -> List[EnvironmentStep]: # type: ignore
if config is not None:
for k, v in config.items():
- self.env_params.set_float_parameter(k, v)
+ if isinstance(v, float):
+ self.env_params.set_float_parameter(k, v)
+ elif isinstance(v, ParameterRandomizationSettings):
+ v.apply(k, self.env_params)
self.env.reset()
all_step_result = self._generate_all_results()
self.previous_step = EnvironmentStep(all_step_result, 0, {}, {})
diff --git a/ml-agents/mlagents/trainers/subprocess_env_manager.py b/ml-agents/mlagents/trainers/subprocess_env_manager.py
index 0687cbdf14..8bf2e4e771 100644
--- a/ml-agents/mlagents/trainers/subprocess_env_manager.py
+++ b/ml-agents/mlagents/trainers/subprocess_env_manager.py
@@ -23,6 +23,7 @@
get_timer_root,
)
from mlagents.trainers.brain import BrainParameters
+from mlagents.trainers.settings import ParameterRandomizationSettings
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.side_channel.environment_parameters_channel import (
EnvironmentParametersChannel,
@@ -175,7 +176,10 @@ def external_brains():
_send_response(EnvironmentCommand.EXTERNAL_BRAINS, external_brains())
elif req.cmd == EnvironmentCommand.RESET:
for k, v in req.payload.items():
- env_parameters.set_float_parameter(k, v)
+ if isinstance(v, float):
+ env_parameters.set_float_parameter(k, v)
+ elif isinstance(v, ParameterRandomizationSettings):
+ v.apply(k, env_parameters)
env.reset()
all_step_result = _generate_all_results()
_send_response(EnvironmentCommand.RESET, all_step_result)
diff --git a/ml-agents/mlagents/trainers/tests/test_config_conversion.py b/ml-agents/mlagents/trainers/tests/test_config_conversion.py
index 00bfc42ac0..49a1489f12 100644
--- a/ml-agents/mlagents/trainers/tests/test_config_conversion.py
+++ b/ml-agents/mlagents/trainers/tests/test_config_conversion.py
@@ -152,12 +152,20 @@ def test_convert_behaviors(trainer_type, use_recurrent):
assert RewardSignalType.CURIOSITY in trainer_settings.reward_signals
+@mock.patch("mlagents.trainers.upgrade_config.convert_samplers")
@mock.patch("mlagents.trainers.upgrade_config.convert_behaviors")
@mock.patch("mlagents.trainers.upgrade_config.remove_nones")
@mock.patch("mlagents.trainers.upgrade_config.write_to_yaml_file")
@mock.patch("mlagents.trainers.upgrade_config.parse_args")
@mock.patch("mlagents.trainers.upgrade_config.load_config")
-def test_main(mock_load, mock_parse, yaml_write_mock, remove_none_mock, mock_convert):
+def test_main(
+ mock_load,
+ mock_parse,
+ yaml_write_mock,
+ remove_none_mock,
+ mock_convert_behaviors,
+ mock_convert_samplers,
+):
test_output_file = "test.yaml"
mock_load.side_effect = [
yaml.safe_load(PPO_CONFIG),
@@ -171,7 +179,8 @@ def test_main(mock_load, mock_parse, yaml_write_mock, remove_none_mock, mock_con
sampler="test",
)
mock_parse.return_value = mock_args
- mock_convert.return_value = "test_converted_config"
+ mock_convert_behaviors.return_value = "test_converted_config"
+ mock_convert_samplers.return_value = "test_converted_sampler_config"
dict_without_nones = mock.Mock(name="nonones")
remove_none_mock.return_value = dict_without_nones
@@ -181,7 +190,7 @@ def test_main(mock_load, mock_parse, yaml_write_mock, remove_none_mock, mock_con
yaml_write_mock.assert_called_with(dict_without_nones, test_output_file)
assert saved_dict["behaviors"] == "test_converted_config"
assert saved_dict["curriculum"] == "test_curriculum_config"
- assert saved_dict["parameter_randomization"] == "test_sampler_config"
+ assert saved_dict["parameter_randomization"] == "test_converted_sampler_config"
def test_remove_nones():
diff --git a/ml-agents/mlagents/trainers/tests/test_learn.py b/ml-agents/mlagents/trainers/tests/test_learn.py
index 5a67036ade..167fe157e3 100644
--- a/ml-agents/mlagents/trainers/tests/test_learn.py
+++ b/ml-agents/mlagents/trainers/tests/test_learn.py
@@ -7,6 +7,7 @@
from mlagents.trainers.cli_utils import DetectDefault
from mlagents_envs.exception import UnityEnvironmentException
from mlagents.trainers.stats import StatsReporter
+from mlagents.trainers.settings import UniformSettings
def basic_options(extra_args=None):
@@ -45,7 +46,10 @@ def basic_options(extra_args=None):
MOCK_SAMPLER_CURRICULUM_YAML = """
parameter_randomization:
- sampler1: foo
+ sampler1:
+ sampler_type: uniform
+ sampler_parameters:
+ min_value: 0.2
curriculum:
behavior1:
@@ -61,7 +65,6 @@ def basic_options(extra_args=None):
@patch("mlagents.trainers.learn.write_run_options")
@patch("mlagents.trainers.learn.handle_existing_directories")
@patch("mlagents.trainers.learn.TrainerFactory")
-@patch("mlagents.trainers.learn.SamplerManager")
@patch("mlagents.trainers.learn.SubprocessEnvManager")
@patch("mlagents.trainers.learn.create_environment_factory")
@patch("mlagents.trainers.settings.load_config")
@@ -69,7 +72,6 @@ def test_run_training(
load_config,
create_environment_factory,
subproc_env_mock,
- sampler_manager_mock,
trainer_factory_mock,
handle_dir_mock,
write_run_options_mock,
@@ -87,14 +89,7 @@ def test_run_training(
options = basic_options()
learn.run_training(0, options)
mock_init.assert_called_once_with(
- trainer_factory_mock.return_value,
- "results/ppo",
- "ppo",
- None,
- True,
- 0,
- sampler_manager_mock.return_value,
- None,
+ trainer_factory_mock.return_value, "results/ppo", "ppo", None, True, 0
)
handle_dir_mock.assert_called_once_with(
"results/ppo", False, False, "results/notuselessrun"
@@ -216,7 +211,7 @@ def test_yaml_args(mock_file):
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_SAMPLER_CURRICULUM_YAML)
def test_sampler_configs(mock_file):
opt = parse_command_line(["mytrainerpath"])
- assert opt.parameter_randomization == {"sampler1": "foo"}
+ assert isinstance(opt.parameter_randomization["sampler1"], UniformSettings)
assert len(opt.curriculum.keys()) == 2
diff --git a/ml-agents/mlagents/trainers/tests/test_sampler_class.py b/ml-agents/mlagents/trainers/tests/test_sampler_class.py
deleted file mode 100644
index 29954c6599..0000000000
--- a/ml-agents/mlagents/trainers/tests/test_sampler_class.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import pytest
-
-from mlagents.trainers.sampler_class import SamplerManager
-from mlagents.trainers.sampler_class import (
- UniformSampler,
- MultiRangeUniformSampler,
- GaussianSampler,
-)
-from mlagents.trainers.exception import TrainerError
-
-
-def sampler_config_1():
- return {
- "mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10},
- "gravity": {
- "sampler-type": "multirange_uniform",
- "intervals": [[8, 11], [15, 20]],
- },
- }
-
-
-def check_value_in_intervals(val, intervals):
- check_in_bounds = [a <= val <= b for a, b in intervals]
- return any(check_in_bounds)
-
-
-def test_sampler_config_1():
- config = sampler_config_1()
- sampler = SamplerManager(config)
-
- assert sampler.is_empty() is False
- assert isinstance(sampler.samplers["mass"], UniformSampler)
- assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler)
-
- cur_sample = sampler.sample_all()
-
- # Check uniform sampler for mass
- assert sampler.samplers["mass"].min_value == config["mass"]["min_value"]
- assert sampler.samplers["mass"].max_value == config["mass"]["max_value"]
- assert config["mass"]["min_value"] <= cur_sample["mass"]
- assert config["mass"]["max_value"] >= cur_sample["mass"]
-
- # Check multirange_uniform sampler for gravity
- assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"]
- assert check_value_in_intervals(
- cur_sample["gravity"], sampler.samplers["gravity"].intervals
- )
-
-
-def sampler_config_2():
- return {"angle": {"sampler-type": "gaussian", "mean": 0, "st_dev": 1}}
-
-
-def test_sampler_config_2():
- config = sampler_config_2()
- sampler = SamplerManager(config)
- assert sampler.is_empty() is False
- assert isinstance(sampler.samplers["angle"], GaussianSampler)
-
- # Check angle gaussian sampler
- assert sampler.samplers["angle"].mean == config["angle"]["mean"]
- assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"]
-
-
-def test_empty_samplers():
- empty_sampler = SamplerManager({})
- assert empty_sampler.is_empty()
- empty_cur_sample = empty_sampler.sample_all()
- assert empty_cur_sample == {}
-
- none_sampler = SamplerManager(None)
- assert none_sampler.is_empty()
- none_cur_sample = none_sampler.sample_all()
- assert none_cur_sample == {}
-
-
-def incorrect_uniform_sampler():
- # Do not specify required arguments to uniform sampler
- return {"mass": {"sampler-type": "uniform", "min-value": 10}}
-
-
-def incorrect_sampler_config():
- # Do not specify 'sampler-type' key
- return {"mass": {"min-value": 2, "max-value": 30}}
-
-
-def test_incorrect_uniform_sampler():
- config = incorrect_uniform_sampler()
- with pytest.raises(TrainerError):
- SamplerManager(config)
-
-
-def test_incorrect_sampler():
- config = incorrect_sampler_config()
- with pytest.raises(TrainerError):
- SamplerManager(config)
diff --git a/ml-agents/mlagents/trainers/tests/test_settings.py b/ml-agents/mlagents/trainers/tests/test_settings.py
index 6a8b2b9355..14928599fa 100644
--- a/ml-agents/mlagents/trainers/tests/test_settings.py
+++ b/ml-agents/mlagents/trainers/tests/test_settings.py
@@ -11,6 +11,10 @@
RewardSignalType,
RewardSignalSettings,
CuriositySettings,
+ ParameterRandomizationSettings,
+ UniformSettings,
+ GaussianSettings,
+ MultiRangeUniformSettings,
TrainerType,
strict_to_cls,
)
@@ -149,3 +153,85 @@ def test_reward_signal_structure():
RewardSignalSettings.structure(
"notadict", Dict[RewardSignalType, RewardSignalSettings]
)
+
+
+def test_parameter_randomization_structure():
+ """
+ Tests the ParameterRandomizationSettings structure method and all validators.
+ """
+ parameter_randomization_dict = {
+ "mass": {
+ "sampler_type": "uniform",
+ "sampler_parameters": {"min_value": 1.0, "max_value": 2.0},
+ },
+ "scale": {
+ "sampler_type": "gaussian",
+ "sampler_parameters": {"mean": 1.0, "st_dev": 2.0},
+ },
+ "length": {
+ "sampler_type": "multirangeuniform",
+ "sampler_parameters": {"intervals": [[1.0, 2.0], [3.0, 4.0]]},
+ },
+ }
+ parameter_randomization_distributions = ParameterRandomizationSettings.structure(
+ parameter_randomization_dict, Dict[str, ParameterRandomizationSettings]
+ )
+ assert isinstance(parameter_randomization_distributions["mass"], UniformSettings)
+ assert isinstance(parameter_randomization_distributions["scale"], GaussianSettings)
+ assert isinstance(
+ parameter_randomization_distributions["length"], MultiRangeUniformSettings
+ )
+
+ # Check invalid distribution type
+ invalid_distribution_dict = {
+ "mass": {
+ "sampler_type": "beta",
+ "sampler_parameters": {"alpha": 1.0, "beta": 2.0},
+ }
+ }
+ with pytest.raises(ValueError):
+ ParameterRandomizationSettings.structure(
+ invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
+ )
+
+ # Check min less than max in uniform
+ invalid_distribution_dict = {
+ "mass": {
+ "sampler_type": "uniform",
+ "sampler_parameters": {"min_value": 2.0, "max_value": 1.0},
+ }
+ }
+ with pytest.raises(TrainerConfigError):
+ ParameterRandomizationSettings.structure(
+ invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
+ )
+
+ # Check min less than max in multirange
+ invalid_distribution_dict = {
+ "mass": {
+ "sampler_type": "multirangeuniform",
+ "sampler_parameters": {"intervals": [[2.0, 1.0]]},
+ }
+ }
+ with pytest.raises(TrainerConfigError):
+ ParameterRandomizationSettings.structure(
+ invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
+ )
+
+ # Check multirange has valid intervals
+ invalid_distribution_dict = {
+ "mass": {
+ "sampler_type": "multirangeuniform",
+ "sampler_parameters": {"intervals": [[1.0, 2.0], [3.0]]},
+ }
+ }
+ with pytest.raises(TrainerConfigError):
+ ParameterRandomizationSettings.structure(
+ invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
+ )
+
+ # Check non-Dict input
+ with pytest.raises(TrainerConfigError):
+ ParameterRandomizationSettings.structure(
+ "notadict", Dict[str, ParameterRandomizationSettings]
+ )
diff --git a/ml-agents/mlagents/trainers/tests/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/test_simple_rl.py
index 0407601957..67d3c66617 100644
--- a/ml-agents/mlagents/trainers/tests/test_simple_rl.py
+++ b/ml-agents/mlagents/trainers/tests/test_simple_rl.py
@@ -13,7 +13,6 @@
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.simple_env_manager import SimpleEnvManager
-from mlagents.trainers.sampler_class import SamplerManager
from mlagents.trainers.demo_loader import write_demo
from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary
from mlagents.trainers.settings import (
@@ -138,8 +137,6 @@ def _check_environment_trains(
meta_curriculum=meta_curriculum,
train=True,
training_seed=seed,
- sampler_manager=SamplerManager(None),
- resampling_interval=None,
)
# Begin training
diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_controller.py b/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
index cf2f872531..8a0280dc4b 100644
--- a/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
+++ b/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
@@ -4,7 +4,6 @@
from mlagents.tf_utils import tf
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.ghost.controller import GhostController
-from mlagents.trainers.sampler_class import SamplerManager
@pytest.fixture
@@ -18,8 +17,6 @@ def basic_trainer_controller():
meta_curriculum=None,
train=True,
training_seed=99,
- sampler_manager=SamplerManager({}),
- resampling_interval=None,
)
@@ -36,8 +33,6 @@ def test_initialization_seed(numpy_random_seed, tensorflow_set_seed):
meta_curriculum=None,
train=True,
training_seed=seed,
- sampler_manager=SamplerManager({}),
- resampling_interval=None,
)
numpy_random_seed.assert_called_with(seed)
tensorflow_set_seed.assert_called_with(seed)
diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py
index 3a8a74d15f..fdb73bca03 100644
--- a/ml-agents/mlagents/trainers/trainer_controller.py
+++ b/ml-agents/mlagents/trainers/trainer_controller.py
@@ -17,7 +17,6 @@
UnityCommunicationException,
UnityCommunicatorStoppedException,
)
-from mlagents.trainers.sampler_class import SamplerManager
from mlagents_envs.timers import (
hierarchical_timer,
timed,
@@ -42,8 +41,6 @@ def __init__(
meta_curriculum: Optional[MetaCurriculum],
train: bool,
training_seed: int,
- sampler_manager: SamplerManager,
- resampling_interval: Optional[int],
):
"""
:param output_path: Path to save the model.
@@ -52,8 +49,6 @@ def __init__(
:param meta_curriculum: MetaCurriculum object which stores information about all curricula.
:param train: Whether to train model, or only run inference.
:param training_seed: Seed to use for Numpy and Tensorflow random number generation.
- :param sampler_manager: SamplerManager object handles samplers for resampling the reset parameters.
- :param resampling_interval: Specifies number of simulation steps after which reset parameters are resampled.
:param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging.
"""
self.trainers: Dict[str, Trainer] = {}
@@ -64,8 +59,6 @@ def __init__(
self.run_id = run_id
self.train_model = train
self.meta_curriculum = meta_curriculum
- self.sampler_manager = sampler_manager
- self.resampling_interval = resampling_interval
self.ghost_controller = self.trainer_factory.ghost_controller
self.trainer_threads: List[threading.Thread] = []
@@ -142,12 +135,10 @@ def _reset_env(self, env: EnvManager) -> None:
A Data structure corresponding to the initial reset state of the
environment.
"""
- sampled_reset_param = self.sampler_manager.sample_all()
new_meta_curriculum_config = (
self.meta_curriculum.get_config() if self.meta_curriculum else {}
)
- sampled_reset_param.update(new_meta_curriculum_config)
- env.reset(config=sampled_reset_param)
+ env.reset(config=new_meta_curriculum_config)
def _not_done_training(self) -> bool:
return (
@@ -207,7 +198,6 @@ def _create_trainers_and_managers(
def start_learning(self, env_manager: EnvManager) -> None:
self._create_output_path(self.output_path)
tf.reset_default_graph()
- global_step = 0
last_brain_behavior_ids: Set[str] = set()
try:
# Initial reset
@@ -219,8 +209,7 @@ def start_learning(self, env_manager: EnvManager) -> None:
last_brain_behavior_ids = external_brain_behavior_ids
n_steps = self.advance(env_manager)
for _ in range(n_steps):
- global_step += 1
- self.reset_env_if_ready(env_manager, global_step)
+ self.reset_env_if_ready(env_manager)
# Stop advancing trainers
self.join_threads()
except (
@@ -258,7 +247,7 @@ def end_trainer_episodes(
if changed:
self.trainers[brain_name].reward_buffer.clear()
- def reset_env_if_ready(self, env: EnvManager, steps: int) -> None:
+ def reset_env_if_ready(self, env: EnvManager) -> None:
if self.meta_curriculum:
# Get the sizes of the reward buffers.
reward_buff_sizes = {
@@ -274,16 +263,9 @@ def reset_env_if_ready(self, env: EnvManager, steps: int) -> None:
# If any lessons were incremented or the environment is
# ready to be reset
meta_curriculum_reset = any(lessons_incremented.values())
- # Check if we are performing generalization training and we have finished the
- # specified number of steps for the lesson
- generalization_reset = (
- not self.sampler_manager.is_empty()
- and (steps != 0)
- and (self.resampling_interval)
- and (steps % self.resampling_interval == 0)
- )
+ # If ghost trainer swapped teams
ghost_controller_reset = self.ghost_controller.should_reset()
- if meta_curriculum_reset or generalization_reset or ghost_controller_reset:
+ if meta_curriculum_reset or ghost_controller_reset:
self.end_trainer_episodes(env, lessons_incremented)
@timed
diff --git a/ml-agents/mlagents/trainers/upgrade_config.py b/ml-agents/mlagents/trainers/upgrade_config.py
index 7425708ddb..4263e6cf1f 100644
--- a/ml-agents/mlagents/trainers/upgrade_config.py
+++ b/ml-agents/mlagents/trainers/upgrade_config.py
@@ -82,6 +82,23 @@ def remove_nones(config: Dict[Any, Any]) -> Dict[str, Any]:
return new_config
+# Take a sampler from the old format and convert to new sampler structure
+def convert_samplers(old_sampler_config: Dict[str, Any]) -> Dict[str, Any]:
+ new_sampler_config: Dict[str, Any] = {}
+ for parameter, parameter_config in old_sampler_config.items():
+ if parameter == "resampling-interval":
+ print(
+ "resampling-interval is no longer necessary for parameter randomization and is being ignored."
+ )
+ continue
+ new_sampler_config[parameter] = {}
+ new_sampler_config[parameter]["sampler_type"] = parameter_config["sampler-type"]
+ new_samp_parameters = dict(parameter_config) # Copy dict
+ new_samp_parameters.pop("sampler-type")
+ new_sampler_config[parameter]["sampler_parameters"] = new_samp_parameters
+ return new_sampler_config
+
+
def parse_args():
argparser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -124,7 +141,8 @@ def main() -> None:
full_config["curriculum"] = curriculum_config_dict
if args.sampler is not None:
- sampler_config_dict = load_config(args.sampler)
+ old_sampler_config_dict = load_config(args.sampler)
+ sampler_config_dict = convert_samplers(old_sampler_config_dict)
full_config["parameter_randomization"] = sampler_config_dict
# Convert config to dict