// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/controlflow/scan.h"

#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tensor/transpose.h"
#include "core/framework/ort_value.h"

// TODO: It's ugly to include a .cc file but this .cc file defines the implementation of some templates which we need.
#include "core/framework/ort_value_tensor_slicer.cc"

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;

namespace onnxruntime {
namespace rocm {

template <>
Scan<8>::Scan(const OpKernelInfo& info) : onnxruntime::Scan<8>(info) {
  scan::detail::DeviceHelpers helpers;

  helpers.set_data_to_zero_func = [](void* data, size_t size_in_bytes) -> Status {
    HIP_RETURN_IF_ERROR(hipMemset(data, 0, size_in_bytes));
    return Status::OK();
  };

  // copy into base class
  SetDeviceHelpers(helpers);
}

template <>
Scan<9>::Scan(const OpKernelInfo& info) : onnxruntime::Scan<9>(info) {
  scan::detail::DeviceHelpers helpers;

  helpers.transpose_func = [this](const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output, Stream* stream) {
    // TODO: We construct a Transpose kernel on each call as doing so is fairly lightweight.
    // We could potentially keep a single instance and reuse it if that isn't performant enough.
    const OpKernelInfo& info = OpKernel::Info();
    Transpose transpose_kernel = rocm::Transpose(info);
    return rocm::Transpose::DoTranspose(transpose_kernel, stream, permutations, input, output);
  };

  // copy into base class
  SetDeviceHelpers(helpers);
}

template <>
Status Scan<8>::Compute(OpKernelContext* ctx) const {
  // call the base CPU version.
  // we have this ROCM implementation so the inputs/outputs stay on GPU where possible.
  // the logic to run the subgraph must be on CPU either way.
  // technically we don't need this override of Compute, but it will be optimized out and it's easier to debug
  // that this implementation is being called with it.
  auto status = onnxruntime::Scan<8>::Compute(ctx);
  return status;
}

template <>
Status Scan<9>::Compute(OpKernelContext* ctx) const {
  // call the base CPU version.
  // we have this ROCM implementation so the inputs/outputs stay on GPU where possible.
  // the logic to run the subgraph must be on CPU either way.
  // technically we don't need this override of Compute, but it will be optimized out and it's easier to debug
  // that this implementation is being called with it.
  auto status = onnxruntime::Scan<9>::Compute(ctx);
  return status;
}

ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan,
                                  kOnnxDomain,
                                  8, 8,
                                  kRocmExecutionProvider,
                                  (*KernelDefBuilder::Create())
                                      .InputMemoryType(OrtMemTypeCPUInput, 0)  // 'sequence_lens' needs to be on CPU
                                      .TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
                                      .TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
                                  Scan<8>);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan,
                                  kOnnxDomain,
                                  9, 10,
                                  kRocmExecutionProvider,
                                  (*KernelDefBuilder::Create())
                                      // 'I' is in the ONNX spec but is not used for any inputs or outputs
                                      // .TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
                                      .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
                                  Scan<9>);

// Opset 11 starts to support Neg Axis.
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan,
                                  kOnnxDomain,
                                  11,
                                  15,
                                  kRocmExecutionProvider,
                                  (*KernelDefBuilder::Create())
                                      // 'I' is in the ONNX spec but is not used for any inputs or outputs
                                      // .TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
                                      .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
                                  Scan<9>);

// Opset 16 starts to support BFloat16 type for the type constraint "V"
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan,
                                  kOnnxDomain,
                                  16, 18,
                                  kRocmExecutionProvider,
                                  (*KernelDefBuilder::Create())
                                      // 'I' is in the ONNX spec but is not used for any inputs or outputs
                                      // .TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
                                      .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
                                  Scan<9>);

// Opset 19 starts to support float 8 types for the type constraint "V"
ONNX_OPERATOR_KERNEL_EX(Scan,
                        kOnnxDomain,
                        19,
                        kRocmExecutionProvider,
                        (*KernelDefBuilder::Create())
                            // 'I' is in the ONNX spec but is not used for any inputs or outputs
                            // .TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
                            .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()),
                        Scan<9>);

}  // namespace rocm
}  // namespace onnxruntime
