Skip to content

Commit 0dff178

Browse files
authored
DeNaN improvements (#2888)
Instead of instrumenting every local.get, instrument parameters on arrival at a function once on entry. After that, every local will always contain a de-naned value (since we would denan on a local.set). This is more efficient and also less confusing I think. Also avoid doing anything to values that fall through as they have already been fixed up.
1 parent 501b0a0 commit 0dff178

File tree

4 files changed

+212
-7
lines changed

4 files changed

+212
-7
lines changed

src/ir/properties.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ inline Index getZeroExtBits(Expression* curr) {
200200
}
201201

202202
// Returns a falling-through value, that is, it looks through a local.tee
203-
// and other operations that receive a value and let it flow through them.
203+
// and other operations that receive a value and let it flow through them. If
204+
// there is no value falling through, returns the node itself (as that is the
205+
// value that trivially falls through, with 0 steps in the middle).
204206
inline Expression* getFallthrough(Expression* curr,
205207
const PassOptions& passOptions,
206208
FeatureSet features) {
@@ -241,6 +243,25 @@ inline Expression* getFallthrough(Expression* curr,
241243
return curr;
242244
}
243245

246+
// Returns whether the resulting value here must fall through without being
247+
// modified. For example, a tee always does so. That is, this returns false if
248+
// and only if the return value may have some computation performed on it to
249+
// change it from the inputs the instruction receives.
250+
// This differs from getFallthrough() which returns a single value that falls
251+
// through - here if more than one value can fall through, like in if-else,
252+
// we can return true. That is, there we care about a value falling through and
253+
// for us to get that actual value to look at; here we just care whether the
254+
// value falls through without being changed, even if it might be one of
255+
// several options.
256+
inline bool isResultFallthrough(Expression* curr) {
257+
// Note that we don't check if there is a return value here; the node may be
258+
// unreachable, for example, but then there is no meaningful answer to give
259+
// anyhow.
260+
return curr->is<LocalSet>() || curr->is<Block>() || curr->is<If>() ||
261+
curr->is<Loop>() || curr->is<Try>() || curr->is<Select>() ||
262+
curr->is<Break>();
263+
}
264+
244265
} // namespace Properties
245266

246267
} // namespace wasm

src/passes/DeNaN.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
// differ on wasm's nondeterminism around NaNs.
2323
//
2424

25+
#include "ir/properties.h"
2526
#include "pass.h"
2627
#include "wasm-builder.h"
2728
#include "wasm.h"
@@ -33,7 +34,18 @@ struct DeNaN : public WalkerPass<
3334
void visitExpression(Expression* expr) {
3435
// If the expression returns a floating-point value, ensure it is not a
3536
// NaN. If we can do this at compile time, do it now, which is useful for
36-
// initializations of global (which we can't do a function call in).
37+
// initializations of global (which we can't do a function call in). Note
38+
// that we don't instrument local.gets, which would cause problems if we
39+
// ran this pass more than once (the added functions use gets, and we don't
40+
// want to instrument them).
41+
if (expr->is<LocalGet>()) {
42+
return;
43+
}
44+
// If the result just falls through without being modified, then we've
45+
// already fixed it up earlier.
46+
if (Properties::isResultFallthrough(expr)) {
47+
return;
48+
}
3749
Builder builder(*getModule());
3850
Expression* replacement = nullptr;
3951
auto* c = expr->dynCast<Const>();
@@ -61,6 +73,38 @@ struct DeNaN : public WalkerPass<
6173
}
6274
}
6375

