Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGTargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SelectionDAGTargetInfo {
SelectionDAGTargetInfo(const SelectionDAGTargetInfo &) = delete;
SelectionDAGTargetInfo &operator=(const SelectionDAGTargetInfo &) = delete;
virtual ~SelectionDAGTargetInfo();
virtual bool isTargetStrictFPOpcode(unsigned Opcode) const { return false; }

/// Emit target-specific code that performs a memcpy.
/// This can be used by targets to provide code sequences for cases
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -2497,6 +2497,13 @@ class TargetLoweringBase {
setOperationAction(Opc, OrigVT, Promote);
AddPromotedToType(Opc, OrigVT, DestVT);
}
void setOperationPromotedToType(ArrayRef<unsigned> Ops, MVT OrigVT,
MVT DestVT) {
for (auto Op : Ops) {
setOperationAction(Op, OrigVT, Promote);
AddPromotedToType(Op, OrigVT, DestVT);
}
}

/// Targets should invoke this method for each target independent node that
/// they want to provide a custom DAG combiner for by implementing the
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Support/RISCVISAInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ static const RISCVSupportedExtension SupportedExtensions[] = {
{"zve64f", RISCVExtensionVersion{1, 0}},
{"zve64d", RISCVExtensionVersion{1, 0}},

{"zvfhmin", RISCVExtensionVersion{1, 0}},

{"zicbom", RISCVExtensionVersion{1, 0}},
{"zicboz", RISCVExtensionVersion{1, 0}},
{"zicbop", RISCVExtensionVersion{1, 0}},
Expand Down Expand Up @@ -785,6 +787,7 @@ static const char *ImpliedExtsZk[] = {"zkn", "zkt", "zkr"};
static const char *ImpliedExtsZkn[] = {"zbkb", "zbkc", "zbkx", "zkne", "zknd", "zknh"};
static const char *ImpliedExtsZks[] = {"zbkb", "zbkc", "zbkx", "zksed", "zksh"};
static const char *ImpliedExtsZvfh[] = {"zve32f"};
static const char *ImpliedExtsZvfhmin[] = {"zve32f"};

struct ImpliedExtsEntry {
StringLiteral Name;
Expand Down Expand Up @@ -814,6 +817,7 @@ static constexpr ImpliedExtsEntry ImpliedExts[] = {
{{"zve64f"}, {ImpliedExtsZve64f}},
{{"zve64x"}, {ImpliedExtsZve64x}},
{{"zvfh"}, {ImpliedExtsZvfh}},
{{"zvfhmin"}, {ImpliedExtsZvfhmin}},
{{"zvl1024b"}, {ImpliedExtsZvl1024b}},
{{"zvl128b"}, {ImpliedExtsZvl128b}},
{{"zvl16384b"}, {ImpliedExtsZvl16384b}},
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/RISCV/RISCV.td
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,16 @@ def FeatureStdExtZvfh
"'Zvfh' (Vector Half-Precision Floating-Point)",
[FeatureStdExtZve32f]>;

def FeatureStdExtZvfhmin
: SubtargetFeature<"zvfhmin", "HasStdExtZvfhmin", "true",
"'Zvfhmin' (Vector Half-Precision Floating-Point Minimal)",
[FeatureStdExtZve32f]>;

def HasVInstructionsF16Minimal : Predicate<"Subtarget->hasVInstructionsF16Minimal()">,
AssemblerPredicate<(any_of FeatureStdExtZvfhmin, FeatureStdExtZvfh),
"'Zvfhmin' (Vector Half-Precision Floating-Point Minimal) or "
"'Zvfh' (Vector Half-Precision Floating-Point)">;

def FeatureStdExtZicbom
: SubtargetFeature<"zicbom", "HasStdExtZicbom", "true",
"'Zicbom' (Cache-Block Management Instructions)">;
Expand Down
197 changes: 196 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::i32, &RISCV::VGPRRegClass);

addRegisterClass(MVT::f32, &RISCV::VGPRRegClass);

if (Subtarget.hasVInstructionsF16Minimal()) {
addRegisterClass(MVT::f16, &RISCV::VGPRRegClass);
addRegisterClass(MVT::v8f16, &RISCV::VGPRRegClass);
addRegisterClass(MVT::v16f16, &RISCV::VGPRRegClass);
addRegisterClass(MVT::v32f16, &RISCV::VGPRRegClass);
}
}

// Compute derived properties from the register classes.
Expand Down Expand Up @@ -727,6 +734,67 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::SETGT, ISD::SETOGT, ISD::SETGE, ISD::SETOGE,
};

// TODO: support more ops.
static const unsigned ZvfhminZvfbfminPromoteOps[] = {
ISD::FMINNUM,
ISD::FMAXNUM,
// ISD::FMINIMUMNUM,
// ISD::FMAXIMUMNUM,
ISD::FADD,
ISD::FSUB,
ISD::FMUL,
ISD::FMA,
ISD::FDIV,
ISD::FSQRT,
ISD::FCEIL,
ISD::FTRUNC,
ISD::FFLOOR,
ISD::FROUND,
ISD::FROUNDEVEN,
ISD::FRINT,
ISD::FNEARBYINT,
ISD::IS_FPCLASS,
ISD::SETCC,
ISD::FMAXIMUM,
ISD::FMINIMUM,
ISD::STRICT_FADD,
ISD::STRICT_FSUB,
ISD::STRICT_FMUL,
ISD::STRICT_FDIV,
ISD::STRICT_FSQRT,
ISD::STRICT_FMA,
ISD::VECREDUCE_FMIN,
ISD::VECREDUCE_FMAX,
// ISD::VECREDUCE_FMINIMUM,
// ISD::VECREDUCE_FMAXIMUM
};

// TODO: support more vp ops.
static const unsigned ZvfhminZvfbfminPromoteVPOps[] = {
ISD::VP_FADD,
ISD::VP_FSUB,
ISD::VP_FMUL,
ISD::VP_FDIV,
ISD::VP_FMA,
ISD::VP_REDUCE_FMIN,
ISD::VP_REDUCE_FMAX,
ISD::VP_SQRT,
ISD::VP_FMINNUM,
ISD::VP_FMAXNUM,
ISD::VP_FCEIL,
ISD::VP_FFLOOR,
ISD::VP_FROUND,
ISD::VP_FROUNDEVEN,
ISD::VP_FROUNDTOZERO,
ISD::VP_FRINT,
ISD::VP_FNEARBYINT,
ISD::VP_SETCC,
// ISD::VP_FMINIMUM,
// ISD::VP_FMAXIMUM,
// ISD::VP_REDUCE_FMINIMUM,
// ISD::VP_REDUCE_FMAXIMUM
};

// Sets common operation actions on RVV floating-point vector types.
const auto SetCommonVFPActions = [&](MVT VT) {
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
Expand Down Expand Up @@ -803,6 +871,19 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
continue;
SetCommonVFPActions(VT);
}
} else if (Subtarget.hasVInstructionsF16Minimal()) {
for (MVT VT : F16VecVTs) {
if (!isTypeLegal(VT))
continue;
setOperationAction(ISD::FP_ROUND, VT, Custom);
setOperationAction(ISD::FP_EXTEND, VT, Custom);
setOperationAction(ISD::VP_FP_ROUND, VT, Custom);
setOperationAction(ISD::VP_FP_EXTEND, VT, Custom);
setOperationAction(ISD::LOAD, VT, Custom);
setOperationAction(ISD::STORE, VT, Custom);
setOperationAction(ISD::VP_LOAD, VT, Custom);
setOperationAction(ISD::VP_STORE, VT, Custom);
}
}

