Skip to content

Commit 74d4666

Browse files
committed
Pipeines can not optionally be transactions (wrapped in MULTI/EXEC) or not by passing the transaction parameter. This fixes #23.
1 parent b33a6fc commit 74d4666

File tree

2 files changed

+60
-26
lines changed

2 files changed

+60
-26
lines changed

redis/client.py

+46-18
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import socket
44
import threading
55
import warnings
6+
from itertools import chain
67
from redis.exceptions import ConnectionError, ResponseError, InvalidResponse
78
from redis.exceptions import RedisError, AuthenticationError
89

@@ -246,8 +247,20 @@ def _get_db(self):
246247
return self.connection.db
247248
db = property(_get_db)
248249

249-
def pipeline(self):
250-
return Pipeline(self.connection, self.encoding, self.errors)
250+
def pipeline(self, transaction=True):
251+
"""
252+
Return a new pipeline object that can queue multiple commands for
253+
later execution. ``transaction`` indicates whether all commands
254+
should be executed atomically. Apart from multiple atomic operations,
255+
pipelines are useful for batch loading of data as they reduce the
256+
number of back and forth network operations between client and server.
257+
"""
258+
return Pipeline(
259+
self.connection,
260+
transaction,
261+
self.encoding,
262+
self.errors
263+
)
251264

252265

253266
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
@@ -1032,16 +1045,16 @@ class Pipeline(Redis):
10321045
ResponseError exceptions, such as those raised when issuing a command
10331046
on a key of a different datatype.
10341047
"""
1035-
def __init__(self, connection, charset, errors):
1048+
def __init__(self, connection, transaction, charset, errors):
10361049
self.connection = connection
1050+
self.transaction = transaction
10371051
self.encoding = charset
10381052
self.errors = errors
10391053
self.subscribed = False # NOTE not in use, but necessary
10401054
self.reset()
10411055

10421056
def reset(self):
10431057
self.command_stack = []
1044-
self.execute_command('MULTI')
10451058

10461059
def _execute_command(self, command_name, command, **options):
10471060
"""
@@ -1066,19 +1079,20 @@ def _execute_command(self, command_name, command, **options):
10661079
self.command_stack.append((command_name, command, options))
10671080
return self
10681081

1069-
def _execute(self, commands):
1070-
# build up all commands into a single request to increase network perf
1071-
all_cmds = ''.join([c for _1, c, _2 in commands])
1082+
def _execute_transaction(self, commands):
1083+
# wrap the commands in MULTI ... EXEC statements to indicate an
1084+
# atomic operation
1085+
all_cmds = ''.join([c for _1, c, _2 in chain(
1086+
(('', 'MULTI\r\n', ''),),
1087+
commands,
1088+
(('', 'EXEC\r\n', ''),)
1089+
)])
10721090
self.connection.send(all_cmds, self)
1073-
# we only care about the last item in the response, which should be
1074-
# the EXEC command
1075-
for i in range(len(commands)-1):
1091+
# parse off the response for MULTI and all commands prior to EXEC
1092+
for i in range(len(commands)+1):
10761093
_ = self.parse_response('_')
1077-
# tell the response parse to catch errors and return them as
1078-
# part of the response
1094+
# parse the EXEC. we want errors returned as items in the response
10791095
response = self.parse_response('_', catch_errors=True)
1080-
# don't return the results of the MULTI or EXEC command
1081-
commands = [(c[0], c[2]) for c in commands[1:-1]]
10821096
if len(response) != len(commands):
10831097
raise ResponseError("Wrong number of response items from "
10841098
"pipline execution")
@@ -1087,20 +1101,34 @@ def _execute(self, commands):
10871101
for r, cmd in zip(response, commands):
10881102
if not isinstance(r, Exception):
10891103
if cmd[0] in self.RESPONSE_CALLBACKS:
1090-
r = self.RESPONSE_CALLBACKS[cmd[0]](r, **cmd[1])
1104+
r = self.RESPONSE_CALLBACKS[cmd[0]](r, **cmd[2])
10911105
data.append(r)
10921106
return data
10931107

1108+
def _execute_pipeline(self, commands):
1109+
# build up all commands into a single request to increase network perf
1110+
all_cmds = ''.join([c for _1, c, _2 in commands])
1111+
self.connection.send(all_cmds, self)
1112+
data = []
1113+
for command_name, _, options in commands:
1114+
data.append(
1115+
self.parse_response(command_name, catch_errors=True, **options)
1116+
)
1117+
return data
1118+
10941119
def execute(self):
10951120
"Execute all the commands in the current pipeline"
1096-
self.execute_command('EXEC')
10971121
stack = self.command_stack
10981122
self.reset()
1123+
if self.transaction:
1124+
execute = self._execute_transaction
1125+
else:
1126+
execute = self._execute_pipeline
10991127
try:
1100-
return self._execute(stack)
1128+
return execute(stack)
11011129
except ConnectionError:
11021130
self.connection.disconnect()
1103-
return self._execute(stack)
1131+
return execute(stack)
11041132

11051133
def select(self, *args, **kwargs):
11061134
raise RedisError("Cannot select a different database from a pipeline")

tests/pipeline.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ class PipelineTestCase(unittest.TestCase):
55
def setUp(self):
66
self.client = redis.Redis(host='localhost', port=6379, db=9)
77
self.client.flushdb()
8-
8+
99
def tearDown(self):
1010
self.client.flushdb()
11-
11+
1212
def test_pipeline(self):
1313
pipe = self.client.pipeline()
1414
pipe.set('a', 'a1').get('a').zadd('z', 'z1', 1).zadd('z', 'z2', 4)
@@ -23,14 +23,14 @@ def test_pipeline(self):
2323
[('z1', 2.0), ('z2', 4)],
2424
]
2525
)
26-
26+
2727
def test_pipeline_with_fresh_connection(self):
2828
redis.client.connection_manager.connections.clear()
2929
self.client = redis.Redis(host='localhost', port=6379, db=9)
3030
pipe = self.client.pipeline()
3131
pipe.set('a', 'b')
3232
self.assertEquals(pipe.execute(), [True])
33-
33+
3434
def test_invalid_command_in_pipeline(self):
3535
# all commands but the invalid one should be excuted correctly
3636
self.client['c'] = 'a'
@@ -53,10 +53,16 @@ def test_invalid_command_in_pipeline(self):
5353
self.assertEquals(pipe.set('z', 'zzz').execute(), [True])
5454
self.assertEquals(self.client['z'], 'zzz')
5555

56-
def test_pipe_cannot_select(self):
56+
def test_pipeline_cannot_select(self):
5757
pipe = self.client.pipeline()
5858
self.assertRaises(redis.RedisError,
5959
pipe.select, 'localhost', 6379, db=9)
60-
61-
62-
60+
61+
def test_pipeline_no_transaction(self):
62+
pipe = self.client.pipeline(transaction=False)
63+
pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1')
64+
self.assertEquals(pipe.execute(), [True, True, True])
65+
self.assertEquals(self.client['a'], 'a1')
66+
self.assertEquals(self.client['b'], 'b1')
67+
self.assertEquals(self.client['c'], 'c1')
68+

0 commit comments

Comments
 (0)