Skip to content

StackCheck: Check both under and overflow #3091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions src/passes/StackCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@

namespace wasm {

// The base is where the stack begins. As it goes down, that is the highest
// valid address.
static Name STACK_BASE("__stack_base");
// The limit is the farthest it can grow to, which is the lowest valid address.
static Name STACK_LIMIT("__stack_limit");
// Old version, which sets the limit.
// TODO: remove this
static Name SET_STACK_LIMIT("__set_stack_limit");
// New version, which sets the base and the limit.
static Name SET_STACK_LIMITS("__set_stack_limits");

static void importStackOverflowHandler(Module& module, Name name) {
ImportInfo info(module);
Expand All @@ -55,34 +63,43 @@ static void addExportedFunction(Module& module, Function* function) {
module.addExport(export_);
}

static void generateSetStackLimitFunction(Module& module) {
static void generateSetStackLimitFunctions(Module& module) {
Builder builder(module);
Function* function =
// One-parameter version
Function* limitFunc =
builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {});
LocalGet* getArg = builder.makeLocalGet(0, Type::i32);
Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
function->body = store;
addExportedFunction(module, function);
limitFunc->body = store;
addExportedFunction(module, limitFunc);
// Two-parameter version
Function* limitsFunc = builder.makeFunction(
SET_STACK_LIMITS, Signature({Type::i32, Type::i32}, Type::none), {});
LocalGet* getBase = builder.makeLocalGet(0, Type::i32);
Expression* storeBase = builder.makeGlobalSet(STACK_BASE, getBase);
LocalGet* getLimit = builder.makeLocalGet(1, Type::i32);
Expression* storeLimit = builder.makeGlobalSet(STACK_LIMIT, getLimit);
limitsFunc->body = builder.makeBlock({storeBase, storeLimit});
addExportedFunction(module, limitsFunc);
}

struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
EnforceStackLimit(Global* stackPointer,
Global* stackLimit,
Builder& builder,
Name handler)
: stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
handler(handler) {}
struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> {
EnforceStackLimits(Global* stackPointer,
Global* stackBase,
Global* stackLimit,
Builder& builder,
Name handler)
: stackPointer(stackPointer), stackBase(stackBase), stackLimit(stackLimit),
builder(builder), handler(handler) {}

bool isFunctionParallel() override { return true; }

Pass* create() override {
return new EnforceStackLimit(stackPointer, stackLimit, builder, handler);
return new EnforceStackLimits(
stackPointer, stackBase, stackLimit, builder, handler);
}

Expression* stackBoundsCheck(Function* func,
Expression* value,
Global* stackPointer,
Global* stackLimit) {
Expression* stackBoundsCheck(Function* func, Expression* value) {
// Add a local to store the value of the expression. We need the value
// twice: once to check if it has overflowed, and again to assign to store
// it.
Expand All @@ -95,12 +112,18 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
} else {
handlerExpr = builder.makeUnreachable();
}
// (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit))
// If it is >= the base or <= the limit, then error.
auto check = builder.makeIf(
builder.makeBinary(
BinaryOp::LtUInt32,
builder.makeLocalTee(newSP, value, stackPointer->type),
builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
BinaryOp::OrInt32,
builder.makeBinary(
BinaryOp::GtUInt32,
builder.makeLocalTee(newSP, value, stackPointer->type),
builder.makeGlobalGet(stackBase->name, stackBase->type)),
builder.makeBinary(
BinaryOp::LtUInt32,
builder.makeLocalGet(newSP, stackPointer->type),
builder.makeGlobalGet(stackLimit->name, stackLimit->type))),
handlerExpr);
// (global.set $__stack_pointer (local.get $newSP))
auto newSet = builder.makeGlobalSet(
Expand All @@ -110,13 +133,13 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {

void visitGlobalSet(GlobalSet* curr) {
if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
replaceCurrent(
stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit));
replaceCurrent(stackBoundsCheck(getFunction(), curr->value));
}
}

private:
Global* stackPointer;
Global* stackBase;
Global* stackLimit;
Builder& builder;
Name handler;
Expand All @@ -139,16 +162,22 @@ struct StackCheck : public Pass {
}

Builder builder(*module);
Global* stackBase = builder.makeGlobal(STACK_BASE,
stackPointer->type,
builder.makeConst(int32_t(0)),
Builder::Mutable);
module->addGlobal(stackBase);

Global* stackLimit = builder.makeGlobal(STACK_LIMIT,
stackPointer->type,
builder.makeConst(int32_t(0)),
Builder::Mutable);
module->addGlobal(stackLimit);

PassRunner innerRunner(module);
EnforceStackLimit(stackPointer, stackLimit, builder, handler)
EnforceStackLimits(stackPointer, stackBase, stackLimit, builder, handler)
.run(&innerRunner, module);
generateSetStackLimitFunction(*module);
generateSetStackLimitFunctions(*module);
}
};

Expand Down
50 changes: 37 additions & 13 deletions test/lld/basic_safe_stack.wat.out
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
(type $none_=>_none (func))
(type $i32_=>_none (func (param i32)))
(type $i32_=>_i32 (func (param i32) (result i32)))
(type $i32_i32_=>_none (func (param i32 i32)))
(import "env" "__handle_stack_overflow" (func $__handle_stack_overflow))
(memory $0 2)
(table $0 1 1 funcref)
(global $global$0 (mut i32) (i32.const 66112))
(global $global$1 i32 (i32.const 568))
(global $__stack_base (mut i32) (i32.const 0))
(global $__stack_limit (mut i32) (i32.const 0))
(export "memory" (memory $0))
(export "__wasm_call_ctors" (func $__wasm_call_ctors))
Expand All @@ -15,18 +17,25 @@
(export "main" (func $main))
(export "__data_end" (global $global$1))
(export "__set_stack_limit" (func $__set_stack_limit))
(export "__set_stack_limits" (func $__set_stack_limits))
(export "__growWasmMemory" (func $__growWasmMemory))
(func $__wasm_call_ctors
(nop)
)
(func $stackRestore (param $0 i32)
(local $1 i32)
(if
(i32.lt_u
(local.tee $1
(local.get $0)
(i32.or
(i32.gt_u
(local.tee $1
(local.get $0)
)
(global.get $__stack_base)
)
(i32.lt_u
(local.get $1)
(global.get $__stack_limit)
)
(global.get $__stack_limit)
)
(call $__handle_stack_overflow)
)
Expand All @@ -40,19 +49,25 @@
(local $3 i32)
(block
(if
(i32.lt_u
(local.tee $3
(local.tee $1
(i32.and
(i32.sub
(global.get $global$0)
(local.get $0)
(i32.or
(i32.gt_u
(local.tee $3
(local.tee $1
(i32.and
(i32.sub
(global.get $global$0)
(local.get $0)
)
(i32.const -16)
)
(i32.const -16)
)
)
(global.get $__stack_base)
)
(i32.lt_u
(local.get $3)
(global.get $__stack_limit)
)
(global.get $__stack_limit)
)
(call $__handle_stack_overflow)
)
Expand All @@ -70,6 +85,14 @@
(local.get $0)
)
)
(func $__set_stack_limits (param $0 i32) (param $1 i32)
(global.set $__stack_base
(local.get $0)
)
(global.set $__stack_limit
(local.get $1)
)
)
(func $__growWasmMemory (param $newSize i32) (result i32)
(memory.grow
(local.get $newSize)
Expand All @@ -95,6 +118,7 @@
"stackAlloc",
"main",
"__set_stack_limit",
"__set_stack_limits",
"__growWasmMemory"
],
"namedGlobals": {
Expand Down
Loading