Skip to content

Commit 5e62091

Browse files
committed
Simplify TypedFormatChecker
1 parent e65c5e6 commit 5e62091

File tree

3 files changed

+81
-203
lines changed

3 files changed

+81
-203
lines changed

compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala

Lines changed: 78 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,57 @@ package dotty.tools.dotc
22
package transform.localopt
33

44
import scala.annotation.tailrec
5-
import scala.collection.mutable.{ListBuffer, Stack}
6-
import scala.reflect.{ClassTag, classTag}
5+
import scala.collection.mutable.ListBuffer
76
import scala.util.chaining.*
87
import scala.util.matching.Regex.Match
98

109
import java.util.{Calendar, Date, Formattable}
1110

1211
import PartialFunction.cond
1312

13+
import dotty.tools.dotc.ast.tpd.{Match => _, *}
14+
import dotty.tools.dotc.core.Contexts._
15+
import dotty.tools.dotc.core.Symbols._
16+
import dotty.tools.dotc.core.Types._
17+
import dotty.tools.dotc.core.Phases.typerPhase
18+
1419
/** Formatter string checker. */
15-
abstract class FormatChecker(using reporter: InterpolationReporter):
20+
class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: InterpolationReporter):
21+
22+
val argTypes = args.map(_.tpe)
23+
val actuals = ListBuffer.empty[Tree]
24+
25+
// count of args, for checking indexes
26+
val argc = argTypes.length
1627

1728
// Pick the first runtime type which the i'th arg can satisfy.
1829
// If conversion is required, implementation must emit it.
19-
def argType(argi: Int, types: ClassTag[?]*): ClassTag[?]
30+
def argType(argi: Int, types: Type*): Type =
31+
require(argi < argc, s"$argi out of range picking from $types")
32+
val tpe = argTypes(argi)
33+
types.find(t => argConformsTo(argi, tpe, t))
34+
.orElse(types.find(t => argConvertsTo(argi, tpe, t)))
35+
.getOrElse {
36+
reporter.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi)
37+
actuals += args(argi)
38+
types.head
39+
}
2040

21-
// count of args, for checking indexes
22-
def argc: Int
41+
object formattableTypes:
42+
val FormattableType = requiredClassRef("java.util.Formattable")
43+
val BigIntType = requiredClassRef("scala.math.BigInt")
44+
val BigDecimalType = requiredClassRef("scala.math.BigDecimal")
45+
val CalendarType = requiredClassRef("java.util.Calendar")
46+
val DateType = requiredClassRef("java.util.Date")
47+
import formattableTypes.*
48+
def argConformsTo(argi: Int, arg: Type, target: Type): Boolean = (arg <:< target).tap(if _ then actuals += args(argi))
49+
def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean =
50+
import typer.Implicits.SearchSuccess
51+
atPhase(typerPhase) {
52+
ctx.typer.inferView(args(argi), target) match
53+
case SearchSuccess(view, ref, _, _) => actuals += view ; true
54+
case _ => false
55+
}
2356

2457
// match a conversion specifier
2558
val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r
@@ -51,7 +84,7 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
5184
def insertStringConversion(): Unit =
5285
amended += "%s" + part
5386
convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve
54-
argType(n-1, classTag[Any])
87+
argType(n-1, defn.AnyType)
5588
def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}")
5689
def accept(op: Conversion): Unit =
5790
if !op.isLeading then errorLeading(op)
@@ -66,11 +99,7 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
6699
val cv = Conversion(matches.next(), n)
67100
if cv.isLiteral then insertStringConversion()
68101
else if cv.isIndexed then
69-
if cv.index.getOrElse(-1) == n then accept(cv)
70-
else
71-
// either some other arg num, or '<'
72-
//c.warning(op.groupPos(Index), "Index is not this arg")
73-
insertStringConversion()
102+
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion()
74103
else if !cv.isError then accept(cv)
75104

