diff --git a/oracle/common.go b/oracle/common.go index efcb375..b6712ad 100644 --- a/oracle/common.go +++ b/oracle/common.go @@ -41,9 +41,8 @@ package oracle import ( "bytes" "database/sql" + "database/sql/driver" "encoding/json" - "fmt" - "math" "reflect" "strings" "time" @@ -52,50 +51,50 @@ import ( "github.com/google/uuid" "gorm.io/datatypes" "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) // Extra data types for the data type that are not declared in the // default DataType list const ( - JSON schema.DataType = "json" Timestamp schema.DataType = "timestamp" TimestampWithTimeZone schema.DataType = "timestamp with time zone" ) // Helper function to get Oracle array type for a field -func getOracleArrayType(field *schema.Field, values []any) string { - switch field.DataType { - case schema.Bool: - return "TABLE OF NUMBER(1)" - case schema.Int, schema.Uint: - return "TABLE OF NUMBER" - case schema.Float: - return "TABLE OF NUMBER" - case JSON: - // PL/SQL does not yet allow declaring collections of JSON (TABLE OF JSON) directly. - // Workaround for JSON type - fallthrough - case schema.String: - if field.Size > 0 && field.Size <= 4000 { - return fmt.Sprintf("TABLE OF VARCHAR2(%d)", field.Size) - } else { - for _, value := range values { - if strValue, ok := value.(string); ok { - if len(strValue) > 4000 { - return "TABLE OF CLOB" - } - } +func getOracleArrayType(values []any) string { + arrayType := "TABLE OF VARCHAR2(4000)" + for _, val := range values { + if val == nil { + continue + } + switch v := val.(type) { + case bool: + arrayType = "TABLE OF NUMBER(1)" + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + arrayType = "TABLE OF NUMBER" + case time.Time: + arrayType = "TABLE OF TIMESTAMP WITH TIME ZONE" + case godror.Lob: + if v.IsClob { + return "TABLE OF CLOB" + } else { + return "TABLE OF BLOB" + } + case []byte: + // Store byte slices longer than 4000 bytes as BLOB + if len(v) > 4000 { + return "TABLE OF BLOB" + } + case string: + // Store strings longer than 4000 characters as CLOB + if len(v) > 4000 { + return "TABLE OF CLOB" } } - return "TABLE OF VARCHAR2(4000)" - case schema.Time: - return "TABLE OF TIMESTAMP WITH TIME ZONE" - case schema.Bytes: - return "TABLE OF BLOB" - default: - return "TABLE OF " + strings.ToUpper(string(field.DataType)) } + return arrayType } // Helper function to get all column names for a table @@ -131,6 +130,12 @@ func createTypedDestination(f *schema.Field) interface{} { return new(string) } + // To differentiate between bool fields stored as NUMBER(1) and bool fields stored as actual BOOLEAN type, + // check the struct's "type" tag. + if f.DataType == "boolean" { + return new(bool) + } + // If the field has a serializer, the field type may not be directly related to the column type in the database. // In this case, determine the destination type using the field's data type, which is the column type in the // database. @@ -204,13 +209,17 @@ func createTypedDestination(f *schema.Field) interface{} { case reflect.Float32, reflect.Float64: return new(float64) + + case reflect.Slice: + if ft.Elem().Kind() == reflect.Uint8 { // []byte + return new([]byte) + } } // Fallback return new(string) } -// Convert values for Oracle-specific types func convertValue(val interface{}) interface{} { if val == nil { return nil @@ -222,6 +231,10 @@ func convertValue(val interface{}) interface{} { rv = rv.Elem() val = rv.Interface() } + isNil := false + if rv.Kind() == reflect.Ptr && rv.IsNil() { + isNil = true + } switch v := val.(type) { case json.RawMessage: @@ -238,7 +251,7 @@ func convertValue(val interface{}) interface{} { case *uuid.UUID, *datatypes.UUID: // Convert nil pointer to a UUID to empty string so that it is stored in the database as NULL // rather than "00000000-0000-0000-0000-000000000000" - if rv.IsNil() { + if isNil { return "" } return val @@ -249,15 +262,31 @@ func convertValue(val interface{}) interface{} { return 0 } case string: - if len(v) > math.MaxInt16 { + // Store strings longer than 4000 characters as CLOB + if len(v) > 4000 { return godror.Lob{IsClob: true, Reader: strings.NewReader(v)} } return v case []byte: - if len(v) > math.MaxInt16 { + // Store byte slices longer than 4000 bytes as BLOB + if len(v) > 4000 { return godror.Lob{IsClob: false, Reader: bytes.NewReader(v)} } return v + case driver.Valuer: + // Unwrap driver.Valuer to its underlying type by recursing into + // convertValue until we get a non-Valuer type + if v == nil || isNil { + return val + } + unwrappedValue, err := v.Value() + if err != nil { + return val + } + return convertValue(unwrappedValue) + case clause.Expr: + // If we get a clause.Expr, convert it to nil; it should be handled elsewhere + return nil default: return val } @@ -285,6 +314,13 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ targetType = field.FieldType.Elem() } + // When PL/SQL LOBs are returned, skip conversion. + // LOB addresses are freed by the driver after the query, so we cannot read their content + // from the return value. If you need to read stored LOB content, do it in a separate query. + if _, ok := value.(godror.Lob); ok { + return nil + } + switch targetType { case reflect.TypeOf(gorm.DeletedAt{}): if nullTime, ok := value.(sql.NullTime); ok { @@ -318,6 +354,16 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ default: converted = value } + case reflect.TypeOf(uuid.UUID{}), reflect.TypeOf(datatypes.UUID{}): + uuidStr, ok := value.(string) + if !ok { + return nil + } + parsed, err := uuid.Parse(uuidStr) + if err != nil { + return nil + } + converted = parsed case reflect.TypeOf(time.Time{}): switch vv := value.(type) { @@ -388,6 +434,12 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ } func isJSONField(f *schema.Field) bool { + // Support detecting JSON fields through the struct's "type" tag. + // Also support jsonb for compatibility with other databases. + if f.DataType == "json" || f.DataType == "jsonb" { + return true + } + _rawMsgT := reflect.TypeOf(json.RawMessage{}) _gormJSON := reflect.TypeOf(datatypes.JSON{}) if f == nil { diff --git a/oracle/create.go b/oracle/create.go index 0dcac30..3e7e65e 100644 --- a/oracle/create.go +++ b/oracle/create.go @@ -39,17 +39,27 @@ package oracle import ( + "bytes" "database/sql" "fmt" "reflect" "strings" + "github.com/godror/godror" "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) +// plsqlBindVariableMap is a helper struct to manage PL/SQL bind variables. +// It maps column names to their corresponding slice of real values, as well +// as recording columns that are LOBs. +type plsqlBindVariableMap struct { + lobColumns map[string]bool + variableMap map[string][]any +} + // Create overrides GORM's create callback for Oracle. // // Behavior: @@ -131,9 +141,14 @@ func Create(db *gorm.DB) { hasReturningInDryRun := db.DryRun && hasReturningClause needsReturning := stmtSchema != nil && len(stmtSchema.FieldsWithDefaultDBValue) > 0 && (!db.DryRun || hasReturningInDryRun) - if needsReturning && len(createValues.Values) > 1 { + // Pre-emptively map PL/SQL bind variables to check for LOBs + // If we have LOBs, we need to use PL/SQL for bulk inserts to ensure + // all values for a particular column are identically typed. + plsqlBindMap := mapPLSQLBindValues(createValues) + + if (needsReturning || len(plsqlBindMap.lobColumns) > 0) && len(createValues.Values) > 1 { // Multiple rows with RETURNING - use PL/SQL - buildBulkInsertPLSQL(db, createValues) + buildBulkInsertPLSQL(db, createValues, plsqlBindMap) } else if needsReturning { // Single row with RETURNING - use regular SQL with RETURNING buildSingleInsertSQL(db, createValues) @@ -144,6 +159,60 @@ func Create(db *gorm.DB) { } } +// mapPLSQLBindValues maps the bind variables for PL/SQL batch inserts. +// It frontloads the conversion of values to their real types, while also +// ensuring that columns that are LOBs are identified and typed consistently. +// Without this, subsets of batch inserts targeting string or []byte fields +// may overrun the maximum size for VARCHAR2 and cause inconsistent types during UNIONs. +func mapPLSQLBindValues(createValues clause.Values) plsqlBindVariableMap { + lobColumns := make(map[string]bool) + mappedVars := make(map[string][]any) + for i, column := range createValues.Columns { + for _, values := range createValues.Values { + value := convertValue(values[i]) + if _, ok := lobColumns[column.Name]; ok { + value = convertToLOB(value) + } else { + lob, isLob := value.(godror.Lob) + if isLob { + lobColumns[column.Name] = true + lobs := convertToLOBs(mappedVars[column.Name]) + mappedVars[column.Name] = lobs + value = lob + } + } + mappedVars[column.Name] = append(mappedVars[column.Name], value) + } + } + return plsqlBindVariableMap{ + variableMap: mappedVars, + lobColumns: lobColumns, + } +} + +// convertToLOBs converts an array of values to their respective LOB types (if needed). +func convertToLOBs(values []any) []any { + newVals := make([]any, len(values)) + for i, val := range values { + newVals[i] = convertToLOB(val) + } + return newVals +} + +// convertToLOB converts a value to its respective LOB type (if any). +func convertToLOB(val any) any { + if val == nil { + return val + } + switch v := val.(type) { + case string: + return godror.Lob{IsClob: true, Reader: strings.NewReader(v)} + case []byte: + return godror.Lob{IsClob: false, Reader: bytes.NewReader(v)} + } + return val +} + // validateCreateData checks for invalid data in the destination before processing func validateCreateData(stmt *gorm.Statement) error { if stmt.Dest == nil { @@ -175,7 +244,7 @@ func validateCreateData(stmt *gorm.Statement) error { } // Build PL/SQL block for bulk INSERT/MERGE with RETURNING -func buildBulkInsertPLSQL(db *gorm.DB, createValues clause.Values) { +func buildBulkInsertPLSQL(db *gorm.DB, createValues clause.Values, bindMap plsqlBindVariableMap) { sanitizeCreateValuesForBulkArrays(db.Statement, &createValues) stmt := db.Statement @@ -229,16 +298,16 @@ func buildBulkInsertPLSQL(db *gorm.DB, createValues clause.Values) { shouldUseMerge := ShouldUseRealConflict(createValues, onConflict, conflictColumns) if shouldUseMerge { - buildBulkMergePLSQL(db, createValues, onConflictClause) + buildBulkMergePLSQL(db, createValues, onConflictClause, bindMap) return } } // Original INSERT logic for when there's no conflict handling needed - buildBulkInsertOnlyPLSQL(db, createValues) + buildBulkInsertOnlyPLSQL(db, createValues, bindMap) } // Build PL/SQL block for bulk MERGE with RETURNING (OnConflict case) -func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClause clause.Clause) { +func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClause clause.Clause, bindMap plsqlBindVariableMap) { sanitizeCreateValuesForBulkArrays(db.Statement, &createValues) stmt := db.Statement @@ -267,18 +336,18 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau valuesColumnMap[strings.ToUpper(column.Name)] = true } - // Filter conflict columns to remove non unique columns + // Filter conflict columns to remove non unique columns and columns not a part of the INSERT var filteredConflictColumns []clause.Column for _, conflictCol := range conflictColumns { field := stmt.Schema.LookUpField(conflictCol.Name) - if valuesColumnMap[strings.ToUpper(conflictCol.Name)] && (field.Unique || field.AutoIncrement) { + if valuesColumnMap[strings.ToUpper(conflictCol.Name)] && fieldCanConflict(field, schema) { filteredConflictColumns = append(filteredConflictColumns, conflictCol) } } // Check if we have any usable conflict columns if len(filteredConflictColumns) == 0 { - buildBulkInsertOnlyPLSQL(db, createValues) + buildBulkInsertOnlyPLSQL(db, createValues, bindMap) return } @@ -294,13 +363,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau // Create array types and variables for each column for i, column := range createValues.Columns { - var arrayType string - if field := findFieldByDBName(schema, column.Name); field != nil { - arrayType = getOracleArrayType(field, pluck(createValues.Values, i)) - } else { - arrayType = "TABLE OF VARCHAR2(4000)" - } - + arrayType := getOracleArrayType(bindMap.variableMap[column.Name]) plsqlBuilder.WriteString(fmt.Sprintf(" TYPE t_col_%d_array IS %s;\n", i, arrayType)) plsqlBuilder.WriteString(fmt.Sprintf(" l_col_%d_array t_col_%d_array;\n", i, i)) } @@ -308,14 +371,14 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau plsqlBuilder.WriteString("BEGIN\n") // Initialize arrays with values - for i := range createValues.Columns { + for i, column := range createValues.Columns { plsqlBuilder.WriteString(fmt.Sprintf(" l_col_%d_array := t_col_%d_array(", i, i)) - for j, values := range createValues.Values { + for j, value := range bindMap.variableMap[column.Name] { if j > 0 { plsqlBuilder.WriteString(", ") } plsqlBuilder.WriteString(fmt.Sprintf(":%d", len(stmt.Vars)+1)) - stmt.Vars = append(stmt.Vars, convertValue(values[i])) + stmt.Vars = append(stmt.Vars, value) } plsqlBuilder.WriteString(");\n") } @@ -517,7 +580,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau if isJSONField(field) { if isRawMessageField(field) { // Column is a BLOB, return raw bytes; no JSON_SERIALIZE - stmt.Vars = append(stmt.Vars, sql.Out{Dest: new([]byte)}) + stmt.Vars = append(stmt.Vars, sql.Out{Dest: &godror.Lob{IsClob: false}}) plsqlBuilder.WriteString(fmt.Sprintf( " IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1, @@ -526,7 +589,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau plsqlBuilder.WriteString("; END IF;\n") } else { // datatypes.JSON (text-based) -> serialize to CLOB - stmt.Vars = append(stmt.Vars, sql.Out{Dest: new(string)}) + stmt.Vars = append(stmt.Vars, sql.Out{Dest: &godror.Lob{IsClob: true}}) plsqlBuilder.WriteString(fmt.Sprintf( " IF l_affected_records.COUNT > %d THEN :%d := JSON_SERIALIZE(l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1, @@ -535,7 +598,16 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau plsqlBuilder.WriteString(" RETURNING CLOB); END IF;\n") } } else { - stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)}) + fieldType := createTypedDestination(field) + if bindMap.lobColumns[column] { + switch fieldType.(type) { + case *[]uint8: + fieldType = &godror.Lob{IsClob: false} + case *string: + fieldType = &godror.Lob{IsClob: true} + } + } + stmt.Vars = append(stmt.Vars, sql.Out{Dest: fieldType}) plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1)) db.QuoteTo(&plsqlBuilder, column) plsqlBuilder.WriteString("; END IF;\n") @@ -564,7 +636,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau } // Build PL/SQL block for bulk INSERT only (no conflict handling) -func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) { +func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values, bindMap plsqlBindVariableMap) { stmt := db.Statement schema := stmt.Schema @@ -577,13 +649,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) { // Create array types and variables for each column for i, column := range createValues.Columns { - var arrayType string - if field := findFieldByDBName(schema, column.Name); field != nil { - arrayType = getOracleArrayType(field, pluck(createValues.Values, i)) - } else { - arrayType = "TABLE OF VARCHAR2(4000)" - } - + arrayType := getOracleArrayType(bindMap.variableMap[column.Name]) plsqlBuilder.WriteString(fmt.Sprintf(" TYPE t_col_%d_array IS %s;\n", i, arrayType)) plsqlBuilder.WriteString(fmt.Sprintf(" l_col_%d_array t_col_%d_array;\n", i, i)) } @@ -591,14 +657,14 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) { plsqlBuilder.WriteString("BEGIN\n") // Initialize arrays with values - for i := range createValues.Columns { + for i, column := range createValues.Columns { plsqlBuilder.WriteString(fmt.Sprintf(" l_col_%d_array := t_col_%d_array(", i, i)) - for j, values := range createValues.Values { + for j, value := range bindMap.variableMap[column.Name] { if j > 0 { plsqlBuilder.WriteString(", ") } plsqlBuilder.WriteString(fmt.Sprintf(":%d", len(stmt.Vars)+1)) - stmt.Vars = append(stmt.Vars, convertValue(values[i])) + stmt.Vars = append(stmt.Vars, value) } plsqlBuilder.WriteString(");\n") } @@ -649,21 +715,30 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) { if isJSONField(field) { if isRawMessageField(field) { // Column is a BLOB, return raw bytes; no JSON_SERIALIZE - stmt.Vars = append(stmt.Vars, sql.Out{Dest: new([]byte)}) + stmt.Vars = append(stmt.Vars, sql.Out{Dest: &godror.Lob{IsClob: false}}) plsqlBuilder.WriteString(fmt.Sprintf( " IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n", rowIdx, outParamIndex+1, rowIdx+1, quotedColumn, )) } else { // datatypes.JSON (text-based) -> serialize to CLOB - stmt.Vars = append(stmt.Vars, sql.Out{Dest: new(string)}) + stmt.Vars = append(stmt.Vars, sql.Out{Dest: &godror.Lob{IsClob: true}}) plsqlBuilder.WriteString(fmt.Sprintf( " IF l_inserted_records.COUNT > %d THEN :%d := JSON_SERIALIZE(l_inserted_records(%d).%s RETURNING CLOB); END IF;\n", rowIdx, outParamIndex+1, rowIdx+1, quotedColumn, )) } } else { - stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)}) + fieldType := createTypedDestination(field) + if bindMap.lobColumns[column] { + switch fieldType.(type) { + case *[]uint8: + fieldType = &godror.Lob{IsClob: false} + case *string: + fieldType = &godror.Lob{IsClob: true} + } + } + stmt.Vars = append(stmt.Vars, sql.Out{Dest: fieldType}) plsqlBuilder.WriteString(fmt.Sprintf( " IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n", rowIdx, outParamIndex+1, rowIdx+1, quotedColumn, @@ -1016,3 +1091,31 @@ func pluck[T any, N int](data [][]T, col int) []T { } return out } + +// fieldCanConflict checks if a field can be used as a conflict target in PL/SQL MERGE statements. +// A field can be used as a conflict target if it is contains unique values. This includes primary key fields +// and unique fields. However, in cases of composite primary keys where the identity column is auto-incremented, +// even a primary key field cannot be used as a conflict target, as the auto-incremented primary key will ensure +// a unique row. +func fieldCanConflict(field *schema.Field, schema *schema.Schema) bool { + if field.PrimaryKey { + if schema != nil && schema.PrioritizedPrimaryField != nil && schema.PrioritizedPrimaryField.AutoIncrement { + // If the auto-incremented primary key is among the createValues, then it *can* be a conflict target + if schema.PrioritizedPrimaryField.Name == field.Name { + return true + } + return false + } + for _, primaryField := range schema.PrimaryFields { + if primaryField.AutoIncrement { + // If the auto-incremented primary key is among the createValues, then it *can* be a conflict target + if primaryField.Name == field.Name { + return true + } + return false + } + } + return true + } + return field.Unique +} diff --git a/tests/blob_test.go b/tests/blob_test.go index dc079e3..3ac5111 100644 --- a/tests/blob_test.go +++ b/tests/blob_test.go @@ -41,7 +41,6 @@ package tests import ( "bytes" "crypto/rand" - "strings" "testing" "time" @@ -63,22 +62,12 @@ type BlobVariantModel struct { LargeBlob []byte `gorm:"type:blob"` } -type BlobOneToManyModel struct { - ID uint `gorm:"primaryKey"` - Children []BlobChildModel `gorm:"foreignKey:ID"` -} - -type BlobChildModel struct { - ID uint `gorm:"primaryKey"` - Data []byte `gorm:"type:blob"` -} - func setupBlobTestTables(t *testing.T) { t.Log("Setting up BLOB test tables") - DB.Migrator().DropTable(&BlobTestModel{}, &BlobVariantModel{}, &BlobOneToManyModel{}, &BlobChildModel{}) + DB.Migrator().DropTable(&BlobTestModel{}, &BlobVariantModel{}) - err := DB.AutoMigrate(&BlobTestModel{}, &BlobVariantModel{}, &BlobOneToManyModel{}, &BlobChildModel{}) + err := DB.AutoMigrate(&BlobTestModel{}, &BlobVariantModel{}) if err != nil { t.Fatalf("Failed to migrate BLOB test tables: %v", err) } @@ -434,23 +423,6 @@ func TestBlobWithReturning(t *testing.T) { } } -func TestBlobOnConflict(t *testing.T) { - setupBlobTestTables(t) - - model := &BlobOneToManyModel{ - ID: 1, - Children: []BlobChildModel{ - { - Data: []byte(strings.Repeat("X", 32768)), - }, - }, - } - err := DB.Create(model).Error - if err != nil { - t.Fatalf("Failed to create BLOB record with ON CONFLICT: %v", err) - } -} - func TestBlobErrorHandling(t *testing.T) { setupBlobTestTables(t) diff --git a/tests/boolean_test.go b/tests/boolean_test.go index cd945b2..866f22e 100644 --- a/tests/boolean_test.go +++ b/tests/boolean_test.go @@ -40,8 +40,8 @@ package tests import ( "database/sql" - "testing" "strings" + "testing" ) type BooleanTest struct { @@ -135,23 +135,23 @@ func TestBooleanQueryFilters(t *testing.T) { } func TestBooleanNegativeInvalidDBValue(t *testing.T) { - DB.Migrator().DropTable(&BooleanTest{}) - DB.AutoMigrate(&BooleanTest{}) - - if err := DB.Exec(`INSERT INTO "BOOLEAN_TESTS" ("ID","FLAG") VALUES (2001, 2)`).Error; err != nil { - t.Fatalf("failed to insert invalid bool: %v", err) - } - - var got BooleanTest - err := DB.First(&got, 2001).Error - if err == nil { - t.Fatal("expected invalid boolean scan error, got nil") - } - - if !strings.Contains(err.Error(), "invalid") && - !strings.Contains(err.Error(), "convert") { - t.Fatalf("expected boolean conversion error, got: %v", err) - } + DB.Migrator().DropTable(&BooleanTest{}) + DB.AutoMigrate(&BooleanTest{}) + + if err := DB.Exec(`INSERT INTO "BOOLEAN_TESTS" ("ID","FLAG") VALUES (2001, 2)`).Error; err != nil { + t.Fatalf("failed to insert invalid bool: %v", err) + } + + var got BooleanTest + err := DB.First(&got, 2001).Error + if err == nil { + t.Fatal("expected invalid boolean scan error, got nil") + } + + if !strings.Contains(err.Error(), "invalid") && + !strings.Contains(err.Error(), "convert") { + t.Fatalf("expected boolean conversion error, got: %v", err) + } } func TestBooleanInsertWithIntValues(t *testing.T) { @@ -213,26 +213,26 @@ func TestBooleanDefaultValue(t *testing.T) { } func TestBooleanQueryMixedComparisons(t *testing.T) { - DB.Migrator().DropTable(&BooleanTest{}) - DB.AutoMigrate(&BooleanTest{}) - - DB.Create(&BooleanTest{Flag: true}) - DB.Create(&BooleanTest{Flag: false}) - - var gotNum []BooleanTest - - // FILTER USING NUMBER - if err := DB.Where("FLAG = 1").Find(&gotNum).Error; err != nil { - t.Fatal(err) - } - if len(gotNum) == 0 { - t.Errorf("expected at least 1 row for FLAG=1") - } - - var gotStr []BooleanTest - if err := DB.Where("FLAG = 'true'").Find(&gotStr).Error; err == nil { - t.Errorf("expected ORA-01722 when comparing NUMBER to string literal") - } + DB.Migrator().DropTable(&BooleanTest{}) + DB.AutoMigrate(&BooleanTest{}) + + DB.Create(&BooleanTest{Flag: true}) + DB.Create(&BooleanTest{Flag: false}) + + var gotNum []BooleanTest + + // FILTER USING NUMBER + if err := DB.Where("FLAG = 1").Find(&gotNum).Error; err != nil { + t.Fatal(err) + } + if len(gotNum) == 0 { + t.Errorf("expected at least 1 row for FLAG=1") + } + + var gotStr []BooleanTest + if err := DB.Where("FLAG = 'true'").Find(&gotStr).Error; err == nil { + t.Errorf("expected ORA-01722 when comparing NUMBER to string literal") + } } func TestBooleanStringCoercion(t *testing.T) { @@ -259,7 +259,6 @@ func TestBooleanStringCoercion(t *testing.T) { } } - func TestBooleanNullableColumn(t *testing.T) { t.Skip("Skipping until nullable bool bug is resolved") DB.Migrator().DropTable(&BooleanTest{}) @@ -280,3 +279,78 @@ func TestBooleanNullableColumn(t *testing.T) { t.Errorf("expected NULL, got %v", *got.Nullable) } } + +type Model struct { + ID uint `gorm:"primaryKey"` + Children []ChildModel `gorm:"foreignKey:ParentID"` +} + +type ChildModel struct { + ParentID uint + ID uint `gorm:"primaryKey"` + Data bool `gorm:"type:boolean"` +} + +func setupBooleanTestTables(t *testing.T) { + t.Log("Setting up boolean test tables") + + DB.Migrator().DropTable(&Model{}, &ChildModel{}) + + err := DB.AutoMigrate(&Model{}, &ChildModel{}) + if err != nil { + t.Fatalf("Failed to migrate boolean test tables: %v", err) + } + + t.Log("boolean test tables created successfully") +} + +func TestBooleanOnConflict(t *testing.T) { + type test struct { + model any + fn func(model any) error + } + tests := map[string]test{ + "OneToManySingle": { + model: &Model{ + ID: 1, + Children: []ChildModel{ + { + ID: 1, + Data: true, + }, + }, + }, + fn: func(model any) error { + return DB.Create(model).Error + }, + }, + "OneToManyBatch": { + model: &Model{ + ID: 1, + Children: []ChildModel{ + { + ID: 1, + Data: true, + }, + { + ID: 2, + Data: false, + }, + }, + }, + fn: func(model any) error { + return DB.Create(model).Error + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + setupBooleanTestTables(t) + err := tc.fn(tc.model) + if err != nil { + t.Fatalf("Failed to create boolean record with ON CONFLICT: %v", err) + } + }) + } +} diff --git a/tests/json_bulk_test.go b/tests/json_bulk_test.go index 1b4fdf5..d6457bb 100644 --- a/tests/json_bulk_test.go +++ b/tests/json_bulk_test.go @@ -39,6 +39,7 @@ package tests import ( + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -962,3 +963,104 @@ func TestJSONRootArray(t *testing.T) { t.Fatalf("unexpected array content after appends: %#v", arr) } } + +func TestCustomJSON(t *testing.T) { + type CustomJSONModel struct { + Blah string `gorm:"primaryKey"` + Data AttributeMap `gorm:"type:json"` + } + + type test struct { + model any + fn func(model any) error + } + tests := map[string]test{ + "Single": { + model: []CustomJSONModel{ + { + Blah: "1", + Data: AttributeMap{"Data": strings.Repeat("X", 32768)}, + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleBatch": { + model: []CustomJSONModel{ + { + Blah: "1", + Data: AttributeMap{"Data": strings.Repeat("X", 32768)}, + }, + { + Blah: "2", + Data: AttributeMap{"Data": strings.Repeat("Y", 3)}, + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + DB.Migrator().DropTable(&CustomJSONModel{}) + if err := DB.Set("gorm:table_options", "TABLESPACE SYSAUX").AutoMigrate(&CustomJSONModel{}); err != nil { + t.Fatalf("migrate failed: %v", err) + } + err := tc.fn(tc.model) + if err != nil { + t.Fatalf("Failed to create CLOB record with ON CONFLICT: %v", err) + } + }) + } +} + +func scanBytes(src interface{}) ([]byte, bool) { + if stringer, ok := src.(fmt.Stringer); ok { + return []byte(stringer.String()), true + } + bytes, ok := src.([]byte) + if !ok { + return nil, false + } + return bytes, true +} + +type AttributeMap map[string]interface{} + +func (a AttributeMap) Value() (driver.Value, error) { + attrs := a + if attrs == nil { + attrs = AttributeMap{} + } + value, err := json.Marshal(attrs) + return value, err +} + +func (a *AttributeMap) Scan(src interface{}) error { + bytes, ok := scanBytes(src) + if !ok { + return fmt.Errorf("failed to scan attribute map") + } + var raw interface{} + err := json.Unmarshal(bytes, &raw) + if err != nil { + return err + } + + if raw == nil { + *a = map[string]interface{}{} + return nil + } + *a, ok = raw.(map[string]interface{}) + if !ok { + return fmt.Errorf("failed to convert attribute map from json") + } + return nil +} diff --git a/tests/lob_test.go b/tests/lob_test.go new file mode 100644 index 0000000..39b5fd3 --- /dev/null +++ b/tests/lob_test.go @@ -0,0 +1,371 @@ +/* +** Copyright (c) 2025 Oracle and/or its affiliates. +** +** The Universal Permissive License (UPL), Version 1.0 +** +** Subject to the condition set forth below, permission is hereby granted to any +** person obtaining a copy of this software, associated documentation and/or data +** (collectively the "Software"), free of charge and under any and all copyright +** rights in the Software, and any and all patent rights owned or freely +** licensable by each licensor hereunder covering either (i) the unmodified +** Software as contributed to or provided by such licensor, or (ii) the Larger +** Works (as defined below), to deal in both +** +** (a) the Software, and +** (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if +** one is included with the Software (each a "Larger Work" to which the Software +** is contributed by such licensors), +** +** without restriction, including without limitation the rights to copy, create +** derivative works of, display, perform, and distribute the Software and make, +** use, sell, offer for sale, import, export, have made, and have sold the +** Software and the Larger Work(s), and to sublicense the foregoing rights on +** either these or other terms. +** +** This license is subject to the following condition: +** The above copyright notice and either this complete permission notice or at +** a minimum a reference to the UPL must be included in all copies or +** substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +** AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +** OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +** SOFTWARE. + */ + +package tests + +import ( + "strings" + "testing" + + "gorm.io/gorm/clause" +) + +type ClobOneToManyModel struct { + ID uint `gorm:"primaryKey"` + Children []ClobChildModel `gorm:"foreignKey:ParentID"` +} + +type ClobChildModel struct { + ParentID uint + Blah string `gorm:"primaryKey"` + Data string `gorm:"type:clob"` +} + +type ClobSingleModel struct { + Blah string `gorm:"primaryKey"` + Data string `gorm:"type:clob"` +} + +type BlobOneToManyModel struct { + ID uint `gorm:"primaryKey"` + Children []BlobChildModel `gorm:"foreignKey:ParentID"` +} + +type BlobChildModel struct { + ParentID uint + Blah string `gorm:"primaryKey"` + Data []byte `gorm:"type:blob"` +} + +type BlobSingleModel struct { + ID uint `gorm:"primaryKey"` + Data []byte `gorm:"type:blob"` +} + +func setupLobTestTables(t *testing.T) { + t.Log("Setting up LOB test tables") + + DB.Migrator().DropTable(&ClobOneToManyModel{}, &ClobChildModel{}, &ClobSingleModel{}, &BlobOneToManyModel{}, &BlobChildModel{}, &BlobSingleModel{}) + + err := DB.AutoMigrate(&ClobOneToManyModel{}, &ClobChildModel{}, &ClobSingleModel{}, &BlobOneToManyModel{}, &BlobChildModel{}, &BlobSingleModel{}) + if err != nil { + t.Fatalf("Failed to migrate LOB test tables: %v", err) + } + + t.Log("LOB test tables created successfully") +} + +func TestClobOnConflict(t *testing.T) { + type test struct { + model any + fn func(model any) error + } + tests := map[string]test{ + "OneToManySingle": { + model: &ClobOneToManyModel{ + ID: 1, + Children: []ClobChildModel{ + { + Blah: "1", + Data: strings.Repeat("X", 32768), + }, + }, + }, + fn: func(model any) error { + return DB.Create(model).Error + }, + }, + "OneToManyBatch": { + model: &ClobOneToManyModel{ + ID: 1, + Children: []ClobChildModel{ + { + Blah: "1", + Data: strings.Repeat("X", 32768), + }, + { + Blah: "2", + Data: strings.Repeat("Y", 3), + }, + }, + }, + fn: func(model any) error { + return DB.Create(model).Error + }, + }, + "Single": { + model: []ClobSingleModel{ + { + Blah: "1", + Data: strings.Repeat("X", 32768), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleNotQuiteLOB": { + model: []ClobSingleModel{ + { + Blah: "1", + Data: strings.Repeat("X", 32767), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleBatch": { + model: []ClobSingleModel{ + { + Blah: "1", + Data: strings.Repeat("X", 32768), + }, + { + Blah: "2", + Data: strings.Repeat("Y", 3), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleBatchNotQuiteLOB": { + model: []ClobSingleModel{ + { + Blah: "1", + Data: strings.Repeat("X", 32767), + }, + { + Blah: "2", + Data: strings.Repeat("Y", 3), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleBatchReverse": { + model: []ClobSingleModel{ + { + Blah: "1", + Data: strings.Repeat("Y", 3), + }, + { + Blah: "2", + Data: strings.Repeat("X", 32768), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + setupLobTestTables(t) + err := tc.fn(tc.model) + if err != nil { + t.Fatalf("Failed to create CLOB record with ON CONFLICT: %v", err) + } + }) + } +} + +func TestBlobOnConflict(t *testing.T) { + type test struct { + model any + fn func(model any) error + } + tests := map[string]test{ + "OneToManySingle": { + model: &BlobOneToManyModel{ + ID: 1, + Children: []BlobChildModel{ + { + Blah: "1", + Data: []byte(strings.Repeat("X", 32768)), + }, + }, + }, + fn: func(model any) error { + return DB.Create(model).Error + }, + }, + "OneToManyBatch": { + model: &BlobOneToManyModel{ + ID: 1, + Children: []BlobChildModel{ + { + Blah: "1", + Data: []byte(strings.Repeat("X", 32768)), + }, + { + Blah: "2", + Data: []byte(strings.Repeat("Y", 3)), + }, + }, + }, + fn: func(model any) error { + return DB.Create(model).Error + }, + }, + "Single": { + model: []BlobSingleModel{ + { + ID: 1, + Data: []byte(strings.Repeat("X", 32768)), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleNotQuiteLOB": { + model: []BlobSingleModel{ + { + ID: 1, + Data: []byte(strings.Repeat("X", 32767)), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleBatch": { + model: []BlobSingleModel{ + { + ID: 1, + Data: []byte(strings.Repeat("X", 32768)), + }, + { + ID: 2, + Data: []byte(strings.Repeat("Y", 3)), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleBatchNotQuiteLOB": { + model: []BlobSingleModel{ + { + ID: 1, + Data: []byte(strings.Repeat("X", 32767)), + }, + { + ID: 2, + Data: []byte(strings.Repeat("Y", 3)), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + "SingleBatchReverse": { + model: []BlobSingleModel{ + { + ID: 1, + Data: []byte(strings.Repeat("Y", 3)), + }, + { + ID: 2, + Data: []byte(strings.Repeat("X", 32768)), + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + setupLobTestTables(t) + err := tc.fn(tc.model) + if err != nil { + t.Fatalf("Failed to create BLOB record with ON CONFLICT: %v", err) + } + }) + } +} + +func TestClobUpdateOnConflict(t *testing.T) { + setupLobTestTables(t) + model := []ClobSingleModel{ + { + Blah: "1", + Data: strings.Repeat("X", 32768), + }, + { + Blah: "2", + Data: strings.Repeat("Y", 3), + }, + } + + err := DB.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(model, 1000).Error + if err != nil { + t.Fatalf("Failed to create CLOB record with ON CONFLICT: %v", err) + } + model[1].Data = strings.Repeat("Z", 5000) + err = DB.Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(model, 1000).Error + if err != nil { + t.Fatalf("Failed to update CLOB record with ON CONFLICT: %v", err) + } +} diff --git a/tests/clob_test.go b/tests/uuid_test.go similarity index 59% rename from tests/clob_test.go rename to tests/uuid_test.go index 7ba16f1..3b8c91a 100644 --- a/tests/clob_test.go +++ b/tests/uuid_test.go @@ -39,46 +39,82 @@ package tests import ( - "strings" "testing" -) -type ClobOneToManyModel struct { - ID uint `gorm:"primaryKey"` - Children []ClobChildModel `gorm:"foreignKey:ID"` -} + "github.com/google/uuid" + "gorm.io/gorm/clause" +) -type ClobChildModel struct { - ID uint `gorm:"primaryKey"` - Data string `gorm:"type:clob"` +type UUIDModel struct { + ID uint `gorm:"primaryKey"` + UUID *uuid.UUID `gorm:"type:VARCHAR2(36)"` + // Data string `gorm:"type:clob"` } -func setupClobTestTables(t *testing.T) { - t.Log("Setting up CLOB test tables") +func setupUUIDTestTables(t *testing.T) { + t.Log("Setting up UUID test tables") - DB.Migrator().DropTable(&ClobOneToManyModel{}, &ClobChildModel{}) + DB.Migrator().DropTable(&UUIDModel{}) - err := DB.AutoMigrate(&ClobOneToManyModel{}, &ClobChildModel{}) + err := DB.AutoMigrate(&UUIDModel{}) if err != nil { - t.Fatalf("Failed to migrate CLOB test tables: %v", err) + t.Fatalf("Failed to migrate UUID test tables: %v", err) } - t.Log("CLOB test tables created successfully") + t.Log("UUID test tables created successfully") } -func TestClobOnConflict(t *testing.T) { - setupClobTestTables(t) - - model := &ClobOneToManyModel{ - ID: 1, - Children: []ClobChildModel{ - { - Data: strings.Repeat("X", 32768), +func TestUUIDPLSQL(t *testing.T) { + myUUID := uuid.New() + type test struct { + model any + fn func(model any) error + } + tests := map[string]test{ + "InsertWithReturning": { + model: []UUIDModel{ + { + UUID: &myUUID, + }, + }, + fn: func(model any) error { + return DB.Create(model).Error + }, + }, + "InsertWithReturningNil": { + model: []UUIDModel{ + { + UUID: nil, + }, + }, + fn: func(model any) error { + return DB.Create(model).Error + }, + }, + "BatchInsert": { + model: []UUIDModel{ + { + UUID: &myUUID, + }, + { + UUID: nil, + }, + }, + fn: func(model any) error { + return DB.Clauses(clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(model, 1000).Error }, }, } - err := DB.Create(model).Error - if err != nil { - t.Fatalf("Failed to create BLOB record with ON CONFLICT: %v", err) + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + setupUUIDTestTables(t) + err := tc.fn(tc.model) + if err != nil { + t.Fatalf("Failed to create UUID record with PLSQL: %v", err) + } + }) } }