|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +""" |
| 15 | +This file loads and executes yaml-encoded test cases from pipeline_e2e.yaml |
| 16 | +""" |
14 | 17 |
|
15 | 18 | from __future__ import annotations
|
16 | 19 | import os
|
|
36 | 39 |
|
37 | 40 | def yaml_loader(field="tests"):
|
38 | 41 | """
|
39 |
| - loads test cases or data from yaml file |
| 42 | + Helper to load test cases or data from yaml file |
40 | 43 | """
|
41 | 44 | with open(f"{test_dir_name}/pipeline_e2e.yaml") as f:
|
42 | 45 | test_cases = yaml.safe_load(f)
|
43 | 46 | return test_cases[field]
|
44 | 47 |
|
45 | 48 |
|
46 |
| -@pytest.fixture(scope="module") |
47 |
| -def event_loop(): |
48 |
| - """Change event_loop fixture to module level.""" |
49 |
| - import asyncio |
50 |
| - policy = asyncio.get_event_loop_policy() |
51 |
| - loop = policy.new_event_loop() |
52 |
| - yield loop |
53 |
| - loop.close() |
54 |
| - |
55 |
| - |
56 |
| -@pytest.fixture(scope="module") |
57 |
| -def client(): |
58 |
| - client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) |
59 |
| - data = yaml_loader("data") |
60 |
| - try: |
61 |
| - # setup data |
62 |
| - batch = client.batch() |
63 |
| - for collection_name, documents in data.items(): |
64 |
| - collection_ref = client.collection(collection_name) |
65 |
| - for document_id, document_data in documents.items(): |
66 |
| - document_ref = collection_ref.document(document_id) |
67 |
| - batch.set(document_ref, document_data) |
68 |
| - batch.commit() |
69 |
| - yield client |
70 |
| - finally: |
71 |
| - # clear data |
72 |
| - for collection_name, documents in data.items(): |
73 |
| - collection_ref = client.collection(collection_name) |
74 |
| - for document_id in documents: |
75 |
| - document_ref = collection_ref.document(document_id) |
76 |
| - document_ref.delete() |
77 |
| - |
78 |
| - |
79 |
| -@pytest.fixture(scope="module") |
80 |
| -def async_client(client): |
81 |
| - yield AsyncClient(project=client.project, database=client._database) |
82 |
| - |
83 |
| - |
84 |
| -def _apply_yaml_args(cls, client, yaml_args): |
85 |
| - """ |
86 |
| - Helper to instantiate a class with yaml arguments. The arguments will be applied |
87 |
| - as positional or keyword arguments, based on type |
88 |
| - """ |
89 |
| - if isinstance(yaml_args, dict): |
90 |
| - return cls(**parse_expressions(client, yaml_args)) |
91 |
| - elif isinstance(yaml_args, list): |
92 |
| - # yaml has an array of arguments. Treat as args |
93 |
| - return cls(*parse_expressions(client, yaml_args)) |
94 |
| - else: |
95 |
| - # yaml has a single argument |
96 |
| - return cls(parse_expressions(client, yaml_args)) |
97 |
| - |
98 |
| - |
99 |
| -def parse_pipeline(client, pipeline: list[dict[str, Any], str]): |
100 |
| - """ |
101 |
| - parse a yaml list of pipeline stages into firestore._pipeline_stages.Stage classes |
102 |
| - """ |
103 |
| - result_list = [] |
104 |
| - for stage in pipeline: |
105 |
| - # stage will be either a map of the stage_name and its args, or just the stage_name itself |
106 |
| - stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0] |
107 |
| - stage_cls: type[stages.Stage] = getattr(stages, stage_name) |
108 |
| - # find arguments if given |
109 |
| - if isinstance(stage, dict): |
110 |
| - stage_yaml_args = stage[stage_name] |
111 |
| - stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args) |
112 |
| - else: |
113 |
| - # yaml has no arguments |
114 |
| - stage_obj = stage_cls() |
115 |
| - result_list.append(stage_obj) |
116 |
| - return client._pipeline_cls._create_with_stages(client, *result_list) |
117 |
| - |
118 |
| - |
119 |
| -def _is_expr_string(yaml_str): |
120 |
| - """ |
121 |
| - Returns true if a string represents a class in pipeline_expressions |
122 |
| - """ |
123 |
| - return ( |
124 |
| - isinstance(yaml_str, str) |
125 |
| - and yaml_str[0].isupper() |
126 |
| - and hasattr(pipeline_expressions, yaml_str) |
127 |
| - ) |
128 |
| - |
129 |
| - |
130 |
| -def _is_stage_string(yaml_str): |
131 |
| - """ |
132 |
| - Returns true if a string represents a class in pipeline_stages |
133 |
| - """ |
134 |
| - return ( |
135 |
| - isinstance(yaml_str, str) |
136 |
| - and yaml_str[0].isupper() |
137 |
| - and hasattr(stages, yaml_str) |
138 |
| - ) |
139 |
| - |
140 |
| - |
141 |
| -def parse_expressions(client, yaml_element: Any): |
142 |
| - """ |
143 |
| - Turn yaml objects into pipeline expressions or native python object arguments |
144 |
| - """ |
145 |
| - if isinstance(yaml_element, list): |
146 |
| - return [parse_expressions(client, v) for v in yaml_element] |
147 |
| - elif isinstance(yaml_element, dict): |
148 |
| - if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))): |
149 |
| - # build pipeline expressions if possible |
150 |
| - cls_str = next(iter(yaml_element)) |
151 |
| - cls = getattr(pipeline_expressions, cls_str) |
152 |
| - yaml_args = yaml_element[cls_str] |
153 |
| - return _apply_yaml_args(cls, client, yaml_args) |
154 |
| - elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))): |
155 |
| - # build pipeline stage if possible (eg, for SampleOptions) |
156 |
| - cls_str = next(iter(yaml_element)) |
157 |
| - cls = getattr(stages, cls_str) |
158 |
| - yaml_args = yaml_element[cls_str] |
159 |
| - return _apply_yaml_args(cls, client, yaml_args) |
160 |
| - elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": |
161 |
| - # find Pipeline objects for Union expressions |
162 |
| - other_ppl = yaml_element["Pipeline"] |
163 |
| - return parse_pipeline(client, other_ppl) |
164 |
| - else: |
165 |
| - # otherwise, return dict |
166 |
| - return { |
167 |
| - parse_expressions(client, k): parse_expressions(client, v) |
168 |
| - for k, v in yaml_element.items() |
169 |
| - } |
170 |
| - elif _is_expr_string(yaml_element): |
171 |
| - return getattr(pipeline_expressions, yaml_element)() |
172 |
| - else: |
173 |
| - return yaml_element |
174 |
| - |
175 |
| - |
176 | 49 | @pytest.mark.parametrize(
|
177 | 50 | "test_dict",
|
178 | 51 | [t for t in yaml_loader() if "assert_proto" in t],
|
@@ -268,3 +141,145 @@ async def test_pipeline_results_async(test_dict, async_client):
|
268 | 141 | assert got_results == expected_results
|
269 | 142 | if expected_count is not None:
|
270 | 143 | assert len(got_results) == expected_count
|
| 144 | + |
| 145 | + |
| 146 | +################################################################################# |
| 147 | +# Helpers & Fixtures |
| 148 | +################################################################################# |
| 149 | + |
| 150 | + |
| 151 | +def parse_pipeline(client, pipeline: list[dict[str, Any], str]): |
| 152 | + """ |
| 153 | + parse a yaml list of pipeline stages into firestore._pipeline_stages.Stage classes |
| 154 | + """ |
| 155 | + result_list = [] |
| 156 | + for stage in pipeline: |
| 157 | + # stage will be either a map of the stage_name and its args, or just the stage_name itself |
| 158 | + stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0] |
| 159 | + stage_cls: type[stages.Stage] = getattr(stages, stage_name) |
| 160 | + # find arguments if given |
| 161 | + if isinstance(stage, dict): |
| 162 | + stage_yaml_args = stage[stage_name] |
| 163 | + stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args) |
| 164 | + else: |
| 165 | + # yaml has no arguments |
| 166 | + stage_obj = stage_cls() |
| 167 | + result_list.append(stage_obj) |
| 168 | + return client._pipeline_cls._create_with_stages(client, *result_list) |
| 169 | + |
| 170 | + |
| 171 | +def _parse_expressions(client, yaml_element: Any): |
| 172 | + """ |
| 173 | + Turn yaml objects into pipeline expressions or native python object arguments |
| 174 | + """ |
| 175 | + if isinstance(yaml_element, list): |
| 176 | + return [_parse_expressions(client, v) for v in yaml_element] |
| 177 | + elif isinstance(yaml_element, dict): |
| 178 | + if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))): |
| 179 | + # build pipeline expressions if possible |
| 180 | + cls_str = next(iter(yaml_element)) |
| 181 | + cls = getattr(pipeline_expressions, cls_str) |
| 182 | + yaml_args = yaml_element[cls_str] |
| 183 | + return _apply_yaml_args(cls, client, yaml_args) |
| 184 | + elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))): |
| 185 | + # build pipeline stage if possible (eg, for SampleOptions) |
| 186 | + cls_str = next(iter(yaml_element)) |
| 187 | + cls = getattr(stages, cls_str) |
| 188 | + yaml_args = yaml_element[cls_str] |
| 189 | + return _apply_yaml_args(cls, client, yaml_args) |
| 190 | + elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": |
| 191 | + # find Pipeline objects for Union expressions |
| 192 | + other_ppl = yaml_element["Pipeline"] |
| 193 | + return parse_pipeline(client, other_ppl) |
| 194 | + else: |
| 195 | + # otherwise, return dict |
| 196 | + return { |
| 197 | + _parse_expressions(client, k): _parse_expressions(client, v) |
| 198 | + for k, v in yaml_element.items() |
| 199 | + } |
| 200 | + elif _is_expr_string(yaml_element): |
| 201 | + return getattr(pipeline_expressions, yaml_element)() |
| 202 | + else: |
| 203 | + return yaml_element |
| 204 | + |
| 205 | + |
| 206 | +def _apply_yaml_args(cls, client, yaml_args): |
| 207 | + """ |
| 208 | + Helper to instantiate a class with yaml arguments. The arguments will be applied |
| 209 | + as positional or keyword arguments, based on type |
| 210 | + """ |
| 211 | + if isinstance(yaml_args, dict): |
| 212 | + return cls(**_parse_expressions(client, yaml_args)) |
| 213 | + elif isinstance(yaml_args, list): |
| 214 | + # yaml has an array of arguments. Treat as args |
| 215 | + return cls(*_parse_expressions(client, yaml_args)) |
| 216 | + else: |
| 217 | + # yaml has a single argument |
| 218 | + return cls(_parse_expressions(client, yaml_args)) |
| 219 | + |
| 220 | + |
| 221 | +def _is_expr_string(yaml_str): |
| 222 | + """ |
| 223 | + Returns true if a string represents a class in pipeline_expressions |
| 224 | + """ |
| 225 | + return ( |
| 226 | + isinstance(yaml_str, str) |
| 227 | + and yaml_str[0].isupper() |
| 228 | + and hasattr(pipeline_expressions, yaml_str) |
| 229 | + ) |
| 230 | + |
| 231 | + |
| 232 | +def _is_stage_string(yaml_str): |
| 233 | + """ |
| 234 | + Returns true if a string represents a class in pipeline_stages |
| 235 | + """ |
| 236 | + return ( |
| 237 | + isinstance(yaml_str, str) |
| 238 | + and yaml_str[0].isupper() |
| 239 | + and hasattr(stages, yaml_str) |
| 240 | + ) |
| 241 | + |
| 242 | + |
| 243 | +@pytest.fixture(scope="module") |
| 244 | +def event_loop(): |
| 245 | + """Change event_loop fixture to module level.""" |
| 246 | + import asyncio |
| 247 | + |
| 248 | + policy = asyncio.get_event_loop_policy() |
| 249 | + loop = policy.new_event_loop() |
| 250 | + yield loop |
| 251 | + loop.close() |
| 252 | + |
| 253 | + |
| 254 | +@pytest.fixture(scope="module") |
| 255 | +def client(): |
| 256 | + """ |
| 257 | + Build a client to use for requests |
| 258 | + """ |
| 259 | + client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) |
| 260 | + data = yaml_loader("data") |
| 261 | + try: |
| 262 | + # setup data |
| 263 | + batch = client.batch() |
| 264 | + for collection_name, documents in data.items(): |
| 265 | + collection_ref = client.collection(collection_name) |
| 266 | + for document_id, document_data in documents.items(): |
| 267 | + document_ref = collection_ref.document(document_id) |
| 268 | + batch.set(document_ref, document_data) |
| 269 | + batch.commit() |
| 270 | + yield client |
| 271 | + finally: |
| 272 | + # clear data |
| 273 | + for collection_name, documents in data.items(): |
| 274 | + collection_ref = client.collection(collection_name) |
| 275 | + for document_id in documents: |
| 276 | + document_ref = collection_ref.document(document_id) |
| 277 | + document_ref.delete() |
| 278 | + |
| 279 | + |
| 280 | +@pytest.fixture(scope="module") |
| 281 | +def async_client(client): |
| 282 | + """ |
| 283 | + Build an async client to use for AsyncPipeline requests |
| 284 | + """ |
| 285 | + yield AsyncClient(project=client.project, database=client._database) |
0 commit comments