|
25 | 25 | EmptyValue, LongType, SetType, UTF8Type,
|
26 | 26 | cql_typename, int8_pack, int64_pack, lookup_casstype,
|
27 | 27 | lookup_casstype_simple, parse_casstype_args,
|
28 |
| - int32_pack, Int32Type, ListType, MapType |
| 28 | + int32_pack, Int32Type, ListType, MapType, VectorType, |
| 29 | + FloatType |
29 | 30 | )
|
30 | 31 | from cassandra.encoder import cql_quote
|
31 | 32 | from cassandra.pool import Host
|
@@ -188,6 +189,12 @@ class BarType(FooType):
|
188 | 189 | self.assertEqual(UTF8Type, ctype.subtypes[2])
|
189 | 190 | self.assertEqual([b'city', None, b'zip'], ctype.names)
|
190 | 191 |
|
| 192 | + def test_parse_casstype_vector(self): |
| 193 | + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)") |
| 194 | + self.assertTrue(issubclass(ctype, VectorType)) |
| 195 | + self.assertEqual(3, ctype.vector_size) |
| 196 | + self.assertEqual(FloatType, ctype.subtype) |
| 197 | + |
191 | 198 | def test_empty_value(self):
|
192 | 199 | self.assertEqual(str(EmptyValue()), 'EMPTY')
|
193 | 200 |
|
@@ -301,6 +308,19 @@ def test_cql_quote(self):
|
301 | 308 | self.assertEqual(cql_quote('test'), "'test'")
|
302 | 309 | self.assertEqual(cql_quote(0), '0')
|
303 | 310 |
|
| 311 | + def test_vector_round_trip(self): |
| 312 | + base = [3.4, 2.9, 41.6, 12.0] |
| 313 | + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") |
| 314 | + base_bytes = ctype.serialize(base, 0) |
| 315 | + self.assertEqual(16, len(base_bytes)) |
| 316 | + result = ctype.deserialize(base_bytes, 0) |
| 317 | + self.assertEqual(len(base), len(result)) |
| 318 | + for idx in range(0,len(base)): |
| 319 | + self.assertAlmostEqual(base[idx], result[idx], places=5) |
| 320 | + |
| 321 | + def test_vector_cql_parameterized_type(self): |
| 322 | + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") |
| 323 | + self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<float, 4>") |
304 | 324 |
|
305 | 325 | ZERO = datetime.timedelta(0)
|
306 | 326 |
|
|
0 commit comments