@@ -178,6 +178,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
178178 bool selectRsqrt (Register ResVReg, const SPIRVType *ResType,
179179 MachineInstr &I) const ;
180180
181+ bool selectIntegerDot (Register ResVReg, const SPIRVType *ResType,
182+ MachineInstr &I) const ;
183+
181184 void renderImm32 (MachineInstrBuilder &MIB, const MachineInstr &I,
182185 int OpIdx) const ;
183186 void renderFImm32 (MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -380,6 +383,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
380383 MIB.addImm (V);
381384 return MIB.constrainAllUses (TII, TRI, RBI);
382385 }
386+
387+ case TargetOpcode::G_FDOTPROD: {
388+ MachineBasicBlock &BB = *I.getParent ();
389+ return BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpDot))
390+ .addDef (ResVReg)
391+ .addUse (GR.getSPIRVTypeID (ResType))
392+ .addUse (I.getOperand (1 ).getReg ())
393+ .addUse (I.getOperand (2 ).getReg ())
394+ .constrainAllUses (TII, TRI, RBI);
395+ }
396+ case TargetOpcode::G_SDOTPROD:
397+ case TargetOpcode::G_UDOTPROD:
398+ return selectIntegerDot (ResVReg, ResType, I);
399+
383400 case TargetOpcode::G_MEMMOVE:
384401 case TargetOpcode::G_MEMCPY:
385402 case TargetOpcode::G_MEMSET:
@@ -1366,6 +1383,67 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
13661383 .constrainAllUses (TII, TRI, RBI);
13671384}
13681385
1386+ // Since there is no integer dot implementation, expand by piecewise multiplying
1387+ // and adding the results, making use of FMA operations where possible.
1388+ bool SPIRVInstructionSelector::selectIntegerDot (Register ResVReg,
1389+ const SPIRVType *ResType,
1390+ MachineInstr &I) const {
1391+ assert (I.getNumOperands () == 3 );
1392+ assert (I.getOperand (1 ).isReg ());
1393+ assert (I.getOperand (2 ).isReg ());
1394+ MachineBasicBlock &BB = *I.getParent ();
1395+
1396+ // Multiply the vectors, then sum the results
1397+ Register Vec0 = I.getOperand (1 ).getReg ();
1398+ Register Vec1 = I.getOperand (2 ).getReg ();
1399+ Register TmpVec = MRI->createVirtualRegister (&SPIRV::IDRegClass);
1400+ SPIRVType *VecType = GR.getSPIRVTypeForVReg (Vec0);
1401+
1402+ bool Result = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpIMulV))
1403+ .addDef (TmpVec)
1404+ .addUse (GR.getSPIRVTypeID (VecType))
1405+ .addUse (Vec0)
1406+ .addUse (Vec1)
1407+ .constrainAllUses (TII, TRI, RBI);
1408+
1409+ assert (GR.getScalarOrVectorComponentCount (VecType) > 1 &&
1410+ " dot product requires a vector of at least 2 components" );
1411+
1412+ Register Res = MRI->createVirtualRegister (&SPIRV::IDRegClass);
1413+ Result |= BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpCompositeExtract))
1414+ .addDef (Res)
1415+ .addUse (GR.getSPIRVTypeID (ResType))
1416+ .addUse (TmpVec)
1417+ .addImm (0 )
1418+ .constrainAllUses (TII, TRI, RBI);
1419+
1420+ for (unsigned i = 1 ; i < GR.getScalarOrVectorComponentCount (VecType); i++) {
1421+ Register Elt = MRI->createVirtualRegister (&SPIRV::IDRegClass);
1422+
1423+ Result |=
1424+ BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpCompositeExtract))
1425+ .addDef (Elt)
1426+ .addUse (GR.getSPIRVTypeID (ResType))
1427+ .addUse (TmpVec)
1428+ .addImm (i)
1429+ .constrainAllUses (TII, TRI, RBI);
1430+
1431+ Register Sum = i < GR.getScalarOrVectorComponentCount (VecType) - 1
1432+ ? MRI->createVirtualRegister (&SPIRV::IDRegClass)
1433+ : ResVReg;
1434+
1435+ Result |= BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpIAddS))
1436+ .addDef (Sum)
1437+ .addUse (GR.getSPIRVTypeID (ResType))
1438+ .addUse (Res)
1439+ .addUse (Elt)
1440+ .constrainAllUses (TII, TRI, RBI);
1441+ Res = Sum;
1442+ }
1443+
1444+ return Result;
1445+ }
1446+
13691447bool SPIRVInstructionSelector::selectBitreverse (Register ResVReg,
13701448 const SPIRVType *ResType,
13711449 MachineInstr &I) const {
0 commit comments