diff --git a/docs/source/api.rst b/docs/source/api.rst index 3481480fe..7d5a6307d 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -740,15 +740,13 @@ A :class:`neo4j.Result` is attached to an active connection, through a :class:`n .. automethod:: graph - **This is experimental.** (See :ref:`filter-warnings-ref`) - .. automethod:: value .. automethod:: values .. automethod:: data -See https://neo4j.com/docs/driver-manual/current/cypher-workflow/#driver-type-mapping for more about type mapping. +See https://neo4j.com/docs/python-manual/current/cypher-workflow/#python-driver-type-mapping for more about type mapping. Graph diff --git a/neo4j/work/result.py b/neo4j/work/result.py index bb3dba02c..4808ee731 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -24,8 +24,8 @@ from neo4j.data import DataDehydrator from neo4j.io import ConnectionErrorHandler +from neo4j.meta import experimental from neo4j.work.summary import ResultSummary -from neo4j.exceptions import ResultConsumedError class Result: @@ -335,6 +335,8 @@ def graph(self): :returns: a result graph :rtype: :class:`neo4j.graph.Graph` + + **This is experimental.** (See :ref:`filter-warnings-ref`) """ self._buffer_all() return self._hydrant.graph @@ -372,3 +374,28 @@ def data(self, *keys): :rtype: list """ return [record.data(*keys) for record in self] + + @experimental("pandas support is experimental and might be changed or " + "removed in future versions") + def to_df(self): + """Convert (the rest of) the result to a pandas DataFrame. + + This method is only available if the `pandas` library is installed. + + ``tx.run("UNWIND range(1, 10) AS n RETURN n, n+1 as m").to_df()``, for + instance will return a DataFrame with two columns: ``n`` and ``m`` and + 10 rows. + + :rtype: :py:class:`pandas.DataFrame` + :raises ImportError: if `pandas` library is not available. + + .. versionadded:: 5.0 + This method was backported from 5.0 for preview purposes. + + **This is experimental.** + ``pandas`` support might be changed or removed in future versions + without warning. (See :ref:`filter-warnings-ref`) + """ + import pandas as pd + + return pd.DataFrame(self.values(), columns=self._keys) diff --git a/tests/requirements.txt b/tests/requirements.txt index b9405952f..56b4d66e3 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -5,3 +5,4 @@ pytest-benchmark pytest-cov pytest-mock teamcity-messages +pandas>=1.0.0 diff --git a/tests/unit/work/test_result.py b/tests/unit/work/test_result.py index c4b5aa13d..c33331f6a 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/work/test_result.py @@ -21,6 +21,7 @@ from unittest import mock +import pandas as pd import pytest from neo4j import ( @@ -31,16 +32,20 @@ SummaryCounters, Version, ) -from neo4j.data import DataHydrator +from neo4j.data import ( + DataHydrator, + Node, + Relationship, +) +from neo4j.packstream import Structure from neo4j.work.result import Result class Records: def __init__(self, fields, records): - assert all(len(fields) == len(r) for r in records) - self.fields = fields - # self.records = [{"record_values": r} for r in records] - self.records = records + self.fields = tuple(fields) + self.records = tuple(records) + assert all(len(self.fields) == len(r) for r in self.records) def __len__(self): return self.records.__len__() @@ -422,3 +427,54 @@ def test_data(num_records): assert result.data("hello", "world") == expected_data for record in records: assert record.data.called_once_with("hello", "world") + + +@pytest.mark.parametrize( + ("keys", "values", "types", "instances"), + ( + (["i"], zip(range(5)), ["int64"], None), + (["x"], zip((n - .5) / 5 for n in range(5)), ["float64"], None), + (["s"], zip(("foo", "bar", "baz", "foobar")), ["object"], None), + (["l"], zip(([1, 2], [3, 4])), ["object"], None), + ( + ["n"], + zip(( + Structure(b"N", 0, ["LABEL_A"], {"a": 1, "b": 2}), + Structure(b"N", 2, ["LABEL_B"], {"a": 1, "c": 1.2}), + Structure(b"N", 1, ["LABEL_A", "LABEL_B"], {"a": [1, "a"]}), + )), + ["object"], + [Node] + ), + ( + ["r"], + zip(( + Structure(b"R", 0, 1, 2, "TYPE", {"a": 1, "b": 2}), + Structure(b"R", 420, 1337, 69, "HYPE", {"all memes": True}), + )), + ["object"], + [Relationship] + ), + ) +) +def test_to_df(keys, values, types, instances): + values = list(values) + connection = ConnectionStub(records=Records(keys, values)) + result = Result(connection, DataHydrator(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + df = result.to_df() + + assert isinstance(df, pd.DataFrame) + assert df.keys().to_list() == keys + assert len(df) == len(values) + assert df.dtypes.to_list() == types + + expected_df = pd.DataFrame( + {k: [v[i] for v in values] for i, k in enumerate(keys)} + ) + + if instances: + for i, k in enumerate(keys): + assert all(isinstance(v, instances[i]) for v in df[k]) + else: + assert df.equals(expected_df)