@@ -2,24 +2,57 @@ package dotty.tools.dotc
2
2
package transform .localopt
3
3
4
4
import scala .annotation .tailrec
5
- import scala .collection .mutable .{ListBuffer , Stack }
6
- import scala .reflect .{ClassTag , classTag }
5
+ import scala .collection .mutable .ListBuffer
7
6
import scala .util .chaining .*
8
7
import scala .util .matching .Regex .Match
9
8
10
9
import java .util .{Calendar , Date , Formattable }
11
10
12
11
import PartialFunction .cond
13
12
13
+ import dotty .tools .dotc .ast .tpd .{Match => _ , * }
14
+ import dotty .tools .dotc .core .Contexts ._
15
+ import dotty .tools .dotc .core .Symbols ._
16
+ import dotty .tools .dotc .core .Types ._
17
+ import dotty .tools .dotc .core .Phases .typerPhase
18
+
14
19
/** Formatter string checker. */
15
- abstract class FormatChecker (using reporter : InterpolationReporter ):
20
+ class TypedFormatChecker (args : List [Tree ])(using Context )(using reporter : InterpolationReporter ):
21
+
22
+ val argTypes = args.map(_.tpe)
23
+ val actuals = ListBuffer .empty[Tree ]
24
+
25
+ // count of args, for checking indexes
26
+ val argc = argTypes.length
16
27
17
28
// Pick the first runtime type which the i'th arg can satisfy.
18
29
// If conversion is required, implementation must emit it.
19
- def argType (argi : Int , types : ClassTag [? ]* ): ClassTag [? ]
30
+ def argType (argi : Int , types : Type * ): Type =
31
+ require(argi < argc, s " $argi out of range picking from $types" )
32
+ val tpe = argTypes(argi)
33
+ types.find(t => argConformsTo(argi, tpe, t))
34
+ .orElse(types.find(t => argConvertsTo(argi, tpe, t)))
35
+ .getOrElse {
36
+ reporter.argError(s " Found: ${tpe.show}, Required: ${types.mkString(" , " )}" , argi)
37
+ actuals += args(argi)
38
+ types.head
39
+ }
20
40
21
- // count of args, for checking indexes
22
- def argc : Int
41
+ object formattableTypes :
42
+ val FormattableType = requiredClassRef(" java.util.Formattable" )
43
+ val BigIntType = requiredClassRef(" scala.math.BigInt" )
44
+ val BigDecimalType = requiredClassRef(" scala.math.BigDecimal" )
45
+ val CalendarType = requiredClassRef(" java.util.Calendar" )
46
+ val DateType = requiredClassRef(" java.util.Date" )
47
+ import formattableTypes .*
48
+ def argConformsTo (argi : Int , arg : Type , target : Type ): Boolean = (arg <:< target).tap(if _ then actuals += args(argi))
49
+ def argConvertsTo (argi : Int , arg : Type , target : Type ): Boolean =
50
+ import typer .Implicits .SearchSuccess
51
+ atPhase(typerPhase) {
52
+ ctx.typer.inferView(args(argi), target) match
53
+ case SearchSuccess (view, ref, _, _) => actuals += view ; true
54
+ case _ => false
55
+ }
23
56
24
57
// match a conversion specifier
25
58
val formatPattern = """ %(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""" .r
@@ -51,7 +84,7 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
51
84
def insertStringConversion (): Unit =
52
85
amended += " %s" + part
53
86
convert += Conversion (formatPattern.findAllMatchIn(" %s" ).next(), n) // improve
54
- argType(n- 1 , classTag[ Any ] )
87
+ argType(n- 1 , defn. AnyType )
55
88
def errorLeading (op : Conversion ) = op.errorAt(Spec )(s " conversions must follow a splice; ${Conversion .literalHelp}" )
56
89
def accept (op : Conversion ): Unit =
57
90
if ! op.isLeading then errorLeading(op)
@@ -66,11 +99,7 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
66
99
val cv = Conversion (matches.next(), n)
67
100
if cv.isLiteral then insertStringConversion()
68
101
else if cv.isIndexed then
69
- if cv.index.getOrElse(- 1 ) == n then accept(cv)
70
- else
71
- // either some other arg num, or '<'
72
- // c.warning(op.groupPos(Index), "Index is not this arg")
73
- insertStringConversion()
102
+ if cv.index.getOrElse(- 1 ) == n then accept(cv) else insertStringConversion()
74
103
else if ! cv.isError then accept(cv)
75
104
76
105
// any remaining conversions in this part must be either literals or indexed
@@ -128,12 +157,26 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
128
157
// descriptor is at index 0 of the part string
129
158
def isLeading : Boolean = descriptor.at(Spec ) == 0
130
159
131
- // flags and index in specifier are ok
132
- private def goodies = goodFlags && goodIndex
133
-
134
- // true if passes. Default checks flags and index
160
+ // true if passes.
135
161
def verify : Boolean =
136
- kind match {
162
+ // various assertions
163
+ def goodies = goodFlags && goodIndex
164
+ def noFlags = flags.isEmpty or errorAt(Flags )(" flags not allowed" )
165
+ def noWidth = width.isEmpty or errorAt(Width )(" width not allowed" )
166
+ def noPrecision = precision.isEmpty or errorAt(Precision )(" precision not allowed" )
167
+ def only_- (msg : String ) =
168
+ val badFlags = flags.filterNot { case '-' | '<' => true case _ => false }
169
+ badFlags.isEmpty or badFlag(badFlags(0 ), s " Only '-' allowed for $msg" )
170
+ def goodFlags =
171
+ val badFlags = flags.filterNot(okFlags.contains)
172
+ for f <- badFlags do badFlag(f, s " Illegal flag ' $f' " )
173
+ badFlags.isEmpty
174
+ def goodIndex =
175
+ if index.nonEmpty && hasFlag('<' ) then warningAt(Index )(" Argument index ignored if '<' flag is present" )
176
+ val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true )
177
+ okRange || hasFlag('<' ) or errorAt(Index )(" Argument index out of range" )
178
+ // begin verify
179
+ kind match
137
180
case StringXn => goodies
138
181
case BooleanXn => goodies
139
182
case HashXn => goodies
@@ -143,58 +186,55 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
143
186
def x_comma = cc != 'd' && hasFlag(',' ) and badFlag(',' , " ',' only allowed for d conversion of integral types" )
144
187
goodies && noPrecision && ! d_# && ! x_comma
145
188
case FloatingPointXn =>
146
- goodies && (cc match {
189
+ goodies && (cc match
147
190
case 'a' | 'A' =>
148
191
val badFlags = " ,(" .filter(hasFlag)
149
192
noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s " ' $badf' not allowed for a, A " ))
150
193
case _ => true
151
- } )
194
+ )
152
195
case DateTimeXn =>
153
196
def hasCC = op.length == 2 or errorAt(CC )(" Date/time conversion must have two characters" )
154
197
def goodCC = " HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc" .contains(cc) or errorAt(CC , 1 )(s " ' $cc' doesn't seem to be a date or time conversion " )
155
198
goodies && hasCC && goodCC && noPrecision && only_-(" date/time conversions" )
156
199
case LiteralXn =>
157
- op match {
200
+ op match
158
201
case " %" => goodies && noPrecision and width.foreach(_ => warningAt(Width )(" width ignored on literal" ))
159
202
case " n" => noFlags && noWidth && noPrecision
160
- }
161
203
case ErrorXn =>
162
204
errorAt(CC )(s " illegal conversion character ' $cc' " )
163
205
false
164
- }
206
+ end verify
165
207
166
208
// is the specifier OK with the given arg
167
- def accepts (arg : ClassTag [ ? ] ): Boolean =
209
+ def accepts (arg : Type ): Boolean =
168
210
kind match
169
- case BooleanXn => arg == classTag[ Boolean ] orElse warningAt(CC )(" Boolean format is null test for non-Boolean" )
211
+ case BooleanXn => arg == defn. BooleanType orElse warningAt(CC )(" Boolean format is null test for non-Boolean" )
170
212
case IntegralXn =>
171
- arg == classTag[ BigInt ] || ! cond(cc) {
213
+ arg == BigIntType || ! cond(cc) {
172
214
case 'o' | 'x' | 'X' if hasAnyFlag(" + (" ) => " + (" .filter(hasFlag).foreach(bad => badFlag(bad, s " only use ' $bad' for BigInt conversions to o, x, X " )) ; true
173
215
}
174
216
case _ => true
175
217
176
218
// what arg type if any does the conversion accept
177
- def acceptableVariants : List [ClassTag [ ? ] ] =
178
- kind match {
179
- case StringXn => if hasFlag('#' ) then classTag[ Formattable ] :: Nil else classTag[ Any ] :: Nil
180
- case BooleanXn => classTag[ Boolean ] :: Conversion . FakeNullTag :: Nil
181
- case HashXn => classTag[ Any ] :: Nil
182
- case CharacterXn => classTag[ Char ] :: classTag[ Byte ] :: classTag[ Short ] :: classTag[ Int ] :: Nil
183
- case IntegralXn => classTag[ Int ] :: classTag[ Long ] :: classTag[ Byte ] :: classTag[ Short ] :: classTag[ BigInt ] :: Nil
184
- case FloatingPointXn => classTag[ Double ] :: classTag[ Float ] :: classTag[ BigDecimal ] :: Nil
185
- case DateTimeXn => classTag[ Long ] :: classTag[ Calendar ] :: classTag[ Date ] :: Nil
219
+ def acceptableVariants : List [Type ] =
220
+ kind match
221
+ case StringXn => if hasFlag('#' ) then FormattableType :: Nil else defn. AnyType :: Nil
222
+ case BooleanXn => defn. BooleanType :: defn. NullType :: Nil
223
+ case HashXn => defn. AnyType :: Nil
224
+ case CharacterXn => defn. CharType :: defn. ByteType :: defn. ShortType :: defn. IntType :: Nil
225
+ case IntegralXn => defn. IntType :: defn. LongType :: defn. ByteType :: defn. ShortType :: BigIntType :: Nil
226
+ case FloatingPointXn => defn. DoubleType :: defn. FloatType :: BigDecimalType :: Nil
227
+ case DateTimeXn => defn. LongType :: CalendarType :: DateType :: Nil
186
228
case LiteralXn => Nil
187
229
case ErrorXn => Nil
188
- }
189
230
190
231
// what flags does the conversion accept?
191
232
private def okFlags : String =
192
- kind match {
233
+ kind match
193
234
case StringXn => " -#<"
194
235
case BooleanXn | HashXn => " -<"
195
236
case LiteralXn => " -"
196
237
case _ => " -#+ 0,(<"
197
- }
198
238
199
239
def hasFlag (f : Char ) = flags.contains(f)
200
240
def hasAnyFlag (fs : String ) = fs.exists(hasFlag)
@@ -206,21 +246,6 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
206
246
def errorAt (g : SpecGroup , i : Int = 0 )(msg : String ) = reporter.partError(msg, argi, descriptor.offset(g, i))
207
247
def warningAt (g : SpecGroup , i : Int = 0 )(msg : String ) = reporter.partWarning(msg, argi, descriptor.offset(g, i))
208
248
209
- // various assertions
210
- def noFlags = flags.isEmpty or errorAt(Flags )(" flags not allowed" )
211
- def noWidth = width.isEmpty or errorAt(Width )(" width not allowed" )
212
- def noPrecision = precision.isEmpty or errorAt(Precision )(" precision not allowed" )
213
- def only_- (msg : String ) =
214
- val badFlags = flags.filterNot { case '-' | '<' => true case _ => false }
215
- badFlags.isEmpty or badFlag(badFlags(0 ), s " Only '-' allowed for $msg" )
216
- def goodFlags =
217
- val badFlags = flags.filterNot(okFlags.contains)
218
- for f <- badFlags do badFlag(f, s " Illegal flag ' $f' " )
219
- badFlags.isEmpty
220
- def goodIndex =
221
- if index.nonEmpty && hasFlag('<' ) then warningAt(Index )(" Argument index ignored if '<' flag is present" )
222
- val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true )
223
- okRange || hasFlag('<' ) or errorAt(Index )(" Argument index out of range" )
224
249
object Conversion :
225
250
def apply (m : Match , i : Int ): Conversion =
226
251
def kindOf (cc : Char ) = cc match
@@ -243,5 +268,5 @@ abstract class FormatChecker(using reporter: InterpolationReporter):
243
268
case None => new Conversion (m, i, ErrorXn ).tap(_.errorAt(Spec )(s " Missing conversion operator in ' ${m.matched}'; $literalHelp" ))
244
269
end apply
245
270
val literalHelp = " use %% for literal %, %n for newline"
246
- private val FakeNullTag : ClassTag [? ] = null
247
271
end Conversion
272
+ end TypedFormatChecker
0 commit comments