76105
// any remaining conversions in this part must be either literals or indexed
@@ -128,12 +157,26 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
128157
// descriptor is at index 0 of the part string
129158
def isLeading: Boolean = descriptor.at(Spec) == 0
130159

131-
// flags and index in specifier are ok
132-
private def goodies = goodFlags && goodIndex
133-
134-
// true if passes. Default checks flags and index
160+
// true if passes.
135161
def verify: Boolean =
136-
kind match {
162+
// various assertions
163+
def goodies = goodFlags && goodIndex
164+
def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed")
165+
def noWidth = width.isEmpty or errorAt(Width)("width not allowed")
166+
def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed")
167+
def only_-(msg: String) =
168+
val badFlags = flags.filterNot { case '-' | '<' => true case _ => false }
169+
badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg")
170+
def goodFlags =
171+
val badFlags = flags.filterNot(okFlags.contains)
172+
for f <- badFlags do badFlag(f, s"Illegal flag '$f'")
173+
badFlags.isEmpty
174+
def goodIndex =
175+
if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present")
176+
val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true)
177+
okRange || hasFlag('<') or errorAt(Index)("Argument index out of range")
178+
// begin verify
179+
kind match
137180
case StringXn => goodies
138181
case BooleanXn => goodies
139182
case HashXn => goodies
@@ -143,58 +186,55 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
143186
def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types")
144187
goodies && noPrecision && !d_# && !x_comma
145188
case FloatingPointXn =>
146-
goodies && (cc match {
189+
goodies && (cc match
147190
case 'a' | 'A' =>
148191
val badFlags = ",(".filter(hasFlag)
149192
noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A"))
150193
case _ => true
151-
})
194+
)
152195
case DateTimeXn =>
153196
def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters")
154197
def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion")
155198
goodies && hasCC && goodCC && noPrecision && only_-("date/time conversions")
156199
case LiteralXn =>
157-
op match {
200+
op match
158201
case "%" => goodies && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal"))
159202
case "n" => noFlags && noWidth && noPrecision
160-
}
161203
case ErrorXn =>
162204
errorAt(CC)(s"illegal conversion character '$cc'")
163205
false
164-
}
206+
end verify
165207

