Skip to content

Commit a5e6903

Browse files
committed
add support for Boolean Data type for server version above 23
1 parent 0f74e08 commit a5e6903

File tree

4 files changed

+105
-13
lines changed

4 files changed

+105
-13
lines changed

oracle/common.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func getOracleArrayType(values []any) string {
7171
}
7272
switch v := val.(type) {
7373
case bool:
74-
arrayType = "TABLE OF NUMBER(1)"
74+
arrayType = "TABLE OF BOOLEAN"
7575
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
7676
arrayType = "TABLE OF NUMBER"
7777
case time.Time:
@@ -132,7 +132,7 @@ func createTypedDestination(f *schema.Field) interface{} {
132132

133133
// To differentiate between bool fields stored as NUMBER(1) and bool fields stored as actual BOOLEAN type,
134134
// check the struct's "type" tag.
135-
if f.DataType == "boolean" {
135+
if string(f.DataType) == "bool" || string(f.DataType) == "boolean" {
136136
return new(bool)
137137
}
138138

@@ -199,7 +199,7 @@ func createTypedDestination(f *schema.Field) interface{} {
199199
return new(string)
200200

201201
case reflect.Bool:
202-
return new(int64)
202+
return new(bool)
203203

204204
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
205205
return new(int64)
@@ -470,6 +470,9 @@ func isRawMessageField(f *schema.Field) bool {
470470
func convertPrimitiveType(value interface{}, targetType reflect.Type) interface{} {
471471
switch targetType.Kind() {
472472
case reflect.Bool:
473+
if v, ok := value.(bool); ok {
474+
return v
475+
}
473476
if v, ok := value.(int64); ok {
474477
return v != 0
475478
}

oracle/migrator.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
642642
// Builds Oracle-compatible default values from string
643643
func (m Migrator) buildOracleDefault(defaultValue string) string {
644644
defaultValue = strings.TrimSpace(defaultValue)
645+
dialector := m.Dialector.(Dialector)
645646

646647
if defaultValue == "" {
647648
return ""
@@ -656,8 +657,14 @@ func (m Migrator) buildOracleDefault(defaultValue string) string {
656657
case "SYSDATE":
657658
return "DEFAULT SYSDATE"
658659
case "TRUE":
660+
if dialector.Config.ServerVersion >= 23 {
661+
return "DEFAULT TRUE"
662+
}
659663
return "DEFAULT 1"
660664
case "FALSE":
665+
if dialector.Config.ServerVersion >= 23 {
666+
return "DEFAULT FALSE"
667+
}
661668
return "DEFAULT 0"
662669
}
663670

@@ -689,8 +696,15 @@ func (m Migrator) buildOracleDefault(defaultValue string) string {
689696

690697
// Build Oracle-compatible default values from Go interface
691698
func (m Migrator) buildOracleDefaultFromInterface(value interface{}) string {
699+
dialector := m.Dialector.(Dialector)
692700
switch v := value.(type) {
693701
case bool:
702+
if dialector.Config.ServerVersion >= 23 {
703+
if v {
704+
return "DEFAULT TRUE"
705+
}
706+
return "DEFAULT FALSE"
707+
}
694708
if v {
695709
return "DEFAULT 1"
696710
}

oracle/oracle.go

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ type Config struct {
6767
Conn *sql.DB
6868
DefaultStringSize uint
6969
SkipQuoteIdentifiers bool
70+
ServerVersion int
7071
}
7172

7273
type Dialector struct {
@@ -132,6 +133,11 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) {
132133
if err != nil {
133134
return err
134135
}
136+
version, err := GetServerVersion(db)
137+
if err != nil {
138+
return err
139+
}
140+
d.Config.ServerVersion = version
135141

136142
return nil
137143
}
@@ -188,7 +194,9 @@ func (d Dialector) getStringType(field *schema.Field) string {
188194
}
189195

190196
func (d Dialector) getBooleanType() string {
191-
// Oracle doesn't support BOOLEAN in CREATE TABLE, use NUMBER(1) instead
197+
if d.Config.ServerVersion >= 23 {
198+
return "BOOLEAN"
199+
}
192200
return "NUMBER(1)"
193201
}
194202

@@ -294,3 +302,40 @@ func (d Dialector) RollbackTo(tx *gorm.DB, name string) error {
294302
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
295303
return tx.Error
296304
}
305+
306+
// GetServerVersion retrieves the Oracle server version as an integer.
307+
func GetServerVersion(db *gorm.DB) (int, error) {
308+
sqlDB, err := db.DB()
309+
if err != nil {
310+
return 0, err
311+
}
312+
rows, err := sqlDB.Query("SELECT banner FROM v$version")
313+
if err != nil {
314+
return 0, err
315+
}
316+
defer rows.Close()
317+
for rows.Next() {
318+
var banner string
319+
err = rows.Scan(&banner)
320+
if err != nil {
321+
return 0, err
322+
}
323+
// Parse banner
324+
if strings.Contains(banner, "Oracle") && strings.Contains(banner, "Release") {
325+
parts := strings.Split(banner, " ")
326+
for i, p := range parts {
327+
if p == "Release" && i+1 < len(parts) {
328+
rel := parts[i+1]
329+
dotPos := strings.Index(rel, ".")
330+
if dotPos != -1 {
331+
version, perr := strconv.Atoi(rel[:dotPos])
332+
if perr == nil {
333+
return version, nil
334+
}
335+
}
336+
}
337+
}
338+
}
339+
}
340+
return 0, fmt.Errorf("could not find version in banners")
341+
}

tests/boolean_test.go

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,30 @@ func TestBooleanNegativeInvalidDBValue(t *testing.T) {
142142
t.Fatalf("failed to insert invalid bool: %v", err)
143143
}
144144

145+
// Check if using native BOOLEAN or NUMBER(1)
146+
var columnType string
147+
DB.Raw(`SELECT DATA_TYPE FROM USER_TAB_COLUMNS
148+
WHERE TABLE_NAME = 'BOOLEAN_TESTS' AND COLUMN_NAME = 'FLAG'`).Scan(&columnType)
145149
var got BooleanTest
146150
err := DB.First(&got, 2001).Error
147-
if err == nil {
148-
t.Fatal("expected invalid boolean scan error, got nil")
149-
}
150-
151-
if !strings.Contains(err.Error(), "invalid") &&
152-
!strings.Contains(err.Error(), "convert") {
153-
t.Fatalf("expected boolean conversion error, got: %v", err)
151+
if columnType == "BOOLEAN" {
152+
// Oracle database version above 23 native BOOLEAN accepts any number
153+
if err != nil {
154+
t.Fatalf("Oracle database version above 23 BOOLEAN should accept numeric value 2, got error: %v", err)
155+
}
156+
// Verify that 2 was treated as TRUE
157+
if !got.Flag {
158+
t.Errorf("expected FLAG=true for value 2 in BOOLEAN column, got false")
159+
}
160+
} else {
161+
// Oracle database version below 23 uses NUMBER(1), should reject value 2
162+
if err == nil {
163+
t.Fatal("expected invalid boolean scan error, got nil")
164+
}
165+
if !strings.Contains(err.Error(), "invalid") &&
166+
!strings.Contains(err.Error(), "convert") {
167+
t.Fatalf("expected boolean conversion error, got: %v", err)
168+
}
154169
}
155170
}
156171

@@ -229,9 +244,24 @@ func TestBooleanQueryMixedComparisons(t *testing.T) {
229244
t.Errorf("expected at least 1 row for FLAG=1")
230245
}
231246

247+
// Check if using native BOOLEAN or NUMBER(1)
248+
var columnType string
249+
DB.Raw(`SELECT DATA_TYPE FROM USER_TAB_COLUMNS
250+
WHERE TABLE_NAME = 'BOOLEAN_TESTS' AND COLUMN_NAME = 'FLAG'`).Scan(&columnType)
251+
232252
var gotStr []BooleanTest
233-
if err := DB.Where("FLAG = 'true'").Find(&gotStr).Error; err == nil {
234-
t.Errorf("expected ORA-01722 when comparing NUMBER to string literal")
253+
err := DB.Where("FLAG = 'true'").Find(&gotStr).Error
254+
255+
if columnType == "NUMBER" {
256+
// For NUMBER(1), expect error
257+
if err == nil {
258+
t.Errorf("expected ORA-01722 when comparing NUMBER to string literal")
259+
}
260+
} else {
261+
// For BOOLEAN, it's valid
262+
if err != nil {
263+
t.Errorf("unexpected error for BOOLEAN type: %v", err)
264+
}
235265
}
236266
}
237267

0 commit comments

Comments
 (0)