Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cassandra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def emit(self, record):

logging.getLogger('cassandra').addHandler(NullHandler())

__version_info__ = (3, 27, 0)
__version_info__ = (3, 28, 0b1)
__version__ = '.'.join(map(str, __version_info__))


Expand Down
37 changes: 34 additions & 3 deletions cassandra/cqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,15 @@ def parse_casstype_args(typestring):
else:
names.append(None)

ctype = lookup_casstype_simple(tok)
try:
ctype = int(tok)
except ValueError:
ctype = lookup_casstype_simple(tok)
types.append(ctype)

# return the first (outer) type, which will have all parameters applied
return args[0][0][0]


def lookup_casstype(casstype):
"""
Given a Cassandra type as a string (possibly including parameters), hand
Expand All @@ -259,6 +261,7 @@ def lookup_casstype(casstype):
try:
return parse_casstype_args(casstype)
except (ValueError, AssertionError, IndexError) as e:
log.debug("Exception in parse_casstype_args: %s" % e)
raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e))


Expand Down Expand Up @@ -296,7 +299,7 @@ class _CassandraType(object):
"""

def __repr__(self):
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
return '<%s>' % (self.cql_parameterized_type())

@classmethod
def from_binary(cls, byts, protocol_version):
Expand Down Expand Up @@ -1421,3 +1424,31 @@ def serialize(cls, v, protocol_version):
buf.write(int8_pack(cls._encode_precision(bound.precision)))

return buf.getvalue()

class VectorType(_CassandraType):
typename = 'org.apache.cassandra.db.marshal.VectorType'
vector_size = 0
subtype = None

@classmethod
def apply_parameters(cls, params, names):
assert len(params) == 2
subtype = lookup_casstype(params[0])
vsize = params[1]
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype})

@classmethod
def deserialize(cls, byts, protocol_version):
indexes = (4 * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + 4], protocol_version) for idx in indexes]
Copy link
Collaborator Author

@absurdfarce absurdfarce May 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the explicit assumption here that each element in the vector has a serialized form that is 4 bytes long. This is fine for the initial rollout (which will only include server-side support of float vectors) but needs to be generalized... although it should be noted there's some significant conversation on the Java equivalent of this PR about how to generalize that.


@classmethod
def serialize(cls, v, protocol_version):
buf = io.BytesIO()
for item in v:
buf.write(cls.subtype.serialize(item, protocol_version))
return buf.getvalue()

@classmethod
def cql_parameterized_type(cls):
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)
22 changes: 21 additions & 1 deletion tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
EmptyValue, LongType, SetType, UTF8Type,
cql_typename, int8_pack, int64_pack, lookup_casstype,
lookup_casstype_simple, parse_casstype_args,
int32_pack, Int32Type, ListType, MapType
int32_pack, Int32Type, ListType, MapType, VectorType,
FloatType
)
from cassandra.encoder import cql_quote
from cassandra.pool import Host
Expand Down Expand Up @@ -190,6 +191,12 @@ class BarType(FooType):
self.assertEqual(UTF8Type, ctype.subtypes[2])
self.assertEqual([b'city', None, b'zip'], ctype.names)

def test_parse_casstype_vector(self):
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)")
self.assertTrue(issubclass(ctype, VectorType))
self.assertEqual(3, ctype.vector_size)
self.assertEqual(FloatType, ctype.subtype)

def test_empty_value(self):
self.assertEqual(str(EmptyValue()), 'EMPTY')

Expand Down Expand Up @@ -303,6 +310,19 @@ def test_cql_quote(self):
self.assertEqual(cql_quote('test'), "'test'")
self.assertEqual(cql_quote(0), '0')

def test_vector_round_trip(self):
base = [3.4, 2.9, 41.6, 12.0]
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
base_bytes = ctype.serialize(base, 0)
self.assertEqual(16, len(base_bytes))
result = ctype.deserialize(base_bytes, 0)
self.assertEqual(len(base), len(result))
for idx in range(0,len(base)):
self.assertAlmostEqual(base[idx], result[idx], places=5)

def test_vector_cql_parameterized_type(self):
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<float, 4>")

ZERO = datetime.timedelta(0)

Expand Down