166208
// is the specifier OK with the given arg
167-
def accepts(arg: ClassTag[?]): Boolean =
209+
def accepts(arg: Type): Boolean =
168210
kind match
169-
case BooleanXn => arg == classTag[Boolean] orElse warningAt(CC)("Boolean format is null test for non-Boolean")
211+
case BooleanXn => arg == defn.BooleanType orElse warningAt(CC)("Boolean format is null test for non-Boolean")
170212
case IntegralXn =>
171-
arg == classTag[BigInt] || !cond(cc) {
213+
arg == BigIntType || !cond(cc) {
172214
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true
173215
}
174216
case _ => true
175217

176218
// what arg type if any does the conversion accept
177-
def acceptableVariants: List[ClassTag[?]] =
178-
kind match {
179-
case StringXn => if hasFlag('#') then classTag[Formattable] :: Nil else classTag[Any] :: Nil
180-
case BooleanXn => classTag[Boolean] :: Conversion.FakeNullTag :: Nil
181-
case HashXn => classTag[Any] :: Nil
182-
case CharacterXn => classTag[Char] :: classTag[Byte] :: classTag[Short] :: classTag[Int] :: Nil
183-
case IntegralXn => classTag[Int] :: classTag[Long] :: classTag[Byte] :: classTag[Short] :: classTag[BigInt] :: Nil
184-
case FloatingPointXn => classTag[Double] :: classTag[Float] :: classTag[BigDecimal] :: Nil
185-
case DateTimeXn => classTag[Long] :: classTag[Calendar] :: classTag[Date] :: Nil
219+
def acceptableVariants: List[Type] =
220+
kind match
221+
case StringXn => if hasFlag('#') then FormattableType :: Nil else defn.AnyType :: Nil
222+
case BooleanXn => defn.BooleanType :: defn.NullType :: Nil
223+
case HashXn => defn.AnyType :: Nil
224+
case CharacterXn => defn.CharType :: defn.ByteType :: defn.ShortType :: defn.IntType :: Nil
225+
case IntegralXn => defn.IntType :: defn.LongType :: defn.ByteType :: defn.ShortType :: BigIntType :: Nil
226+
case FloatingPointXn => defn.DoubleType :: defn.FloatType :: BigDecimalType :: Nil
227+
case DateTimeXn => defn.LongType :: CalendarType :: DateType :: Nil
186228
case LiteralXn => Nil
187229
case ErrorXn => Nil
188-
}
189230

190231
// what flags does the conversion accept?
191232
private def okFlags: String =
192-
kind match {
233+
kind match
193234
case StringXn => "-#<"
194235
case BooleanXn | HashXn => "-<"
195236
case LiteralXn => "-"
196237
case _ => "-#+ 0,(<"
197-
}
198238

199239
def hasFlag(f: Char) = flags.contains(f)
200240
def hasAnyFlag(fs: String) = fs.exists(hasFlag)
@@ -206,21 +246,6 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
206246
def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partError(msg, argi, descriptor.offset(g, i))
207247
def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partWarning(msg, argi, descriptor.offset(g, i))
208248

209-
// various assertions
210-
def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed")
211-
def noWidth = width.isEmpty or errorAt(Width)("width not allowed")
212-
def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed")
213-
def only_-(msg: String) =
214-
val badFlags = flags.filterNot { case '-' | '<' => true case _ => false }
215-
badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg")
216-
def goodFlags =
217-
val badFlags = flags.filterNot(okFlags.contains)
218-
for f <- badFlags do badFlag(f, s"Illegal flag '$f'")
219-
badFlags.isEmpty
220-
def goodIndex =
221-
if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present")
222-
val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true)
223-
okRange || hasFlag('<') or errorAt(Index)("Argument index out of range")
224249
object Conversion:
225250
def apply(m: Match, i: Int): Conversion =
226251
def kindOf(cc: Char) = cc match
@@ -243,5 +268,5 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
243268
case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp"))
244269
end apply
245270
val literalHelp = "use %% for literal %, %n for newline"
246-
private val FakeNullTag: ClassTag[?] = null
247271
end Conversion
272+
end TypedFormatChecker

compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala

Lines changed: 3 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,11 @@
11
package dotty.tools.dotc
22
package transform.localopt
33

4-
import dotty.tools.dotc.ast.Trees._
5-
import dotty.tools.dotc.ast.tpd
6-
import dotty.tools.dotc.core.Decorators._
4+
import dotty.tools.dotc.ast.tpd.*
75
import dotty.tools.dotc.core.Constants.Constant
8-
import dotty.tools.dotc.core.Contexts._
9-
import dotty.tools.dotc.core.StdNames._
10-
import dotty.tools.dotc.core.NameKinds._
11-
import dotty.tools.dotc.core.Symbols._
12-
import dotty.tools.dotc.core.Types._
13-
import dotty.tools.dotc.core.Phases.typerPhase
14-
import dotty.tools.dotc.typer.ProtoTypes._
15-
16-
import scala.StringContext.processEscapes
17-
import scala.annotation.tailrec
18-
import scala.collection.mutable.{ListBuffer, Stack}
19-
import scala.reflect.{ClassTag, classTag}
20-
import scala.util.chaining._
21-
import scala.util.matching.Regex.Match
6+
import dotty.tools.dotc.core.Contexts.*
227

238
object FormatInterpolatorTransform:
24-
import tpd._
259

