Skip to content

Commit 8df814c

Browse files
committed
cleaned up test file
1 parent 1ff2053 commit 8df814c

File tree

1 file changed

+146
-131
lines changed

1 file changed

+146
-131
lines changed

tests/system/test_pipeline_acceptance.py

Lines changed: 146 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""
15+
This file loads and executes yaml-encoded test cases from pipeline_e2e.yaml
16+
"""
1417

1518
from __future__ import annotations
1619
import os
@@ -36,143 +39,13 @@
3639

3740
def yaml_loader(field="tests"):
3841
"""
39-
loads test cases or data from yaml file
42+
Helper to load test cases or data from yaml file
4043
"""
4144
with open(f"{test_dir_name}/pipeline_e2e.yaml") as f:
4245
test_cases = yaml.safe_load(f)
4346
return test_cases[field]
4447

4548

46-
@pytest.fixture(scope="module")
47-
def event_loop():
48-
"""Change event_loop fixture to module level."""
49-
import asyncio
50-
policy = asyncio.get_event_loop_policy()
51-
loop = policy.new_event_loop()
52-
yield loop
53-
loop.close()
54-
55-
56-
@pytest.fixture(scope="module")
57-
def client():
58-
client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB)
59-
data = yaml_loader("data")
60-
try:
61-
# setup data
62-
batch = client.batch()
63-
for collection_name, documents in data.items():
64-
collection_ref = client.collection(collection_name)
65-
for document_id, document_data in documents.items():
66-
document_ref = collection_ref.document(document_id)
67-
batch.set(document_ref, document_data)
68-
batch.commit()
69-
yield client
70-
finally:
71-
# clear data
72-
for collection_name, documents in data.items():
73-
collection_ref = client.collection(collection_name)
74-
for document_id in documents:
75-
document_ref = collection_ref.document(document_id)
76-
document_ref.delete()
77-
78-
79-
@pytest.fixture(scope="module")
80-
def async_client(client):
81-
yield AsyncClient(project=client.project, database=client._database)
82-
83-
84-
def _apply_yaml_args(cls, client, yaml_args):
85-
"""
86-
Helper to instantiate a class with yaml arguments. The arguments will be applied
87-
as positional or keyword arguments, based on type
88-
"""
89-
if isinstance(yaml_args, dict):
90-
return cls(**parse_expressions(client, yaml_args))
91-
elif isinstance(yaml_args, list):
92-
# yaml has an array of arguments. Treat as args
93-
return cls(*parse_expressions(client, yaml_args))
94-
else:
95-
# yaml has a single argument
96-
return cls(parse_expressions(client, yaml_args))
97-
98-
99-
def parse_pipeline(client, pipeline: list[dict[str, Any], str]):
100-
"""
101-
parse a yaml list of pipeline stages into firestore._pipeline_stages.Stage classes
102-
"""
103-
result_list = []
104-
for stage in pipeline:
105-
# stage will be either a map of the stage_name and its args, or just the stage_name itself
106-
stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0]
107-
stage_cls: type[stages.Stage] = getattr(stages, stage_name)
108-
# find arguments if given
109-
if isinstance(stage, dict):
110-
stage_yaml_args = stage[stage_name]
111-
stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args)
112-
else:
113-
# yaml has no arguments
114-
stage_obj = stage_cls()
115-
result_list.append(stage_obj)
116-
return client._pipeline_cls._create_with_stages(client, *result_list)
117-
118-
119-
def _is_expr_string(yaml_str):
120-
"""
121-
Returns true if a string represents a class in pipeline_expressions
122-
"""
123-
return (
124-
isinstance(yaml_str, str)
125-
and yaml_str[0].isupper()
126-
and hasattr(pipeline_expressions, yaml_str)
127-
)
128-
129-
130-
def _is_stage_string(yaml_str):
131-
"""
132-
Returns true if a string represents a class in pipeline_stages
133-
"""
134-
return (
135-
isinstance(yaml_str, str)
136-
and yaml_str[0].isupper()
137-
and hasattr(stages, yaml_str)
138-
)
139-
140-
141-
def parse_expressions(client, yaml_element: Any):
142-
"""
143-
Turn yaml objects into pipeline expressions or native python object arguments
144-
"""
145-
if isinstance(yaml_element, list):
146-
return [parse_expressions(client, v) for v in yaml_element]
147-
elif isinstance(yaml_element, dict):
148-
if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))):
149-
# build pipeline expressions if possible
150-
cls_str = next(iter(yaml_element))
151-
cls = getattr(pipeline_expressions, cls_str)
152-
yaml_args = yaml_element[cls_str]
153-
return _apply_yaml_args(cls, client, yaml_args)
154-
elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))):
155-
# build pipeline stage if possible (eg, for SampleOptions)
156-
cls_str = next(iter(yaml_element))
157-
cls = getattr(stages, cls_str)
158-
yaml_args = yaml_element[cls_str]
159-
return _apply_yaml_args(cls, client, yaml_args)
160-
elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline":
161-
# find Pipeline objects for Union expressions
162-
other_ppl = yaml_element["Pipeline"]
163-
return parse_pipeline(client, other_ppl)
164-
else:
165-
# otherwise, return dict
166-
return {
167-
parse_expressions(client, k): parse_expressions(client, v)
168-
for k, v in yaml_element.items()
169-
}
170-
elif _is_expr_string(yaml_element):
171-
return getattr(pipeline_expressions, yaml_element)()
172-
else:
173-
return yaml_element
174-
175-
17649
@pytest.mark.parametrize(
17750
"test_dict",
17851
[t for t in yaml_loader() if "assert_proto" in t],
@@ -268,3 +141,145 @@ async def test_pipeline_results_async(test_dict, async_client):
268141
assert got_results == expected_results
269142
if expected_count is not None:
270143
assert len(got_results) == expected_count
144+
145+
146+
#################################################################################
147+
# Helpers & Fixtures
148+
#################################################################################
149+
150+
151+
def parse_pipeline(client, pipeline: list[dict[str, Any], str]):
152+
"""
153+
parse a yaml list of pipeline stages into firestore._pipeline_stages.Stage classes
154+
"""
155+
result_list = []
156+
for stage in pipeline:
157+
# stage will be either a map of the stage_name and its args, or just the stage_name itself
158+
stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0]
159+
stage_cls: type[stages.Stage] = getattr(stages, stage_name)
160+
# find arguments if given
161+
if isinstance(stage, dict):
162+
stage_yaml_args = stage[stage_name]
163+
stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args)
164+
else:
165+
# yaml has no arguments
166+
stage_obj = stage_cls()
167+
result_list.append(stage_obj)
168+
return client._pipeline_cls._create_with_stages(client, *result_list)
169+
170+
171+
def _parse_expressions(client, yaml_element: Any):
172+
"""
173+
Turn yaml objects into pipeline expressions or native python object arguments
174+
"""
175+
if isinstance(yaml_element, list):
176+
return [_parse_expressions(client, v) for v in yaml_element]
177+
elif isinstance(yaml_element, dict):
178+
if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))):
179+
# build pipeline expressions if possible
180+
cls_str = next(iter(yaml_element))
181+
cls = getattr(pipeline_expressions, cls_str)
182+
yaml_args = yaml_element[cls_str]
183+
return _apply_yaml_args(cls, client, yaml_args)
184+
elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))):
185+
# build pipeline stage if possible (eg, for SampleOptions)
186+
cls_str = next(iter(yaml_element))
187+
cls = getattr(stages, cls_str)
188+
yaml_args = yaml_element[cls_str]
189+
return _apply_yaml_args(cls, client, yaml_args)
190+
elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline":
191+
# find Pipeline objects for Union expressions
192+
other_ppl = yaml_element["Pipeline"]
193+
return parse_pipeline(client, other_ppl)
194+
else:
195+
# otherwise, return dict
196+
return {
197+
_parse_expressions(client, k): _parse_expressions(client, v)
198+
for k, v in yaml_element.items()
199+
}
200+
elif _is_expr_string(yaml_element):
201+
return getattr(pipeline_expressions, yaml_element)()
202+
else:
203+
return yaml_element
204+
205+
206+
def _apply_yaml_args(cls, client, yaml_args):
207+
"""
208+
Helper to instantiate a class with yaml arguments. The arguments will be applied
209+
as positional or keyword arguments, based on type
210+
"""
211+
if isinstance(yaml_args, dict):
212+
return cls(**_parse_expressions(client, yaml_args))
213+
elif isinstance(yaml_args, list):
214+
# yaml has an array of arguments. Treat as args
215+
return cls(*_parse_expressions(client, yaml_args))
216+
else:
217+
# yaml has a single argument
218+
return cls(_parse_expressions(client, yaml_args))
219+
220+
221+
def _is_expr_string(yaml_str):
222+
"""
223+
Returns true if a string represents a class in pipeline_expressions
224+
"""
225+
return (
226+
isinstance(yaml_str, str)
227+
and yaml_str[0].isupper()
228+
and hasattr(pipeline_expressions, yaml_str)
229+
)
230+
231+
232+
def _is_stage_string(yaml_str):
233+
"""
234+
Returns true if a string represents a class in pipeline_stages
235+
"""
236+
return (
237+
isinstance(yaml_str, str)
238+
and yaml_str[0].isupper()
239+
and hasattr(stages, yaml_str)
240+
)
241+
242+
243+
@pytest.fixture(scope="module")
244+
def event_loop():
245+
"""Change event_loop fixture to module level."""
246+
import asyncio
247+
248+
policy = asyncio.get_event_loop_policy()
249+
loop = policy.new_event_loop()
250+
yield loop
251+
loop.close()
252+
253+
254+
@pytest.fixture(scope="module")
255+
def client():
256+
"""
257+
Build a client to use for requests
258+
"""
259+
client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB)
260+
data = yaml_loader("data")
261+
try:
262+
# setup data
263+
batch = client.batch()
264+
for collection_name, documents in data.items():
265+
collection_ref = client.collection(collection_name)
266+
for document_id, document_data in documents.items():
267+
document_ref = collection_ref.document(document_id)
268+
batch.set(document_ref, document_data)
269+
batch.commit()
270+
yield client
271+
finally:
272+
# clear data
273+
for collection_name, documents in data.items():
274+
collection_ref = client.collection(collection_name)
275+
for document_id in documents:
276+
document_ref = collection_ref.document(document_id)
277+
document_ref.delete()
278+
279+
280+
@pytest.fixture(scope="module")
281+
def async_client(client):
282+
"""
283+
Build an async client to use for AsyncPipeline requests
284+
"""
285+
yield AsyncClient(project=client.project, database=client._database)

0 commit comments

Comments
 (0)