Skip to content

Commit 67d2df7

Browse files
committed
fix
1 parent ce7d574 commit 67d2df7

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

models/unittest/fixtures_loader.go

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ func (f *fixturesLoader) prepareFieldValue(v any) any {
3535
return v
3636
}
3737

38-
func (f *fixturesLoader) mssqlTableHasIdentityColumn(db *sql.DB, tableName string) (bool, error) {
39-
row := db.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName)
38+
func (f *fixturesLoader) mssqlTableHasIdentityColumn(q *sql.Tx, tableName string) (bool, error) {
39+
row := q.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName)
4040
var count int
4141
if err := row.Scan(&count); err != nil {
4242
return false, err
4343
}
4444
return count > 0, nil
4545
}
4646

47-
func (f *fixturesLoader) loadFixtures(file string) error {
47+
func (f *fixturesLoader) loadFixtures(tx *sql.Tx, file string) error {
4848
data, err := os.ReadFile(file)
4949
if err != nil {
5050
return fmt.Errorf("failed to read file %q: %w", file, err)
@@ -57,25 +57,14 @@ func (f *fixturesLoader) loadFixtures(file string) error {
5757

5858
tableName, _, _ := strings.Cut(filepath.Base(file), ".")
5959
tableNameQuoted := f.quoteObject(tableName)
60-
_, err = f.engine.Table(tableName).Where("1=1").Delete() // sqlite3 doesn't support truncate
60+
_, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", tableNameQuoted)) // sqlite3 doesn't support truncate
6161
if err != nil {
6262
return err
6363
}
6464

65-
goDB := f.engine.DB().DB
66-
tx, err := goDB.Begin()
67-
if err != nil {
68-
return err
69-
}
70-
defer func() {
71-
if tx != nil {
72-
_ = tx.Rollback()
73-
}
74-
}()
75-
7665
switch f.engine.Dialect().URI().DBType {
7766
case schemas.MSSQL:
78-
hasIdentityColumn, err := f.mssqlTableHasIdentityColumn(goDB, tableName)
67+
hasIdentityColumn, err := f.mssqlTableHasIdentityColumn(tx, tableName)
7968
if err != nil {
8069
return err
8170
}
@@ -84,6 +73,7 @@ func (f *fixturesLoader) loadFixtures(file string) error {
8473
if err != nil {
8574
return err
8675
}
76+
defer func() { _, err = tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", tableNameQuoted)) }()
8777
}
8878
}
8979

@@ -112,16 +102,20 @@ func (f *fixturesLoader) loadFixtures(file string) error {
112102
sqlBuf = sqlBuf[:0]
113103
sqlArguments = sqlArguments[:0]
114104
}
115-
err = tx.Commit()
116-
tx = nil
117-
return err
105+
return nil
118106
}
119107

120108
func (f *fixturesLoader) Load() error {
109+
goDB := f.engine.DB().DB
110+
121111
switch f.engine.Dialect().URI().DBType {
122112
case schemas.SQLITE:
123113
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
124114
f.paramPlaceholder = func(idx int) string { return "?" }
115+
if _, err := goDB.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
116+
return err
117+
}
118+
defer func() { _, _ = goDB.Exec("PRAGMA defer_foreign_keys = OFF") }()
125119
case schemas.POSTGRES:
126120
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
127121
f.paramPlaceholder = func(idx int) string { return fmt.Sprintf(`$%d`, idx) }
@@ -141,13 +135,20 @@ func (f *fixturesLoader) Load() error {
141135
f.opts.Files = append(f.opts.Files, e.Name())
142136
}
143137
}
138+
139+
tx, err := goDB.Begin()
140+
if err != nil {
141+
return err
142+
}
143+
defer func() { _ = tx.Rollback() }()
144+
144145
for _, file := range f.opts.Files {
145146
if !filepath.IsAbs(file) {
146147
file = filepath.Join(f.opts.Dir, file)
147148
}
148-
if err := f.loadFixtures(file); err != nil {
149+
if err := f.loadFixtures(tx, file); err != nil {
149150
return fmt.Errorf("failed to load fixtures from %s: %w", file, err)
150151
}
151152
}
152-
return nil
153+
return tx.Commit()
153154
}

0 commit comments

Comments
 (0)