#include "PatternTritonGPUOpToLLVM.h"
#include "Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

namespace {

using namespace mlir;
using namespace mlir::triton;

struct GetNumProgramsOpConversion
    : public ConvertOpToLLVMPattern<triton::GetNumProgramsOp> {
  using ConvertOpToLLVMPattern<
      triton::GetNumProgramsOp>::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // It is not easy to get the compute capability here, so we use numCTAs to
    // decide the semantic of GetNumProgramsOp. If numCTAs = 1, then
    // GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to
    // "%nclusterid".
    auto moduleOp = op->getParentOfType<ModuleOp>();
    assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
    int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);

    Location loc = op->getLoc();
    assert(op.getAxisAsInt() < 3);
    std::string sreg = numCTAs == 1 ? "nctaid." : "nclusterid.";
    sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'

    Value numPrograms = LLVM::NVIDIA::getSRegValue(rewriter, loc, sreg);
    rewriter.replaceOp(op, numPrograms);
    return success();
  }
};

struct GetCanonicalWarpIdConversion
    : public ConvertOpToLLVMPattern<triton::nvidia_gpu::GetCanonicalWarpIdOp> {
  using ConvertOpToLLVMPattern<
      triton::nvidia_gpu::GetCanonicalWarpIdOp>::ConvertOpToLLVMPattern;
  LogicalResult
  matchAndRewrite(triton::nvidia_gpu::GetCanonicalWarpIdOp op,
                  OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto warpIdOp = rewriter.create<triton::nvgpu::CanonicalWarpIdOp>(
        op->getLoc(), rewriter.getI32Type());
    rewriter.replaceOp(op, warpIdOp);
    return success();
  }
};
} // namespace

void mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(
    LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
    PatternBenefit benefit) {
  patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
  patterns.add<GetCanonicalWarpIdConversion>(typeConverter, benefit);
}