76+
void visitFunction(Function* func) {
77+
if (func->imported()) {
78+
return;
79+
}
80+
// Instrument all locals as they enter the function.
81+
Builder builder(*getModule());
82+
std::vector<Expression*> fixes;
83+
auto num = func->getNumParams();
84+
for (Index i = 0; i < num; i++) {
85+
if (func->getLocalType(i) == Type::f32) {
86+
fixes.push_back(builder.makeLocalSet(
87+
i,
88+
builder.makeCall(
89+
"deNan32", {builder.makeLocalGet(i, Type::f32)}, Type::f32)));
90+
} else if (func->getLocalType(i) == Type::f64) {
91+
fixes.push_back(builder.makeLocalSet(
92+
i,
93+
builder.makeCall(
94+
"deNan64", {builder.makeLocalGet(i, Type::f64)}, Type::f64)));
95+
}
96+
}
97+
if (!fixes.empty()) {
98+
fixes.push_back(func->body);
99+
func->body = builder.makeBlock(fixes);
100+
// Merge blocks so we don't add an unnecessary one.
101+
PassRunner runner(getModule(), getPassOptions());
102+
runner.setIsNested(true);
103+
runner.add("merge-blocks");
104+
runner.run();
105+
}
106+
}
107+
64108
void visitModule(Module* module) {
65109
// Add helper functions.
66110
Builder builder(*module);

test/passes/denan.txt

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,139 @@
11
(module
22
(type $f32_=>_f32 (func (param f32) (result f32)))
33
(type $f64_=>_f64 (func (param f64) (result f64)))
4+
(type $i32_f32_i64_f64_=>_none (func (param i32 f32 i64 f64)))
5+
(type $f32_f64_=>_none (func (param f32 f64)))
46
(global $global$1 (mut f32) (f32.const 0))
57
(global $global$2 (mut f32) (f32.const 12.34000015258789))
68
(func $foo32 (param $x f32) (result f32)
9+
(local.set $x
10+
(call $deNan32
11+
(local.get $x)
12+
)
13+
)
714
(call $deNan32
815
(call $foo32
9-
(call $deNan32
10-
(local.get $x)
11-
)
16+
(local.get $x)
1217
)
1318
)
1419
)
1520
(func $foo64 (param $x f64) (result f64)
21+
(local.set $x
22+
(call $deNan64
23+
(local.get $x)
24+
)
25+
)
1626
(call $deNan64
1727
(call $foo64
18-
(call $deNan64
19-
(local.get $x)
28+
(local.get $x)
29+
)
30+
)
31+
)
32+
(func $various (param $x i32) (param $y f32) (param $z i64) (param $w f64)
33+
(local.set $y
34+
(call $deNan32
35+
(local.get $y)
36+
)
37+
)
38+
(local.set $w
39+
(call $deNan64
40+
(local.get $w)
41+
)
42+
)
43+
(nop)
44+
)
45+
(func $ignore-local.get (param $f f32) (param $d f64)
46+
(local.set $f
47+
(call $deNan32
48+
(local.get $f)
49+
)
50+
)
51+
(local.set $d
52+
(call $deNan64
53+
(local.get $d)
54+
)
55+
)
56+
(drop
57+
(local.get $f)
58+
)
59+
(drop
60+
(local.get $d)
61+
)
62+
(local.set $f
63+
(local.get $f)
64+
)
65+
(local.set $d
66+
(local.get $d)
67+
)
68+
(drop
69+
(local.get $f)
70+
)
71+
(drop
72+
(local.get $d)
73+
)
74+
(drop
75+
(call $deNan32
76+
(f32.abs
77+
(local.get $f)
78+
)
79+
)
80+
)
81+
(drop
82+
(call $deNan64
83+
(f64.abs
84+
(local.get $d)
85+
)
86+
)
87+
)
88+
(local.set $f
89+
(call $deNan32
90+
(f32.abs
91+
(local.get $f)
92+
)
93+
)
94+
)
95+
(local.set $d
96+
(call $deNan64
97+
(f64.abs
98+
(local.get $d)
99+
)
100+
)
101+
)
102+
(drop
103+
(local.get $f)
104+
)
105+
(drop
106+
(local.get $d)
107+
)
108+
)
109+
(func $tees (param $x f32) (result f32)
110+
(local.set $x
111+
(call $deNan32
112+
(local.get $x)
113+
)
114+
)
115+
(local.tee $x
116+
(local.tee $x
117+
(local.tee $x
118+
(local.tee $x
119+
(local.get $x)
120+
)
20121
)
21122
)
22123
)
23124
)
125+
(func $select (param $x f32) (result f32)
126+
(local.set $x
127+
(call $deNan32
128+
(local.get $x)
129+
)
130+
)
131+
(select
132+
(local.get $x)
133+
(local.get $x)
134+
(i32.const 1)
135+
)
136+
)
24137
(func $deNan32 (param $0 f32) (result f32)
25138
(if (result f32)
26139
(f32.eq

test/passes/denan.wast

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,31 @@
77
(func $foo64 (param $x f64) (result f64)
88
(call $foo64 (local.get $x))
99
)
10+
(func $various (param $x i32) (param $y f32) (param $z i64) (param $w f64)
11+
)
12+
(func $ignore-local.get (param $f f32) (param $d f64)
13+
(drop (local.get $f))
14+
(drop (local.get $d))
15+
(local.set $f (local.get $f))
16+
(local.set $d (local.get $d))
17+
(drop (local.get $f))
18+
(drop (local.get $d))
19+
(drop (f32.abs (local.get $f)))
20+
(drop (f64.abs (local.get $d)))
21+
(local.set $f (f32.abs (local.get $f)))
22+
(local.set $d (f64.abs (local.get $d)))
23+
(drop (local.get $f))
24+
(drop (local.get $d))
25+
)
26+
(func $tees (param $x f32) (result f32)
27+
(local.tee $x
28+
(local.tee $x
29+
(local.tee $x
30+
(local.tee $x
31+
(local.get $x))))))
32+
(func $select (param $x f32) (result f32)
33+
(select
34+
(local.get $x)
35+
(local.get $x)
36+
(i32.const 1)))
1037
)

0 commit comments

Comments
 (0)