diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index d92d0b2452..4782fdccd8 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -821,7 +821,9 @@ def update_context_id(self, i): self.context_counter += conditional.get_context_size() def add_update(self, column, value, operation=None, previous=None): - value = column.to_database(value) + # For remove all values are None, no need to convert them + if operation != 'remove': + value = column.to_database(value) col_type = type(column) container_update_type = ContainerUpdateClause.type_map.get(col_type) if container_update_type: diff --git a/tests/integration/cqlengine/base.py b/tests/integration/cqlengine/base.py index bdb62aa2a3..c504c7279c 100644 --- a/tests/integration/cqlengine/base.py +++ b/tests/integration/cqlengine/base.py @@ -30,6 +30,8 @@ class TestQueryUpdateModel(Model): text_set = columns.Set(columns.Text, required=False) text_list = columns.List(columns.Text, required=False) text_map = columns.Map(columns.Text, columns.Text, required=False) + bin_map = columns.Map(columns.BigInt, columns.Bytes, required=False, default={}) + class BaseCassEngTestCase(unittest.TestCase): diff --git a/tests/integration/cqlengine/operators/test_where_operators.py b/tests/integration/cqlengine/operators/test_where_operators.py index 1e0134dbac..e04a377c88 100644 --- a/tests/integration/cqlengine/operators/test_where_operators.py +++ b/tests/integration/cqlengine/operators/test_where_operators.py @@ -80,7 +80,7 @@ def test_is_not_null_to_cql(self): self.assertEqual( str(TestQueryUpdateModel.filter(IsNotNull("text"), partition=uuid4())), ('SELECT "cluster", "count", "text", "text_set", ' - '"text_list", "text_map" FROM cqlengine_test.test_query_update_model ' + '"text_list", "text_map", "bin_map" FROM cqlengine_test.test_query_update_model ' 'WHERE "text" IS NOT NULL AND "partition" = %(0)s LIMIT 10000') ) diff --git a/tests/integration/cqlengine/query/test_updates.py b/tests/integration/cqlengine/query/test_updates.py index fb6082bfe2..f92e4fc53f 100644 --- a/tests/integration/cqlengine/query/test_updates.py +++ b/tests/integration/cqlengine/query/test_updates.py @@ -246,20 +246,30 @@ def test_map_update_remove(self): TestQueryUpdateModel.objects.create( partition=partition, cluster=cluster, - text_map={"foo": '1', "bar": '2'} + text_map={"foo": '1', "bar": '2'}, + bin_map={123: b'1', 789: b'2'} ) TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( text_map__remove={"bar"}, - text_map__update={"foz": '4', "foo": '2'} + text_map__update={"foz": '4', "foo": '2'}, + bin_map__remove={789}, + bin_map__update={456: b'4', 123: b'2'} ) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_map, {"foo": '2', "foz": '4'}) + self.assertEqual(obj.bin_map, {123: b'2', 456: b'4'}) TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( - text_map__remove={"foo", "foz"} + text_map__remove={"foo", "foz"}, + bin_map__remove={123, 456} + ) + rec = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual( + rec.text_map, + {} ) self.assertEqual( - TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster).text_map, + rec.bin_map, {} ) diff --git a/tests/integration/cqlengine/statements/test_base_statement.py b/tests/integration/cqlengine/statements/test_base_statement.py index ef5f3b2585..94f1aeaa30 100644 --- a/tests/integration/cqlengine/statements/test_base_statement.py +++ b/tests/integration/cqlengine/statements/test_base_statement.py @@ -65,7 +65,7 @@ def _verify_statement(self, original): for assignment in original.assignments: self.assertEqual(response[assignment.field], assignment.value) - self.assertEqual(len(response), 7) + self.assertEqual(len(response), 8) def test_insert_statement_execute(self): """ @@ -92,6 +92,7 @@ def test_insert_statement_execute(self): st.add_assignment(Column(db_field='text_set'), set(("foo_update", "bar_update"))) st.add_assignment(Column(db_field='text_list'), ["foo_update", "bar_update"]) st.add_assignment(Column(db_field='text_map'), {"foo": '3', "bar": '4'}) + st.add_assignment(Column(db_field='bin_map'), {123: b'3', 456: b'4'}) execute(st) self._verify_statement(st) @@ -150,6 +151,7 @@ def _insert_statement(self, partition, cluster): st.add_assignment(Column(db_field='text_set'), set(("foo", "bar"))) st.add_assignment(Column(db_field='text_list'), ["foo", "bar"]) st.add_assignment(Column(db_field='text_map'), {"foo": '1', "bar": '2'}) + st.add_assignment(Column(db_field='bin_map'), {123: b'1', 456: b'2'}) execute(st) self._verify_statement(st)