@@ -35,16 +35,16 @@ func (f *fixturesLoader) prepareFieldValue(v any) any {
35
35
return v
36
36
}
37
37
38
- func (f * fixturesLoader ) mssqlTableHasIdentityColumn (db * sql.DB , tableName string ) (bool , error ) {
39
- row := db .QueryRow (`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)` , tableName )
38
+ func (f * fixturesLoader ) mssqlTableHasIdentityColumn (q * sql.Tx , tableName string ) (bool , error ) {
39
+ row := q .QueryRow (`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)` , tableName )
40
40
var count int
41
41
if err := row .Scan (& count ); err != nil {
42
42
return false , err
43
43
}
44
44
return count > 0 , nil
45
45
}
46
46
47
- func (f * fixturesLoader ) loadFixtures (file string ) error {
47
+ func (f * fixturesLoader ) loadFixtures (tx * sql. Tx , file string ) error {
48
48
data , err := os .ReadFile (file )
49
49
if err != nil {
50
50
return fmt .Errorf ("failed to read file %q: %w" , file , err )
@@ -57,25 +57,14 @@ func (f *fixturesLoader) loadFixtures(file string) error {
57
57
58
58
tableName , _ , _ := strings .Cut (filepath .Base (file ), "." )
59
59
tableNameQuoted := f .quoteObject (tableName )
60
- _ , err = f . engine . Table ( tableName ). Where ( "1=1" ). Delete ( ) // sqlite3 doesn't support truncate
60
+ _ , err = tx . Exec ( fmt . Sprintf ( "DELETE FROM %s" , tableNameQuoted ) ) // sqlite3 doesn't support truncate
61
61
if err != nil {
62
62
return err
63
63
}
64
64
65
- goDB := f .engine .DB ().DB
66
- tx , err := goDB .Begin ()
67
- if err != nil {
68
- return err
69
- }
70
- defer func () {
71
- if tx != nil {
72
- _ = tx .Rollback ()
73
- }
74
- }()
75
-
76
65
switch f .engine .Dialect ().URI ().DBType {
77
66
case schemas .MSSQL :
78
- hasIdentityColumn , err := f .mssqlTableHasIdentityColumn (goDB , tableName )
67
+ hasIdentityColumn , err := f .mssqlTableHasIdentityColumn (tx , tableName )
79
68
if err != nil {
80
69
return err
81
70
}
@@ -84,6 +73,7 @@ func (f *fixturesLoader) loadFixtures(file string) error {
84
73
if err != nil {
85
74
return err
86
75
}
76
+ defer func () { _ , err = tx .Exec (fmt .Sprintf ("SET IDENTITY_INSERT %s OFF" , tableNameQuoted )) }()
87
77
}
88
78
}
89
79
@@ -112,16 +102,20 @@ func (f *fixturesLoader) loadFixtures(file string) error {
112
102
sqlBuf = sqlBuf [:0 ]
113
103
sqlArguments = sqlArguments [:0 ]
114
104
}
115
- err = tx .Commit ()
116
- tx = nil
117
- return err
105
+ return nil
118
106
}
119
107
120
108
func (f * fixturesLoader ) Load () error {
109
+ goDB := f .engine .DB ().DB
110
+
121
111
switch f .engine .Dialect ().URI ().DBType {
122
112
case schemas .SQLITE :
123
113
f .quoteObject = func (s string ) string { return fmt .Sprintf (`"%s"` , s ) }
124
114
f .paramPlaceholder = func (idx int ) string { return "?" }
115
+ if _ , err := goDB .Exec ("PRAGMA defer_foreign_keys = ON" ); err != nil {
116
+ return err
117
+ }
118
+ defer func () { _ , _ = goDB .Exec ("PRAGMA defer_foreign_keys = OFF" ) }()
125
119
case schemas .POSTGRES :
126
120
f .quoteObject = func (s string ) string { return fmt .Sprintf (`"%s"` , s ) }
127
121
f .paramPlaceholder = func (idx int ) string { return fmt .Sprintf (`$%d` , idx ) }
@@ -141,13 +135,20 @@ func (f *fixturesLoader) Load() error {
141
135
f .opts .Files = append (f .opts .Files , e .Name ())
142
136
}
143
137
}
138
+
139
+ tx , err := goDB .Begin ()
140
+ if err != nil {
141
+ return err
142
+ }
143
+ defer func () { _ = tx .Rollback () }()
144
+
144
145
for _ , file := range f .opts .Files {
145
146
if ! filepath .IsAbs (file ) {
146
147
file = filepath .Join (f .opts .Dir , file )
147
148
}
148
- if err := f .loadFixtures (file ); err != nil {
149
+ if err := f .loadFixtures (tx , file ); err != nil {
149
150
return fmt .Errorf ("failed to load fixtures from %s: %w" , file , err )
150
151
}
151
152
}
152
- return nil
153
+ return tx . Commit ()
153
154
}
0 commit comments