if (Subtarget.hasVInstructionsF32()) {
Expand Down Expand Up @@ -963,6 +1044,42 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTruncStoreAction(VT, OtherVT, Expand);
}

if (VT.getVectorElementType() == MVT::f16 &&
!Subtarget.hasVInstructionsF16()) {
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
setOperationAction(
{ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT,
Custom);
setOperationAction({ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP}, VT,
Custom);
if (Subtarget.hasStdExtZfhmin()) {
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
} else {
// We need to custom legalize f16 build vectors if Zfhmin isn't
// available.
setOperationAction(ISD::BUILD_VECTOR, MVT::f16, Custom);
}
setOperationAction(ISD::FNEG, VT, Expand);
setOperationAction(ISD::FABS, VT, Expand);
setOperationAction(ISD::FCOPYSIGN, VT, Expand);
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
// Don't promote f16 vector operations to f32 if f32 vector type is
// not legal.
// TODO: could split the f16 vector into two vectors and do promotion.
if (!isTypeLegal(F32VecVT))
continue;
// Custom split nxv32[b]f16 since nxv32[b]f32 is not legal.
if (getLMUL(VT) == RISCVII::LMUL_8) {
setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom);
setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom);
} else {
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
}
}

// We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
// setOperationAction({ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR}, VT,
// Custom);
Expand Down Expand Up @@ -1797,7 +1914,7 @@ static bool useRVVForFixedLengthVectorVT(MVT VT,
return false;
break;
case MVT::f16:
if (!Subtarget.hasVInstructionsF16())
if (!Subtarget.hasVInstructionsF16Minimal())
return false;
break;
case MVT::f32:
Expand Down Expand Up @@ -9498,6 +9615,45 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
return Opcode;
}

