diff --git a/docker_db.sh b/docker_db.sh index 8bc6cb6eab43..9478f8aaecac 100755 --- a/docker_db.sh +++ b/docker_db.sh @@ -4,9 +4,17 @@ if command -v docker > /dev/null; then CONTAINER_CLI=$(command -v docker) HEALTCHECK_PATH="{{.State.Health.Status}}" PRIVILEGED_CLI="" + IS_PODMAN=false + if [[ "$(docker version | grep Podman)" == "" ]]; then + IS_DOCKER_RUNTIME=true + else + IS_DOCKER_RUNTIME=false + fi else CONTAINER_CLI=$(command -v podman) HEALTCHECK_PATH="{{.State.Healthcheck.Status}}" + IS_PODMAN=true + IS_DOCKER_RUNTIME=false # Only use sudo for podman if command -v sudo > /dev/null; then PRIVILEGED_CLI="sudo" @@ -15,6 +23,12 @@ else fi fi +if [[ "$(uname -s)" == "Darwin" ]]; then + IS_OSX=true +else + IS_OSX=false +fi + mysql() { mysql_9_4 } @@ -286,29 +300,93 @@ db2() { } db2_11_5() { - $PRIVILEGED_CLI $CONTAINER_CLI rm -f db2 || true - $PRIVILEGED_CLI $CONTAINER_CLI run --name db2 --privileged -e DB2INSTANCE=orm_test -e DB2INST1_PASSWORD=orm_test -e DBNAME=orm_test -e LICENSE=accept -e AUTOCONFIG=false -e ARCHIVE_LOGS=false -e TO_CREATE_SAMPLEDB=false -e REPODB=false -p 50000:50000 -d ${DB_IMAGE_DB2_11_5:-icr.io/db2_community/db2:11.5.9.0} + CONTAINER_OPTIONS="" + if [[ "$IS_OSX" == "true" ]]; then + # Thanks to Mohamed Asfour https://community.ibm.com/community/user/discussion/db2-luw-115xx-mac-m1-ready#bm017584d2-8d76-42a6-8f76-018dac8e78f2 + # This SO post explains what goes wrong on OSX: https://stackoverflow.com/questions/70175677/ibmcom-db2-docker-image-fails-on-m1 + # Also, use the $HOME directory as base directory to make volume mounts work on Colima on Mac + db2install=$HOME/db2install.sh + rm -f ${db2install} || true + cat <<'EOF' >${db2install} +#!/bin/bash +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - db2inst1 -c '/su - db2inst1 -c '. .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - db2inst1 -c \"/su - db2inst1 -c \". .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${DB2INSTANCE?} -c '/su - \${DB2INSTANCE?} -c '. .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${DB2INSTANCE?} -c \"/su - \${DB2INSTANCE?} -c \". .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${instance?} -c '/su - \${instance?} -c '. .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${instance?} -c \"/su - \${instance?} -c \". .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${instance_name?} -c '/su - \${instance_name?} -c '. .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${instance_name?} -c \"/su - \${instance_name?} -c \". .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - db2inst1 -c \\\\\"/su - db2inst1 -c \\\". .profile \&\& /g" {} + +. /var/db2_setup/lib/setup_db2_instance.sh +EOF + chmod 777 ${db2install} + if [[ "$IS_PODMAN" == "true" ]]; then + CONTAINER_OPTIONS='--platform=linux/amd64 -e IS_OSXFS=true -v '${db2install}':/db2install.sh --entrypoint=["/bin/bash","-c","/db2install.sh"]' + CONTAINER_ARGS= + else + CONTAINER_OPTIONS='--platform=linux/amd64 -e IS_OSXFS=true -v '${db2install}':/db2install.sh --entrypoint=/bin/bash' + CONTAINER_ARGS=" -c /db2install.sh" + fi + if [[ "$IS_PODMAN" == "false" ]]; then + export DOCKER_DEFAULT_PLATFORM=linux/amd64 + fi + fi + $CONTAINER_CLI rm -f db2 || true + $CONTAINER_CLI run --name db2 --privileged -e DB2INSTANCE=orm_test -e DB2INST1_PASSWORD=orm_test -e DBNAME=orm_test -e LICENSE=accept -e AUTOCONFIG=false -e ARCHIVE_LOGS=false -e TO_CREATE_SAMPLEDB=false -e REPODB=false -e BLU=false -e ENABLE_ORACLE_COMPATIBILITY=false -e UPDATEAVAIL=NO -e PERSISTENT_HOME=true -e HADR_ENABLED=false -p 50000:50000 $CONTAINER_OPTIONS -d ${DB_IMAGE_DB2_11_5:-icr.io/db2_community/db2:11.5.9.0} $CONTAINER_ARGS # Give the container some time to start OUTPUT= while [[ $OUTPUT != *"INSTANCE"* ]]; do echo "Waiting for DB2 to start..." sleep 10 - OUTPUT=$($PRIVILEGED_CLI $CONTAINER_CLI logs db2 2>&1) + OUTPUT=$($CONTAINER_CLI logs db2 2>&1) done - $PRIVILEGED_CLI $CONTAINER_CLI exec -t db2 su - orm_test bash -c ". /database/config/orm_test/sqllib/db2profile; /database/config/orm_test/sqllib/bin/db2 'connect to orm_test'; /database/config/orm_test/sqllib/bin/db2 'CREATE USER TEMPORARY TABLESPACE usr_tbsp MANAGED BY AUTOMATIC STORAGE'" + $CONTAINER_CLI exec -t db2 su - orm_test bash -c ". /database/config/orm_test/sqllib/db2profile; /database/config/orm_test/sqllib/bin/db2 'connect to orm_test'; /database/config/orm_test/sqllib/bin/db2 'CREATE USER TEMPORARY TABLESPACE usr_tbsp MANAGED BY AUTOMATIC STORAGE'" } db2_12_1() { - $PRIVILEGED_CLI $CONTAINER_CLI rm -f db2 || true - $PRIVILEGED_CLI $CONTAINER_CLI run --name db2 --privileged -e DB2INSTANCE=orm_test -e DB2INST1_PASSWORD=orm_test -e DBNAME=orm_test -e LICENSE=accept -e AUTOCONFIG=false -e ARCHIVE_LOGS=false -e TO_CREATE_SAMPLEDB=false -e REPODB=false -p 50000:50000 -d ${DB_IMAGE_DB2_11_5:-icr.io/db2_community/db2:12.1.2.0} + CONTAINER_OPTIONS="" + if [[ "$IS_OSX" == "true" ]]; then + # Thanks to Mohamed Asfour https://community.ibm.com/community/user/discussion/db2-luw-115xx-mac-m1-ready#bm017584d2-8d76-42a6-8f76-018dac8e78f2 + # This SO post explains what goes wrong on OSX: https://stackoverflow.com/questions/70175677/ibmcom-db2-docker-image-fails-on-m1 + # Also, use the $HOME directory as base directory to make volume mounts work on Colima on Mac + db2install=$HOME/db2install.sh + rm -f ${db2install} || true + cat <<'EOF' >${db2install} +#!/bin/bash +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - db2inst1 -c '/su - db2inst1 -c '. .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - db2inst1 -c \"/su - db2inst1 -c \". .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${DB2INSTANCE?} -c '/su - \${DB2INSTANCE?} -c '. .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${DB2INSTANCE?} -c \"/su - \${DB2INSTANCE?} -c \". .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${instance?} -c '/su - \${instance?} -c '. .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${instance?} -c \"/su - \${instance?} -c \". .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${instance_name?} -c '/su - \${instance_name?} -c '. .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - \${instance_name?} -c \"/su - \${instance_name?} -c \". .profile \&\& /g" {} + +find /var/db2_setup -type f -not -path '*/\.*' -exec sed -i "s/su - db2inst1 -c \\\\\"/su - db2inst1 -c \\\". .profile \&\& /g" {} + +. /var/db2_setup/lib/setup_db2_instance.sh +EOF + chmod 777 ${db2install} + if [[ "$IS_PODMAN" == "true" ]]; then + CONTAINER_OPTIONS='--platform=linux/amd64 -e IS_OSXFS=true -v '${db2install}':/db2install.sh --entrypoint=["/bin/bash","-c","/db2install.sh"]' + CONTAINER_ARGS= + else + CONTAINER_OPTIONS='--platform=linux/amd64 -e IS_OSXFS=true -v '${db2install}':/db2install.sh --entrypoint=/bin/bash' + CONTAINER_ARGS=" -c /db2install.sh" + fi + if [[ "$IS_PODMAN" == "false" ]]; then + export DOCKER_DEFAULT_PLATFORM=linux/amd64 + fi + fi + $CONTAINER_CLI rm -f db2 || true + $CONTAINER_CLI run --name db2 --privileged -e DB2INSTANCE=orm_test -e DB2INST1_PASSWORD=orm_test -e DBNAME=orm_test -e LICENSE=accept -e AUTOCONFIG=false -e ARCHIVE_LOGS=false -e TO_CREATE_SAMPLEDB=false -e REPODB=false -e BLU=false -e ENABLE_ORACLE_COMPATIBILITY=false -e UPDATEAVAIL=NO -e PERSISTENT_HOME=true -e HADR_ENABLED=false -p 50000:50000 $CONTAINER_OPTIONS -d ${DB_IMAGE_DB2_12_1:-icr.io/db2_community/db2:12.1.2.0} $CONTAINER_ARGS # Give the container some time to start OUTPUT= while [[ $OUTPUT != *"INSTANCE"* ]]; do echo "Waiting for DB2 to start..." sleep 10 - OUTPUT=$($PRIVILEGED_CLI $CONTAINER_CLI logs db2 2>&1) + OUTPUT=$($CONTAINER_CLI logs db2 2>&1) done - $PRIVILEGED_CLI $CONTAINER_CLI exec -t db2 su - orm_test bash -c ". /database/config/orm_test/sqllib/db2profile; /database/config/orm_test/sqllib/bin/db2 'connect to orm_test'; /database/config/orm_test/sqllib/bin/db2 'CREATE USER TEMPORARY TABLESPACE usr_tbsp MANAGED BY AUTOMATIC STORAGE'" + $CONTAINER_CLI exec -t db2 su - orm_test bash -c ". /database/config/orm_test/sqllib/db2profile; /database/config/orm_test/sqllib/bin/db2 'connect to orm_test'; /database/config/orm_test/sqllib/bin/db2 'CREATE USER TEMPORARY TABLESPACE usr_tbsp MANAGED BY AUTOMATIC STORAGE'" } db2_spatial() { @@ -726,28 +804,30 @@ EOF\"" } disable_userland_proxy() { - if [[ "$HEALTCHECK_PATH" == "{{.State.Health.Status}}" ]]; then - if [[ ! -f /etc/docker/daemon.json ]]; then - echo "Didn't find /etc/docker/daemon.json but need to disable userland-proxy..." - echo "Stopping docker..." - sudo service docker stop - echo "Creating /etc/docker/daemon.json..." - sudo bash -c "echo '{\"userland-proxy\": false}' > /etc/docker/daemon.json" - echo "Starting docker..." - sudo service docker start - echo "Docker successfully started with userland proxies disabled" - elif ! grep -q userland-proxy /etc/docker/daemon.json; then - echo "Userland proxy is still enabled in /etc/docker/daemon.json, but need to disable it..." - export docker_daemon_json=$( /etc/docker/daemon.json" - echo "Starting docker..." - sudo service docker start - echo "Service status:" - sudo journalctl -xeu docker.service - echo "Docker successfully started with userland proxies disabled" + if [[ "$IS_DOCKER_RUNTIME" == "true" ]]; then + if [[ "$HEALTCHECK_PATH" == "{{.State.Health.Status}}" ]]; then + if [[ ! -f /etc/docker/daemon.json ]]; then + echo "Didn't find /etc/docker/daemon.json but need to disable userland-proxy..." + echo "Stopping docker..." + sudo service docker stop + echo "Creating /etc/docker/daemon.json..." + sudo bash -c "echo '{\"userland-proxy\": false}' > /etc/docker/daemon.json" + echo "Starting docker..." + sudo service docker start + echo "Docker successfully started with userland proxies disabled" + elif ! grep -q userland-proxy /etc/docker/daemon.json; then + echo "Userland proxy is still enabled in /etc/docker/daemon.json, but need to disable it..." + export docker_daemon_json=$( /etc/docker/daemon.json" + echo "Starting docker..." + sudo service docker start + echo "Service status:" + sudo journalctl -xeu docker.service + echo "Docker successfully started with userland proxies disabled" + fi fi fi } diff --git a/documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc b/documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc index 7c4c93b18f8a..3e5d63555755 100644 --- a/documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc +++ b/documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc @@ -12,10 +12,21 @@ The Hibernate ORM Vector module contains support for mathematical vector types a This is useful for AI/ML topics like vector similarity search and Retrieval-Augmented Generation (RAG). The module comes with support for a special `vector` data type that essentially represents an array of bytes, floats, or doubles. -So far, both the PostgreSQL extension `pgvector` and the Oracle database 23ai+ `AI Vector Search` feature are supported, but in theory, -the vector specific functions could be implemented to work with every database that supports arrays. +Currently, the following databases are supported: -For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation] or the https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[AI Vector Search documentation]. +* PostgreSQL 13+ through the https://github.com/pgvector/pgvector#querying[`pgvector` extension] +* https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[Oracle database 23ai+] +* https://mariadb.com/docs/server/reference/sql-structure/vectors/vector-overview[MariaDB 11.7+] +* https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html[MySQL 9.0+] + +In theory, the vector-specific functions could be implemented to work with every database that supports arrays. + +[WARNING] +==== +Per the https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html#function_distance[MySQL documentation], +the various vector distance functions for MySQL only work on MySQL cloud offerings like +https://dev.mysql.com/doc/heatwave/en/mys-hw-about-heatwave.html[HeatWave MySQL on OCI]. +==== [[vector-module-setup]] === Setup @@ -42,22 +53,32 @@ so no further configuration is necessary to make the features available. [[vector-module-usage]] ==== Usage -Annotate a persistent attribute with `@JdbcTypeCode(SqlTypes.VECTOR)` and specify the vector length with `@Array(length = ...)`. +Annotate a persistent attribute with one of the various vector type codes `@JdbcTypeCode` and specify the vector length with `@Array(length = ...)`. +Possible vector type codes and the compatible Java types are: + +* `@JdbcTypeCode(SqlTypes.VECTOR_BINARY)` for `byte[]` +* `@JdbcTypeCode(SqlTypes.VECTOR_INT8)` for `byte[]` +* `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT16)` for `float[]` +* `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT32)` for `float[]` +* `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT64)` for `double[]` +* `@JdbcTypeCode(SqlTypes.VECTOR)` for `float[]` + +Hibernate ORM also provides support for sparse vectors through dedicated Java types: + +* `@JdbcTypeCode(SqlTypes.SPARSE_VECTOR_INT8)` for `SparseByteVector` +* `@JdbcTypeCode(SqlTypes.SPARSE_VECTOR_FLOAT32)` for `SparseFloatVector` +* `@JdbcTypeCode(SqlTypes.SPARSE_VECTOR_FLOAT64)` for `SparseDoubleVector` [WARNING] ==== -As Oracle AI Vector Search supports different types of elements (to ensure better performance and compatibility with embedding models), you can also use: - -- `@JdbcTypeCode(SqlTypes.VECTOR_INT8)` for `byte[]` -- `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT32)` for `float[]` -- `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT64)` for `double[]`. +Vector data type support depends on native support of the underlying database. ==== [[vector-module-usage-example]] ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=usage-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=usage-example] ---- ==== @@ -77,14 +98,21 @@ Expressions of the vector type can be used with various vector functions. | `euclidean_distance()` | Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors. Maps to the `<``-``>` operator for `pgvector` and maps to the `vector_distance(v1, v2, EUCLIDEAN)` function for `Oracle AI Vector Search`. +| `euclidean_squared_distance()` | Computes the https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance[squared euclidean distance] between two vectors. + | `l2_distance()` | Alias for `euclidean_distance()` +| `l2_squared_distance()` | Alias for `euclidean_squared_distance()` + | `taxicab_distance()` | Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors. Maps to `vector_distance(v1, v2, MANHATTAN)` function for `Oracle AI Vector Search`. | `l1_distance()` | Alias for `taxicab_distance()` | `hamming_distance()` | Computes the https://en.wikipedia.org/wiki/Hamming_distance[hamming distance] between two vectors. Maps to `vector_distance(v1, v2, HAMMING)` function for `Oracle AI Vector Search`. +| `jaccard_distance()` | Computes the https://en.wikipedia.org/wiki/Jaccard_index[jaccard distance] between two vectors. Maps to the `<``%``>` operator for `pgvector` and maps to the +`vector_distance(v1, v2, JACCARD)` function for `Oracle AI Vector Search`. + | `inner_product()` | Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors | `negative_inner_product()` | Computes the negative inner product. Maps to the `<``#``>` operator for `pgvector` and maps to the @@ -93,6 +121,14 @@ Expressions of the vector type can be used with various vector functions. | `vector_dims()` | Determines the dimensions of a vector | `vector_norm()` | Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector + +| `l2_norm()` | Alias for `vector_norm()` + +| `l2_normalize()` | Normalizes each component of a vector by dividing it with the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of the vector. + +| `binary_quantize()` | Reduces a vector of size N to a binary vector with N bits, using 0 for values <= 0 and 1 for values > 0. + +| `subvector()` | Creates a subvector from a given vector, a 1-based start index and a count. |=== In addition to these special vector functions, it is also possible to use vectors with the following builtin `pgvector` operators: @@ -113,7 +149,7 @@ which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 ) ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=cosine-distance-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=cosine-distance-example] ---- ==== @@ -128,7 +164,22 @@ The `l2_distance()` function is an alias. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=euclidean-distance-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-distance-example] +---- +==== + +[[vector-module-functions-euclidean-squared-distance]] +===== `euclidean_squared_distance()` and `l2_squared_distance()` + +Computes the https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance[squared euclidean distance] between two vectors, +which is `sum( (v1_i - v2_i)^2 )`, just like the regular `euclidean_distance`, but without the `sqrt`. +The `l2_squared_distance()` function is an alias. + +[[vector-module-functions-euclidean-squared-distance-example]] +==== +[source, java, indent=0] +---- +include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-squared-distance-example] ---- ==== @@ -143,7 +194,37 @@ The `l1_distance()` function is an alias. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=taxicab-distance-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=taxicab-distance-example] +---- +==== + +[[vector-module-functions-hamming-distance]] +===== `hamming_distance()` + +Computes the https://en.wikipedia.org/wiki/Hamming_distance[hamming distance] between two binary vectors, +which is `bit_count(v1 ^ v2)` i.e. the amount of bits where two vectors differ. +Maps to the `<``~``>` operator for `pgvector`. + +[[vector-module-functions-taxicab-distance-example]] +==== +[source, java, indent=0] +---- +include::{example-dir-vector}/BinaryVectorTest.java[tags=hamming-distance-example] +---- +==== + +[[vector-module-functions-jaccard-distance]] +===== `jaccard_distance()` + +Computes the https://en.wikipedia.org/wiki/Jaccard_index[jaccard distance] between two binary vectors, +which is `1 - bit_count(v1 & v2) / bit_count(v1 | v2)`. +Maps to the `<``%``>` operator for `pgvector`. + +[[vector-module-functions-taxicab-distance-example]] +==== +[source, java, indent=0] +---- +include::{example-dir-vector}/BinaryVectorTest.java[tags=jaccard-distance-example] ---- ==== @@ -158,7 +239,7 @@ and the `inner_product()` function as well, but multiplies the result time `-1`. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=inner-product-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=inner-product-example] ---- ==== @@ -171,24 +252,63 @@ Determines the dimensions of a vector. ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=vector-dims-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=vector-dims-example] ---- ==== [[vector-module-functions-vector-norm]] -===== `vector_norm()` +===== `vector_norm()` and `l2_norm()` Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector, which is `sqrt( sum( v_i^2 ) )`. +The `l2_norm()` function is an alias. [[vector-module-functions-vector-norm-example]] ==== [source, java, indent=0] ---- -include::{example-dir-vector}/PGVectorTest.java[tags=vector-norm-example] +include::{example-dir-vector}/FloatVectorTest.java[tags=vector-norm-example] +---- +==== + +[[vector-module-functions-l2-normalize]] +===== `l2_normalize()` + +Normalizes each component of a vector by dividing it with the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of the vector. + +[[vector-module-functions-l2-normalize-example]] +==== +[source, java, indent=0] +---- +include::{example-dir-vector}/FloatVectorTest.java[tags=l2-normalize-example] +---- +==== + +[[vector-module-functions-binary-quantize]] +===== `binary_quantize()` + +Reduces a vector of size N to a binary vector with N bits, using 0 for values <= 0 and 1 for values > 0. + +[[vector-module-functions-binary-quantize-example]] +==== +[source, java, indent=0] +---- +include::{example-dir-vector}/FloatVectorTest.java[tags=binary-quantize-example] ---- ==== +[[vector-module-functions-subvector]] +===== `binary_quantize()` + +Creates a subvector from a given vector, a 1-based start index and a count. + +[[vector-module-functions-subvector-example]] +==== +[source, java, indent=0] +---- +include::{example-dir-vector}/FloatVectorTest.java[tags=subvector-example] +---- +==== diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/DB2LegacyDialect.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/DB2LegacyDialect.java index d84f4e3a8f64..692a6e870ebc 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/DB2LegacyDialect.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/DB2LegacyDialect.java @@ -12,6 +12,7 @@ import org.hibernate.boot.model.TypeContributions; import org.hibernate.community.dialect.sequence.LegacyDB2SequenceSupport; import org.hibernate.community.dialect.temptable.DB2LegacyLocalTemporaryTableStrategy; +import org.hibernate.dialect.DB2Dialect; import org.hibernate.dialect.DB2GetObjectExtractor; import org.hibernate.dialect.DatabaseVersion; import org.hibernate.dialect.Dialect; @@ -183,7 +184,7 @@ public DB2LegacyDialect() { } public DB2LegacyDialect(DialectResolutionInfo info) { - super( info ); + this( DB2Dialect.determinFullDatabaseVersion( info ) ); lockingSupport = buildLockingSupport(); } @@ -192,6 +193,11 @@ public DB2LegacyDialect(DatabaseVersion version) { lockingSupport = buildLockingSupport(); } + @Override + public DatabaseVersion determineDatabaseVersion(DialectResolutionInfo info) { + return DB2Dialect.determinFullDatabaseVersion( info ); + } + protected LockingSupport buildLockingSupport() { // Introduced in 11.5: https://www.ibm.com/docs/en/db2/11.5?topic=statement-concurrent-access-resolution-clause final boolean supportsSkipLocked = getVersion().isSameOrAfter( 11, 5 ); diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingInetJdbcType.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingInetJdbcType.java index 81cfd92d3ee6..eab976f9cdc9 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingInetJdbcType.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingInetJdbcType.java @@ -10,7 +10,9 @@ import java.sql.ResultSet; import java.sql.SQLException; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.ValueBinder; @@ -35,6 +37,7 @@ public class GaussDBCastingInetJdbcType implements JdbcType { @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); @@ -42,6 +45,11 @@ public void appendWriteExpression( appender.append( " as inet)" ); } + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + @Override public int getJdbcTypeCode() { return SqlTypes.VARBINARY; diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingIntervalSecondJdbcType.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingIntervalSecondJdbcType.java index 221066afd170..2a9698ef9a55 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingIntervalSecondJdbcType.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingIntervalSecondJdbcType.java @@ -10,7 +10,9 @@ import java.sql.ResultSet; import java.sql.SQLException; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.sql.ast.SqlAstTranslator; @@ -77,6 +79,7 @@ public JdbcMappingContainer getExpressionType() { @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( '(' ); @@ -84,6 +87,11 @@ public void appendWriteExpression( appender.append( "*interval'1 second')" ); } + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + @Override public int getJdbcTypeCode() { return SqlTypes.NUMERIC; diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingJsonArrayJdbcType.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingJsonArrayJdbcType.java index fcc995ccd7bf..97d8eba0c914 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingJsonArrayJdbcType.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingJsonArrayJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.community.dialect; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcType; @@ -27,6 +29,7 @@ public GaussDBCastingJsonArrayJdbcType(JdbcType elementJdbcType, boolean jsonb) @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); @@ -39,4 +42,9 @@ public void appendWriteExpression( appender.append( "json)" ); } } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingJsonJdbcType.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingJsonJdbcType.java index 3d14060478bf..cc097a350521 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingJsonJdbcType.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBCastingJsonJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.community.dialect; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.EmbeddableMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.sql.ast.spi.SqlAppender; @@ -46,6 +48,7 @@ public AggregateJdbcType resolveAggregateJdbcType( @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); @@ -58,4 +61,9 @@ public void appendWriteExpression( appender.append( "json)" ); } } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBStructuredJdbcType.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBStructuredJdbcType.java index 5fc201fb86c8..65a351ef0fd6 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBStructuredJdbcType.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/GaussDBStructuredJdbcType.java @@ -8,9 +8,11 @@ import java.sql.PreparedStatement; import java.sql.SQLException; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.boot.model.naming.Identifier; import org.hibernate.dialect.Dialect; import org.hibernate.dialect.type.AbstractPostgreSQLStructJdbcType; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.EmbeddableMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.sql.ast.spi.SqlAppender; @@ -59,6 +61,7 @@ public AggregateJdbcType resolveAggregateJdbcType( @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); @@ -68,6 +71,11 @@ public void appendWriteExpression( appender.append( ')' ); } + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + @Override public ValueBinder getBinder(JavaType javaType) { return new BasicBinder<>( javaType, this ) { diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/DB2Dialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/DB2Dialect.java index a50bac1e03cc..098cd379ea85 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/DB2Dialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/DB2Dialect.java @@ -6,6 +6,7 @@ import jakarta.persistence.TemporalType; import jakarta.persistence.Timeout; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.Timeouts; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; @@ -113,7 +114,10 @@ import java.util.Date; import java.util.List; import java.util.TimeZone; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import static java.lang.Integer.parseInt; import static org.hibernate.exception.spi.TemplatedViolatedConstraintNameExtractor.extractUsingTemplate; import static org.hibernate.internal.util.JdbcExceptionHelper.extractErrorCode; import static org.hibernate.type.SqlTypes.BINARY; @@ -147,6 +151,8 @@ public class DB2Dialect extends Dialect { final static DatabaseVersion MINIMUM_VERSION = DatabaseVersion.make( 11, 1 ); + + private static final Pattern DB2_VERSION_PATTERN = Pattern.compile( "(?:ARI|DSN|QSQ|SQL)(\\d\\d)(\\d\\d)(\\d)\\d?" ); private static final int BIND_PARAMETERS_NUMBER_LIMIT = 32_767; private static final String FOR_READ_ONLY_SQL = " for read only with rs"; @@ -175,7 +181,7 @@ public DB2Dialect() { } public DB2Dialect(DialectResolutionInfo info) { - this( info.makeCopyOrDefault( MINIMUM_VERSION ) ); + this( determinFullDatabaseVersion( info ) ); registerKeywords( info ); } @@ -184,6 +190,42 @@ public DB2Dialect(DatabaseVersion version) { lockingSupport = buildLockingSupport(); } + @Override + public DatabaseVersion determineDatabaseVersion(DialectResolutionInfo info) { + return determinFullDatabaseVersion( info ); + } + + public static DatabaseVersion determinFullDatabaseVersion(DialectResolutionInfo info) { + String versionString = null; + final DatabaseMetaData databaseMetadata = info.getDatabaseMetadata(); + if ( databaseMetadata != null ) { + try { + versionString = databaseMetadata.getDatabaseProductVersion(); + } + catch (SQLException ex) { + // Ignore + } + } + final DatabaseVersion databaseVersion = versionString == null ? null : parseVersion( versionString ); + return databaseVersion != null ? databaseVersion : info.makeCopyOrDefault( MINIMUM_VERSION ); + } + + public static @Nullable DatabaseVersion parseVersion(String versionString) { + if ( versionString.length() != 9 ) { + // The default format + return null; + } + DatabaseVersion databaseVersion = null; + final Matcher matcher = DB2_VERSION_PATTERN.matcher( versionString ); + if ( matcher.find() ) { + int majorVersion = parseInt( matcher.group( 1 ) ); + int minorVersion = parseInt( matcher.group( 2 ) ); + int microVersion = parseInt( matcher.group( 3 ) ); + databaseVersion = new SimpleDatabaseVersion( majorVersion, minorVersion, microVersion ); + } + return databaseVersion; + } + protected LockingSupport buildLockingSupport() { // Introduced in 11.5: https://www.ibm.com/docs/en/db2/11.5?topic=statement-concurrent-access-resolution-clause final boolean supportsSkipLocked = getVersion().isSameOrAfter( 11, 5 ); diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java b/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java index 89f8c76690fc..722122dfd546 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java @@ -15,4 +15,5 @@ public class OracleTypes { public static final int VECTOR_INT8 = -106; public static final int VECTOR_FLOAT32 = -107; public static final int VECTOR_FLOAT64 = -108; + public static final int VECTOR_BINARY = -109; } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/AggregateSupport.java b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/AggregateSupport.java index 3361c42cefda..35cc3d101f35 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/AggregateSupport.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/aggregate/AggregateSupport.java @@ -61,6 +61,7 @@ default String aggregateComponentCustomReadExpression( new SqlTypedMappingImpl( column.getTypeName(), column.getLength(), + column.getArrayLength(), column.getPrecision(), column.getScale(), column.getTemporalPrecision(), diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java index cb0dde411195..19f71a37ae49 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java @@ -8,8 +8,10 @@ import java.util.List; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.SqlTypedMapping; import org.hibernate.metamodel.model.domain.ReturnableType; import org.hibernate.query.sqm.CastType; import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; @@ -77,8 +79,27 @@ public void render( renderCastArrayToString( sqlAppender, arguments.get( 0 ), dialect, walker ); } else { - new PatternRenderer( dialect.castPattern( sourceType, targetType ) ) - .render( sqlAppender, arguments, walker ); + final Size targetSize; + if ( castTarget.getLength() != null || castTarget.getArrayLength() != null + || castTarget.getPrecision() != null || castTarget.getScale() != null ) { + targetSize = castTarget.toSize(); + } + else { + targetSize = targetJdbcMapping instanceof SqlTypedMapping sqlTypedMapping + ? sqlTypedMapping.toSize() + : null; + } + String castPattern = targetJdbcMapping.getJdbcType().castFromPattern( sourceMapping, targetSize ); + if ( castPattern == null ) { + final Size sourceSize = sourceMapping instanceof SqlTypedMapping sqlTypedMapping + ? sqlTypedMapping.toSize() + : null; + castPattern = sourceMapping.getJdbcType().castToPattern( targetJdbcMapping, sourceSize ); + if ( castPattern == null ) { + castPattern = dialect.castPattern( sourceType, targetType ); + } + } + new PatternRenderer( castPattern ).render( sqlAppender, arguments, walker ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/GenerateSeriesSetReturningFunctionTypeResolver.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/GenerateSeriesSetReturningFunctionTypeResolver.java index 3c6cc28e4a58..e9aef674978b 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/GenerateSeriesSetReturningFunctionTypeResolver.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/GenerateSeriesSetReturningFunctionTypeResolver.java @@ -88,6 +88,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, null, null, + null, false, false, false, @@ -110,6 +111,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, typedMapping.getColumnDefinition(), typedMapping.getLength(), + typedMapping.getArrayLength(), typedMapping.getPrecision(), typedMapping.getScale(), typedMapping.getTemporalPrecision(), @@ -134,6 +136,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, null, null, + null, false, true, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/NumberSeriesGenerateSeriesFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/NumberSeriesGenerateSeriesFunction.java index 28c1599f0511..b22e332d047e 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/NumberSeriesGenerateSeriesFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/NumberSeriesGenerateSeriesFunction.java @@ -360,6 +360,7 @@ protected SelectableMapping[] resolveIterationVariableBasedFunctionReturnType( null, null, null, + null, false, false, false, @@ -393,6 +394,7 @@ protected SelectableMapping[] resolveIterationVariableBasedFunctionReturnType( null, typedMapping.getColumnDefinition(), typedMapping.getLength(), + typedMapping.getArrayLength(), typedMapping.getPrecision(), typedMapping.getScale(), typedMapping.getTemporalPrecision(), @@ -417,6 +419,7 @@ protected SelectableMapping[] resolveIterationVariableBasedFunctionReturnType( null, null, null, + null, false, true, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java index 2cee65aeed1c..99354744dd9d 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java @@ -90,6 +90,14 @@ public ReturnableType resolveFunctionReturnType( case NUMERIC: return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType; case VECTOR: + case VECTOR_BINARY: + case VECTOR_INT8: + case VECTOR_FLOAT16: + case VECTOR_FLOAT32: + case VECTOR_FLOAT64: + case SPARSE_VECTOR_INT8: + case SPARSE_VECTOR_FLOAT32: + case SPARSE_VECTOR_FLOAT64: return basicType; } return bigDecimalType; @@ -123,6 +131,14 @@ public BasicValuedMapping resolveFunctionReturnType( final Class argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass(); return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType; case VECTOR: + case VECTOR_BINARY: + case VECTOR_INT8: + case VECTOR_FLOAT16: + case VECTOR_FLOAT32: + case VECTOR_FLOAT64: + case SPARSE_VECTOR_INT8: + case SPARSE_VECTOR_FLOAT32: + case SPARSE_VECTOR_FLOAT64: return (BasicValuedMapping) jdbcMapping; } return bigDecimalType; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/UnnestSetReturningFunctionTypeResolver.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/UnnestSetReturningFunctionTypeResolver.java index d4e58fb4b3dc..d751ff7d1220 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/UnnestSetReturningFunctionTypeResolver.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/UnnestSetReturningFunctionTypeResolver.java @@ -109,6 +109,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, null, null, + null, false, false, false, @@ -136,6 +137,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, selectableMapping.getColumnDefinition(), selectableMapping.getLength(), + selectableMapping.getArrayLength(), selectableMapping.getPrecision(), selectableMapping.getScale(), selectableMapping.getTemporalPrecision(), @@ -166,6 +168,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, typedMapping.getColumnDefinition(), typedMapping.getLength(), + typedMapping.getArrayLength(), typedMapping.getPrecision(), typedMapping.getScale(), typedMapping.getTemporalPrecision(), @@ -190,6 +193,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, null, null, + null, false, true, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/H2UnnestFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/H2UnnestFunction.java index 8c6aadf5f5ba..3c0dea0ee1a9 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/array/H2UnnestFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/array/H2UnnestFunction.java @@ -211,6 +211,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, null, null, + null, false, false, false, @@ -248,6 +249,7 @@ public SelectableMapping[] resolveFunctionReturnType( selectableMapping.getCustomWriteExpression(), selectableMapping.getColumnDefinition(), selectableMapping.getLength(), + selectableMapping.getArrayLength(), selectableMapping.getPrecision(), selectableMapping.getScale(), selectableMapping.getTemporalPrecision(), @@ -288,6 +290,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, typedMapping.getColumnDefinition(), typedMapping.getLength(), + typedMapping.getArrayLength(), typedMapping.getPrecision(), typedMapping.getScale(), typedMapping.getTemporalPrecision(), @@ -312,6 +315,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, null, null, + null, false, true, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/H2JsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/H2JsonTableFunction.java index 23460fe6c31d..3f4b68655851 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/H2JsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/H2JsonTableFunction.java @@ -826,6 +826,7 @@ protected void addSelectableMapping(List selectableMappings, null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/JsonTableSetReturningFunctionTypeResolver.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/JsonTableSetReturningFunctionTypeResolver.java index cf73ca7515f3..af5acb17ef76 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/JsonTableSetReturningFunctionTypeResolver.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/JsonTableSetReturningFunctionTypeResolver.java @@ -134,6 +134,7 @@ protected void addSelectableMapping(List selectableMappings, null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/OracleJsonTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/OracleJsonTableFunction.java index c6d1dec6b9c8..2b795a7a8288 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/json/OracleJsonTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/json/OracleJsonTableFunction.java @@ -134,6 +134,7 @@ protected void addSelectableMapping(List selectableMappings, null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/DB2XmlTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/DB2XmlTableFunction.java index 907efaea4657..59405748709f 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/DB2XmlTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/DB2XmlTableFunction.java @@ -99,6 +99,7 @@ protected void addSelectableMapping(List selectableMappings, null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/HANAXmlTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/HANAXmlTableFunction.java index 3408068c0900..dbcd1fb66619 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/HANAXmlTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/HANAXmlTableFunction.java @@ -454,6 +454,7 @@ protected void addSelectableMapping(List selectableMappings, null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/OracleXmlTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/OracleXmlTableFunction.java index 7f476d98b8f6..9fec3fcb32d6 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/OracleXmlTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/OracleXmlTableFunction.java @@ -64,6 +64,7 @@ protected void addSelectableMapping(List selectableMappings, null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/SybaseASEXmlTableFunction.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/SybaseASEXmlTableFunction.java index 4b91671b5355..38cbc0dcd3c4 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/SybaseASEXmlTableFunction.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/SybaseASEXmlTableFunction.java @@ -152,6 +152,7 @@ else if ( arguments.xmlDocument() instanceof ColumnReference columnReference ) { null, null, null, + null, false, false, false, @@ -185,6 +186,7 @@ protected void addSelectableMapping(List selectableMappings, null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/XmlTableSetReturningFunctionTypeResolver.java b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/XmlTableSetReturningFunctionTypeResolver.java index 19b9a18bde8b..9b6c3d38955f 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/XmlTableSetReturningFunctionTypeResolver.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/function/xml/XmlTableSetReturningFunctionTypeResolver.java @@ -112,6 +112,7 @@ protected void addSelectableMapping(List selectableMappings, null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/MariaDBCastingJsonArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/MariaDBCastingJsonArrayJdbcType.java index b5278d4f7cf1..4e0981a224e7 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/MariaDBCastingJsonArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/MariaDBCastingJsonArrayJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.dialect.type; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcType; @@ -21,10 +23,16 @@ public MariaDBCastingJsonArrayJdbcType(JdbcType elementJdbcType) { @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "json_extract(" ); appender.append( writeExpression ); appender.append( ",'$')" ); } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/MariaDBCastingJsonJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/MariaDBCastingJsonJdbcType.java index d6f43349398a..c0a872a4d425 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/MariaDBCastingJsonJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/MariaDBCastingJsonJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.dialect.type; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.EmbeddableMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.sql.ast.spi.SqlAppender; @@ -35,10 +37,16 @@ public AggregateJdbcType resolveAggregateJdbcType( @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "json_extract(" ); appender.append( writeExpression ); appender.append( ",'$')" ); } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/MySQLCastingJsonArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/MySQLCastingJsonArrayJdbcType.java index 9c0de9b9b066..ac6f57dac969 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/MySQLCastingJsonArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/MySQLCastingJsonArrayJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.dialect.type; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcType; @@ -21,10 +23,16 @@ public MySQLCastingJsonArrayJdbcType(JdbcType elementJdbcType) { @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); appender.append( writeExpression ); appender.append( " as json)" ); } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/MySQLCastingJsonJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/MySQLCastingJsonJdbcType.java index b63457bb5418..41bc1c25c6a7 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/MySQLCastingJsonJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/MySQLCastingJsonJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.dialect.type; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.EmbeddableMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.sql.ast.spi.SqlAppender; @@ -35,10 +37,16 @@ public AggregateJdbcType resolveAggregateJdbcType( @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); appender.append( writeExpression ); appender.append( " as json)" ); } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/OracleArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/OracleArrayJdbcType.java index acbe11e3b481..7712cd7e08d2 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/OracleArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/OracleArrayJdbcType.java @@ -10,6 +10,7 @@ import java.sql.SQLException; import java.sql.Types; import java.util.Locale; +import java.util.Objects; import org.hibernate.HibernateException; import org.hibernate.boot.model.relational.Database; @@ -288,4 +289,16 @@ public String getFriendlyName() { public String toString() { return "OracleArrayTypeDescriptor(" + typeName + ")"; } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof OracleArrayJdbcType jdbcType + && Objects.equals( typeName, jdbcType.typeName ); + } + + @Override + public int hashCode() { + return Objects.hashCode( typeName ) + 31 * super.hashCode(); + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingInetJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingInetJdbcType.java index 5f17e618ecbb..45b446cfc83f 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingInetJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingInetJdbcType.java @@ -10,7 +10,9 @@ import java.sql.ResultSet; import java.sql.SQLException; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.ValueBinder; @@ -32,6 +34,7 @@ public class PostgreSQLCastingInetJdbcType implements JdbcType { @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); @@ -39,6 +42,11 @@ public void appendWriteExpression( appender.append( " as inet)" ); } + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + @Override public int getJdbcTypeCode() { return SqlTypes.VARBINARY; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingIntervalSecondJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingIntervalSecondJdbcType.java index fbb7c0e12ec1..7f036ddd45f6 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingIntervalSecondJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingIntervalSecondJdbcType.java @@ -10,7 +10,9 @@ import java.sql.ResultSet; import java.sql.SQLException; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.sql.ast.SqlAstTranslator; @@ -74,6 +76,7 @@ public JdbcMappingContainer getExpressionType() { @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( '(' ); @@ -81,6 +84,11 @@ public void appendWriteExpression( appender.append( "*interval'1 second')" ); } + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + @Override public int getJdbcTypeCode() { return SqlTypes.NUMERIC; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingJsonArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingJsonArrayJdbcType.java index eae07c430894..12efa281b002 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingJsonArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingJsonArrayJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.dialect.type; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcType; @@ -24,6 +26,7 @@ public PostgreSQLCastingJsonArrayJdbcType(JdbcType elementJdbcType, boolean json @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); @@ -36,4 +39,9 @@ public void appendWriteExpression( appender.append( "json)" ); } } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingJsonJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingJsonJdbcType.java index 77ae4a56d165..ac707333eb83 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingJsonJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLCastingJsonJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.dialect.type; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.EmbeddableMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.sql.ast.spi.SqlAppender; @@ -43,6 +45,7 @@ public AggregateJdbcType resolveAggregateJdbcType( @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); @@ -55,4 +58,9 @@ public void appendWriteExpression( appender.append( "json)" ); } } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLStructCastingJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLStructCastingJdbcType.java index d911542c106e..eb1017168f8c 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLStructCastingJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/PostgreSQLStructCastingJdbcType.java @@ -8,8 +8,10 @@ import java.sql.PreparedStatement; import java.sql.SQLException; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.boot.model.naming.Identifier; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.EmbeddableMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.sql.ast.spi.SqlAppender; @@ -55,6 +57,7 @@ public AggregateJdbcType resolveAggregateJdbcType( @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); @@ -64,6 +67,11 @@ public void appendWriteExpression( appender.append( ')' ); } + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + @Override public ValueBinder getBinder(JavaType javaType) { return new BasicBinder<>( javaType, this ) { diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/SQLServerCastingXmlArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/SQLServerCastingXmlArrayJdbcType.java index 7f0a4468cc2a..3e2201c7c4ee 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/SQLServerCastingXmlArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/SQLServerCastingXmlArrayJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.dialect.type; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.XmlArrayJdbcType; @@ -21,10 +23,16 @@ public SQLServerCastingXmlArrayJdbcType(JdbcType elementJdbcType) { @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); appender.append( writeExpression ); appender.append( " as xml)" ); } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/type/SQLServerCastingXmlJdbcType.java b/hibernate-core/src/main/java/org/hibernate/dialect/type/SQLServerCastingXmlJdbcType.java index a91b8464434f..5821c79a2d40 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/type/SQLServerCastingXmlJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/type/SQLServerCastingXmlJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.dialect.type; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.EmbeddableMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.sql.ast.spi.SqlAppender; @@ -35,10 +37,16 @@ public AggregateJdbcType resolveAggregateJdbcType( @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( "cast(" ); appender.append( writeExpression ); appender.append( " as xml)" ); } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/mapping/Selectable.java b/hibernate-core/src/main/java/org/hibernate/mapping/Selectable.java index 7154eeaa2c60..b6d6c39810ad 100644 --- a/hibernate-core/src/main/java/org/hibernate/mapping/Selectable.java +++ b/hibernate-core/src/main/java/org/hibernate/mapping/Selectable.java @@ -6,7 +6,9 @@ import org.hibernate.Incubating; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.type.MappingContext; import org.hibernate.type.spi.TypeConfiguration; /** @@ -70,8 +72,14 @@ default String getWriteExpr() { : customWriteExpression; } - @Incubating + @Deprecated(forRemoval = true, since = "7.2") default String getWriteExpr(JdbcMapping jdbcMapping, Dialect dialect) { - return jdbcMapping.getJdbcType().wrapWriteExpression( getWriteExpr(), dialect ); + return jdbcMapping.getJdbcType().wrapWriteExpression( getWriteExpr(), null, dialect ); + } + + @Incubating + default String getWriteExpr(JdbcMapping jdbcMapping, Dialect dialect, MappingContext mappingContext) { + final Size size = this instanceof Column column ? column.getColumnSize( dialect, mappingContext ) : null; + return jdbcMapping.getJdbcType().wrapWriteExpression( getWriteExpr(), size, dialect ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SelectableConsumer.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SelectableConsumer.java index 9b27ddfc1a12..8908f9b341d8 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SelectableConsumer.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SelectableConsumer.java @@ -4,6 +4,8 @@ */ package org.hibernate.metamodel.mapping; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.function.BiConsumer; import java.util.function.IntFunction; @@ -116,40 +118,47 @@ public String getColumnDefinition() { } @Override - public Long getLength() { + public @Nullable Long getLength() { // we could probably use the details from `base`, but // this method should really never be called on this object throw new UnsupportedOperationException(); } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + // we could probably use the details from `base`, but + // this method should really never be called on this object + throw new UnsupportedOperationException(); + } + + @Override + public @Nullable Integer getPrecision() { // we could probably use the details from `base`, but // this method should really never be called on this object return null; } @Override - public Integer getScale() { + public @Nullable Integer getScale() { // we could probably use the details from `base`, but // this method should really never be called on this object return null; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { // we could probably use the details from `base`, but // this method should really never be called on this object return null; } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return null; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return null; } } @@ -179,37 +188,42 @@ public String getSelectionExpression() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { + return null; + } + + @Override + public @Nullable String getCustomWriteExpression() { return null; } @Override - public String getCustomWriteExpression() { + public @Nullable String getColumnDefinition() { return null; } @Override - public String getColumnDefinition() { + public @Nullable Long getLength() { return null; } @Override - public Long getLength() { + public @Nullable Integer getArrayLength() { return null; } @Override - public Integer getPrecision() { + public @Nullable Integer getPrecision() { return null; } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return null; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return null; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SelectableMapping.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SelectableMapping.java index 1db168a318fc..a00188374396 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SelectableMapping.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SelectableMapping.java @@ -4,6 +4,7 @@ */ package org.hibernate.metamodel.mapping; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.Incubating; import org.hibernate.annotations.ColumnTransformer; @@ -34,14 +35,14 @@ default SelectablePath getSelectablePath() { * The selection's read expression accounting for formula treatment as well * as {@link ColumnTransformer#read()} */ - String getCustomReadExpression(); + @Nullable String getCustomReadExpression(); /** * The selection's write expression accounting {@link ColumnTransformer#write()} * * @apiNote Always null for formula mappings */ - String getCustomWriteExpression(); + @Nullable String getCustomWriteExpression(); default String getWriteExpression() { final String customWriteExpression = getCustomWriteExpression(); diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SoftDeleteMapping.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SoftDeleteMapping.java index 8a74f62f5a38..4f1b02afe8dc 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SoftDeleteMapping.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SoftDeleteMapping.java @@ -4,6 +4,7 @@ */ package org.hibernate.metamodel.mapping; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.annotations.SoftDeleteType; import org.hibernate.sql.ast.spi.SqlExpressionResolver; import org.hibernate.sql.ast.tree.expression.ColumnReference; @@ -99,12 +100,12 @@ default String getContainingTableExpression() { } @Override - default String getCustomReadExpression() { + default @Nullable String getCustomReadExpression() { return null; } @Override - default String getCustomWriteExpression() { + default @Nullable String getCustomWriteExpression() { return null; } @@ -134,27 +135,32 @@ default boolean isPartitioned() { } @Override - default String getColumnDefinition() { + default @Nullable String getColumnDefinition() { return null; } @Override - default Long getLength() { + default @Nullable Long getLength() { return null; } @Override - default Integer getPrecision() { + default @Nullable Integer getArrayLength() { return null; } @Override - default Integer getScale() { + default @Nullable Integer getPrecision() { return null; } @Override - default Integer getTemporalPrecision() { + default @Nullable Integer getScale() { + return null; + } + + @Override + default @Nullable Integer getTemporalPrecision() { return null; } } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SqlTypedMapping.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SqlTypedMapping.java index 5c9515a59429..eba1296c8be1 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SqlTypedMapping.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/SqlTypedMapping.java @@ -16,6 +16,7 @@ public interface SqlTypedMapping { @Nullable String getColumnDefinition(); @Nullable Long getLength(); + @Nullable Integer getArrayLength(); @Nullable Integer getPrecision(); @Nullable Integer getScale(); @Nullable Integer getTemporalPrecision(); @@ -26,6 +27,7 @@ default boolean isLob() { default Size toSize() { final Size size = new Size(); + size.setArrayLength( getArrayLength() ); size.setLength( getLength() ); if ( getTemporalPrecision() != null ) { size.setPrecision( getTemporalPrecision() ); diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AbstractEmbeddableMapping.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AbstractEmbeddableMapping.java index 82bda15f6147..893f4cdd1d54 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AbstractEmbeddableMapping.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AbstractEmbeddableMapping.java @@ -332,6 +332,7 @@ protected boolean finishInitialization( final SelectablePath selectablePath; final String columnDefinition; final Long length; + final Integer arrayLength; final Integer precision; final Integer scale; final Integer temporalPrecision; @@ -340,6 +341,7 @@ protected boolean finishInitialization( if ( selectable instanceof Column column ) { columnDefinition = column.getSqlType(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); scale = column.getScale(); temporalPrecision = column.getTemporalPrecision(); @@ -351,6 +353,7 @@ protected boolean finishInitialization( else { columnDefinition = null; length = null; + arrayLength = null; precision = null; scale = null; temporalPrecision = null; @@ -372,9 +375,14 @@ protected boolean finishInitialization( selectablePath, selectable.isFormula(), selectable.getCustomReadExpression(), - selectable.getWriteExpr( basicValue.getResolution().getJdbcMapping(), dialect ), + selectable.getWriteExpr( + basicValue.getResolution().getJdbcMapping(), + dialect, + creationProcess.getCreationContext().getBootModel() + ), columnDefinition, length, + arrayLength, precision, scale, temporalPrecision, diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AnyDiscriminatorPart.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AnyDiscriminatorPart.java index 25a8c03367dc..9cc9a2e09a36 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AnyDiscriminatorPart.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AnyDiscriminatorPart.java @@ -4,6 +4,7 @@ */ package org.hibernate.metamodel.mapping.internal; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.cache.MutableCacheKeyBuilder; import org.hibernate.engine.FetchStyle; import org.hibernate.engine.FetchTiming; @@ -56,12 +57,13 @@ public class AnyDiscriminatorPart implements DiscriminatorMapping, FetchOptions private final String table; private final String column; private final SelectablePath selectablePath; - private final String customReadExpression; - private final String customWriteExpression; - private final String columnDefinition; - private final Long length; - private final Integer precision; - private final Integer scale; + private final @Nullable String customReadExpression; + private final @Nullable String customWriteExpression; + private final @Nullable String columnDefinition; + private final @Nullable Long length; + private final @Nullable Integer arrayLength; + private final @Nullable Integer precision; + private final @Nullable Integer scale; private final boolean insertable; private final boolean updateable; @@ -70,6 +72,7 @@ public class AnyDiscriminatorPart implements DiscriminatorMapping, FetchOptions private final BasicType underlyingJdbcMapping; private final DiscriminatorConverter valueConverter; + @Deprecated(forRemoval = true, since = "7.2") public AnyDiscriminatorPart( NavigableRole partRole, DiscriminatedAssociationModelPart declaringType, @@ -86,6 +89,49 @@ public AnyDiscriminatorPart( boolean updateable, boolean partitioned, BasicType underlyingJdbcMapping, + Map valueToEntityNameMap, + ImplicitDiscriminatorStrategy implicitValueStrategy, + MappingMetamodelImplementor mappingMetamodel) { + this( + partRole, + declaringType, + table, + column, + selectablePath, + customReadExpression, + customWriteExpression, + columnDefinition, + length, + null, + precision, + scale, + insertable, + updateable, + partitioned, + underlyingJdbcMapping, + valueToEntityNameMap, + implicitValueStrategy, + mappingMetamodel + ); + } + + public AnyDiscriminatorPart( + NavigableRole partRole, + DiscriminatedAssociationModelPart declaringType, + String table, + String column, + SelectablePath selectablePath, + @Nullable String customReadExpression, + @Nullable String customWriteExpression, + @Nullable String columnDefinition, + @Nullable Long length, + @Nullable Integer arrayLength, + @Nullable Integer precision, + @Nullable Integer scale, + boolean insertable, + boolean updateable, + boolean partitioned, + BasicType underlyingJdbcMapping, Map valueToEntityNameMap, ImplicitDiscriminatorStrategy implicitValueStrategy, MappingMetamodelImplementor mappingMetamodel) { @@ -98,6 +144,7 @@ public AnyDiscriminatorPart( this.customWriteExpression = customWriteExpression; this.columnDefinition = columnDefinition; this.length = length; + this.arrayLength = arrayLength; this.precision = precision; this.scale = scale; this.insertable = insertable; @@ -184,37 +231,42 @@ public boolean isPartitioned() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return customReadExpression; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return customWriteExpression; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return columnDefinition; } @Override - public Long getLength() { + public @Nullable Long getLength() { return length; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + return arrayLength; + } + + @Override + public @Nullable Integer getPrecision() { return precision; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return null; } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return scale; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AnyKeyPart.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AnyKeyPart.java index 3e5e07857ed7..58576a06115d 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AnyKeyPart.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/AnyKeyPart.java @@ -6,6 +6,7 @@ import java.util.function.BiConsumer; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.engine.FetchStyle; import org.hibernate.engine.FetchTiming; import org.hibernate.engine.spi.SharedSessionContractImplementor; @@ -46,12 +47,13 @@ public class AnyKeyPart implements BasicValuedModelPart, FetchOptions { private final String column; private final SelectablePath selectablePath; private final DiscriminatedAssociationModelPart anyPart; - private final String customReadExpression; - private final String customWriteExpression; - private final String columnDefinition; - private final Long length; - private final Integer precision; - private final Integer scale; + private final @Nullable String customReadExpression; + private final @Nullable String customWriteExpression; + private final @Nullable String columnDefinition; + private final @Nullable Long length; + private final @Nullable Integer arrayLength; + private final @Nullable Integer precision; + private final @Nullable Integer scale; private final boolean nullable; private final boolean insertable; private final boolean updateable; @@ -75,6 +77,45 @@ public AnyKeyPart( boolean updateable, boolean partitioned, JdbcMapping jdbcMapping) { + this( + navigableRole, + anyPart, + table, + column, + selectablePath, + customReadExpression, + customWriteExpression, + columnDefinition, + length, + null, + precision, + scale, + nullable, + insertable, + updateable, + partitioned, + jdbcMapping + ); + } + + public AnyKeyPart( + NavigableRole navigableRole, + DiscriminatedAssociationModelPart anyPart, + String table, + String column, + SelectablePath selectablePath, + @Nullable String customReadExpression, + @Nullable String customWriteExpression, + @Nullable String columnDefinition, + @Nullable Long length, + @Nullable Integer arrayLength, + @Nullable Integer precision, + @Nullable Integer scale, + boolean nullable, + boolean insertable, + boolean updateable, + boolean partitioned, + JdbcMapping jdbcMapping) { this.navigableRole = navigableRole; this.table = table; this.column = column; @@ -84,6 +125,7 @@ public AnyKeyPart( this.customWriteExpression = customWriteExpression; this.columnDefinition = columnDefinition; this.length = length; + this.arrayLength = arrayLength; this.precision = precision; this.scale = scale; this.nullable = nullable; @@ -139,37 +181,42 @@ public boolean isPartitioned() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return customReadExpression; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return customWriteExpression; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return columnDefinition; } @Override - public Long getLength() { + public @Nullable Long getLength() { return length; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + return arrayLength; + } + + @Override + public @Nullable Integer getPrecision() { return precision; } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return scale; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return null; } @@ -372,7 +419,8 @@ private SqlSelection resolveSqlSelection( this, getContainingTableExpression() ); - final SqlExpressionResolver expressionResolver = creationState.getSqlAstCreationState().getSqlExpressionResolver(); + final SqlExpressionResolver expressionResolver = creationState.getSqlAstCreationState() + .getSqlExpressionResolver(); return expressionResolver.resolveSqlSelection( expressionResolver.resolveSqlExpression( tableReference, diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicAttributeMapping.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicAttributeMapping.java index 261e16e005cc..8d670e8ad22a 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicAttributeMapping.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicAttributeMapping.java @@ -6,6 +6,7 @@ import java.util.function.BiConsumer; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.engine.FetchStyle; import org.hibernate.engine.FetchTiming; import org.hibernate.engine.spi.SharedSessionContractImplementor; @@ -49,12 +50,13 @@ public class BasicAttributeMapping private final Integer temporalPrecision; private final SelectablePath selectablePath; private final boolean isFormula; - private final String customReadExpression; - private final String customWriteExpression; - private final String columnDefinition; - private final Long length; - private final Integer precision; - private final Integer scale; + private final @Nullable String customReadExpression; + private final @Nullable String customWriteExpression; + private final @Nullable String columnDefinition; + private final @Nullable Long length; + private final @Nullable Integer arrayLength; + private final @Nullable Integer precision; + private final @Nullable Integer scale; private final JdbcMapping jdbcMapping; private final boolean isLob; @@ -66,6 +68,7 @@ public class BasicAttributeMapping private final JavaType domainTypeDescriptor; + @Deprecated(forRemoval = true, since = "7.2") public BasicAttributeMapping( String attributeName, NavigableRole navigableRole, @@ -78,13 +81,71 @@ public BasicAttributeMapping( String mappedColumnExpression, SelectablePath selectablePath, boolean isFormula, - String customReadExpression, - String customWriteExpression, - String columnDefinition, - Long length, - Integer precision, - Integer scale, - Integer temporalPrecision, + @Nullable String customReadExpression, + @Nullable String customWriteExpression, + @Nullable String columnDefinition, + @Nullable Long length, + @Nullable Integer precision, + @Nullable Integer scale, + @Nullable Integer temporalPrecision, + boolean isLob, + boolean nullable, + boolean insertable, + boolean updateable, + boolean partitioned, + JdbcMapping jdbcMapping, + ManagedMappingType declaringType, + PropertyAccess propertyAccess) { + this( + attributeName, + navigableRole, + stateArrayPosition, + fetchableIndex, + attributeMetadata, + mappedFetchTiming, + mappedFetchStyle, + tableExpression, + mappedColumnExpression, + selectablePath, + isFormula, + customReadExpression, + customWriteExpression, + columnDefinition, + length, + null, + precision, + scale, + temporalPrecision, + isLob, + nullable, + insertable, + updateable, + partitioned, + jdbcMapping, + declaringType, + propertyAccess ); + } + + public BasicAttributeMapping( + String attributeName, + NavigableRole navigableRole, + int stateArrayPosition, + int fetchableIndex, + AttributeMetadata attributeMetadata, + FetchTiming mappedFetchTiming, + FetchStyle mappedFetchStyle, + String tableExpression, + String mappedColumnExpression, + SelectablePath selectablePath, + boolean isFormula, + @Nullable String customReadExpression, + @Nullable String customWriteExpression, + @Nullable String columnDefinition, + @Nullable Long length, + @Nullable Integer arrayLength, + @Nullable Integer precision, + @Nullable Integer scale, + @Nullable Integer temporalPrecision, boolean isLob, boolean nullable, boolean insertable, @@ -113,6 +174,7 @@ public BasicAttributeMapping( this.isFormula = isFormula; this.columnDefinition = columnDefinition; this.length = length; + this.arrayLength = arrayLength; this.precision = precision; this.scale = scale; this.isLob = isLob; @@ -184,6 +246,7 @@ else if ( original instanceof SingularAttributeMapping mapping ) { selectableMapping.getCustomWriteExpression(), selectableMapping.getColumnDefinition(), selectableMapping.getLength(), + selectableMapping.getArrayLength(), selectableMapping.getPrecision(), selectableMapping.getScale(), selectableMapping.getTemporalPrecision(), @@ -263,12 +326,12 @@ public boolean isPartitioned() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return customReadExpression; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return customWriteExpression; } @@ -278,27 +341,32 @@ public String getWriteExpression() { } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return columnDefinition; } @Override - public Long getLength() { + public @Nullable Long getLength() { return length; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + return arrayLength; + } + + @Override + public @Nullable Integer getPrecision() { return precision; } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return scale; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return temporalPrecision; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicEntityIdentifierMappingImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicEntityIdentifierMappingImpl.java index db4b78e85184..501cc728f676 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicEntityIdentifierMappingImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicEntityIdentifierMappingImpl.java @@ -8,6 +8,7 @@ import java.util.function.BiConsumer; import java.util.function.Supplier; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.cache.MutableCacheKeyBuilder; import org.hibernate.engine.FetchStyle; import org.hibernate.engine.FetchTiming; @@ -64,10 +65,11 @@ public class BasicEntityIdentifierMappingImpl implements BasicEntityIdentifierMa private final String rootTable; private final String pkColumnName; - private final String columnDefinition; - private final Long length; - private final Integer precision; - private final Integer scale; + private final @Nullable String columnDefinition; + private final @Nullable Long length; + private final @Nullable Integer arrayLength; + private final @Nullable Integer precision; + private final @Nullable Integer scale; private final boolean insertable; private final boolean updateable; @@ -75,6 +77,7 @@ public class BasicEntityIdentifierMappingImpl implements BasicEntityIdentifierMa private final SessionFactoryImplementor sessionFactory; + @Deprecated(forRemoval = true, since = "7.2") public BasicEntityIdentifierMappingImpl( EntityPersister entityPersister, Supplier instanceCreator, @@ -89,8 +92,42 @@ public BasicEntityIdentifierMappingImpl( boolean updateable, BasicType idType, MappingModelCreationProcess creationProcess) { + this( + entityPersister, + instanceCreator, + attributeName, + rootTable, + pkColumnName, + columnDefinition, + length, + null, + precision, + scale, + insertable, + updateable, + idType, + creationProcess + ); + } + + public BasicEntityIdentifierMappingImpl( + EntityPersister entityPersister, + Supplier instanceCreator, + String attributeName, + String rootTable, + String pkColumnName, + @Nullable String columnDefinition, + @Nullable Long length, + @Nullable Integer arrayLength, + @Nullable Integer precision, + @Nullable Integer scale, + boolean insertable, + boolean updateable, + BasicType idType, + MappingModelCreationProcess creationProcess) { this.columnDefinition = columnDefinition; this.length = length; + this.arrayLength = arrayLength; this.precision = precision; this.scale = scale; this.insertable = insertable; @@ -324,37 +361,42 @@ public boolean hasPartitionedSelectionMapping() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return null; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return null; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return columnDefinition; } @Override - public Long getLength() { + public @Nullable Long getLength() { return length; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + return arrayLength; + } + + @Override + public @Nullable Integer getPrecision() { return precision; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return null; } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return scale; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicValuedCollectionPart.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicValuedCollectionPart.java index 603b050e0aa8..e8a08f331165 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicValuedCollectionPart.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/BasicValuedCollectionPart.java @@ -6,6 +6,7 @@ import java.util.function.BiConsumer; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.engine.FetchStyle; import org.hibernate.engine.FetchTiming; import org.hibernate.engine.spi.SharedSessionContractImplementor; @@ -111,37 +112,42 @@ public boolean isUpdateable() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return selectableMapping.getCustomReadExpression(); } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return selectableMapping.getCustomWriteExpression(); } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return selectableMapping.getColumnDefinition(); } @Override - public Long getLength() { + public @Nullable Long getLength() { return selectableMapping.getLength(); } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + return selectableMapping.getArrayLength(); + } + + @Override + public @Nullable Integer getPrecision() { return selectableMapping.getPrecision(); } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return selectableMapping.getTemporalPrecision(); } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return selectableMapping.getScale(); } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/CaseStatementDiscriminatorMappingImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/CaseStatementDiscriminatorMappingImpl.java index 6ea8b92d0e2d..1ef1b8025572 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/CaseStatementDiscriminatorMappingImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/CaseStatementDiscriminatorMappingImpl.java @@ -8,6 +8,7 @@ import java.util.LinkedHashMap; import java.util.List; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.engine.FetchTiming; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.metamodel.mapping.DiscriminatorType; @@ -141,37 +142,42 @@ private Expression createCaseSearchedExpression(TableGroup entityTableGroup) { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return null; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return null; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return null; } @Override - public Long getLength() { + public @Nullable Long getLength() { return null; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { return null; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getPrecision() { return null; } @Override - public Integer getScale() { + public @Nullable Integer getTemporalPrecision() { + return null; + } + + @Override + public @Nullable Integer getScale() { return null; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/CollectionIdentifierDescriptorImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/CollectionIdentifierDescriptorImpl.java index ae076700498a..631da553a63d 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/CollectionIdentifierDescriptorImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/CollectionIdentifierDescriptorImpl.java @@ -6,6 +6,7 @@ import java.util.function.BiConsumer; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.cache.MutableCacheKeyBuilder; import org.hibernate.engine.FetchStyle; import org.hibernate.engine.FetchTiming; @@ -108,37 +109,42 @@ public boolean isNullable() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return null; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return null; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return null; } @Override - public Long getLength() { + public @Nullable Long getLength() { return null; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { return null; } @Override - public Integer getScale() { + public @Nullable Integer getPrecision() { return null; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getScale() { + return null; + } + + @Override + public @Nullable Integer getTemporalPrecision() { return null; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/DiscriminatedAssociationMapping.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/DiscriminatedAssociationMapping.java index 5265b7eda6ca..6359c245fdcb 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/DiscriminatedAssociationMapping.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/DiscriminatedAssociationMapping.java @@ -88,6 +88,7 @@ public static DiscriminatedAssociationMapping from( metaColumn.getCustomWriteExpression(), metaColumn.getSqlType(), metaColumn.getLength(), + metaColumn.getArrayLength(), metaColumn.getPrecision(), metaColumn.getScale(), bootValueMapping.isColumnInsertable( 0 ), @@ -112,6 +113,7 @@ public static DiscriminatedAssociationMapping from( keyColumn.getCustomWriteExpression(), keyColumn.getSqlType(), keyColumn.getLength(), + keyColumn.getArrayLength(), keyColumn.getPrecision(), keyColumn.getScale(), bootValueMapping.isNullable(), diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EmbeddableMappingTypeImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EmbeddableMappingTypeImpl.java index 1424e73be38b..b613738d123c 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EmbeddableMappingTypeImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EmbeddableMappingTypeImpl.java @@ -476,6 +476,7 @@ private boolean finishInitialization( final SelectablePath selectablePath; final String columnDefinition; final Long length; + final Integer arrayLength; final Integer precision; final Integer scale; final Integer temporalPrecision; @@ -484,6 +485,7 @@ private boolean finishInitialization( if ( selectable instanceof Column column ) { columnDefinition = column.getSqlType(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); scale = column.getScale(); temporalPrecision = column.getTemporalPrecision(); @@ -495,6 +497,7 @@ private boolean finishInitialization( else { columnDefinition = null; length = null; + arrayLength = null; precision = null; scale = null; temporalPrecision = null; @@ -515,9 +518,14 @@ private boolean finishInitialization( selectablePath, selectable.isFormula(), selectable.getCustomReadExpression(), - selectable.getWriteExpr( basicValue.getResolution().getJdbcMapping(), dialect ), + selectable.getWriteExpr( + basicValue.getResolution().getJdbcMapping(), + dialect, + creationProcess.getCreationContext().getBootModel() + ), columnDefinition, length, + arrayLength, precision, scale, temporalPrecision, @@ -715,6 +723,7 @@ private EmbeddableDiscriminatorMapping generateDiscriminatorMapping( final String columnDefinition; final String name; final Long length; + final Integer arrayLength; final Integer precision; final Integer scale; final boolean isFormula = discriminator.hasFormula(); @@ -726,6 +735,7 @@ private EmbeddableDiscriminatorMapping generateDiscriminatorMapping( ); columnDefinition = null; length = null; + arrayLength = null; precision = null; scale = null; } @@ -736,6 +746,7 @@ private EmbeddableDiscriminatorMapping generateDiscriminatorMapping( columnDefinition = column.getSqlType(); name = column.getName(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); scale = column.getScale(); } @@ -751,6 +762,7 @@ private EmbeddableDiscriminatorMapping generateDiscriminatorMapping( columnDefinition, selectable.getCustomReadExpression(), length, + arrayLength, precision, scale, bootDescriptor.getDiscriminatorType() diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EntityRowIdMappingImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EntityRowIdMappingImpl.java index b83e2cf56a9b..cf2e55651144 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EntityRowIdMappingImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EntityRowIdMappingImpl.java @@ -6,6 +6,7 @@ import java.util.function.BiConsumer; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.cache.MutableCacheKeyBuilder; import org.hibernate.engine.FetchTiming; import org.hibernate.engine.spi.SessionFactoryImplementor; @@ -188,37 +189,42 @@ public String getSelectionExpression() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return null; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return null; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return null; } @Override - public Long getLength() { + public @Nullable Long getLength() { return null; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { return null; } @Override - public Integer getScale() { + public @Nullable Integer getPrecision() { return null; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getScale() { + return null; + } + + @Override + public @Nullable Integer getTemporalPrecision() { return null; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EntityVersionMappingImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EntityVersionMappingImpl.java index cd824754b7b0..8c5cf1d2feeb 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EntityVersionMappingImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/EntityVersionMappingImpl.java @@ -7,6 +7,7 @@ import java.util.function.BiConsumer; import java.util.function.Supplier; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.cache.MutableCacheKeyBuilder; import org.hibernate.engine.FetchStyle; import org.hibernate.engine.FetchTiming; @@ -46,16 +47,18 @@ public class EntityVersionMappingImpl implements EntityVersionMapping, FetchOpti private final String columnTableExpression; private final String columnExpression; - private final String columnDefinition; - private final Long length; - private final Integer precision; - private final Integer scale; - private final Integer temporalPrecision; + private final @Nullable String columnDefinition; + private final @Nullable Long length; + private final @Nullable Integer arrayLength; + private final @Nullable Integer precision; + private final @Nullable Integer scale; + private final @Nullable Integer temporalPrecision; private final BasicType versionBasicType; private final VersionValue unsavedValueStrategy; + @Deprecated(forRemoval = true, since = "7.2") public EntityVersionMappingImpl( RootClass bootEntityDescriptor, Supplier templateInstanceAccess, @@ -69,9 +72,41 @@ public EntityVersionMappingImpl( Integer temporalPrecision, BasicType versionBasicType, EntityMappingType declaringType) { + this( + bootEntityDescriptor, + templateInstanceAccess, + attributeName, + columnTableExpression, + columnExpression, + columnDefinition, + length, + null, + precision, + scale, + temporalPrecision, + versionBasicType, + declaringType + ); + } + + public EntityVersionMappingImpl( + RootClass bootEntityDescriptor, + Supplier templateInstanceAccess, + String attributeName, + String columnTableExpression, + String columnExpression, + @Nullable String columnDefinition, + @Nullable Long length, + @Nullable Integer arrayLength, + @Nullable Integer precision, + @Nullable Integer scale, + @Nullable Integer temporalPrecision, + BasicType versionBasicType, + EntityMappingType declaringType) { this.attributeName = attributeName; this.columnDefinition = columnDefinition; this.length = length; + this.arrayLength = arrayLength; this.precision = precision; this.scale = scale; this.temporalPrecision = temporalPrecision; @@ -143,37 +178,42 @@ public boolean hasPartitionedSelectionMapping() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return null; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return null; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return columnDefinition; } @Override - public Long getLength() { + public @Nullable Long getLength() { return length; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + return arrayLength; + } + + @Override + public @Nullable Integer getPrecision() { return precision; } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return scale; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return temporalPrecision; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/ExplicitColumnDiscriminatorMappingImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/ExplicitColumnDiscriminatorMappingImpl.java index bb9441894818..ff01a75b4c77 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/ExplicitColumnDiscriminatorMappingImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/ExplicitColumnDiscriminatorMappingImpl.java @@ -4,6 +4,7 @@ */ package org.hibernate.metamodel.mapping.internal; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.metamodel.mapping.DiscriminatorConverter; import org.hibernate.metamodel.mapping.DiscriminatorType; import org.hibernate.metamodel.mapping.EmbeddableDiscriminatorMapping; @@ -31,11 +32,12 @@ public class ExplicitColumnDiscriminatorMappingImpl extends AbstractDiscriminato private final String columnFormula; private final boolean isPhysical; private final boolean isUpdateable; - private final String columnDefinition; - private final String customReadExpression; - private final Long length; - private final Integer precision; - private final Integer scale; + private final @Nullable String columnDefinition; + private final @Nullable String customReadExpression; + private final @Nullable Long length; + private final @Nullable Integer arrayLength; + private final @Nullable Integer precision; + private final @Nullable Integer scale; public ExplicitColumnDiscriminatorMappingImpl( ManagedMappingType mappingType, @@ -51,6 +53,38 @@ public ExplicitColumnDiscriminatorMappingImpl( Integer precision, Integer scale, DiscriminatorType discriminatorType) { + this( + mappingType, + name, + tableExpression, + columnExpression, + isFormula, + isPhysical, + isUpdateable, + columnDefinition, + customReadExpression, + length, + null, + precision, + scale, + discriminatorType ); + } + + public ExplicitColumnDiscriminatorMappingImpl( + ManagedMappingType mappingType, + String name, + String tableExpression, + String columnExpression, + boolean isFormula, + boolean isPhysical, + boolean isUpdateable, + @Nullable String columnDefinition, + @Nullable String customReadExpression, + @Nullable Long length, + @Nullable Integer arrayLength, + @Nullable Integer precision, + @Nullable Integer scale, + DiscriminatorType discriminatorType) { //noinspection unchecked super( mappingType, (DiscriminatorType) discriminatorType, (BasicType) discriminatorType.getUnderlyingJdbcMapping() ); this.name = name; @@ -59,6 +93,7 @@ public ExplicitColumnDiscriminatorMappingImpl( this.columnDefinition = columnDefinition; this.customReadExpression = customReadExpression; this.length = length; + this.arrayLength = arrayLength; this.precision = precision; this.scale = scale; if ( isFormula ) { @@ -118,37 +153,42 @@ public String getSelectionExpression() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return customReadExpression; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return null; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return columnDefinition; } @Override - public Long getLength() { + public @Nullable Long getLength() { return length; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + return arrayLength; + } + + @Override + public @Nullable Integer getPrecision() { return precision; } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return scale; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return null; } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/MappingModelCreationHelper.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/MappingModelCreationHelper.java index b20c539a4997..9e932f897108 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/MappingModelCreationHelper.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/MappingModelCreationHelper.java @@ -197,6 +197,63 @@ public static CompositeIdentifierMapping buildNonEncapsulatedCompositeIdentifier // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Non-identifier attributes + @Deprecated(forRemoval = true, since = "7.2") + public static BasicAttributeMapping buildBasicAttributeMapping( + String attrName, + NavigableRole navigableRole, + int stateArrayPosition, + int fetchableIndex, + Property bootProperty, + ManagedMappingType declaringType, + BasicType attrType, + String tableExpression, + String attrColumnName, + SelectablePath selectablePath, + boolean isAttrFormula, + String readExpr, + String writeExpr, + String columnDefinition, + Long length, + Integer precision, + Integer scale, + Integer temporalPrecision, + boolean isLob, + boolean nullable, + boolean insertable, + boolean updateable, + PropertyAccess propertyAccess, + CascadeStyle cascadeStyle, + MappingModelCreationProcess creationProcess) { + return buildBasicAttributeMapping( + attrName, + navigableRole, + stateArrayPosition, + fetchableIndex, + bootProperty, + declaringType, + attrType, + tableExpression, + attrColumnName, + selectablePath, + isAttrFormula, + readExpr, + writeExpr, + columnDefinition, + length, + null, + precision, + scale, + temporalPrecision, + isLob, + nullable, + insertable, + updateable, + propertyAccess, + cascadeStyle, + creationProcess + ); + } + @SuppressWarnings("rawtypes") public static BasicAttributeMapping buildBasicAttributeMapping( String attrName, @@ -214,6 +271,7 @@ public static BasicAttributeMapping buildBasicAttributeMapping( String writeExpr, String columnDefinition, Long length, + Integer arrayLength, Integer precision, Integer scale, Integer temporalPrecision, @@ -267,6 +325,7 @@ public static BasicAttributeMapping buildBasicAttributeMapping( writeExpr, columnDefinition, length, + arrayLength, precision, scale, temporalPrecision, diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SelectableMappingImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SelectableMappingImpl.java index fb43a87ec780..97e1d3b1bd93 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SelectableMappingImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SelectableMappingImpl.java @@ -52,7 +52,48 @@ public SelectableMappingImpl( boolean partitioned, boolean isFormula, JdbcMapping jdbcMapping) { - super( columnDefinition, length, precision, scale, temporalPrecision, jdbcMapping ); + this( + containingTableExpression, + selectionExpression, + selectablePath, + customReadExpression, + customWriteExpression, + columnDefinition, + length, + null, + precision, + scale, + temporalPrecision, + isLob, + nullable, + insertable, + updateable, + partitioned, + isFormula, + jdbcMapping + ); + } + + public SelectableMappingImpl( + String containingTableExpression, + String selectionExpression, + SelectablePath selectablePath, + String customReadExpression, + String customWriteExpression, + String columnDefinition, + Long length, + Integer arrayLength, + Integer precision, + Integer scale, + Integer temporalPrecision, + boolean isLob, + boolean nullable, + boolean insertable, + boolean updateable, + boolean partitioned, + boolean isFormula, + JdbcMapping jdbcMapping) { + super( columnDefinition, length, arrayLength, precision, scale, temporalPrecision, jdbcMapping ); assert selectionExpression != null; // Save memory by using interned strings. Probability is high that we have multiple duplicate strings this.containingTableExpression = containingTableExpression == null ? null : containingTableExpression.intern(); @@ -166,6 +207,7 @@ public static SelectableMapping from( final String columnExpression; final String columnDefinition; final Long length; + final Integer arrayLength; final Integer precision; final Integer scale; final Integer temporalPrecision; @@ -176,6 +218,7 @@ public static SelectableMapping from( columnExpression = selectable.getTemplate( dialect, typeConfiguration ); columnDefinition = null; length = null; + arrayLength = null; precision = null; scale = null; temporalPrecision = null; @@ -188,6 +231,7 @@ public static SelectableMapping from( columnExpression = selectable.getText( dialect ); columnDefinition = column.getSqlType(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); scale = column.getScale(); temporalPrecision = column.getTemporalPrecision(); @@ -203,9 +247,10 @@ public static SelectableMapping from( ? null : parentPath.append( selectableName ), selectable.getCustomReadExpression(), - selectable.getWriteExpr( jdbcMapping, dialect ), + selectable.getWriteExpr( jdbcMapping, dialect, creationContext.getBootModel() ), columnDefinition, length, + arrayLength, precision, scale, temporalPrecision, diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SimpleForeignKeyDescriptor.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SimpleForeignKeyDescriptor.java index b9533c21f7c0..c71899048fb3 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SimpleForeignKeyDescriptor.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SimpleForeignKeyDescriptor.java @@ -10,6 +10,7 @@ import java.util.function.BiConsumer; import java.util.function.IntFunction; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.Hibernate; import org.hibernate.bytecode.enhance.spi.interceptor.EnhancementAsProxyLazinessInterceptor; import org.hibernate.cache.MutableCacheKeyBuilder; @@ -642,37 +643,41 @@ public boolean isPartitioned() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return keySide.getModelPart().getCustomReadExpression(); } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return keySide.getModelPart().getCustomWriteExpression(); } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return keySide.getModelPart().getColumnDefinition(); } @Override - public Long getLength() { + public @Nullable Long getLength() { return keySide.getModelPart().getLength(); } + @Override + public @Nullable Integer getArrayLength() { + return keySide.getModelPart().getArrayLength(); + } @Override - public Integer getPrecision() { + public @Nullable Integer getPrecision() { return keySide.getModelPart().getPrecision(); } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return keySide.getModelPart().getScale(); } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return keySide.getModelPart().getTemporalPrecision(); } diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SqlTypedMappingImpl.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SqlTypedMappingImpl.java index fe8556bbf376..dfbc47e79a36 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SqlTypedMappingImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/internal/SqlTypedMappingImpl.java @@ -16,13 +16,14 @@ public class SqlTypedMappingImpl implements SqlTypedMapping { private final @Nullable String columnDefinition; private final @Nullable Long length; + private final @Nullable Integer arrayLength; private final @Nullable Integer precision; private final @Nullable Integer scale; private final @Nullable Integer temporalPrecision; private final JdbcMapping jdbcMapping; public SqlTypedMappingImpl(JdbcMapping jdbcMapping) { - this( null, null, null, null, null, jdbcMapping ); + this( null, null, null, null, null, null, jdbcMapping ); } public SqlTypedMappingImpl( @@ -32,9 +33,21 @@ public SqlTypedMappingImpl( @Nullable Integer scale, @Nullable Integer temporalPrecision, JdbcMapping jdbcMapping) { + this( columnDefinition, length, null, precision, scale, temporalPrecision, jdbcMapping ); + } + + public SqlTypedMappingImpl( + @Nullable String columnDefinition, + @Nullable Long length, + @Nullable Integer arrayLength, + @Nullable Integer precision, + @Nullable Integer scale, + @Nullable Integer temporalPrecision, + JdbcMapping jdbcMapping) { // Save memory by using interned strings. Probability is high that we have multiple duplicate strings this.columnDefinition = columnDefinition == null ? null : columnDefinition.intern(); this.length = length; + this.arrayLength = arrayLength; this.precision = precision; this.scale = scale; this.temporalPrecision = temporalPrecision; @@ -51,6 +64,11 @@ public SqlTypedMappingImpl( return length; } + @Override + public @Nullable Integer getArrayLength() { + return arrayLength; + } + @Override public @Nullable Integer getPrecision() { return precision; diff --git a/hibernate-core/src/main/java/org/hibernate/persister/collection/AbstractCollectionPersister.java b/hibernate-core/src/main/java/org/hibernate/persister/collection/AbstractCollectionPersister.java index e15310d71a18..7482ca6ad683 100644 --- a/hibernate-core/src/main/java/org/hibernate/persister/collection/AbstractCollectionPersister.java +++ b/hibernate-core/src/main/java/org/hibernate/persister/collection/AbstractCollectionPersister.java @@ -381,7 +381,11 @@ public AbstractCollectionPersister( else { final Column col = (Column) selectable; elementColumnNames[j] = col.getQuotedName( dialect ); - elementColumnWriters[j] = col.getWriteExpr( elementBootDescriptor.getSelectableType( factory.getRuntimeMetamodels(), j ), dialect ); + elementColumnWriters[j] = col.getWriteExpr( + elementBootDescriptor.getSelectableType( factory.getRuntimeMetamodels(), j ), + dialect, + creationContext.getBootModel() + ); elementColumnReaders[j] = col.getReadExpr( dialect ); elementColumnReaderTemplates[j] = col.getTemplate( dialect, typeConfiguration ); elementColumnIsGettable[j] = true; diff --git a/hibernate-core/src/main/java/org/hibernate/persister/entity/AbstractEntityPersister.java b/hibernate-core/src/main/java/org/hibernate/persister/entity/AbstractEntityPersister.java index a0ba547c2c55..b711e4a918a3 100644 --- a/hibernate-core/src/main/java/org/hibernate/persister/entity/AbstractEntityPersister.java +++ b/hibernate-core/src/main/java/org/hibernate/persister/entity/AbstractEntityPersister.java @@ -5072,6 +5072,7 @@ protected EntityDiscriminatorMapping generateDiscriminatorMapping(PersistentClas final String discriminatorColumnExpression; final String columnDefinition; final Long length; + final Integer arrayLength; final Integer precision; final Integer scale; if ( getDiscriminatorFormulaTemplate() == null ) { @@ -5082,12 +5083,14 @@ protected EntityDiscriminatorMapping generateDiscriminatorMapping(PersistentClas if ( column == null ) { columnDefinition = null; length = null; + arrayLength = null; precision = null; scale = null; } else { columnDefinition = column.getSqlType(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); scale = column.getScale(); } @@ -5096,6 +5099,7 @@ protected EntityDiscriminatorMapping generateDiscriminatorMapping(PersistentClas discriminatorColumnExpression = getDiscriminatorFormulaTemplate(); columnDefinition = null; length = null; + arrayLength = null; precision = null; scale = null; } @@ -5110,6 +5114,7 @@ protected EntityDiscriminatorMapping generateDiscriminatorMapping(PersistentClas columnDefinition, null, length, + arrayLength, precision, scale, getDiscriminatorDomainType() @@ -5257,11 +5262,13 @@ protected EntityIdentifierMapping generateIdentifierMapping( } final String columnDefinition; final Long length; + final Integer arrayLength; final Integer precision; final Integer scale; if ( bootEntityDescriptor.getIdentifier() == null ) { columnDefinition = null; length = null; + arrayLength = null; precision = null; scale = null; } @@ -5269,6 +5276,7 @@ protected EntityIdentifierMapping generateIdentifierMapping( Column column = bootEntityDescriptor.getIdentifier().getColumns().get( 0 ); columnDefinition = column.getSqlType(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); scale = column.getScale(); } @@ -5282,6 +5290,7 @@ protected EntityIdentifierMapping generateIdentifierMapping( rootTableKeyColumnNames[0], columnDefinition, length, + arrayLength, precision, scale, value.isColumnInsertable( 0 ), @@ -5328,6 +5337,7 @@ protected static EntityVersionMapping generateVersionMapping( column.getText( dialect ), column.getSqlType(), column.getLength(), + column.getArrayLength(), column.getPrecision(), column.getScale(), column.getTemporalPrecision(), @@ -5373,6 +5383,7 @@ protected AttributeMapping generateNonIdAttributeMapping( "?", column.getSqlType(), column.getLength(), + column.getArrayLength(), column.getPrecision(), column.getScale(), column.getTemporalPrecision(), @@ -5394,6 +5405,7 @@ protected AttributeMapping generateNonIdAttributeMapping( final String customWriteExpr; final String columnDefinition; final Long length; + final Integer arrayLength; final Integer precision; final Integer scale; final Integer temporalPrecision; @@ -5408,6 +5420,7 @@ protected AttributeMapping generateNonIdAttributeMapping( Column column = value.getColumns().get( 0 ); columnDefinition = column.getSqlType(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); temporalPrecision = column.getTemporalPrecision(); scale = column.getScale(); @@ -5431,10 +5444,15 @@ protected AttributeMapping generateNonIdAttributeMapping( creationContext.getDialect(), creationContext.getTypeConfiguration() ); - customWriteExpr = selectable.getWriteExpr( (JdbcMapping) attrType, creationContext.getDialect() ); + customWriteExpr = selectable.getWriteExpr( + (JdbcMapping) attrType, + creationContext.getDialect(), + creationContext.getBootModel() + ); Column column = value.getColumns().get( 0 ); columnDefinition = column.getSqlType(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); temporalPrecision = column.getTemporalPrecision(); scale = column.getScale(); @@ -5450,6 +5468,7 @@ protected AttributeMapping generateNonIdAttributeMapping( customWriteExpr = null; columnDefinition = null; length = null; + arrayLength = null; precision = null; temporalPrecision = null; scale = null; @@ -5474,6 +5493,7 @@ protected AttributeMapping generateNonIdAttributeMapping( customWriteExpr, columnDefinition, length, + arrayLength, precision, scale, temporalPrecision, diff --git a/hibernate-core/src/main/java/org/hibernate/persister/entity/JoinedSubclassEntityPersister.java b/hibernate-core/src/main/java/org/hibernate/persister/entity/JoinedSubclassEntityPersister.java index 99d45013dc7a..002fd126e1d1 100644 --- a/hibernate-core/src/main/java/org/hibernate/persister/entity/JoinedSubclassEntityPersister.java +++ b/hibernate-core/src/main/java/org/hibernate/persister/entity/JoinedSubclassEntityPersister.java @@ -1023,11 +1023,13 @@ protected EntityIdentifierMapping generateIdentifierMapping( final String columnDefinition; final Long length; + final Integer arrayLength; final Integer precision; final Integer scale; if ( bootEntityDescriptor.getIdentifier() == null ) { columnDefinition = null; length = null; + arrayLength = null; precision = null; scale = null; } @@ -1035,6 +1037,7 @@ protected EntityIdentifierMapping generateIdentifierMapping( final Column column = bootEntityDescriptor.getIdentifier().getColumns().get( 0 ); columnDefinition = column.getSqlType(); length = column.getLength(); + arrayLength = column.getArrayLength(); precision = column.getPrecision(); scale = column.getScale(); } @@ -1047,6 +1050,7 @@ protected EntityIdentifierMapping generateIdentifierMapping( tableKeyColumns[0][0], columnDefinition, length, + arrayLength, precision, scale, value.isColumnInsertable( 0 ), diff --git a/hibernate-core/src/main/java/org/hibernate/persister/entity/UnionSubclassEntityPersister.java b/hibernate-core/src/main/java/org/hibernate/persister/entity/UnionSubclassEntityPersister.java index 8608ecc1ff14..e357f9202903 100644 --- a/hibernate-core/src/main/java/org/hibernate/persister/entity/UnionSubclassEntityPersister.java +++ b/hibernate-core/src/main/java/org/hibernate/persister/entity/UnionSubclassEntityPersister.java @@ -467,6 +467,7 @@ private String getSelectClauseNullString(Column col, Dialect dialect) { new SqlTypedMappingImpl( col.getTypeName(), col.getLength(), + col.getArrayLength(), col.getPrecision(), col.getScale(), col.getTemporalPrecision(), diff --git a/hibernate-core/src/main/java/org/hibernate/persister/entity/mutation/EntityTableMapping.java b/hibernate-core/src/main/java/org/hibernate/persister/entity/mutation/EntityTableMapping.java index 6e3d99c6e706..25ba54d38ac6 100644 --- a/hibernate-core/src/main/java/org/hibernate/persister/entity/mutation/EntityTableMapping.java +++ b/hibernate-core/src/main/java/org/hibernate/persister/entity/mutation/EntityTableMapping.java @@ -9,6 +9,7 @@ import java.util.Objects; import java.util.function.Consumer; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.engine.spi.SharedSessionContractImplementor; import org.hibernate.internal.util.collections.ArrayHelper; import org.hibernate.jdbc.Expectation; @@ -377,37 +378,42 @@ public boolean isPartitioned() { } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return null; } @Override - public Long getLength() { + public @Nullable Long getLength() { return null; } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { return null; } @Override - public Integer getScale() { + public @Nullable Integer getPrecision() { return null; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getScale() { return null; } @Override - public String getCustomReadExpression() { + public @Nullable Integer getTemporalPrecision() { return null; } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomReadExpression() { + return null; + } + + @Override + public @Nullable String getCustomWriteExpression() { return null; } } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/produce/function/internal/SetReturningFunctionTypeResolverBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/produce/function/internal/SetReturningFunctionTypeResolverBuilder.java index 7b8db383ac5b..98ced5f56bd8 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/produce/function/internal/SetReturningFunctionTypeResolverBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/produce/function/internal/SetReturningFunctionTypeResolverBuilder.java @@ -136,6 +136,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, null, null, + null, false, true, false, @@ -158,6 +159,7 @@ public SelectableMapping[] resolveFunctionReturnType( null, null, null, + null, false, false, false, diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java index 7213d5d2fe44..904781604411 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java @@ -391,6 +391,7 @@ import org.hibernate.sql.results.graph.internal.ImmutableFetchList; import org.hibernate.sql.results.internal.SqlSelectionImpl; import org.hibernate.sql.results.internal.StandardEntityGraphTraversalStateImpl; +import org.hibernate.type.BasicPluralType; import org.hibernate.type.BasicType; import org.hibernate.type.BindableType; import org.hibernate.type.BottomType; @@ -6261,6 +6262,7 @@ private static SqlTypedMappingImpl precision(BasicType bindable, Object bindV if ( bindValue instanceof BigInteger bigInteger ) { int precision = bindValue.toString().length() - ( bigInteger.signum() < 0 ? 1 : 0 ); return new SqlTypedMappingImpl( + null, null, null, precision, @@ -6271,6 +6273,7 @@ private static SqlTypedMappingImpl precision(BasicType bindable, Object bindV } else if ( bindValue instanceof BigDecimal bigDecimal ) { return new SqlTypedMappingImpl( + null, null, null, bigDecimal.precision(), @@ -6484,12 +6487,29 @@ public Object visitCastTarget(SqmCastTarget target) { if ( targetType instanceof BasicType basicType ) { targetType = resolveSqlTypeIndicators( this, basicType, target.getNodeJavaType() ); } - return new CastTarget( - targetType.getJdbcMapping(), - target.getLength(), - target.getPrecision(), - target.getScale() - ); + if ( targetType instanceof BasicPluralType + && target.getLength() != null + && target.getScale() == null ) { + // Assume that the given length is the array length + return new CastTarget( + targetType.getJdbcMapping(), + null, + null, + target.getLength().intValue(), + null, + null + ); + } + else { + return new CastTarget( + targetType.getJdbcMapping(), + null, + target.getLength(), + null, + target.getPrecision(), + target.getScale() + ); + } } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tuple/internal/AnonymousTupleBasicValuedModelPart.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tuple/internal/AnonymousTupleBasicValuedModelPart.java index c8ddc0839937..40266a57e163 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tuple/internal/AnonymousTupleBasicValuedModelPart.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tuple/internal/AnonymousTupleBasicValuedModelPart.java @@ -6,6 +6,7 @@ import java.util.function.BiConsumer; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.Incubating; import org.hibernate.cache.MutableCacheKeyBuilder; import org.hibernate.engine.FetchStyle; @@ -72,6 +73,7 @@ public AnonymousTupleBasicValuedModelPart( null, null, null, + null, false, true, false, @@ -164,12 +166,12 @@ public String getSelectionExpression() { } @Override - public String getCustomReadExpression() { + public @Nullable String getCustomReadExpression() { return selectableMapping.getCustomReadExpression(); } @Override - public String getCustomWriteExpression() { + public @Nullable String getCustomWriteExpression() { return selectableMapping.getCustomWriteExpression(); } @@ -204,27 +206,32 @@ public boolean hasPartitionedSelectionMapping() { } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return selectableMapping.getColumnDefinition(); } @Override - public Long getLength() { + public @Nullable Long getLength() { return selectableMapping.getLength(); } @Override - public Integer getPrecision() { + public @Nullable Integer getArrayLength() { + return selectableMapping.getArrayLength(); + } + + @Override + public @Nullable Integer getPrecision() { return selectableMapping.getPrecision(); } @Override - public Integer getScale() { + public @Nullable Integer getScale() { return selectableMapping.getScale(); } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return selectableMapping.getTemporalPrecision(); } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java index 46b3a75d00fd..440526445d30 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java @@ -468,6 +468,16 @@ public void appendSql(long value) { sqlBuffer.append( value ); } + @Override + public void appendSql(double value) { + sqlBuffer.append( value ); + } + + @Override + public void appendSql(float value) { + sqlBuffer.append( value ); + } + @Override public void appendSql(boolean value) { sqlBuffer.append( value ); @@ -5638,6 +5648,7 @@ protected void renderCasted(Expression expression) { parameter.getJdbcMapping(), sqlTypedMapping.getColumnDefinition(), sqlTypedMapping.getLength(), + sqlTypedMapping.getArrayLength(), sqlTypedMapping.getTemporalPrecision() != null ? sqlTypedMapping.getTemporalPrecision() : sqlTypedMapping.getPrecision(), @@ -5665,7 +5676,10 @@ protected void renderLiteral(Literal literal, boolean castParameter) { renderCasted( new LiteralAsParameter<>( literal, marker ) ); } else { - jdbcType.appendWriteExpression( marker, this, dialect ); + final Size size = literal.getJdbcMapping() instanceof SqlTypedMapping sqlTypedMapping + ? sqlTypedMapping.toSize() + : null; + jdbcType.appendWriteExpression( marker, size, this, dialect ); } } else { @@ -7014,7 +7028,12 @@ protected void renderParameterAsParameter(int position, JdbcParameter jdbcParame final JdbcType jdbcType = jdbcParameter.getExpressionType().getJdbcMapping( 0 ).getJdbcType(); assert jdbcType != null; final String parameterMarker = parameterMarkerStrategy.createMarker( position, jdbcType ); - jdbcType.appendWriteExpression( parameterMarker, this, dialect ); + final Size size = jdbcParameter.getExpressionType() instanceof SqlTypedMapping sqlTypedMapping + ? sqlTypedMapping.toSize() + : jdbcParameter instanceof SqlTypedMappingJdbcParameter parameter + ? parameter.getSqlTypedMapping().toSize() + : null; + jdbcType.appendWriteExpression( parameterMarker, size, this, dialect ); } protected final int addParameterBinder(JdbcParameter parameter) { diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAppender.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAppender.java index 9295629f9854..7da838354015 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAppender.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAppender.java @@ -46,6 +46,14 @@ default void appendSql(boolean value) { appendSql( String.valueOf( value ) ); } + default void appendSql(double value) { + appendSql( String.valueOf( value ) ); + } + + default void appendSql(float value) { + appendSql( String.valueOf( value ) ); + } + default void appendDoubleQuoteEscapedString(String value) { final StringBuilder sb = new StringBuilder( value.length() + 2 ); QuotingHelper.appendDoubleQuoteEscapedString( sb, value ); diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTranslatorWithMerge.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTranslatorWithMerge.java index a0232c239cf4..ee1cee398c5c 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTranslatorWithMerge.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTranslatorWithMerge.java @@ -11,6 +11,7 @@ import org.hibernate.sql.ast.tree.Statement; import org.hibernate.sql.exec.spi.JdbcOperation; import org.hibernate.sql.model.ast.ColumnValueBinding; +import org.hibernate.sql.model.ast.ColumnWriteFragment; import org.hibernate.sql.model.internal.OptionalTableUpdate; import org.hibernate.sql.model.jdbc.MergeOperation; @@ -155,7 +156,13 @@ private void renderMergeUsingQuery(OptionalTableUpdate optionalTableUpdate) { } protected void renderMergeUsingQuerySelection(ColumnValueBinding selectionBinding) { - renderCasted( selectionBinding.getValueExpression() ); + final ColumnWriteFragment valueExpression = selectionBinding.getValueExpression(); + if ( valueExpression.getExpressionType().getJdbcType().isWriteExpressionTyped( getDialect() ) ) { + valueExpression.accept( this ); + } + else { + renderCasted( valueExpression ); + } appendSql( " " ); appendSql( selectionBinding.getColumnReference().getColumnExpression() ); } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTranslatorWithUpsert.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTranslatorWithUpsert.java index bf0ef1795e34..b914a4270dab 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTranslatorWithUpsert.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/SqlAstTranslatorWithUpsert.java @@ -12,6 +12,7 @@ import org.hibernate.sql.exec.spi.JdbcOperation; import org.hibernate.sql.model.MutationOperation; import org.hibernate.sql.model.ast.ColumnValueBinding; +import org.hibernate.sql.model.ast.ColumnWriteFragment; import org.hibernate.sql.model.internal.OptionalTableUpdate; import org.hibernate.sql.model.jdbc.DeleteOrUpsertOperation; import org.hibernate.sql.model.jdbc.UpsertOperation; @@ -112,14 +113,26 @@ protected void renderMergeSource(OptionalTableUpdate optionalTableUpdate) { columnList.append( ", " ); } columnList.append( keyBinding.getColumnReference().getColumnExpression() ); - renderCasted( keyBinding.getValueExpression() ); + final ColumnWriteFragment valueExpression = keyBinding.getValueExpression(); + if ( valueExpression.getExpressionType().getJdbcType().isWriteExpressionTyped( getDialect() ) ) { + valueExpression.accept( this ); + } + else { + renderCasted( valueExpression ); + } } for ( int i = 0; i < valueBindings.size(); i++ ) { appendSql( ", " ); columnList.append( ", " ); final ColumnValueBinding valueBinding = valueBindings.get( i ); columnList.append( valueBinding.getColumnReference().getColumnExpression() ); - renderCasted( valueBinding.getValueExpression() ); + final ColumnWriteFragment valueExpression = valueBinding.getValueExpression(); + if ( valueExpression.getExpressionType().getJdbcType().isWriteExpressionTyped( getDialect() ) ) { + valueExpression.accept( this ); + } + else { + renderCasted( valueExpression ); + } } appendSql( ") " ); diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/StringBuilderSqlAppender.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/StringBuilderSqlAppender.java index 83c062a3f7f4..7b8e1492044c 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/StringBuilderSqlAppender.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/StringBuilderSqlAppender.java @@ -48,6 +48,16 @@ public void appendSql(boolean value) { sb.append( value ); } + @Override + public void appendSql(double value) { + sb.append( value ); + } + + @Override + public void appendSql(float value) { + sb.append( value ); + } + @Override public Appendable append(CharSequence csq) { sb.append( csq ); diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CastTarget.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CastTarget.java index 09616d050367..4e33bec18f28 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CastTarget.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CastTarget.java @@ -4,6 +4,7 @@ */ package org.hibernate.sql.ast.tree.expression; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.metamodel.mapping.SqlTypedMapping; @@ -15,33 +16,43 @@ */ public class CastTarget implements Expression, SqlAstNode, SqlTypedMapping { private final JdbcMapping type; - private final String sqlType; - private final Long length; - private final Integer precision; - private final Integer scale; + private final @Nullable String sqlType; + private final @Nullable Long length; + private final @Nullable Integer arrayLength; + private final @Nullable Integer precision; + private final @Nullable Integer scale; public CastTarget(JdbcMapping type) { - this( type, null, null, null ); + this( type, null, null, null, null, null ); } - public CastTarget(JdbcMapping type, Long length, Integer precision, Integer scale) { + public CastTarget(JdbcMapping type, @Nullable Long length, @Nullable Integer precision, @Nullable Integer scale) { this( type, null, length, precision, scale ); } - public CastTarget(JdbcMapping type, String sqlType, Long length, Integer precision, Integer scale) { + public CastTarget(JdbcMapping type, @Nullable Long length, @Nullable Integer arrayLength, @Nullable Integer precision, @Nullable Integer scale) { + this( type, null, length, arrayLength, precision, scale ); + } + + public CastTarget(JdbcMapping type, @Nullable String sqlType, @Nullable Long length, @Nullable Integer precision, @Nullable Integer scale) { + this( type, sqlType, length, null, precision, scale ); + } + + public CastTarget(JdbcMapping type, @Nullable String sqlType, @Nullable Long length, @Nullable Integer arrayLength, @Nullable Integer precision, @Nullable Integer scale) { this.type = type; this.sqlType = sqlType; this.length = length; + this.arrayLength = arrayLength; this.precision = precision; this.scale = scale; } - public String getSqlType() { + public @Nullable String getSqlType() { return sqlType; } @Override - public String getColumnDefinition() { + public @Nullable String getColumnDefinition() { return sqlType; } @@ -50,20 +61,28 @@ public JdbcMapping getJdbcMapping() { return type; } - public Long getLength() { + @Override + public @Nullable Long getLength() { return length; } - public Integer getPrecision() { + @Override + public @Nullable Integer getArrayLength() { + return arrayLength; + } + + @Override + public @Nullable Integer getPrecision() { return precision; } @Override - public Integer getTemporalPrecision() { + public @Nullable Integer getTemporalPrecision() { return null; } - public Integer getScale() { + @Override + public @Nullable Integer getScale() { return scale; } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/LiteralAsParameter.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/LiteralAsParameter.java index 218428f8f47e..4057a0eb0043 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/LiteralAsParameter.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/LiteralAsParameter.java @@ -4,8 +4,10 @@ */ package org.hibernate.sql.ast.tree.expression; +import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.metamodel.mapping.JdbcMappingContainer; +import org.hibernate.metamodel.mapping.SqlTypedMapping; import org.hibernate.sql.ast.SqlAstTranslator; import org.hibernate.sql.ast.spi.SqlAppender; @@ -27,6 +29,9 @@ public LiteralAsParameter(Literal literal, String parameterMarker) { @Override public void renderToSql(SqlAppender sqlAppender, SqlAstTranslator walker, SessionFactoryImplementor sessionFactory) { + final Size size = literal.getExpressionType() instanceof SqlTypedMapping sqlTypedMapping + ? sqlTypedMapping.toSize() + : null; literal.getJdbcMapping().getJdbcType().appendWriteExpression( parameterMarker, sqlAppender, diff --git a/hibernate-core/src/main/java/org/hibernate/type/BasicCollectionType.java b/hibernate-core/src/main/java/org/hibernate/type/BasicCollectionType.java index 70160c445fc5..ad34e08ec35c 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/BasicCollectionType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/BasicCollectionType.java @@ -35,6 +35,16 @@ public BasicCollectionType( this.name = determineName( collectionTypeDescriptor, baseDescriptor ); } + public BasicCollectionType( + BasicType baseDescriptor, + JdbcType arrayJdbcType, + JavaType collectionTypeDescriptor, + String typeName) { + super( arrayJdbcType, collectionTypeDescriptor ); + this.baseDescriptor = baseDescriptor; + this.name = typeName; + } + private static String determineName(BasicCollectionJavaType collectionTypeDescriptor, BasicType baseDescriptor) { final String elementTypeName = determineElementTypeName( baseDescriptor ); switch ( collectionTypeDescriptor.getSemantics().getCollectionClassification() ) { diff --git a/hibernate-core/src/main/java/org/hibernate/type/BasicTypeRegistry.java b/hibernate-core/src/main/java/org/hibernate/type/BasicTypeRegistry.java index 9f9a881aed7d..5138ecd4befc 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/BasicTypeRegistry.java +++ b/hibernate-core/src/main/java/org/hibernate/type/BasicTypeRegistry.java @@ -15,11 +15,13 @@ import org.hibernate.boot.spi.BootstrapContext; import org.hibernate.internal.CoreLogging; import org.hibernate.internal.CoreMessageLogger; +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; import org.hibernate.type.descriptor.java.BasicPluralJavaType; import org.hibernate.type.descriptor.java.ImmutableMutabilityPlan; import org.hibernate.type.descriptor.java.JavaType; import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.DelegatingJdbcTypeIndicators; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; import org.hibernate.type.internal.BasicTypeImpl; @@ -170,8 +172,48 @@ private BasicType resolvedType(ArrayJdbcType arrayType, BasicPluralJavaTy typeConfiguration, indicators.getDialect(), elementType, - null, - indicators + new ColumnTypeInformation() { + @Override + public Boolean getNullable() { + return null; + } + + @Override + public int getTypeCode() { + return arrayType.getDefaultSqlTypeCode(); + } + + @Override + public String getTypeName() { + return null; + } + + @Override + public int getColumnSize() { + return 0; + } + + @Override + public int getDecimalDigits() { + return 0; + } + }, + new DelegatingJdbcTypeIndicators( indicators ) { + @Override + public Integer getExplicitJdbcTypeCode() { + return arrayType.getDefaultSqlTypeCode(); + } + + @Override + public int getPreferredSqlTypeCodeForArray() { + return arrayType.getDefaultSqlTypeCode(); + } + + @Override + public int getPreferredSqlTypeCodeForArray(int elementSqlTypeCode) { + return arrayType.getDefaultSqlTypeCode(); + } + } ); if ( resolvedType instanceof BasicPluralType ) { register( resolvedType ); diff --git a/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java b/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java index f2354b799d16..a465f4bee052 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java @@ -681,10 +681,10 @@ public class SqlTypes { /** - * A type code representing an {@code embedding vector} type for databases + * A type code representing a {@code vector} type for databases * like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL}, * {@link org.hibernate.dialect.OracleDialect Oracle 23ai} and {@link org.hibernate.dialect.MariaDBDialect MariaDB}. - * An embedding vector essentially is a {@code float[]} with a fixed size. + * A vector essentially is a {@code float[]} with a fixed length. * * @since 6.4 */ @@ -701,10 +701,39 @@ public class SqlTypes { public static final int VECTOR_FLOAT32 = 10_002; /** - * A type code representing a double-precision floating-point type for Oracle 23ai database. + * A type code representing a double-precision floating-point vector type for Oracle 23ai database. */ public static final int VECTOR_FLOAT64 = 10_003; + /** + * A type code representing a bit precision vector type for databases + * like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and + * {@link org.hibernate.dialect.OracleDialect Oracle 23ai}. + */ + public static final int VECTOR_BINARY = 10_004; + + /** + * A type code representing a half-precision floating-point vector type for databases + * like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL}. + */ + public static final int VECTOR_FLOAT16 = 10_005; + + /** + * A type code representing a sparse single-byte integer vector type for Oracle 23ai database. + */ + public static final int SPARSE_VECTOR_INT8 = 10_006; + + /** + * A type code representing a sparse single-precision floating-point vector type for Oracle 23ai database. + */ + public static final int SPARSE_VECTOR_FLOAT32 = 10_007; + + /** + * A type code representing a sparse double-precision floating-point vector type for Oracle 23ai database. + */ + public static final int SPARSE_VECTOR_FLOAT64 = 10_008; + + private SqlTypes() { } diff --git a/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java b/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java index fc521d5146f6..8c16770b3caf 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java @@ -749,6 +749,14 @@ private StandardBasicTypes() { "byte_vector", byte[].class, SqlTypes.VECTOR_INT8 ); + /** + * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_FLOAT16 VECTOR_FLOAT16}, + * specifically for embedding half-precision floating-point (16-bits) vectors like provided by the PostgreSQL extension pgvector. + */ + public static final BasicTypeReference VECTOR_FLOAT16 = new BasicTypeReference<>( + "float16_vector", float[].class, SqlTypes.VECTOR_FLOAT16 + ); + /** * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, * specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai. @@ -765,6 +773,38 @@ private StandardBasicTypes() { "double_vector", double[].class, SqlTypes.VECTOR_FLOAT64 ); + /** + * The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_BINARY VECTOR_BIT}, + * specifically for embedding bit vectors like provided by Oracle 23ai. + */ + public static final BasicTypeReference VECTOR_BINARY = new BasicTypeReference<>( + "binary_vector", byte[].class, SqlTypes.VECTOR_BINARY + ); + +// /** +// * The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_INT8 VECTOR_INT8}, +// * specifically for embedding integer vectors (8-bits) like provided by Oracle 23ai. +// */ +// public static final BasicTypeReference SPARSE_VECTOR_INT8 = new BasicTypeReference<>( +// "sparse_byte_vector", byte[].class, SqlTypes.SPARSE_VECTOR_INT8 +// ); +// +// /** +// * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, +// * specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai. +// */ +// public static final BasicTypeReference SPARSE_VECTOR_FLOAT32 = new BasicTypeReference<>( +// "sparse_float_vector", float[].class, SqlTypes.SPARSE_VECTOR_FLOAT32 +// ); +// +// /** +// * The standard Hibernate type for mapping {@code double[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, +// * specifically for embedding double-precision floating-point (64-bits) vectors like provided by Oracle 23ai. +// */ +// public static final BasicTypeReference SPARSE_VECTOR_FLOAT64 = new BasicTypeReference<>( +// "sparse_double_vector", double[].class, SqlTypes.SPARSE_VECTOR_FLOAT64 +// ); + public static void prime(TypeConfiguration typeConfiguration) { BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); @@ -1286,6 +1326,34 @@ public static void prime(TypeConfiguration typeConfiguration) { "byte_vector" ); + handle( + VECTOR_BINARY, + null, + basicTypeRegistry, + "bit_vector" + ); + +// handle( +// SPARSE_VECTOR_FLOAT32, +// null, +// basicTypeRegistry, +// "sparse_float_vector" +// ); +// +// handle( +// SPARSE_VECTOR_FLOAT64, +// null, +// basicTypeRegistry, +// "sparse_double_vector" +// ); +// +// handle( +// SPARSE_VECTOR_INT8, +// null, +// basicTypeRegistry, +// "sparse_byte_vector" +// ); + // Specialized version handlers diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/PrimitiveByteArrayJavaType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/PrimitiveByteArrayJavaType.java index f0e41b167df6..d68bec284276 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/PrimitiveByteArrayJavaType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/java/PrimitiveByteArrayJavaType.java @@ -116,6 +116,13 @@ public X unwrap(byte[] value, Class type, WrapperOptions options) { if ( Blob.class.isAssignableFrom( type ) ) { return (X) options.getLobCreator().createBlob( value ); } + if ( type.isAssignableFrom( Byte[].class ) ) { + final Byte[] array = new Byte[value.length]; + for ( int i = 0; i < value.length; i++ ) { + array[i] = value[i]; + } + return (X) array; + } throw unknownUnwrap( type ); } @@ -142,6 +149,13 @@ else if ( value instanceof Byte byteValue ) { // Support binding a single element as parameter value return new byte[]{ byteValue }; } + else if ( value instanceof Byte[] array ) { + final byte[] bytes = new byte[array.length]; + for ( int i = 0; i < array.length; i++ ) { + bytes[i] = array[i]; + } + return bytes; + } throw unknownWrap( value.getClass() ); } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/ArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/ArrayJdbcType.java index 6f1d3fc74f63..0f243ea2aa58 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/ArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/ArrayJdbcType.java @@ -81,7 +81,7 @@ public JavaType getJdbcRecommendedJavaTypeMapping( } } - private static JavaType elementJavaType(JavaType javaTypeDescriptor) { + protected static JavaType elementJavaType(JavaType javaTypeDescriptor) { if ( javaTypeDescriptor instanceof ByteArrayJavaType ) { // Special handling needed for Byte[], because that would conflict with the VARBINARY mapping return ByteJavaType.INSTANCE; diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/H2FormatJsonJdbcType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/H2FormatJsonJdbcType.java index 62769c923674..37cdb64585d1 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/H2FormatJsonJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/H2FormatJsonJdbcType.java @@ -4,7 +4,9 @@ */ package org.hibernate.type.descriptor.jdbc; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.metamodel.mapping.EmbeddableMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; import org.hibernate.sql.ast.spi.SqlAppender; @@ -41,8 +43,17 @@ public AggregateJdbcType resolveAggregateJdbcType( } @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { appender.append( writeExpression ); appender.append( " format json" ); } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java index c3bb08aca6cc..cadfedf8a669 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java @@ -9,10 +9,12 @@ import java.sql.SQLException; import java.sql.Types; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.Incubating; import org.hibernate.boot.model.relational.Database; import org.hibernate.dialect.Dialect; import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.query.sqm.CastType; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.sql.ast.spi.StringBuilderSqlAppender; @@ -179,8 +181,9 @@ default Expression wrapTopLevelSelectionExpression(Expression expression) { /** * Wraps the write expression to be able to write values with this JdbcType's ValueBinder. * @since 6.2 + * @deprecated Use {@link #wrapWriteExpression(String, Size, Dialect)} */ - @Incubating + @Deprecated(forRemoval = true, since = "7.2") default String wrapWriteExpression(String writeExpression, Dialect dialect) { final StringBuilder sb = new StringBuilder( writeExpression.length() ); appendWriteExpression( writeExpression, new StringBuilderSqlAppender( sb ), dialect ); @@ -190,12 +193,43 @@ default String wrapWriteExpression(String writeExpression, Dialect dialect) { /** * Append the write expression wrapped in a way to be able to write values with this JdbcType's ValueBinder. * @since 6.2 + * @deprecated Use {@link #appendWriteExpression(String, Size, SqlAppender, Dialect)} instead */ - @Incubating + @Deprecated(forRemoval = true, since = "7.2") default void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { appender.append( writeExpression ); } + /** + * Wraps the write expression to be able to write values with this JdbcType's ValueBinder. + * @since 7.2 + */ + @Incubating + default String wrapWriteExpression(String writeExpression, @Nullable Size size, Dialect dialect) { + final StringBuilder sb = new StringBuilder( writeExpression.length() ); + appendWriteExpression( writeExpression, size, new StringBuilderSqlAppender( sb ), dialect ); + return sb.toString(); + } + + /** + * Append the write expression wrapped in a way to be able to write values with this JdbcType's ValueBinder. + * @since 7.2 + */ + @Incubating + default void appendWriteExpression(String writeExpression, @Nullable Size size, SqlAppender appender, Dialect dialect) { + appendWriteExpression( writeExpression, appender, dialect ); + } + + /** + * Whether the write expression is typed. + * This is used to determine if a parameter expression needs a cast in e.g. a select item context. + * @since 7.2 + */ + @Incubating + default boolean isWriteExpressionTyped(Dialect dialect) { + return false; + } + default boolean isInteger() { int typeCode = getDdlTypeCode(); return isIntegral(typeCode) @@ -367,6 +401,32 @@ default String getExtraCreateTableInfo(JavaType javaType, String columnName, return ""; } + /** + * Returns the cast pattern from the given source type to this type, or {@code null} if not possible. + * + * @param sourceMapping The source type + * @param size The size of this target type + * @return The cast pattern or null + * @since 7.2 + */ + @Incubating + default @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return null; + } + + /** + * Returns the cast pattern from this type to the given target type, or {@code null} if not possible. + * + * @param targetJdbcMapping The target type + * @param size The size of this source type + * @return The cast pattern or null + * @since 7.2 + */ + @Incubating + default @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) { + return null; + } + @Incubating default boolean isComparable() { final int code = getDefaultSqlTypeCode(); diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlAsStringArrayJdbcType.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlAsStringArrayJdbcType.java index 4cdc0ad40d1b..337d38af5135 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlAsStringArrayJdbcType.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlAsStringArrayJdbcType.java @@ -174,4 +174,20 @@ protected X doExtract(CallableStatement statement, String name, WrapperOptions o }; } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof XmlAsStringArrayJdbcType jdbcType + && ddlTypeCode == jdbcType.ddlTypeCode + && nationalized == jdbcType.nationalized; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Boolean.hashCode( nationalized ); + result = 31 * result + ddlTypeCode; + return result; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlHelper.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlHelper.java index 1e0abde4045c..07d0af20058a 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlHelper.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/XmlHelper.java @@ -1028,6 +1028,16 @@ public void appendSql(boolean value) { sb.append( value ); } + @Override + public void appendSql(double value) { + sb.append( value ); + } + + @Override + public void appendSql(float value) { + sb.append( value ); + } + @Override public String toString() { return sb.toString(); diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterArray.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterArray.java index 2c126d538925..b09b2004c506 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterArray.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterArray.java @@ -22,7 +22,7 @@ public JdbcLiteralFormatterArray(JavaType javaType, JdbcLiteralFormatter e } @Override - public void appendJdbcLiteral(SqlAppender appender, Object value, Dialect dialect, WrapperOptions wrapperOptions) { + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { dialect.appendArrayLiteral( appender, unwrapArray( value, wrapperOptions ), elementFormatter, wrapperOptions ); } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterBinary.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterBinary.java index dd4b48788d7c..2e5696b27318 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterBinary.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterBinary.java @@ -22,7 +22,7 @@ public JdbcLiteralFormatterBinary(JavaType javaType) { } @Override - public void appendJdbcLiteral(SqlAppender appender, Object value, Dialect dialect, WrapperOptions wrapperOptions) { + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { dialect.appendBinaryLiteral( appender, unwrap( value, byte[].class, wrapperOptions ) ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterBoolean.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterBoolean.java index 11ef85384bee..9259f2bdbf1f 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterBoolean.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterBoolean.java @@ -22,7 +22,7 @@ public JdbcLiteralFormatterBoolean(JavaType javaType) { } @Override - public void appendJdbcLiteral(SqlAppender appender, Object value, Dialect dialect, WrapperOptions wrapperOptions) { + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { dialect.appendBooleanValueString( appender, unwrap( value, Boolean.class, wrapperOptions ) ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterCharacterData.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterCharacterData.java index de36c7bb499e..e5ab857cad29 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterCharacterData.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterCharacterData.java @@ -31,7 +31,7 @@ public JdbcLiteralFormatterCharacterData(JavaType javaType, boolean isNationa } @Override - public void appendJdbcLiteral(SqlAppender appender, Object value, Dialect dialect, WrapperOptions wrapperOptions) { + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { final String literalValue = unwrap( value, String.class, wrapperOptions ); if ( isNationalized ) { appender.appendSql( NATIONALIZED_PREFIX ); diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterNumericData.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterNumericData.java index 74443a6fd738..5f67212c4ca2 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterNumericData.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterNumericData.java @@ -25,7 +25,7 @@ public JdbcLiteralFormatterNumericData(JavaType javaType, Class javaType, TemporalType precision } @Override - public void appendJdbcLiteral(SqlAppender appender, Object value, Dialect dialect, WrapperOptions options) { + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions options) { final TimeZone jdbcTimeZone = getJdbcTimeZone( options ); // for performance reasons, avoid conversions if we can if ( value instanceof java.util.Date date ) { diff --git a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterUUIDData.java b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterUUIDData.java index a0aa2035f32e..961d08be5970 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterUUIDData.java +++ b/hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/internal/JdbcLiteralFormatterUUIDData.java @@ -23,7 +23,7 @@ public JdbcLiteralFormatterUUIDData(JavaType javaType) { } @Override - public void appendJdbcLiteral(SqlAppender appender, Object value, Dialect dialect, WrapperOptions wrapperOptions) { + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { dialect.appendUUIDLiteral( appender, unwrap( value, UUID.class, wrapperOptions ) ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/type/format/StringJsonDocumentWriter.java b/hibernate-core/src/main/java/org/hibernate/type/format/StringJsonDocumentWriter.java index a7e600d6ac0e..47f7f7012ed0 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/format/StringJsonDocumentWriter.java +++ b/hibernate-core/src/main/java/org/hibernate/type/format/StringJsonDocumentWriter.java @@ -422,6 +422,16 @@ public void appendSql(boolean value) { sb.append( value ); } + @Override + public void appendSql(double value) { + sb.append( value ); + } + + @Override + public void appendSql(float value) { + sb.append( value ); + } + @Override public String toString() { return sb.toString(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/dialect/PostgreSQLDialectTestCase.java b/hibernate-core/src/test/java/org/hibernate/orm/test/dialect/PostgreSQLDialectTestCase.java index e5812f246508..6cd45d703c94 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/dialect/PostgreSQLDialectTestCase.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/dialect/PostgreSQLDialectTestCase.java @@ -181,6 +181,7 @@ public void testTextVsVarchar() { null, null, null, + null, typeConfiguration.getBasicTypeForJavaType( String.class ) ), typeConfiguration diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicListTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicListTest.java index c69c69a93de0..51609c2af1b0 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicListTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicListTest.java @@ -145,7 +145,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = integerListType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = integerListType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); Query tq = em.createNativeQuery( "SELECT * FROM table_with_integer_list t WHERE the_list " + op + " " + param, TableWithIntegerList.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicSortedSetTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicSortedSetTest.java index fc6f249936d3..f83abed5a9d9 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicSortedSetTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BasicSortedSetTest.java @@ -146,7 +146,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = integerSortedSetType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = integerSortedSetType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); Query tq = em.createNativeQuery( "SELECT * FROM table_with_integer_sorted_set t WHERE the_sorted_set " + op + " " + param, TableWithIntegerSortedSet.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BooleanArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BooleanArrayTest.java index 0c9f383026d1..fb0bf020b3a3 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/BooleanArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/BooleanArrayTest.java @@ -152,7 +152,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_boolean_arrays t WHERE the_array " + op + " " + param, TableWithBooleanArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/DateArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/DateArrayTest.java index ca77782791d6..72b0f9e98995 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/DateArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/DateArrayTest.java @@ -165,7 +165,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_date_arrays t WHERE the_array " + op + " " + param, TableWithDateArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/DoubleArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/DoubleArrayTest.java index 415a7887c779..f3a28fa7f4e9 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/DoubleArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/DoubleArrayTest.java @@ -156,7 +156,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_double_arrays t WHERE the_array " + op + " " + param, TableWithDoubleArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumArrayTest.java index e5a3be851e58..97884be8941e 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumArrayTest.java @@ -149,7 +149,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_enum_arrays t WHERE the_array " + op + " " + param, TableWithEnumArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetConverterTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetConverterTest.java index 8c6baf50cd4a..36df8c0f8e51 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetConverterTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetConverterTest.java @@ -154,7 +154,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = enumSetType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = enumSetType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); Query tq = em.createNativeQuery( "SELECT * FROM table_with_enum_set_convert t WHERE the_set " + op + " " + param, TableWithEnumSetConverter.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetTest.java index d21b6ffba8cf..2d89fa12a4c1 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/EnumSetTest.java @@ -147,7 +147,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = enumSetType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = enumSetType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); Query tq = em.createNativeQuery( "SELECT * FROM table_with_enum_set t WHERE the_set " + op + " " + param, TableWithEnumSet.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/FloatArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/FloatArrayTest.java index b21c00c8104c..e73ae6e02feb 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/FloatArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/FloatArrayTest.java @@ -144,7 +144,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_float_arrays t WHERE the_array " + op + " " + param, TableWithFloatArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/IntegerArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/IntegerArrayTest.java index f79cb450e125..3cb2ea9e1f6f 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/IntegerArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/IntegerArrayTest.java @@ -144,7 +144,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_integer_arrays t WHERE the_array " + op + " " + param, TableWithIntegerArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/LongArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/LongArrayTest.java index 6fa16b15e0c8..733b9a1b49c4 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/LongArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/LongArrayTest.java @@ -149,7 +149,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_bigint_arrays t WHERE the_array " + op + " " + param, TableWithLongArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/ShortArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/ShortArrayTest.java index 7679281ee475..005080e4ba2e 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/ShortArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/ShortArrayTest.java @@ -144,7 +144,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_short_arrays t WHERE the_array " + op + " " + param, TableWithShortArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/StringArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/StringArrayTest.java index 0e2f750ce7b6..3d6ff6db0df9 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/StringArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/StringArrayTest.java @@ -136,7 +136,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_string_arrays t WHERE the_array " + op + " " + param, TableWithStringArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimeArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimeArrayTest.java index 42382510a32b..59e88f3bd841 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimeArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimeArrayTest.java @@ -158,7 +158,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_time_arrays t WHERE the_array " + op + " " + param, TableWithTimeArrays.class diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimestampArrayTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimestampArrayTest.java index 135caf6c751a..5cb4d35b13ba 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimestampArrayTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/type/TimestampArrayTest.java @@ -163,7 +163,7 @@ public void testNativeQuery(SessionFactoryScope scope) { scope.inSession( em -> { final Dialect dialect = em.getDialect(); final String op = dialect.supportsDistinctFromPredicate() ? "IS NOT DISTINCT FROM" : "="; - final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", dialect ); + final String param = arrayType.getJdbcType().wrapWriteExpression( ":data", null, dialect ); TypedQuery tq = em.createNativeQuery( "SELECT * FROM table_with_timestamp_arrays t WHERE the_array " + op + " " + param, TableWithTimestampArrays.class diff --git a/hibernate-spatial/src/main/java/org/hibernate/spatial/dialect/postgis/AbstractCastingPostGISJdbcType.java b/hibernate-spatial/src/main/java/org/hibernate/spatial/dialect/postgis/AbstractCastingPostGISJdbcType.java index 14464a86fb75..7d964e70e295 100644 --- a/hibernate-spatial/src/main/java/org/hibernate/spatial/dialect/postgis/AbstractCastingPostGISJdbcType.java +++ b/hibernate-spatial/src/main/java/org/hibernate/spatial/dialect/postgis/AbstractCastingPostGISJdbcType.java @@ -10,7 +10,9 @@ import java.sql.SQLException; import java.sql.Types; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; import org.hibernate.spatial.GeometryLiteralFormatter; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.descriptor.ValueBinder; @@ -56,6 +58,7 @@ public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaType) @Override public void appendWriteExpression( String writeExpression, + @Nullable Size size, SqlAppender appender, Dialect dialect) { appender.append( getConstructorFunction() ); @@ -64,6 +67,11 @@ public void appendWriteExpression( appender.append( ')' ); } + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + public Geometry toGeometry(String wkt) { if ( wkt == null ) { return null; diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java b/hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java index b968c4fbcf9d..dea094d480c9 100644 --- a/hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java +++ b/hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java @@ -13,9 +13,11 @@ import org.hibernate.boot.internal.MetadataBuilderImpl; import org.hibernate.boot.internal.NamedProcedureCallDefinitionImpl; import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; import org.hibernate.boot.model.IdentifierGeneratorDefinition; import org.hibernate.boot.model.NamedEntityGraphDefinition; import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; import org.hibernate.boot.model.TypeDefinition; import org.hibernate.boot.model.TypeDefinitionRegistry; import org.hibernate.boot.model.convert.spi.ConverterAutoApplyHandler; @@ -97,6 +99,7 @@ import org.hibernate.type.descriptor.java.StringJavaType; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.VarcharJdbcType; +import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry; import org.hibernate.type.internal.BasicTypeImpl; import org.hibernate.type.spi.TypeConfiguration; import org.hibernate.usertype.CompositeUserType; @@ -105,6 +108,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.Set; import java.util.UUID; import java.util.function.Consumer; @@ -685,6 +689,7 @@ public boolean apply(Dialect dialect) { null, null, null, + null, new BasicTypeImpl<>( StringJavaType.INSTANCE, VarcharJdbcType.INSTANCE ) ), new TypeConfiguration() @@ -712,6 +717,7 @@ public boolean apply(Dialect dialect) { null, null, null, + null, new BasicTypeImpl<>( StringJavaType.INSTANCE, VarcharJdbcType.INSTANCE ) ), new TypeConfiguration() @@ -739,6 +745,7 @@ public boolean apply(Dialect dialect) { null, null, null, + null, new BasicTypeImpl<>( StringJavaType.INSTANCE, VarcharJdbcType.INSTANCE ) ), new TypeConfiguration() @@ -1081,6 +1088,138 @@ public boolean apply(Dialect dialect) { } } + public static class SupportsVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR ); + } + } + + public static class SupportsFloat16VectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT16 ); + } + } + + public static class SupportsFloatVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT32 ); + } + } + + public static class SupportsDoubleVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT64 ); + } + } + + public static class SupportsByteVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_INT8 ); + } + } + + public static class SupportsBinaryVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.VECTOR_BINARY ); + } + } + + public static class SupportsSparseFloatVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.SPARSE_VECTOR_FLOAT32 ); + } + } + + public static class SupportsSparseDoubleVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.SPARSE_VECTOR_FLOAT64 ); + } + } + + public static class SupportsSparseByteVectorType implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesDdlType( dialect, SqlTypes.SPARSE_VECTOR_INT8 ); + } + } + + public static class SupportsCosineDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "cosine_distance" ); + } + } + + public static class SupportsEuclideanDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "euclidean_distance" ); + } + } + + public static class SupportsEuclideanSquaredDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "euclidean_squared_distance" ); + } + } + + public static class SupportsTaxicabDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "taxicab_distance" ); + } + } + + public static class SupportsHammingDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "hamming_distance" ); + } + } + + public static class SupportsJaccardDistance implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "jaccard_distance" ); + } + } + + public static class SupportsInnerProduct implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "inner_product" ); + } + } + + public static class SupportsVectorDims implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "vector_dims" ); + } + } + + public static class SupportsVectorNorm implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "vector_norm" ); + } + } + + public static class SupportsL2Norm implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "l2_norm" ); + } + } + + public static class SupportsL2Normalize implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "l2_normalize" ); + } + } + + public static class SupportsSubvector implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "subvector" ); + } + } + + public static class SupportsBinaryQuantize implements DialectFeatureCheck { + public boolean apply(Dialect dialect) { + return definesFunction( dialect, "binary_quantize" ); + } + } + public static class IsJtds implements DialectFeatureCheck { public boolean apply(Dialect dialect) { return dialect instanceof SybaseDialect && ( (SybaseDialect) dialect ).getDriverKind() == SybaseDriverKind.JTDS; @@ -1146,7 +1285,7 @@ public boolean apply(Dialect dialect) { } } - private static final HashMap FUNCTION_REGISTRIES = new HashMap<>(); + private static final HashMap FUNCTION_CONTRIBUTIONS = new HashMap<>(); public static boolean definesFunction(Dialect dialect, String functionName) { return getSqmFunctionRegistry( dialect ).findFunctionDescriptor( functionName ) != null; @@ -1156,6 +1295,11 @@ public static boolean definesSetReturningFunction(Dialect dialect, String functi return getSqmFunctionRegistry( dialect ).findSetReturningFunctionDescriptor( functionName ) != null; } + public static boolean definesDdlType(Dialect dialect, int typeCode) { + final DdlTypeRegistry ddlTypeRegistry = getFunctionContributions( dialect ).typeConfiguration.getDdlTypeRegistry(); + return ddlTypeRegistry.getDescriptor( typeCode ) != null; + } + public static class SupportsSubqueryInSelect implements DialectFeatureCheck { @Override public boolean apply(Dialect dialect) { @@ -1177,24 +1321,33 @@ public boolean apply(Dialect dialect) { } } - private static SqmFunctionRegistry getSqmFunctionRegistry(Dialect dialect) { - SqmFunctionRegistry sqmFunctionRegistry = FUNCTION_REGISTRIES.get( dialect ); - if ( sqmFunctionRegistry == null ) { + return getFunctionContributions( dialect ).functionRegistry; + } + + private static FakeFunctionContributions getFunctionContributions(Dialect dialect) { + FakeFunctionContributions functionContributions = FUNCTION_CONTRIBUTIONS.get( dialect ); + if ( functionContributions == null ) { final TypeConfiguration typeConfiguration = new TypeConfiguration(); final SqmFunctionRegistry functionRegistry = new SqmFunctionRegistry(); typeConfiguration.scope( new FakeMetadataBuildingContext( typeConfiguration, functionRegistry ) ); final FakeTypeContributions typeContributions = new FakeTypeContributions( typeConfiguration ); - final FakeFunctionContributions functionContributions = new FakeFunctionContributions( + functionContributions = new FakeFunctionContributions( dialect, typeConfiguration, functionRegistry ); dialect.contribute( typeContributions, typeConfiguration.getServiceRegistry() ); dialect.initializeFunctionRegistry( functionContributions ); - FUNCTION_REGISTRIES.put( dialect, sqmFunctionRegistry = functionContributions.functionRegistry ); + for ( TypeContributor typeContributor : ServiceLoader.load( TypeContributor.class ) ) { + typeContributor.contribute( typeContributions, typeConfiguration.getServiceRegistry() ); + } + for ( FunctionContributor functionContributor : ServiceLoader.load( FunctionContributor.class ) ) { + functionContributor.contributeFunctions( functionContributions ); + } + FUNCTION_CONTRIBUTIONS.put( dialect, functionContributions ); } - return sqmFunctionRegistry; + return functionContributions; } public static class FakeTypeContributions implements TypeContributions { diff --git a/hibernate-vector/hibernate-vector.gradle b/hibernate-vector/hibernate-vector.gradle index 61e10169bc30..9419c07354c5 100644 --- a/hibernate-vector/hibernate-vector.gradle +++ b/hibernate-vector/hibernate-vector.gradle @@ -13,6 +13,8 @@ description = 'Hibernate\'s extensions for vector support' dependencies { api project( ':hibernate-core' ) + compileOnly jdbcLibs.mssql + testImplementation project( ':hibernate-testing' ) testImplementation project( path: ':hibernate-core', configuration: 'tests' ) } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/AbstractSparseVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/AbstractSparseVector.java new file mode 100644 index 000000000000..02506b8b11dd --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/AbstractSparseVector.java @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.hibernate.internal.util.collections.ArrayHelper; + +import java.util.AbstractList; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Base class for sparse vectors. + * + * @since 7.1 + */ +public abstract class AbstractSparseVector extends AbstractList { + + protected static final int[] EMPTY_INT_ARRAY = new int[0]; + + public abstract int[] indices(); + + protected interface ElementParser { + V parse(String string, int start, int end); + } + protected record ParsedVector(int size, int[] indices, List elements) { + } + + protected static ParsedVector parseSparseVector(String string, ElementParser parser) { + if ( string == null || !string.startsWith( "[" ) || !string.endsWith( "]" ) ) { + throw invalidVector( string ); + } + final int lengthEndIndex = string.indexOf( ',', 2 ); + if ( lengthEndIndex == -1 ) { + throw invalidVector( string ); + } + final int indicesStartIndex = lengthEndIndex + 1; + if ( string.charAt( indicesStartIndex ) != '[' ) { + throw invalidVector( string ); + } + final int indicesEndIndex = string.indexOf( ']', indicesStartIndex + 1 ); + if ( indicesEndIndex == -1 ) { + throw invalidVector( string ); + } + final int commaIndex = indicesEndIndex + 1; + if ( string.charAt( commaIndex ) != ',' ) { + throw invalidVector( string ); + } + final int elementsStartIndex = commaIndex + 1; + if ( string.charAt( elementsStartIndex ) != '[' ) { + throw invalidVector( string ); + } + final int elementsEndIndex = string.indexOf( ']', elementsStartIndex + 1 ); + if ( elementsEndIndex == -1 ) { + throw invalidVector( string ); + } + if ( elementsEndIndex != string.length() - 2 ) { + throw invalidVector( string ); + } + final int size = Integer.parseInt( string, 1, lengthEndIndex, 10 ); + int start = indicesStartIndex + 1; + final List indicesList = new ArrayList<>(); + if ( start < indicesEndIndex ) { + for ( int i = start; i < indicesEndIndex; i++ ) { + if ( string.charAt( i ) == ',' ) { + indicesList.add( Integer.parseInt( string, start, i, 10 ) ); + start = i + 1; + } + } + indicesList.add( Integer.parseInt( string, start, indicesEndIndex, 10 ) ); + } + final int[] indices = ArrayHelper.toIntArray( indicesList ); + final List elements = new ArrayList<>( indices.length ); + start = elementsStartIndex + 1; + if ( start < elementsEndIndex ) { + for ( int i = start; i < elementsEndIndex; i++ ) { + if ( string.charAt( i ) == ',' ) { + elements.add( parser.parse( string, start, i ) ); + start = i + 1; + } + } + elements.add( parser.parse( string, start, elementsEndIndex ) ); + } + return new ParsedVector<>( size, indices, elements ); + } + + private static IllegalArgumentException invalidVector(String string) { + return new IllegalArgumentException( "Invalid sparse vector string: " + string ); + } + + protected static int[] validateIndices(int[] indices, int dataLength, int size) { + if ( indices == null ) { + throw new IllegalArgumentException( "indices cannot be null" ); + } + if ( indices.length != dataLength ) { + throw new IllegalArgumentException( "indices length does not match data length" ); + } + int previousIndex = -1; + for ( int i = 0; i < indices.length; i++ ) { + if ( indices[i] < 0 ) { + throw new IllegalArgumentException( "indices[" + i + "] < 0" ); + } + else if ( indices[i] < previousIndex ) { + throw new IllegalArgumentException( "Indices array is not sorted ascendingly." ); + } + previousIndex = indices[i]; + } + if ( previousIndex >= size ) { + throw new IllegalArgumentException( "Indices array contains index " + previousIndex + " that is greater than or equal to size: " + size ); + } + return indices; + } + + @Override + public void clear() { + throw new UnsupportedOperationException( "Cannot remove from sparse vector" ); + } + + @Override + public E remove(int index) { + throw new UnsupportedOperationException( "Cannot remove from sparse vector" ); + } + + @Override + public boolean add(E aByte) { + throw new UnsupportedOperationException( "Cannot add to sparse vector" ); + } + + @Override + public void add(int index, E element) { + throw new UnsupportedOperationException( "Cannot add to sparse vector" ); + } + + @Override + public boolean addAll(int index, Collection c) { + throw new UnsupportedOperationException( "Cannot add to sparse vector" ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/SparseByteVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/SparseByteVector.java new file mode 100644 index 000000000000..dac7de89676e --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/SparseByteVector.java @@ -0,0 +1,205 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.List; + +/** + * {@link java.util.List} implementation for a sparse byte vector. + * + * @since 7.1 + */ +public class SparseByteVector extends AbstractSparseVector { + + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + + private final int size; + private int[] indices = EMPTY_INT_ARRAY; + private byte[] data = EMPTY_BYTE_ARRAY; + + public SparseByteVector(int size) { + if ( size <= 0 ) { + throw new IllegalArgumentException( "size must be greater than zero" ); + } + this.size = size; + } + + public SparseByteVector(List list) { + if ( list instanceof SparseByteVector sparseVector ) { + size = sparseVector.size; + indices = sparseVector.indices.clone(); + data = sparseVector.data.clone(); + } + else { + if ( list == null ) { + throw new IllegalArgumentException( "list cannot be null" ); + } + if ( list.isEmpty() ) { + throw new IllegalArgumentException( "list cannot be empty" ); + } + int size = 0; + int[] indices = new int[list.size()]; + byte[] data = new byte[list.size()]; + for ( int i = 0; i < list.size(); i++ ) { + final Byte b = list.get( i ); + if ( b != null && b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = list.size(); + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + } + + public SparseByteVector(byte[] denseVector) { + if ( denseVector == null ) { + throw new IllegalArgumentException( "denseVector cannot be null" ); + } + if ( denseVector.length == 0 ) { + throw new IllegalArgumentException( "denseVector cannot be empty" ); + } + int size = 0; + int[] indices = new int[denseVector.length]; + byte[] data = new byte[denseVector.length]; + for ( int i = 0; i < denseVector.length; i++ ) { + final byte b = denseVector[i]; + if ( b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = denseVector.length; + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + + public SparseByteVector(int size, int[] indices, byte[] data) { + this( validateData( data, size ), validateIndices( indices, data.length, size ), size ); + } + + private SparseByteVector(byte[] data, int[] indices, int size) { + this.size = size; + this.indices = indices; + this.data = data; + } + + public SparseByteVector(String string) { + final ParsedVector parsedVector = + parseSparseVector( string, (s, start, end) -> Byte.parseByte( s.substring( start, end ) ) ); + this.size = parsedVector.size(); + this.indices = parsedVector.indices(); + this.data = toByteArray( parsedVector.elements() ); + } + + private static byte[] toByteArray(List elements) { + final byte[] result = new byte[elements.size()]; + for ( int i = 0; i < elements.size(); i++ ) { + result[i] = elements.get(i); + } + return result; + } + + private static byte[] validateData(byte[] data, int size) { + if ( size == 0 ) { + throw new IllegalArgumentException( "size cannot be 0" ); + } + if ( data == null ) { + throw new IllegalArgumentException( "data cannot be null" ); + } + if ( size < data.length ) { + throw new IllegalArgumentException( "size cannot be smaller than data size" ); + } + for ( int i = 0; i < data.length; i++ ) { + if ( data[i] == 0 ) { + throw new IllegalArgumentException( "data[" + i + "] == 0" ); + } + } + return data; + } + + @Override + public SparseByteVector clone() { + return new SparseByteVector( data.clone(), indices.clone(), size ); + } + + @Override + public Byte get(int index) { + final int foundIndex = Arrays.binarySearch( indices, index ); + return foundIndex < 0 ? 0 : data[foundIndex]; + } + + @Override + public Byte set(int index, Byte element) { + final int foundIndex = Arrays.binarySearch( indices, index ); + if ( foundIndex < 0 ) { + if ( element != null && element != 0 ) { + final int[] newIndices = new int[indices.length + 1]; + final byte[] newData = new byte[data.length + 1]; + final int insertionPoint = -foundIndex - 1; + System.arraycopy( indices, 0, newIndices, 0, insertionPoint ); + System.arraycopy( data, 0, newData, 0, insertionPoint ); + newIndices[insertionPoint] = index; + newData[insertionPoint] = element; + System.arraycopy( indices, insertionPoint, newIndices, insertionPoint + 1, indices.length - insertionPoint ); + System.arraycopy( data, insertionPoint, newData, insertionPoint + 1, data.length - insertionPoint ); + this.indices = newIndices; + this.data = newData; + } + return null; + } + else { + final byte oldValue = data[foundIndex]; + if ( element != null && element != 0 ) { + data[foundIndex] = element; + } + else { + final int[] newIndices = new int[indices.length - 1]; + final byte[] newData = new byte[data.length - 1]; + System.arraycopy( indices, 0, newIndices, 0, foundIndex ); + System.arraycopy( data, 0, newData, 0, foundIndex ); + System.arraycopy( indices, foundIndex + 1, newIndices, foundIndex, indices.length - foundIndex - 1 ); + System.arraycopy( data, foundIndex + 1, newData, foundIndex, data.length - foundIndex - 1 ); + this.indices = newIndices; + this.data = newData; + } + return oldValue; + } + } + + public byte[] toDenseVector() { + final byte[] result = new byte[this.size]; + for ( int i = 0; i < indices.length; i++ ) { + result[indices[i]] = data[i]; + } + return result; + } + + @Override + public int[] indices() { + return indices; + } + + public byte[] bytes() { + return data; + } + + @Override + public int size() { + return size; + } + + @Override + public String toString() { + return "[" + size + + "," + Arrays.toString( indices ) + + "," + Arrays.toString( data ) + + "]"; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/SparseDoubleVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/SparseDoubleVector.java new file mode 100644 index 000000000000..6fcb9f67b8fe --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/SparseDoubleVector.java @@ -0,0 +1,185 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.List; + +/** + * {@link List} implementation for a sparse byte vector. + * + * @since 7.1 + */ +public class SparseDoubleVector extends AbstractSparseVector { + + private static final double[] EMPTY_FLOAT_ARRAY = new double[0]; + + private final int size; + private int[] indices = EMPTY_INT_ARRAY; + private double[] data = EMPTY_FLOAT_ARRAY; + + public SparseDoubleVector(int size) { + this.size = size; + } + + public SparseDoubleVector(List list) { + if ( list instanceof SparseDoubleVector sparseVector ) { + size = sparseVector.size; + indices = sparseVector.indices.clone(); + data = sparseVector.data.clone(); + } + else { + int size = 0; + int[] indices = new int[list.size()]; + double[] data = new double[list.size()]; + for ( int i = 0; i < list.size(); i++ ) { + final Double b = list.get( i ); + if ( b != null && b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = list.size(); + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + } + + public SparseDoubleVector(double[] denseVector) { + int size = 0; + int[] indices = new int[denseVector.length]; + double[] data = new double[denseVector.length]; + for ( int i = 0; i < denseVector.length; i++ ) { + final double b = denseVector[i]; + if ( b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = denseVector.length; + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + + public SparseDoubleVector(int size, int[] indices, double[] data) { + this( validateData( data ), validateIndices( indices, data.length, size ), size ); + } + + private SparseDoubleVector(double[] data, int[] indices, int size) { + this.size = size; + this.indices = indices; + this.data = data; + } + + public SparseDoubleVector(String string) { + final ParsedVector parsedVector = + parseSparseVector( string, (s, start, end) -> Double.parseDouble( s.substring( start, end ) ) ); + this.size = parsedVector.size(); + this.indices = parsedVector.indices(); + this.data = toDoubleArray( parsedVector.elements() ); + } + + private static double[] toDoubleArray(List elements) { + final double[] result = new double[elements.size()]; + for ( int i = 0; i < elements.size(); i++ ) { + result[i] = elements.get(i); + } + return result; + } + + private static double[] validateData(double[] data) { + if ( data == null ) { + throw new IllegalArgumentException( "data cannot be null" ); + } + for ( int i = 0; i < data.length; i++ ) { + if ( data[i] == 0 ) { + throw new IllegalArgumentException( "data[" + i + "] == 0" ); + } + } + return data; + } + + @Override + public SparseDoubleVector clone() { + return new SparseDoubleVector( data.clone(), indices.clone(), size ); + } + + @Override + public Double get(int index) { + final int foundIndex = Arrays.binarySearch( indices, index ); + return foundIndex < 0 ? 0 : data[foundIndex]; + } + + @Override + public Double set(int index, Double element) { + final int foundIndex = Arrays.binarySearch( indices, index ); + if ( foundIndex < 0 ) { + if ( element != null && element != 0 ) { + final int[] newIndices = new int[indices.length + 1]; + final double[] newData = new double[data.length + 1]; + final int insertionPoint = -foundIndex - 1; + System.arraycopy( indices, 0, newIndices, 0, insertionPoint ); + System.arraycopy( data, 0, newData, 0, insertionPoint ); + newIndices[insertionPoint] = index; + newData[insertionPoint] = element; + System.arraycopy( indices, insertionPoint, newIndices, insertionPoint + 1, indices.length - insertionPoint ); + System.arraycopy( data, insertionPoint, newData, insertionPoint + 1, data.length - insertionPoint ); + this.indices = newIndices; + this.data = newData; + } + return null; + } + else { + final double oldValue = data[foundIndex]; + if ( element != null && element != 0 ) { + data[foundIndex] = element; + } + else { + final int[] newIndices = new int[indices.length - 1]; + final double[] newData = new double[data.length - 1]; + System.arraycopy( indices, 0, newIndices, 0, foundIndex ); + System.arraycopy( data, 0, newData, 0, foundIndex ); + System.arraycopy( indices, foundIndex + 1, newIndices, foundIndex, indices.length - foundIndex - 1 ); + System.arraycopy( data, foundIndex + 1, newData, foundIndex, data.length - foundIndex - 1 ); + this.indices = newIndices; + this.data = newData; + } + return oldValue; + } + } + + public double[] toDenseVector() { + final double[] result = new double[this.size]; + for ( int i = 0; i < indices.length; i++ ) { + result[indices[i]] = data[i]; + } + return result; + } + + @Override + public int[] indices() { + return indices; + } + + public double[] doubles() { + return data; + } + + @Override + public int size() { + return size; + } + + @Override + public String toString() { + return "[" + size + + "," + Arrays.toString( indices ) + + "," + Arrays.toString( data ) + + "]"; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/SparseFloatVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/SparseFloatVector.java new file mode 100644 index 000000000000..92010ba84f43 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/SparseFloatVector.java @@ -0,0 +1,185 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.List; + +/** + * {@link List} implementation for a sparse byte vector. + * + * @since 7.1 + */ +public class SparseFloatVector extends AbstractSparseVector { + + private static final float[] EMPTY_FLOAT_ARRAY = new float[0]; + + private final int size; + private int[] indices = EMPTY_INT_ARRAY; + private float[] data = EMPTY_FLOAT_ARRAY; + + public SparseFloatVector(int size) { + this.size = size; + } + + public SparseFloatVector(List list) { + if ( list instanceof SparseFloatVector sparseVector ) { + size = sparseVector.size; + indices = sparseVector.indices.clone(); + data = sparseVector.data.clone(); + } + else { + int size = 0; + int[] indices = new int[list.size()]; + float[] data = new float[list.size()]; + for ( int i = 0; i < list.size(); i++ ) { + final Float b = list.get( i ); + if ( b != null && b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = list.size(); + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + } + + public SparseFloatVector(float[] denseVector) { + int size = 0; + int[] indices = new int[denseVector.length]; + float[] data = new float[denseVector.length]; + for ( int i = 0; i < denseVector.length; i++ ) { + final float b = denseVector[i]; + if ( b != 0 ) { + indices[size] = i; + data[size] = b; + size++; + } + } + this.size = denseVector.length; + this.indices = Arrays.copyOf( indices, size ); + this.data = Arrays.copyOf( data, size ); + } + + public SparseFloatVector(int size, int[] indices, float[] data) { + this( validateData( data ), validateIndices( indices, data.length, size ), size ); + } + + private SparseFloatVector(float[] data, int[] indices, int size) { + this.size = size; + this.indices = indices; + this.data = data; + } + + public SparseFloatVector(String string) { + final ParsedVector parsedVector = + parseSparseVector( string, (s, start, end) -> Float.parseFloat( s.substring( start, end ) ) ); + this.size = parsedVector.size(); + this.indices = parsedVector.indices(); + this.data = toFloatArray( parsedVector.elements() ); + } + + private static float[] toFloatArray(List elements) { + final float[] result = new float[elements.size()]; + for ( int i = 0; i < elements.size(); i++ ) { + result[i] = elements.get(i); + } + return result; + } + + private static float[] validateData(float[] data) { + if ( data == null ) { + throw new IllegalArgumentException( "data cannot be null" ); + } + for ( int i = 0; i < data.length; i++ ) { + if ( data[i] == 0 ) { + throw new IllegalArgumentException( "data[" + i + "] == 0" ); + } + } + return data; + } + + @Override + public SparseFloatVector clone() { + return new SparseFloatVector( data.clone(), indices.clone(), size ); + } + + @Override + public Float get(int index) { + final int foundIndex = Arrays.binarySearch( indices, index ); + return foundIndex < 0 ? 0 : data[foundIndex]; + } + + @Override + public Float set(int index, Float element) { + final int foundIndex = Arrays.binarySearch( indices, index ); + if ( foundIndex < 0 ) { + if ( element != null && element != 0 ) { + final int[] newIndices = new int[indices.length + 1]; + final float[] newData = new float[data.length + 1]; + final int insertionPoint = -foundIndex - 1; + System.arraycopy( indices, 0, newIndices, 0, insertionPoint ); + System.arraycopy( data, 0, newData, 0, insertionPoint ); + newIndices[insertionPoint] = index; + newData[insertionPoint] = element; + System.arraycopy( indices, insertionPoint, newIndices, insertionPoint + 1, indices.length - insertionPoint ); + System.arraycopy( data, insertionPoint, newData, insertionPoint + 1, data.length - insertionPoint ); + this.indices = newIndices; + this.data = newData; + } + return null; + } + else { + final float oldValue = data[foundIndex]; + if ( element != null && element != 0 ) { + data[foundIndex] = element; + } + else { + final int[] newIndices = new int[indices.length - 1]; + final float[] newData = new float[data.length - 1]; + System.arraycopy( indices, 0, newIndices, 0, foundIndex ); + System.arraycopy( data, 0, newData, 0, foundIndex ); + System.arraycopy( indices, foundIndex + 1, newIndices, foundIndex, indices.length - foundIndex - 1 ); + System.arraycopy( data, foundIndex + 1, newData, foundIndex, data.length - foundIndex - 1 ); + this.indices = newIndices; + this.data = newData; + } + return oldValue; + } + } + + public float[] toDenseVector() { + final float[] result = new float[this.size]; + for ( int i = 0; i < indices.length; i++ ) { + result[indices[i]] = data[i]; + } + return result; + } + + @Override + public int[] indices() { + return indices; + } + + public float[] floats() { + return data; + } + + @Override + public int size() { + return size; + } + + @Override + public String toString() { + return "[" + size + + "," + Arrays.toString( indices ) + + "," + Arrays.toString( data ) + + "]"; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/VectorJdbcType.java deleted file mode 100644 index b50144fdbbbd..000000000000 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorJdbcType.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright Red Hat Inc. and Hibernate Authors - */ -package org.hibernate.vector; - -import java.sql.CallableStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.BitSet; - -import org.hibernate.dialect.Dialect; -import org.hibernate.sql.ast.spi.SqlAppender; -import org.hibernate.type.SqlTypes; -import org.hibernate.type.descriptor.ValueExtractor; -import org.hibernate.type.descriptor.WrapperOptions; -import org.hibernate.type.descriptor.java.JavaType; -import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; -import org.hibernate.type.descriptor.jdbc.BasicExtractor; -import org.hibernate.type.descriptor.jdbc.JdbcType; -import org.hibernate.type.spi.TypeConfiguration; - -public class VectorJdbcType extends ArrayJdbcType { - - private static final float[] EMPTY = new float[0]; - public VectorJdbcType(JdbcType elementJdbcType) { - super( elementJdbcType ); - } - - @Override - public int getDefaultSqlTypeCode() { - return SqlTypes.VECTOR; - } - - @Override - public JavaType getJdbcRecommendedJavaTypeMapping( - Integer precision, - Integer scale, - TypeConfiguration typeConfiguration) { - return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); - } - - @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "cast(" ); - appender.append( writeExpression ); - appender.append( " as vector)" ); - } - - @Override - public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { - return new BasicExtractor<>( javaTypeDescriptor, this ) { - @Override - protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( getFloatArray( rs.getString( paramIndex ) ), options ); - } - - @Override - protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( getFloatArray( statement.getString( index ) ), options ); - } - - @Override - protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( getFloatArray( statement.getString( name ) ), options ); - } - - private float[] getFloatArray(String string) { - if ( string == null ) { - return null; - } - if ( string.length() == 2 ) { - return EMPTY; - } - final BitSet commaPositions = new BitSet(); - int size = 1; - for ( int i = 1; i < string.length(); i++ ) { - final char c = string.charAt( i ); - if ( c == ',' ) { - commaPositions.set( i ); - size++; - } - } - final float[] result = new float[size]; - int floatStartIndex = 1; - int commaIndex; - int index = 0; - while ( ( commaIndex = commaPositions.nextSetBit( floatStartIndex ) ) != -1 ) { - result[index++] = Float.parseFloat( string.substring( floatStartIndex, commaIndex ) ); - floatStartIndex = commaIndex + 1; - } - result[index] = Float.parseFloat( string.substring( floatStartIndex, string.length() - 1 ) ); - return result; - } - }; - } -} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractDB2VectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractDB2VectorJdbcType.java new file mode 100644 index 000000000000..bb11fd4db259 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractDB2VectorJdbcType.java @@ -0,0 +1,149 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.BasicPluralJavaType; +import org.hibernate.type.descriptor.java.ByteJavaType; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.java.PrimitiveByteArrayJavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +/** + * Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for DB2. + */ +public abstract class AbstractDB2VectorJdbcType extends ArrayJdbcType { + + public AbstractDB2VectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) { + return targetJdbcMapping.getJdbcType().isStringLike() ? "vector_serialize(?1 returning ?2)" : null; + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "vector(?1," + getVectorParameters( size ) + ")" : null; + } + + @Override + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { + appender.append( "vector(" ); + appender.append( writeExpression ); + appender.append( ',' ); + appender.append( getVectorParameters( size ) ); + appender.append( ')' ); + } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + + public abstract String getVectorParameters(@Nullable Size size); + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + final JavaType elementJavaType; + if ( javaTypeDescriptor instanceof PrimitiveByteArrayJavaType ) { + // Special handling needed for Byte[], because that would conflict with the VARBINARY mapping + //noinspection unchecked + elementJavaType = (JavaType) ByteJavaType.INSTANCE; + } + else if ( javaTypeDescriptor instanceof BasicPluralJavaType ) { + //noinspection unchecked + elementJavaType = ((BasicPluralJavaType) javaTypeDescriptor).getElementJavaType(); + } + else { + throw new IllegalArgumentException( "not a BasicPluralJavaType" ); + } + return new DB2JdbcLiteralFormatterVector<>( + javaTypeDescriptor, + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType ), + this + ); + } + + @Override + public String toString() { + return "DB2VectorTypeDescriptor"; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) + throws SQLException { + st.setString( index, + ((AbstractDB2VectorJdbcType) getJdbcType()).getStringVector( value, getJavaType(), options ) ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setString( name, + ((AbstractDB2VectorJdbcType) getJdbcType()).getStringVector( value, getJavaType(), options ) ); + } + + }; + } + + @Override + public ValueExtractor getExtractor(final JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return getJavaType().wrap( + ((AbstractDB2VectorJdbcType) getJdbcType()).getVectorArray( rs.getString( paramIndex ) ), + options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return getJavaType().wrap( + ((AbstractDB2VectorJdbcType) getJdbcType()).getVectorArray( statement.getString( index ) ), + options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) + throws SQLException { + return getJavaType().wrap( + ((AbstractDB2VectorJdbcType) getJdbcType()).getVectorArray( statement.getString( name ) ), + options ); + } + + }; + } + + protected abstract Object getVectorArray(String string); + + protected abstract String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options); + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleSparseVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleSparseVectorJdbcType.java new file mode 100644 index 000000000000..cab0cf11516d --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleSparseVectorJdbcType.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.vector.AbstractSparseVector; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public abstract class AbstractOracleSparseVectorJdbcType extends AbstractOracleVectorJdbcType { + + public AbstractOracleSparseVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new OracleJdbcLiteralFormatterSparseVector<>( javaTypeDescriptor, getVectorParameters() ); + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + st.setObject( index, getBindValue( value, options ) ); + } + else { + st.setString( index, stringVector( value, options ) ); + } + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + st.setObject( name, getBindValue( value, options ) ); + } + else { + st.setString( name, stringVector( value, options ) ); + } + } + + private String stringVector(X value, WrapperOptions options) { + return ((AbstractOracleSparseVectorJdbcType) getJdbcType()).getStringVector( value, getJavaType(), options ); + } + + @Override + public Object getBindValue(X value, WrapperOptions options) { + return ((AbstractOracleSparseVectorJdbcType) getJdbcType()).getBindValue( getJavaType(), value, options ); + } + }; + } + + protected abstract Object getBindValue(JavaType javaType, X value, WrapperOptions options); + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return javaTypeDescriptor.unwrap( vector, AbstractSparseVector.class, options ).toString(); + } + + @Override + protected Class getNativeJavaType() { + return Object.class; + } + + @Override + protected int getNativeTypeCode() { + return SqlTypes.OTHER; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleVectorJdbcType.java similarity index 77% rename from hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleVectorJdbcType.java index 48407feab48f..f39165273ecc 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleVectorJdbcType.java @@ -2,29 +2,28 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.sql.CallableStatement; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.ValueBinder; import org.hibernate.type.descriptor.ValueExtractor; import org.hibernate.type.descriptor.WrapperOptions; -import org.hibernate.type.descriptor.java.BasicPluralJavaType; -import org.hibernate.type.descriptor.java.ByteJavaType; import org.hibernate.type.descriptor.java.JavaType; -import org.hibernate.type.descriptor.java.PrimitiveByteArrayJavaType; import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; import org.hibernate.type.descriptor.jdbc.BasicBinder; import org.hibernate.type.descriptor.jdbc.BasicExtractor; import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; import org.hibernate.type.descriptor.jdbc.JdbcType; -import org.hibernate.type.descriptor.jdbc.internal.JdbcLiteralFormatterArray; /** * Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for Oracle. @@ -43,31 +42,47 @@ public AbstractOracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSu this.isVectorSupported = isVectorSupported; } - public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect); + @Override + public @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) { + return targetJdbcMapping.getJdbcType().isStringLike() ? "from_vector(?1 returning ?2)" : null; + } @Override - public int getDefaultSqlTypeCode() { - return SqlTypes.VECTOR; + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "to_vector(?1," + getVectorParameters() + ")" : null; } @Override - public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { - final JavaType elementJavaType; - if ( javaTypeDescriptor instanceof PrimitiveByteArrayJavaType ) { - // Special handling needed for Byte[], because that would conflict with the VARBINARY mapping - //noinspection unchecked - elementJavaType = (JavaType) ByteJavaType.INSTANCE; - } - else if ( javaTypeDescriptor instanceof BasicPluralJavaType ) { - //noinspection unchecked - elementJavaType = ( (BasicPluralJavaType) javaTypeDescriptor ).getElementJavaType(); + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { + if ( isVectorSupported ) { + appender.append( writeExpression ); } else { - throw new IllegalArgumentException( "not a BasicPluralJavaType" ); + appender.append( "to_vector(" ); + appender.append( writeExpression ); + appender.append( ',' ); + appender.append( getVectorParameters() ); + appender.append( ')' ); } - return new JdbcLiteralFormatterArray<>( + } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return !isVectorSupported; + } + + public abstract String getVectorParameters(); + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new OracleJdbcLiteralFormatterVector<>( javaTypeDescriptor, - getElementJdbcType().getJdbcLiteralFormatter( elementJavaType ) + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) ), + getVectorParameters().replace( ",sparse", "" ) ); } @@ -76,7 +91,6 @@ public String toString() { return "OracleVectorTypeDescriptor"; } - @Override public ValueBinder getBinder(final JavaType javaTypeDescriptor) { return new BasicBinder<>( javaTypeDescriptor, this ) { @@ -142,7 +156,7 @@ protected X doExtract(CallableStatement statement, String name, WrapperOptions o }; } - protected abstract T getVectorArray(String string); + protected abstract Object getVectorArray(String string); protected abstract String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options); diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/CockroachFunctionContributor.java similarity index 72% rename from hibernate-vector/src/main/java/org/hibernate/vector/PGVectorFunctionContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/CockroachFunctionContributor.java index 7b52fcafa592..ca4f7acca60d 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorFunctionContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/CockroachFunctionContributor.java @@ -2,24 +2,24 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.FunctionContributor; import org.hibernate.dialect.CockroachDialect; import org.hibernate.dialect.Dialect; -import org.hibernate.dialect.PostgreSQLDialect; -public class PGVectorFunctionContributor implements FunctionContributor { +public class CockroachFunctionContributor implements FunctionContributor { @Override public void contributeFunctions(FunctionContributions functionContributions) { final Dialect dialect = functionContributions.getDialect(); - if (dialect instanceof PostgreSQLDialect || dialect instanceof CockroachDialect) { + if ( dialect instanceof CockroachDialect && dialect.getVersion().isSameOrAfter( 24, 2 ) ) { final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); vectorFunctionFactory.cosineDistance( "?1<=>?2" ); vectorFunctionFactory.euclideanDistance( "?1<->?2" ); + vectorFunctionFactory.euclideanSquaredDistance( "(?1<->?2)^2" ); vectorFunctionFactory.l1Distance( "l1_distance(?1,?2)" ); vectorFunctionFactory.innerProduct( "(?1<#>?2)*-1" ); @@ -27,6 +27,8 @@ public void contributeFunctions(FunctionContributions functionContributions) { vectorFunctionFactory.vectorDimensions(); vectorFunctionFactory.vectorNorm(); + + functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "vector_norm" ); } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/CockroachTypeContributor.java similarity index 52% rename from hibernate-vector/src/main/java/org/hibernate/vector/PGVectorTypeContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/CockroachTypeContributor.java index 1cfcf5257487..4680b984c17d 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/PGVectorTypeContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/CockroachTypeContributor.java @@ -2,16 +2,12 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; - -import java.lang.reflect.Type; +package org.hibernate.vector.internal; import org.hibernate.boot.model.TypeContributions; import org.hibernate.boot.model.TypeContributor; import org.hibernate.dialect.CockroachDialect; import org.hibernate.dialect.Dialect; -import org.hibernate.dialect.PostgreSQLDialect; -import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.jdbc.spi.JdbcServices; import org.hibernate.service.ServiceRegistry; import org.hibernate.type.BasicArrayType; @@ -22,49 +18,52 @@ import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; -import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.spi.TypeConfiguration; -public class PGVectorTypeContributor implements TypeContributor { - - private static final Type[] VECTOR_JAVA_TYPES = { - Float[].class, - float[].class - }; +public class CockroachTypeContributor implements TypeContributor { @Override public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); - if ( dialect instanceof PostgreSQLDialect || - dialect instanceof CockroachDialect ) { + if ( dialect instanceof CockroachDialect && dialect.getVersion().isSameOrAfter( 24, 2 ) ) { final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); - final ArrayJdbcType vectorJdbcType = new VectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) ); - jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType ); - for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) { - basicTypeRegistry.register( - new BasicArrayType<>( - floatBasicType, - vectorJdbcType, - javaTypeRegistry.getDescriptor( vectorJavaType ) - ), - StandardBasicTypes.VECTOR.getName() - ); - } + final ArrayJdbcType genericVectorJdbcType = new PGVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR, + "vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new PGVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT32, + "vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) + ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return getTypeName( - size.getArrayLength() == null ? null : size.getArrayLength().longValue(), - null, - null - ); - } - } + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l)", "vector", dialect ) ); } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2ByteVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2ByteVectorJdbcType.java new file mode 100644 index 000000000000..3f6920a587ce --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2ByteVectorJdbcType.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +import java.util.Arrays; + +/** + * Specialized type mapping for single-byte integer vector {@link SqlTypes#VECTOR_INT8} SQL data type for DB2. + */ +public class DB2ByteVectorJdbcType extends AbstractDB2VectorJdbcType { + + public DB2ByteVectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public String getVectorParameters(@Nullable Size size) { + assert size != null; + return size.getArrayLength() + ",int8"; + } + + @Override + public String getFriendlyName() { + return "VECTOR_INT8"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_INT8; + } + + @Override + protected byte[] getVectorArray(String string) { + return VectorHelper.parseByteVector( string ); + } + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return Arrays.toString( javaTypeDescriptor.unwrap( vector, byte[].class, options ) ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2FloatVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2FloatVectorJdbcType.java new file mode 100644 index 000000000000..6f2914c46efa --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2FloatVectorJdbcType.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +import java.util.Arrays; + +/** + * Specialized type mapping for single-precision floating-point vector {@link SqlTypes#VECTOR_FLOAT32} SQL data type for DB2. + */ + +public class DB2FloatVectorJdbcType extends AbstractDB2VectorJdbcType { + + public DB2FloatVectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public String getVectorParameters(@Nullable Size size) { + assert size != null; + return size.getArrayLength() + ",float32"; + } + + @Override + public String getFriendlyName() { + return "VECTOR_FLOAT32"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_FLOAT32; + } + + @Override + protected float[] getVectorArray(String string) { + return VectorHelper.parseFloatVector( string ); + } + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return Arrays.toString( javaTypeDescriptor.unwrap( vector, float[].class, options ) ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2JdbcLiteralFormatterVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2JdbcLiteralFormatterVector.java new file mode 100644 index 000000000000..f5c91c997ea3 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2JdbcLiteralFormatterVector.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class DB2JdbcLiteralFormatterVector extends BasicJdbcLiteralFormatter { + + private final JdbcLiteralFormatter elementFormatter; + private final AbstractDB2VectorJdbcType db2VectorJdbcType; + + public DB2JdbcLiteralFormatterVector(JavaType javaType, JdbcLiteralFormatter elementFormatter, AbstractDB2VectorJdbcType db2VectorJdbcType) { + super( javaType ); + //noinspection unchecked + this.elementFormatter = (JdbcLiteralFormatter) elementFormatter; + this.db2VectorJdbcType = db2VectorJdbcType; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + final Object[] objects = unwrapArray( value, wrapperOptions ); + appender.append( "vector('" ); + char separator = '['; + for ( Object o : objects ) { + appender.append( separator ); + elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions ); + separator = ','; + } + appender.append( "]'," ); + appender.append( db2VectorJdbcType.getVectorParameters( new Size().setArrayLength( objects.length ) ) ); + appender.append( ')' ); + } + + private Object[] unwrapArray(Object value, WrapperOptions wrapperOptions) { + return unwrap( value, Object[].class, wrapperOptions ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorFunctionContributor.java new file mode 100644 index 000000000000..8ff485fb94fc --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorFunctionContributor.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.DB2Dialect; +import org.hibernate.dialect.Dialect; +import org.hibernate.type.BasicType; +import org.hibernate.type.spi.TypeConfiguration; + + +public class DB2VectorFunctionContributor implements FunctionContributor { + + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final Dialect dialect = functionContributions.getDialect(); + if ( dialect instanceof DB2Dialect db2Dialect && db2Dialect.getDB2Version().isSameOrAfter( 12, 1, 2 ) ) { + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); + + vectorFunctionFactory.cosineDistance( "vector_distance(?1,?2,COSINE)" ); + vectorFunctionFactory.euclideanDistance( "vector_distance(?1,?2,EUCLIDEAN)" ); + vectorFunctionFactory.euclideanSquaredDistance( "vector_distance(?1,?2,EUCLIDEAN_SQUARED)" ); + vectorFunctionFactory.l1Distance( "vector_distance(?1,?2,MANHATTAN)" ); + vectorFunctionFactory.hammingDistance( "vector_distance(?1,?2,HAMMING)" ); + + vectorFunctionFactory.innerProduct( "vector_distance(?1,?2,DOT)*-1" ); + vectorFunctionFactory.negativeInnerProduct( "vector_distance(?1,?2,DOT)" ); + + final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); + final BasicType integerType = typeConfiguration.getBasicTypeForJavaType( Integer.class ); + final BasicType doubleType = typeConfiguration.getBasicTypeForJavaType( Double.class ); + vectorFunctionFactory.registerNamedVectorFunction("vector_dimension_count", integerType, 1 ); + functionContributions.getFunctionRegistry().registerAlternateKey( "vector_dims", "vector_dimension_count" ); + vectorFunctionFactory.registerPatternVectorFunction( "vector_norm", "vector_norm(?1,EUCLIDEAN)", doubleType, 1 ); + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorJdbcType.java new file mode 100644 index 000000000000..7fb256c23703 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorJdbcType.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +/** + * Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for DB2. + */ +public class DB2VectorJdbcType extends DB2FloatVectorJdbcType { + + public DB2VectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public String getFriendlyName() { + return "VECTOR"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorTypeContributor.java new file mode 100644 index 000000000000..83a2a62d1b3c --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/DB2VectorTypeContributor.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; +import org.hibernate.dialect.DB2Dialect; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.service.ServiceRegistry; +import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; +import org.hibernate.type.spi.TypeConfiguration; + +public class DB2VectorTypeContributor implements TypeContributor { + + @Override + public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { + final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); + if ( dialect instanceof DB2Dialect db2Dialect && db2Dialect.getDB2Version().isSameOrAfter( 12, 1, 2 ) ) { + final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); + final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); + final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); + final ArrayJdbcType genericVectorJdbcType = new DB2VectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new DB2FloatVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + final JdbcType byteVectorJdbcType = new DB2ByteVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ) + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_INT8, byteVectorJdbcType ); + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.BYTE ), + byteVectorJdbcType, + javaTypeRegistry.getDescriptor( byte[].class ) + ), + StandardBasicTypes.VECTOR_INT8.getName() + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR, "vector($l,float32)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_INT8, "vector($l,int8)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l,float32)", "vector", dialect ) + ); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAJdbcLiteralFormatterVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAJdbcLiteralFormatterVector.java new file mode 100644 index 000000000000..d849866e8441 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAJdbcLiteralFormatterVector.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class HANAJdbcLiteralFormatterVector extends BasicJdbcLiteralFormatter { + + private final JdbcLiteralFormatter elementFormatter; + private final String typeName; + + public HANAJdbcLiteralFormatterVector(JavaType javaType, JdbcLiteralFormatter elementFormatter, String typeName) { + super( javaType ); + //noinspection unchecked + this.elementFormatter = (JdbcLiteralFormatter) elementFormatter; + this.typeName = typeName; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + appender.appendSql( "to_" ); + appender.appendSql( typeName ); + appender.appendSql( "('" ); + final Object[] objects = unwrap( value, Object[].class, wrapperOptions ); + appender.appendSql( "cast('" ); + char separator = '['; + for ( Object o : objects ) { + appender.appendSql( separator ); + elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions ); + separator = ','; + } + appender.appendSql( "]')" ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorFunctionContributor.java new file mode 100644 index 000000000000..9f0a7b097933 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorFunctionContributor.java @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.HANADialect; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.type.spi.TypeConfiguration; + +import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER; + +public class HANAVectorFunctionContributor implements FunctionContributor { + + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final Dialect dialect = functionContributions.getDialect(); + if ( dialect instanceof HANADialect hanaDialect && hanaDialect.isCloud() ) { + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); + + vectorFunctionFactory.cosineDistance( "cosine_similarity(?1,?2)" ); + vectorFunctionFactory.euclideanDistance( "l2distance(?1,?2)" ); + vectorFunctionFactory.euclideanSquaredDistance( "power(l2distance(?1,?2),2)" ); + + final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); + vectorFunctionFactory.registerPatternVectorFunction( + "vector_dims", + "cardinality(?1)", + typeConfiguration.getBasicTypeForJavaType( Integer.class ), + 1 + ); + vectorFunctionFactory.registerNamedVectorFunction( + "l2norm", + typeConfiguration.getBasicTypeForJavaType( Double.class ), + 1 + ); + functionContributions.getFunctionRegistry().registerAlternateKey( "vector_norm", "l2norm" ); + functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "l2norm" ); + + functionContributions.getFunctionRegistry().namedDescriptorBuilder( "subvector" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 3 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( StandardFunctionArgumentTypeResolvers.byArgument( + VectorArgumentTypeResolver.INSTANCE, + StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ), + StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ) + ) ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) ) + .register(); + functionContributions.getFunctionRegistry().namedDescriptorBuilder( "l2normalize" ) + .setArgumentsValidator( VectorArgumentValidator.INSTANCE ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) ) + .register(); + functionContributions.getFunctionRegistry().registerAlternateKey( "l2_normalize", "l2normalize" ); + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorJdbcType.java new file mode 100644 index 000000000000..4231872b6f25 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorJdbcType.java @@ -0,0 +1,148 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; + +import static org.hibernate.vector.internal.VectorHelper.parseFloatVector; + +public class HANAVectorJdbcType extends ArrayJdbcType { + + private final int sqlType; + private final String typeName; + + public HANAVectorJdbcType(JdbcType elementJdbcType, int sqlType, String typeName) { + super( elementJdbcType ); + this.sqlType = sqlType; + this.typeName = typeName; + } + + @Override + public int getDefaultSqlTypeCode() { + return sqlType; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new HANAJdbcLiteralFormatterVector<>( + javaTypeDescriptor, + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) ), + typeName + ); + } + + @Override + public @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) { + final JdbcType jdbcType = targetJdbcMapping.getJdbcType(); + return jdbcType.isString() + ? jdbcType.isLob() ? "to_nclob(?1)" : "to_nvarchar(?1)" + : null; + } + + @Override + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { + appender.append( "to_" ); + appender.append( typeName ); + appender.append( '('); + appender.append( writeExpression ); + appender.append( ')' ); + } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "to_" + typeName + "(?1)" : null; + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( rs.getString( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( name ) ), options ); + } + }; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + st.setString( index, getBindValue( value, options ) ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setString( name, getBindValue( value, options ) ); + } + + @Override + public String getBindValue(X value, WrapperOptions options) { + return Arrays.toString( getJavaType().unwrap( value, float[].class, options ) ); + } + }; + } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof HANAVectorJdbcType vectorJdbcType + && sqlType == vectorJdbcType.sqlType; + } + + @Override + public int hashCode() { + return sqlType + 31 * super.hashCode(); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorTypeContributor.java new file mode 100644 index 000000000000..32fdd9c9c9df --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/HANAVectorTypeContributor.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.HANADialect; +import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.service.ServiceRegistry; +import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; +import org.hibernate.type.spi.TypeConfiguration; + +public class HANAVectorTypeContributor implements TypeContributor { + + @Override + public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { + final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); + if ( dialect instanceof HANADialect hanaDialect && hanaDialect.isCloud() ) { + final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); + final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); + final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); + final ArrayJdbcType genericVectorJdbcType = new HANAVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR, + "real_vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new HANAVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT32, + "real_vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + final ArrayJdbcType float16VectorJdbcType = new HANAVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT16, + "half_vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT16, float16VectorJdbcType ); + + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + float16VectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT16.getName() + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR, "real_vector($l)", "real_vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "real_vector($l)", "real_vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT16, "half_vector($l)", "half_vector", dialect ) + ); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBFunctionContributor.java similarity index 86% rename from hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBFunctionContributor.java index ac14aa3d48cd..f707dbb69d1e 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBFunctionContributor.java @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.FunctionContributor; @@ -13,7 +13,7 @@ public class MariaDBFunctionContributor implements FunctionContributor { @Override public void contributeFunctions(FunctionContributions functionContributions) { final Dialect dialect = functionContributions.getDialect(); - if ( dialect instanceof MariaDBDialect ) { + if ( dialect instanceof MariaDBDialect && dialect.getVersion().isSameOrAfter( 11, 7 ) ) { final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); vectorFunctionFactory.cosineDistance( "vec_distance_cosine(?1,?2)" ); diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBJdbcLiteralFormatterVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBJdbcLiteralFormatterVector.java new file mode 100644 index 000000000000..bfd8ef4a50f0 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBJdbcLiteralFormatterVector.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class MariaDBJdbcLiteralFormatterVector extends BasicJdbcLiteralFormatter { + + private final JdbcLiteralFormatter elementFormatter; + + public MariaDBJdbcLiteralFormatterVector(JavaType javaType, JdbcLiteralFormatter elementFormatter) { + super( javaType ); + //noinspection unchecked + this.elementFormatter = (JdbcLiteralFormatter) elementFormatter; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + final Object[] objects = unwrap( value, Object[].class, wrapperOptions ); + appender.appendSql( "vec_fromtext('" ); + char separator = '['; + for ( Object o : objects ) { + appender.appendSql( separator ); + elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions ); + separator = ','; + } + appender.appendSql( "]')" ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBTypeContributor.java similarity index 54% rename from hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBTypeContributor.java index 78a3540db69d..b56662a1185d 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBTypeContributor.java @@ -2,13 +2,12 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.TypeContributions; import org.hibernate.boot.model.TypeContributor; import org.hibernate.dialect.Dialect; import org.hibernate.dialect.MariaDBDialect; -import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.jdbc.spi.JdbcServices; import org.hibernate.service.ServiceRegistry; import org.hibernate.type.BasicArrayType; @@ -19,50 +18,50 @@ import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; -import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.spi.TypeConfiguration; -import java.lang.reflect.Type; - public class MariaDBTypeContributor implements TypeContributor { - private static final Type[] VECTOR_JAVA_TYPES = { - Float[].class, - float[].class - }; - @Override public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); - if ( dialect instanceof MariaDBDialect ) { + if ( dialect instanceof MariaDBDialect && dialect.getVersion().isSameOrAfter( 11, 7 ) ) { final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); - final ArrayJdbcType vectorJdbcType = new BinaryVectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) ); - jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType ); - for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) { - basicTypeRegistry.register( - new BasicArrayType<>( - floatBasicType, - vectorJdbcType, - javaTypeRegistry.getDescriptor( vectorJavaType ) - ), - StandardBasicTypes.VECTOR.getName() - ); - } + final ArrayJdbcType genericVectorJdbcType = new MariaDBVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new MariaDBVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT32 + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) + ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return getTypeName( - size.getArrayLength() == null ? null : size.getArrayLength().longValue(), - null, - null - ); - } - } + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l)", "vector", dialect ) ); } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBVectorJdbcType.java similarity index 52% rename from hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBVectorJdbcType.java index 2f25d70edbd8..b6cd2390f59d 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBVectorJdbcType.java @@ -2,11 +2,13 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; import org.hibernate.sql.ast.spi.SqlAppender; -import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.ValueBinder; import org.hibernate.type.descriptor.ValueExtractor; import org.hibernate.type.descriptor.WrapperOptions; @@ -14,6 +16,7 @@ import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; import org.hibernate.type.descriptor.jdbc.BasicBinder; import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.spi.TypeConfiguration; @@ -22,15 +25,18 @@ import java.sql.ResultSet; import java.sql.SQLException; -public class BinaryVectorJdbcType extends ArrayJdbcType { +public class MariaDBVectorJdbcType extends ArrayJdbcType { - public BinaryVectorJdbcType(JdbcType elementJdbcType) { + private final int sqlType; + + public MariaDBVectorJdbcType(JdbcType elementJdbcType, int sqlType) { super( elementJdbcType ); + this.sqlType = sqlType; } @Override public int getDefaultSqlTypeCode() { - return SqlTypes.VECTOR; + return sqlType; } @Override @@ -42,26 +48,48 @@ public JavaType getJdbcRecommendedJavaTypeMapping( } @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new MariaDBJdbcLiteralFormatterVector<>( + javaTypeDescriptor, + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) ) + ); + } + + @Override + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { appender.append( writeExpression ); } + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "vec_fromtext(?1)" : null; + } + + @Override + public @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) { + return targetJdbcMapping.getJdbcType().isStringLike() ? "vec_totext(?1)" : null; + } + @Override public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { return new BasicExtractor<>( javaTypeDescriptor, this ) { @Override protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( rs.getObject( paramIndex, float[].class ), options ); + return getJavaType().wrap( rs.getObject( paramIndex, float[].class ), options ); } @Override protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( statement.getObject( index, float[].class ), options ); + return getJavaType().wrap( statement.getObject( index, float[].class ), options ); } @Override protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { - return javaTypeDescriptor.wrap( statement.getObject( name, float[].class ), options ); + return getJavaType().wrap( statement.getObject( name, float[].class ), options ); } }; @@ -73,19 +101,31 @@ public ValueBinder getBinder(final JavaType javaTypeDescriptor) { @Override protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { - st.setObject( index, value ); + st.setObject( index, getBindValue( value, options ) ); } @Override protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) throws SQLException { - st.setObject( name, value, java.sql.Types.ARRAY ); + st.setObject( name, getBindValue( value, options ), java.sql.Types.ARRAY ); } @Override public Object getBindValue(X value, WrapperOptions options) { - return value; + return getJavaType().unwrap( value, float[].class, options ); } }; } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof MariaDBVectorJdbcType vectorJdbcType + && sqlType == vectorJdbcType.sqlType; + } + + @Override + public int hashCode() { + return sqlType + 31 * super.hashCode(); + } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLFunctionContributor.java new file mode 100644 index 000000000000..0d93fb55128f --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLFunctionContributor.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.MySQLDialect; + +public class MySQLFunctionContributor implements FunctionContributor { + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final Dialect dialect = functionContributions.getDialect(); + if ( dialect instanceof MySQLDialect mySQLDialect && mySQLDialect.getMySQLVersion().isSameOrAfter( 9, 0 ) ) { + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); + + vectorFunctionFactory.cosineDistance( "distance(?1,?2,'cosine')" ); + vectorFunctionFactory.euclideanDistance( "distance(?1,?2,'euclidean')" ); + vectorFunctionFactory.innerProduct( "distance(?1,?2,'dot')*-1" ); + vectorFunctionFactory.negativeInnerProduct( "distance(?1,?2,'dot')" ); + + vectorFunctionFactory.registerNamedVectorFunction( + "vector_dim", + functionContributions.getTypeConfiguration().getBasicTypeForJavaType( Integer.class ), + 1 + ); + functionContributions.getFunctionRegistry().registerAlternateKey( "vector_dims", "vector_dim" ); + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLJdbcLiteralFormatterVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLJdbcLiteralFormatterVector.java new file mode 100644 index 000000000000..b70b1e7b837b --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLJdbcLiteralFormatterVector.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class MySQLJdbcLiteralFormatterVector extends BasicJdbcLiteralFormatter { + + private final JdbcLiteralFormatter elementFormatter; + + public MySQLJdbcLiteralFormatterVector(JavaType javaType, JdbcLiteralFormatter elementFormatter) { + super( javaType ); + //noinspection unchecked + this.elementFormatter = (JdbcLiteralFormatter) elementFormatter; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + final Object[] objects = unwrap( value, Object[].class, wrapperOptions ); + appender.appendSql( "string_to_vector('" ); + char separator = '['; + for ( Object o : objects ) { + appender.appendSql( separator ); + elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions ); + separator = ','; + } + appender.appendSql( "]')" ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLTypeContributor.java new file mode 100644 index 000000000000..d6bbfeba9528 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLTypeContributor.java @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.MySQLDialect; +import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.service.ServiceRegistry; +import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; +import org.hibernate.type.spi.TypeConfiguration; + +public class MySQLTypeContributor implements TypeContributor { + + @Override + public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { + final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); + if ( dialect instanceof MySQLDialect mySQLDialect && mySQLDialect.getMySQLVersion().isSameOrAfter( 9, 0 ) ) { + final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); + final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); + final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); + final ArrayJdbcType genericVectorJdbcType = new MySQLVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new MySQLVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT32 + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l)", "vector", dialect ) + ); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLVectorJdbcType.java new file mode 100644 index 000000000000..8c0d6714dfb2 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLVectorJdbcType.java @@ -0,0 +1,141 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; + +import static org.hibernate.vector.internal.VectorHelper.parseFloatVector; + +public class MySQLVectorJdbcType extends ArrayJdbcType { + + private final int sqlType; + + public MySQLVectorJdbcType(JdbcType elementJdbcType, int sqlType) { + super( elementJdbcType ); + this.sqlType = sqlType; + } + + @Override + public int getDefaultSqlTypeCode() { + return sqlType; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new MySQLJdbcLiteralFormatterVector<>( + javaTypeDescriptor, + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) ) + ); + } + + @Override + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { + appender.append( "string_to_vector(" ); + appender.append( writeExpression ); + appender.append( ')' ); + } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "string_to_vector(?1)" : null; + } + + @Override + public @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) { + return targetJdbcMapping.getJdbcType().isStringLike() ? "vector_to_string(?1)" : null; + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return getJavaType().wrap( parseFloatVector( rs.getBytes( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return getJavaType().wrap( parseFloatVector( statement.getBytes( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return getJavaType().wrap( parseFloatVector( statement.getBytes( name ) ), options ); + } + + }; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + st.setString( index, getBindValue( value, options ) ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setString( name, getBindValue( value, options ) ); + } + + @Override + public String getBindValue(X value, WrapperOptions options) { + return Arrays.toString( getJavaType().unwrap( value, float[].class, options ) ); + } + }; + } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof MySQLVectorJdbcType vectorJdbcType + && sqlType == vectorJdbcType.sqlType; + } + + @Override + public int hashCode() { + return sqlType + 31 * super.hashCode(); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleBinaryVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleBinaryVectorJdbcType.java new file mode 100644 index 000000000000..65067976cb50 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleBinaryVectorJdbcType.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.OracleTypes; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +import java.util.Arrays; + +/** + * Specialized type mapping for binary vector {@link SqlTypes#VECTOR_BINARY} SQL data type for Oracle. + */ +public class OracleBinaryVectorJdbcType extends AbstractOracleVectorJdbcType { + + public OracleBinaryVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getVectorParameters() { + return "*,binary"; + } + + @Override + public String getFriendlyName() { + return "VECTOR_BINARY"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_BINARY; + } + + @Override + protected byte[] getVectorArray(String string) { + return VectorHelper.parseByteVector( string ); + } + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return Arrays.toString( javaTypeDescriptor.unwrap( vector, byte[].class, options ) ); + } + + protected Class getNativeJavaType(){ + return byte[].class; + } + + protected int getNativeTypeCode(){ + return OracleTypes.VECTOR_BINARY; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleByteVectorJdbcType.java similarity index 53% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleByteVectorJdbcType.java index 76379fed45bc..6521cae13346 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleByteVectorJdbcType.java @@ -2,14 +2,11 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.Arrays; -import java.util.BitSet; -import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleTypes; -import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.java.JavaType; @@ -22,18 +19,13 @@ */ public class OracleByteVectorJdbcType extends AbstractOracleVectorJdbcType { - - private static final byte[] EMPTY = new byte[0]; - public OracleByteVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { super( elementJdbcType, isVectorSupported ); } @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "to_vector(" ); - appender.append( writeExpression ); - appender.append( ", *, INT8)" ); + public String getVectorParameters() { + return "*,int8"; } @Override @@ -48,31 +40,7 @@ public int getDefaultSqlTypeCode() { @Override protected byte[] getVectorArray(String string) { - if ( string == null ) { - return null; - } - if ( string.length() == 2 ) { - return EMPTY; - } - final BitSet commaPositions = new BitSet(); - int size = 1; - for ( int i = 1; i < string.length(); i++ ) { - final char c = string.charAt( i ); - if ( c == ',' ) { - commaPositions.set( i ); - size++; - } - } - final byte[] result = new byte[size]; - int doubleStartIndex = 1; - int commaIndex; - int index = 0; - while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { - result[index++] = Byte.parseByte( string.substring( doubleStartIndex, commaIndex ) ); - doubleStartIndex = commaIndex + 1; - } - result[index] = Byte.parseByte( string.substring( doubleStartIndex, string.length() - 1 ) ); - return result; + return VectorHelper.parseByteVector( string ); } @Override @@ -82,10 +50,10 @@ protected String getStringVector(T vector, JavaType javaTypeDescriptor, W protected Class getNativeJavaType(){ return byte[].class; - }; + } protected int getNativeTypeCode(){ return OracleTypes.VECTOR_INT8; - }; + } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleDoubleVectorJdbcType.java similarity index 53% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleDoubleVectorJdbcType.java index 9a2c07318ffb..c32ffaf4da58 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleDoubleVectorJdbcType.java @@ -2,14 +2,11 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.Arrays; -import java.util.BitSet; -import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleTypes; -import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.java.JavaType; @@ -22,18 +19,13 @@ */ public class OracleDoubleVectorJdbcType extends AbstractOracleVectorJdbcType { - private static final double[] EMPTY = new double[0]; - public OracleDoubleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { super( elementJdbcType, isVectorSupported ); } - @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "to_vector(" ); - appender.append( writeExpression ); - appender.append( ", *, FLOAT64)" ); + public String getVectorParameters() { + return "*,float64"; } @Override @@ -48,31 +40,7 @@ public int getDefaultSqlTypeCode() { @Override protected double[] getVectorArray(String string) { - if ( string == null ) { - return null; - } - if ( string.length() == 2 ) { - return EMPTY; - } - final BitSet commaPositions = new BitSet(); - int size = 1; - for ( int i = 1; i < string.length(); i++ ) { - final char c = string.charAt( i ); - if ( c == ',' ) { - commaPositions.set( i ); - size++; - } - } - final double[] result = new double[size]; - int doubleStartIndex = 1; - int commaIndex; - int index = 0; - while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { - result[index++] = Double.parseDouble( string.substring( doubleStartIndex, commaIndex ) ); - doubleStartIndex = commaIndex + 1; - } - result[index] = Double.parseDouble( string.substring( doubleStartIndex, string.length() - 1 ) ); - return result; + return VectorHelper.parseDoubleVector( string ); } @Override diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleFloatVectorJdbcType.java similarity index 53% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleFloatVectorJdbcType.java index acb06905c4b9..5fea12a422d1 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleFloatVectorJdbcType.java @@ -2,14 +2,11 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.Arrays; -import java.util.BitSet; -import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleTypes; -import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.WrapperOptions; import org.hibernate.type.descriptor.java.JavaType; @@ -23,18 +20,13 @@ public class OracleFloatVectorJdbcType extends AbstractOracleVectorJdbcType { - - private static final float[] EMPTY = new float[0]; - public OracleFloatVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { super( elementJdbcType, isVectorSupported ); } @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "to_vector(" ); - appender.append( writeExpression ); - appender.append( ", *, FLOAT32)" ); + public String getVectorParameters() { + return "*,float32"; } @Override @@ -49,31 +41,7 @@ public int getDefaultSqlTypeCode() { @Override protected float[] getVectorArray(String string) { - if ( string == null ) { - return null; - } - if ( string.length() == 2 ) { - return EMPTY; - } - final BitSet commaPositions = new BitSet(); - int size = 1; - for ( int i = 1; i < string.length(); i++ ) { - final char c = string.charAt( i ); - if ( c == ',' ) { - commaPositions.set( i ); - size++; - } - } - final float[] result = new float[size]; - int doubleStartIndex = 1; - int commaIndex; - int index = 0; - while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { - result[index++] = Float.parseFloat( string.substring( doubleStartIndex, commaIndex ) ); - doubleStartIndex = commaIndex + 1; - } - result[index] = Float.parseFloat( string.substring( doubleStartIndex, string.length() - 1 ) ); - return result; + return VectorHelper.parseFloatVector( string ); } @Override diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterSparseVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterSparseVector.java new file mode 100644 index 000000000000..b8ea7688199d --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterSparseVector.java @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; +import org.hibernate.vector.AbstractSparseVector; +import org.hibernate.vector.SparseByteVector; +import org.hibernate.vector.SparseDoubleVector; +import org.hibernate.vector.SparseFloatVector; + +public class OracleJdbcLiteralFormatterSparseVector extends BasicJdbcLiteralFormatter { + + private final String vectorParameters; + + public OracleJdbcLiteralFormatterSparseVector(JavaType javaType, String vectorParameters) { + super( javaType ); + this.vectorParameters = vectorParameters; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + appender.appendSql( "to_vector('" ); + final AbstractSparseVector sparseVector = unwrap( value, AbstractSparseVector.class, wrapperOptions ); + appender.appendSql( '[' ); + appender.appendSql( sparseVector.size() ); + appender.appendSql( ',' ); + char separator = '['; + for ( int index : sparseVector.indices() ) { + appender.appendSql( separator ); + appender.appendSql( index ); + separator = ','; + } + appender.appendSql( "]," ); + separator = '['; + if ( sparseVector instanceof SparseFloatVector floatVector ) { + for ( float f : floatVector.floats() ) { + appender.appendSql( separator ); + appender.appendSql( f ); + separator = ','; + } + } + else if ( sparseVector instanceof SparseDoubleVector doubleVector ) { + for ( double d : doubleVector.doubles() ) { + appender.appendSql( separator ); + appender.appendSql( d ); + separator = ','; + } + } + else if ( sparseVector instanceof SparseByteVector byteVector ) { + for ( byte b : byteVector.bytes() ) { + appender.appendSql( separator ); + appender.appendSql( b ); + separator = ','; + } + } + else { + throw new IllegalArgumentException( "Unsupported sparse vector type: " + sparseVector.getClass().getName() ); + } + appender.appendSql( "]]'," ); + appender.appendSql( vectorParameters ); + appender.appendSql( ')' ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterVector.java new file mode 100644 index 000000000000..27df530953df --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleJdbcLiteralFormatterVector.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class OracleJdbcLiteralFormatterVector extends BasicJdbcLiteralFormatter { + + private final JdbcLiteralFormatter elementFormatter; + private final String vectorParameters; + + public OracleJdbcLiteralFormatterVector(JavaType javaType, JdbcLiteralFormatter elementFormatter, String vectorParameters) { + super( javaType ); + //noinspection unchecked + this.elementFormatter = (JdbcLiteralFormatter) elementFormatter; + this.vectorParameters = vectorParameters; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + final Object[] objects = unwrap( value, Object[].class, wrapperOptions ); + appender.append( "to_vector('" ); + char separator = '['; + for ( Object o : objects ) { + appender.append( separator ); + elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions ); + separator = ','; + } + appender.append( "]'," ); + appender.append( vectorParameters ); + appender.append( ')' ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseByteVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseByteVectorJdbcType.java new file mode 100644 index 000000000000..9c1f77c19f30 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseByteVectorJdbcType.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import oracle.sql.VECTOR; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.vector.SparseByteVector; + +import java.sql.CallableStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +/** + * Specialized type mapping for sparse single-byte integer vector {@link SqlTypes#SPARSE_VECTOR_INT8} SQL data type for Oracle. + */ +public class OracleSparseByteVectorJdbcType extends AbstractOracleSparseVectorJdbcType { + + public OracleSparseByteVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getVectorParameters() { + return "*,int8,sparse"; + } + + @Override + public String getFriendlyName() { + return "SPARSE_VECTOR_INT8"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.SPARSE_VECTOR_INT8; + } + + @Override + protected Object getBindValue(JavaType javaType, X value, WrapperOptions options) { + if ( isVectorSupported ) { + final SparseByteVector sparseVector = javaType.unwrap( value, SparseByteVector.class, options ); + return VECTOR.SparseByteArray.of( sparseVector.size(), sparseVector.indices(), sparseVector.bytes() ); + } + else { + return getStringVector( value, javaType, options ); + } + } + + @Override + public ValueExtractor getExtractor(final JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( rs.getObject( paramIndex, VECTOR.SparseByteArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( rs.getString( paramIndex ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( index, VECTOR.SparseByteArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( index ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( name, VECTOR.SparseByteArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( name ) ), options ); + } + } + + private Object wrapNativeValue(VECTOR.SparseByteArray nativeValue) { + return nativeValue == null + ? null + : new SparseByteVector( nativeValue.length(), nativeValue.indices(), nativeValue.values() ); + } + + private Object wrapStringValue(String value) { + return ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( value ); + } + + }; + } + + @Override + protected SparseByteVector getVectorArray(String string) { + if ( string == null ) { + return null; + } + return new SparseByteVector( string ); + } + + @Override + protected Class getNativeJavaType() { + return VECTOR.SparseByteArray.class; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseDoubleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseDoubleVectorJdbcType.java new file mode 100644 index 000000000000..901057cbd753 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseDoubleVectorJdbcType.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import oracle.sql.VECTOR; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.vector.SparseDoubleVector; + +import java.sql.CallableStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +/** + * Specialized type mapping for sparse double-precision floating-point vector {@link SqlTypes#SPARSE_VECTOR_FLOAT64} SQL data type for Oracle. + */ +public class OracleSparseDoubleVectorJdbcType extends AbstractOracleSparseVectorJdbcType { + + public OracleSparseDoubleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getVectorParameters() { + return "*,float64,sparse"; + } + + @Override + public String getFriendlyName() { + return "SPARSE_VECTOR_FLOAT64"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.SPARSE_VECTOR_FLOAT64; + } + + @Override + protected Object getBindValue(JavaType javaType, X value, WrapperOptions options) { + if ( isVectorSupported ) { + final SparseDoubleVector sparseVector = javaType.unwrap( value, SparseDoubleVector.class, options ); + return VECTOR.SparseDoubleArray.of( sparseVector.size(), sparseVector.indices(), sparseVector.doubles() ); + } + else { + return getStringVector( value, javaType, options ); + } + } + + @Override + public ValueExtractor getExtractor(final JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( rs.getObject( paramIndex, VECTOR.SparseDoubleArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( rs.getString( paramIndex ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( index, VECTOR.SparseDoubleArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( index ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( name, VECTOR.SparseDoubleArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( name ) ), options ); + } + } + + private Object wrapNativeValue(VECTOR.SparseDoubleArray nativeValue) { + return nativeValue == null + ? null + : new SparseDoubleVector( nativeValue.length(), nativeValue.indices(), nativeValue.values() ); + } + + private Object wrapStringValue(String value) { + return ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( value ); + } + + }; + } + + @Override + protected SparseDoubleVector getVectorArray(String string) { + if ( string == null ) { + return null; + } + return new SparseDoubleVector( string ); + } + + @Override + protected Class getNativeJavaType() { + return VECTOR.SparseDoubleArray.class; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseFloatVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseFloatVectorJdbcType.java new file mode 100644 index 000000000000..2425c15f9bf7 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleSparseFloatVectorJdbcType.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import oracle.sql.VECTOR; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.vector.SparseFloatVector; + +import java.sql.CallableStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +/** + * Specialized type mapping for sparse single-precision floating-point vector {@link SqlTypes#SPARSE_VECTOR_FLOAT32} SQL data type for Oracle. + */ +public class OracleSparseFloatVectorJdbcType extends AbstractOracleSparseVectorJdbcType { + + public OracleSparseFloatVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getVectorParameters() { + return "*,float32,sparse"; + } + + @Override + public String getFriendlyName() { + return "SPARSE_VECTOR_FLOAT32"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.SPARSE_VECTOR_FLOAT32; + } + + @Override + protected Object getBindValue(JavaType javaType, X value, WrapperOptions options) { + if ( isVectorSupported ) { + final SparseFloatVector sparseVector = javaType.unwrap( value, SparseFloatVector.class, options ); + return VECTOR.SparseFloatArray.of( sparseVector.size(), sparseVector.indices(), sparseVector.floats() ); + } + else { + return getStringVector( value, javaType, options ); + } + } + + @Override + public ValueExtractor getExtractor(final JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( rs.getObject( paramIndex, VECTOR.SparseFloatArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( rs.getString( paramIndex ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( index, VECTOR.SparseFloatArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( index ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( wrapNativeValue( statement.getObject( name, VECTOR.SparseFloatArray.class ) ), options ); + } + else { + return getJavaType().wrap( wrapStringValue( statement.getString( name ) ), options ); + } + } + + private Object wrapNativeValue(VECTOR.SparseFloatArray nativeValue) { + return nativeValue == null + ? null + : new SparseFloatVector( nativeValue.length(), nativeValue.indices(), nativeValue.values() ); + } + + private Object wrapStringValue(String value) { + return ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( value ); + } + + }; + } + + @Override + protected SparseFloatVector getVectorArray(String string) { + if ( string == null ) { + return null; + } + return new SparseFloatVector( string ); + } + + @Override + protected Class getNativeJavaType() { + return VECTOR.SparseFloatArray.class; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorFunctionContributor.java similarity index 75% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorFunctionContributor.java index 69572ac79c7d..770e8c2cacda 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorFunctionContributor.java @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.FunctionContributor; @@ -14,19 +14,22 @@ public class OracleVectorFunctionContributor implements FunctionContributor { @Override public void contributeFunctions(FunctionContributions functionContributions) { final Dialect dialect = functionContributions.getDialect(); - if ( dialect instanceof OracleDialect ) { + if ( dialect instanceof OracleDialect && dialect.getVersion().isSameOrAfter( 23, 4 ) ) { final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); vectorFunctionFactory.cosineDistance( "vector_distance(?1,?2,COSINE)" ); vectorFunctionFactory.euclideanDistance( "vector_distance(?1,?2,EUCLIDEAN)" ); + vectorFunctionFactory.euclideanSquaredDistance( "vector_distance(?1,?2,EUCLIDEAN_SQUARED)" ); vectorFunctionFactory.l1Distance( "vector_distance(?1,?2,MANHATTAN)" ); vectorFunctionFactory.hammingDistance( "vector_distance(?1,?2,HAMMING)" ); + vectorFunctionFactory.jaccardDistance( "vector_distance(?1,?2,JACCARD)" ); vectorFunctionFactory.innerProduct( "vector_distance(?1,?2,DOT)*-1" ); vectorFunctionFactory.negativeInnerProduct( "vector_distance(?1,?2,DOT)" ); vectorFunctionFactory.vectorDimensions(); vectorFunctionFactory.vectorNorm(); + functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "vector_norm" ); } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorJdbcType.java similarity index 75% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorJdbcType.java index cdfa8dd219f9..f4b9e74cceac 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorJdbcType.java @@ -2,11 +2,9 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; -import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleTypes; -import org.hibernate.sql.ast.spi.SqlAppender; import org.hibernate.type.SqlTypes; import org.hibernate.type.descriptor.jdbc.JdbcType; @@ -20,7 +18,6 @@ */ public class OracleVectorJdbcType extends OracleFloatVectorJdbcType { - public OracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { super( elementJdbcType, isVectorSupported ); } @@ -31,10 +28,8 @@ public String getFriendlyName() { } @Override - public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { - appender.append( "to_vector(" ); - appender.append( writeExpression ); - appender.append( ", *, *)" ); + public String getVectorParameters() { + return "*,*"; } @Override diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorTypeContributor.java similarity index 55% rename from hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorTypeContributor.java index 480bb58ef372..4e6347b97dfe 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/OracleVectorTypeContributor.java @@ -2,24 +2,22 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.TypeContributions; import org.hibernate.boot.model.TypeContributor; import org.hibernate.dialect.Dialect; import org.hibernate.dialect.OracleDialect; -import org.hibernate.engine.jdbc.Size; import org.hibernate.engine.jdbc.spi.JdbcServices; -import org.hibernate.internal.util.StringHelper; import org.hibernate.service.ServiceRegistry; import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicCollectionType; import org.hibernate.type.BasicTypeRegistry; import org.hibernate.type.SqlTypes; import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; import org.hibernate.type.descriptor.jdbc.JdbcType; import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; -import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.spi.TypeConfiguration; public class OracleVectorTypeContributor implements TypeContributor { @@ -57,7 +55,30 @@ public void contribute(TypeContributions typeContributions, ServiceRegistry serv isVectorSupported ); jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_INT8, byteVectorJdbcType ); + final JdbcType bitVectorJdbcType = new OracleBinaryVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_BINARY, bitVectorJdbcType ); + final JdbcType sparseByteVectorJdbcType = new OracleSparseByteVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.SPARSE_VECTOR_INT8, sparseByteVectorJdbcType ); + final JdbcType sparseFloatVectorJdbcType = new OracleSparseFloatVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.SPARSE_VECTOR_FLOAT32, sparseFloatVectorJdbcType ); + final JdbcType sparseDoubleVectorJdbcType = new OracleSparseDoubleVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.DOUBLE ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.SPARSE_VECTOR_FLOAT64, sparseDoubleVectorJdbcType ); + javaTypeRegistry.addDescriptor( SparseByteVectorJavaType.INSTANCE ); + javaTypeRegistry.addDescriptor( SparseFloatVectorJavaType.INSTANCE ); + javaTypeRegistry.addDescriptor( SparseDoubleVectorJavaType.INSTANCE ); // Resolving basic types after jdbc types are registered. basicTypeRegistry.register( @@ -92,62 +113,66 @@ public void contribute(TypeContributions typeContributions, ServiceRegistry serv ), StandardBasicTypes.VECTOR_INT8.getName() ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.BYTE ), + bitVectorJdbcType, + javaTypeRegistry.getDescriptor( byte[].class ) + ), + StandardBasicTypes.VECTOR_BINARY.getName() + ); + basicTypeRegistry.register( + new BasicCollectionType<>( + basicTypeRegistry.resolve( StandardBasicTypes.BYTE ), + sparseByteVectorJdbcType, + SparseByteVectorJavaType.INSTANCE, + "sparse_byte_vector" + ) + ); + basicTypeRegistry.register( + new BasicCollectionType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + sparseFloatVectorJdbcType, + SparseFloatVectorJavaType.INSTANCE, + "sparse_float_vector" + ) + ); + basicTypeRegistry.register( + new BasicCollectionType<>( + basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE ), + sparseDoubleVectorJdbcType, + SparseDoubleVectorJavaType.INSTANCE, + "sparse_double_vector" + ) + ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR, "vector($l, *)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return OracleVectorTypeContributor.replace( - "vector($l, *)", - size.getArrayLength() == null ? null : size.getArrayLength().longValue() - ); - } - } + new VectorDdlType( SqlTypes.VECTOR, "vector($l,*)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_INT8, "vector($l,int8)", "vector", dialect ) ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR_INT8, "vector($l, INT8)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return OracleVectorTypeContributor.replace( - "vector($l, INT8)", - size.getArrayLength() == null ? null : size.getArrayLength().longValue() - ); - } - } + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l,float32)", "vector", dialect ) ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR_FLOAT32, "vector($l, FLOAT32)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return OracleVectorTypeContributor.replace( - "vector($l, FLOAT32)", - size.getArrayLength() == null ? null : size.getArrayLength().longValue() - ); - } - } + new VectorDdlType( SqlTypes.VECTOR_FLOAT64, "vector($l,float64)", "vector", dialect ) ); typeConfiguration.getDdlTypeRegistry().addDescriptor( - new DdlTypeImpl( SqlTypes.VECTOR_FLOAT64, "vector($l, FLOAT64)", "vector", dialect ) { - @Override - public String getTypeName(Size size) { - return OracleVectorTypeContributor.replace( - "vector($l, FLOAT64)", - size.getArrayLength() == null ? null : size.getArrayLength().longValue() - ); - } - } + new VectorDdlType( SqlTypes.VECTOR_BINARY, "vector($l,binary)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.SPARSE_VECTOR_INT8, "vector($l,int8,sparse)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.SPARSE_VECTOR_FLOAT32, "vector($l,float32,sparse)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.SPARSE_VECTOR_FLOAT64, "vector($l,float64,sparse)", "vector", dialect ) ); } } - - /** - * Replace vector dimension with the length or * for undefined length - */ - private static String replace(String type, Long size) { - return StringHelper.replaceOnce( type, "$l", size != null ? size.toString() : "*" ); - } - private boolean isVectorSupportedByDriver(OracleDialect dialect) { int majorVersion = dialect.getDriverMajorVersion(); int minorVersion = dialect.getDriverMinorVersion(); diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGBinaryVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGBinaryVectorJdbcType.java new file mode 100644 index 000000000000..7969eb8cadb1 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGBinaryVectorJdbcType.java @@ -0,0 +1,109 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; + +import static org.hibernate.vector.internal.VectorHelper.parseBitString; +import static org.hibernate.vector.internal.VectorHelper.toBitString; + +public class PGBinaryVectorJdbcType extends ArrayJdbcType { + + public PGBinaryVectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_BINARY; + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new PGVectorJdbcLiteralFormatterBinaryVector<>( javaTypeDescriptor ); + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( byte[].class ); + } + +// @Override +// public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { +// appender.append( "cast(" ); +// appender.append( writeExpression ); +// appender.append( " as varbit)" ); +// } +// +// @Override +// public boolean isWriteExpressionTyped(Dialect dialect) { +// return true; +// } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "cast(?1 as varbit)" : null; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) + throws SQLException { + st.setObject( index, toBitString( getJavaType().unwrap( value, byte[].class, options ) ), Types.OTHER ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setObject( name, toBitString( getJavaType().unwrap( value, byte[].class, options ) ), Types.OTHER ); + } + + }; + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseBitString( rs.getString( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseBitString( statement.getString( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseBitString( statement.getString( name ) ), options ); + } + }; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGSparseFloatVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGSparseFloatVectorJdbcType.java new file mode 100644 index 000000000000..1d64e32cbc43 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGSparseFloatVectorJdbcType.java @@ -0,0 +1,184 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; +import org.hibernate.vector.SparseFloatVector; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +public class PGSparseFloatVectorJdbcType extends ArrayJdbcType { + + public PGSparseFloatVectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.SPARSE_VECTOR_FLOAT32; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new PGVectorJdbcLiteralFormatterSparseVector<>( javaTypeDescriptor ); + } + + @Override + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { + appender.append( "cast(" ); + appender.append( writeExpression ); + appender.append( " as sparsevec)" ); + } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "cast(?1 as sparsevec)" : null; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + st.setString( index, getString( value, options ) ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setString( name, getString( value, options ) ); + } + + @Override + public Object getBindValue(X value, WrapperOptions options) { + return getString( value, options ); + } + + private String getString(X value, WrapperOptions options) { + final SparseFloatVector vector = getJavaType().unwrap( value, SparseFloatVector.class, options ); + final int size = vector.size(); + final int[] indices = vector.indices(); + final float[] floats = vector.floats(); + final StringBuilder sb = new StringBuilder( indices.length * 50 ); + char separator = '{'; + for ( int i = 0; i < indices.length; i++ ) { + sb.append( separator ); + // The sparvec format is 1 based + sb.append( indices[i] + 1 ); + sb.append( ':' ); + sb.append( floats[i] ); + separator = ','; + } + sb.append("}/"); + sb.append( size ); + return sb.toString(); + } + }; + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseSparseFloatVector( rs.getString( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseSparseFloatVector( statement.getString( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseSparseFloatVector( statement.getString( name ) ), options ); + } + }; + } + + /** + * Parses the pgvector sparsevec format `{idx1:val1,idx2:val2}/size`. + */ + private static @Nullable SparseFloatVector parseSparseFloatVector(@Nullable String string) { + if ( string == null ) { + return null; + } + + final int slashIndex = string.lastIndexOf( '/' ); + if ( string.charAt( 0 ) != '{' || slashIndex == -1 || string.charAt( slashIndex - 1 ) != '}' ) { + throw new IllegalArgumentException( "Invalid sparse vector string: " + string ); + } + final int size = Integer.parseInt( string, slashIndex + 1, string.length(), 10 ); + final int end = slashIndex - 1; + final int count = countValues( string, end ); + final int[] indices = new int[count]; + final float[] values = new float[count]; + int start = 1; + int index = 0; + for ( int i = start; i < end; i++ ) { + final char c = string.charAt( i ); + if ( c == ':' ) { + // Indices are 1 based in this format, but we need a zero base + indices[index] = Integer.parseInt( string, start, i, 10 ) - 1; + start = i + 1; + } + else if ( c == ',' ) { + values[index++] = Float.parseFloat( string.substring( start, i ) ); + start = i + 1; + } + } + if ( start != end ) { + values[index] = Float.parseFloat( string.substring( start, end ) ); + assert count == index + 1; + } + return new SparseFloatVector( size, indices, values ); + } + + private static int countValues(String string, int end) { + int count = 0; + for ( int i = 1; i < end; i++ ) { + if ( string.charAt( i ) == ':' ) { + count++; + } + } + return count; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorDimsFunction.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorDimsFunction.java new file mode 100644 index 000000000000..d6d8cb0119ce --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorDimsFunction.java @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.metamodel.model.domain.ReturnableType; +import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.List; + +public class PGVectorDimsFunction extends AbstractSqmSelfRenderingFunctionDescriptor { + public PGVectorDimsFunction(TypeConfiguration typeConfiguration) { + super( + "vector_dims", + StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 1 ), + VectorArgumentValidator.INSTANCE + ), + StandardFunctionReturnTypeResolvers.invariant( typeConfiguration.getBasicTypeForJavaType( Integer.class ) ), + VectorArgumentTypeResolver.INSTANCE + ); + } + + @Override + public void render(SqlAppender sqlAppender, List sqlAstArguments, ReturnableType returnType, SqlAstTranslator walker) { + final Expression expression = (Expression) sqlAstArguments.get( 0 ); + final int sqlTypeCode = + expression.getExpressionType().getSingleJdbcMapping().getJdbcType().getDefaultSqlTypeCode(); + if ( sqlTypeCode == SqlTypes.SPARSE_VECTOR_FLOAT32 ) { + sqlAppender.append( "cast(split_part(cast(" ); + expression.accept( walker ); + sqlAppender.append( " as text),'/',2) as integer)" ); + } + else { + if ( sqlTypeCode == SqlTypes.VECTOR_BINARY ) { + sqlAppender.append( "length" ); + } + else { + sqlAppender.append( "vector_dims" ); + } + sqlAppender.append( '(' ); + expression.accept( walker ); + sqlAppender.append( ')' ); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorFunctionContributor.java new file mode 100644 index 000000000000..66306832861d --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorFunctionContributor.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.spi.TypeConfiguration; + +import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER; + +public class PGVectorFunctionContributor implements FunctionContributor { + + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final Dialect dialect = functionContributions.getDialect(); + if ( dialect instanceof PostgreSQLDialect ) { + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); + + vectorFunctionFactory.cosineDistance( "?1<=>?2" ); + vectorFunctionFactory.euclideanDistance( "?1<->?2" ); + vectorFunctionFactory.euclideanSquaredDistance( "(?1<->?2)^2" ); + vectorFunctionFactory.l1Distance( "l1_distance(?1,?2)" ); + vectorFunctionFactory.hammingDistance( "?1<~>?2" ); + vectorFunctionFactory.jaccardDistance( "?1<%>?2" ); + + vectorFunctionFactory.innerProduct( "(?1<#>?2)*-1" ); + vectorFunctionFactory.negativeInnerProduct( "?1<#>?2" ); + + final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); + functionContributions.getFunctionRegistry() + .register( "vector_dims", new PGVectorDimsFunction( typeConfiguration ) ); + functionContributions.getFunctionRegistry() + .register( "vector_norm", new PGVectorNormFunction( typeConfiguration ) ); + + functionContributions.getFunctionRegistry().namedDescriptorBuilder( "binary_quantize" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 1 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( + typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.VECTOR_BINARY ) + ) ) + .register(); + functionContributions.getFunctionRegistry().namedDescriptorBuilder( "subvector" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 3 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( StandardFunctionArgumentTypeResolvers.byArgument( + VectorArgumentTypeResolver.INSTANCE, + StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ), + StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ) + ) ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) ) + .register(); + functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "vector_norm" ); + functionContributions.getFunctionRegistry().namedDescriptorBuilder( "l2_normalize" ) + .setArgumentsValidator( VectorArgumentValidator.INSTANCE ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) ) + .register(); + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterBinaryVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterBinaryVector.java new file mode 100644 index 000000000000..237305493084 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterBinaryVector.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class PGVectorJdbcLiteralFormatterBinaryVector extends BasicJdbcLiteralFormatter { + + public PGVectorJdbcLiteralFormatterBinaryVector(JavaType javaType) { + super( javaType ); + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + appender.append( "cast('" ); + appender.append( VectorHelper.toBitString( unwrap( value, byte[].class, wrapperOptions ) ) ); + appender.append( "' as varbit)" ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterSparseVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterSparseVector.java new file mode 100644 index 000000000000..baefc9f974a9 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterSparseVector.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; +import org.hibernate.vector.AbstractSparseVector; +import org.hibernate.vector.SparseByteVector; +import org.hibernate.vector.SparseDoubleVector; +import org.hibernate.vector.SparseFloatVector; + +public class PGVectorJdbcLiteralFormatterSparseVector extends BasicJdbcLiteralFormatter { + + public PGVectorJdbcLiteralFormatterSparseVector(JavaType javaType) { + super( javaType ); + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + appender.append( "cast('" ); + final AbstractSparseVector sparseVector = unwrap( value, AbstractSparseVector.class, wrapperOptions ); + char separator = '{'; + final int[] indices = sparseVector.indices(); + if ( sparseVector instanceof SparseFloatVector floatVector ) { + final float[] floats = floatVector.floats(); + for ( int i = 0; i < floats.length; i++ ) { + appender.appendSql( separator ); + appender.appendSql( indices[i] + 1 ); + appender.appendSql( ':' ); + appender.appendSql( floats[i] ); + separator = ','; + } + } + else if ( sparseVector instanceof SparseDoubleVector doubleVector ) { + final double[] doubles = doubleVector.doubles(); + for ( int i = 0; i < doubles.length; i++ ) { + appender.appendSql( separator ); + appender.appendSql( indices[i] + 1 ); + appender.appendSql( ':' ); + appender.appendSql( doubles[i] ); + separator = ','; + } + } + else if ( sparseVector instanceof SparseByteVector byteVector ) { + final byte[] bytes = byteVector.bytes(); + for ( int i = 0; i < bytes.length; i++ ) { + appender.appendSql( separator ); + appender.appendSql( indices[i] + 1 ); + appender.appendSql( ':' ); + appender.appendSql( bytes[i] ); + separator = ','; + } + } + else { + throw new IllegalArgumentException( "Unsupported sparse vector type: " + sparseVector.getClass().getName() ); + } + appender.append( "}/" ); + appender.appendSql( sparseVector.size() ); + appender.append( "' as sparsevec)" ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterVector.java new file mode 100644 index 000000000000..fe1205a1d95b --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcLiteralFormatterVector.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class PGVectorJdbcLiteralFormatterVector extends BasicJdbcLiteralFormatter { + + private final JdbcLiteralFormatter elementFormatter; + + public PGVectorJdbcLiteralFormatterVector(JavaType javaType, JdbcLiteralFormatter elementFormatter) { + super( javaType ); + //noinspection unchecked + this.elementFormatter = (JdbcLiteralFormatter) elementFormatter; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + final Object[] objects = unwrap( value, Object[].class, wrapperOptions ); + appender.appendSql( "cast('" ); + char separator = '['; + for ( Object o : objects ) { + appender.appendSql( separator ); + elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions ); + separator = ','; + } + appender.appendSql( "]' as vector)" ); + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcType.java new file mode 100644 index 000000000000..1118f22c096d --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcType.java @@ -0,0 +1,113 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.sql.CallableStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import static org.hibernate.vector.internal.VectorHelper.parseFloatVector; + +public class PGVectorJdbcType extends ArrayJdbcType { + + private final int sqlType; + private final String typeName; + + public PGVectorJdbcType(JdbcType elementJdbcType, int sqlType, String typeName) { + super( elementJdbcType ); + this.sqlType = sqlType; + this.typeName = typeName; + } + + @Override + public int getDefaultSqlTypeCode() { + return sqlType; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new PGVectorJdbcLiteralFormatterVector<>( + javaTypeDescriptor, + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) ) + ); + } + + @Override + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { + appender.append( "cast(" ); + appender.append( writeExpression ); + appender.append( " as " ); + appender.append( typeName ); + appender.append( ')' ); + } + + @Override + public boolean isWriteExpressionTyped(Dialect dialect) { + return true; + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "cast(?1 as " + typeName + ")" : null; + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( rs.getString( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( name ) ), options ); + } + }; + } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof PGVectorJdbcType vectorJdbcType + && sqlType == vectorJdbcType.sqlType; + } + + @Override + public int hashCode() { + return sqlType + 31 * super.hashCode(); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorNormFunction.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorNormFunction.java new file mode 100644 index 000000000000..814a2b07c318 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorNormFunction.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.metamodel.model.domain.ReturnableType; +import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.SqlAstNode; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.spi.TypeConfiguration; + +import java.util.List; + +public class PGVectorNormFunction extends AbstractSqmSelfRenderingFunctionDescriptor { + public PGVectorNormFunction(TypeConfiguration typeConfiguration) { + super( + "vector_norm", + StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 1 ), + VectorArgumentValidator.INSTANCE + ), + StandardFunctionReturnTypeResolvers.invariant( typeConfiguration.getBasicTypeForJavaType( Double.class ) ), + VectorArgumentTypeResolver.INSTANCE + ); + } + + @Override + public void render(SqlAppender sqlAppender, List sqlAstArguments, ReturnableType returnType, SqlAstTranslator walker) { + final Expression expression = (Expression) sqlAstArguments.get( 0 ); + sqlAppender.append( + switch ( expression.getExpressionType().getSingleJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) { + case SqlTypes.SPARSE_VECTOR_FLOAT32, SqlTypes.VECTOR_FLOAT16 -> "l2_norm"; + default -> "vector_norm"; + } + ); + sqlAppender.append( '(' ); + expression.accept( walker ); + sqlAppender.append( ')' ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorTypeContributor.java new file mode 100644 index 000000000000..352bc4bb28ad --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorTypeContributor.java @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.service.ServiceRegistry; +import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicCollectionType; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; +import org.hibernate.type.spi.TypeConfiguration; + +public class PGVectorTypeContributor implements TypeContributor { + + @Override + public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { + final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); + if ( dialect instanceof PostgreSQLDialect ) { + final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); + final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); + final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); + final ArrayJdbcType genericVectorJdbcType = new PGVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR, + "vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = new PGVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT32, + "vector" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + final ArrayJdbcType float16VectorJdbcType = new PGVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + SqlTypes.VECTOR_FLOAT16, + "halfvec" + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT16, float16VectorJdbcType ); + final JdbcType bitVectorJdbcType = new PGBinaryVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ) + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_BINARY, bitVectorJdbcType ); + final JdbcType sparseFloatVectorJdbcType = new PGSparseFloatVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.SPARSE_VECTOR_FLOAT32, sparseFloatVectorJdbcType ); + + javaTypeRegistry.addDescriptor( SparseFloatVectorJavaType.INSTANCE ); + + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + float16VectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT16.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.BYTE ), + bitVectorJdbcType, + javaTypeRegistry.getDescriptor( byte[].class ) + ), + StandardBasicTypes.VECTOR_BINARY.getName() + ); + basicTypeRegistry.register( + new BasicCollectionType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + sparseFloatVectorJdbcType, + SparseFloatVectorJavaType.INSTANCE, + "sparse_float_vector" + ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_BINARY, "bit($l)", "bit", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT16, "halfvec($l)", "halfvec", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.SPARSE_VECTOR_FLOAT32, "sparsevec($l)", "sparsevec", dialect ) + ); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerCastingVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerCastingVectorJdbcType.java new file mode 100644 index 000000000000..3171f2f8e779 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerCastingVectorJdbcType.java @@ -0,0 +1,158 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.metamodel.mapping.JdbcMapping; +import org.hibernate.metamodel.mapping.JdbcMappingContainer; +import org.hibernate.sql.ast.SqlAstTranslator; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.sql.ast.tree.expression.SelfRenderingExpression; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; + +import static org.hibernate.vector.internal.VectorHelper.parseFloatVector; + +public class SQLServerCastingVectorJdbcType extends ArrayJdbcType { + + private final int sqlType; + + public SQLServerCastingVectorJdbcType(JdbcType elementJdbcType, int sqlType) { + super( elementJdbcType ); + this.sqlType = sqlType; + } + + @Override + public int getDefaultSqlTypeCode() { + return sqlType; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + return new SQLServerJdbcLiteralFormatterVector<>( + javaTypeDescriptor, + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) ) + ); + } + + @Override + public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) { + return sourceMapping.getJdbcType().isStringLike() ? "cast(?1 as vector(" + size.getArrayLength() + "))" : null; + } + + @Override + public Expression wrapTopLevelSelectionExpression(Expression expression) { + return new SelfRenderingExpression() { + @Override + public void renderToSql( + SqlAppender sqlAppender, + SqlAstTranslator walker, + SessionFactoryImplementor sessionFactory) { + sqlAppender.append( "cast(" ); + expression.accept( walker ); + sqlAppender.append( " as nvarchar(max))" ); + } + + @Override + public JdbcMappingContainer getExpressionType() { + return expression.getExpressionType(); + } + }; + } + + @Override + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { + appender.appendSql( "cast(" ); + appender.appendSql( writeExpression ); + appender.appendSql( " as vector(" ); + appender.appendSql( size.getArrayLength() ); + appender.appendSql( "))" ); + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( rs.getString( paramIndex ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( index ) ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( name ) ), options ); + } + + }; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + st.setString( index, getBindValue( value, options ) ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setString( name, getBindValue( value, options ) ); + } + + @Override + public String getBindValue(X value, WrapperOptions options) { + return Arrays.toString( getJavaType().unwrap( value, float[].class, options ) ); + } + }; + } + + @Override + public boolean equals(Object that) { + return super.equals( that ) + && that instanceof SQLServerCastingVectorJdbcType vectorJdbcType + && sqlType == vectorJdbcType.sqlType; + } + + @Override + public int hashCode() { + return sqlType + 31 * super.hashCode(); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerJdbcLiteralFormatterVector.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerJdbcLiteralFormatterVector.java new file mode 100644 index 000000000000..77543af76ab5 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerJdbcLiteralFormatterVector.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter; + +public class SQLServerJdbcLiteralFormatterVector extends BasicJdbcLiteralFormatter { + + private final JdbcLiteralFormatter elementFormatter; + + public SQLServerJdbcLiteralFormatterVector(JavaType javaType, JdbcLiteralFormatter elementFormatter) { + super( javaType ); + //noinspection unchecked + this.elementFormatter = (JdbcLiteralFormatter) elementFormatter; + } + + @Override + public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) { + final Object[] objects = unwrap( value, Object[].class, wrapperOptions ); + appender.appendSql( "cast('" ); + char separator = '['; + for ( Object o : objects ) { + appender.appendSql( separator ); + elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions ); + separator = ','; + } + appender.appendSql( "]' as vector(" ); + appender.appendSql( objects.length ); + appender.appendSql( "))" ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerTypeContributor.java new file mode 100644 index 000000000000..9c095e51871c --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerTypeContributor.java @@ -0,0 +1,108 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.HibernateError; +import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; +import org.hibernate.boot.registry.classloading.spi.ClassLoaderService; +import org.hibernate.boot.registry.classloading.spi.ClassLoadingException; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.SQLServerDialect; +import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.service.ServiceRegistry; +import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; +import org.hibernate.type.spi.TypeConfiguration; + +import java.lang.reflect.InvocationTargetException; + +public class SQLServerTypeContributor implements TypeContributor { + + @Override + public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { + final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); + if ( dialect instanceof SQLServerDialect && dialect.getVersion().isSameOrAfter( 17 ) ) { + final boolean supportsDriverType = supportsDriverType( serviceRegistry ); + final String vectorJdbcType = supportsDriverType ? "org.hibernate.vector.internal.SQLServerVectorJdbcType" + : "org.hibernate.vector.internal.SQLServerCastingVectorJdbcType"; + final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); + final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); + final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); + final JdbcType floatJdbcType = jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ); + final ArrayJdbcType genericVectorJdbcType = create( + serviceRegistry, + vectorJdbcType, + floatJdbcType, + SqlTypes.VECTOR + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final ArrayJdbcType floatVectorJdbcType = create( + serviceRegistry, + vectorJdbcType, + floatJdbcType, + SqlTypes.VECTOR_FLOAT32 + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "vector($l)", "vector", dialect ) + ); + } + } + + private static boolean supportsDriverType(ServiceRegistry serviceRegistry) { + final ClassLoaderService classLoaderService = serviceRegistry.requireService( ClassLoaderService.class ); + try { + classLoaderService.classForName( "microsoft.sql.Vector" ); + return true; + } + catch (ClassLoadingException ex) { + return false; + } + } + + private static X create(ServiceRegistry serviceRegistry, String className, JdbcType elementType, int sqlType) { + final ClassLoaderService classLoaderService = serviceRegistry.requireService( ClassLoaderService.class ); + try { + return classLoaderService.classForName( className ) + .getConstructor( JdbcType.class, int.class ) + .newInstance( elementType, sqlType ); + } + catch (NoSuchMethodException e) { + throw new HibernateError( "Class does not have an empty constructor", e ); + } + catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new HibernateError( "Could not construct JdbcType", e ); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerVectorFunctionContributor.java new file mode 100644 index 000000000000..1940e2350d06 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerVectorFunctionContributor.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.SQLServerDialect; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.type.BasicType; +import org.hibernate.type.spi.TypeConfiguration; + +public class SQLServerVectorFunctionContributor implements FunctionContributor { + + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final Dialect dialect = functionContributions.getDialect(); + if ( dialect instanceof SQLServerDialect && dialect.getVersion().isSameOrAfter( 17 ) ) { + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); + + vectorFunctionFactory.cosineDistance( "vector_distance('cosine',?1,?2)" ); + vectorFunctionFactory.euclideanDistance( "vector_distance('euclidean',?1,?2)" ); + vectorFunctionFactory.euclideanSquaredDistance( "square(vector_distance('euclidean',?1,?2))" ); + + vectorFunctionFactory.innerProduct( "vector_distance('dot',?1,?2)*-1" ); + vectorFunctionFactory.negativeInnerProduct( "vector_distance('dot',?1,?2)" ); + + final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); + final BasicType integerType = typeConfiguration.getBasicTypeForJavaType( Integer.class ); + final BasicType doubleType = typeConfiguration.getBasicTypeForJavaType( Double.class ); + + vectorFunctionFactory.registerPatternVectorFunction( "vector_dims", "vectorproperty(?1,'Dimensions')", integerType, 1 ); + vectorFunctionFactory.registerPatternVectorFunction( "vector_norm", "vector_norm(?1,'norm2')", doubleType, 1 ); + functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "vector_norm" ); + functionContributions.getFunctionRegistry().patternDescriptorBuilder( "l2_normalize", "vector_normalize(?1,'norm2')" ) + .setArgumentsValidator( VectorArgumentValidator.INSTANCE ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) ) + .register(); + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerVectorJdbcType.java new file mode 100644 index 000000000000..04e0f1fafc17 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SQLServerVectorJdbcType.java @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import microsoft.sql.Vector; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.sql.ast.tree.expression.Expression; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +public class SQLServerVectorJdbcType extends SQLServerCastingVectorJdbcType { + + public SQLServerVectorJdbcType(JdbcType elementJdbcType, int sqlType) { + super( elementJdbcType, sqlType ); + } + + @Override + public Expression wrapTopLevelSelectionExpression(Expression expression) { + return expression; + } + + @Override + public void appendWriteExpression( + String writeExpression, + @Nullable Size size, + SqlAppender appender, + Dialect dialect) { + appender.append( writeExpression ); + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return getValue( rs.getObject( paramIndex, Vector.class ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return getValue( statement.getObject( index, Vector.class ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return getValue( statement.getObject( name, Vector.class ), options ); + } + + private X getValue(Vector vector, WrapperOptions options) { + if ( vector == null ) { + return null; + } + return getJavaType().wrap( vector.getData(), options ); + } + + }; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + st.setObject( index, getBindValue( value, options ), microsoft.sql.Types.VECTOR ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setObject( name, getBindValue( value, options ), microsoft.sql.Types.VECTOR ); + } + + @Override + protected void doBindNull(PreparedStatement st, int index, WrapperOptions options) throws SQLException { + st.setNull( index, microsoft.sql.Types.VECTOR ); + } + + @Override + protected void doBindNull(CallableStatement st, String name, WrapperOptions options) throws SQLException { + st.setNull( name, microsoft.sql.Types.VECTOR ); + } + + @Override + public Object getBindValue(X value, WrapperOptions options) { + final Float[] floats = getJavaType().unwrap( value, Float[].class, options ); + return new Vector( floats.length, Vector.VectorDimensionType.FLOAT32, floats ); + } + }; + } + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseByteVectorJavaType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseByteVectorJavaType.java new file mode 100644 index 000000000000..47c353b80e4b --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseByteVectorJavaType.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; +import org.hibernate.type.BasicCollectionType; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.AbstractClassJavaType; +import org.hibernate.type.descriptor.java.BasicPluralJavaType; +import org.hibernate.type.descriptor.java.ByteJavaType; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.java.MutableMutabilityPlan; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators; +import org.hibernate.type.spi.TypeConfiguration; +import org.hibernate.vector.SparseByteVector; + +import java.util.Arrays; +import java.util.List; + + +public class SparseByteVectorJavaType extends AbstractClassJavaType implements BasicPluralJavaType { + + public static final SparseByteVectorJavaType INSTANCE = new SparseByteVectorJavaType(); + + public SparseByteVectorJavaType() { + super( SparseByteVector.class, new SparseVectorMutabilityPlan() ); + } + + @Override + public JavaType getElementJavaType() { + return ByteJavaType.INSTANCE; + } + + @Override + public BasicType resolveType(TypeConfiguration typeConfiguration, Dialect dialect, BasicType elementType, ColumnTypeInformation columnTypeInformation, JdbcTypeIndicators stdIndicators) { + final int arrayTypeCode = stdIndicators.getPreferredSqlTypeCodeForArray( elementType.getJdbcType().getDefaultSqlTypeCode() ); + final JdbcType arrayJdbcType = typeConfiguration.getJdbcTypeRegistry() + .resolveTypeConstructorDescriptor( arrayTypeCode, elementType, columnTypeInformation ); + if ( elementType.getValueConverter() != null ) { + throw new IllegalArgumentException( "Can't convert element type of sparse vector" ); + } + return typeConfiguration.getBasicTypeRegistry() + .resolve( this, arrayJdbcType, + () -> new BasicCollectionType<>( elementType, arrayJdbcType, this, "sparse_byte_vector" ) ); + } + + @Override + public JdbcType getRecommendedJdbcType(JdbcTypeIndicators indicators) { + return indicators.getJdbcType( SqlTypes.SPARSE_VECTOR_INT8 ); + } + + @Override + public X unwrap(SparseByteVector value, Class type, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if ( type.isInstance( value ) ) { + //noinspection unchecked + return (X) value; + } + else if ( byte[].class.isAssignableFrom( type ) ) { + return (X) value.toDenseVector(); + } + else if ( Object[].class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toArray(); + } + else if ( String.class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toString(); + } + else { + throw unknownUnwrap( type ); + } + } + + @Override + public SparseByteVector wrap(X value, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if (value instanceof SparseByteVector vector) { + return vector; + } + else if (value instanceof List list) { + //noinspection unchecked + return new SparseByteVector( (List) list ); + } + else if (value instanceof Object[] array) { + //noinspection unchecked + return new SparseByteVector( (List) (List) Arrays.asList( array ) ); + } + else if (value instanceof byte[] vector) { + return new SparseByteVector( vector ); + } + else if (value instanceof String vector) { + return new SparseByteVector( vector ); + } + else { + throw unknownWrap( value.getClass() ); + } + } + + private static class SparseVectorMutabilityPlan extends MutableMutabilityPlan { + @Override + protected SparseByteVector deepCopyNotNull(SparseByteVector value) { + return value.clone(); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseDoubleVectorJavaType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseDoubleVectorJavaType.java new file mode 100644 index 000000000000..d43ae847671a --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseDoubleVectorJavaType.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; +import org.hibernate.type.BasicCollectionType; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.AbstractClassJavaType; +import org.hibernate.type.descriptor.java.BasicPluralJavaType; +import org.hibernate.type.descriptor.java.DoubleJavaType; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.java.MutableMutabilityPlan; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators; +import org.hibernate.type.spi.TypeConfiguration; +import org.hibernate.vector.SparseDoubleVector; + +import java.util.Arrays; +import java.util.List; + + +public class SparseDoubleVectorJavaType extends AbstractClassJavaType implements BasicPluralJavaType { + + public static final SparseDoubleVectorJavaType INSTANCE = new SparseDoubleVectorJavaType(); + + public SparseDoubleVectorJavaType() { + super( SparseDoubleVector.class, new SparseVectorMutabilityPlan() ); + } + + @Override + public JavaType getElementJavaType() { + return DoubleJavaType.INSTANCE; + } + + @Override + public BasicType resolveType(TypeConfiguration typeConfiguration, Dialect dialect, BasicType elementType, ColumnTypeInformation columnTypeInformation, JdbcTypeIndicators stdIndicators) { + final int arrayTypeCode = stdIndicators.getPreferredSqlTypeCodeForArray( elementType.getJdbcType().getDefaultSqlTypeCode() ); + final JdbcType arrayJdbcType = typeConfiguration.getJdbcTypeRegistry() + .resolveTypeConstructorDescriptor( arrayTypeCode, elementType, columnTypeInformation ); + if ( elementType.getValueConverter() != null ) { + throw new IllegalArgumentException( "Can't convert element type of sparse vector" ); + } + return typeConfiguration.getBasicTypeRegistry() + .resolve( this, arrayJdbcType, + () -> new BasicCollectionType<>( elementType, arrayJdbcType, this, "sparse_double_vector" ) ); + } + + @Override + public JdbcType getRecommendedJdbcType(JdbcTypeIndicators indicators) { + return indicators.getJdbcType( SqlTypes.SPARSE_VECTOR_INT8 ); + } + + @Override + public X unwrap(SparseDoubleVector value, Class type, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if ( type.isInstance( value ) ) { + //noinspection unchecked + return (X) value; + } + else if ( double[].class.isAssignableFrom( type ) ) { + return (X) value.toDenseVector(); + } + else if ( Object[].class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toArray(); + } + else if ( String.class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toString(); + } + else { + throw unknownUnwrap( type ); + } + } + + @Override + public SparseDoubleVector wrap(X value, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if (value instanceof SparseDoubleVector vector) { + return vector; + } + else if (value instanceof List list) { + //noinspection unchecked + return new SparseDoubleVector( (List) list ); + } + else if (value instanceof Object[] array) { + //noinspection unchecked + return new SparseDoubleVector( (List) (List) Arrays.asList( array ) ); + } + else if (value instanceof double[] vector) { + return new SparseDoubleVector( vector ); + } + else if (value instanceof String vector) { + return new SparseDoubleVector( vector ); + } + else { + throw unknownWrap( value.getClass() ); + } + } + + private static class SparseVectorMutabilityPlan extends MutableMutabilityPlan { + @Override + protected SparseDoubleVector deepCopyNotNull(SparseDoubleVector value) { + return value.clone(); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseFloatVectorJavaType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseFloatVectorJavaType.java new file mode 100644 index 000000000000..df1dad7444c8 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/SparseFloatVectorJavaType.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.tool.schema.extract.spi.ColumnTypeInformation; +import org.hibernate.type.BasicCollectionType; +import org.hibernate.type.BasicType; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.AbstractClassJavaType; +import org.hibernate.type.descriptor.java.BasicPluralJavaType; +import org.hibernate.type.descriptor.java.FloatJavaType; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.java.MutableMutabilityPlan; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators; +import org.hibernate.type.spi.TypeConfiguration; +import org.hibernate.vector.SparseFloatVector; + +import java.util.Arrays; +import java.util.List; + + +public class SparseFloatVectorJavaType extends AbstractClassJavaType implements BasicPluralJavaType { + + public static final SparseFloatVectorJavaType INSTANCE = new SparseFloatVectorJavaType(); + + public SparseFloatVectorJavaType() { + super( SparseFloatVector.class, new SparseVectorMutabilityPlan() ); + } + + @Override + public JavaType getElementJavaType() { + return FloatJavaType.INSTANCE; + } + + @Override + public BasicType resolveType(TypeConfiguration typeConfiguration, Dialect dialect, BasicType elementType, ColumnTypeInformation columnTypeInformation, JdbcTypeIndicators stdIndicators) { + final int arrayTypeCode = stdIndicators.getPreferredSqlTypeCodeForArray( elementType.getJdbcType().getDefaultSqlTypeCode() ); + final JdbcType arrayJdbcType = typeConfiguration.getJdbcTypeRegistry() + .resolveTypeConstructorDescriptor( arrayTypeCode, elementType, columnTypeInformation ); + if ( elementType.getValueConverter() != null ) { + throw new IllegalArgumentException( "Can't convert element type of sparse vector" ); + } + return typeConfiguration.getBasicTypeRegistry() + .resolve( this, arrayJdbcType, + () -> new BasicCollectionType<>( elementType, arrayJdbcType, this, "sparse_float_vector" ) ); + } + + @Override + public JdbcType getRecommendedJdbcType(JdbcTypeIndicators indicators) { + return indicators.getJdbcType( SqlTypes.SPARSE_VECTOR_INT8 ); + } + + @Override + public X unwrap(SparseFloatVector value, Class type, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if ( type.isInstance( value ) ) { + //noinspection unchecked + return (X) value; + } + else if ( float[].class.isAssignableFrom( type ) ) { + return (X) value.toDenseVector(); + } + else if ( Object[].class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toArray(); + } + else if ( String.class.isAssignableFrom( type ) ) { + //noinspection unchecked + return (X) value.toString(); + } + else { + throw unknownUnwrap( type ); + } + } + + @Override + public SparseFloatVector wrap(X value, WrapperOptions options) { + if ( value == null ) { + return null; + } + else if (value instanceof SparseFloatVector vector) { + return vector; + } + else if (value instanceof List list) { + //noinspection unchecked + return new SparseFloatVector( (List) list ); + } + else if (value instanceof Object[] array) { + //noinspection unchecked + return new SparseFloatVector( (List) (List) Arrays.asList( array ) ); + } + else if (value instanceof float[] vector) { + return new SparseFloatVector( vector ); + } + else if (value instanceof String vector) { + return new SparseFloatVector( vector ); + } + else { + throw unknownWrap( value.getClass() ); + } + } + + private static class SparseVectorMutabilityPlan extends MutableMutabilityPlan { + @Override + protected SparseFloatVector deepCopyNotNull(SparseFloatVector value) { + return value.clone(); + } + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentTypeResolver.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentTypeResolver.java similarity index 81% rename from hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentTypeResolver.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentTypeResolver.java index 4af45fd9f44e..d867679a5acf 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentTypeResolver.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentTypeResolver.java @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.List; @@ -21,11 +21,18 @@ */ public class VectorArgumentTypeResolver implements AbstractFunctionArgumentTypeResolver { - public static final FunctionArgumentTypeResolver INSTANCE = new VectorArgumentTypeResolver(); + public static final FunctionArgumentTypeResolver INSTANCE = new VectorArgumentTypeResolver( 0 ); + public static final FunctionArgumentTypeResolver DISTANCE_INSTANCE = new VectorArgumentTypeResolver( 0, 1 ); + + private final int[] vectorIndices; + + public VectorArgumentTypeResolver(int... vectorIndices) { + this.vectorIndices = vectorIndices; + } @Override public @Nullable MappingModelExpressible resolveFunctionArgumentType(List> arguments, int argumentIndex, SqmToSqlAstConverter converter) { - for ( int i = 0; i < arguments.size(); i++ ) { + for ( int i : vectorIndices ) { if ( i != argumentIndex ) { final SqmTypedNode node = arguments.get( i ); if ( node instanceof SqmExpression ) { diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentValidator.java similarity index 69% rename from hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentValidator.java index 4bd1632c50be..b15f703d2b97 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorArgumentValidator.java @@ -2,17 +2,17 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import java.util.List; +import org.hibernate.type.BasicType; import org.hibernate.type.BindingContext; import org.hibernate.query.sqm.SqmExpressible; import org.hibernate.query.sqm.produce.function.ArgumentsValidator; import org.hibernate.query.sqm.produce.function.FunctionArgumentException; import org.hibernate.query.sqm.tree.SqmTypedNode; import org.hibernate.query.sqm.tree.domain.SqmDomainType; -import org.hibernate.type.BasicPluralType; import org.hibernate.type.SqlTypes; /** @@ -20,14 +20,21 @@ */ public class VectorArgumentValidator implements ArgumentsValidator { - public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator(); + public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator( 0 ); + public static final ArgumentsValidator DISTANCE_INSTANCE = new VectorArgumentValidator( 0, 1 ); + + private final int[] vectorIndices; + + public VectorArgumentValidator(int... vectorIndices) { + this.vectorIndices = vectorIndices; + } @Override public void validate( List> arguments, String functionName, BindingContext bindingContext) { - for ( int i = 0; i < arguments.size(); i++ ) { + for ( int i : vectorIndices ) { final SqmExpressible expressible = arguments.get( i ).getExpressible(); if ( expressible != null ) { final SqmDomainType type = expressible.getSqmType(); @@ -46,9 +53,10 @@ public void validate( } private static boolean isVectorType(SqmExpressible vectorType) { - return vectorType instanceof BasicPluralType basicPluralType - && switch ( basicPluralType.getJdbcType().getDefaultSqlTypeCode() ) { - case SqlTypes.VECTOR, SqlTypes.VECTOR_INT8, SqlTypes.VECTOR_FLOAT32, SqlTypes.VECTOR_FLOAT64 -> true; + return vectorType instanceof BasicType basicType + && switch ( basicType.getJdbcType().getDefaultSqlTypeCode() ) { + case SqlTypes.VECTOR, SqlTypes.VECTOR_INT8, SqlTypes.VECTOR_FLOAT16, SqlTypes.VECTOR_FLOAT32, SqlTypes.VECTOR_FLOAT64, + SqlTypes.VECTOR_BINARY, SqlTypes.SPARSE_VECTOR_INT8, SqlTypes.SPARSE_VECTOR_FLOAT32, SqlTypes.SPARSE_VECTOR_FLOAT64-> true; default -> false; }; } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorDdlType.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorDdlType.java new file mode 100644 index 000000000000..d14d360351f6 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorDdlType.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.hibernate.dialect.Dialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; + +/** + * DDL type for vector types. + * + * @since 7.1 + */ +public class VectorDdlType extends DdlTypeImpl { + + public VectorDdlType(int sqlTypeCode, boolean isLob, String typeNamePattern, String castTypeNamePattern, String castTypeName, Dialect dialect) { + super( sqlTypeCode, isLob, typeNamePattern, castTypeNamePattern, castTypeName, dialect ); + } + + public VectorDdlType(int sqlTypeCode, String typeNamePattern, String castTypeNamePattern, String castTypeName, Dialect dialect) { + super( sqlTypeCode, typeNamePattern, castTypeNamePattern, castTypeName, dialect ); + } + + public VectorDdlType(int sqlTypeCode, boolean isLob, String typeNamePattern, String castTypeName, Dialect dialect) { + super( sqlTypeCode, isLob, typeNamePattern, castTypeName, dialect ); + } + + public VectorDdlType(int sqlTypeCode, String typeNamePattern, String castTypeName, Dialect dialect) { + super( sqlTypeCode, typeNamePattern, castTypeName, dialect ); + } + + public VectorDdlType(int sqlTypeCode, String typeNamePattern, Dialect dialect) { + super( sqlTypeCode, typeNamePattern, dialect ); + } + + @Override + public String getTypeName(Size size) { + return getTypeName( + size.getArrayLength() == null ? null : size.getArrayLength().longValue(), + null, + null + ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorFunctionFactory.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorFunctionFactory.java similarity index 78% rename from hibernate-vector/src/main/java/org/hibernate/vector/VectorFunctionFactory.java rename to hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorFunctionFactory.java index 71679aeac05a..a0fd4a6b5ce4 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorFunctionFactory.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorFunctionFactory.java @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 * Copyright Red Hat Inc. and Hibernate Authors */ -package org.hibernate.vector; +package org.hibernate.vector.internal; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.query.sqm.function.SqmFunctionRegistry; @@ -41,6 +41,11 @@ public void euclideanDistance(String pattern) { functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" ); } + public void euclideanSquaredDistance(String pattern) { + registerVectorDistanceFunction( "euclidean_squared_distance", pattern ); + functionRegistry.registerAlternateKey( "l2_squared_distance", "euclidean_squared_distance" ); + } + public void l1Distance(String pattern) { registerVectorDistanceFunction( "l1_distance", pattern ); functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" ); @@ -58,6 +63,10 @@ public void hammingDistance(String pattern) { registerVectorDistanceFunction( "hamming_distance", pattern ); } + public void jaccardDistance(String pattern) { + registerVectorDistanceFunction( "jaccard_distance", pattern ); + } + public void vectorDimensions() { registerNamedVectorFunction( "vector_dims", integerType, 1 ); } @@ -70,9 +79,9 @@ public void registerVectorDistanceFunction(String functionName, String pattern) functionRegistry.patternDescriptorBuilder( functionName, pattern ) .setArgumentsValidator( StandardArgumentsValidators.composite( StandardArgumentsValidators.exactly( 2 ), - VectorArgumentValidator.INSTANCE + VectorArgumentValidator.DISTANCE_INSTANCE ) ) - .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.DISTANCE_INSTANCE ) .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) .register(); } @@ -88,4 +97,15 @@ public void registerNamedVectorFunction(String functionName, BasicType return .register(); } + public void registerPatternVectorFunction(String functionName, String pattern, BasicType returnType, int argumentCount) { + functionRegistry.patternDescriptorBuilder( functionName, pattern ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( argumentCount ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( returnType ) ) + .register(); + } + } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorHelper.java b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorHelper.java new file mode 100644 index 000000000000..6ca866d625c2 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/internal/VectorHelper.java @@ -0,0 +1,174 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.util.BitSet; + +/** + * Helper for vector related functionality. + * + * @since 7.1 + */ +public class VectorHelper { + + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + private static final float[] EMPTY_FLOAT_ARRAY = new float[0]; + private static final double[] EMPTY_DOUBLE_ARRAY = new double[0]; + + public static @Nullable byte[] parseByteVector(@Nullable String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY_BYTE_ARRAY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final byte[] result = new byte[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Byte.parseByte( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Byte.parseByte( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + public static @Nullable float[] parseFloatVector(@Nullable String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY_FLOAT_ARRAY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final float[] result = new float[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Float.parseFloat( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Float.parseFloat( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + public static @Nullable double[] parseDoubleVector(@Nullable String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY_DOUBLE_ARRAY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final double[] result = new double[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Double.parseDouble( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Double.parseDouble( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + public static @Nullable float[] parseFloatVector(@Nullable byte[] bytes) { + if ( bytes == null ) { + return null; + } + if ( bytes.length == 0 ) { + return EMPTY_FLOAT_ARRAY; + } + if ( (bytes.length & 3) != 0 ) { + throw new IllegalArgumentException( + "Invalid byte array length. Expected a multiple of 4 but got: " + bytes.length ); + } + final float[] result = new float[bytes.length >> 2]; + for ( int i = 0, resultLength = result.length; i < resultLength; i++ ) { + final int offset = i << 2; + final int asInt = (bytes[offset] & 0xFF) + | ((bytes[offset + 1] & 0xFF) << 8) + | ((bytes[offset + 2] & 0xFF) << 16) + | ((bytes[offset + 3] & 0xFF) << 24); + result[i] = Float.intBitsToFloat( asInt ); + } + return result; + } + + public static byte[] parseBitString(String bitString) { + assert new BigInteger( "1" + bitString, 2 ).bitLength() == bitString.length() + 1; + final int fullBytesCount = bitString.length() >> 3; + final int fullBytesStartPosition = ((bitString.length() & 7) == 0 ? 0 : 1); + final int byteCount = fullBytesCount + fullBytesStartPosition; + final byte[] bytes = new byte[byteCount]; + final int fullBytesBitCount = fullBytesCount << 3; + final int leadingBits = bitString.length() - fullBytesBitCount; + if ( leadingBits > 0 ) { + for (int i = 0; i < leadingBits; i++ ) { + bytes[0] |= (byte) (((bitString.charAt( i ) - 48)) << (7 - i)); + } + } + for ( int i = fullBytesStartPosition; i < fullBytesCount; i ++ ) { + bytes[i] = (byte) ( + ((bitString.charAt( i * 8 + 0 ) - 48) << 7) + | ((bitString.charAt( i * 8 + 1 ) - 48) << 6) + | ((bitString.charAt( i * 8 + 2 ) - 48) << 5) + | ((bitString.charAt( i * 8 + 3 ) - 48) << 4) + | ((bitString.charAt( i * 8 + 4 ) - 48) << 3) + | ((bitString.charAt( i * 8 + 5 ) - 48) << 2) + | ((bitString.charAt( i * 8 + 6 ) - 48) << 1) + | ((bitString.charAt( i * 8 + 7 ) - 48) << 0) + ); + } + return bytes; + } + + public static String toBitString(byte[] bytes) { + final byte[] bitBytes = new byte[bytes.length * 8]; + for ( int i = 0; i < bytes.length; i++ ) { + final byte b = bytes[i]; + bitBytes[i * 8 + 0] = (byte) (((b >>> 7) & 1) + 48); + bitBytes[i * 8 + 1] = (byte) (((b >>> 6) & 1) + 48); + bitBytes[i * 8 + 2] = (byte) (((b >>> 5) & 1) + 48); + bitBytes[i * 8 + 3] = (byte) (((b >>> 4) & 1) + 48); + bitBytes[i * 8 + 4] = (byte) (((b >>> 3) & 1) + 48); + bitBytes[i * 8 + 5] = (byte) (((b >>> 2) & 1) + 48); + bitBytes[i * 8 + 6] = (byte) (((b >>> 1) & 1) + 48); + bitBytes[i * 8 + 7] = (byte) (((b >>> 0) & 1) + 48); + } + return new String( bitBytes, StandardCharsets.UTF_8 ); + } +} diff --git a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor index 6103956ccbd7..bf3db90f772e 100644 --- a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor +++ b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor @@ -1,3 +1,8 @@ -org.hibernate.vector.PGVectorFunctionContributor -org.hibernate.vector.OracleVectorFunctionContributor -org.hibernate.vector.MariaDBFunctionContributor +org.hibernate.vector.internal.PGVectorFunctionContributor +org.hibernate.vector.internal.OracleVectorFunctionContributor +org.hibernate.vector.internal.MariaDBFunctionContributor +org.hibernate.vector.internal.MySQLFunctionContributor +org.hibernate.vector.internal.DB2VectorFunctionContributor +org.hibernate.vector.internal.CockroachFunctionContributor +org.hibernate.vector.internal.HANAVectorFunctionContributor +org.hibernate.vector.internal.SQLServerVectorFunctionContributor \ No newline at end of file diff --git a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor index 11605464c824..ea1206dedc20 100644 --- a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor +++ b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor @@ -1,3 +1,8 @@ -org.hibernate.vector.PGVectorTypeContributor -org.hibernate.vector.OracleVectorTypeContributor -org.hibernate.vector.MariaDBTypeContributor +org.hibernate.vector.internal.PGVectorTypeContributor +org.hibernate.vector.internal.OracleVectorTypeContributor +org.hibernate.vector.internal.MariaDBTypeContributor +org.hibernate.vector.internal.MySQLTypeContributor +org.hibernate.vector.internal.DB2VectorTypeContributor +org.hibernate.vector.internal.CockroachTypeContributor +org.hibernate.vector.internal.HANAVectorTypeContributor +org.hibernate.vector.internal.SQLServerTypeContributor \ No newline at end of file diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/BinaryVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/BinaryVectorTest.java new file mode 100644 index 000000000000..e568ba187cdf --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/BinaryVectorTest.java @@ -0,0 +1,275 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.OracleDialect; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.dialect.PostgresPlusDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SkipForDialect; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.cosineDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.euclideanDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.euclideanNormBinary; +import static org.hibernate.vector.VectorTestHelper.euclideanSquaredDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.hammingDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.innerProductBinary; +import static org.hibernate.vector.VectorTestHelper.jaccardDistanceBinary; +import static org.hibernate.vector.VectorTestHelper.taxicabDistanceBinary; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = BinaryVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsBinaryVectorType.class) +@SkipForDialect(dialectClass = PostgresPlusDialect.class, reason = "Test database does not have the extension enabled") +public class BinaryVectorTest { + + private static final byte[] V1 = new byte[]{ 1, 2, 3 }; + private static final byte[] V2 = new byte[]{ 4, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new byte[]{ 1, 2, 3 }, tableRecord.getTheVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new byte[]{ 4, 5, 6 }, tableRecord.getTheVector() ); + } ); + } + + @Test + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final String literal = VectorTestHelper.vectorBinaryStringLiteral( new byte[] {1, 1, 1}, em ); + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('" + literal + "' as binary_vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + assertEquals( VectorTestHelper.vectorBinaryStringLiteral( V1, em ), vector.get( 0, String.class ) ); + assertArrayEquals( new byte[]{ 1, 1, 1 }, vector.get( 1, byte[].class ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanSquaredDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testEuclideanSquaredDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_squared_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanSquaredDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanSquaredDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProductBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProductBinary( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProductBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProductBinary( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::hamming-distance-example[] + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::hamming-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsJaccardDistance.class) + public void testJaccardDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::jaccard-distance-example[] + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, jaccard_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::jaccard-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( jaccardDistanceBinary( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( jaccardDistanceBinary( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length * 8, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length * 8, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with bit vectors") + @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNormBinary( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNormBinary( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.VECTOR_BINARY) + @Array(length = 24) + private byte[] theVector; + + public VectorEntity() { + } + + public VectorEntity(Long id, byte[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public byte[] getTheVector() { + return theVector; + } + + public void setTheVector(byte[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/ByteVectorTest.java similarity index 72% rename from hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java rename to hibernate-vector/src/test/java/org/hibernate/vector/ByteVectorTest.java index f10764997053..f7ac36ba6a6a 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/ByteVectorTest.java @@ -9,13 +9,15 @@ import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; import org.hibernate.dialect.OracleDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; import org.hibernate.testing.orm.junit.SkipForDialect; import org.hibernate.type.SqlTypes; import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.vector.internal.VectorHelper; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -25,16 +27,23 @@ import jakarta.persistence.Id; import jakarta.persistence.Tuple; +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.euclideanSquaredDistance; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Hassan AL Meftah */ -@DomainModel(annotatedClasses = OracleByteVectorTest.VectorEntity.class) +@DomainModel(annotatedClasses = ByteVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialect(value = OracleDialect.class, majorVersion = 23, minorVersion = 4) -public class OracleByteVectorTest { +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsByteVectorType.class) +public class ByteVectorTest { private static final byte[] V1 = new byte[]{ 1, 2, 3 }; private static final byte[] V2 = new byte[]{ 4, 5, 6 }; @@ -67,14 +76,23 @@ public void testRead(SessionFactoryScope scope) { } @Test + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('[1, 1, 1]' as byte_vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + assertArrayEquals( new byte[]{ 1, 2, 3 }, VectorHelper.parseByteVector( vector.get( 0, String.class ) ) ); + assertArrayEquals( new byte[]{ 1, 1, 1 }, vector.get( 1, byte[].class ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) public void testCosineDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::cosine-distance-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector, byte[].class ) + .setParameter( "vec", vector ) .getResultList(); - //end::cosine-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); @@ -84,14 +102,13 @@ public void testCosineDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanSquaredDistance.class) public void testEuclideanDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::euclidean-distance-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::euclidean-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); @@ -101,14 +118,29 @@ public void testEuclideanDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + public void testEuclideanSquaredDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_squared_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) public void testTaxicabDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::taxicab-distance-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::taxicab-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -118,14 +150,13 @@ public void testTaxicabDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) public void testInnerProduct(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -137,14 +168,13 @@ public void testInnerProduct(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) public void testHammingDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] final byte[] vector = new byte[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -154,12 +184,11 @@ public void testHammingDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) public void testVectorDims(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-dims-example[] final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-dims-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( V1.length, results.get( 0 ).get( 1 ) ); @@ -169,13 +198,12 @@ public void testVectorDims(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") public void testVectorNorm(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-norm-example[] final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-norm-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -184,74 +212,16 @@ public void testVectorNorm(SessionFactoryScope scope) { } ); } - private static double cosineDistance(byte[] f1, byte[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(byte[] f1, byte[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(byte[] f1, byte[] f2) { - return norm( f1 ) - norm( f2 ); - } - - private static double innerProduct(byte[] f1, byte[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - public static double hammingDistance(byte[] f1, byte[] f2) { - assert f1.length == f2.length; - int distance = 0; - for (int i = 0; i < f1.length; i++) { - if (!(f1[i] == f2[i])) { - distance++; - } - } - return distance; - } - - - private static double euclideanNorm(byte[] f) { - double result = 0; - for ( double v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - private static double norm(byte[] f) { - double result = 0; - for ( double v : f ) { - result += Math.abs( v ); - } - return result; - } - @Entity( name = "VectorEntity" ) public static class VectorEntity { @Id private Long id; - //tag::usage-example[] @Column( name = "the_vector" ) @JdbcTypeCode(SqlTypes.VECTOR_INT8) @Array(length = 3) private byte[] theVector; - //end::usage-example[] - - public VectorEntity() { } diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/DoubleVectorTest.java similarity index 72% rename from hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java rename to hibernate-vector/src/test/java/org/hibernate/vector/DoubleVectorTest.java index 9aee2b8f0dd5..2dc4839a94f0 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/DoubleVectorTest.java @@ -9,13 +9,15 @@ import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; import org.hibernate.dialect.OracleDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; import org.hibernate.testing.orm.junit.SkipForDialect; import org.hibernate.type.SqlTypes; import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.vector.internal.VectorHelper; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -25,16 +27,23 @@ import jakarta.persistence.Id; import jakarta.persistence.Tuple; +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.euclideanSquaredDistance; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Hassan AL Meftah */ -@DomainModel(annotatedClasses = OracleDoubleVectorTest.VectorEntity.class) +@DomainModel(annotatedClasses = DoubleVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialect(value = OracleDialect.class, majorVersion = 23, minorVersion = 4) -public class OracleDoubleVectorTest { +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsDoubleVectorType.class) +public class DoubleVectorTest { private static final double[] V1 = new double[]{ 1, 2, 3 }; private static final double[] V2 = new double[]{ 4, 5, 6 }; @@ -67,14 +76,23 @@ public void testRead(SessionFactoryScope scope) { } @Test + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('[1, 1, 1]' as double_vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + assertArrayEquals( new double[]{ 1, 2, 3 }, VectorHelper.parseDoubleVector( vector.get( 0, String.class ) ) ); + assertArrayEquals( new double[]{ 1, 1, 1 }, vector.get( 1, double[].class ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) public void testCosineDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::cosine-distance-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::cosine-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); @@ -84,14 +102,13 @@ public void testCosineDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) public void testEuclideanDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::euclidean-distance-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::euclidean-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.00002D ); @@ -101,14 +118,29 @@ public void testEuclideanDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanSquaredDistance.class) + public void testEuclideanSquaredDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_squared_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.00002D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.00002D); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) public void testTaxicabDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::taxicab-distance-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::taxicab-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -118,14 +150,13 @@ public void testTaxicabDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) public void testInnerProduct(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -137,14 +168,13 @@ public void testInnerProduct(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) public void testHammingDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] final double[] vector = new double[]{ 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -153,12 +183,11 @@ public void testHammingDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) public void testVectorDims(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-dims-example[] final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-dims-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( V1.length, results.get( 0 ).get( 1 ) ); @@ -168,13 +197,12 @@ public void testVectorDims(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") public void testVectorNorm(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::vector-norm-example[] final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) .getResultList(); - //end::vector-norm-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -183,72 +211,16 @@ public void testVectorNorm(SessionFactoryScope scope) { } ); } - - private static double cosineDistance(double[] f1, double[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(double[] f1, double[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(double[] f1, double[] f2) { - return norm( f1 ) - norm( f2 ); - } - - public static double hammingDistance(double[] f1, double[] f2) { - assert f1.length == f2.length; - int distance = 0; - for (int i = 0; i < f1.length; i++) { - if (!(f1[i] == f2[i])) { - distance++; - } - } - return distance; - } - - private static double innerProduct(double[] f1, double[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - private static double euclideanNorm(double[] f) { - double result = 0; - for ( double v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - private static double norm(double[] f) { - double result = 0; - for ( double v : f ) { - result += Math.abs( v ); - } - return result; - } - @Entity( name = "VectorEntity" ) public static class VectorEntity { @Id private Long id; - //tag::usage-example[] @Column( name = "the_vector" ) @JdbcTypeCode(SqlTypes.VECTOR_FLOAT64) @Array(length = 3) private double[] theVector; - //end::usage-example[] public VectorEntity() { } diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/Float16VectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/Float16VectorTest.java new file mode 100644 index 000000000000..5f130ab68148 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/Float16VectorTest.java @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.hibernate.vector.internal.VectorHelper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.euclideanNormalize; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = Float16VectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsFloat16VectorType.class) +public class Float16VectorTest extends FloatVectorTest { + + @BeforeEach + @Override + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @Test + @Override + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 ); + } ); + } + + @Test + @Override + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('[1, 1, 1]' as float16_vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + assertArrayEquals( new float[]{ 1, 2, 3 }, VectorHelper.parseFloatVector( vector.get( 0, String.class ) ) ); + assertArrayEquals( new float[]{ 1, 1, 1 }, vector.get( 1, float[].class ) ); + } ); + } + + // Due to lower precision (float16/half-precision floating-point) type usage, + // we have to give a higher allowed delta since we can't easily calculate with the same precision in Java yet + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsL2Normalize.class) + @Override + public void testL2Normalize(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( + "select e.id, l2_normalize(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertArrayEquals( euclideanNormalize( V1 ), results.get( 0 ).get( 1, float[].class ), 0.0002f ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertArrayEquals( euclideanNormalize( V2 ), results.get( 1 ).get( 1, float[].class ), 0.0002f ); + } ); + } + + @Entity(name = "VectorEntity") + public static class VectorEntity { + + @Id + private Long id; + + @Column(name = "the_vector") + @JdbcTypeCode(SqlTypes.VECTOR_FLOAT16) + @Array(length = 3) + private float[] theVector; + + + public VectorEntity() { + } + + public VectorEntity(Long id, float[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public float[] getTheVector() { + return theVector; + } + + public void setTheVector(float[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/Float32VectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/Float32VectorTest.java new file mode 100644 index 000000000000..23181fabb6a1 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/Float32VectorTest.java @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.hibernate.vector.internal.VectorHelper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +@DomainModel(annotatedClasses = Float32VectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsFloatVectorType.class) +public class Float32VectorTest extends FloatVectorTest { + + @BeforeEach + @Override + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @Test + @Override + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 ); + } ); + } + + @Test + @Override + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('[1, 1, 1]' as float_vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + assertArrayEquals( new float[]{ 1, 2, 3 }, VectorHelper.parseFloatVector( vector.get( 0, String.class ) ) ); + assertArrayEquals( new float[]{ 1, 1, 1 }, vector.get( 1, float[].class ) ); + } ); + } + + @Entity(name = "VectorEntity") + public static class VectorEntity { + + @Id + private Long id; + + @Column(name = "the_vector") + @JdbcTypeCode(SqlTypes.VECTOR_FLOAT32) + @Array(length = 3) + private float[] theVector; + + + public VectorEntity() { + } + + public VectorEntity(Long id, float[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public float[] getTheVector() { + return theVector; + } + + public void setTheVector(float[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/FloatVectorTest.java similarity index 52% rename from hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java rename to hibernate-vector/src/test/java/org/hibernate/vector/FloatVectorTest.java index 0f024f9fa766..cb096971daed 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/FloatVectorTest.java @@ -4,41 +4,51 @@ */ package org.hibernate.vector; -import java.util.List; - +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.MySQLDialect; import org.hibernate.dialect.OracleDialect; +import org.hibernate.dialect.PostgreSQLDialect; import org.hibernate.testing.orm.junit.SkipForDialect; -import org.hibernate.type.SqlTypes; - +import org.hibernate.dialect.PostgresPlusDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SkipForDialect; +import org.hibernate.type.SqlTypes; +import org.hibernate.vector.internal.VectorHelper; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Id; -import jakarta.persistence.Tuple; +import java.util.List; +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.euclideanNormalize; +import static org.hibernate.vector.VectorTestHelper.euclideanSquaredDistance; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -/** - * @author Hassan AL Meftah - */ -@DomainModel(annotatedClasses = OracleGenericVectorTest.VectorEntity.class) +@DomainModel(annotatedClasses = FloatVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialect(value = OracleDialect.class, majorVersion = 23, minorVersion = 4) -public class OracleGenericVectorTest { +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorType.class) +@SkipForDialect(dialectClass = PostgresPlusDialect.class, reason = "Test database does not have the extension enabled") +public class FloatVectorTest { - private static final float[] V1 = new float[] { 1, 2, 3 }; - private static final float[] V2 = new float[] { 4, 5, 6 }; + protected static final float[] V1 = new float[] { 1, 2, 3 }; + protected static final float[] V2 = new float[] { 4, 5, 6 }; @BeforeEach public void prepareData(SessionFactoryScope scope) { @@ -68,6 +78,20 @@ public void testRead(SessionFactoryScope scope) { } @Test + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-cast-example[] + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('[1, 1, 1]' as vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + //end::vector-cast-example[] + assertArrayEquals( new float[]{ 1, 2, 3 }, VectorHelper.parseFloatVector( vector.get( 0, String.class ) ) ); + assertArrayEquals( new float[]{ 1, 1, 1 }, vector.get( 1, float[].class ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + @SkipForDialect(dialectClass = MySQLDialect.class, reason = "Only MySQL HeatWave supports this function") public void testCosineDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::cosine-distance-example[] @@ -88,6 +112,8 @@ public void testCosineDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + @SkipForDialect(dialectClass = MySQLDialect.class, reason = "Only MySQL HeatWave supports this function") public void testEuclideanDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::euclidean-distance-example[] @@ -108,6 +134,28 @@ public void testEuclideanDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanSquaredDistance.class) + public void testEuclideanSquaredDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::euclidean-squared-distance-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, euclidean_squared_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::euclidean-squared-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) public void testTaxicabDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::taxicab-distance-example[] @@ -128,6 +176,8 @@ public void testTaxicabDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + @SkipForDialect(dialectClass = MySQLDialect.class, reason = "Only MySQL HeatWave supports this function") public void testInnerProduct(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::inner-product-example[] @@ -150,9 +200,11 @@ public void testInnerProduct(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Only supported with bit vectors") public void testHammingDistance(SessionFactoryScope scope) { scope.inTransaction( em -> { - //tag::inner-product-example[] + //tag::hamming-distance-example[] final float[] vector = new float[] { 1, 1, 1 }; final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", @@ -160,7 +212,7 @@ public void testHammingDistance(SessionFactoryScope scope) { ) .setParameter( "vec", vector ) .getResultList(); - //end::inner-product-example[] + //end::hamming-distance-example[] assertEquals( 2, results.size() ); assertEquals( 1L, results.get( 0 ).get( 0 ) ); assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); @@ -170,6 +222,7 @@ public void testHammingDistance(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) public void testVectorDims(SessionFactoryScope scope) { scope.inTransaction( em -> { //tag::vector-dims-example[] @@ -188,6 +241,7 @@ public void testVectorDims(SessionFactoryScope scope) { } @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") public void testVectorNorm(SessionFactoryScope scope) { scope.inTransaction( em -> { @@ -206,59 +260,83 @@ public void testVectorNorm(SessionFactoryScope scope) { } ); } - - private static double cosineDistance(float[] f1, float[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(float[] f1, float[] f2) { - return norm( f1 ) - norm( f2 ); - } - - private static double innerProduct(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsL2Norm.class) + @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") + public void testL2Norm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::l2-norm-example[] + final List results = em.createSelectionQuery( + "select e.id, l2_norm(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::l2-norm-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); } - public static double hammingDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - int distance = 0; - for ( int i = 0; i < f1.length; i++ ) { - if ( !( f1[i] == f2[i] ) ) { - distance++; - } - } - return distance; + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsL2Normalize.class) + public void testL2Normalize(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::l2-normalize-example[] + final List results = em.createSelectionQuery( + "select e.id, l2_normalize(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::l2-normalize-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertArrayEquals( euclideanNormalize( V1 ), results.get( 0 ).get( 1, float[].class ), 0.0000001f ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertArrayEquals( euclideanNormalize( V2 ), results.get( 1 ).get( 1, float[].class ), 0.0000001f ); + } ); } - - private static double euclideanNorm(float[] f) { - double result = 0; - for ( double v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsSubvector.class) + public void testSubvector(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::subvector-example[] + final List results = em.createSelectionQuery( + "select e.id, subvector(e.theVector, 1, 1) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::subvector-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( 1, results.get( 0 ).get( 1, float[].class ).length ); + assertEquals( V1[0], results.get( 0 ).get( 1, float[].class )[0], 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( 1, results.get( 1 ).get( 1, float[].class ).length ); + assertEquals( V2[0], results.get( 1 ).get( 1, float[].class )[0], 0D ); + } ); } - private static double norm(float[] f) { - double result = 0; - for ( double v : f ) { - result += Math.abs( v ); - } - return result; + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsBinaryQuantize.class) + public void testBinaryQuantize(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::binary-quantize-example[] + final List results = em.createSelectionQuery( + "select e.id, binary_quantize(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::binary-quantize-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertArrayEquals( new byte[]{(byte) 0b11100000}, results.get( 0 ).get( 1, byte[].class ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertArrayEquals( new byte[]{(byte) 0b11100000}, results.get( 1 ).get( 1, byte[].class ) ); + } ); } @Entity(name = "VectorEntity") diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java deleted file mode 100644 index 78afe7fd5291..000000000000 --- a/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright Red Hat Inc. and Hibernate Authors - */ -package org.hibernate.vector; - -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Id; -import jakarta.persistence.Tuple; -import org.hibernate.annotations.Array; -import org.hibernate.annotations.JdbcTypeCode; -import org.hibernate.dialect.MariaDBDialect; -import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; -import org.hibernate.testing.orm.junit.SessionFactory; -import org.hibernate.testing.orm.junit.SessionFactoryScope; -import org.hibernate.type.SqlTypes; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** - * @author Diego Dupin - */ -@DomainModel(annotatedClasses = MariaDBTest.VectorEntity.class) -@SessionFactory -@RequiresDialect(value = MariaDBDialect.class, majorVersion = 11, minorVersion = 7) -public class MariaDBTest { - - private static final float[] V1 = new float[]{ 1, 2, 3 }; - private static final float[] V2 = new float[]{ 4, 5, 6 }; - - @BeforeEach - public void prepareData(SessionFactoryScope scope) { - scope.inTransaction( em -> { - em.persist( new VectorEntity( 1L, V1 ) ); - em.persist( new VectorEntity( 2L, V2 ) ); - } ); - } - - @AfterEach - public void cleanup(SessionFactoryScope scope) { - scope.inTransaction( em -> { - em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); - } ); - } - - @Test - public void testRead(SessionFactoryScope scope) { - scope.inTransaction( em -> { - VectorEntity tableRecord; - tableRecord = em.find( VectorEntity.class, 1L ); - assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 ); - - tableRecord = em.find( VectorEntity.class, 2L ); - assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 ); - } ); - } - - @Test - public void testCosineDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::cosine-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::cosine-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D ); - } ); - } - - @Test - public void testEuclideanDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::euclidean-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::euclidean-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); - } ); - } - - private static double cosineDistance(float[] f1, float[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double innerProduct(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - private static double euclideanNorm(float[] f) { - double result = 0; - for ( float v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - @Entity( name = "VectorEntity" ) - public static class VectorEntity { - - @Id - private Long id; - - //tag::usage-example[] - @Column( name = "the_vector" ) - @JdbcTypeCode(SqlTypes.VECTOR) - @Array(length = 3) - private float[] theVector; - //end::usage-example[] - - public VectorEntity() { - } - - public VectorEntity(Long id, float[] theVector) { - this.id = id; - this.theVector = theVector; - } - - public Long getId() { - return id; - } - - public void setId(Long id) { - this.id = id; - } - - public float[] getTheVector() { - return theVector; - } - - public void setTheVector(float[] theVector) { - this.theVector = theVector; - } - } -} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java deleted file mode 100644 index 7f34f305a472..000000000000 --- a/hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java +++ /dev/null @@ -1,301 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright Red Hat Inc. and Hibernate Authors - */ -package org.hibernate.vector; - -import java.util.List; - -import org.hibernate.annotations.Array; -import org.hibernate.annotations.JdbcTypeCode; -import org.hibernate.dialect.OracleDialect; -import org.hibernate.testing.orm.junit.SkipForDialect; -import org.hibernate.type.SqlTypes; - -import org.hibernate.testing.orm.junit.DomainModel; -import org.hibernate.testing.orm.junit.RequiresDialect; -import org.hibernate.testing.orm.junit.SessionFactory; -import org.hibernate.testing.orm.junit.SessionFactoryScope; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Id; -import jakarta.persistence.Tuple; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** - * @author Hassan AL Meftah - */ -@DomainModel(annotatedClasses = OracleFloatVectorTest.VectorEntity.class) -@SessionFactory -@RequiresDialect(value = OracleDialect.class, majorVersion = 23, minorVersion = 4) -public class OracleFloatVectorTest { - - private static final float[] V1 = new float[] { 1, 2, 3 }; - private static final float[] V2 = new float[] { 4, 5, 6 }; - - @BeforeEach - public void prepareData(SessionFactoryScope scope) { - scope.inTransaction( em -> { - em.persist( new VectorEntity( 1L, V1 ) ); - em.persist( new VectorEntity( 2L, V2 ) ); - } ); - } - - @AfterEach - public void cleanup(SessionFactoryScope scope) { - scope.inTransaction( em -> { - em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); - } ); - } - - @Test - public void testRead(SessionFactoryScope scope) { - scope.inTransaction( em -> { - VectorEntity tableRecord; - tableRecord = em.find( VectorEntity.class, 1L ); - assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 ); - - tableRecord = em.find( VectorEntity.class, 2L ); - assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 ); - } ); - } - - @Test - public void testCosineDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::cosine-distance-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) - .getResultList(); - //end::cosine-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); - } ); - } - - @Test - public void testEuclideanDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::euclidean-distance-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) - .getResultList(); - //end::euclidean-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); - } ); - } - - @Test - public void testTaxicabDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::taxicab-distance-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) - .getResultList(); - //end::taxicab-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); - } ); - } - - @Test - public void testInnerProduct(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::inner-product-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) - .getResultList(); - //end::inner-product-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); - assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); - assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); - } ); - } - - @Test - public void testHammingDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::inner-product-example[] - final float[] vector = new float[] { 1, 1, 1 }; - final List results = em.createSelectionQuery( - "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", - Tuple.class - ) - .setParameter( "vec", vector ) - .getResultList(); - //end::inner-product-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); - } ); - } - - @Test - public void testVectorDims(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::vector-dims-example[] - final List results = em.createSelectionQuery( - "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", - Tuple.class - ) - .getResultList(); - //end::vector-dims-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( V1.length, results.get( 0 ).get( 1 ) ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( V2.length, results.get( 1 ).get( 1 ) ); - } ); - } - - @Test - @SkipForDialect(dialectClass = OracleDialect.class, reason = "Oracle 23.9 bug") - public void testVectorNorm(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::vector-norm-example[] - final List results = em.createSelectionQuery( - "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", - Tuple.class - ) - .getResultList(); - //end::vector-norm-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); - } ); - } - - - private static double cosineDistance(float[] f1, float[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(float[] f1, float[] f2) { - return norm( f1 ) - norm( f2 ); - } - - private static double innerProduct(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - public static double hammingDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - int distance = 0; - for ( int i = 0; i < f1.length; i++ ) { - if ( !( f1[i] == f2[i] ) ) { - distance++; - } - } - return distance; - } - - - private static double euclideanNorm(float[] f) { - double result = 0; - for ( double v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - private static double norm(float[] f) { - double result = 0; - for ( double v : f ) { - result += Math.abs( v ); - } - return result; - } - - @Entity(name = "VectorEntity") - public static class VectorEntity { - - @Id - private Long id; - - //tag::usage-example[] - @Column(name = "the_vector") - @JdbcTypeCode(SqlTypes.VECTOR_FLOAT32) - @Array(length = 3) - private float[] theVector; - //end::usage-example[] - - - public VectorEntity() { - } - - public VectorEntity(Long id, float[] theVector) { - this.id = id; - this.theVector = theVector; - } - - public Long getId() { - return id; - } - - public void setId(Long id) { - this.id = id; - } - - public float[] getTheVector() { - return theVector; - } - - public void setTheVector(float[] theVector) { - this.theVector = theVector; - } - } -} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/PGVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/PGVectorTest.java index d3c44c4eebc2..9c3c2f734972 100644 --- a/hibernate-vector/src/test/java/org/hibernate/vector/PGVectorTest.java +++ b/hibernate-vector/src/test/java/org/hibernate/vector/PGVectorTest.java @@ -4,28 +4,25 @@ */ package org.hibernate.vector; -import java.util.List; - +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; import org.hibernate.annotations.Array; import org.hibernate.annotations.JdbcTypeCode; import org.hibernate.dialect.CockroachDialect; import org.hibernate.dialect.PostgreSQLDialect; -import org.hibernate.testing.orm.junit.RequiresDialects; -import org.hibernate.testing.orm.junit.SkipForDialect; -import org.hibernate.type.SqlTypes; - import org.hibernate.testing.orm.junit.DomainModel; import org.hibernate.testing.orm.junit.RequiresDialect; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SkipForDialect; +import org.hibernate.type.SqlTypes; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Id; -import jakarta.persistence.Tuple; +import java.util.List; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -35,10 +32,8 @@ */ @DomainModel(annotatedClasses = PGVectorTest.VectorEntity.class) @SessionFactory -@RequiresDialects({ - @RequiresDialect(value = PostgreSQLDialect.class, matchSubTypes = false), - @RequiresDialect(value = CockroachDialect.class, majorVersion = 24, minorVersion = 2) -}) +@RequiresDialect(value = PostgreSQLDialect.class, matchSubTypes = false) +@RequiresDialect(value = CockroachDialect.class, majorVersion = 24, minorVersion = 2) public class PGVectorTest { private static final float[] V1 = new float[]{ 1, 2, 3 }; @@ -59,118 +54,6 @@ public void cleanup(SessionFactoryScope scope) { } ); } - @Test - public void testRead(SessionFactoryScope scope) { - scope.inTransaction( em -> { - VectorEntity tableRecord; - tableRecord = em.find( VectorEntity.class, 1L ); - assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 ); - - tableRecord = em.find( VectorEntity.class, 2L ); - assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 ); - } ); - } - - @Test - public void testCosineDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::cosine-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::cosine-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D ); - } ); - } - - @Test - public void testEuclideanDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::euclidean-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::euclidean-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); - } ); - } - - @Test - public void testTaxicabDistance(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::taxicab-distance-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::taxicab-distance-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); - } ); - } - - @Test - public void testInnerProduct(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::inner-product-example[] - final float[] vector = new float[]{ 1, 1, 1 }; - final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) - .setParameter( "vec", vector ) - .getResultList(); - //end::inner-product-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); - assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, Double.class ), 0D ); - } ); - } - - @Test - public void testVectorDims(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::vector-dims-example[] - final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) - .getResultList(); - //end::vector-dims-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( V1.length, results.get( 0 ).get( 1 ) ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( V2.length, results.get( 1 ).get( 1 ) ); - } ); - } - - @Test - public void testVectorNorm(SessionFactoryScope scope) { - scope.inTransaction( em -> { - //tag::vector-norm-example[] - final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) - .getResultList(); - //end::vector-norm-example[] - assertEquals( 2, results.size() ); - assertEquals( 1L, results.get( 0 ).get( 0 ) ); - assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, Double.class ), 0D ); - assertEquals( 2L, results.get( 1 ).get( 0 ) ); - assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, Double.class ), 0D ); - } ); - } - @Test @SkipForDialect(dialectClass = CockroachDialect.class, reason = "CockroachDB does not currently support the sum() function on vector type" ) public void testVectorSum(SessionFactoryScope scope) { @@ -227,60 +110,16 @@ public void testMultiplication(SessionFactoryScope scope) { } ); } - private static double cosineDistance(float[] f1, float[] f2) { - return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); - } - - private static double euclideanDistance(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += Math.pow( (double) f1[i] - f2[i], 2 ); - } - return Math.sqrt( result ); - } - - private static double taxicabDistance(float[] f1, float[] f2) { - return norm( f1 ) - norm( f2 ); - } - - private static double innerProduct(float[] f1, float[] f2) { - assert f1.length == f2.length; - double result = 0; - for ( int i = 0; i < f1.length; i++ ) { - result += ( (double) f1[i] ) * ( (double) f2[i] ); - } - return result; - } - - private static double euclideanNorm(float[] f) { - double result = 0; - for ( float v : f ) { - result += Math.pow( v, 2 ); - } - return Math.sqrt( result ); - } - - private static double norm(float[] f) { - double result = 0; - for ( float v : f ) { - result += Math.abs( v ); - } - return result; - } - @Entity( name = "VectorEntity" ) public static class VectorEntity { @Id private Long id; - //tag::usage-example[] @Column( name = "the_vector" ) @JdbcTypeCode(SqlTypes.VECTOR) @Array(length = 3) private float[] theVector; - //end::usage-example[] public VectorEntity() { } diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorTest.java new file mode 100644 index 000000000000..ee78cd363f77 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorTest.java @@ -0,0 +1,242 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.euclideanSquaredDistance; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = SparseByteVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsSparseByteVectorType.class) +public class SparseByteVectorTest { + + private static final byte[] V1 = new byte[]{ 0, 2, 3 }; + private static final byte[] V2 = new byte[]{ 0, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, new SparseByteVector( V1 ) ) ); + em.persist( new VectorEntity( 2L, new SparseByteVector( V2 ) ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new byte[]{ 0, 2, 3 }, tableRecord.getTheVector().toDenseVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new byte[]{ 0, 5, 6 }, tableRecord.getTheVector().toDenseVector() ); + } ); + } + + @Test + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final String literal = VectorTestHelper.vectorSparseStringLiteral( new byte[] {1, 1, 1}, em ); + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('" + literal + "' as sparse_byte_vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + assertEquals( VectorTestHelper.vectorSparseStringLiteral( V1, em ), vector.get( 0, String.class ) ); + assertEquals( new SparseByteVector( new byte[]{ 1, 1, 1 } ), vector.get( 1, SparseByteVector.class ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanSquaredDistance.class) + public void testEuclideanSquaredDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_squared_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseByteVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.SPARSE_VECTOR_INT8) + @Array(length = 3) + private SparseByteVector theVector; + + public VectorEntity() { + } + + public VectorEntity(Long id, SparseByteVector theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public SparseByteVector getTheVector() { + return theVector; + } + + public void setTheVector(SparseByteVector theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorUnitTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorUnitTest.java new file mode 100644 index 000000000000..9fc656be8801 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseByteVectorUnitTest.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class SparseByteVectorUnitTest { + + @Test + public void testEmpty() { + final SparseByteVector bytes = new SparseByteVector( 3 ); + bytes.set( 1, (byte) 3 ); + assertArrayEquals( new Object[] {(byte) 0, (byte) 3, (byte) 0}, bytes.toArray() ); + } + + @Test + public void testInsertBefore() { + final SparseByteVector bytes = new SparseByteVector( 3, new int[] {1}, new byte[] {3} ); + bytes.set( 0, (byte) 2 ); + assertArrayEquals( new Object[] {(byte) 2, (byte) 3, (byte) 0}, bytes.toArray() ); + } + + @Test + public void testInsertAfter() { + final SparseByteVector bytes = new SparseByteVector( 3, new int[] {1}, new byte[] {3} ); + bytes.set( 2, (byte) 2 ); + assertArrayEquals( new Object[] {(byte) 0, (byte) 3, (byte) 2}, bytes.toArray() ); + } + + @Test + public void testReplace() { + final SparseByteVector bytes = new SparseByteVector( 3, new int[] {0, 1, 2}, new byte[] {3, 3, 3} ); + bytes.set( 2, (byte) 2 ); + assertArrayEquals( new Object[] {(byte) 3, (byte) 3, (byte) 2}, bytes.toArray() ); + } + + @Test + public void testFromDenseVector() { + final SparseByteVector bytes = new SparseByteVector( new byte[] {0, 3, 0} ); + assertArrayEquals( new Object[] {(byte) 0, (byte) 3, (byte) 0}, bytes.toArray() ); + } + + @Test + public void testFromDenseVectorList() { + final SparseByteVector bytes = new SparseByteVector( List.of( (byte) 0, (byte) 3, (byte) 0 ) ); + assertArrayEquals( new Object[] {(byte) 0, (byte) 3, (byte) 0}, bytes.toArray() ); + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorTest.java new file mode 100644 index 000000000000..7feba32ef2a0 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorTest.java @@ -0,0 +1,244 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.euclideanSquaredDistance; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.normalizeVectorString; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; +import static org.hibernate.vector.VectorTestHelper.vectorSparseStringLiteral; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = SparseDoubleVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsSparseDoubleVectorType.class) +public class SparseDoubleVectorTest { + + private static final double[] V1 = new double[]{ 0, 2, 3 }; + private static final double[] V2 = new double[]{ 0, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, new SparseDoubleVector( V1 ) ) ); + em.persist( new VectorEntity( 2L, new SparseDoubleVector( V2 ) ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new double[]{ 0, 2, 3 }, tableRecord.getTheVector().toDenseVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new double[]{ 0, 5, 6 }, tableRecord.getTheVector().toDenseVector() ); + } ); + } + + @Test + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final String literal = vectorSparseStringLiteral( new double[] {1, 1, 1}, em ); + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('" + literal + "' as sparse_double_vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + assertEquals( vectorSparseStringLiteral( V1, em ), normalizeVectorString( vector.get( 0, String.class ) ) ); + assertEquals( new SparseDoubleVector( new double[]{ 1, 1, 1 } ), vector.get( 1, SparseDoubleVector.class ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanSquaredDistance.class) + public void testEuclideanSquaredDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_squared_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseDoubleVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.SPARSE_VECTOR_FLOAT64) + @Array(length = 3) + private SparseDoubleVector theVector; + + public VectorEntity() { + } + + public VectorEntity(Long id, SparseDoubleVector theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public SparseDoubleVector getTheVector() { + return theVector; + } + + public void setTheVector(SparseDoubleVector theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorUnitTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorUnitTest.java new file mode 100644 index 000000000000..adf14bb2cef2 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseDoubleVectorUnitTest.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class SparseDoubleVectorUnitTest { + + @Test + public void testEmpty() { + final SparseDoubleVector doubles = new SparseDoubleVector( 3 ); + doubles.set( 1, (double) 3 ); + assertArrayEquals( new Object[] {(double) 0, (double) 3, (double) 0}, doubles.toArray() ); + } + + @Test + public void testInsertBefore() { + final SparseDoubleVector doubles = new SparseDoubleVector( 3, new int[] {1}, new double[] {3} ); + doubles.set( 0, (double) 2 ); + assertArrayEquals( new Object[] {(double) 2, (double) 3, (double) 0}, doubles.toArray() ); + } + + @Test + public void testInsertAfter() { + final SparseDoubleVector doubles = new SparseDoubleVector( 3, new int[] {1}, new double[] {3} ); + doubles.set( 2, (double) 2 ); + assertArrayEquals( new Object[] {(double) 0, (double) 3, (double) 2}, doubles.toArray() ); + } + + @Test + public void testReplace() { + final SparseDoubleVector doubles = new SparseDoubleVector( 3, new int[] {0, 1, 2}, new double[] {3, 3, 3} ); + doubles.set( 2, (double) 2 ); + assertArrayEquals( new Object[] {(double) 3, (double) 3, (double) 2}, doubles.toArray() ); + } + + @Test + public void testFromDenseVector() { + final SparseDoubleVector doubles = new SparseDoubleVector( new double[] {0, 3, 0} ); + assertArrayEquals( new Object[] {(double) 0, (double) 3, (double) 0}, doubles.toArray() ); + } + + @Test + public void testFromDenseVectorList() { + final SparseDoubleVector doubles = new SparseDoubleVector( List.of( (double) 0, (double) 3, (double) 0 ) ); + assertArrayEquals( new Object[] {(double) 0, (double) 3, (double) 0}, doubles.toArray() ); + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorTest.java new file mode 100644 index 000000000000..7bc2b787ba0c --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorTest.java @@ -0,0 +1,249 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.dialect.PostgresPlusDialect; +import org.hibernate.testing.orm.junit.DialectFeatureChecks; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialectFeature; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.SkipForDialect; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.hibernate.vector.VectorTestHelper.cosineDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanDistance; +import static org.hibernate.vector.VectorTestHelper.euclideanNorm; +import static org.hibernate.vector.VectorTestHelper.euclideanSquaredDistance; +import static org.hibernate.vector.VectorTestHelper.hammingDistance; +import static org.hibernate.vector.VectorTestHelper.innerProduct; +import static org.hibernate.vector.VectorTestHelper.normalizeVectorString; +import static org.hibernate.vector.VectorTestHelper.taxicabDistance; +import static org.hibernate.vector.VectorTestHelper.vectorSparseStringLiteral; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@DomainModel(annotatedClasses = SparseFloatVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsSparseFloatVectorType.class) +@SkipForDialect(dialectClass = PostgresPlusDialect.class, reason = "Test database does not have the extension enabled") +public class SparseFloatVectorTest { + + private static final float[] V1 = new float[]{ 0, 2, 3 }; + private static final float[] V2 = new float[]{ 0, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, new SparseFloatVector( V1 ) ) ); + em.persist( new VectorEntity( 2L, new SparseFloatVector( V2 ) ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new float[]{ 0, 2, 3 }, tableRecord.getTheVector().toDenseVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new float[]{ 0, 5, 6 }, tableRecord.getTheVector().toDenseVector() ); + } ); + } + + @Test + public void testCast(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final String literal = vectorSparseStringLiteral( new float[] {1, 1, 1}, em ); + final Tuple vector = em.createSelectionQuery( "select cast(e.theVector as string), cast('" + literal + "' as sparse_float_vector(3)) from VectorEntity e where e.id = 1", Tuple.class ) + .getSingleResult(); + assertEquals( vectorSparseStringLiteral( V1, em ), normalizeVectorString( vector.get( 0, String.class ) ) ); + assertEquals( new SparseFloatVector( new float[]{ 1, 1, 1 } ), vector.get( 1, SparseFloatVector.class ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsCosineDistance.class) + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanDistance.class) + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsEuclideanSquaredDistance.class) + public void testEuclideanSquaredDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_squared_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanSquaredDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsTaxicabDistance.class) + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsInnerProduct.class) + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsHammingDistance.class) + @SkipForDialect(dialectClass = PostgreSQLDialect.class, matchSubTypes = true, reason = "Not supported with sparse vectors") + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", new SparseFloatVector( vector ) ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorDims.class) + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + @RequiresDialectFeature(feature = DialectFeatureChecks.SupportsVectorNorm.class) + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.SPARSE_VECTOR_FLOAT32) + @Array(length = 3) + private SparseFloatVector theVector; + + public VectorEntity() { + } + + public VectorEntity(Long id, SparseFloatVector theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public SparseFloatVector getTheVector() { + return theVector; + } + + public void setTheVector(SparseFloatVector theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorUnitTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorUnitTest.java new file mode 100644 index 000000000000..20eb1dd3c3f8 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/SparseFloatVectorUnitTest.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class SparseFloatVectorUnitTest { + + @Test + public void testEmpty() { + final SparseFloatVector floats = new SparseFloatVector( 3 ); + floats.set( 1, (float) 3 ); + assertArrayEquals( new Object[] {(float) 0, (float) 3, (float) 0}, floats.toArray() ); + } + + @Test + public void testInsertBefore() { + final SparseFloatVector floats = new SparseFloatVector( 3, new int[] {1}, new float[] {3} ); + floats.set( 0, (float) 2 ); + assertArrayEquals( new Object[] {(float) 2, (float) 3, (float) 0}, floats.toArray() ); + } + + @Test + public void testInsertAfter() { + final SparseFloatVector floats = new SparseFloatVector( 3, new int[] {1}, new float[] {3} ); + floats.set( 2, (float) 2 ); + assertArrayEquals( new Object[] {(float) 0, (float) 3, (float) 2}, floats.toArray() ); + } + + @Test + public void testReplace() { + final SparseFloatVector floats = new SparseFloatVector( 3, new int[] {0, 1, 2}, new float[] {3, 3, 3} ); + floats.set( 2, (float) 2 ); + assertArrayEquals( new Object[] {(float) 3, (float) 3, (float) 2}, floats.toArray() ); + } + + @Test + public void testFromDenseVector() { + final SparseFloatVector floats = new SparseFloatVector( new float[] {0, 3, 0} ); + assertArrayEquals( new Object[] {(float) 0, (float) 3, (float) 0}, floats.toArray() ); + } + + @Test + public void testFromDenseVectorList() { + final SparseFloatVector floats = new SparseFloatVector( List.of( (float) 0, (float) 3, (float) 0 ) ); + assertArrayEquals( new Object[] {(float) 0, (float) 3, (float) 0}, floats.toArray() ); + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/VectorTestHelper.java b/hibernate-vector/src/test/java/org/hibernate/vector/VectorTestHelper.java new file mode 100644 index 000000000000..3edc584d1315 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/VectorTestHelper.java @@ -0,0 +1,327 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.engine.spi.SessionImplementor; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.spi.TypeConfiguration; + +public class VectorTestHelper { + + public static double cosineDistance(float[] f1, float[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + public static double cosineDistance(double[] f1, double[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + public static double cosineDistance(byte[] f1, byte[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + public static double cosineDistanceBinary(byte[] f1, byte[] f2) { + return 1D - innerProductBinary( f1, f2 ) / ( euclideanNormBinary( f1 ) * euclideanNormBinary( f2 ) ); + } + + public static double euclideanDistance(float[] f1, float[] f2) { + return Math.sqrt( euclideanSquaredDistance( f1, f2 ) ); + } + + public static double euclideanDistance(double[] f1, double[] f2) { + return Math.sqrt( euclideanSquaredDistance( f1, f2 ) ); + } + + public static double euclideanDistance(byte[] f1, byte[] f2) { + return Math.sqrt( euclideanSquaredDistance( f1, f2 ) ); + } + + public static double euclideanDistanceBinary(byte[] f1, byte[] f2) { + // On bit level, the two distance functions are equivalent + return Math.sqrt( hammingDistanceBinary( f1, f2 ) ); + } + + public static double euclideanSquaredDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return result; + } + + public static double euclideanSquaredDistance(double[] f1, double[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return result; + } + + public static double euclideanSquaredDistance(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return result; + } + + public static double euclideanSquaredDistanceBinary(byte[] f1, byte[] f2) { + // On bit level, the two distance functions are equivalent + return hammingDistanceBinary( f1, f2 ); + } + + public static double taxicabDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.abs( f1[i] - f2[i] ); + } + return result; + } + + public static double taxicabDistance(double[] f1, double[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.abs( f1[i] - f2[i] ); + } + return result; + } + + public static double taxicabDistance(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.abs( f1[i] - f2[i] ); + } + return result; + } + + public static double taxicabDistanceBinary(byte[] f1, byte[] f2) { + // On bit level, the two distance functions are equivalent + return hammingDistanceBinary( f1, f2 ); + } + + public static double innerProduct(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double innerProduct(double[] f1, double[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double innerProduct(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double innerProductBinary(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Integer.bitCount( f1[i] & f2[i] ); + } + return result; + } + + public static double hammingDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + int distance = 0; + for ( int i = 0; i < f1.length; i++ ) { + if ( !( f1[i] == f2[i] ) ) { + distance++; + } + } + return distance; + } + + public static double hammingDistance(double[] f1, double[] f2) { + assert f1.length == f2.length; + int distance = 0; + for (int i = 0; i < f1.length; i++) { + if (!(f1[i] == f2[i])) { + distance++; + } + } + return distance; + } + + public static double hammingDistance(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + int distance = 0; + for (int i = 0; i < f1.length; i++) { + if (!(f1[i] == f2[i])) { + distance++; + } + } + return distance; + } + + public static double hammingDistanceBinary(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + int distance = 0; + for (int i = 0; i < f1.length; i++) { + distance += Integer.bitCount( f1[i] ^ f2[i] ); + } + return distance; + } + + public static double euclideanNorm(float[] f) { + double result = 0; + for ( float v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + public static double euclideanNorm(double[] f) { + double result = 0; + for ( double v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + public static double euclideanNorm(byte[] f) { + double result = 0; + for ( byte v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + public static float[] euclideanNormalize(float[] f) { + final double norm = euclideanNorm( f ); + final float[] result = new float[f.length]; + for ( int i = 0; i < f.length; i++ ) { + result[i] = (float) (f[i] / norm); + } + return result; + } + + public static float[] euclideanNormalize(double[] f) { + final double norm = euclideanNorm( f ); + final float[] result = new float[f.length]; + for ( int i = 0; i < f.length; i++ ) { + result[i] = (float) (f[i] / norm); + } + return result; + } + + public static float[] euclideanNormalize(byte[] f) { + final double norm = euclideanNorm( f ); + final float[] result = new float[f.length]; + for ( int i = 0; i < f.length; i++ ) { + result[i] = (float) (f[i] / norm); + } + return result; + } + + public static double euclideanNormBinary(byte[] f) { + double result = 0; + for ( byte v : f ) { + result += Integer.bitCount( v ); + } + return Math.sqrt( result ); + } + + public static double jaccardDistanceBinary(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + int intersectionSum = 0; + int unionSum = 0; + for (int i = 0; i < f1.length; i++) { + intersectionSum += Integer.bitCount( f1[i] & f2[i] ); + unionSum += Integer.bitCount( f1[i] | f2[i] ); + } + return 1d - (double) intersectionSum / unionSum; + } + + public static String normalizeVectorString(String vector) { + return vector.replace( "E+000", "" ) + .replace( ".0", "" ); + } + + public static String vectorSparseStringLiteral(float[] vector, SessionImplementor session) { + final SessionFactoryImplementor sessionFactory = session.getFactory(); + final TypeConfiguration typeConfiguration = sessionFactory.getTypeConfiguration(); + final JdbcLiteralFormatter literalFormatter = typeConfiguration.getJdbcTypeRegistry() + .getDescriptor( SqlTypes.SPARSE_VECTOR_FLOAT32 ) + .getJdbcLiteralFormatter( typeConfiguration.getJavaTypeRegistry().getDescriptor( SparseFloatVector.class ) ); + final String jdbcLiteral = literalFormatter.toJdbcLiteral( + new SparseFloatVector( vector ), + sessionFactory.getJdbcServices().getDialect(), + session + ); + final int start = jdbcLiteral.indexOf( '\'' ); + final int end = jdbcLiteral.indexOf( '\'', start + 1 ); + return jdbcLiteral.substring( start + 1, end ).replace( ".0", "" ); + } + + public static String vectorSparseStringLiteral(double[] vector, SessionImplementor session) { + final SessionFactoryImplementor sessionFactory = session.getFactory(); + final TypeConfiguration typeConfiguration = sessionFactory.getTypeConfiguration(); + final JdbcLiteralFormatter literalFormatter = typeConfiguration.getJdbcTypeRegistry() + .getDescriptor( SqlTypes.SPARSE_VECTOR_FLOAT64 ) + .getJdbcLiteralFormatter( typeConfiguration.getJavaTypeRegistry().getDescriptor( SparseDoubleVector.class ) ); + final String jdbcLiteral = literalFormatter.toJdbcLiteral( + new SparseDoubleVector( vector ), + sessionFactory.getJdbcServices().getDialect(), + session + ); + final int start = jdbcLiteral.indexOf( '\'' ); + final int end = jdbcLiteral.indexOf( '\'', start + 1 ); + return jdbcLiteral.substring( start + 1, end ).replace( ".0", "" ); + } + + public static String vectorSparseStringLiteral(byte[] vector, SessionImplementor session) { + final SessionFactoryImplementor sessionFactory = session.getFactory(); + final TypeConfiguration typeConfiguration = sessionFactory.getTypeConfiguration(); + final JdbcLiteralFormatter literalFormatter = typeConfiguration.getJdbcTypeRegistry() + .getDescriptor( SqlTypes.SPARSE_VECTOR_INT8 ) + .getJdbcLiteralFormatter( typeConfiguration.getJavaTypeRegistry().getDescriptor( SparseByteVector.class ) ); + final String jdbcLiteral = literalFormatter.toJdbcLiteral( + new SparseByteVector( vector ), + sessionFactory.getJdbcServices().getDialect(), + session + ); + final int start = jdbcLiteral.indexOf( '\'' ); + final int end = jdbcLiteral.indexOf( '\'', start + 1 ); + return jdbcLiteral.substring( start + 1, end ); + } + + public static String vectorBinaryStringLiteral(byte[] vector, SessionImplementor session) { + final SessionFactoryImplementor sessionFactory = session.getFactory(); + final JdbcLiteralFormatter literalFormatter = sessionFactory.getTypeConfiguration() + .getBasicTypeRegistry().resolve( StandardBasicTypes.VECTOR_BINARY ).getJdbcLiteralFormatter(); + final String jdbcLiteral = literalFormatter.toJdbcLiteral( + vector, + sessionFactory.getJdbcServices().getDialect(), + session + ); + final int start = jdbcLiteral.indexOf( '\'' ); + final int end = jdbcLiteral.indexOf( '\'', start + 1 ); + return jdbcLiteral.substring( start + 1, end ); + } +}