Skip to content

Commit 6a3f900

Browse files
committed
Improved typing
1 parent 697ce76 commit 6a3f900

19 files changed

+235
-221
lines changed

graphql/execution/executor.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@
4747
def subscribe(*args, **kwargs):
4848
# type: (*Any, **Any) -> Union[ExecutionResult, Observable]
4949
allow_subscriptions = kwargs.pop("allow_subscriptions", True)
50-
return execute(*args, allow_subscriptions=allow_subscriptions, **kwargs)
50+
return execute( # type: ignore
51+
*args, allow_subscriptions=allow_subscriptions, **kwargs
52+
)
5153

5254

5355
def execute(
@@ -116,7 +118,7 @@ def execute(
116118
allow_subscriptions,
117119
)
118120

119-
def executor(v):
121+
def promise_executor(v):
120122
# type: (Optional[Any]) -> Union[OrderedDict, Promise, Observable]
121123
return execute_operation(exe_context, exe_context.operation, root)
122124

@@ -135,7 +137,9 @@ def on_resolve(data):
135137

136138
return ExecutionResult(data=data, errors=exe_context.errors)
137139

138-
promise = Promise.resolve(None).then(executor).catch(on_rejected).then(on_resolve)
140+
promise = (
141+
Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve)
142+
)
139143

