@@ -35,16 +35,16 @@ func (f *fixturesLoader) prepareFieldValue(v any) any {
35
35
return v
36
36
}
37
37
38
- func (f * fixturesLoader ) mssqlTableHasIdentityColumn (db * sql.DB , tableName string ) (bool , error ) {
39
- row := db .QueryRow (`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)` , tableName )
38
+ func (f * fixturesLoader ) mssqlTableHasIdentityColumn (q * sql.Tx , tableName string ) (bool , error ) {
39
+ row := q .QueryRow (`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)` , tableName )
40
40
var count int
41
41
if err := row .Scan (& count ); err != nil {
42
42
return false , err
43
43
}
44
44
return count > 0 , nil
45
45
}
46
46
47
- func (f * fixturesLoader ) loadFixtures (file string ) error {
47
+ func (f * fixturesLoader ) loadFixtures (tx * sql. Tx , file string ) error {
48
48
data , err := os .ReadFile (file )
49
49
if err != nil {
50
50
return fmt .Errorf ("failed to read file %q: %w" , file , err )
@@ -57,25 +57,14 @@ func (f *fixturesLoader) loadFixtures(file string) error {
57
57
58
58
tableName , _ , _ := strings .Cut (filepath .Base (file ), "." )
59
59
tableNameQuoted := f .quoteObject (tableName )
60
- _ , err = f . engine . Table ( tableName ). Where ( "1=1" ). Delete ( ) // sqlite3 doesn't support truncate
60
+ _ , err = tx . Exec ( fmt . Sprintf ( "DELETE FROM %s" , tableNameQuoted ) ) // sqlite3 doesn't support truncate
61
61
if err != nil {
62
62
return err
63
63
}
64
64
65
- goDB := f .engine .DB ().DB
66
- tx , err := goDB .Begin ()
67
- if err != nil {
68
- return err
69
- }
70
- defer func () {
71
- if tx != nil {
72
- _ = tx .Rollback ()
73
- }
74
- }()
75
-
76
65
switch f .engine .Dialect ().URI ().DBType {
77
66
case schemas .MSSQL :
78
- hasIdentityColumn , err := f .mssqlTableHasIdentityColumn (goDB , tableName )
67
+ hasIdentityColumn , err := f .mssqlTableHasIdentityColumn (tx , tableName )
79
68
if err != nil {
80
69
return err
81
70
}
@@ -112,16 +101,20 @@ func (f *fixturesLoader) loadFixtures(file string) error {
112
101
sqlBuf = sqlBuf [:0 ]
113
102
sqlArguments = sqlArguments [:0 ]
114
103
}
115
- err = tx .Commit ()
116
- tx = nil
117
- return err
104
+ return nil
118
105
}
119
106
120
107
func (f * fixturesLoader ) Load () error {
108
+ goDB := f .engine .DB ().DB
109
+
121
110
switch f .engine .Dialect ().URI ().DBType {
122
111
case schemas .SQLITE :
123
112
f .quoteObject = func (s string ) string { return fmt .Sprintf (`"%s"` , s ) }
124
113
f .paramPlaceholder = func (idx int ) string { return "?" }
114
+ if _ , err := goDB .Exec ("PRAGMA defer_foreign_keys = ON" ); err != nil {
115
+ return err
116
+ }
117
+ defer func () { _ , _ = goDB .Exec ("PRAGMA defer_foreign_keys = OFF" ) }()
125
118
case schemas .POSTGRES :
126
119
f .quoteObject = func (s string ) string { return fmt .Sprintf (`"%s"` , s ) }
127
120
f .paramPlaceholder = func (idx int ) string { return fmt .Sprintf (`$%d` , idx ) }
@@ -141,13 +134,20 @@ func (f *fixturesLoader) Load() error {
141
134
f .opts .Files = append (f .opts .Files , e .Name ())
142
135
}
143
136
}
137
+
138
+ tx , err := goDB .Begin ()
139
+ if err != nil {
140
+ return err
141
+ }
142
+ defer func () { _ = tx .Rollback () }()
143
+
144
144
for _ , file := range f .opts .Files {
145
145
if ! filepath .IsAbs (file ) {
146
146
file = filepath .Join (f .opts .Dir , file )
147
147
}
148
- if err := f .loadFixtures (file ); err != nil {
148
+ if err := f .loadFixtures (tx , file ); err != nil {
149
149
return fmt .Errorf ("failed to load fixtures from %s: %w" , file , err )
150
150
}
151
151
}
152
- return nil
152
+ return tx . Commit ()
153
153
}
0 commit comments