Skip to content

Commit 89020a0

Browse files
authored
StackCheck: Check both under and overflow (#3091)
See emscripten-core/emscripten#9039 (comment) The valid stack area is a region [A, B] in memory. Previously we just checked that new stack positions S were S >= A, which prevented us from growing too much (the stack grows down). But that only worked if the growth was small enough to not overflow and become a big unsigned value. This PR makes us check the other way too, which requires us to know where the stack starts out at. This still supports the old way of just passing in the growth limit. We can remove it after the roll. In principle this can all be done on the LLVM side too after emscripten-core/emscripten#12057 but I'm not sure of the details there, and this is easy to fix here and get testing up (which can help with later LLVM work). This helps emscripten-core/emscripten#11860 by allowing us to clean up some fastcomp-specific stuff in tests.
1 parent ef7ab77 commit 89020a0

6 files changed

+269
-89
lines changed

src/passes/StackCheck.cpp

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,16 @@
3131

3232
namespace wasm {
3333

34+
// The base is where the stack begins. As it goes down, that is the highest
35+
// valid address.
36+
static Name STACK_BASE("__stack_base");
37+
// The limit is the farthest it can grow to, which is the lowest valid address.
3438
static Name STACK_LIMIT("__stack_limit");
39+
// Old version, which sets the limit.
40+
// TODO: remove this
3541
static Name SET_STACK_LIMIT("__set_stack_limit");
42+
// New version, which sets the base and the limit.
43+
static Name SET_STACK_LIMITS("__set_stack_limits");
3644

3745
static void importStackOverflowHandler(Module& module, Name name) {
3846
ImportInfo info(module);
@@ -55,34 +63,43 @@ static void addExportedFunction(Module& module, Function* function) {
5563
module.addExport(export_);
5664
}
5765

58-
static void generateSetStackLimitFunction(Module& module) {
66+
static void generateSetStackLimitFunctions(Module& module) {
5967
Builder builder(module);
60-
Function* function =
68+
// One-parameter version
69+
Function* limitFunc =
6170
builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {});
6271
LocalGet* getArg = builder.makeLocalGet(0, Type::i32);
6372
Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
64-
function->body = store;
65-
addExportedFunction(module, function);
73+
limitFunc->body = store;
74+
addExportedFunction(module, limitFunc);
75+
// Two-parameter version
76+
Function* limitsFunc = builder.makeFunction(
77+
SET_STACK_LIMITS, Signature({Type::i32, Type::i32}, Type::none), {});
78+
LocalGet* getBase = builder.makeLocalGet(0, Type::i32);
79+
Expression* storeBase = builder.makeGlobalSet(STACK_BASE, getBase);
80+
LocalGet* getLimit = builder.makeLocalGet(1, Type::i32);
81+
Expression* storeLimit = builder.makeGlobalSet(STACK_LIMIT, getLimit);
82+
limitsFunc->body = builder.makeBlock({storeBase, storeLimit});
83+
addExportedFunction(module, limitsFunc);
6684
}
6785

68-
struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
69-
EnforceStackLimit(Global* stackPointer,
70-
Global* stackLimit,
71-
Builder& builder,
72-
Name handler)
73-
: stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
74-
handler(handler) {}
86+
struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> {
87+
EnforceStackLimits(Global* stackPointer,
88+
Global* stackBase,
89+
Global* stackLimit,
90+
Builder& builder,
91+
Name handler)
92+
: stackPointer(stackPointer), stackBase(stackBase), stackLimit(stackLimit),
93+
builder(builder), handler(handler) {}
7594

7695
bool isFunctionParallel() override { return true; }
7796

7897
Pass* create() override {
79-
return new EnforceStackLimit(stackPointer, stackLimit, builder, handler);
98+
return new EnforceStackLimits(
99+
stackPointer, stackBase, stackLimit, builder, handler);
80100
}
81101

82-
Expression* stackBoundsCheck(Function* func,
83-
Expression* value,
84-
Global* stackPointer,
85-
Global* stackLimit) {
102+
Expression* stackBoundsCheck(Function* func, Expression* value) {
86103
// Add a local to store the value of the expression. We need the value
87104
// twice: once to check if it has overflowed, and again to assign to store
88105
// it.
@@ -95,12 +112,18 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
95112
} else {
96113
handlerExpr = builder.makeUnreachable();
97114
}
98-
// (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit))
115+
// If it is >= the base or <= the limit, then error.
99116
auto check = builder.makeIf(
100117
builder.makeBinary(
101-
BinaryOp::LtUInt32,
102-
builder.makeLocalTee(newSP, value, stackPointer->type),
103-
builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
118+
BinaryOp::OrInt32,
119+
builder.makeBinary(
120+
BinaryOp::GtUInt32,
121+
builder.makeLocalTee(newSP, value, stackPointer->type),
122+
builder.makeGlobalGet(stackBase->name, stackBase->type)),
123+
builder.makeBinary(
124+
BinaryOp::LtUInt32,
125+
builder.makeLocalGet(newSP, stackPointer->type),
126+
builder.makeGlobalGet(stackLimit->name, stackLimit->type))),
104127
handlerExpr);
105128
// (global.set $__stack_pointer (local.get $newSP))
106129
auto newSet = builder.makeGlobalSet(
@@ -110,13 +133,13 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
110133

111134
void visitGlobalSet(GlobalSet* curr) {
112135
if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
113-
replaceCurrent(
114-
stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit));
136+
replaceCurrent(stackBoundsCheck(getFunction(), curr->value));
115137
}
116138
}
117139

118140
private:
119141
Global* stackPointer;
142+
Global* stackBase;
120143
Global* stackLimit;
121144
Builder& builder;
122145
Name handler;
@@ -139,16 +162,22 @@ struct StackCheck : public Pass {
139162
}
140163

141164
Builder builder(*module);
165+
Global* stackBase = builder.makeGlobal(STACK_BASE,
166+
stackPointer->type,
167+
builder.makeConst(int32_t(0)),
168+
Builder::Mutable);
169+
module->addGlobal(stackBase);
170+
142171
Global* stackLimit = builder.makeGlobal(STACK_LIMIT,
143172
stackPointer->type,
144173
builder.makeConst(int32_t(0)),
145174
Builder::Mutable);
146175
module->addGlobal(stackLimit);
147176

148177
PassRunner innerRunner(module);
149-
EnforceStackLimit(stackPointer, stackLimit, builder, handler)
178+
EnforceStackLimits(stackPointer, stackBase, stackLimit, builder, handler)
150179
.run(&innerRunner, module);
151-
generateSetStackLimitFunction(*module);
180+
generateSetStackLimitFunctions(*module);
152181
}
153182
};
154183

test/lld/basic_safe_stack.wat.out

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
(type $none_=>_none (func))
33
(type $i32_=>_none (func (param i32)))
44
(type $i32_=>_i32 (func (param i32) (result i32)))
5+
(type $i32_i32_=>_none (func (param i32 i32)))
56
(import "env" "__handle_stack_overflow" (func $__handle_stack_overflow))
67
(memory $0 2)
78
(table $0 1 1 funcref)
89
(global $global$0 (mut i32) (i32.const 66112))
910
(global $global$1 i32 (i32.const 568))
11+
(global $__stack_base (mut i32) (i32.const 0))
1012
(global $__stack_limit (mut i32) (i32.const 0))
1113
(export "memory" (memory $0))
1214
(export "__wasm_call_ctors" (func $__wasm_call_ctors))
@@ -15,18 +17,25 @@
1517
(export "main" (func $main))
1618
(export "__data_end" (global $global$1))
1719
(export "__set_stack_limit" (func $__set_stack_limit))
20+
(export "__set_stack_limits" (func $__set_stack_limits))
1821
(export "__growWasmMemory" (func $__growWasmMemory))
1922
(func $__wasm_call_ctors
2023
(nop)
2124
)
2225
(func $stackRestore (param $0 i32)
2326
(local $1 i32)
2427
(if
25-
(i32.lt_u
26-
(local.tee $1
27-
(local.get $0)
28+
(i32.or
29+
(i32.gt_u
30+
(local.tee $1
31+
(local.get $0)
32+
)
33+
(global.get $__stack_base)
34+
)
35+
(i32.lt_u
36+
(local.get $1)
37+
(global.get $__stack_limit)
2838
)
29-
(global.get $__stack_limit)
3039
)
3140
(call $__handle_stack_overflow)
3241
)
@@ -40,19 +49,25 @@
4049
(local $3 i32)
4150
(block
4251
(if
43-
(i32.lt_u
44-
(local.tee $3
45-
(local.tee $1
46-
(i32.and
47-
(i32.sub
48-
(global.get $global$0)
49-
(local.get $0)
52+
(i32.or
53+
(i32.gt_u
54+
(local.tee $3
55+
(local.tee $1
56+
(i32.and
57+
(i32.sub
58+
(global.get $global$0)
59+
(local.get $0)
60+
)
61+
(i32.const -16)
5062
)
51-
(i32.const -16)
5263
)
5364
)
65+
(global.get $__stack_base)
66+
)
67+
(i32.lt_u
68+
(local.get $3)
69+
(global.get $__stack_limit)
5470
)
55-
(global.get $__stack_limit)
5671
)
5772
(call $__handle_stack_overflow)
5873
)
@@ -70,6 +85,14 @@
7085
(local.get $0)
7186
)
7287
)
88+
(func $__set_stack_limits (param $0 i32) (param $1 i32)
89+
(global.set $__stack_base
90+
(local.get $0)
91+
)
92+
(global.set $__stack_limit
93+
(local.get $1)
94+
)
95+
)
7396
(func $__growWasmMemory (param $newSize i32) (result i32)
7497
(memory.grow
7598
(local.get $newSize)
@@ -95,6 +118,7 @@
95118
"stackAlloc",
96119
"main",
97120
"__set_stack_limit",
121+
"__set_stack_limits",
98122
"__growWasmMemory"
99123
],
100124
"namedGlobals": {

0 commit comments

Comments
 (0)