Skip to content

Commit acfb2c0

Browse files
committed
Added experimental dataloader
1 parent 23f4b48 commit acfb2c0

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from promise import Promise
2+
from promise.dataloader import DataLoader
3+
4+
from graphql import GraphQLObjectType, GraphQLField, GraphQLID, GraphQLArgument, GraphQLNonNull, GraphQLSchema, parse, execute
5+
6+
7+
def test_batches_correctly():
8+
9+
Business = GraphQLObjectType('Business', lambda: {
10+
'id': GraphQLField(GraphQLID, resolver=lambda root, args, context, info: root),
11+
})
12+
13+
Query = GraphQLObjectType('Query', lambda: {
14+
'getBusiness': GraphQLField(Business,
15+
args={
16+
'id': GraphQLArgument(GraphQLNonNull(GraphQLID)),
17+
},
18+
resolver=lambda root, args, context, info: context.business_data_loader.load(args.get('id'))
19+
),
20+
})
21+
22+
schema = GraphQLSchema(query=Query)
23+
24+
25+
doc = '''
26+
{
27+
business1: getBusiness(id: "1") {
28+
id
29+
}
30+
business2: getBusiness(id: "2") {
31+
id
32+
}
33+
}
34+
'''
35+
doc_ast = parse(doc)
36+
37+
38+
load_calls = []
39+
40+
class BusinessDataLoader(DataLoader):
41+
def batch_load_fn(self, keys):
42+
load_calls.append(keys)
43+
return Promise.resolve(keys)
44+
45+
class Context(object):
46+
business_data_loader = BusinessDataLoader()
47+
48+
49+
result = execute(schema, doc_ast, None, context_value=Context())
50+
assert not result.errors
51+
assert result.data == {
52+
'business1': {
53+
'id': '1'
54+
},
55+
'business2': {
56+
'id': '2'
57+
},
58+
}
59+
assert load_calls == [['1','2']]
60+
61+
62+
def test_batches_multiple_together():
63+
64+
Location = GraphQLObjectType('Location', lambda: {
65+
'id': GraphQLField(GraphQLID, resolver=lambda root, args, context, info: root),
66+
})
67+
68+
Business = GraphQLObjectType('Business', lambda: {
69+
'id': GraphQLField(GraphQLID, resolver=lambda root, args, context, info: root),
70+
'location': GraphQLField(Location,
71+
resolver=lambda root, args, context, info: context.location_data_loader.load('location-{}'.format(root))
72+
),
73+
})
74+
75+
Query = GraphQLObjectType('Query', lambda: {
76+
'getBusiness': GraphQLField(Business,
77+
args={
78+
'id': GraphQLArgument(GraphQLNonNull(GraphQLID)),
79+
},
80+
resolver=lambda root, args, context, info: context.business_data_loader.load(args.get('id'))
81+
),
82+
})
83+
84+
schema = GraphQLSchema(query=Query)
85+
86+
87+
doc = '''
88+
{
89+
business1: getBusiness(id: "1") {
90+
id
91+
location {
92+
id
93+
}
94+
}
95+
business2: getBusiness(id: "2") {
96+
id
97+
location {
98+
id
99+
}
100+
}
101+
}
102+
'''
103+
doc_ast = parse(doc)
104+
105+
106+
business_load_calls = []
107+
108+
class BusinessDataLoader(DataLoader):
109+
def batch_load_fn(self, keys):
110+
business_load_calls.append(keys)
111+
return Promise.resolve(keys)
112+
113+
location_load_calls = []
114+
115+
class LocationDataLoader(DataLoader):
116+
def batch_load_fn(self, keys):
117+
location_load_calls.append(keys)
118+
return Promise.resolve(keys)
119+
120+
class Context(object):
121+
business_data_loader = BusinessDataLoader()
122+
location_data_loader = LocationDataLoader()
123+
124+
125+
result = execute(schema, doc_ast, None, context_value=Context())
126+
assert not result.errors
127+
assert result.data == {
128+
'business1': {
129+
'id': '1',
130+
'location': {
131+
'id': 'location-1'
132+
}
133+
},
134+
'business2': {
135+
'id': '2',
136+
'location': {
137+
'id': 'location-2'
138+
}
139+
},
140+
}
141+
assert business_load_calls == [['1','2']]
142+
assert location_load_calls == [['location-1','location-2']]

0 commit comments

Comments
 (0)