diff --git a/llvm/test/TableGen/DecoderEmitterFnTable.td b/llvm/test/TableGen/DecoderEmitterFnTable.td new file mode 100644 index 0000000000000..7bed18c19a513 --- /dev/null +++ b/llvm/test/TableGen/DecoderEmitterFnTable.td @@ -0,0 +1,84 @@ +// RUN: llvm-tblgen -gen-disassembler -use-fn-table-in-decode-to-mcinst -I %p/../../include %s | FileCheck %s + +include "llvm/Target/Target.td" + +def archInstrInfo : InstrInfo { } + +def arch : Target { + let InstructionSet = archInstrInfo; +} + +let Namespace = "arch" in { + def R0 : Register<"r0">; + def R1 : Register<"r1">; + def R2 : Register<"r2">; + def R3 : Register<"r3">; +} +def Regs : RegisterClass<"Regs", [i32], 32, (add R0, R1, R2, R3)>; + +class TestInstruction : Instruction { + let Size = 1; + let OutOperandList = (outs); + field bits<8> Inst; + field bits<8> SoftFail = 0; +} + +// Define instructions to generate 4 cases in decodeToMCInst. +// Lower 2 bits define the number of operands. Each register operand +// needs 2 bits to encode. + +// An instruction with no inputs. Encoded with lower 2 bits = 0 and upper +// 6 bits = 0 as well. +def Inst0 : TestInstruction { + let Inst = 0x0; + let InOperandList = (ins); + let AsmString = "Inst0"; +} + +// An instruction with a single input. Encoded with lower 2 bits = 1 and the +// single input in bits 2-3. +def Inst1 : TestInstruction { + bits<2> r0; + let Inst{1-0} = 1; + let Inst{3-2} = r0; + let InOperandList = (ins Regs:$r0); + let AsmString = "Inst1"; +} + +// An instruction with two inputs. Encoded with lower 2 bits = 2 and the +// inputs in bits 2-3 and 4-5. +def Inst2 : TestInstruction { + bits<2> r0; + bits<2> r1; + let Inst{1-0} = 2; + let Inst{3-2} = r0; + let Inst{5-4} = r1; + let InOperandList = (ins Regs:$r0, Regs:$r1); + let AsmString = "Inst2"; +} + +// An instruction with three inputs. Encoded with lower 2 bits = 3 and the +// inputs in bits 2-3 and 4-5 and 6-7. +def Inst3 : TestInstruction { + bits<2> r0; + bits<2> r1; + bits<2> r2; + let Inst{1-0} = 3; + let Inst{3-2} = r0; + let Inst{5-4} = r1; + let Inst{7-6} = r2; + let InOperandList = (ins Regs:$r0, Regs:$r1, Regs:$r2); + let AsmString = "Inst3"; +} + +// CHECK-LABEL: DecodeStatus decodeFn0(DecodeStatus S, InsnType insn, MCInst &MI, uint64_t Address, const MCDisassembler *Decoder, bool &DecodeComplete) +// CHECK-LABEL: DecodeStatus decodeFn1(DecodeStatus S, InsnType insn, MCInst &MI, uint64_t Address, const MCDisassembler *Decoder, bool &DecodeComplete) +// CHECK-LABEL: DecodeStatus decodeFn2(DecodeStatus S, InsnType insn, MCInst &MI, uint64_t Address, const MCDisassembler *Decoder, bool &DecodeComplete) +// CHECK-LABEL: DecodeStatus decodeFn3(DecodeStatus S, InsnType insn, MCInst &MI, uint64_t Address, const MCDisassembler *Decoder, bool &DecodeComplete) +// CHECK-LABEL: decodeToMCInst(unsigned Idx, DecodeStatus S, InsnType insn, MCInst &MI, uint64_t Address, const MCDisassembler *Decoder, bool &DecodeComplete) +// CHECK: static constexpr DecodeFnTy decodeFnTable[] +// CHECK-NEXT: decodeFn0, +// CHECK-NEXT: decodeFn1, +// CHECK-NEXT: decodeFn2, +// CHECK-NEXT: decodeFn3, +// CHECK: return decodeFnTable[Idx](S, insn, MI, Address, Decoder, DecodeComplete) diff --git a/llvm/utils/TableGen/DecoderEmitter.cpp b/llvm/utils/TableGen/DecoderEmitter.cpp index 2e8ff2aa47d96..af25975f7c7ec 100644 --- a/llvm/utils/TableGen/DecoderEmitter.cpp +++ b/llvm/utils/TableGen/DecoderEmitter.cpp @@ -83,6 +83,14 @@ static cl::opt LargeTable( "in the table instead of the default 16 bits."), cl::init(false), cl::cat(DisassemblerEmitterCat)); +static cl::opt UseFnTableInDecodetoMCInst( + "use-fn-table-in-decode-to-mcinst", + cl::desc( + "Use a table of function pointers instead of a switch case in the\n" + "generated `decodeToMCInst` function. Helps improve compile time\n" + "of the generated code."), + cl::init(false), cl::cat(DisassemblerEmitterCat)); + STATISTIC(NumEncodings, "Number of encodings considered"); STATISTIC(NumEncodingsLackingDisasm, "Number of encodings without disassembler info"); @@ -1066,31 +1074,67 @@ void DecoderEmitter::emitPredicateFunction(formatted_raw_ostream &OS, void DecoderEmitter::emitDecoderFunction(formatted_raw_ostream &OS, DecoderSet &Decoders, indent Indent) const { - // The decoder function is just a big switch statement based on the - // input decoder index. - OS << Indent << "template \n"; - OS << Indent << "static DecodeStatus decodeToMCInst(DecodeStatus S," - << " unsigned Idx, InsnType insn, MCInst &MI,\n"; - OS << Indent << " uint64_t " - << "Address, const MCDisassembler *Decoder, bool &DecodeComplete) {\n"; - Indent += 2; - OS << Indent << "DecodeComplete = true;\n"; + // The decoder function is just a big switch statement or a table of function + // pointers based on the input decoder index. + // TODO: When InsnType is large, using uint64_t limits all fields to 64 bits // It would be better for emitBinaryParser to use a 64-bit tmp whenever // possible but fall back to an InsnType-sized tmp for truly large fields. - OS << Indent - << "using TmpType = " - "std::conditional_t::" - "value, InsnType, uint64_t>;\n"; - OS << Indent << "TmpType tmp;\n"; - OS << Indent << "switch (Idx) {\n"; - OS << Indent << "default: llvm_unreachable(\"Invalid index!\");\n"; - for (const auto &[Index, Decoder] : enumerate(Decoders)) { - OS << Indent << "case " << Index << ":\n"; - OS << Decoder; - OS << Indent + 2 << "return S;\n"; + StringRef TmpTypeDecl = + "using TmpType = std::conditional_t::value, " + "InsnType, uint64_t>;\n"; + StringRef DecodeParams = + "DecodeStatus S, InsnType insn, MCInst &MI, uint64_t Address, const " + "MCDisassembler *Decoder, bool &DecodeComplete"; + + if (UseFnTableInDecodetoMCInst) { + // Emit a function for each case first. + for (const auto &[Index, Decoder] : enumerate(Decoders)) { + OS << Indent << "template \n"; + OS << Indent << "DecodeStatus decodeFn" << Index << "(" << DecodeParams + << ") {\n"; + Indent += 2; + OS << Indent << TmpTypeDecl; + OS << Indent << "[[maybe_unused]] TmpType tmp;\n"; + OS << Decoder; + OS << Indent << "return S;\n"; + Indent -= 2; + OS << Indent << "}\n\n"; + } + } + + OS << Indent << "// Handling " << Decoders.size() << " cases.\n"; + OS << Indent << "template \n"; + OS << Indent << "static DecodeStatus decodeToMCInst(unsigned Idx, " + << DecodeParams << ") {\n"; + Indent += 2; + OS << Indent << "DecodeComplete = true;\n"; + + if (UseFnTableInDecodetoMCInst) { + // Build a table of function pointers. + OS << Indent << "using DecodeFnTy = DecodeStatus (*)(" << DecodeParams + << ");\n"; + OS << Indent << "static constexpr DecodeFnTy decodeFnTable[] = {\n"; + for (size_t Index : llvm::seq(Decoders.size())) + OS << Indent + 2 << "decodeFn" << Index << ",\n"; + OS << Indent << "};\n"; + OS << Indent << "if (Idx >= " << Decoders.size() << ")\n"; + OS << Indent + 2 << "llvm_unreachable(\"Invalid index!\");\n"; + OS << Indent + << "return decodeFnTable[Idx](S, insn, MI, Address, Decoder, " + "DecodeComplete);\n"; + } else { + OS << Indent << TmpTypeDecl; + OS << Indent << "TmpType tmp;\n"; + OS << Indent << "switch (Idx) {\n"; + OS << Indent << "default: llvm_unreachable(\"Invalid index!\");\n"; + for (const auto &[Index, Decoder] : enumerate(Decoders)) { + OS << Indent << "case " << Index << ":\n"; + OS << Decoder; + OS << Indent + 2 << "return S;\n"; + } + OS << Indent << "}\n"; } - OS << Indent << "}\n"; Indent -= 2; OS << Indent << "}\n"; } @@ -1267,7 +1311,8 @@ std::pair FilterChooser::getDecoderIndex(DecoderSet &Decoders, // FIXME: emitDecoder() function can take a buffer directly rather than // a stream. raw_svector_ostream S(Decoder); - bool HasCompleteDecoder = emitDecoder(S, indent(4), Opc); + indent Indent(UseFnTableInDecodetoMCInst ? 2 : 4); + bool HasCompleteDecoder = emitDecoder(S, Indent, Opc); // Using the full decoder string as the key value here is a bit // heavyweight, but is effective. If the string comparisons become a @@ -2371,7 +2416,7 @@ static DecodeStatus decodeInstruction(const uint8_t DecodeTable[], MCInst &MI, << " makeUp(insn, Len);"; } OS << R"( - S = decodeToMCInst(S, DecodeIdx, insn, MI, Address, DisAsm, DecodeComplete); + S = decodeToMCInst(DecodeIdx, S, insn, MI, Address, DisAsm, DecodeComplete); assert(DecodeComplete); LLVM_DEBUG(dbgs() << Loc << ": OPC_Decode: opcode " << Opc @@ -2393,7 +2438,7 @@ static DecodeStatus decodeInstruction(const uint8_t DecodeTable[], MCInst &MI, MCInst TmpMI; TmpMI.setOpcode(Opc); bool DecodeComplete; - S = decodeToMCInst(S, DecodeIdx, insn, TmpMI, Address, DisAsm, DecodeComplete); + S = decodeToMCInst(DecodeIdx, S, insn, TmpMI, Address, DisAsm, DecodeComplete); LLVM_DEBUG(dbgs() << Loc << ": OPC_TryDecode: opcode " << Opc << ", using decoder " << DecodeIdx << ": ");