Skip to content

Add bitwise Int compiletime operations #8377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,7 @@ class Definitions {
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max, tpnme.ToString,
tpnme.Xor, tpnme.BitwiseAnd, tpnme.BitwiseOr, tpnme.ASR, tpnme.LSL, tpnme.LSR
)
private val compiletimePackageBooleanTypes: Set[Name] = Set(tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or)
private val compiletimePackageStringTypes: Set[Name] = Set(tpnme.Plus)
Expand Down
44 changes: 23 additions & 21 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,27 +208,29 @@ object StdNames {
final val IOOBException: N = "IndexOutOfBoundsException"
final val FunctionXXL: N = "FunctionXXL"

final val Abs: N = "Abs"
final val And: N = "&&"
final val Div: N = "/"
final val Equals: N = "=="
final val Ge: N = ">="
final val Gt: N = ">"
final val Le: N = "<="
final val Lt: N = "<"
final val Max: N = "Max"
final val Min: N = "Min"
final val Minus: N = "-"
final val Mod: N = "%"
final val Negate: N = "Negate"
final val Not: N = "!"
final val NotEquals: N = "!="
final val Or: N = "||"
final val Plus: N = "+"
final val S: N = "S"
final val Times: N = "*"
final val ToString: N = "ToString"
final val Xor: N = "^"
final val Abs: N = "Abs"
final val And: N = "&&"
final val BitwiseAnd: N = "BitwiseAnd"
final val BitwiseOr: N = "BitwiseOr"
final val Div: N = "/"
final val Equals: N = "=="
final val Ge: N = ">="
final val Gt: N = ">"
final val Le: N = "<="
final val Lt: N = "<"
final val Max: N = "Max"
final val Min: N = "Min"
final val Minus: N = "-"
final val Mod: N = "%"
final val Negate: N = "Negate"
final val Not: N = "!"
final val NotEquals: N = "!="
final val Or: N = "||"
final val Plus: N = "+"
final val S: N = "S"
final val Times: N = "*"
final val ToString: N = "ToString"
final val Xor: N = "^"

final val ClassfileAnnotation: N = "ClassfileAnnotation"
final val ClassManifest: N = "ClassManifest"
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3842,6 +3842,12 @@ object Types {
case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _)
case tpnme.Xor if nArgs == 2 => constantFold2(intValue, _ ^ _)
case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(intValue, _ & _)
case tpnme.BitwiseOr if nArgs == 2 => constantFold2(intValue, _ | _)
case tpnme.ASR if nArgs == 2 => constantFold2(intValue, _ >> _)
case tpnme.LSL if nArgs == 2 => constantFold2(intValue, _ << _)
case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
case _ => None
Expand Down
47 changes: 45 additions & 2 deletions library/src/scala/compiletime/ops/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,40 @@ package object ops {
@infix type /[X <: Int, Y <: Int] <: Int

/** Remainder of the division of `X` by `Y`.
* ```scala
* ```scala
* val mod: 5 % 2 = 1
* ```
*/
@infix type %[X <: Int, Y <: Int] <: Int

/** Binary left shift of `X` by `Y`.
* ```scala
* val lshift: 1 << 2 = 4
* ```
*/
@infix type <<[X <: Int, Y <: Int] <: Int

/** Binary right shift of `X` by `Y`.
* ```scala
* val rshift: 10 >> 1 = 5
* ```
*/
@infix type >>[X <: Int, Y <: Int] <: Int

/** Binary right shift of `X` by `Y`, filling the left with zeros.
* ```scala
* val rshiftzero: 10 >>> 1 = 5
* ```
*/
@infix type >>>[X <: Int, Y <: Int] <: Int

/** Bitwise xor of `X` and `Y`.
* ```scala
* val xor: 10 ^ 30 = 20
* ```
*/
@infix type ^[X <: Int, Y <: Int] <: Int

/** Less-than comparison of two `Int` singleton types.
* ```scala
* val lt1: 4 < 2 = false
Expand Down Expand Up @@ -100,6 +128,21 @@ package object ops {
*/
@infix type <=[X <: Int, Y <: Int] <: Boolean

/** Bitwise and of `X` and `Y`.
* ```scala
* val and1: BitwiseAnd[4, 4] = 4
* val and2: BitwiseAnd[10, 5] = 0
* ```
*/
type BitwiseAnd[X <: Int, Y <: Int] <: Int

/** Bitwise or of `X` and `Y`.
* ```scala
* val or: BitwiseOr[10, 11] = 11
* ```
*/
type BitwiseOr[X <: Int, Y <: Int] <: Int

/** Absolute value of an `Int` singleton type.
* ```scala
* val abs: Abs[-1] = 1
Expand All @@ -124,7 +167,7 @@ package object ops {

/** Maximum of two `Int` singleton types.
* ```scala
* val abs: Abs[-1] = 1
* val max: Max[-1, 1] = 1
* ```
*/
type Max[X <: Int, Y <: Int] <: Int
Expand Down
30 changes: 30 additions & 0 deletions tests/neg/singleton-ops-int.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,34 @@ object Test {
val t49: ToString[-1] = "-1"
val t50: ToString[0] = "-0" // error
val t51: ToString[200] = "100" // error

val t52: 1 ^ 2 = 3
val t53: 1 ^ 3 = 3 // error
val t54: -1 ^ -2 = 1
val t55: -1 ^ -3 = 1 // error

val t56: BitwiseOr[1, 2] = 3
val t57: BitwiseOr[10, 12] = 13 // error
val t58: BitwiseOr[-11, 12] = -3
val t59: BitwiseOr[-111, -10] = 0 // error

val t60: BitwiseAnd[1, 1] = 1
val t61: BitwiseAnd[1, 2] = 0
val t62: BitwiseAnd[-1, -3] = 3 // error
val t63: BitwiseAnd[-1, -1] = 1 // error

val t64: 1 << 1 = 2
val t65: 1 << 2 = 4
val t66: 1 << 3 = 8
val t67: 1 << 4 = 0 // error

val t68: 100 >> 2 = 25
val t69: 123456789 >> 71 = 964506
val t70: -7 >> 3 = -1
val t71: -7 >> 3 = 0 // error

val t72: -1 >>> 10000 = 65535
val t73: -7 >>> 3 = 536870911
val t74: -7 >>> 3 = -1 // error

}
3 changes: 2 additions & 1 deletion tests/pos/singleton-ops-composition.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import scala.compiletime.ops.boolean._
import scala.compiletime.ops.int._
import scala.compiletime.ops.int.{^ => ^^,_} // must rename int.^ or get clash with boolean.^

object Test {
val t0: 1 + 2 * 3 = 7
val t1: (2 * 7 + 1) % 10 = 5
val t3: 1 * 1 + 2 * 2 + 3 * 3 + 4 * 4 = 30
val t4: true && false || true && true || false ^ false = true
val t5: BitwiseOr[100 << 2 >>> 2 >> 2 ^^ 3, BitwiseAnd[7, 7]] = 31
}
84 changes: 84 additions & 0 deletions tests/run/singleton-ops-flags.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package example {

import compiletime.S
import compiletime.ops.int.<<

object TastyFlags:

final val EmptyFlags = baseFlags
final val Erased = EmptyFlags.next
final val Internal = Erased.next
final val Inline = Internal.next
final val InlineProxy = Inline.next
final val Opaque = InlineProxy.next
final val Scala2x = Opaque.next
final val Extension = Scala2x.next
final val Given = Extension.next
final val Exported = Given.next
final val NoInits = Exported.next
final val TastyMacro = NoInits.next
final val Enum = TastyMacro.next
final val Open = Enum.next

type LastFlag = Open.idx.type

def (s: FlagSet).debug: String =
if s == EmptyFlags then "EmptyFlags"
else s.toSingletonSets[LastFlag].map ( [n <: Int] => (flag: SingletonFlagSet[n]) => flag match {
case Erased => "Erased"
case Internal => "Internal"
case Inline => "Inline"
case InlineProxy => "InlineProxy"
case Opaque => "Opaque"
case Scala2x => "Scala2x"
case Extension => "Extension"
case Given => "Given"
case Exported => "Exported"
case NoInits => "NoInits"
case TastyMacro => "TastyMacro"
case Enum => "Enum"
case Open => "Open"
}) mkString(" | ")

object opaques:

opaque type FlagSet = Int
opaque type EmptyFlagSet <: FlagSet = 0
opaque type SingletonFlagSet[N <: Int] <: FlagSet = 1 << N

opaque type SingletonSets[N <: Int] = Int

private def [N <: Int](n: N).shift: 1 << N = ( 1 << n ).asInstanceOf
private def [N <: Int](n: N).succ : S[N] = ( n + 1 ).asInstanceOf

final val baseFlags: EmptyFlagSet = 0

def (s: EmptyFlagSet).next: SingletonFlagSet[0] = 1
def [N <: Int: ValueOf](s: SingletonFlagSet[N]).next: SingletonFlagSet[S[N]] = valueOf[N].succ.shift
def [N <: Int: ValueOf](s: SingletonFlagSet[N]).idx: N = valueOf[N]
def [N <: Int](s: FlagSet).toSingletonSets: SingletonSets[N] = s
def (s: FlagSet) | (t: FlagSet): FlagSet = s | t

def [A, N <: Int: ValueOf](ss: SingletonSets[N]).map(f: [t <: Int] => (s: SingletonFlagSet[t]) => A): List[A] =
val maxFlag = valueOf[N]
val buf = List.newBuilder[A]
var current = 0
while (current <= maxFlag) {
val flag = current.shift
if ((flag & ss) != 0) {
buf += f(flag)
}
current += 1
}
buf.result

end opaques

export opaques._

}


import example.TastyFlags._

@main def Test = assert((Open | Given | Inline | Erased).debug == "Erased | Inline | Given | Open")