Skip to content

Commit 82fb259

Browse files
author
Chris Elion
authored
Set torch device from commandline (#4888)
1 parent 56898ab commit 82fb259

File tree

8 files changed

+103
-14
lines changed

8 files changed

+103
-14
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ removed when training with a player. The Editor still requires it to be clamped
3030
`AddList()` is recommended, as it does not generate any additional memory allocations. (#4887)
3131

3232
#### ml-agents / ml-agents-envs / gym-unity (Python)
33+
- Added a `--torch-device` commandline option to `mlagents-learn`, which sets the default
34+
[`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device) used for training. (#4888)
35+
- The `--cpu` commandline option had no effect and was removed. Use `--torch-device=cpu` to force CPU training. (#4888)
3336

3437
### Bug Fixes
3538
#### com.unity.ml-agents (C#)

docs/Training-ML-Agents.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ using the help utility:
188188
mlagents-learn --help
189189
```
190190

191-
These additional CLI arguments are grouped into environment, engine and checkpoint. The available settings and example values are shown below.
191+
These additional CLI arguments are grouped into environment, engine, checkpoint and torch.
192+
The available settings and example values are shown below.
192193

193194
#### Environment settings
194195

@@ -227,6 +228,13 @@ checkpoint_settings:
227228
inference: false
228229
```
229230
231+
#### Torch settings:
232+
233+
```yaml
234+
torch_settings:
235+
device: cpu
236+
```
237+
230238
### Behavior Configurations
231239
232240
The primary section of the trainer config file is a
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from mlagents.torch_utils.torch import torch as torch # noqa
22
from mlagents.torch_utils.torch import nn # noqa
3+
from mlagents.torch_utils.torch import set_torch_config # noqa
34
from mlagents.torch_utils.torch import default_device # noqa

ml-agents/mlagents/torch_utils/torch.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
from distutils.version import LooseVersion
44
import pkg_resources
55
from mlagents.torch_utils import cpu_utils
6+
from mlagents.trainers.settings import TorchSettings
7+
from mlagents_envs.logging_util import get_logger
8+
9+
10+
logger = get_logger(__name__)
611

712

813
def assert_torch_installed():
@@ -32,14 +37,32 @@ def assert_torch_installed():
3237
torch.set_num_threads(cpu_utils.get_num_threads_to_use())
3338
os.environ["KMP_BLOCKTIME"] = "0"
3439

35-
if torch.cuda.is_available():
36-
torch.set_default_tensor_type(torch.cuda.FloatTensor)
37-
device = torch.device("cuda")
38-
else:
39-
torch.set_default_tensor_type(torch.FloatTensor)
40-
device = torch.device("cpu")
40+
41+
_device = torch.device("cpu")
42+
43+
44+
def set_torch_config(torch_settings: TorchSettings) -> None:
45+
global _device
46+
47+
if torch_settings.device is None:
48+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
49+
else:
50+
device_str = torch_settings.device
51+
52+
_device = torch.device(device_str)
53+
54+
if _device.type == "cuda":
55+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
56+
else:
57+
torch.set_default_tensor_type(torch.FloatTensor)
58+
logger.info(f"default Torch device: {_device}")
59+
60+
61+
# Initialize to default settings
62+
set_torch_config(TorchSettings(device=None))
63+
4164
nn = torch.nn
4265

4366

4467
def default_device():
45-
return device
68+
return _device

ml-agents/mlagents/trainers/cli_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,6 @@ def _create_parser() -> argparse.ArgumentParser:
177177
"passed to the executable.",
178178
action=DetectDefault,
179179
)
180-
argparser.add_argument(
181-
"--cpu",
182-
default=False,
183-
action=DetectDefaultStoreTrue,
184-
help="Forces training using CPU only",
185-
)
186180
argparser.add_argument(
187181
"--torch",
188182
default=False,
@@ -252,6 +246,15 @@ def _create_parser() -> argparse.ArgumentParser:
252246
help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
253247
"the graphics driver. Use this only if your agents don't use visual observations.",
254248
)
249+
250+
torch_conf = argparser.add_argument_group(title="Torch Configuration")
251+
torch_conf.add_argument(
252+
"--torch-device",
253+
default=None,
254+
dest="device",
255+
action=DetectDefault,
256+
help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"',
257+
)
255258
return argparser
256259

257260

ml-agents/mlagents/trainers/learn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
6262
:param run_options: Command line arguments for training.
6363
"""
6464
with hierarchical_timer("run_training.setup"):
65+
torch_utils.set_torch_config(options.torch_settings)
6566
checkpoint_settings = options.checkpoint_settings
6667
env_settings = options.env_settings
6768
engine_settings = options.engine_settings

ml-agents/mlagents/trainers/settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,11 @@ class EngineSettings:
733733
no_graphics: bool = parser.get_default("no_graphics")
734734

735735

736+
@attr.s(auto_attribs=True)
737+
class TorchSettings:
738+
device: Optional[str] = parser.get_default("torch_device")
739+
740+
736741
@attr.s(auto_attribs=True)
737742
class RunOptions(ExportableSettings):
738743
default_settings: Optional[TrainerSettings] = None
@@ -743,6 +748,7 @@ class RunOptions(ExportableSettings):
743748
engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
744749
environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None
745750
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)
751+
torch_settings: TorchSettings = attr.ib(factory=TorchSettings)
746752

747753
# These are options that are relevant to the run itself, and not the engine or environment.
748754
# They will be left here.
@@ -784,6 +790,7 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
784790
"checkpoint_settings": {},
785791
"env_settings": {},
786792
"engine_settings": {},
793+
"torch_settings": {},
787794
}
788795
if config_path is not None:
789796
configured_dict.update(load_config(config_path))
@@ -808,6 +815,8 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
808815
configured_dict["env_settings"][key] = val
809816
elif key in attr.fields_dict(EngineSettings):
810817
configured_dict["engine_settings"][key] = val
818+
elif key in attr.fields_dict(TorchSettings):
819+
configured_dict["torch_settings"][key] = val
811820
else: # Base options
812821
configured_dict[key] = val
813822

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
from unittest import mock
3+
4+
import torch # noqa I201
5+
6+
from mlagents.torch_utils import set_torch_config, default_device
7+
from mlagents.trainers.settings import TorchSettings
8+
9+
10+
@pytest.mark.parametrize(
11+
"device_str, expected_type, expected_index, expected_tensor_type",
12+
[
13+
("cpu", "cpu", None, torch.FloatTensor),
14+
("cuda", "cuda", None, torch.cuda.FloatTensor),
15+
("cuda:42", "cuda", 42, torch.cuda.FloatTensor),
16+
("opengl", "opengl", None, torch.FloatTensor),
17+
],
18+
)
19+
@mock.patch.object(torch, "set_default_tensor_type")
20+
def test_set_torch_device(
21+
mock_set_default_tensor_type,
22+
device_str,
23+
expected_type,
24+
expected_index,
25+
expected_tensor_type,
26+
):
27+
try:
28+
torch_settings = TorchSettings(device=device_str)
29+
set_torch_config(torch_settings)
30+
assert default_device().type == expected_type
31+
if expected_index is None:
32+
assert default_device().index is None
33+
else:
34+
assert default_device().index == expected_index
35+
mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type)
36+
except Exception:
37+
raise
38+
finally:
39+
# restore the defaults
40+
torch_settings = TorchSettings(device=None)
41+
set_torch_config(torch_settings)

0 commit comments

Comments
 (0)