Skip to content
Draft
73 changes: 53 additions & 20 deletions cassandra/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,26 +720,59 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata,
self.column_types = [c[3] for c in column_metadata]
col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]

def decode_val(val, col_md, col_desc):
uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc)
col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3]
raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val
return col_type.from_binary(raw_bytes, protocol_version)

def decode_row(row):
return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs))

try:
self.parsed_rows = [decode_row(row) for row in rows]
except Exception:
for row in rows:
for val, col_md, col_desc in zip(row, column_metadata, col_descs):
try:
decode_val(val, col_md, col_desc)
except Exception as e:
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
col_md[3].cql_parameterized_type(),
str(e)))
# Optimize by checking column_encryption_policy once and creating appropriate decode path
if column_encryption_policy:
# Pre-compute encryption info for each column to avoid repeated lookups
column_encryption_info = [
(column_encryption_policy.contains_column(col_desc), col_desc)
for col_desc in col_descs
]

def decode_val_with_encryption(val, col_md, uses_ce, col_desc):
if uses_ce:
col_type = column_encryption_policy.column_type(col_desc)
raw_bytes = column_encryption_policy.decrypt(col_desc, val)
else:
col_type = col_md[3]
raw_bytes = val
return col_type.from_binary(raw_bytes, protocol_version)

def decode_row(row):
return tuple(
decode_val_with_encryption(val, col_md, uses_ce, col_desc)
for val, col_md, (uses_ce, col_desc) in zip(row, column_metadata, column_encryption_info)
)

try:
self.parsed_rows = [decode_row(row) for row in rows]
except Exception:
for row in rows:
for val, col_md, (uses_ce, col_desc) in zip(row, column_metadata, column_encryption_info):
try:
decode_val_with_encryption(val, col_md, uses_ce, col_desc)
except Exception as e:
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
col_md[3].cql_parameterized_type(),
str(e)))
else:
# Simple path without encryption - just decode raw bytes directly
def decode_val_simple(val, col_type):
return col_type.from_binary(val, protocol_version)

def decode_row(row):
return tuple(decode_val_simple(val, col_md[3]) for val, col_md in zip(row, column_metadata))

try:
self.parsed_rows = [decode_row(row) for row in rows]
except Exception:
for row in rows:
for val, col_md in zip(row, column_metadata):
try:
decode_val_simple(val, col_md[3])
except Exception as e:
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
col_md[3].cql_parameterized_type(),
str(e)))

def recv_results_prepared(self, f, protocol_version, user_type_map):
self.query_id = read_binary_string(f)
Expand Down
152 changes: 152 additions & 0 deletions tests/unit/test_protocol_decode_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from unittest.mock import Mock, MagicMock
import io

from cassandra import ProtocolVersion
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS
from cassandra.cqltypes import Int32Type, UTF8Type
from cassandra.policies import ColDesc
from cassandra.marshal import int32_pack


class DecodeOptimizationTest(unittest.TestCase):
"""
Tests to verify the optimization of column_encryption_policy checks
in recv_results_rows. The optimization should avoid checking the policy
for every value and instead check once per recv_results_rows call.
"""

def _create_mock_result_metadata(self):
"""Create mock result metadata for testing"""
return [
('keyspace1', 'table1', 'col1', Int32Type),
('keyspace1', 'table1', 'col2', UTF8Type),
]

def _create_mock_result_message(self):
"""Create a mock result message with data"""
msg = ResultMessage(kind=RESULT_KIND_ROWS)
msg.column_metadata = self._create_mock_result_metadata()
msg.recv_results_metadata = Mock()
msg.recv_row = Mock(side_effect=[
[int32_pack(42), b'hello'],
[int32_pack(100), b'world'],
])
return msg

def _create_mock_stream(self):
"""Create a mock stream for reading rows"""
# Pack rowcount (2 rows)
data = int32_pack(2)
return io.BytesIO(data)

def test_decode_without_encryption_policy(self):
"""
Test that decoding works correctly without column encryption policy.
This should use the optimized simple path.
"""
msg = self._create_mock_result_message()
f = self._create_mock_stream()

msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None)

# Verify results
self.assertEqual(len(msg.parsed_rows), 2)
self.assertEqual(msg.parsed_rows[0][0], 42)
self.assertEqual(msg.parsed_rows[0][1], 'hello')
self.assertEqual(msg.parsed_rows[1][0], 100)
self.assertEqual(msg.parsed_rows[1][1], 'world')

def test_decode_with_encryption_policy_no_encrypted_columns(self):
"""
Test that decoding works with encryption policy when no columns are encrypted.
"""
msg = self._create_mock_result_message()
f = self._create_mock_stream()

# Create mock encryption policy that has no encrypted columns
mock_policy = Mock()
mock_policy.contains_column = Mock(return_value=False)

msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)

# Verify results
self.assertEqual(len(msg.parsed_rows), 2)
self.assertEqual(msg.parsed_rows[0][0], 42)
self.assertEqual(msg.parsed_rows[0][1], 'hello')

# Verify contains_column was called only once per column (optimization check)
# Should be called 2 times total (once per column, not per value per row)
self.assertEqual(mock_policy.contains_column.call_count, 2)

def test_decode_with_encryption_policy_with_encrypted_column(self):
"""
Test that decoding works with encryption policy when one column is encrypted.
"""
msg = self._create_mock_result_message()
f = self._create_mock_stream()

# Create mock encryption policy where first column is encrypted
mock_policy = Mock()
def contains_column_side_effect(col_desc):
return col_desc.col == 'col1'
mock_policy.contains_column = Mock(side_effect=contains_column_side_effect)
mock_policy.column_type = Mock(return_value=Int32Type)
mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val)

msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)

# Verify results
self.assertEqual(len(msg.parsed_rows), 2)
self.assertEqual(msg.parsed_rows[0][0], 42)
self.assertEqual(msg.parsed_rows[0][1], 'hello')

# Verify contains_column was called only once per column (optimization)
self.assertEqual(mock_policy.contains_column.call_count, 2)

# Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column)
self.assertEqual(mock_policy.decrypt.call_count, 2)

def test_optimization_efficiency(self):
"""
Verify that the optimization reduces the number of policy checks.
With the old code, contains_column would be called for every value.
With the new code, it's called once per column.
"""
msg = self._create_mock_result_message()

# Create more rows to make the optimization more apparent
msg.recv_row = Mock(side_effect=[
[int32_pack(i), f'text{i}'.encode()] for i in range(100)
])

# Create mock stream with 100 rows
f = io.BytesIO(int32_pack(100))

mock_policy = Mock()
mock_policy.contains_column = Mock(return_value=False)

msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)

# With optimization: contains_column called once per column = 2 calls
# Without optimization: would be called per value = 100 rows * 2 columns = 200 calls
self.assertEqual(mock_policy.contains_column.call_count, 2,
"Optimization failed: contains_column should be called once per column, not per value")


if __name__ == '__main__':
unittest.main()