Skip to content

Commit e2da469

Browse files
committed
Fix numerical issues.
1 parent 1677b97 commit e2da469

File tree

9 files changed

+200
-50
lines changed

9 files changed

+200
-50
lines changed

ext/stats/moments.nogo

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package stats
2+
3+
import "math"
4+
5+
type moment struct {
6+
m1, m2, m3, m4 kahan
7+
n int64
8+
}
9+
10+
func (w *moment) enqueue(x float64) {
11+
n := w.n + 1
12+
w.n = n
13+
y := x - w.m1.hi - w.m1.lo
14+
w.m1.add(y / float64(n))
15+
y = math.FMA(y, x, -w.m2.hi) - w.m2.lo
16+
w.m2.add(y / float64(n))
17+
y = math.FMA(y, x, -w.m3.hi) - w.m3.lo
18+
w.m3.add(y / float64(n))
19+
y = math.FMA(y, x, -w.m4.hi) - w.m4.lo
20+
w.m4.add(y / float64(n))
21+
}
22+
23+
func (w *moment) dequeue(x float64) {
24+
n := w.n - 1
25+
if n <= 0 {
26+
*w = moment{}
27+
return
28+
}
29+
w.n = n
30+
y := x - w.m1.hi + w.m1.lo
31+
w.m1.sub(y / float64(n))
32+
y = math.FMA(y, x, w.m2.hi) + w.m2.lo
33+
w.m2.sub(y / float64(n))
34+
y = math.FMA(y, x, w.m3.hi) + w.m3.lo
35+
w.m3.sub(y / float64(n))
36+
y = math.FMA(y, x, w.m4.hi) + w.m4.lo
37+
w.m4.sub(y / float64(n))
38+
}

ext/stats/percentile.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import (
1111
"github.com/ncruces/sort/quick"
1212
)
1313

