Skip to content

feat: support more types in switch statements #2926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 67 additions & 37 deletions src/compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2816,19 +2816,18 @@ export class Compiler extends DiagnosticEmitter {
let numCases = cases.length;

// Compile the condition (always executes)
let condExpr = this.compileExpression(statement.condition, Type.u32,
Constraints.ConvImplicit
);

let condExpr = this.compileExpression(statement.condition, Type.auto);
let condType = this.currentType;

// Shortcut if there are no cases
if (!numCases) return module.drop(condExpr);

// Assign the condition to a temporary local as we compare it multiple times
let outerFlow = this.currentFlow;
let tempLocal = outerFlow.getTempLocal(Type.u32);
let tempLocal = outerFlow.getTempLocal(condType);
let tempLocalIndex = tempLocal.index;
let breaks = new Array<ExpressionRef>(1 + numCases);
breaks[0] = module.local_set(tempLocalIndex, condExpr, false); // u32
breaks[0] = module.local_set(tempLocalIndex, condExpr, condType.isManaged);

// Make one br_if per labeled case and leave it to Binaryen to optimize the
// sequence of br_ifs to a br_table according to optimization levels
Expand All @@ -2841,14 +2840,24 @@ export class Compiler extends DiagnosticEmitter {
defaultIndex = i;
continue;
}
breaks[breakIndex++] = module.br(`case${i}|${label}`,
module.binary(BinaryOp.EqI32,
module.local_get(tempLocalIndex, TypeRef.I32),
this.compileExpression(assert(case_.label), Type.u32,
Constraints.ConvImplicit
)
)

// Compile the equality expression for this case
const left = statement.condition;
const leftExpr = module.local_get(tempLocalIndex, condType.toRef());
const leftType = condType;
const right = case_.label!;
const rightExpr = this.compileExpression(assert(case_.label), condType, Constraints.ConvImplicit);
const rightType = this.currentType;
const equalityExpr = this.compileCommutativeCompareBinaryExpressionFromParts(
Token.Equals_Equals,
left, leftExpr, leftType,
right, rightExpr, rightType,
condType,
statement
);

// Add it to the list of breaks
breaks[breakIndex++] = module.br(`case${i}|${label}`, equalityExpr);
}

// If there is a default case, break to it, otherwise break out of the switch
Expand Down Expand Up @@ -3800,32 +3809,53 @@ export class Compiler extends DiagnosticEmitter {
expression: BinaryExpression,
contextualType: Type,
): ExpressionRef {
let module = this.module;
let left = expression.left;
let right = expression.right;

const left = expression.left;
const leftExpr = this.compileExpression(left, contextualType);
const leftType = this.currentType;

const right = expression.right;
const rightExpr = this.compileExpression(right, leftType);
const rightType = this.currentType;

return this.compileCommutativeCompareBinaryExpressionFromParts(
expression.operator,
left, leftExpr, leftType,
right, rightExpr, rightType,
contextualType,
expression
);
}

let leftExpr: ExpressionRef;
let leftType: Type;
let rightExpr: ExpressionRef;
let rightType: Type;
let commonType: Type | null;
/**
* compile `==` `===` `!=` `!==` BinaryExpression, from previously compiled left and right expressions.
*
* This is split from `compileCommutativeCompareBinaryExpression` so that the logic can be reused
* for switch cases in `compileSwitchStatement`, where the left expression only should be compiled once.
*/
private compileCommutativeCompareBinaryExpressionFromParts(
operator: Token,
left: Expression,
leftExpr: ExpressionRef,
leftType: Type,
right: Expression,
rightExpr: ExpressionRef,
rightType: Type,
contextualType: Type,
reportNode: Node
): ExpressionRef {

let operator = expression.operator;
let module = this.module;
let operatorString = operatorTokenToString(operator);

leftExpr = this.compileExpression(left, contextualType);
leftType = this.currentType;

rightExpr = this.compileExpression(right, leftType);
rightType = this.currentType;

// check operator overload
const operatorKind = OperatorKind.fromBinaryToken(operator);
const leftOverload = leftType.lookupOverload(operatorKind, this.program);
const rightOverload = rightType.lookupOverload(operatorKind, this.program);
if (leftOverload && rightOverload && leftOverload != rightOverload) {
this.error(
DiagnosticCode.Ambiguous_operator_overload_0_conflicting_overloads_1_and_2, expression.range,
DiagnosticCode.Ambiguous_operator_overload_0_conflicting_overloads_1_and_2,
reportNode.range,
operatorString,
leftOverload.internalName,
rightOverload.internalName
Expand All @@ -3838,23 +3868,23 @@ export class Compiler extends DiagnosticEmitter {
leftOverload,
left, leftExpr, leftType,
right, rightExpr, rightType,
expression
reportNode
);
}
if (rightOverload) {
return this.compileCommutativeBinaryOverload(
rightOverload,
right, rightExpr, rightType,
left, leftExpr, leftType,
expression
reportNode
);
}
const signednessIsRelevant = false;
commonType = Type.commonType(leftType, rightType, contextualType, signednessIsRelevant);
const commonType = Type.commonType(leftType, rightType, contextualType, signednessIsRelevant);
if (!commonType) {
this.error(
DiagnosticCode.Operator_0_cannot_be_applied_to_types_1_and_2,
expression.range,
reportNode.range,
operatorString,
leftType.toString(),
rightType.toString()
Expand All @@ -3867,13 +3897,13 @@ export class Compiler extends DiagnosticEmitter {
if (isConstExpressionNaN(module, rightExpr) || isConstExpressionNaN(module, leftExpr)) {
this.warning(
DiagnosticCode._NaN_does_not_compare_equal_to_any_other_value_including_itself_Use_isNaN_x_instead,
expression.range
reportNode.range
);
}
if (isConstNegZero(rightExpr) || isConstNegZero(leftExpr)) {
this.warning(
DiagnosticCode.Comparison_with_0_0_is_sign_insensitive_Use_Object_is_x_0_0_if_the_sign_matters,
expression.range
reportNode.range
);
}
}
Expand All @@ -3887,10 +3917,10 @@ export class Compiler extends DiagnosticEmitter {
switch (operator) {
case Token.Equals_Equals_Equals:
case Token.Equals_Equals:
return this.makeEq(leftExpr, rightExpr, commonType, expression);
return this.makeEq(leftExpr, rightExpr, commonType, reportNode);
case Token.Exclamation_Equals_Equals:
case Token.Exclamation_Equals:
return this.makeNe(leftExpr, rightExpr, commonType, expression);
return this.makeNe(leftExpr, rightExpr, commonType, reportNode);
default:
assert(false);
return module.unreachable();
Expand Down
Loading