140144
if not return_promise:
141145
exe_context.executor.wait_until_finished()
@@ -151,7 +155,7 @@ def on_resolve(data):
151155
def execute_operation(
152156
exe_context, # type: ExecutionContext
153157
operation, # type: OperationDefinition
154-
root_value, # type: Union[None, Data, type]
158+
root_value, # type: Any
155159
):
156160
# type: (...) -> Union[OrderedDict, Promise]
157161
type = get_operation_root_type(exe_context.schema, operation)
@@ -224,7 +228,7 @@ def execute_fields(
224228
parent_type, # type: GraphQLObjectType
225229
source_value, # type: Any
226230
fields, # type: DefaultOrderedDict
227-
path, # type: Union[List[Union[int, str]], List[str]]
231+
path, # type: List[Union[int, str]]
228232
info, # type: Optional[ResolveInfo]
229233
):
230234
# type: (...) -> Union[OrderedDict, Promise]
@@ -257,39 +261,43 @@ def execute_fields(
257261
def subscribe_fields(
258262
exe_context, # type: ExecutionContext
259263
parent_type, # type: GraphQLObjectType
260-
source_value, # type: Union[None, Data, type]
264+
source_value, # type: Any
261265
fields, # type: DefaultOrderedDict
262266
):
263267
# type: (...) -> Observable
264-
exe_context = SubscriberExecutionContext(exe_context)
268+
subscriber_exe_context = SubscriberExecutionContext(exe_context)
265269

266270
def on_error(error):
267-
exe_context.report_error(error)
271+
subscriber_exe_context.report_error(error)
268272

269273
def map_result(
270274
data # type: Union[Dict[str, None], Dict[str, OrderedDict], Dict[str, str]]
271275
):
272276
# type: (...) -> ExecutionResult
273-
if exe_context.errors:
274-
result = ExecutionResult(data=data, errors=exe_context.errors)
277+
if subscriber_exe_context.errors:
278+
result = ExecutionResult(data=data, errors=subscriber_exe_context.errors)
275279
else:
276280
result = ExecutionResult(data=data)
277-
exe_context.reset()
281+
subscriber_exe_context.reset()
278282
return result
279283

280-
observables = []
284+
observables = [] # type: List[Observable]
281285

282286
# assert len(fields) == 1, "Can only subscribe one element at a time."
283287

284288
for response_name, field_asts in fields.items():
285289
result = subscribe_field(
286-
exe_context, parent_type, source_value, field_asts, [response_name]
290+
subscriber_exe_context,
291+
parent_type,
292+
source_value,
293+
field_asts,
294+
[response_name],
287295
)
288296
if result is Undefined:
289297
continue
290298

291299
def catch_error(error):
292-
exe_context.errors.append(error)
300+
subscriber_exe_context.errors.append(error)
293301
return Observable.just(None)
294302

295303
# Map observable results
@@ -305,10 +313,10 @@ def catch_error(error):
305313
def resolve_field(
306314
exe_context, # type: ExecutionContext
307315
parent_type, # type: GraphQLObjectType
308-
source, # type: Union[None, Cat, Dog]
316+
source, # type: Any
309317
field_asts, # type: List[Field]
310318
parent_info, # type: Optional[ResolveInfo]
311-
field_path, # type: Union[List[Union[int, str]], List[str]]
319+
field_path, # type: List[Union[int, str]]
312320
):
313321
# type: (...) -> Any
314322
field_ast = field_asts[0]
@@ -360,7 +368,7 @@ def resolve_field(
360368
def subscribe_field(
361369
exe_context, # type: SubscriberExecutionContext
362370
parent_type, # type: GraphQLObjectType
363-
source, # type: Union[None, Data, type]
371+
source, # type: Any
364372
field_asts, # type: List[Field]
365373
path, # type: List[str]
366374
):
@@ -430,12 +438,12 @@ def subscribe_field(
430438

431439
def resolve_or_error(
432440
resolve_fn, # type: Callable
433-
source, # type: Union[None, Cat, Dog]
441+
source, # type: Any
434442
info, # type: ResolveInfo
435443
args, # type: Dict
436-
executor, # type: Union[BaseExecutor, SyncExecutor]
444+
executor, # type: Any
437445
):
438-
# type: (...) -> Union[List[Union[Cat, Dog]], bool, str]
446+
# type: (...) -> Any
439447
try:
440448
return executor.execute(resolve_fn, source, info, **args)
441449
except Exception as e:
@@ -444,7 +452,7 @@ def resolve_or_error(
444452
info.parent_type.name, info.field_name
445453
)
446454
)
447-
e.stack = sys.exc_info()[2]
455+
e.stack = sys.exc_info()[2] # type: ignore
448456
return e
449457

450458

@@ -453,10 +461,10 @@ def complete_value_catching_error(
453461
return_type, # type: Any
454462
field_asts, # type: List[Field]
455463
info, # type: ResolveInfo
456-
path, # type: Union[List[Union[int, str]], List[str]]
464+
path, # type: List[Union[int, str]]
457465
result, # type: Any
458466
):
459-
# type: (...) -> Union[bool, str]
467+
# type: (...) -> Any
460468
# If the field type is non-nullable, then it is resolved without any
461469
# protection from errors.
462470
if isinstance(return_type, GraphQLNonNull):
@@ -472,7 +480,7 @@ def complete_value_catching_error(
472480

473481
def handle_error(error):
474482
# type: (Union[GraphQLError, GraphQLLocatedError]) -> Optional[Any]
475-
traceback = completed._traceback
483+
traceback = completed._traceback # type: ignore
476484
exe_context.report_error(error, traceback)
477485
return None
478486

@@ -490,10 +498,10 @@ def complete_value(
490498
return_type, # type: Any
491499
field_asts, # type: List[Field]
492500
info, # type: ResolveInfo
493-
path, # type: Union[List[Union[int, str]], List[str]]
501+
path, # type: List[Union[int, str]]
494502
result, # type: Any
495503
):
496-
# type: (...) -> Union[bool, str]
504+
# type: (...) -> Any
497505
"""
498506
Implements the instructions for completeValue as defined in the
499507
"Field entries" section of the spec.
@@ -566,10 +574,10 @@ def complete_list_value(
566574
return_type, # type: GraphQLList
567575
field_asts, # type: List[Field]
568576
info, # type: ResolveInfo
569-
path, # type: List[str]
577+
path, # type: List[Union[int, str]]
570578
result, # type: Any
571579
):
572-
# type: (...) -> Any
580+
# type: (...) -> List[Any]
573581
"""
574582
Complete a list value by completing each item in the list with the inner type
575583
"""
@@ -597,10 +605,10 @@ def complete_list_value(
597605

598606
def complete_leaf_value(
599607
return_type, # type: Union[GraphQLEnumType, GraphQLScalarType]
600-
path, # type: Union[List[Union[int, str]], List[str]]
601-
result, # type: Union[int, str]
608+
path, # type: List[Union[int, str]]
609+
result, # type: Any
602610
):
603-
# type: (...) -> Union[int, str]
611+
# type: (...) -> Union[int, str, float, bool]
604612
"""
605613
Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible.
606614
"""
@@ -625,12 +633,12 @@ def complete_abstract_value(
625633
path, # type: List[Union[int, str]]
626634
result, # type: Any
627635
):
628-
# type: (...) -> OrderedDict
636+
# type: (...) -> Dict[str, Any]
629637
"""
630638
Complete an value of an abstract type by determining the runtime type of that value, then completing based
631639
on that type.
632640
"""
633-
runtime_type = None
641+
runtime_type = None # type: Union[str, GraphQLObjectType, None]
634642

635643
# Field type must be Object, Interface or Union and expect sub-selections.
636644
if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)):
@@ -640,7 +648,7 @@ def complete_abstract_value(
640648
runtime_type = get_default_resolve_type_fn(result, info, return_type)
641649

642650
if isinstance(runtime_type, string_types):
643-
runtime_type = info.schema.get_type(runtime_type)
651+
runtime_type = info.schema.get_type(runtime_type) # type: ignore
644652

645653
if not isinstance(runtime_type, GraphQLObjectType):
646654
raise GraphQLError(
@@ -671,22 +679,23 @@ def get_default_resolve_type_fn(
671679
info, # type: ResolveInfo
672680
abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType]
673681
):
674-
# type: (...) -> GraphQLObjectType
682+
# type: (...) -> Optional[GraphQLObjectType]
675683
possible_types = info.schema.get_possible_types(abstract_type)
676684
for type in possible_types:
677685
if callable(type.is_type_of) and type.is_type_of(value, info):
678686
return type
687+
return None
679688

680689

681690
def complete_object_value(
682691
exe_context, # type: ExecutionContext
683692
return_type, # type: GraphQLObjectType
684693
field_asts, # type: List[Field]
685694
info, # type: ResolveInfo
686-
path, # type: Union[List[Union[int, str]], List[str]]
695+
path, # type: List[Union[int, str]]
687696
result, # type: Any
688697
):
689-
# type: (...) -> Union[OrderedDict, Promise]
698+
# type: (...) -> Dict[str, Any]
690699
"""
691700
Complete an Object value by evaluating all sub-selections.
692701
"""
@@ -708,7 +717,7 @@ def complete_nonnull_value(
708717
return_type, # type: GraphQLNonNull
709718
field_asts, # type: List[Field]
710719
info, # type: ResolveInfo
711-
path, # type: Union[List[Union[int, str]], List[str]]
720+
path, # type: List[Union[int, str]]
712721
result, # type: Any
713722
):
714723
# type: (...) -> Any

graphql/execution/executors/asyncio.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
if False: # flake8: noqa
88
from asyncio.unix_events import _UnixSelectorEventLoop
9-
from typing import Optional, Any, Callable
9+
from typing import Optional, Any, Callable, List
1010

1111
try:
1212
from asyncio import ensure_future
1313
except ImportError:
1414
# ensure_future is only implemented in Python 3.4.4+
15-
def ensure_future(coro_or_future, loop=None):
15+
def ensure_future(coro_or_future, loop=None): # type: ignore
1616
"""Wrap a coroutine or an awaitable in a future.
1717
1818
If the argument is a Future, it is returned directly.
@@ -39,7 +39,7 @@ def ensure_future(coro_or_future, loop=None):
3939
def isasyncgen(obj):
4040
False
4141

42-
def asyncgen_to_observable(asyncgen):
42+
def asyncgen_to_observable(asyncgen, loop=None):
4343
pass
4444

4545

@@ -49,7 +49,7 @@ def __init__(self, loop=None):
4949
if loop is None:
5050
loop = get_event_loop()
5151
self.loop = loop
52-
self.futures = []
52+
self.futures = [] # type: List[Future]
5353

5454
def wait_until_finished(self):
5555
# type: () -> None

graphql/execution/executors/thread.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .utils import process
66

77
if False: # flake8: noqa
8-
from typing import Any, Callable
8+
from typing import Any, Callable, List
99

1010

1111
class ThreadExecutor(object):
@@ -14,7 +14,7 @@ class ThreadExecutor(object):
1414

1515
def __init__(self, pool=False):
1616
# type: (bool) -> None
17-
self.threads = []
17+
self.threads = [] # type: List[Thread]
1818
if pool:
1919
self.execute = self.execute_in_pool
2020
self.pool = ThreadPool(processes=pool)
@@ -26,7 +26,8 @@ def wait_until_finished(self):
2626
while self.threads:
2727
threads = self.threads
2828
self.threads = []
29-
[thread.join() for thread in threads]
29+
for thread in threads:
30+
thread.join()
3031

3132
def clean(self):
3233
self.threads = []

graphql/execution/executors/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
if False: # flake8: noqa
44
from ..base import ResolveInfo
55
from promise import Promise
6-
from typing import Callable, Dict, Tuple, Union
6+
from typing import Callable, Dict, Tuple, Union, Any
77

88

99
def process(
@@ -18,5 +18,5 @@ def process(
1818
p.do_resolve(val)
1919
except Exception as e:
2020
traceback = exc_info()[2]
21-
e.stack = traceback
21+
e.stack = traceback # type: ignore
2222
p.do_reject(e, traceback=traceback)

0 commit comments

Comments
 (0)