!
!     Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
!
! NVIDIA CORPORATION and its licensors retain all intellectual property
! and proprietary rights in and to this software, related documentation
! and any modifications thereto.  Any use, reproduction, disclosure or
! distribution of this software and related documentation without an express
! license agreement from NVIDIA CORPORATION is strictly prohibited.
!

program cufft2dompTest
  use cufft
  use cudafor ! for cudaDeviceSynchronize
  use omp_lib

  implicit none
  integer, parameter :: m=768, n=512
  complex, allocatable :: a(:,:),b(:,:),c(:,:)
  real   , allocatable :: r(:,:),q(:,:)
  integer :: iplan1, iplan2, iplan3, ierr
  real, dimension(4) :: res

  allocate( a(m,n), b(m,n), c(m,n) )
  allocate( r(m,n), q(m,n) )

  a = 1; r = 1

  ierr = cufftPlan2D(iplan1,m,n,CUFFT_C2C)
  ierr = ierr + cufftSetStream(iplan1,ompx_get_cuda_stream(omp_get_default_device(), .false.))

  ierr = ierr + cufftExecC2C(iplan1,a,b,CUFFT_FORWARD)
  ierr = ierr + cufftExecC2C(iplan1,b,c,CUFFT_INVERSE)

  ierr = ierr + cudaDeviceSynchronize()

  res(1) = maxval( real(b) ) - sum( real(b) )
  res(2) = maxval( imag(b) )
  res(3) = maxval( abs( a - c / (m*n) ) )

  ! Check forward answer
  write(*,*) 'Max error C2C FWD: ', cmplx( res(1), res(2) ) 

  ! Check inverse answer
  write(*,*) 'Max error C2C INV: ', res(3)

  ! Real transform
  ierr = ierr + cufftPlan2D(iplan2,m,n,CUFFT_R2C)
  ierr = ierr + cufftPlan2D(iplan3,m,n,CUFFT_C2R)
  ierr = ierr + cufftSetStream(iplan2,ompx_get_cuda_stream(omp_get_default_device(), .false.))
  ierr = ierr + cufftSetStream(iplan3,ompx_get_cuda_stream(omp_get_default_device(), .false.))

  ierr = ierr + cufftExecR2C(iplan2,r,b)
  ierr = ierr + cufftExecC2R(iplan3,b,q)

  ierr = ierr + cudaDeviceSynchronize()

  res(4) = maxval( abs( r - q / (m*n) ) )

  ! Check R2C + C2R answer
  write(*,*) 'Max error R2C/C2R: ', res(4)

  ierr = ierr + cufftDestroy(iplan1)
  ierr = ierr + cufftDestroy(iplan2)
  ierr = ierr + cufftDestroy(iplan3)

  deallocate( a, b, c )
  deallocate( r, q )

  if (ierr.eq.0 .and. maxval( res ) .lt. 1.e-6) then
    print *,"Test PASSED"
  else
    print *,"Test FAILED, max error =", maxval( res )
  endif

end program cufft2dompTest
