Skip to content

Support matrix multiplication operator (@) #2287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 20, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def from_operator(self, op: ast35.operator) -> str:
op_name = ASTConverter.op_map.get(type(op))
if op_name is None:
raise RuntimeError('Unknown operator ' + str(type(op)))
elif op_name == '@':
raise RuntimeError('mypy does not support the MatMult operator')
else:
return op_name

Expand Down
4 changes: 2 additions & 2 deletions mypy/lex.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ def lex(string: Union[str, bytes], first_line: int = 1,

# List of regular expressions that match non-alphabetical operators
operators = [re.compile('[-+*/<>.%&|^~]'),
re.compile('==|!=|<=|>=|\\*\\*|//|<<|>>|<>')]
re.compile('==|!=|<=|>=|\\*\\*|@|//|<<|>>|<>')]

# List of regular expressions that match punctuator tokens
punctuators = [re.compile('[=,()@`]|(->)'),
re.compile('\\['),
re.compile(']'),
re.compile('([-+*/%&|^]|\\*\\*|//|<<|>>)=')]
re.compile('([-+*/%@&|^]|\\*\\*|//|<<|>>)=')]


# Map single-character string escape sequences to corresponding characters.
Expand Down
4 changes: 3 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
'%': '__mod__',
'//': '__floordiv__',
'**': '__pow__',
'@': '__matmul__',
'&': '__and__',
'|': '__or__',
'^': '__xor__',
Expand All @@ -1349,7 +1350,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:


ops_with_inplace_method = {
'+', '-', '*', '/', '%', '//', '**', '&', '|', '^', '<<', '>>'}
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'}

inplace_operator_methods = set(
'__i' + op_methods[op][2:] for op in ops_with_inplace_method)
Expand All @@ -1362,6 +1363,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
'__mod__': '__rmod__',
'__floordiv__': '__rfloordiv__',
'__pow__': '__rpow__',
'__matmul__': '__rmatmul__',
'__and__': '__rand__',
'__or__': '__ror__',
'__xor__': '__rxor__',
Expand Down
4 changes: 2 additions & 2 deletions mypy/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'**': 16,
'-u': 15, '+u': 15, '~': 15, # unary operators (-, + and ~)
'<cast>': 14,
'*': 13, '/': 13, '//': 13, '%': 13,
'*': 13, '/': 13, '//': 13, '%': 13, '@': 13,
'+': 12, '-': 12,
'>>': 11, '<<': 11,
'&': 10,
Expand All @@ -61,7 +61,7 @@


op_assign = set([
'+=', '-=', '*=', '/=', '//=', '%=', '**=', '|=', '&=', '^=', '>>=',
'+=', '-=', '*=', '/=', '//=', '%=', '**=', '@=', '|=', '&=', '^=', '>>=',
'<<='])

op_comp = set([
Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,21 @@ main:3: error: Unsupported operand types for * ("A" and "C")
main:4: error: Incompatible types in assignment (expression has type "C", variable has type "A")
main:5: error: Unsupported left operand type for * ("B")

[case testMatMul]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need separate tests for --fast-parser?

Copy link
Contributor Author

@elazarg elazarg Oct 19, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. Do we? It already supports the operator. Looks like this work is delegated to typed_ast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a test to verify that the fast parser recognizes the operator. As far as I understand there's nothing more to test there, not more than there is for any other operator.

a, b, c = None, None, None # type: (A, B, C)
c = a @ c # E: Unsupported operand types for @ ("A" and "C")
a = a @ b # E: Incompatible types in assignment (expression has type "C", variable has type "A")
c = b @ a # E: Unsupported left operand type for @ ("B")
c = a @ b

class A:
def __matmul__(self, x: 'B') -> 'C':
pass
class B:
pass
class C:
pass

[case testDiv]

a, b, c = None, None, None # type: (A, B, C)
Expand Down
7 changes: 7 additions & 0 deletions test-data/unit/check-fastparse.test
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,10 @@ def f(a):
pass
[out]
main:3: error: invalid type comment

[case testFastParseMatMul]
# flags: --fast-parser
from typing import Any
x = None # type: Any
x @ 1
x @= 1
14 changes: 14 additions & 0 deletions test-data/unit/check-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ class C: pass
main:3: error: Unsupported operand types for * ("A" and "A")
main:4: error: Unsupported left operand type for * ("C")

[case testMatMulAssign]
a, c = None, None # type: (A, C)
a @= a # E: Unsupported operand types for @ ("A" and "A")
c @= a # E: Unsupported left operand type for @ ("C")
a @= c

class A:
def __matmul__(self, x: 'C') -> 'A': pass

class C: pass

[case testDivAssign]

a, c = None, None # type: (A, C)
Expand Down Expand Up @@ -295,11 +306,14 @@ import typing
class A:
def __iadd__(self, x: int) -> 'A': pass
def __imul__(self, x: str) -> 'A': pass
def __imatmul__(self, x: str) -> 'A': pass
a = A()
a += 1
a *= ''
a @= ''
a += '' # E: Argument 1 to "__iadd__" of "A" has incompatible type "str"; expected "int"
a *= 1 # E: Argument 1 to "__imul__" of "A" has incompatible type "int"; expected "str"
a @= 1 # E: Argument 1 to "__imatmul__" of "A" has incompatible type "int"; expected "str"

[case testInplaceSetitem]
class A(object):
Expand Down