Skip to content

Commit d21ee6e

Browse files
committed
fix
1 parent ce7d574 commit d21ee6e

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

models/unittest/fixtures_loader.go

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ func (f *fixturesLoader) prepareFieldValue(v any) any {
3535
return v
3636
}
3737

38-
func (f *fixturesLoader) mssqlTableHasIdentityColumn(db *sql.DB, tableName string) (bool, error) {
39-
row := db.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName)
38+
func (f *fixturesLoader) mssqlTableHasIdentityColumn(q *sql.Tx, tableName string) (bool, error) {
39+
row := q.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName)
4040
var count int
4141
if err := row.Scan(&count); err != nil {
4242
return false, err
4343
}
4444
return count > 0, nil
4545
}
4646

47-
func (f *fixturesLoader) loadFixtures(file string) error {
47+
func (f *fixturesLoader) loadFixtures(tx *sql.Tx, file string) error {
4848
data, err := os.ReadFile(file)
4949
if err != nil {
5050
return fmt.Errorf("failed to read file %q: %w", file, err)
@@ -57,25 +57,14 @@ func (f *fixturesLoader) loadFixtures(file string) error {
5757

5858
tableName, _, _ := strings.Cut(filepath.Base(file), ".")
5959
tableNameQuoted := f.quoteObject(tableName)
60-
_, err = f.engine.Table(tableName).Where("1=1").Delete() // sqlite3 doesn't support truncate
60+
_, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", tableNameQuoted)) // sqlite3 doesn't support truncate
6161
if err != nil {
6262
return err
6363
}
6464

65-
goDB := f.engine.DB().DB
66-
tx, err := goDB.Begin()
67-
if err != nil {
68-
return err
69-
}
70-
defer func() {
71-
if tx != nil {
72-
_ = tx.Rollback()
73-
}
74-
}()
75-
7665
switch f.engine.Dialect().URI().DBType {
7766
case schemas.MSSQL:
78-
hasIdentityColumn, err := f.mssqlTableHasIdentityColumn(goDB, tableName)
67+
hasIdentityColumn, err := f.mssqlTableHasIdentityColumn(tx, tableName)
7968
if err != nil {
8069
return err
8170
}
@@ -112,16 +101,20 @@ func (f *fixturesLoader) loadFixtures(file string) error {
112101
sqlBuf = sqlBuf[:0]
113102
sqlArguments = sqlArguments[:0]
114103
}
115-
err = tx.Commit()
116-
tx = nil
117-
return err
104+
return nil
118105
}
119106

120107
func (f *fixturesLoader) Load() error {
108+
goDB := f.engine.DB().DB
109+
121110
switch f.engine.Dialect().URI().DBType {
122111
case schemas.SQLITE:
123112
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
124113
f.paramPlaceholder = func(idx int) string { return "?" }
114+
if _, err := goDB.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
115+
return err
116+
}
117+
defer func() { _, _ = goDB.Exec("PRAGMA defer_foreign_keys = OFF") }()
125118
case schemas.POSTGRES:
126119
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
127120
f.paramPlaceholder = func(idx int) string { return fmt.Sprintf(`$%d`, idx) }
@@ -141,13 +134,20 @@ func (f *fixturesLoader) Load() error {
141134
f.opts.Files = append(f.opts.Files, e.Name())
142135
}
143136
}
137+
138+
tx, err := goDB.Begin()
139+
if err != nil {
140+
return err
141+
}
142+
defer func() { _ = tx.Rollback() }()
143+
144144
for _, file := range f.opts.Files {
145145
if !filepath.IsAbs(file) {
146146
file = filepath.Join(f.opts.Dir, file)
147147
}
148-
if err := f.loadFixtures(file); err != nil {
148+
if err := f.loadFixtures(tx, file); err != nil {
149149
return fmt.Errorf("failed to load fixtures from %s: %w", file, err)
150150
}
151151
}
152-
return nil
152+
return tx.Commit()
153153
}

0 commit comments

Comments
 (0)