diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index b5ea0203..16a724d0 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -135,7 +135,7 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph: symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False, ) - print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") + logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") # NOTE: At this point, we _could_ choose to implement something similar to # NodeDict._fetch_all() and AdjListOuterDict._fetch_all() to iterate through @@ -195,7 +195,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: ), ) - print(f"ADB Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") + logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") start_time = time.time() @@ -240,6 +240,6 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: key_to_id=vertex_ids_to_index, ) - print(f"NXCG Graph construction took {time.time() - start_time}s") + logger.info(f"NXCG Graph construction took {time.time() - start_time}s") return G.nxcg_graph diff --git a/tests/conftest.py b/tests/conftest.py index f807426f..a1465d26 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -120,17 +120,3 @@ def create_grid_graph(graph_cls: type[nxadb.Graph]) -> nxadb.Graph: return graph_cls( incoming_graph_data=grid_graph, name="GridGraph", write_async=False ) - - -# Taken from: -# https://stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call -class Capturing(list[str]): - def __enter__(self): - self._stdout = sys.stdout - sys.stdout = self._stringio = StringIO() - return self - - def __exit__(self, *args): - self.extend(self._stringio.getvalue().splitlines()) - del self._stringio # free up some memory - sys.stdout = self._stdout diff --git a/tests/test.py b/tests/test.py index 145feb6a..79277180 100644 --- a/tests/test.py +++ b/tests/test.py @@ -15,7 +15,7 @@ from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict, EdgeKeyDict from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict -from .conftest import Capturing, create_grid_graph, create_line_graph, db, run_gpu_tests +from .conftest import create_grid_graph, create_line_graph, db, run_gpu_tests G_NX = nx.karate_club_graph() G_NX_digraph = nx.DiGraph(G_NX) @@ -344,38 +344,53 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None: assert nxadb.convert.GPU_AVAILABLE is True assert nx.config.backends.arangodb.use_gpu is True + assert graph.nxcg_graph is None - res_gpu = None - res_cpu = None - - # Measure GPU execution time + # 1. GPU start_gpu = time.time() + res_gpu = nx.pagerank(graph) + gpu_time = time.time() - start_gpu - # Note: While this works, we should use the logger or some alternative - # approach testing this. Via stdout is not the best way to test this. - with Capturing() as output_gpu: - res_gpu = nx.pagerank(graph) + assert graph.nxcg_graph is not None + assert graph.nxcg_graph.number_of_nodes() == 250000 + assert graph.nxcg_graph.number_of_edges() == 499000 - assert any( - "NXCG Graph construction took" in line for line in output_gpu - ), "Expected output not found in GPU execution" + # 2. GPU (cached) + assert graph.use_nxcg_cache is True - gpu_time = time.time() - start_gpu + start_gpu_cached = time.time() + res_gpu_cached = nx.pagerank(graph) + gpu_cached_time = time.time() - start_gpu_cached - # Disable GPU and measure CPU execution time - nx.config.backends.arangodb.use_gpu = False - start_cpu = time.time() - with Capturing() as output_cpu: - res_cpu = nx.pagerank(graph) + assert gpu_cached_time < gpu_time + assert_pagerank(res_gpu, res_gpu_cached, 10) + + # 3. GPU (disable cache) + graph.use_nxcg_cache = False - output_cpu_list = list(output_cpu) - assert len(output_cpu_list) == 1 - assert "Graph 'GridGraph' load took" in output_cpu_list[0] + start_gpu_no_cache = time.time() + res_gpu_no_cache = nx.pagerank(graph) + gpu_no_cache_time = time.time() - start_gpu_no_cache + assert gpu_cached_time < gpu_no_cache_time + assert_pagerank(res_gpu_cached, res_gpu_no_cache, 10) + + # 4. CPU + assert graph.nxcg_graph is not None + graph.clear_nxcg_cache() + assert graph.nxcg_graph is None + nx.config.backends.arangodb.use_gpu = False + + start_cpu = time.time() + res_cpu = nx.pagerank(graph) cpu_time = time.time() - start_cpu - assert gpu_time < cpu_time, "GPU execution should be faster than CPU execution" - assert_pagerank(res_gpu, res_cpu, 10) + assert graph.nxcg_graph is None + + m = "GPU execution should be faster than CPU execution" + assert gpu_time < cpu_time, m + assert gpu_no_cache_time < cpu_time, m + assert_pagerank(res_gpu_no_cache, res_cpu, 10) @pytest.mark.parametrize(