Skip to content

Commit ea9e0ae

Browse files
committed
improve DefaultArg
1 parent c62106f commit ea9e0ae

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

modules/testlogger/testlogger.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (w *testLoggerWriterCloser) Reset() {
9393
func PrintCurrentTest(t testing.TB, skip ...int) func() {
9494
t.Helper()
9595
start := time.Now()
96-
actualSkip := util.DefArgZero(skip) + 1
96+
actualSkip := util.DefaultArg(skip) + 1
9797
_, filename, line, _ := runtime.Caller(actualSkip)
9898

9999
if log.CanColorStdout {

modules/util/util.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -230,18 +230,23 @@ func IfZero[T comparable](v, def T) T {
230230
return v
231231
}
232232

233-
// DefArg helps the "optional argument" in Golang: func(foo string, optionalArg ...int)
234-
// it returns the first non-zero value from the given optional argument,
235-
// or the default value if there is no optional argument.
236-
func DefArg[T any](defArgs []T, def T) (ret T) {
237-
if len(defArgs) == 1 {
238-
return defArgs[0]
233+
// DefaultArg helps the "optional argument" in Golang:
234+
//
235+
// func foo(optionalArg ...int) { return DefaultArg(optionalArg) }
236+
// calling `foo()` gets 0, calling `foo(100)` gets 100
237+
// func bar(optionalArg ...int) { return DefaultArg(optionalArg, 42) }
238+
// calling `bar()` gets 42, calling `bar(100)` gets 100
239+
//
240+
// Passing more than 1 item to `optionalArg` or `def` is undefined behavior.
241+
// At the moment it only returns the first argument.
242+
func DefaultArg[T any](optionalArg []T, def ...T) (ret T) {
243+
if len(optionalArg) >= 1 {
244+
return optionalArg[0]
239245
}
240-
return def
241-
}
242-
243-
func DefArgZero[T any](defArgs []T) (ret T) {
244-
return DefArg(defArgs, ret)
246+
if len(def) >= 1 {
247+
return def[0]
248+
}
249+
return ret
245250
}
246251

247252
func ReserveLineBreakForTextarea(input string) string {

modules/util/util_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,16 @@ func TestReserveLineBreakForTextarea(t *testing.T) {
240240
assert.Equal(t, ReserveLineBreakForTextarea("test\r\ndata"), "test\ndata")
241241
assert.Equal(t, ReserveLineBreakForTextarea("test\r\ndata\r\n"), "test\ndata\n")
242242
}
243+
244+
func TestDefaultArg(t *testing.T) {
245+
foo := func(other any, optionalArg ...int) int {
246+
return DefaultArg(optionalArg)
247+
}
248+
bar := func(other any, optionalArg ...int) int {
249+
return DefaultArg(optionalArg, 42)
250+
}
251+
assert.Equal(t, 0, foo(nil))
252+
assert.Equal(t, 100, foo(nil, 100))
253+
assert.Equal(t, 42, bar(nil))
254+
assert.Equal(t, 100, bar(nil, 100))
255+
}

tests/test_utils.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ func PrepareCleanPackageData(t testing.TB) {
262262

263263
func PrepareTestEnv(t testing.TB, skip ...int) func() {
264264
t.Helper()
265-
deferFn := PrintCurrentTest(t, util.DefArgZero(skip)+1)
265+
deferFn := PrintCurrentTest(t, util.DefaultArg(skip)+1)
266266

267267
// load database fixtures
268268
assert.NoError(t, unittest.LoadFixtures())
@@ -276,7 +276,7 @@ func PrepareTestEnv(t testing.TB, skip ...int) func() {
276276

277277
func PrintCurrentTest(t testing.TB, skip ...int) func() {
278278
t.Helper()
279-
return testlogger.PrintCurrentTest(t, util.DefArgZero(skip)+1)
279+
return testlogger.PrintCurrentTest(t, util.DefaultArg(skip)+1)
280280
}
281281

282282
// Printf takes a format and args and prints the string to os.Stdout

0 commit comments

Comments
 (0)