19
19
SPEC = re .compile (
20
20
r'^(?:(?P<void>V)|(?P<id>[iusfIUSF])(?:\((?P<start>\d+)-(?P<end>\d+)\)|'
21
21
r'(?P<width>\d+)(:?/(?P<llvm_width>\d+))?)'
22
- r'|(?P<reference>\d+))(?P<modifiers>[vShdnwusDMC ]*)(?P<force_width>x\d+)?'
22
+ r'|(?P<reference>\d+))(?P<index>\.\d+)?(?P< modifiers>[vShdnwusfDMC ]*)(?P<force_width>x\d+)?'
23
23
r'(?:(?P<pointer>Pm|Pc)(?P<llvm_pointer>/.*)?|(?P<bitcast>->.*))?$'
24
24
)
25
25
@@ -70,23 +70,32 @@ def lookup(raw):
70
70
{k : lookup (v ) for k , v in data .items ()})
71
71
72
72
class PlatformTypeInfo (object ):
73
- def __init__ (self , llvm_name , properties ):
74
- self .properties = properties
75
- self .llvm_name = llvm_name
73
+ def __init__ (self , llvm_name , properties , elems = None ):
74
+ if elems is None :
75
+ self .properties = properties
76
+ self .llvm_name = llvm_name
77
+ else :
78
+ assert properties is None and llvm_name is None
79
+ self .properties = {}
80
+ self .elems = elems
76
81
77
82
def __repr__ (self ):
78
83
return '<PlatformTypeInfo {}, {}>' .format (self .llvm_name , self .properties )
79
84
80
85
def __getattr__ (self , name ):
81
86
return self .properties [name ]
82
87
88
+ def __getitem__ (self , idx ):
89
+ return self .elems [idx ]
90
+
83
91
def vectorize (self , length , width_info ):
84
92
props = self .properties .copy ()
85
93
props .update (width_info )
86
94
return PlatformTypeInfo ('v{}{}' .format (length , self .llvm_name ), props )
87
95
88
- def pointer (self ):
89
- return PlatformTypeInfo ('p0{}' .format (self .llvm_name ), self .properties )
96
+ def pointer (self , llvm_elem ):
97
+ name = self .llvm_name if llvm_elem is None else llvm_elem .llvm_name
98
+ return PlatformTypeInfo ('p0{}' .format (name ), self .properties )
90
99
91
100
BITWIDTH_POINTER = '<pointer>'
92
101
@@ -128,6 +137,8 @@ def modify(self, spec, width, previous):
128
137
return Unsigned (self .bitwidth ())
129
138
elif spec == 's' :
130
139
return Signed (self .bitwidth ())
140
+ elif spec == 'f' :
141
+ return Float (self .bitwidth ())
131
142
elif spec == 'w' :
132
143
return self .__class__ (self .bitwidth () * 2 )
133
144
elif spec == 'n' :
@@ -283,7 +294,11 @@ def rust_name(self):
283
294
self ._elem .rust_name ())
284
295
285
296
def type_info (self , platform_info ):
286
- return self ._elem .type_info (platform_info ).pointer ()
297
+ if self ._llvm_elem is None :
298
+ llvm_elem = None
299
+ else :
300
+ llvm_elem = self ._llvm_elem .type_info (platform_info )
301
+ return self ._elem .type_info (platform_info ).pointer (llvm_elem )
287
302
288
303
def __eq__ (self , other ):
289
304
return isinstance (other , Pointer ) and self ._const == other ._const \
@@ -298,6 +313,14 @@ def __init__(self, flatten, elems):
298
313
def __repr__ (self ):
299
314
return '<Aggregate {}>' .format (self ._elems )
300
315
316
+ def modify (self , spec , width , previous ):
317
+ if spec .startswith ('.' ):
318
+ num = int (spec [1 :])
319
+ return self ._elems [num ]
320
+ else :
321
+ print (spec )
322
+ raise NotImplementedError ()
323
+
301
324
def compiler_ctor (self ):
302
325
return 'agg({}, vec![{}])' .format ('true' if self ._flatten else 'false' ,
303
326
', ' .join (elem .compiler_ctor () for elem in self ._elems ))
@@ -306,8 +329,7 @@ def rust_name(self):
306
329
return '({})' .format (', ' .join (elem .rust_name () for elem in self ._elems ))
307
330
308
331
def type_info (self , platform_info ):
309
- #return PlatformTypeInfo(None, None, self._llvm_name)
310
- return None
332
+ return PlatformTypeInfo (None , None , [elem .type_info (platform_info ) for elem in self ._elems ])
311
333
312
334
def __eq__ (self , other ):
313
335
return isinstance (other , Aggregate ) and self ._flatten == other ._flatten and \
@@ -349,7 +371,11 @@ def enumerate(self, width, previous):
349
371
id = match .group ('id' )
350
372
reference = match .group ('reference' )
351
373
352
- modifiers = list (match .group ('modifiers' ) or '' )
374
+ modifiers = []
375
+ index = match .group ('index' )
376
+ if index is not None :
377
+ modifiers .append (index )
378
+ modifiers += list (match .group ('modifiers' ) or '' )
353
379
force = match .group ('force_width' )
354
380
if force is not None :
355
381
modifiers .append (force )
@@ -407,16 +433,32 @@ def enumerate(self, width, previous):
407
433
else :
408
434
assert False , 'matched `{}`, but didn\' t understand it?' .format (spec )
409
435
elif spec .startswith ('(' ):
410
- assert bitcast is None
411
436
if spec .endswith (')' ):
412
- raise NotImplementedError ()
437
+ true_spec = spec [1 :- 1 ]
438
+ flatten = False
413
439
elif spec .endswith (')f' ):
414
440
true_spec = spec [1 :- 2 ]
415
441
flatten = True
442
+ else :
443
+ assert False , 'found unclosed aggregate `{}`' .format (spec )
416
444
417
445
for elems in itertools .product (* (TypeSpec (subspec ).enumerate (width , previous )
418
446
for subspec in true_spec .split (',' ))):
419
447
yield Aggregate (flatten , elems )
448
+ elif spec .startswith ('[' ):
449
+ if spec .endswith (']' ):
450
+ true_spec = spec [1 :- 1 ]
451
+ flatten = False
452
+ elif spec .endswith (']f' ):
453
+ true_spec = spec [1 :- 2 ]
454
+ flatten = True
455
+ else :
456
+ assert False , 'found unclosed aggregate `{}`' .format (spec )
457
+ elem_spec , count = true_spec .split (';' )
458
+
459
+ count = int (count )
460
+ for elem in TypeSpec (elem_spec ).enumerate (width , previous ):
461
+ yield Aggregate (flatten , [elem ] * count )
420
462
else :
421
463
assert False , 'Failed to parse `{}`' .format (spec )
422
464
@@ -514,7 +556,7 @@ def parse_args():
514
556
core_type := void | vector | scalar | aggregate | reference
515
557
516
558
modifier := 'v' | 'h' | 'd' | 'n' | 'w' | 'u' | 's' |
517
- 'x' number
559
+ 'x' number | '.' number
518
560
suffix := pointer | bitcast
519
561
pointer := 'Pm' llvm_pointer? | 'Pc' llvm_pointer?
520
562
llvm_pointer := '/' type
@@ -529,7 +571,7 @@ def parse_args():
529
571
scalar_type := 'U' | 'S' | 'F'
530
572
llvm_width := '/' number
531
573
532
- aggregate := '(' (type),* ')' 'f'?
574
+ aggregate := '(' (type),* ')' 'f'? | '[' type ';' number ']' 'f'?
533
575
534
576
reference := number
535
577
@@ -586,6 +628,12 @@ def parse_args():
586
628
- no `f` corresponds to `declare ... @llvm.foo({float, i32})`.
587
629
- having an `f` corresponds to `declare ... @llvm.foo(float, i32)`.
588
630
631
+ The `[type;number]` form is a just shorter way to write
632
+ `(...)`, except avoids doing a cartesian product of generic
633
+ types, e.g. `[S32;2]` is the same as `(S32, S32)`, while
634
+ `[I32;2]` is describing just the two types `(S32,S32)` and
635
+ `(U32,U32)` (i.e. doesn't include `(S32,U32)`, `(U32,S32)` as
636
+ `(I32,I32)` would).
589
637
590
638
(Currently aggregates can not contain other aggregates.)
591
639
@@ -604,13 +652,16 @@ def parse_args():
604
652
### Modifiers
605
653
606
654
- 'v': put a scalar into a vector of the current width (u32 -> u32x4, when width == 128)
655
+ - 'S': get the scalar element of a vector (u32x4 -> u32)
607
656
- 'h': half the length of the vector (u32x4 -> u32x2)
608
657
- 'd': double the length of the vector (u32x2 -> u32x4)
609
658
- 'n': narrow the element of the vector (u32x4 -> u16x4)
610
659
- 'w': widen the element of the vector (u16x4 -> u32x4)
611
- - 'u': force an integer (vector or scalar) to be unsigned (i32x4 -> u32x4)
612
- - 's': force an integer (vector or scalar) to be signed (u32x4 -> i32x4)
660
+ - 'u': force a number (vector or scalar) to be unsigned int (f32x4 -> u32x4)
661
+ - 's': force a number (vector or scalar) to be signed int (u32x4 -> i32x4)
662
+ - 'f': force a number (vector or scalar) to be float (u32x4 -> f32x4)
613
663
- 'x' number: force the type to be a vector of bitwidth `number`.
664
+ - '.' number: get the `number`th element of an aggregate
614
665
- 'D': dereference a pointer (*mut u32 -> u32)
615
666
- 'C': make a pointer const (*mut u32 -> *const u32)
616
667
- 'M': make a pointer mut (*const u32 -> *mut u32)
0 commit comments