Skip to content

Commit 4882524

Browse files
committed
Add missing assert_directive and assert_schema
Replicates graphql/graphql-js@958eb96
1 parent 4e891e2 commit 4882524

File tree

10 files changed

+65
-46
lines changed

10 files changed

+65
-46
lines changed

graphql/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
is_introspection_type,
106106
is_specified_directive,
107107
# Assertions
108+
assert_schema,
109+
assert_directive,
108110
assert_type,
109111
assert_scalar_type,
110112
assert_object_type,
@@ -419,6 +421,8 @@
419421
"is_specified_scalar_type",
420422
"is_introspection_type",
421423
"is_specified_directive",
424+
"assert_schema",
425+
"assert_directive",
422426
"assert_type",
423427
"assert_scalar_type",
424428
"assert_object_type",

graphql/type/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .schema import (
77
# Predicate
88
is_schema,
9+
# Assertion
10+
assert_schema,
911
# GraphQL Schema definition
1012
GraphQLSchema,
1113
)
@@ -95,6 +97,8 @@
9597
from .directives import (
9698
# Predicate
9799
is_directive,
100+
# Assertion
101+
assert_directive,
98102
# Directives Definition
99103
GraphQLDirective,
100104
# Built-in Directives defined by the Spec
@@ -134,6 +138,7 @@
134138

135139
__all__ = [
136140
"is_schema",
141+
"assert_schema",
137142
"GraphQLSchema",
138143
"is_type",
139144
"is_scalar_type",
@@ -210,6 +215,7 @@
210215
"GraphQLResolveInfo",
211216
"ResponsePath",
212217
"is_directive",
218+
"assert_directive",
213219
"is_specified_directive",
214220
"specified_directives",
215221
"GraphQLDirective",

graphql/type/definition.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def is_type(type_: Any) -> bool:
142142
def assert_type(type_: Any) -> GraphQLType:
143143
if not is_type(type_):
144144
raise TypeError(f"Expected {type_} to be a GraphQL type.")
145-
return type_
145+
return cast(GraphQLType, type_)
146146

147147

148148
# These types wrap and modify other types
@@ -170,7 +170,7 @@ def is_wrapping_type(type_: Any) -> bool:
170170
def assert_wrapping_type(type_: Any) -> GraphQLWrappingType:
171171
if not is_wrapping_type(type_):
172172
raise TypeError(f"Expected {type_} to be a GraphQL wrapping type.")
173-
return type_
173+
return cast(GraphQLWrappingType, type_)
174174

175175

176176
# These named types do not include modifiers like List or NonNull.
@@ -229,7 +229,7 @@ def is_named_type(type_: Any) -> bool:
229229
def assert_named_type(type_: Any) -> GraphQLNamedType:
230230
if not is_named_type(type_):
231231
raise TypeError(f"Expected {type_} to be a GraphQL named type.")
232-
return type_
232+
return cast(GraphQLNamedType, type_)
233233

234234

235235
@overload
@@ -358,7 +358,7 @@ def is_scalar_type(type_: Any) -> bool:
358358
def assert_scalar_type(type_: Any) -> GraphQLScalarType:
359359
if not is_scalar_type(type_):
360360
raise TypeError(f"Expected {type_} to be a GraphQL Scalar type.")
361-
return type_
361+
return cast(GraphQLScalarType, type_)
362362

363363

364364
GraphQLArgumentMap = Dict[str, "GraphQLArgument"]
@@ -652,7 +652,7 @@ def is_object_type(type_: Any) -> bool:
652652
def assert_object_type(type_: Any) -> GraphQLObjectType:
653653
if not is_object_type(type_):
654654
raise TypeError(f"Expected {type_} to be a GraphQL Object type.")
655-
return type_
655+
return cast(GraphQLObjectType, type_)
656656

657657

658658
class GraphQLInterfaceType(GraphQLNamedType):
@@ -742,7 +742,7 @@ def is_interface_type(type_: Any) -> bool:
742742
def assert_interface_type(type_: Any) -> GraphQLInterfaceType:
743743
if not is_interface_type(type_):
744744
raise TypeError(f"Expected {type_} to be a GraphQL Interface type.")
745-
return type_
745+
return cast(GraphQLInterfaceType, type_)
746746

747747

748748
GraphQLTypeList = Sequence[GraphQLObjectType]
@@ -831,7 +831,7 @@ def is_union_type(type_: Any) -> bool:
831831
def assert_union_type(type_: Any) -> GraphQLUnionType:
832832
if not is_union_type(type_):
833833
raise TypeError(f"Expected {type_} to be a GraphQL Union type.")
834-
return type_
834+
return cast(GraphQLUnionType, type_)
835835

836836

837837
GraphQLEnumValueMap = Dict[str, "GraphQLEnumValue"]
@@ -978,7 +978,7 @@ def is_enum_type(type_: Any) -> bool:
978978
def assert_enum_type(type_: Any) -> GraphQLEnumType:
979979
if not is_enum_type(type_):
980980
raise TypeError(f"Expected {type_} to be a GraphQL Enum type.")
981-
return type_
981+
return cast(GraphQLEnumType, type_)
982982

983983

984984
class GraphQLEnumValue:
@@ -1113,7 +1113,7 @@ def is_input_object_type(type_: Any) -> bool:
11131113
def assert_input_object_type(type_: Any) -> GraphQLInputObjectType:
11141114
if not is_input_object_type(type_):
11151115
raise TypeError(f"Expected {type_} to be a GraphQL Input Object type.")
1116-
return type_
1116+
return cast(GraphQLInputObjectType, type_)
11171117

11181118

11191119
class GraphQLInputField:
@@ -1188,7 +1188,7 @@ def is_list_type(type_: Any) -> bool:
11881188
def assert_list_type(type_: Any) -> GraphQLList:
11891189
if not is_list_type(type_):
11901190
raise TypeError(f"Expected {type_} to be a GraphQL List type.")
1191-
return type_
1191+
return cast(GraphQLList, type_)
11921192

11931193

11941194
GNT = TypeVar("GNT", bound="GraphQLNullableType")
@@ -1233,7 +1233,7 @@ def is_non_null_type(type_: Any) -> bool:
12331233
def assert_non_null_type(type_: Any) -> GraphQLNonNull:
12341234
if not is_non_null_type(type_):
12351235
raise TypeError(f"Expected {type_} to be a GraphQL Non-Null type.")
1236-
return type_
1236+
return cast(GraphQLNonNull, type_)
12371237

12381238

12391239
# These types can all accept null as a value.
@@ -1266,7 +1266,7 @@ def is_nullable_type(type_: Any) -> bool:
12661266
def assert_nullable_type(type_: Any) -> GraphQLNullableType:
12671267
if not is_nullable_type(type_):
12681268
raise TypeError(f"Expected {type_} to be a GraphQL nullable type.")
1269-
return type_
1269+
return cast(GraphQLNullableType, type_)
12701270

12711271

12721272
@overload
@@ -1310,7 +1310,7 @@ def is_input_type(type_: Any) -> bool:
13101310
def assert_input_type(type_: Any) -> GraphQLInputType:
13111311
if not is_input_type(type_):
13121312
raise TypeError(f"Expected {type_} to be a GraphQL input type.")
1313-
return type_
1313+
return cast(GraphQLInputType, type_)
13141314

13151315

13161316
# These types may be used as output types as the result of fields.
@@ -1342,7 +1342,7 @@ def is_output_type(type_: Any) -> bool:
13421342
def assert_output_type(type_: Any) -> GraphQLOutputType:
13431343
if not is_output_type(type_):
13441344
raise TypeError(f"Expected {type_} to be a GraphQL output type.")
1345-
return type_
1345+
return cast(GraphQLOutputType, type_)
13461346

13471347

13481348
# These types may describe types which may be leaf values.
@@ -1359,7 +1359,7 @@ def is_leaf_type(type_: Any) -> bool:
13591359
def assert_leaf_type(type_: Any) -> GraphQLLeafType:
13601360
if not is_leaf_type(type_):
13611361
raise TypeError(f"Expected {type_} to be a GraphQL leaf type.")
1362-
return type_
1362+
return cast(GraphQLLeafType, type_)
13631363

13641364

13651365
# These types may describe the parent context of a selection set.
@@ -1376,7 +1376,7 @@ def is_composite_type(type_: Any) -> bool:
13761376
def assert_composite_type(type_: Any) -> GraphQLType:
13771377
if not is_composite_type(type_):
13781378
raise TypeError(f"Expected {type_} to be a GraphQL composite type.")
1379-
return type_
1379+
return cast(GraphQLType, type_)
13801380

13811381

13821382
# These types may describe abstract types.
@@ -1393,4 +1393,4 @@ def is_abstract_type(type_: Any) -> bool:
13931393
def assert_abstract_type(type_: Any) -> GraphQLAbstractType:
13941394
if not is_abstract_type(type_):
13951395
raise TypeError(f"Expected {type_} to be a GraphQL composite type.")
1396-
return type_
1396+
return cast(GraphQLAbstractType, type_)

graphql/type/directives.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from typing import Any, Dict, Sequence, cast
22

33
from ..language import ast, DirectiveLocation
4+
from ..pyutils import inspect
45
from .definition import GraphQLArgument, GraphQLInputType, GraphQLNonNull, is_input_type
56
from .scalars import GraphQLBoolean, GraphQLString
67

78
__all__ = [
89
"is_directive",
10+
"assert_directive",
911
"is_specified_directive",
1012
"specified_directives",
1113
"GraphQLDirective",
@@ -17,11 +19,6 @@
1719
]
1820

1921

20-
def is_directive(directive: Any) -> bool:
21-
"""Test if the given value is a GraphQL directive."""
22-
return isinstance(directive, GraphQLDirective)
23-
24-
2522
class GraphQLDirective:
2623
"""GraphQL Directive
2724
@@ -90,6 +87,17 @@ def __repr__(self):
9087
return f"<{self.__class__.__name__}({self})>"
9188

9289

90+
def is_directive(directive: Any) -> bool:
91+
"""Test if the given value is a GraphQL directive."""
92+
return isinstance(directive, GraphQLDirective)
93+
94+
95+
def assert_directive(directive: Any) -> GraphQLDirective:
96+
if not is_directive(directive):
97+
raise TypeError(f"Expected {inspect(directive)} to be a GraphQL directive.")
98+
return cast(GraphQLDirective, directive)
99+
100+
93101
# Used to conditionally include fields or fragments.
94102
GraphQLIncludeDirective = GraphQLDirective(
95103
name="include",

graphql/type/schema.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ..error import GraphQLError
55
from ..language import ast
6+
from ..pyutils import inspect
67
from .definition import (
78
GraphQLAbstractType,
89
GraphQLInterfaceType,
@@ -21,17 +22,12 @@
2122
from .directives import GraphQLDirective, specified_directives, is_directive
2223
from .introspection import introspection_types
2324

24-
__all__ = ["GraphQLSchema", "is_schema"]
25+
__all__ = ["GraphQLSchema", "is_schema", "assert_schema"]
2526

2627

2728
TypeMap = Dict[str, GraphQLNamedType]
2829

2930

30-
def is_schema(schema: Any) -> bool:
31-
"""Test if the given value is a GraphQL schema."""
32-
return isinstance(schema, GraphQLSchema)
33-
34-
3531
class GraphQLSchema:
3632
"""Schema Definition
3733
@@ -177,6 +173,17 @@ def validation_errors(self):
177173
return self._validation_errors
178174

179175

176+
def is_schema(schema: Any) -> bool:
177+
"""Test if the given value is a GraphQL schema."""
178+
return isinstance(schema, GraphQLSchema)
179+
180+
181+
def assert_schema(schema: Any) -> GraphQLSchema:
182+
if not is_schema(schema):
183+
raise TypeError(f"Expected {inspect(schema)} to be a GraphQL schema.")
184+
return cast(GraphQLSchema, schema)
185+
186+
180187
def type_map_reducer(map_: TypeMap, type_: GraphQLNamedType = None) -> TypeMap:
181188
"""Reducer function for creating the type map from given types."""
182189
if not type_:

graphql/type/validate.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Callable, List, Optional, Sequence, Set, Union, cast
33

44
from ..error import GraphQLError
5+
from ..pyutils import inspect
56
from ..language import (
67
EnumValueDefinitionNode,
78
FieldDefinitionNode,
@@ -28,12 +29,11 @@
2829
is_union_type,
2930
is_required_argument,
3031
)
31-
from ..pyutils import inspect
3232
from ..utilities.assert_valid_name import is_valid_name_error
3333
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
3434
from .directives import GraphQLDirective, is_directive
3535
from .introspection import is_introspection_type
36-
from .schema import GraphQLSchema, is_schema
36+
from .schema import GraphQLSchema, assert_schema
3737

3838
__all__ = ["validate_schema", "assert_valid_schema"]
3939

@@ -48,8 +48,7 @@ def validate_schema(schema: GraphQLSchema) -> List[GraphQLError]:
4848
list if no errors were encountered and the Schema is valid.
4949
"""
5050
# First check to ensure the provided value is in fact a GraphQLSchema.
51-
if not is_schema(schema):
52-
raise TypeError(f"Expected {inspect(schema)} to be a GraphQL schema.")
51+
assert_schema(schema)
5352

5453
# If this Schema has already been validated, return the previous results.
5554
# noinspection PyProtectedMember
@@ -70,7 +69,7 @@ def validate_schema(schema: GraphQLSchema) -> List[GraphQLError]:
7069
return errors
7170

7271

73-
def assert_valid_schema(schema: GraphQLSchema):
72+
def assert_valid_schema(schema: GraphQLSchema) -> None:
7473
"""Utility function which asserts a schema is valid.
7574
7675
Throws a TypeError if the schema is invalid.

graphql/utilities/extend_schema.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141
GraphQLSchema,
4242
GraphQLType,
4343
GraphQLUnionType,
44+
assert_schema,
4445
is_enum_type,
4546
is_input_object_type,
4647
is_interface_type,
4748
is_list_type,
4849
is_non_null_type,
4950
is_object_type,
5051
is_scalar_type,
51-
is_schema,
5252
is_union_type,
5353
is_introspection_type,
5454
is_specified_scalar_type,
@@ -80,9 +80,7 @@ def extend_schema(
8080
schema is valid. Set `assume_valid` to true to assume the produced schema is valid.
8181
Set `assume_valid_sdl` to True to assume it is already a valid SDL document.
8282
"""
83-
84-
if not is_schema(schema):
85-
raise TypeError("Must provide valid GraphQLSchema")
83+
assert_schema(schema)
8684

8785
if not isinstance(document_ast, DocumentNode):
8886
"Must provide valid Document AST"

tests/execution/test_abstract_async.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections import namedtuple
21
from typing import NamedTuple
32

43
from pytest import mark

tests/utilities/test_build_ast_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
GraphQLDeprecatedDirective,
99
GraphQLIncludeDirective,
1010
GraphQLSkipDirective,
11+
assert_directive,
1112
assert_enum_type,
1213
assert_input_object_type,
1314
assert_interface_type,
@@ -708,8 +709,7 @@ def correctly_assign_ast_nodes():
708709
test_interface = assert_interface_type(schema.get_type("TestInterface"))
709710
test_type = assert_object_type(schema.get_type("TestType"))
710711
test_scalar = assert_scalar_type(schema.get_type("TestScalar"))
711-
test_directive = schema.get_directive("test")
712-
assert test_directive
712+
test_directive = assert_directive(schema.get_directive("test"))
713713

714714
restored_schema_ast = DocumentNode(
715715
definitions=[

0 commit comments

Comments
 (0)