Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ removed when training with a player. The Editor still requires it to be clamped
Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849)

#### ml-agents / ml-agents-envs / gym-unity (Python)
- Added a `--torch-device` commandline option to `mlagent-learn`, which sets the default
[`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device) used for training. (#4888)

### Bug Fixes
#### com.unity.ml-agents (C#)
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/torch_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mlagents.torch_utils.torch import torch as torch # noqa
from mlagents.torch_utils.torch import nn # noqa
from mlagents.torch_utils.torch import set_torch_config # noqa
from mlagents.torch_utils.torch import default_device # noqa
37 changes: 30 additions & 7 deletions ml-agents/mlagents/torch_utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from distutils.version import LooseVersion
import pkg_resources
from mlagents.torch_utils import cpu_utils
from mlagents.trainers.settings import TorchSettings
from mlagents_envs.logging_util import get_logger


logger = get_logger(__name__)


def assert_torch_installed():
Expand Down Expand Up @@ -32,14 +37,32 @@ def assert_torch_installed():
torch.set_num_threads(cpu_utils.get_num_threads_to_use())
os.environ["KMP_BLOCKTIME"] = "0"

if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
device = torch.device("cuda")
else:
torch.set_default_tensor_type(torch.FloatTensor)
device = torch.device("cpu")

_device = torch.device("cpu")


def set_torch_config(torch_settings: TorchSettings) -> None:
global _device

if torch_settings.device is None:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
else:
device_str = torch_settings.device

_device = torch.device(device_str)

if _device.type == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
logger.info(f"default Torch device: {_device}")


# Initialize to default settings
set_torch_config(TorchSettings(device=None))

nn = torch.nn


def default_device():
return device
return _device
9 changes: 9 additions & 0 deletions ml-agents/mlagents/trainers/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,15 @@ def _create_parser() -> argparse.ArgumentParser:
help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
"the graphics driver. Use this only if your agents don't use visual observations.",
)

torch_conf = argparser.add_argument_group(title="Torch Configuration")
torch_conf.add_argument(
"--torch-device",
default=None,
dest="device",
action=DetectDefault,
help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"',
)
return argparser


Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
:param run_options: Command line arguments for training.
"""
with hierarchical_timer("run_training.setup"):
torch_utils.set_torch_config(options.torch_settings)
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings
engine_settings = options.engine_settings
Expand Down
9 changes: 9 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,11 @@ class EngineSettings:
no_graphics: bool = parser.get_default("no_graphics")


@attr.s(auto_attribs=True)
class TorchSettings:
device: Optional[str] = parser.get_default("torch_device")


@attr.s(auto_attribs=True)
class RunOptions(ExportableSettings):
default_settings: Optional[TrainerSettings] = None
Expand All @@ -743,6 +748,7 @@ class RunOptions(ExportableSettings):
engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)
torch_settings: TorchSettings = attr.ib(factory=TorchSettings)

# These are options that are relevant to the run itself, and not the engine or environment.
# They will be left here.
Expand Down Expand Up @@ -784,6 +790,7 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
"checkpoint_settings": {},
"env_settings": {},
"engine_settings": {},
"torch_settings": {},
}
if config_path is not None:
configured_dict.update(load_config(config_path))
Expand All @@ -808,6 +815,8 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
configured_dict["env_settings"][key] = val
elif key in attr.fields_dict(EngineSettings):
configured_dict["engine_settings"][key] = val
elif key in attr.fields_dict(TorchSettings):
configured_dict["torch_settings"][key] = val
else: # Base options
configured_dict[key] = val

Expand Down
41 changes: 41 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
from unittest import mock

import torch # noqa I201

from mlagents.torch_utils import set_torch_config, default_device
from mlagents.trainers.settings import TorchSettings


@pytest.mark.parametrize(
"device_str, expected_type, expected_index, expected_tensor_type",
[
("cpu", "cpu", None, torch.FloatTensor),
("cuda", "cuda", None, torch.cuda.FloatTensor),
("cuda:42", "cuda", 42, torch.cuda.FloatTensor),
("opengl", "opengl", None, torch.FloatTensor),
],
)
@mock.patch.object(torch, "set_default_tensor_type")
def test_set_torch_device(
mock_set_default_tensor_type,
device_str,
expected_type,
expected_index,
expected_tensor_type,
):
try:
torch_settings = TorchSettings(device=device_str)
set_torch_config(torch_settings)
assert default_device().type == expected_type
if expected_index is None:
assert default_device().index is None
else:
assert default_device().index == expected_index
mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type)
except Exception:
raise
finally:
# restore the defaults
torch_settings = TorchSettings(device=None)
set_torch_config(torch_settings)