static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
// Fold FNEG_VL into FMA opcodes.
// The first operand of strict-fp is chain.
bool IsStrict =
DAG.getSelectionDAGInfo().isTargetStrictFPOpcode(N->getOpcode());
unsigned Offset = IsStrict ? 1 : 0;
SDValue A = N->getOperand(0 + Offset);
SDValue B = N->getOperand(1 + Offset);
SDValue C = N->getOperand(2 + Offset);
SDValue Mask = N->getOperand(3 + Offset);
SDValue VL = N->getOperand(4 + Offset);

auto invertIfNegative = [&Mask, &VL](SDValue &V) {
if (V.getOpcode() == RISCVISD::FNEG_VL && V.getOperand(1) == Mask &&
V.getOperand(2) == VL) {
// Return the negated input.
V = V.getOperand(0);
return true;
}

return false;
};

bool NegA = invertIfNegative(A);
bool NegB = invertIfNegative(B);
bool NegC = invertIfNegative(C);

// If no operands are negated, we're done.
if (!NegA && !NegB && !NegC)
return SDValue();

unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), NegA != NegB, NegC);
if (IsStrict)
return DAG.getNode(NewOpcode, SDLoc(N), N->getVTList(),
{N->getOperand(0), A, B, C, Mask, VL});
return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), A, B, C, Mask,
VL);
}

static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
Expand Down Expand Up @@ -9668,6 +9824,39 @@ static SDValue tryDemorganOfBooleanCondition(SDValue Cond, SelectionDAG &DAG) {
return DAG.getNode(Opc, SDLoc(Cond), VT, Setcc, Xor.getOperand(0));
}

static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
return V;

if (N->getValueType(0).isScalableVector() &&
N->getValueType(0).getVectorElementType() == MVT::f32 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16())) {
return SDValue();
}
}

static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (N->getValueType(0).isScalableVector() &&
N->getValueType(0).getVectorElementType() == MVT::f32 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16())) {
return SDValue();
}
}

static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (N->getValueType(0).isScalableVector() &&
N->getValueType(0).getVectorElementType() == MVT::f32 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16())) {
return SDValue();
}
}

