CUDNN_STATUS_SUCCESS = 0

CUDNN_SOFTMAX_FAST              = 0  ; straightforward implementation 
CUDNN_SOFTMAX_ACCURATE          = 1  ; subtract max from every point to avoid overflow
CUDNN_SOFTMAX_LOG               = 2
CUDNN_SOFTMAX_MODE_INSTANCE     = 0  ; compute the softmax over all C, H, W for each N 
CUDNN_SOFTMAX_MODE_CHANNEL      = 1  ; compute the softmax over all C for each H, W, N 

 CUDNN_DIM_MAX = 8 ;Maximum supported number of tensor dimensions
    CUDNN_POOLING_MAX = 0
    CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1  ; count for average includes padded values
    CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2  ;count for average does not include padded values
    CUDNN_POOLING_MAX_DETERMINISTIC             = 3  ;na 50% slower than POOLING_MAX
CUDNN_NOT_PROPAGATE_NAN = 0
CUDNN_TENSOR_NCHW = 0
CUDNN_TENSOR_NHWC        = 1
CUDNN_CONVOLUTION = 0
CUDNN_CROSS_CORRELATION = 1
CUDNN_PROPAGATE_NAN = 1
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST = 1
        CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM         = 0
    CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1
    CUDNN_CONVOLUTION_FWD_ALGO_GEMM                  = 2
    CUDNN_CONVOLUTION_FWD_ALGO_DIRECT                = 3
    CUDNN_CONVOLUTION_FWD_ALGO_FFT                   = 4
    CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING            = 5
    CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD              = 6
    CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED     = 7
    CUDNN_CONVOLUTION_FWD_ALGO_COUNT                 = 8

    CUDNN_ACTIVATION_SIGMOID      = 0
    CUDNN_ACTIVATION_RELU         = 1
    CUDNN_ACTIVATION_TANH         = 2
    CUDNN_ACTIVATION_CLIPPED_RELU = 3
    CUDNN_ACTIVATION_ELU          = 4
    CUDNN_ACTIVATION_IDENTITY     = 5
    CUDNN_ACTIVATION_SWISH        = 6
CUDNN_DTYPE = 0
CUDNN_DATA_FLOAT                         = 0
CUDNN_DATA_DOUBLE                        = 1
CUDNN_DATA_HALF                          = 2
CUDNN_DATA_INT8                          = 3
CUDNN_DATA_INT32                         = 4
;CUDNN_DATA_INT8x4 CUDNN_DEPRECATED_ENUM  = 5
CUDNN_DATA_UINT8                         = 6
;CUDNN_DATA_UINT8x4 CUDNN_DEPRECATED_ENUM = 7
;CUDNN_DATA_INT8x32 CUDNN_DEPRECATED_ENUM = 8
CUDNN_DATA_BFLOAT16                      = 9
CUDNN_DATA_INT64                         = 10
CUDNN_DATA_BOOLEAN                       = 11
CUDNN_DATA_FP8_E4M3                      = 12
CUDNN_DATA_FP8_E5M2                      = 13
CUDNN_DATA_FAST_FLOAT_FOR_FP8            = 14
CUDNN_DATA_FP8_E8M0                      = 15
CUDNN_DATA_FP4_E2M1                      = 16

CUDNN_DEFAULT_MATH                    = 0       ;for cudnnSetConvolutionMathType & cudnnSetRNNMatrixMathType
CUDNN_TENSOR_OP_MATH                  = 1
CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION = 2
CUDNN_FMA_MATH                        = 3

cudaMemcpyHostToHost = 0
cudaMemcpyHostToDevice = 1
cudaMemcpyDeviceToHost = 2
cudaMemcpyDeviceToDevice = 3
cudaMemcpyDefault = 4

macro GPU2mem mem,gpu,siz { invoke  cudaMemcpy,mem, [gpu],siz, cudaMemcpyDeviceToHost }

macro NewGPUData a,siz,to { local .1,.2
      invoke  cudaMalloc,a, siz
      test    eax,eax
      jz      .1
      Msg     'error cudaMalloc '#`a
      jmp     .2
.1:   invoke  cudaMemcpy,[a], to, siz, cudaMemcpyHostToDevice
.2:
}

macro mset [ar] { match v == n , ar \{ mov [v],n \} }

macro prInt a { invoke  sprintf,Temp,'%u',dword a
        if a eqtype 0 | a eqtype eax
        invoke  MessageBox,0,Temp,`a#' =',0
        end if
        if a eqtype [0]
        match e c=], a \{ ;display c
        ;match c,c \\{ fguu equ `\c
        ;              display fguu \\}
        invoke  MessageBox,0,Temp,\`c\#' =',0  \}
        end if
} 

macro GetImgSize imagesizbytes,a,b,c,d { mov eax,4
        mul [a]
        mul [b]
        mul [c]
        mul [d]
        mov [imagesizbytes],eax
}

macro cudaBgnTimer { miRpq   EvntStop,EvntStart
        invoke  cudaEventCreate,EvntStop
        invoke  cudaEventCreate,EvntStart
        invoke  cudaEventRecord,[EvntStart],0
} 

macro ifcuError tt=0 { test eax,eax
                  jz   @f
                  if   tt > 0
                       mov dword [Message],tt
                  end if
                  call GetcuError
@@:
}
macro miRpq [ar] { eq_miRpq  equ ar dq 0 }
macro miRpd [ar] { eq_miRpd  equ ar dd 0 }

macro SetTensor4 desc,a,b,c,d,d1,d2 { invoke  cudnnSetTensor4dDescriptor,[desc],a,b,c,d,d1,d2
        ifcuError
        }
macro SetFilter4 desc,a,b,c,d,d1,d2 { invoke  cudnnSetFilter4dDescriptor,[desc],a,b,c,d,d1,d2
        ifcuError
        }
macro SetPooling2D desc,a,b,c,d,d1,d2,c1,c2 { invoke  cudnnSetPooling2dDescriptor,[desc],a,b,c,d,d1,d2,c1,c2
        ifcuError
        }
macro SetConvolut2d desc,a,b,c,d,d1,d2,d3,d4 { invoke  cudnnSetConvolution2dDescriptor,[desc],a,b,c,d,\
                                             d1,d2,\ ;upscaleXY error if 2,2
                                             d3,d4
        ifcuError
        }


macro @cudnn wt,[ar]  { eq_miRpq  equ ar dq 0
                        if wt eq Pooling
                        invoke cudnnCreatePoolingDescriptor,ar
                        end if
                        if wt eq Activat
                        invoke cudnnCreateActivationDescriptor,ar
                        end if
                       if wt eq Tensor
                       ;display 'ten;'
                       invoke  cudnnCreateTensorDescriptor,ar
                       end if
                       if wt eq Convolut
                       invoke  cudnnCreateConvolutionDescriptor,ar
                       end if
                       if wt eq Filter
                       invoke  cudnnCreateFilterDescriptor,ar
                       end if
                       ifcuError
}                                            