14+
// Compatible with:
15+
// https://sqlite.org/src/file/ext/misc/percentile.c
16+
1417
const (
1518
median = iota
1619
percentile_100

ext/stats/stats.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
// - regr_count: count non-null pairs of variables
1818
// - regr_slope: slope of the least-squares-fit linear equation
1919
// - regr_intercept: y-intercept of the least-squares-fit linear equation
20-
// - regr_json: all regr stats in a JSON object
20+
// - regr_json: all regr stats as a JSON object
2121
// - percentile_disc: discrete quantile
2222
// - percentile_cont: continuous quantile
2323
// - percentile: continuous percentile
@@ -111,6 +111,17 @@ type variance struct {
111111
}
112112

113113
func (fn *variance) Value(ctx sqlite3.Context) {
114+
switch fn.n {
115+
case 1:
116+
switch fn.kind {
117+
case var_pop, stddev_pop:
118+
ctx.ResultFloat(0)
119+
}
120+
return
121+
case 0:
122+
return
123+
}
124+
114125
var r float64
115126
switch fn.kind {
116127
case var_pop:
@@ -151,6 +162,25 @@ type covariance struct {
151162
}
152163

153164
func (fn *covariance) Value(ctx sqlite3.Context) {
165+
if fn.kind == regr_count {
166+
ctx.ResultInt64(fn.regr_count())
167+
return
168+
}
169+
switch fn.n {
170+
case 1:
171+
switch fn.kind {
172+
case var_pop, stddev_pop, regr_sxx, regr_syy, regr_sxy:
173+
ctx.ResultFloat(0)
174+
return
175+
case regr_avgx, regr_avgy:
176+
break
177+
default:
178+
return
179+
}
180+
case 0:
181+
return
182+
}
183+
154184
var r float64
155185
switch fn.kind {
156186
case var_pop:
@@ -175,11 +205,9 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
175205
r = fn.regr_slope()
176206
case regr_intercept:
177207
r = fn.regr_intercept()
178-
case regr_count:
179-
ctx.ResultInt64(fn.regr_count())
180-
return
181208
case regr_json:
182-
ctx.ResultText(fn.regr_json())
209+
var buf [128]byte
210+
ctx.ResultRawText(fn.regr_json(buf[:0]))
183211
return
184212
}
185213
ctx.ResultFloat(r)

ext/stats/stats_test.go

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,23 @@ func TestRegister_variance(t *testing.T) {
2929
t.Fatal(err)
3030
}
3131

32+
stmt, _, err := db.Prepare(`SELECT stddev_pop(x) FROM data`)
33+
if err != nil {
34+
t.Fatal(err)
35+
}
36+
if stmt.Step() {
37+
if got := stmt.ColumnType(0); got != sqlite3.NULL {
38+
t.Errorf("got %v, want NULL", got)
39+
}
40+
}
41+
stmt.Close()
42+
3243
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
3344
if err != nil {
3445
t.Fatal(err)
3546
}
3647

37-
stmt, _, err := db.Prepare(`
48+
stmt, _, err = db.Prepare(`
3849
SELECT
3950
sum(x), avg(x),
4051
var_samp(x), var_pop(x),
@@ -65,7 +76,11 @@ func TestRegister_variance(t *testing.T) {
6576
}
6677
stmt.Close()
6778

68-
stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
79+
stmt, _, err = db.Prepare(`
80+
SELECT
81+
var_samp(x) OVER (ROWS 1 PRECEDING),
82+
var_pop(x) OVER (ROWS 1 PRECEDING)
83+
FROM data`)
6984
if err != nil {
7085
t.Fatal(err)
7186
}
@@ -96,12 +111,26 @@ func TestRegister_covariance(t *testing.T) {
96111
t.Fatal(err)
97112
}
98113

114+
stmt, _, err := db.Prepare(`SELECT regr_count(y, x), regr_json(y, x) FROM data`)
115+
if err != nil {
116+
t.Fatal(err)
117+
}
118+
if stmt.Step() {
119+
if got := stmt.ColumnInt(0); got != 0 {
120+
t.Errorf("got %v, want 0", got)
121+
}
122+
if got := stmt.ColumnType(1); got != sqlite3.NULL {
123+
t.Errorf("got %v, want NULL", got)
124+
}
125+
}
126+
stmt.Close()
127+
99128
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
100129
if err != nil {
101130
t.Fatal(err)
102131
}
103132

104-
stmt, _, err := db.Prepare(`SELECT
133+
stmt, _, err = db.Prepare(`SELECT
105134
corr(y, x), covar_samp(y, x), covar_pop(y, x),
106135
regr_avgy(y, x), regr_avgx(y, x),
107136
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
@@ -157,7 +186,12 @@ func TestRegister_covariance(t *testing.T) {
157186
}
158187
stmt.Close()
159188

160-
stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
189+
stmt, _, err = db.Prepare(`
190+
SELECT
191+
covar_samp(y, x) OVER (ROWS 1 PRECEDING),
192+
covar_pop(y, x) OVER (ROWS 1 PRECEDING),
193+
regr_avgx(y, x) OVER (ROWS 1 PRECEDING)
194+
FROM data`)
161195
if err != nil {
162196
t.Fatal(err)
163197
}
@@ -171,6 +205,9 @@ func TestRegister_covariance(t *testing.T) {
171205
t.Errorf("got %v, want %v", got, want[i])
172206
}
173207
}
208+
if stmt.Err() != nil {
209+
t.Fatal(stmt.Err())
210+
}
174211
stmt.Close()
175212
}
176213

ext/stats/welford.go

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@ package stats
33
import (
44
"math"
55
"strconv"
6-
"strings"
6+
7+
"github.com/ncruces/go-sqlite3/internal/util"
78
)
89

910
// Welford's algorithm with Kahan summation:
11+
// The effect of truncation in statistical computation [van Reeken, AJ 1970]
1012
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
1113
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
1214

13-
// See also:
14-
// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates
15-
1615
type welford struct {
1716
m1, m2 kahan
1817
n int64
@@ -39,17 +38,23 @@ func (w welford) stddev_samp() float64 {
3938
}
4039

4140
func (w *welford) enqueue(x float64) {
42-
w.n++
41+
n := w.n + 1
42+
w.n = n
4343
d1 := x - w.m1.hi - w.m1.lo
44-
w.m1.add(d1 / float64(w.n))
44+
w.m1.add(d1 / float64(n))
4545
d2 := x - w.m1.hi - w.m1.lo
4646
w.m2.add(d1 * d2)
4747
}
4848

4949
func (w *welford) dequeue(x float64) {
50-
w.n--
50+
n := w.n - 1
51+
if n <= 0 {
52+
*w = welford{}
53+
return
54+
}
55+
w.n = n
5156
d1 := x - w.m1.hi - w.m1.lo
52-
w.m1.sub(d1 / float64(w.n))
57+
w.m1.sub(d1 / float64(n))
5358
d2 := x - w.m1.hi - w.m1.lo
5459
w.m2.sub(d1 * d2)
5560
}
@@ -112,38 +117,35 @@ func (w welford2) regr_r2() float64 {
112117
return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi)
113118
}
114119

115-
func (w welford2) regr_json() string {
116-
var json strings.Builder
117-
var num [32]byte
118-
json.Grow(128)
119-
json.WriteString(`{"count":`)
120-
json.Write(strconv.AppendInt(num[:0], w.regr_count(), 10))
121-
json.WriteString(`,"avgy":`)
122-
json.Write(strconv.AppendFloat(num[:0], w.regr_avgy(), 'g', -1, 64))
123-
json.WriteString(`,"avgx":`)
124-
json.Write(strconv.AppendFloat(num[:0], w.regr_avgx(), 'g', -1, 64))
125-
json.WriteString(`,"syy":`)
126-
json.Write(strconv.AppendFloat(num[:0], w.regr_syy(), 'g', -1, 64))
127-
json.WriteString(`,"sxx":`)
128-
json.Write(strconv.AppendFloat(num[:0], w.regr_sxx(), 'g', -1, 64))
129-
json.WriteString(`,"sxy":`)
130-
json.Write(strconv.AppendFloat(num[:0], w.regr_sxy(), 'g', -1, 64))
131-
json.WriteString(`,"slope":`)
132-
json.Write(strconv.AppendFloat(num[:0], w.regr_slope(), 'g', -1, 64))
133-
json.WriteString(`,"intercept":`)
134-
json.Write(strconv.AppendFloat(num[:0], w.regr_intercept(), 'g', -1, 64))
135-
json.WriteString(`,"r2":`)
136-
json.Write(strconv.AppendFloat(num[:0], w.regr_r2(), 'g', -1, 64))
137-
json.WriteByte('}')
138-
return json.String()
120+
func (w welford2) regr_json(dst []byte) []byte {
121+
dst = append(dst, `{"count":`...)
122+
dst = strconv.AppendInt(dst, w.regr_count(), 10)
123+
dst = append(dst, `,"avgy":`...)
124+
dst = util.AppendNumber(dst, w.regr_avgy())
125+
dst = append(dst, `,"avgx":`...)
126+
dst = util.AppendNumber(dst, w.regr_avgx())
127+
dst = append(dst, `,"syy":`...)
128+
dst = util.AppendNumber(dst, w.regr_syy())
129+
dst = append(dst, `,"sxx":`...)
130+
dst = util.AppendNumber(dst, w.regr_sxx())
131+
dst = append(dst, `,"sxy":`...)
132+
dst = util.AppendNumber(dst, w.regr_sxy())
133+
dst = append(dst, `,"slope":`...)
134+
dst = util.AppendNumber(dst, w.regr_slope())
135+
dst = append(dst, `,"intercept":`...)
136+
dst = util.AppendNumber(dst, w.regr_intercept())
137+
dst = append(dst, `,"r2":`...)
138+
dst = util.AppendNumber(dst, w.regr_r2())
139+
return append(dst, '}')
139140
}
140141

141142
func (w *welford2) enqueue(y, x float64) {
142-
w.n++
143+
n := w.n + 1
144+
w.n = n
143145
d1y := y - w.m1y.hi - w.m1y.lo
144146
d1x := x - w.m1x.hi - w.m1x.lo
145-
w.m1y.add(d1y / float64(w.n))
146-
w.m1x.add(d1x / float64(w.n))
147+
w.m1y.add(d1y / float64(n))
148+
w.m1x.add(d1x / float64(n))
147149
d2y := y - w.m1y.hi - w.m1y.lo
148150
d2x := x - w.m1x.hi - w.m1x.lo
149151
w.m2y.add(d1y * d2y)
@@ -152,11 +154,16 @@ func (w *welford2) enqueue(y, x float64) {
152154
}
153155

154156
func (w *welford2) dequeue(y, x float64) {
155-
w.n--
157+
n := w.n - 1
158+
if n <= 0 {
159+
*w = welford2{}
160+
return
161+
}
162+
w.n = n
156163
d1y := y - w.m1y.hi - w.m1y.lo
157164
d1x := x - w.m1x.hi - w.m1x.lo
158-
w.m1y.sub(d1y / float64(w.n))
159-
w.m1x.sub(d1x / float64(w.n))
165+
w.m1y.sub(d1y / float64(n))
166+
w.m1x.sub(d1x / float64(n))
160167
d2y := y - w.m1y.hi - w.m1y.lo
161168
d2x := x - w.m1x.hi - w.m1x.lo
162169
w.m2y.sub(d1y * d2y)

ext/stats/welford_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ func Test_welford(t *testing.T) {
3737
if s1.var_pop() != s2.var_pop() {
3838
t.Errorf("got %v, want %v", s1, s2)
3939
}
40+
41+
s1.dequeue(16)
42+
s1.dequeue(7)
43+
s1.dequeue(13)
44+
s1.enqueue(16)
45+
s1.enqueue(7)
46+
s1.enqueue(13)
47+
if s1.var_pop() != s2.var_pop() {
48+
t.Errorf("got %v, want %v", s1, s2)
49+
}
4050
}
4151

4252
func Test_covar(t *testing.T) {
@@ -65,6 +75,18 @@ func Test_covar(t *testing.T) {
6575
if c1.covar_pop() != c2.covar_pop() {
6676
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
6777
}
78+
79+
c1.dequeue(2, 60)
80+
c1.dequeue(5, 80)
81+
c1.dequeue(4, 75)
82+
c1.dequeue(7, 90)
83+
c1.enqueue(2, 60)
84+
c1.enqueue(5, 80)
85+
c1.enqueue(4, 75)
86+
c1.enqueue(7, 90)
87+
if c1.covar_pop() != c2.covar_pop() {
88+
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
89+
}
6890
}
6991

7092
func Test_correlation(t *testing.T) {

0 commit comments

Comments
 (0)