// Perform common combines for BR_CC and SELECT_CC condtions.
static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
Expand Down Expand Up @@ -10239,6 +10428,12 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), A, B, C, Mask,
VL);
}
case RISCVISD::STRICT_VFNMSUB_VL:
return performVFMADD_VLCombine(N, DAG, Subtarget);
case RISCVISD::FMUL_VL:
return performVFMUL_VLCombine(N, DAG, Subtarget);
case RISCVISD::FSUB_VL:
return performFADDSUB_VLCombine(N, DAG, Subtarget);
case ISD::STORE: {
auto *Store = cast<StoreSDNode>(N);
SDValue Val = Store->getValue();
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ enum NodeType : unsigned {
// result being sign extended to 64 bit. These saturate out of range inputs.
STRICT_FCVT_W_RV64 = ISD::FIRST_TARGET_STRICTFP_OPCODE,
STRICT_FCVT_WU_RV64,
STRICT_VFNMSUB_VL,

// WARNING: Do not add anything in the end unless you want the node to
// have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/RISCV/RISCVSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
bool HasStdExtZve64f = false;
bool HasStdExtZve64d = false;
bool HasStdExtZvfh = false;
bool HasStdExtZvfhmin = false;
bool HasStdExtZfhmin = false;
bool HasStdExtZfh = false;
bool HasStdExtZfinx = false;
Expand Down Expand Up @@ -273,6 +274,9 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
// Vector codegen related methods.
bool hasVInstructions() const { return HasStdExtZve32x; }
bool hasVInstructionsI64() const { return HasStdExtZve64x; }
bool hasVInstructionsF16Minimal() const {
return HasStdExtZvfhmin || HasStdExtZvfh;
}
bool hasVInstructionsF16() const { return HasStdExtZvfh && HasStdExtZfh; }
bool hasVInstructionsF32() const {
return HasStdExtZve32f && (HasStdExtF || HasStdExtZfinx);
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Target/RISCV/VentusInstrInfoV.td
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ def VLHI12 : VENTUS_VLI12<0b001, "vlh12">;
def VLBI12 : VENTUS_VLI12<0b000, "vlb12">;
def VLHUI12 : VENTUS_VLI12<0b101, "vlhu12">;
def VLBUI12 : VENTUS_VLI12<0b100, "vlbu12">;
def VLE16_V : VENTUS_VLI12<0b110, "vle16">;
}

let hasSideEffects = 0, mayLoad = 0, mayStore = 1 in {
Expand All @@ -805,6 +806,7 @@ def VSB : VENTUS_VS<0b000, "vsb">;
def VSWI12 : VENTUS_VSI12<0b110, "vsw12">;
def VSHI12 : VENTUS_VSI12<0b011, "vsh12">;
def VSBI12 : VENTUS_VSI12<0b111, "vsb12">;
def VSE16_V : VENTUS_VSI12<0b110, "vse16">;
}

let Predicates = [HasVInstructions] in {
Expand Down Expand Up @@ -1508,3 +1510,21 @@ let Predicates = [HasStdExtZfinx] in {
def : Pat<(f32 (bitconvert (i32 GPR:$src))), (VMV_V_X GPR:$src)>;
// def : Pat<(i32 (bitconvert GPRF32:$src)), (VFMV_V_F GPRF32:$src)>;
} // Predicates = [HasStdExtZfinx]

//===----------------------------------------------------------------------===//
// zvfhmin: half-precision vector load/store and conversion patterns
//===----------------------------------------------------------------------===//
multiclass ZvfhminPatterns {
// 8 lanes
def : Pat<(v8f16 (load VReg_128:$rs1)), (VLE16_V VReg_128:$rs1, 0)>;
def : Pat<(store (v8f16 VReg_128:$vs3), VReg_128:$rs1), (VSE16_V VReg_128:$vs3, VReg_128:$rs1, 0)>;
def : Pat<(v8f32 (fpextend (v8f16 VReg_128:$vs2))), (VFWCVT_F_F_V VReg_256:$vs2)>;
def : Pat<(v8f16 (fpround (v8f32 VReg_256:$vs2))), (VFNCVT_F_F_W VReg_128:$vs2)>;

// 16 lanes
def : Pat<(v16f16 (load VReg_256:$rs1)), (VLE16_V VReg_256:$rs1, 0)>;
def : Pat<(store (v16f16 VReg_256:$vs3), VReg_256:$rs1), (VSE16_V VReg_256:$vs3, VReg_256:$rs1, 0)>;
def : Pat<(v16f32 (fpextend (v16f16 VReg_256:$vs2))), (VFWCVT_F_F_V VReg_512:$vs2)>;
def : Pat<(v16f16 (fpround (v16f32 VReg_512:$vs2))), (VFNCVT_F_F_W VReg_256:$vs2)>;
}
defm : ZvfhminPatterns;
Loading