Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ before_install: &before_install
- pip install -r requirements-dev.txt

install:
- pip install "mypy==0.782"
- python setup.py install

script:
- mypy --config-file mypy.ini
- CI_PYTHON_VERSION="$TRAVIS_PYTHON_VERSION" sh tests/run_cpu_tests.sh

after_success:
Expand All @@ -51,12 +53,11 @@ jobs:
- stage: Lint check
python: "3.7"
before_install: # Nothing to do
install: pip install flake8 "black==19.10b0" "isort==4.3.21" "mypy==0.782"
install: pip install flake8 "black==19.10b0" "isort==4.3.21"
script:
- flake8 .
- black --check .
- isort -rc -c .
- mypy --config-file mypy.ini
after_success: # Nothing to do

# GitHub Pages Deployment: https://docs.travis-ci.com/user/deployment/pages/
Expand Down
10 changes: 6 additions & 4 deletions ignite/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections.abc as collections
import logging
import random
from typing import Any, Callable, Optional, Tuple, Type, Union
from typing import Any, Callable, Optional, Tuple, Type, Union, cast

import torch

Expand Down Expand Up @@ -41,11 +41,13 @@ def apply_to_type(
if isinstance(input_, (str, bytes)):
return input_
if isinstance(input_, collections.Mapping):
return type(input_)({k: apply_to_type(sample, input_type, func) for k, sample in input_.items()})
return cast(Callable, type(input_))(
{k: apply_to_type(sample, input_type, func) for k, sample in input_.items()}
)
if isinstance(input_, tuple) and hasattr(input_, "_fields"): # namedtuple
return type(input_)(*(apply_to_type(sample, input_type, func) for sample in input_))
return cast(Callable, type(input_))(*(apply_to_type(sample, input_type, func) for sample in input_))
if isinstance(input_, collections.Sequence):
return type(input_)([apply_to_type(sample, input_type, func) for sample in input_])
return cast(Callable, type(input_))([apply_to_type(sample, input_type, func) for sample in input_])
raise TypeError(("input must contain {}, dicts or lists; found {}".format(input_type, type(input_))))


Expand Down
5 changes: 2 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,5 @@ ignore_errors = True

ignore_errors = True

[mypy-ignite.utils.*]

ignore_errors = True
[mypy-numpy.*]
ignore_missing_imports = True