diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index aba256539..a9309d6c2 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -146,14 +146,16 @@ def is_fatal_during_discovery(self): return False def __str__(self): - return "{{code: {code}}} {{message: {message}}}".format(code=self.code, message=self.message) + if self.code or self.message: + return "{{code: {code}}} {{message: {message}}}".format( + code=self.code, message=self.message + ) + return super().__str__() class ClientError(Neo4jError): """ The Client sent a bad request - changing the request might yield a successful outcome. """ - def __str__(self): - return super(Neo4jError, self).__str__() class DatabaseError(Neo4jError): diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py index 2152a5b6d..81475c949 100644 --- a/tests/unit/test_exceptions.py +++ b/tests/unit/test_exceptions.py @@ -237,3 +237,59 @@ def test_transient_error_is_retriable_case_3(): assert isinstance(error, TransientError) assert error.is_retriable() is True + + +@pytest.mark.parametrize( + ("code", "message", "expected_cls", "expected_str"), + ( + ( + "Neo.ClientError.General.UnknownError", + "Test error message", + ClientError, + "{code: Neo.ClientError.General.UnknownError} " + "{message: Test error message}" + ), + ( + None, + "Test error message", + DatabaseError, + "{code: Neo.DatabaseError.General.UnknownError} " + "{message: Test error message}" + ), + ( + "", + "Test error message", + DatabaseError, + "{code: Neo.DatabaseError.General.UnknownError} " + "{message: Test error message}" + ), + ( + "Neo.ClientError.General.UnknownError", + None, + ClientError, + "{code: Neo.ClientError.General.UnknownError} " + "{message: An unknown error occurred}" + ), + ( + "Neo.ClientError.General.UnknownError", + "", + ClientError, + "{code: Neo.ClientError.General.UnknownError} " + "{message: An unknown error occurred}" + ), + ) +) +def test_neo4j_error_from_server_as_str(code, message, expected_cls, + expected_str): + error = Neo4jError.hydrate(code=code, message=message) + + assert type(error) == expected_cls + assert str(error) == expected_str + + +@pytest.mark.parametrize("cls", (Neo4jError, ClientError)) +def test_neo4j_error_from_code_as_str(cls): + error = cls("Generated somewhere in the driver") + + assert type(error)== cls + assert str(error) == "Generated somewhere in the driver"