|
| 1 | +/* |
| 2 | + * Scala (https://www.scala-lang.org) |
| 3 | + * |
| 4 | + * Copyright EPFL and Lightbend, Inc. |
| 5 | + * |
| 6 | + * Licensed under Apache License 2.0 |
| 7 | + * (http://www.apache.org/licenses/LICENSE-2.0). |
| 8 | + * |
| 9 | + * See the NOTICE file distributed with this work for |
| 10 | + * additional information regarding copyright ownership. |
| 11 | + */ |
| 12 | + |
| 13 | +import scala.collection.mutable |
| 14 | + |
| 15 | +object WrapFnGen { |
| 16 | + /** all 43 interfaces in java.util.function package */ |
| 17 | + private lazy val allJfn = Seq( |
| 18 | + "BiConsumer[T, U]: accept(T, U): Unit", |
| 19 | + "BiFunction[T, U, R]: apply(T, U): R", |
| 20 | + "BiPredicate[T, U]: test(T, U): Boolean", |
| 21 | + "BinaryOperator[T]: apply(T, T): T", |
| 22 | + "BooleanSupplier: getAsBoolean: Boolean", |
| 23 | + "Consumer[T]: accept(T): Unit", |
| 24 | + "DoubleBinaryOperator: applyAsDouble(Double, Double): Double", |
| 25 | + "DoubleConsumer: accept(Double): Unit", |
| 26 | + "DoubleFunction[R]: apply(Double): R", |
| 27 | + "DoublePredicate: test(Double): Boolean", |
| 28 | + "DoubleSupplier: getAsDouble: Double", |
| 29 | + "DoubleToIntFunction: applyAsInt(Double): Int", |
| 30 | + "DoubleToLongFunction: applyAsLong(Double): Long", |
| 31 | + "DoubleUnaryOperator: applyAsDouble(Double): Double", |
| 32 | + "Function[T, R]: apply(T): R", |
| 33 | + "IntBinaryOperator: applyAsInt(Int, Int): Int", |
| 34 | + "IntConsumer: accept(Int): Unit", |
| 35 | + "IntFunction[R]: apply(Int): R", |
| 36 | + "IntPredicate: test(Int): Boolean", |
| 37 | + "IntSupplier: getAsInt: Int", |
| 38 | + "IntToDoubleFunction: applyAsDouble(Int): Double", |
| 39 | + "IntToLongFunction: applyAsLong(Int): Long", |
| 40 | + "IntUnaryOperator: applyAsInt(Int): Int", |
| 41 | + "LongBinaryOperator: applyAsLong(Long, Long): Long", |
| 42 | + "LongConsumer: accept(Long): Unit", |
| 43 | + "LongFunction[R]: apply(Long): R", |
| 44 | + "LongPredicate: test(Long): Boolean", |
| 45 | + "LongSupplier: getAsLong: Long", |
| 46 | + "LongToDoubleFunction: applyAsDouble(Long): Double", |
| 47 | + "LongToIntFunction: applyAsInt(Long): Int", |
| 48 | + "LongUnaryOperator: applyAsLong(Long): Long", |
| 49 | + "ObjDoubleConsumer[T]: accept(T, Double): Unit", |
| 50 | + "ObjIntConsumer[T]: accept(T, Int): Unit", |
| 51 | + "ObjLongConsumer[T]: accept(T, Long): Unit", |
| 52 | + "Predicate[T]: test(T): Boolean", |
| 53 | + "Supplier[T]: get: T", |
| 54 | + "ToDoubleBiFunction[T, U]: applyAsDouble(T, U): Double", |
| 55 | + "ToDoubleFunction[T]: applyAsDouble(T): Double", |
| 56 | + "ToIntBiFunction[T, U]: applyAsInt(T, U): Int", |
| 57 | + "ToIntFunction[T]: applyAsInt(T): Int", |
| 58 | + "ToLongBiFunction[T, U]: applyAsLong(T, U): Long", |
| 59 | + "ToLongFunction[T]: applyAsLong(T): Long", |
| 60 | + "UnaryOperator[T]: apply(T): T", |
| 61 | + ).map(Jfn.apply) |
| 62 | + |
| 63 | + /** @param sig - ex: "BiConsumer[T,U]: accept(T,U): Unit" |
| 64 | + * or "DoubleToIntFunction: applyAsInt(Double): Int" */ |
| 65 | + case class Jfn(sig: String) { |
| 66 | + val Array( |
| 67 | + iface, // interface name included type args, ex: BiConsumer[T,U] | DoubleToIntFunction |
| 68 | + _method, // Temp val, ex: accept(T,U) | applyAsInt(Double) |
| 69 | + rType // java function return type, ex: Unit | Int |
| 70 | + ) = sig.split(':').map(_.trim) |
| 71 | + |
| 72 | + // interface name and java interface's type args, |
| 73 | + // ex: ("BiConsumer", "[T,U]") | ("DoubleToIntFunction", "") |
| 74 | + val (ifaceName, jtargs) = iface.span(_ != '[') |
| 75 | + |
| 76 | + // java method name and temp val, ex: "accept" -> "(T,U)" | "applyAsInt" -> "(Double)" |
| 77 | + val (jmethod, _targs) = _method.span(_ != '(') |
| 78 | + |
| 79 | + // java method's type args, ex: Seq("T", "U") | Seq("Double") |
| 80 | + val pTypes: Seq[String] = _targs.unwrapMe |
| 81 | + |
| 82 | + // arguments names, ex: Seq("x1", "x2") |
| 83 | + val args: Seq[String] = pTypes.indices.map { i => "x" + (i+1) } |
| 84 | + // ex: "(x1: T, x2: U)" | "(x1: Double)" |
| 85 | + val argsDecl: String = args.zip(pTypes).map { |
| 86 | + // Don't really need this case. Only here so the generated code is |
| 87 | + // exactly == the code gen by the old method using scala-compiler + scala-reflect |
| 88 | + case (p, t @ ("Double"|"Long"|"Int")) => s"$p: scala.$t" |
| 89 | + case (p, t) => s"$p: $t" |
| 90 | + }.mkString("(", ", ", ")") |
| 91 | + // ex: "(x1, x2)" |
| 92 | + val argsCall: String = args.mkString("(", ", ", ")") |
| 93 | + |
| 94 | + // arity of scala.Function |
| 95 | + val arity: Int = args.length |
| 96 | + |
| 97 | + // ex: "java.util.function.BiConsumer[T,U]" | "java.util.function.DoubleToIntFunction" |
| 98 | + val javaFn = s"java.util.function.$iface" |
| 99 | + |
| 100 | + // ex: "scala.Function2[T, U, Unit]" | "scala.Function1[Double, Int]" |
| 101 | + val scalaFn = s"scala.Function$arity[${(pTypes :+ rType).mkString(", ")}]" |
| 102 | + |
| 103 | + def fromJavaCls: String = |
| 104 | + s"""class FromJava$iface(jf: $javaFn) extends $scalaFn { |
| 105 | + | def apply$argsDecl = jf.$jmethod$argsCall |
| 106 | + |}""".stripMargin |
| 107 | + |
| 108 | + val richAsFnClsName = s"Rich${ifaceName}AsFunction$arity$jtargs" |
| 109 | + def richAsFnCls: String = |
| 110 | + s"""class $richAsFnClsName(private val underlying: $javaFn) extends AnyVal { |
| 111 | + | @inline def asScala: $scalaFn = new FromJava$iface(underlying) |
| 112 | + |}""".stripMargin |
| 113 | + |
| 114 | + def asJavaCls: String = |
| 115 | + s"""class AsJava$iface(sf: $scalaFn) extends $javaFn { |
| 116 | + | def $jmethod$argsDecl = sf.apply$argsCall |
| 117 | + |}""".stripMargin |
| 118 | + |
| 119 | + val richFnAsClsName = s"RichFunction${arity}As$iface" |
| 120 | + def richFnAsCls: String = |
| 121 | + s"""class $richFnAsClsName(private val underlying: $scalaFn) extends AnyVal { |
| 122 | + | @inline def asJava: $javaFn = new AsJava$iface(underlying) |
| 123 | + |}""".stripMargin |
| 124 | + |
| 125 | + def converterImpls: String = |
| 126 | + s"""$fromJavaCls\n |
| 127 | + |$richAsFnCls\n |
| 128 | + |$asJavaCls\n |
| 129 | + |$richFnAsCls\n |
| 130 | + |""".stripMargin |
| 131 | + |
| 132 | + /** @return "implicit def enrichAsJavaXX.." code */ |
| 133 | + def enrichAsJavaDef: String = { |
| 134 | + // This is especially tricky because functions are contravariant in their arguments |
| 135 | + // Need to prevent e.g. Any => String from "downcasting" itself to Int => String; we want the more exact conversion |
| 136 | + // Instead of foo[A](f: (Int, A) => Long): Fuu[A] = new Foo[A](f) |
| 137 | + // we want foo[X, A](f: (X, A) => Long)(implicit evX: Int =:= X): Fuu[A] = new Foo[A](f.asInstanceOf[(Int, A) => Long]) |
| 138 | + // Instead of bar[A](f: A => A): Brr[A] = new Foo[A](f) |
| 139 | + // we want bar[A, B](f: A => B)(implicit evB: A =:= B): Brr[A] = new Foo[A](f.asInstanceOf[A => B]) |
| 140 | + |
| 141 | + val finalTypes = Set("Double", "Long", "Int", "Boolean", "Unit") |
| 142 | + val An = "A(\\d+)".r |
| 143 | + val numberedA = mutable.Set.empty[Int] |
| 144 | + val evidences = mutable.ArrayBuffer.empty[(String, String)] // ex: "A0" -> "Double" |
| 145 | + numberedA ++= pTypes.collect{ case An(digits) if (digits.length < 10) => digits.toInt } |
| 146 | + val scalafnTnames = (pTypes :+ rType).zipWithIndex.map { |
| 147 | + case (pt, i) if i < pTypes.length && finalTypes(pt) || !finalTypes(pt) && pTypes.take(i).contains(pt) => |
| 148 | + val j = Iterator.from(i).dropWhile(numberedA).next() |
| 149 | + val genericName = s"A$j" |
| 150 | + numberedA += j |
| 151 | + evidences += (genericName -> pt) |
| 152 | + genericName |
| 153 | + case (pt, _) => pt |
| 154 | + } |
| 155 | + val scalafnTdefs = scalafnTnames.dropRight(if (finalTypes(rType)) 1 else 0).wrapMe() |
| 156 | + val scalaFnGeneric = s"scala.Function${scalafnTnames.length - 1}[${scalafnTnames.mkString(", ")}]" |
| 157 | + val evs = evidences |
| 158 | + .map { case (generic, specific) => s"ev$generic: =:=[$generic, $specific]" } |
| 159 | + .wrapMe("(implicit ", ")") |
| 160 | + val sf = if (evs.isEmpty) "sf" else s"sf.asInstanceOf[$scalaFn]" |
| 161 | + s"@inline implicit def enrichAsJava$ifaceName$scalafnTdefs(sf: $scalaFnGeneric)$evs: $richFnAsClsName = new $richFnAsClsName($sf)" |
| 162 | + } |
| 163 | + |
| 164 | + def asScalaFromDef = s"@inline def asScalaFrom$iface(jf: $javaFn): $scalaFn = new FromJava$iface(jf)" |
| 165 | + |
| 166 | + def asJavaDef = s"@inline def asJava$iface(sf: $scalaFn): $javaFn = new AsJava$iface(sf)" |
| 167 | + |
| 168 | + def enrichAsScalaDef = s"@inline implicit def enrichAsScalaFrom$iface(jf: $javaFn): $richAsFnClsName = new $richAsFnClsName(jf)" |
| 169 | + } |
| 170 | + |
| 171 | + def converters: String = { |
| 172 | + val groups = allJfn |
| 173 | + .map(jfn => jfn.jtargs.unwrapMe.length -> jfn.enrichAsJavaDef) |
| 174 | + .groupBy(_._1) |
| 175 | + .toSeq |
| 176 | + .sortBy(_._1) |
| 177 | + .reverse |
| 178 | + val maxPriority = groups.head._1 |
| 179 | + groups.map { case (priority, seq) => |
| 180 | + val parent = |
| 181 | + if (priority == maxPriority) "" |
| 182 | + else s" extends Priority${priority + 1}FunctionConverters" |
| 183 | + val me = |
| 184 | + if (priority == 0) "package object FunctionConverters" |
| 185 | + else s"trait Priority${priority}FunctionConverters" |
| 186 | + |
| 187 | + val enrichAsJava = seq.map(_._2) |
| 188 | + val (asXx, enrichAsScala) = |
| 189 | + if (priority != 0) Nil -> Nil |
| 190 | + else allJfn.map { jfn => jfn.asScalaFromDef + "\n\n" + jfn.asJavaDef } -> |
| 191 | + allJfn.map(_.enrichAsScalaDef) |
| 192 | + |
| 193 | + s"""$me$parent { |
| 194 | + | import functionConverterImpls._ |
| 195 | + |${asXx.mkString("\n\n\n").indentMe} |
| 196 | + |${enrichAsJava.mkString("\n\n").indentMe} |
| 197 | + |${enrichAsScala.mkString("\n\n").indentMe} |
| 198 | + |}""".stripMargin |
| 199 | + }.mkString("\n\n\n") |
| 200 | + } |
| 201 | + |
| 202 | + def code: String = |
| 203 | + s""" |
| 204 | + |/* |
| 205 | + | * Copyright EPFL and Lightbend, Inc. |
| 206 | + | * This file auto-generated by WrapFnGen.scala. Do not modify directly. |
| 207 | + | */ |
| 208 | + | |
| 209 | + |package scala.compat.java8 |
| 210 | + | |
| 211 | + |import language.implicitConversions |
| 212 | + | |
| 213 | + |package functionConverterImpls { |
| 214 | + |${allJfn.map(_.converterImpls).mkString("\n").indentMe} |
| 215 | + |} |
| 216 | + |\n |
| 217 | + |$converters |
| 218 | + |""".stripMargin |
| 219 | + |
| 220 | + implicit class StringExt(private val s: String) extends AnyVal { |
| 221 | + def indentMe: String = s.linesIterator.map(" " + _).mkString("\n") |
| 222 | + def unwrapMe: Seq[String] = s match { |
| 223 | + case "" => Nil |
| 224 | + case _ => s |
| 225 | + .substring(1, s.length - 1) // drop "(" and ")" or "[" and "]" |
| 226 | + .split(',').map(_.trim).toSeq |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | + implicit class WrapMe(private val s: Seq[String]) extends AnyVal { |
| 231 | + def wrapMe(start: String = "[", end: String = "]"): String = s match { |
| 232 | + case Nil => "" |
| 233 | + case _ => s.mkString(start, ", ", end) |
| 234 | + } |
| 235 | + } |
| 236 | +} |
0 commit comments