Skip to content

Commit 48fa2ef

Browse files
elazarggvanrossum
authored andcommitted
Support matrix multiplication operator (@) (#2287)
Closes #705.
1 parent 9a6cce0 commit 48fa2ef

File tree

7 files changed

+43
-7
lines changed

7 files changed

+43
-7
lines changed

mypy/fastparse.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@ def from_operator(self, op: ast35.operator) -> str:
157157
op_name = ASTConverter.op_map.get(type(op))
158158
if op_name is None:
159159
raise RuntimeError('Unknown operator ' + str(type(op)))
160-
elif op_name == '@':
161-
raise RuntimeError('mypy does not support the MatMult operator')
162160
else:
163161
return op_name
164162

mypy/lex.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,13 @@ def lex(string: Union[str, bytes], first_line: int = 1,
195195

196196
# List of regular expressions that match non-alphabetical operators
197197
operators = [re.compile('[-+*/<>.%&|^~]'),
198-
re.compile('==|!=|<=|>=|\\*\\*|//|<<|>>|<>')]
198+
re.compile('==|!=|<=|>=|\\*\\*|@|//|<<|>>|<>')]
199199

200200
# List of regular expressions that match punctuator tokens
201201
punctuators = [re.compile('[=,()@`]|(->)'),
202202
re.compile('\\['),
203203
re.compile(']'),
204-
re.compile('([-+*/%&|^]|\\*\\*|//|<<|>>)=')]
204+
re.compile('([-+*/%@&|^]|\\*\\*|//|<<|>>)=')]
205205

206206

207207
# Map single-character string escape sequences to corresponding characters.

mypy/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
13281328
'%': '__mod__',
13291329
'//': '__floordiv__',
13301330
'**': '__pow__',
1331+
'@': '__matmul__',
13311332
'&': '__and__',
13321333
'|': '__or__',
13331334
'^': '__xor__',
@@ -1349,7 +1350,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
13491350

13501351

13511352
ops_with_inplace_method = {
1352-
'+', '-', '*', '/', '%', '//', '**', '&', '|', '^', '<<', '>>'}
1353+
'+', '-', '*', '/', '%', '//', '**', '@', '&', '|', '^', '<<', '>>'}
13531354

13541355
inplace_operator_methods = set(
13551356
'__i' + op_methods[op][2:] for op in ops_with_inplace_method)
@@ -1362,6 +1363,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
13621363
'__mod__': '__rmod__',
13631364
'__floordiv__': '__rfloordiv__',
13641365
'__pow__': '__rpow__',
1366+
'__matmul__': '__rmatmul__',
13651367
'__and__': '__rand__',
13661368
'__or__': '__ror__',
13671369
'__xor__': '__rxor__',

mypy/parse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
'**': 16,
4545
'-u': 15, '+u': 15, '~': 15, # unary operators (-, + and ~)
4646
'<cast>': 14,
47-
'*': 13, '/': 13, '//': 13, '%': 13,
47+
'*': 13, '/': 13, '//': 13, '%': 13, '@': 13,
4848
'+': 12, '-': 12,
4949
'>>': 11, '<<': 11,
5050
'&': 10,
@@ -61,7 +61,7 @@
6161

6262

6363
op_assign = set([
64-
'+=', '-=', '*=', '/=', '//=', '%=', '**=', '|=', '&=', '^=', '>>=',
64+
'+=', '-=', '*=', '/=', '//=', '%=', '**=', '@=', '|=', '&=', '^=', '>>=',
6565
'<<='])
6666

6767
op_comp = set([

test-data/unit/check-expressions.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,21 @@ main:3: error: Unsupported operand types for * ("A" and "C")
182182
main:4: error: Incompatible types in assignment (expression has type "C", variable has type "A")
183183
main:5: error: Unsupported left operand type for * ("B")
184184

185+
[case testMatMul]
186+
a, b, c = None, None, None # type: (A, B, C)
187+
c = a @ c # E: Unsupported operand types for @ ("A" and "C")
188+
a = a @ b # E: Incompatible types in assignment (expression has type "C", variable has type "A")
189+
c = b @ a # E: Unsupported left operand type for @ ("B")
190+
c = a @ b
191+
192+
class A:
193+
def __matmul__(self, x: 'B') -> 'C':
194+
pass
195+
class B:
196+
pass
197+
class C:
198+
pass
199+
185200
[case testDiv]
186201

187202
a, b, c = None, None, None # type: (A, B, C)

test-data/unit/check-fastparse.test

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,10 @@ def f(a):
179179
pass
180180
[out]
181181
main:3: error: invalid type comment
182+
183+
[case testFastParseMatMul]
184+
# flags: --fast-parser
185+
from typing import Any
186+
x = None # type: Any
187+
x @ 1
188+
x @= 1

test-data/unit/check-statements.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,17 @@ class C: pass
228228
main:3: error: Unsupported operand types for * ("A" and "A")
229229
main:4: error: Unsupported left operand type for * ("C")
230230

231+
[case testMatMulAssign]
232+
a, c = None, None # type: (A, C)
233+
a @= a # E: Unsupported operand types for @ ("A" and "A")
234+
c @= a # E: Unsupported left operand type for @ ("C")
235+
a @= c
236+
237+
class A:
238+
def __matmul__(self, x: 'C') -> 'A': pass
239+
240+
class C: pass
241+
231242
[case testDivAssign]
232243

233244
a, c = None, None # type: (A, C)
@@ -295,11 +306,14 @@ import typing
295306
class A:
296307
def __iadd__(self, x: int) -> 'A': pass
297308
def __imul__(self, x: str) -> 'A': pass
309+
def __imatmul__(self, x: str) -> 'A': pass
298310
a = A()
299311
a += 1
300312
a *= ''
313+
a @= ''
301314
a += '' # E: Argument 1 to "__iadd__" of "A" has incompatible type "str"; expected "int"
302315
a *= 1 # E: Argument 1 to "__imul__" of "A" has incompatible type "int"; expected "str"
316+
a @= 1 # E: Argument 1 to "__imatmul__" of "A" has incompatible type "int"; expected "str"
303317

304318
[case testInplaceSetitem]
305319
class A(object):

0 commit comments

Comments
 (0)