Skip to content

Float multiplication of numpy.ndarray in lambda is incorrectly analysed #8001

@jamesohortle

Description

@jamesohortle

Note: if you are reporting a wrong signature of a function or a class in
the standard library, then the typeshed tracker is better suited
for this report: https://github.com/python/typeshed/issues

Please provide more information to help us understand the issue:

  • Are you reporting a bug, or opening a feature request?

Bug.

  • Please insert below the code you are checking with mypy,
    or a mock-up repro if the source is private. We would appreciate
    if you try to simplify your case to a minimal repro.
from typing import NewType, Tuple, Deque

import numpy as np


class Point(np.ndarray):
    def __new__(cls, x: float, y: float) -> np.ndarray:
        return np.array((x, y), dtype=np.float32)


Nose = NewType("Nose", Tuple[Point, ...])


def average_position(positions: Deque) -> Nose:
    factor = 1.0 / len(positions)
    point_sum = [Point(0.0, 0.0) for _ in range(10)]
    for nose in positions:
        for i, point in enumerate(nose):
            point_sum[i] += point
    avg_pos = Nose(tuple(map(lambda p: factor * p, point_sum)))
    return avg_pos


def average_position_with_multiply(positions: Deque) -> Nose:
    factor = 1.0 / len(positions)
    point_sum = [Point(0.0, 0.0) for _ in range(10)]
    for nose in positions:
        for i, point in enumerate(nose):
            point_sum[i] += point
    avg_pos = Nose(tuple(map(lambda p: np.multiply(factor, p), point_sum)))
    return avg_pos
  • What is the actual behavior/output?

Mypy gives the errors below:

bug.py:20: error: Argument 1 to "map" has incompatible type "Callable[[Point], float]"; expected "Callable[[Point], Point]"
bug.py:20: error: Incompatible return value type (got "float", expected "Point")
Found 2 errors in 1 file (checked 1 source file)

In the lambda in average_position(), mypy incorrectly (?) determines the type as Callable[[Point], float], while using the NumPy function np.multiply() (which does the same thing) gives no error. Both functions output the same correct value.

  • What is the behavior/output you expect?
    There should be no error.

  • What are the versions of mypy and Python you are using?
    Do you see the same issue after installing mypy from Git master?

mypy==0.740
Python 3.8.0
  • What are the mypy flags you are using? (For example --strict-optional)
[mypy]
python_version = 3.8

[mypy-cv2,dlib,numpy]
ignore_missing_imports = True

I am unsure if this is actually an error, or if I've done something wrong somewhere, but both functions run and output correctly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions