Skip to content

Commit 7241ae9

Browse files
committed
Support return aggregates in platform intrinsics.
This also involved adding `[TYPE;N]` syntax and aggregate indexing support to the generator script: it's the only way to be able to have a parameterised intrinsic that returns an aggregate, since one can't refer to previous elements of the current aggregate (and that was harder to implement).
1 parent c19e7b6 commit 7241ae9

File tree

2 files changed

+88
-18
lines changed

2 files changed

+88
-18
lines changed

src/etc/platform-intrinsics/generator.py

+67-16
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
SPEC = re.compile(
2020
r'^(?:(?P<void>V)|(?P<id>[iusfIUSF])(?:\((?P<start>\d+)-(?P<end>\d+)\)|'
2121
r'(?P<width>\d+)(:?/(?P<llvm_width>\d+))?)'
22-
r'|(?P<reference>\d+))(?P<modifiers>[vShdnwusDMC]*)(?P<force_width>x\d+)?'
22+
r'|(?P<reference>\d+))(?P<index>\.\d+)?(?P<modifiers>[vShdnwusfDMC]*)(?P<force_width>x\d+)?'
2323
r'(?:(?P<pointer>Pm|Pc)(?P<llvm_pointer>/.*)?|(?P<bitcast>->.*))?$'
2424
)
2525

@@ -70,23 +70,32 @@ def lookup(raw):
7070
{k: lookup(v) for k, v in data.items()})
7171

7272
class PlatformTypeInfo(object):
73-
def __init__(self, llvm_name, properties):
74-
self.properties = properties
75-
self.llvm_name = llvm_name
73+
def __init__(self, llvm_name, properties, elems = None):
74+
if elems is None:
75+
self.properties = properties
76+
self.llvm_name = llvm_name
77+
else:
78+
assert properties is None and llvm_name is None
79+
self.properties = {}
80+
self.elems = elems
7681

7782
def __repr__(self):
7883
return '<PlatformTypeInfo {}, {}>'.format(self.llvm_name, self.properties)
7984

8085
def __getattr__(self, name):
8186
return self.properties[name]
8287

88+
def __getitem__(self, idx):
89+
return self.elems[idx]
90+
8391
def vectorize(self, length, width_info):
8492
props = self.properties.copy()
8593
props.update(width_info)
8694
return PlatformTypeInfo('v{}{}'.format(length, self.llvm_name), props)
8795

88-
def pointer(self):
89-
return PlatformTypeInfo('p0{}'.format(self.llvm_name), self.properties)
96+
def pointer(self, llvm_elem):
97+
name = self.llvm_name if llvm_elem is None else llvm_elem.llvm_name
98+
return PlatformTypeInfo('p0{}'.format(name), self.properties)
9099

91100
BITWIDTH_POINTER = '<pointer>'
92101

@@ -128,6 +137,8 @@ def modify(self, spec, width, previous):
128137
return Unsigned(self.bitwidth())
129138
elif spec == 's':
130139
return Signed(self.bitwidth())
140+
elif spec == 'f':
141+
return Float(self.bitwidth())
131142
elif spec == 'w':
132143
return self.__class__(self.bitwidth() * 2)
133144
elif spec == 'n':
@@ -283,7 +294,11 @@ def rust_name(self):
283294
self._elem.rust_name())
284295

285296
def type_info(self, platform_info):
286-
return self._elem.type_info(platform_info).pointer()
297+
if self._llvm_elem is None:
298+
llvm_elem = None
299+
else:
300+
llvm_elem = self._llvm_elem.type_info(platform_info)
301+
return self._elem.type_info(platform_info).pointer(llvm_elem)
287302

288303
def __eq__(self, other):
289304
return isinstance(other, Pointer) and self._const == other._const \
@@ -298,6 +313,14 @@ def __init__(self, flatten, elems):
298313
def __repr__(self):
299314
return '<Aggregate {}>'.format(self._elems)
300315

316+
def modify(self, spec, width, previous):
317+
if spec.startswith('.'):
318+
num = int(spec[1:])
319+
return self._elems[num]
320+
else:
321+
print(spec)
322+
raise NotImplementedError()
323+
301324
def compiler_ctor(self):
302325
return 'agg({}, vec![{}])'.format('true' if self._flatten else 'false',
303326
', '.join(elem.compiler_ctor() for elem in self._elems))
@@ -306,8 +329,7 @@ def rust_name(self):
306329
return '({})'.format(', '.join(elem.rust_name() for elem in self._elems))
307330

308331
def type_info(self, platform_info):
309-
#return PlatformTypeInfo(None, None, self._llvm_name)
310-
return None
332+
return PlatformTypeInfo(None, None, [elem.type_info(platform_info) for elem in self._elems])
311333

312334
def __eq__(self, other):
313335
return isinstance(other, Aggregate) and self._flatten == other._flatten and \
@@ -349,7 +371,11 @@ def enumerate(self, width, previous):
349371
id = match.group('id')
350372
reference = match.group('reference')
351373

352-
modifiers = list(match.group('modifiers') or '')
374+
modifiers = []
375+
index = match.group('index')
376+
if index is not None:
377+
modifiers.append(index)
378+
modifiers += list(match.group('modifiers') or '')
353379
force = match.group('force_width')
354380
if force is not None:
355381
modifiers.append(force)
@@ -407,16 +433,32 @@ def enumerate(self, width, previous):
407433
else:
408434
assert False, 'matched `{}`, but didn\'t understand it?'.format(spec)
409435
elif spec.startswith('('):
410-
assert bitcast is None
411436
if spec.endswith(')'):
412-
raise NotImplementedError()
437+
true_spec = spec[1:-1]
438+
flatten = False
413439
elif spec.endswith(')f'):
414440
true_spec = spec[1:-2]
415441
flatten = True
442+
else:
443+
assert False, 'found unclosed aggregate `{}`'.format(spec)
416444

417445
for elems in itertools.product(*(TypeSpec(subspec).enumerate(width, previous)
418446
for subspec in true_spec.split(','))):
419447
yield Aggregate(flatten, elems)
448+
elif spec.startswith('['):
449+
if spec.endswith(']'):
450+
true_spec = spec[1:-1]
451+
flatten = False
452+
elif spec.endswith(']f'):
453+
true_spec = spec[1:-2]
454+
flatten = True
455+
else:
456+
assert False, 'found unclosed aggregate `{}`'.format(spec)
457+
elem_spec, count = true_spec.split(';')
458+
459+
count = int(count)
460+
for elem in TypeSpec(elem_spec).enumerate(width, previous):
461+
yield Aggregate(flatten, [elem] * count)
420462
else:
421463
assert False, 'Failed to parse `{}`'.format(spec)
422464

@@ -514,7 +556,7 @@ def parse_args():
514556
core_type := void | vector | scalar | aggregate | reference
515557
516558
modifier := 'v' | 'h' | 'd' | 'n' | 'w' | 'u' | 's' |
517-
'x' number
559+
'x' number | '.' number
518560
suffix := pointer | bitcast
519561
pointer := 'Pm' llvm_pointer? | 'Pc' llvm_pointer?
520562
llvm_pointer := '/' type
@@ -529,7 +571,7 @@ def parse_args():
529571
scalar_type := 'U' | 'S' | 'F'
530572
llvm_width := '/' number
531573
532-
aggregate := '(' (type),* ')' 'f'?
574+
aggregate := '(' (type),* ')' 'f'? | '[' type ';' number ']' 'f'?
533575
534576
reference := number
535577
@@ -586,6 +628,12 @@ def parse_args():
586628
- no `f` corresponds to `declare ... @llvm.foo({float, i32})`.
587629
- having an `f` corresponds to `declare ... @llvm.foo(float, i32)`.
588630
631+
The `[type;number]` form is a just shorter way to write
632+
`(...)`, except avoids doing a cartesian product of generic
633+
types, e.g. `[S32;2]` is the same as `(S32, S32)`, while
634+
`[I32;2]` is describing just the two types `(S32,S32)` and
635+
`(U32,U32)` (i.e. doesn't include `(S32,U32)`, `(U32,S32)` as
636+
`(I32,I32)` would).
589637
590638
(Currently aggregates can not contain other aggregates.)
591639
@@ -604,13 +652,16 @@ def parse_args():
604652
### Modifiers
605653
606654
- 'v': put a scalar into a vector of the current width (u32 -> u32x4, when width == 128)
655+
- 'S': get the scalar element of a vector (u32x4 -> u32)
607656
- 'h': half the length of the vector (u32x4 -> u32x2)
608657
- 'd': double the length of the vector (u32x2 -> u32x4)
609658
- 'n': narrow the element of the vector (u32x4 -> u16x4)
610659
- 'w': widen the element of the vector (u16x4 -> u32x4)
611-
- 'u': force an integer (vector or scalar) to be unsigned (i32x4 -> u32x4)
612-
- 's': force an integer (vector or scalar) to be signed (u32x4 -> i32x4)
660+
- 'u': force a number (vector or scalar) to be unsigned int (f32x4 -> u32x4)
661+
- 's': force a number (vector or scalar) to be signed int (u32x4 -> i32x4)
662+
- 'f': force a number (vector or scalar) to be float (u32x4 -> f32x4)
613663
- 'x' number: force the type to be a vector of bitwidth `number`.
664+
- '.' number: get the `number`th element of an aggregate
614665
- 'D': dereference a pointer (*mut u32 -> u32)
615666
- 'C': make a pointer const (*mut u32 -> *const u32)
616667
- 'M': make a pointer mut (*const u32 -> *mut u32)

src/librustc_trans/trans/intrinsic.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,12 @@ pub fn trans_intrinsic_call<'a, 'blk, 'tcx>(mut bcx: Block<'blk, 'tcx>,
965965
vec![Type::vector(&elem,
966966
length as u64)]
967967
}
968-
Aggregate(false, _) => unimplemented!(),
968+
Aggregate(false, ref contents) => {
969+
let elems = contents.iter()
970+
.map(|t| one(ty_to_type(ccx, t, any_changes_needed)))
971+
.collect::<Vec<_>>();
972+
vec![Type::struct_(ccx, &elems, false)]
973+
}
969974
Aggregate(true, ref contents) => {
970975
*any_changes_needed = true;
971976
contents.iter()
@@ -1049,14 +1054,28 @@ pub fn trans_intrinsic_call<'a, 'blk, 'tcx>(mut bcx: Block<'blk, 'tcx>,
10491054
};
10501055
assert_eq!(inputs.len(), llargs.len());
10511056

1052-
match intr.definition {
1057+
let val = match intr.definition {
10531058
intrinsics::IntrinsicDef::Named(name) => {
10541059
let f = declare::declare_cfn(ccx,
10551060
name,
10561061
Type::func(&inputs, &outputs),
10571062
tcx.mk_nil());
10581063
Call(bcx, f, &llargs, None, call_debug_location)
10591064
}
1065+
};
1066+
1067+
match intr.output {
1068+
intrinsics::Type::Aggregate(flatten, ref elems) => {
1069+
// the output is a tuple so we need to munge it properly
1070+
assert!(!flatten);
1071+
1072+
for i in 0..elems.len() {
1073+
let val = ExtractValue(bcx, val, i);
1074+
Store(bcx, val, StructGEP(bcx, llresult, i));
1075+
}
1076+
C_nil(ccx)
1077+
}
1078+
_ => val,
10601079
}
10611080
}
10621081
};

0 commit comments

Comments
 (0)