2610
class PartsReporter(fun: Tree, args0: Tree, parts: List[Tree], args: List[Tree])(using Context) extends InterpolationReporter:
2711
private var reported = false
@@ -50,67 +34,6 @@ object FormatInterpolatorTransform:
5034
reported = false
5135
def restoreReported(): Unit = reported = oldReported
5236
end PartsReporter
53-
object tags:
54-
import java.util.{Calendar, Date, Formattable}
55-
val StringTag = classTag[String]
56-
val FormattableTag = classTag[Formattable]
57-
val BigIntTag = classTag[BigInt]
58-
val BigDecimalTag = classTag[BigDecimal]
59-
val CalendarTag = classTag[Calendar]
60-
val DateTag = classTag[Date]
61-
class FormattableTypes(using Context):
62-
val FormattableType = requiredClassRef("java.util.Formattable")
63-
val BigIntType = requiredClassRef("scala.math.BigInt")
64-
val BigDecimalType = requiredClassRef("scala.math.BigDecimal")
65-
val CalendarType = requiredClassRef("java.util.Calendar")
66-
val DateType = requiredClassRef("java.util.Date")
67-
class TypedFormatChecker(val args: List[Tree])(using Context, InterpolationReporter) extends FormatChecker:
68-
val reporter = summon[InterpolationReporter]
69-
val argTypes = args.map(_.tpe)
70-
val actuals = ListBuffer.empty[Tree]
71-
val argc = argTypes.length
72-
def argType(argi: Int, types: Seq[ClassTag[?]]) =
73-
require(argi < argc, s"$argi out of range picking from $types")
74-
val tpe = argTypes(argi)
75-
types.find(t => argConformsTo(argi, tpe, argTypeOf(t)))
76-
.orElse(types.find(t => argConvertsTo(argi, tpe, argTypeOf(t))))
77-
.getOrElse {
78-
reporter.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi)
79-
actuals += args(argi)
80-
types.head
81-
}
82-
final lazy val fmtTypes = FormattableTypes()
83-
import tags.*, fmtTypes.*
84-
def argConformsTo(argi: Int, arg: Type, target: Type): Boolean =
85-
(arg <:< target).tap(if _ then actuals += args(argi))
86-
def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean =
87-
import typer.Implicits.SearchSuccess
88-
atPhase(typerPhase) {
89-
ctx.typer.inferView(args(argi), target) match
90-
case SearchSuccess(view, ref, _, _) => actuals += view ; true
91-
case _ => false
92-
}
93-
def argTypeOf(tag: ClassTag[?]): Type = tag match
94-
case StringTag => defn.StringType
95-
case ClassTag.Boolean => defn.BooleanType
96-
case ClassTag.Byte => defn.ByteType
97-
case ClassTag.Char => defn.CharType
98-
case ClassTag.Short => defn.ShortType
99-
case ClassTag.Int => defn.IntType
100-
case ClassTag.Long => defn.LongType
101-
case ClassTag.Float => defn.FloatType
102-
case ClassTag.Double => defn.DoubleType
103-
case ClassTag.Any => defn.AnyType
104-
case ClassTag.AnyRef => defn.AnyRefType
105-
case FormattableTag => FormattableType
106-
case BigIntTag => BigIntType
107-
case BigDecimalTag => BigDecimalType
108-
case CalendarTag => CalendarType
109-
case DateTag => DateType
110-
case null => defn.NullType
111-
case _ => reporter.strCtxError(s"Unknown type for format $tag")
112-
defn.AnyType
113-
end TypedFormatChecker
11437

11538
/** For f"${arg}%xpart", check format conversions and return (format, args)
11639
* suitable for String.format(format, args).
@@ -146,7 +69,7 @@ object FormatInterpolatorTransform:
14669
if reporter.hasReported then (literally(parts.mkString), args0)
14770
else
14871
assert(checker.argc == checker.actuals.size, s"Expected ${checker.argc}, actuals size is ${checker.actuals.size} for [${parts.mkString(", ")}]")
149-
(literally(checked.mkString), tpd.SeqLiteral(checker.actuals.toList, elemtpt))
72+
(literally(checked.mkString), SeqLiteral(checker.actuals.toList, elemtpt))
15073
end checked
15174
end FormatInterpolatorTransform
15275

0 commit comments

Comments
 (0)