Mercurial > hg > Members > yuuhi > OpenCL
diff fft_Example/fft_kernelstring.cc @ 7:ea2e7ce9d5bb
add sample.pgm
author | Yuhi TOMARI <yuhi@cr.ie.u-ryukyu.ac.jp> |
---|---|
date | Tue, 05 Feb 2013 15:19:02 +0900 |
parents | 3602b23914ad |
children |
line wrap: on
line diff
--- a/fft_Example/fft_kernelstring.cc Tue Feb 05 15:12:19 2013 +0900 +++ b/fft_Example/fft_kernelstring.cc Tue Feb 05 15:19:02 2013 +0900 @@ -61,131 +61,132 @@ #define max(A,B) ((A) > (B) ? (A) : (B)) #define min(A,B) ((A) < (B) ? (A) : (B)) -static string +static string num2str(int num) { - char temp[200]; - sprintf(temp, "%d", num); - return string(temp); + char temp[200]; + sprintf(temp, "%d", num); + return string(temp); } -// For any n, this function decomposes n into factors for loacal memory tranpose +// For any n, this function decomposes n into factors for loacal memory tranpose // based fft. Factors (radices) are sorted such that the first one (radixArray[0]) // is the largest. This base radix determines the number of registers used by each // work item and product of remaining radices determine the size of work group needed. // To make things concrete with and example, suppose n = 1024. It is decomposed into -// 1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and +// 1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and // needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length -// 16 ffts (64 work items working in parallel) following by transpose using local +// 16 ffts (64 work items working in parallel) following by transpose using local // memory followed by again 64 length 16 ffts followed by transpose using local memory -// followed by 256 length 4 ffts. For the last step since with size of work group is +// followed by 256 length 4 ffts. For the last step since with size of work group is // 64 and each work item can array for 16 values, 64 work items can compute 256 length -// 4 ffts by each work item computing 4 length 4 ffts. +// 4 ffts by each work item computing 4 length 4 ffts. // Similarly for n = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work // iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed // by transpose using local memory, followed by 256 length 8 in-register ffts, followed // by transpose using local memory, followed by 256 length 8 in-register ffts, followed // by transpose using local memory, followed by 512 length 4 in-register ffts. Again, // for the last step, each work item computes two length 4 in-register ffts and thus -// 256 work items are needed to compute all 512 ffts. -// For n = 32 = 8 x 4, 4 work items first compute 4 in-register +// 256 work items are needed to compute all 512 ffts. +// For n = 32 = 8 x 4, 4 work items first compute 4 in-register // lenth 8 ffts, followed by transpose using local memory followed by 8 in-register // length 4 ffts, where each work item computes two length 4 ffts thus 4 work items -// can compute 8 length 4 ffts. However if work group size of say 64 is choosen, -// each work group can compute 64/ 4 = 16 size 32 ffts (batched transform). +// can compute 8 length 4 ffts. However if work group size of say 64 is choosen, +// each work group can compute 64/ 4 = 16 size 32 ffts (batched transform). // Users can play with these parameters to figure what gives best performance on // their particular device i.e. some device have less register space thus using -// smaller base radix can avoid spilling ... some has small local memory thus +// smaller base radix can avoid spilling ... some has small local memory thus // using smaller work group size may be required etc -static void +static void getRadixArray(unsigned int n, unsigned int *radixArray, unsigned int *numRadices, unsigned int maxRadix) { if(maxRadix > 1) - { - maxRadix = min(n, maxRadix); - unsigned int cnt = 0; - while(n > maxRadix) + { + maxRadix = min(n, maxRadix); + unsigned int cnt = 0; + while(n > maxRadix) + { + radixArray[cnt++] = maxRadix; + n /= maxRadix; + } + radixArray[cnt++] = n; + *numRadices = cnt; + return; + } + + switch(n) { - radixArray[cnt++] = maxRadix; - n /= maxRadix; - } - radixArray[cnt++] = n; - *numRadices = cnt; - return; - } + case 2: + *numRadices = 1; + radixArray[0] = 2; + break; + + case 4: + *numRadices = 1; + radixArray[0] = 4; + break; + + case 8: + *numRadices = 1; + radixArray[0] = 8; + break; + + case 16: + *numRadices = 2; + radixArray[0] = 8; radixArray[1] = 2; + break; - switch(n) - { - case 2: - *numRadices = 1; - radixArray[0] = 2; - break; - - case 4: - *numRadices = 1; - radixArray[0] = 4; - break; - - case 8: - *numRadices = 1; - radixArray[0] = 8; - break; - - case 16: - *numRadices = 2; - radixArray[0] = 8; radixArray[1] = 2; - break; - - case 32: - *numRadices = 2; - radixArray[0] = 8; radixArray[1] = 4; - break; - - case 64: - *numRadices = 2; - radixArray[0] = 8; radixArray[1] = 8; - break; - - case 128: - *numRadices = 3; - radixArray[0] = 8; radixArray[1] = 4; radixArray[2] = 4; - break; - - case 256: - *numRadices = 4; - radixArray[0] = 4; radixArray[1] = 4; radixArray[2] = 4; radixArray[3] = 4; - break; - - case 512: - *numRadices = 3; - radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; - break; - - case 1024: - *numRadices = 3; - radixArray[0] = 16; radixArray[1] = 16; radixArray[2] = 4; - break; - case 2048: - *numRadices = 4; - radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; radixArray[3] = 4; - break; - default: - *numRadices = 0; - return; - } + case 32: + *numRadices = 2; + radixArray[0] = 8; radixArray[1] = 4; + break; + + case 64: + *numRadices = 2; + radixArray[0] = 8; radixArray[1] = 8; + break; + + case 128: + *numRadices = 3; + radixArray[0] = 8; radixArray[1] = 4; radixArray[2] = 4; + break; + + case 256: + *numRadices = 4; + radixArray[0] = 4; radixArray[1] = 4; radixArray[2] = 4; radixArray[3] = 4; + break; + + case 512: + *numRadices = 3; + radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; + break; + + case 1024: + *numRadices = 3; + radixArray[0] = 16; radixArray[1] = 16; radixArray[2] = 4; + break; + case 2048: + *numRadices = 4; + radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; radixArray[3] = 4; + break; + default: + *numRadices = 0; + return; + } } static void insertHeader(string &kernelString, string &kernelName, clFFT_DataFormat dataFormat) { - if(dataFormat == clFFT_SplitComplexFormat) - kernelString += string("__kernel void ") + kernelName + string("(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n"); - else - kernelString += string("__kernel void ") + kernelName + string("(__global float2 *in, __global float2 *out, int dir, int S)\n"); + if(dataFormat == clFFT_SplitComplexFormat) + kernelString += string("__kernel void ") + kernelName + string("(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n"); + else + kernelString += string("__kernel void ") + kernelName + string("(__global float2 *in, __global float2 *out, int dir, int S)\n"); + printf("%s\n",kernelName.c_str()); } -static void +static void insertVariables(string &kStream, int maxRadix) { kStream += string(" int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n"); @@ -202,1056 +203,1059 @@ static void formattedLoad(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat) { - if(dataFormat == clFFT_InterleavedComplexFormat) - kernelString += string(" a[") + num2str(aIndex) + string("] = in[") + num2str(gIndex) + string("];\n"); - else - { - kernelString += string(" a[") + num2str(aIndex) + string("].x = in_real[") + num2str(gIndex) + string("];\n"); - kernelString += string(" a[") + num2str(aIndex) + string("].y = in_imag[") + num2str(gIndex) + string("];\n"); - } + if(dataFormat == clFFT_InterleavedComplexFormat) + kernelString += string(" a[") + num2str(aIndex) + string("] = in[") + num2str(gIndex) + string("];\n"); + else + { + kernelString += string(" a[") + num2str(aIndex) + string("].x = in_real[") + num2str(gIndex) + string("];\n"); + kernelString += string(" a[") + num2str(aIndex) + string("].y = in_imag[") + num2str(gIndex) + string("];\n"); + } } static void formattedStore(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat) { - if(dataFormat == clFFT_InterleavedComplexFormat) - kernelString += string(" out[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("];\n"); - else - { - kernelString += string(" out_real[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].x;\n"); - kernelString += string(" out_imag[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].y;\n"); - } + if(dataFormat == clFFT_InterleavedComplexFormat) + kernelString += string(" out[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("];\n"); + else + { + kernelString += string(" out_real[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].x;\n"); + kernelString += string(" out_imag[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].y;\n"); + } } static int insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXForm, int numXFormsPerWG, int R0, int mem_coalesce_width, clFFT_DataFormat dataFormat) { - int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm); - int groupSize = numWorkItemsPerXForm * numXFormsPerWG; - int i, j; - int lMemSize = 0; - - if(numXFormsPerWG > 1) - kernelString += string(" s = S & ") + num2str(numXFormsPerWG - 1) + string(";\n"); - + int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm); + int groupSize = numWorkItemsPerXForm * numXFormsPerWG; + int i, j; + int lMemSize = 0; + + if(numXFormsPerWG > 1) + kernelString += string(" s = S & ") + num2str(numXFormsPerWG - 1) + string(";\n"); + if(numWorkItemsPerXForm >= mem_coalesce_width) - { - if(numXFormsPerWG > 1) - { - kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm-1) + string(";\n"); - kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); - kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n"); - kernelString += string(" offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n"); - if(dataFormat == clFFT_InterleavedComplexFormat) - { - kernelString += string(" in += offset;\n"); - kernelString += string(" out += offset;\n"); - } - else - { - kernelString += string(" in_real += offset;\n"); - kernelString += string(" in_imag += offset;\n"); - kernelString += string(" out_real += offset;\n"); - kernelString += string(" out_imag += offset;\n"); - } - for(i = 0; i < R0; i++) - formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat); - kernelString += string(" }\n"); - } - else - { - kernelString += string(" ii = lId;\n"); - kernelString += string(" jj = 0;\n"); - kernelString += string(" offset = mad24(groupId, ") + num2str(N) + string(", ii);\n"); - if(dataFormat == clFFT_InterleavedComplexFormat) - { - kernelString += string(" in += offset;\n"); - kernelString += string(" out += offset;\n"); - } - else - { - kernelString += string(" in_real += offset;\n"); - kernelString += string(" in_imag += offset;\n"); - kernelString += string(" out_real += offset;\n"); - kernelString += string(" out_imag += offset;\n"); - } - for(i = 0; i < R0; i++) - formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat); - } - } - else if( N >= mem_coalesce_width ) - { - int numInnerIter = N / mem_coalesce_width; - int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width ); - - kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n"); - kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n"); - kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); - kernelString += string(" offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n"); - kernelString += string(" offset = mad24( offset, ") + num2str(N) + string(", ii );\n"); - if(dataFormat == clFFT_InterleavedComplexFormat) - { - kernelString += string(" in += offset;\n"); - kernelString += string(" out += offset;\n"); - } - else - { - kernelString += string(" in_real += offset;\n"); - kernelString += string(" in_imag += offset;\n"); - kernelString += string(" out_real += offset;\n"); - kernelString += string(" out_imag += offset;\n"); - } - - kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); - for(i = 0; i < numOuterIter; i++ ) { - kernelString += string(" if( jj < s ) {\n"); - for(j = 0; j < numInnerIter; j++ ) - formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat); - kernelString += string(" }\n"); - if(i != numOuterIter - 1) - kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n"); + if(numXFormsPerWG > 1) + { + kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm-1) + string(";\n"); + kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); + kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n"); + kernelString += string(" offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n"); + if(dataFormat == clFFT_InterleavedComplexFormat) + { + kernelString += string(" in += offset;\n"); + kernelString += string(" out += offset;\n"); + } + else + { + kernelString += string(" in_real += offset;\n"); + kernelString += string(" in_imag += offset;\n"); + kernelString += string(" out_real += offset;\n"); + kernelString += string(" out_imag += offset;\n"); + } + for(i = 0; i < R0; i++) + formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat); + kernelString += string(" }\n"); + } + else + { + kernelString += string(" ii = lId;\n"); + kernelString += string(" jj = 0;\n"); + kernelString += string(" offset = mad24(groupId, ") + num2str(N) + string(", ii);\n"); + if(dataFormat == clFFT_InterleavedComplexFormat) + { + kernelString += string(" in += offset;\n"); + kernelString += string(" out += offset;\n"); + } + else + { + kernelString += string(" in_real += offset;\n"); + kernelString += string(" in_imag += offset;\n"); + kernelString += string(" out_real += offset;\n"); + kernelString += string(" out_imag += offset;\n"); + } + for(i = 0; i < R0; i++) + formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat); + } } - kernelString += string("}\n "); - kernelString += string("else {\n"); - for(i = 0; i < numOuterIter; i++ ) + else if( N >= mem_coalesce_width ) { - for(j = 0; j < numInnerIter; j++ ) - formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat); - } - kernelString += string("}\n"); - - kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n"); - kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); - kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n"); - - for( i = 0; i < numOuterIter; i++ ) - { - for( j = 0; j < numInnerIter; j++ ) - { - kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") + - num2str(i * numInnerIter + j) + string("].x;\n"); - } - } - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < R0; i++ ) - kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + int numInnerIter = N / mem_coalesce_width; + int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width ); - for( i = 0; i < numOuterIter; i++ ) - { - for( j = 0; j < numInnerIter; j++ ) - { - kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") + - num2str(i * numInnerIter + j) + string("].y;\n"); - } - } - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < R0; i++ ) - kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; - } - else - { - kernelString += string(" offset = mad24( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n"); - if(dataFormat == clFFT_InterleavedComplexFormat) - { - kernelString += string(" in += offset;\n"); - kernelString += string(" out += offset;\n"); - } - else - { - kernelString += string(" in_real += offset;\n"); - kernelString += string(" in_imag += offset;\n"); - kernelString += string(" out_real += offset;\n"); - kernelString += string(" out_imag += offset;\n"); - } - - kernelString += string(" ii = lId & ") + num2str(N-1) + string(";\n"); - kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n"); - kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); - - kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); - for( i = 0; i < R0; i++ ) - { - kernelString += string(" if(jj < s )\n"); - formattedLoad(kernelString, i, i*groupSize, dataFormat); - if(i != R0 - 1) - kernelString += string(" jj += ") + num2str(groupSize / N) + string(";\n"); - } - kernelString += string("}\n"); - kernelString += string("else {\n"); - for( i = 0; i < R0; i++ ) - { - formattedLoad(kernelString, i, i*groupSize, dataFormat); - } - kernelString += string("}\n"); - - if(numWorkItemsPerXForm > 1) - { + kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n"); + kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n"); + kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n"); + kernelString += string(" offset = mad24( offset, ") + num2str(N) + string(", ii );\n"); + if(dataFormat == clFFT_InterleavedComplexFormat) + { + kernelString += string(" in += offset;\n"); + kernelString += string(" out += offset;\n"); + } + else + { + kernelString += string(" in_real += offset;\n"); + kernelString += string(" in_imag += offset;\n"); + kernelString += string(" out_real += offset;\n"); + kernelString += string(" out_imag += offset;\n"); + } + + kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); + for(i = 0; i < numOuterIter; i++ ) + { + kernelString += string(" if( jj < s ) {\n"); + for(j = 0; j < numInnerIter; j++ ) + formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat); + kernelString += string(" }\n"); + if(i != numOuterIter - 1) + kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n"); + } + kernelString += string("}\n "); + kernelString += string("else {\n"); + for(i = 0; i < numOuterIter; i++ ) + { + for(j = 0; j < numInnerIter; j++ ) + formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat); + } + kernelString += string("}\n"); + kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n"); kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); - kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); - } - else - { - kernelString += string(" ii = 0;\n"); - kernelString += string(" jj = lId;\n"); - kernelString += string(" lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n"); - } + kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n"); + + for( i = 0; i < numOuterIter; i++ ) + { + for( j = 0; j < numInnerIter; j++ ) + { + kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") + + num2str(i * numInnerIter + j) + string("].x;\n"); + } + } + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < R0; i++ ) + kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < numOuterIter; i++ ) + { + for( j = 0; j < numInnerIter; j++ ) + { + kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") + + num2str(i * numInnerIter + j) + string("].y;\n"); + } + } + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < R0; i++ ) + kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; + } + else + { + kernelString += string(" offset = mad24( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n"); + if(dataFormat == clFFT_InterleavedComplexFormat) + { + kernelString += string(" in += offset;\n"); + kernelString += string(" out += offset;\n"); + } + else + { + kernelString += string(" in_real += offset;\n"); + kernelString += string(" in_imag += offset;\n"); + kernelString += string(" out_real += offset;\n"); + kernelString += string(" out_imag += offset;\n"); + } - - for( i = 0; i < R0; i++ ) - kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].x;\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < R0; i++ ) - kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < R0; i++ ) - kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].y;\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < R0; i++ ) - kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; - } - - return lMemSize; + kernelString += string(" ii = lId & ") + num2str(N-1) + string(";\n"); + kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n"); + kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + + kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); + for( i = 0; i < R0; i++ ) + { + kernelString += string(" if(jj < s )\n"); + formattedLoad(kernelString, i, i*groupSize, dataFormat); + if(i != R0 - 1) + kernelString += string(" jj += ") + num2str(groupSize / N) + string(";\n"); + } + kernelString += string("}\n"); + kernelString += string("else {\n"); + for( i = 0; i < R0; i++ ) + { + formattedLoad(kernelString, i, i*groupSize, dataFormat); + } + kernelString += string("}\n"); + + if(numWorkItemsPerXForm > 1) + { + kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n"); + kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); + kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + } + else + { + kernelString += string(" ii = 0;\n"); + kernelString += string(" jj = lId;\n"); + kernelString += string(" lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n"); + } + + + for( i = 0; i < R0; i++ ) + kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].x;\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < R0; i++ ) + kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < R0; i++ ) + kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].y;\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < R0; i++ ) + kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; + } + + return lMemSize; } static int insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr, int numWorkItemsPerXForm, int numXFormsPerWG, int mem_coalesce_width, clFFT_DataFormat dataFormat) { - int groupSize = numWorkItemsPerXForm * numXFormsPerWG; - int i, j, k, ind; - int lMemSize = 0; - int numIter = maxRadix / Nr; - string indent = string(""); - + int groupSize = numWorkItemsPerXForm * numXFormsPerWG; + int i, j, k, ind; + int lMemSize = 0; + int numIter = maxRadix / Nr; + string indent = string(""); + if( numWorkItemsPerXForm >= mem_coalesce_width ) - { - if(numXFormsPerWG > 1) - { - kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n"); - indent = string(" "); - } - for(i = 0; i < maxRadix; i++) - { - j = i % numIter; - k = i / numIter; - ind = j * Nr + k; - formattedStore(kernelString, ind, i*numWorkItemsPerXForm, dataFormat); - } - if(numXFormsPerWG > 1) - kernelString += string(" }\n"); - } + { + if(numXFormsPerWG > 1) + { + kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n"); + indent = string(" "); + } + for(i = 0; i < maxRadix; i++) + { + j = i % numIter; + k = i / numIter; + ind = j * Nr + k; + formattedStore(kernelString, ind, i*numWorkItemsPerXForm, dataFormat); + } + if(numXFormsPerWG > 1) + kernelString += string(" }\n"); + } else if( N >= mem_coalesce_width ) - { - int numInnerIter = N / mem_coalesce_width; - int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width ); - - kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); - kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n"); - kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n"); - kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); - - for( i = 0; i < maxRadix; i++ ) - { - j = i % numIter; - k = i / numIter; - ind = j * Nr + k; - kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n"); - } - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < numOuterIter; i++ ) - for( j = 0; j < numInnerIter; j++ ) - kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].x = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < maxRadix; i++ ) - { - j = i % numIter; - k = i / numIter; - ind = j * Nr + k; - kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n"); - } - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < numOuterIter; i++ ) - for( j = 0; j < numInnerIter; j++ ) - kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].y = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); - for(i = 0; i < numOuterIter; i++ ) { - kernelString += string(" if( jj < s ) {\n"); - for(j = 0; j < numInnerIter; j++ ) - formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat); - kernelString += string(" }\n"); - if(i != numOuterIter - 1) - kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n"); + int numInnerIter = N / mem_coalesce_width; + int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width ); + + kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n"); + kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n"); + kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + + for( i = 0; i < maxRadix; i++ ) + { + j = i % numIter; + k = i / numIter; + ind = j * Nr + k; + kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n"); + } + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < numOuterIter; i++ ) + for( j = 0; j < numInnerIter; j++ ) + kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].x = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < maxRadix; i++ ) + { + j = i % numIter; + k = i / numIter; + ind = j * Nr + k; + kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n"); + } + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < numOuterIter; i++ ) + for( j = 0; j < numInnerIter; j++ ) + kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].y = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); + for(i = 0; i < numOuterIter; i++ ) + { + kernelString += string(" if( jj < s ) {\n"); + for(j = 0; j < numInnerIter; j++ ) + formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat); + kernelString += string(" }\n"); + if(i != numOuterIter - 1) + kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n"); + } + kernelString += string("}\n"); + kernelString += string("else {\n"); + for(i = 0; i < numOuterIter; i++ ) + { + for(j = 0; j < numInnerIter; j++ ) + formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat); + } + kernelString += string("}\n"); + + lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; } - kernelString += string("}\n"); - kernelString += string("else {\n"); - for(i = 0; i < numOuterIter; i++ ) - { - for(j = 0; j < numInnerIter; j++ ) - formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat); - } - kernelString += string("}\n"); - - lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; - } else - { - kernelString += string(" lMemLoad = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); - - kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n"); - kernelString += string(" jj = lId >> ") + num2str((int) log2(N)) + string(";\n"); - kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); - - for( i = 0; i < maxRadix; i++ ) - { - j = i % numIter; - k = i / numIter; - ind = j * Nr + k; - kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n"); - } - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < maxRadix; i++ ) - kernelString += string(" a[") + num2str(i) + string("].x = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < maxRadix; i++ ) - { - j = i % numIter; - k = i / numIter; - ind = j * Nr + k; - kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n"); - } - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - for( i = 0; i < maxRadix; i++ ) - kernelString += string(" a[") + num2str(i) + string("].y = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n"); - kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); - - kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); - for( i = 0; i < maxRadix; i++ ) { - kernelString += string(" if(jj < s ) {\n"); - formattedStore(kernelString, i, i*groupSize, dataFormat); - kernelString += string(" }\n"); - if( i != maxRadix - 1) - kernelString += string(" jj +=") + num2str(groupSize / N) + string(";\n"); - } - kernelString += string("}\n"); - kernelString += string("else {\n"); - for( i = 0; i < maxRadix; i++ ) - { - formattedStore(kernelString, i, i*groupSize, dataFormat); - } - kernelString += string("}\n"); - - lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; - } - - return lMemSize; + kernelString += string(" lMemLoad = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + + kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n"); + kernelString += string(" jj = lId >> ") + num2str((int) log2(N)) + string(";\n"); + kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + + for( i = 0; i < maxRadix; i++ ) + { + j = i % numIter; + k = i / numIter; + ind = j * Nr + k; + kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n"); + } + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < maxRadix; i++ ) + kernelString += string(" a[") + num2str(i) + string("].x = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < maxRadix; i++ ) + { + j = i % numIter; + k = i / numIter; + ind = j * Nr + k; + kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n"); + } + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + for( i = 0; i < maxRadix; i++ ) + kernelString += string(" a[") + num2str(i) + string("].y = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n"); + kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); + + kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); + for( i = 0; i < maxRadix; i++ ) + { + kernelString += string(" if(jj < s ) {\n"); + formattedStore(kernelString, i, i*groupSize, dataFormat); + kernelString += string(" }\n"); + if( i != maxRadix - 1) + kernelString += string(" jj +=") + num2str(groupSize / N) + string(";\n"); + } + kernelString += string("}\n"); + kernelString += string("else {\n"); + for( i = 0; i < maxRadix; i++ ) + { + formattedStore(kernelString, i, i*groupSize, dataFormat); + } + kernelString += string("}\n"); + + lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; + } + + return lMemSize; } -static void +static void insertfftKernel(string &kernelString, int Nr, int numIter) { - int i; - for(i = 0; i < numIter; i++) - { - kernelString += string(" fftKernel") + num2str(Nr) + string("(a+") + num2str(i*Nr) + string(", dir);\n"); - } + int i; + for(i = 0; i < numIter; i++) + { + kernelString += string(" fftKernel") + num2str(Nr) + string("(a+") + num2str(i*Nr) + string(", dir);\n"); + } } static void insertTwiddleKernel(string &kernelString, int Nr, int numIter, int Nprev, int len, int numWorkItemsPerXForm) { - int z, k; - int logNPrev = (int)log2(Nprev); - - for(z = 0; z < numIter; z++) - { - if(z == 0) - { - if(Nprev > 1) - kernelString += string(" angf = (float) (ii >> ") + num2str(logNPrev) + string(");\n"); - else - kernelString += string(" angf = (float) ii;\n"); - } - else - { - if(Nprev > 1) - kernelString += string(" angf = (float) ((") + num2str(z*numWorkItemsPerXForm) + string(" + ii) >>") + num2str(logNPrev) + string(");\n"); - else - kernelString += string(" angf = (float) (") + num2str(z*numWorkItemsPerXForm) + string(" + ii);\n"); - } - - for(k = 1; k < Nr; k++) { - int ind = z*Nr + k; - //float fac = (float) (2.0 * M_PI * (double) k / (double) len); - kernelString += string(" ang = dir * ( 2.0f * M_PI * ") + num2str(k) + string(".0f / ") + num2str(len) + string(".0f )") + string(" * angf;\n"); - kernelString += string(" w = (float2)(native_cos(ang), native_sin(ang));\n"); - kernelString += string(" a[") + num2str(ind) + string("] = complexMul(a[") + num2str(ind) + string("], w);\n"); - } - } + int z, k; + int logNPrev = (int)log2(Nprev); + + for(z = 0; z < numIter; z++) + { + if(z == 0) + { + if(Nprev > 1) + kernelString += string(" angf = (float) (ii >> ") + num2str(logNPrev) + string(");\n"); + else + kernelString += string(" angf = (float) ii;\n"); + } + else + { + if(Nprev > 1) + kernelString += string(" angf = (float) ((") + num2str(z*numWorkItemsPerXForm) + string(" + ii) >>") + num2str(logNPrev) + string(");\n"); + else + kernelString += string(" angf = (float) (") + num2str(z*numWorkItemsPerXForm) + string(" + ii);\n"); + } + + for(k = 1; k < Nr; k++) { + int ind = z*Nr + k; + //float fac = (float) (2.0 * M_PI * (double) k / (double) len); + kernelString += string(" ang = dir * ( 2.0f * M_PI * ") + num2str(k) + string(".0f / ") + num2str(len) + string(".0f )") + string(" * angf;\n"); + kernelString += string(" w = (float2)(native_cos(ang), native_sin(ang));\n"); + kernelString += string(" a[") + num2str(ind) + string("] = complexMul(a[") + num2str(ind) + string("], w);\n"); + } + } } static int getPadding(int numWorkItemsPerXForm, int Nprev, int numWorkItemsReq, int numXFormsPerWG, int Nr, int numBanks, int *offset, int *midPad) { - if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks)) - *offset = 0; - else { - int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev; - int numColsReq = 1; - if(numRowsReq > Nr) - numColsReq = numRowsReq / Nr; - numColsReq = Nprev * numColsReq; - *offset = numColsReq; - } - - if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1) - *midPad = 0; - else { - int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1); - if( bankNum >= numWorkItemsPerXForm ) - *midPad = 0; - else - *midPad = numWorkItemsPerXForm - bankNum; - } - - int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1); - return lMemSize; + if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks)) + *offset = 0; + else { + int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev; + int numColsReq = 1; + if(numRowsReq > Nr) + numColsReq = numRowsReq / Nr; + numColsReq = Nprev * numColsReq; + *offset = numColsReq; + } + + if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1) + *midPad = 0; + else { + int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1); + if( bankNum >= numWorkItemsPerXForm ) + *midPad = 0; + else + *midPad = numWorkItemsPerXForm - bankNum; + } + + int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1); + return lMemSize; } -static void +static void insertLocalStores(string &kernelString, int numIter, int Nr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp) { - int z, k; + int z, k; - for(z = 0; z < numIter; z++) { - for(k = 0; k < Nr; k++) { - int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm; - kernelString += string(" lMemStore[") + num2str(index) + string("] = a[") + num2str(z*Nr + k) + string("].") + comp + string(";\n"); - } - } - kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n"); + for(z = 0; z < numIter; z++) { + for(k = 0; k < Nr; k++) { + int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm; + kernelString += string(" lMemStore[") + num2str(index) + string("] = a[") + num2str(z*Nr + k) + string("].") + comp + string(";\n"); + } + } + kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n"); } -static void +static void insertLocalLoads(string &kernelString, int n, int Nr, int Nrn, int Nprev, int Ncurr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp) { - int numWorkItemsReqN = n / Nrn; - int interBlockHNum = max( Nprev / numWorkItemsPerXForm, 1 ); - int interBlockHStride = numWorkItemsPerXForm; - int vertWidth = max(numWorkItemsPerXForm / Nprev, 1); - vertWidth = min( vertWidth, Nr); - int vertNum = Nr / vertWidth; - int vertStride = ( n / Nr + offset ) * vertWidth; - int iter = max( numWorkItemsReqN / numWorkItemsPerXForm, 1); - int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1; - intraBlockHStride *= Nprev; - - int stride = numWorkItemsReq / Nrn; - int i; - for(i = 0; i < iter; i++) { - int ii = i / (interBlockHNum * vertNum); - int zz = i % (interBlockHNum * vertNum); - int jj = zz % interBlockHNum; - int kk = zz / interBlockHNum; - int z; - for(z = 0; z < Nrn; z++) { - int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride; - kernelString += string(" a[") + num2str(i*Nrn + z) + string("].") + comp + string(" = lMemLoad[") + num2str(st) + string("];\n"); - } - } - kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n"); + int numWorkItemsReqN = n / Nrn; + int interBlockHNum = max( Nprev / numWorkItemsPerXForm, 1 ); + int interBlockHStride = numWorkItemsPerXForm; + int vertWidth = max(numWorkItemsPerXForm / Nprev, 1); + vertWidth = min( vertWidth, Nr); + int vertNum = Nr / vertWidth; + int vertStride = ( n / Nr + offset ) * vertWidth; + int iter = max( numWorkItemsReqN / numWorkItemsPerXForm, 1); + int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1; + intraBlockHStride *= Nprev; + + int stride = numWorkItemsReq / Nrn; + int i; + for(i = 0; i < iter; i++) { + int ii = i / (interBlockHNum * vertNum); + int zz = i % (interBlockHNum * vertNum); + int jj = zz % interBlockHNum; + int kk = zz / interBlockHNum; + int z; + for(z = 0; z < Nrn; z++) { + int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride; + kernelString += string(" a[") + num2str(i*Nrn + z) + string("].") + comp + string(" = lMemLoad[") + num2str(st) + string("];\n"); + } + } + kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n"); } static void insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numWorkItemsReq, int numWorkItemsPerXForm, int numXFormsPerWG, int offset, int midPad) -{ - int Ncurr = Nprev * Nr; - int logNcurr = (int)log2(Ncurr); - int logNprev = (int)log2(Nprev); - int incr = (numWorkItemsReq + offset) * Nr + midPad; - - if(Ncurr < numWorkItemsPerXForm) - { - if(Nprev == 1) - kernelString += string(" j = ii & ") + num2str(Ncurr - 1) + string(";\n"); - else - kernelString += string(" j = (ii & ") + num2str(Ncurr - 1) + string(") >> ") + num2str(logNprev) + string(";\n"); - - if(Nprev == 1) - kernelString += string(" i = ii >> ") + num2str(logNcurr) + string(";\n"); - else - kernelString += string(" i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n"); - } - else - { - if(Nprev == 1) - kernelString += string(" j = ii;\n"); - else - kernelString += string(" j = ii >> ") + num2str(logNprev) + string(";\n"); - if(Nprev == 1) - kernelString += string(" i = 0;\n"); - else - kernelString += string(" i = ii & ") + num2str(Nprev - 1) + string(";\n"); - } +{ + int Ncurr = Nprev * Nr; + int logNcurr = (int)log2(Ncurr); + int logNprev = (int)log2(Nprev); + int incr = (numWorkItemsReq + offset) * Nr + midPad; + + if(Ncurr < numWorkItemsPerXForm) + { + if(Nprev == 1) + kernelString += string(" j = ii & ") + num2str(Ncurr - 1) + string(";\n"); + else + kernelString += string(" j = (ii & ") + num2str(Ncurr - 1) + string(") >> ") + num2str(logNprev) + string(";\n"); + + if(Nprev == 1) + kernelString += string(" i = ii >> ") + num2str(logNcurr) + string(";\n"); + else + kernelString += string(" i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n"); + } + else + { + if(Nprev == 1) + kernelString += string(" j = ii;\n"); + else + kernelString += string(" j = ii >> ") + num2str(logNprev) + string(";\n"); + if(Nprev == 1) + kernelString += string(" i = 0;\n"); + else + kernelString += string(" i = ii & ") + num2str(Nprev - 1) + string(";\n"); + } if(numXFormsPerWG > 1) - kernelString += string(" i = mad24(jj, ") + num2str(incr) + string(", i);\n"); + kernelString += string(" i = mad24(jj, ") + num2str(incr) + string(", i);\n"); - kernelString += string(" lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n"); + kernelString += string(" lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n"); } static void insertLocalStoreIndexArithmatic(string &kernelString, int numWorkItemsReq, int numXFormsPerWG, int Nr, int offset, int midPad) { - if(numXFormsPerWG == 1) { - kernelString += string(" lMemStore = sMem + ii;\n"); - } - else { - kernelString += string(" lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n"); - } + if(numXFormsPerWG == 1) { + kernelString += string(" lMemStore = sMem + ii;\n"); + } + else { + kernelString += string(" lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n"); + } } static void createLocalMemfftKernelString(cl_fft_plan *plan) { - unsigned int radixArray[10]; - unsigned int numRadix; - - unsigned int n = plan->n.x; - - assert(n <= plan->max_work_item_per_workgroup * plan->max_radix && "signal lenght too big for local mem fft\n"); - - getRadixArray(n, radixArray, &numRadix, 0); - assert(numRadix > 0 && "no radix array supplied\n"); - - if(n/radixArray[0] > plan->max_work_item_per_workgroup) - getRadixArray(n, radixArray, &numRadix, plan->max_radix); + unsigned int radixArray[10]; + unsigned int numRadix; + + unsigned int n = plan->n.x; + + assert(n <= plan->max_work_item_per_workgroup * plan->max_radix && "signal lenght too big for local mem fft\n"); + + getRadixArray(n, radixArray, &numRadix, 0); + assert(numRadix > 0 && "no radix array supplied\n"); + + if(n/radixArray[0] > plan->max_work_item_per_workgroup) + getRadixArray(n, radixArray, &numRadix, plan->max_radix); + + assert(radixArray[0] <= plan->max_radix && "max radix choosen is greater than allowed\n"); + assert(n/radixArray[0] <= plan->max_work_item_per_workgroup && "required work items per xform greater than maximum work items allowed per work group for local mem fft\n"); + + unsigned int tmpLen = 1; + unsigned int i; + for(i = 0; i < numRadix; i++) + { + assert( radixArray[i] && !( (radixArray[i] - 1) & radixArray[i] ) ); + tmpLen *= radixArray[i]; + } + assert(tmpLen == n && "product of radices choosen doesnt match the length of signal\n"); + + int offset, midPad; + string localString(""), kernelName(""); + + clFFT_DataFormat dataFormat = plan->format; + string *kernelString = plan->kernel_string; + + + cl_fft_kernel_info **kInfo = &plan->kernel_info; + int kCount = 0; + + while(*kInfo) + { + kInfo = &(*kInfo)->next; + kCount++; + } + + kernelName = string("fft") + num2str(kCount); + + *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info)); + (*kInfo)->kernel = 0; + (*kInfo)->lmem_size = 0; + (*kInfo)->num_workgroups = 0; + (*kInfo)->num_workitems_per_workgroup = 0; + (*kInfo)->dir = cl_fft_kernel_x; + (*kInfo)->in_place_possible = 1; + (*kInfo)->next = NULL; + (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1)); + strcpy((*kInfo)->kernel_name, kernelName.c_str()); - assert(radixArray[0] <= plan->max_radix && "max radix choosen is greater than allowed\n"); - assert(n/radixArray[0] <= plan->max_work_item_per_workgroup && "required work items per xform greater than maximum work items allowed per work group for local mem fft\n"); - - unsigned int tmpLen = 1; - unsigned int i; - for(i = 0; i < numRadix; i++) - { - assert( radixArray[i] && !( (radixArray[i] - 1) & radixArray[i] ) ); - tmpLen *= radixArray[i]; - } - assert(tmpLen == n && "product of radices choosen doesnt match the length of signal\n"); - - int offset, midPad; - string localString(""), kernelName(""); - - clFFT_DataFormat dataFormat = plan->format; - string *kernelString = plan->kernel_string; - - - cl_fft_kernel_info **kInfo = &plan->kernel_info; - int kCount = 0; - - while(*kInfo) - { - kInfo = &(*kInfo)->next; - kCount++; - } - - kernelName = string("fft") + num2str(kCount); - - *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info)); - (*kInfo)->kernel = 0; - (*kInfo)->lmem_size = 0; - (*kInfo)->num_workgroups = 0; - (*kInfo)->num_workitems_per_workgroup = 0; - (*kInfo)->dir = cl_fft_kernel_x; - (*kInfo)->in_place_possible = 1; - (*kInfo)->next = NULL; - (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1)); - strcpy((*kInfo)->kernel_name, kernelName.c_str()); - - unsigned int numWorkItemsPerXForm = n / radixArray[0]; - unsigned int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm; - assert(numWorkItemsPerWG <= plan->max_work_item_per_workgroup); - int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm; - (*kInfo)->num_workgroups = 1; + unsigned int numWorkItemsPerXForm = n / radixArray[0]; + unsigned int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm; + assert(numWorkItemsPerWG <= plan->max_work_item_per_workgroup); + int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm; + (*kInfo)->num_workgroups = 1; (*kInfo)->num_xforms_per_workgroup = numXFormsPerWG; - (*kInfo)->num_workitems_per_workgroup = numWorkItemsPerWG; - - unsigned int *N = radixArray; - unsigned int maxRadix = N[0]; - unsigned int lMemSize = 0; - - insertVariables(localString, maxRadix); - - lMemSize = insertGlobalLoadsAndTranspose(localString, n, numWorkItemsPerXForm, numXFormsPerWG, maxRadix, plan->min_mem_coalesce_width, dataFormat); - (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; - - string xcomp = string("x"); - string ycomp = string("y"); - - unsigned int Nprev = 1; - unsigned int len = n; - unsigned int r; - for(r = 0; r < numRadix; r++) - { - int numIter = N[0] / N[r]; - int numWorkItemsReq = n / N[r]; - int Ncurr = Nprev * N[r]; - insertfftKernel(localString, N[r], numIter); - - if(r < (numRadix - 1)) { - insertTwiddleKernel(localString, N[r], numIter, Nprev, len, numWorkItemsPerXForm); - lMemSize = getPadding(numWorkItemsPerXForm, Nprev, numWorkItemsReq, numXFormsPerWG, N[r], plan->num_local_mem_banks, &offset, &midPad); - (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; - insertLocalStoreIndexArithmatic(localString, numWorkItemsReq, numXFormsPerWG, N[r], offset, midPad); - insertLocalLoadIndexArithmatic(localString, Nprev, N[r], numWorkItemsReq, numWorkItemsPerXForm, numXFormsPerWG, offset, midPad); - insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp); - insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp); - insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp); - insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp); - Nprev = Ncurr; - len = len / N[r]; - } - } - - lMemSize = insertGlobalStoresAndTranspose(localString, n, maxRadix, N[numRadix - 1], numWorkItemsPerXForm, numXFormsPerWG, plan->min_mem_coalesce_width, dataFormat); - (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; - - insertHeader(*kernelString, kernelName, dataFormat); - *kernelString += string("{\n"); - if((*kInfo)->lmem_size) + (*kInfo)->num_workitems_per_workgroup = numWorkItemsPerWG; + + unsigned int *N = radixArray; + unsigned int maxRadix = N[0]; + unsigned int lMemSize = 0; + + insertVariables(localString, maxRadix); + + lMemSize = insertGlobalLoadsAndTranspose(localString, n, numWorkItemsPerXForm, numXFormsPerWG, maxRadix, plan->min_mem_coalesce_width, dataFormat); + (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; + + string xcomp = string("x"); + string ycomp = string("y"); + + unsigned int Nprev = 1; + unsigned int len = n; + unsigned int r; + for(r = 0; r < numRadix; r++) + { + int numIter = N[0] / N[r]; + int numWorkItemsReq = n / N[r]; + int Ncurr = Nprev * N[r]; + insertfftKernel(localString, N[r], numIter); + + if(r < (numRadix - 1)) { + insertTwiddleKernel(localString, N[r], numIter, Nprev, len, numWorkItemsPerXForm); + lMemSize = getPadding(numWorkItemsPerXForm, Nprev, numWorkItemsReq, numXFormsPerWG, N[r], plan->num_local_mem_banks, &offset, &midPad); + (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; + insertLocalStoreIndexArithmatic(localString, numWorkItemsReq, numXFormsPerWG, N[r], offset, midPad); + insertLocalLoadIndexArithmatic(localString, Nprev, N[r], numWorkItemsReq, numWorkItemsPerXForm, numXFormsPerWG, offset, midPad); + insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp); + insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp); + insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp); + insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp); + Nprev = Ncurr; + len = len / N[r]; + } + } + + lMemSize = insertGlobalStoresAndTranspose(localString, n, maxRadix, N[numRadix - 1], numWorkItemsPerXForm, numXFormsPerWG, plan->min_mem_coalesce_width, dataFormat); + (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; + + insertHeader(*kernelString, kernelName, dataFormat); + *kernelString += string("{\n"); + if((*kInfo)->lmem_size) *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n"); - *kernelString += localString; - *kernelString += string("}\n"); + *kernelString += localString; + *kernelString += string("}\n"); } // For n larger than what can be computed using local memory fft, global transposes // multiple kernel launces is needed. For these sizes, n can be decomposed using // much larger base radices i.e. say n = 262144 = 128 x 64 x 32. Thus three kernel // launches will be needed, first computing 64 x 32, length 128 ffts, second computing -// 128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts. -// Each of these base radices can futher be divided into factors so that each of these -// base ffts can be computed within one kernel launch using in-register ffts and local -// memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length -// 128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length -// 16 ffts followed by transpose using local memory followed by each of these eight -// work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This +// 128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts. +// Each of these base radices can futher be divided into factors so that each of these +// base ffts can be computed within one kernel launch using in-register ffts and local +// memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length +// 128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length +// 16 ffts followed by transpose using local memory followed by each of these eight +// work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This // means only 8 work items are needed for computing one length 128 fft. If we choose // work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one -// work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this -// means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each +// work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this +// means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each // work group where each work group is computing 8 length 128 ffts where each length // 128 fft is computed by 8 work items. Same logic can be applied to other two kernels -// in this example. Users can play with difference base radices and difference +// in this example. Users can play with difference base radices and difference // decompositions of base radices to generates different kernels and see which gives // best performance. Following function is just fixed to use 128 as base radix void getGlobalRadixInfo(int n, int *radix, int *R1, int *R2, int *numRadices) { - int baseRadix = min(n, 128); - - int numR = 0; - int N = n; - while(N > baseRadix) - { - N /= baseRadix; - numR++; - } - - for(int i = 0; i < numR; i++) - radix[i] = baseRadix; - - radix[numR] = N; - numR++; - *numRadices = numR; - - for(int i = 0; i < numR; i++) - { - int B = radix[i]; - if(B <= 8) - { - R1[i] = B; - R2[i] = 1; - continue; - } - - int r1 = 2; - int r2 = B / r1; - while(r2 > r1) - { - r1 *=2; - r2 = B / r1; - } - R1[i] = r1; - R2[i] = r2; - } + int baseRadix = min(n, 128); + + int numR = 0; + int N = n; + while(N > baseRadix) + { + N /= baseRadix; + numR++; + } + + for(int i = 0; i < numR; i++) + radix[i] = baseRadix; + + radix[numR] = N; + numR++; + *numRadices = numR; + + for(int i = 0; i < numR; i++) + { + int B = radix[i]; + if(B <= 8) + { + R1[i] = B; + R2[i] = 1; + continue; + } + + int r1 = 2; + int r2 = B / r1; + while(r2 > r1) + { + r1 *=2; + r2 = B / r1; + } + R1[i] = r1; + R2[i] = r2; + } } static void createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir dir, int vertBS) -{ - int i, j, k, t; - int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; +{ + int i, j, k, t; + int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; int R1Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; int R2Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; - int radix, R1, R2; - int numRadices; - - int maxThreadsPerBlock = plan->max_work_item_per_workgroup; - int maxArrayLen = plan->max_radix; - int batchSize = plan->min_mem_coalesce_width; - clFFT_DataFormat dataFormat = plan->format; - int vertical = (dir == cl_fft_kernel_x) ? 0 : 1; - - getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices); - - int numPasses = numRadices; - - string localString(""), kernelName(""); - string *kernelString = plan->kernel_string; - cl_fft_kernel_info **kInfo = &plan->kernel_info; - int kCount = 0; - - while(*kInfo) - { - kInfo = &(*kInfo)->next; - kCount++; - } - - int N = n; - int m = (int)log2(n); - int Rinit = vertical ? BS : 1; - batchSize = vertical ? min(BS, batchSize) : batchSize; - int passNum; - - for(passNum = 0; passNum < numPasses; passNum++) - { - - localString.clear(); - kernelName.clear(); - - radix = radixArr[passNum]; - R1 = R1Arr[passNum]; - R2 = R2Arr[passNum]; - - int strideI = Rinit; - for(i = 0; i < numPasses; i++) - if(i != passNum) - strideI *= radixArr[i]; - - int strideO = Rinit; - for(i = 0; i < passNum; i++) - strideO *= radixArr[i]; - - int threadsPerXForm = R2; - batchSize = R2 == 1 ? plan->max_work_item_per_workgroup : batchSize; - batchSize = min(batchSize, strideI); - int threadsPerBlock = batchSize * threadsPerXForm; - threadsPerBlock = min(threadsPerBlock, maxThreadsPerBlock); - batchSize = threadsPerBlock / threadsPerXForm; - assert(R2 <= R1); - assert(R1*R2 == radix); - assert(R1 <= maxArrayLen); - assert(threadsPerBlock <= maxThreadsPerBlock); - - int numIter = R1 / R2; - int gInInc = threadsPerBlock / batchSize; - - - int lgStrideO = (int)log2(strideO); - int numBlocksPerXForm = strideI / batchSize; - int numBlocks = numBlocksPerXForm; - if(!vertical) - numBlocks *= BS; - else - numBlocks *= vertBS; - - kernelName = string("fft") + num2str(kCount); - *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info)); - (*kInfo)->kernel = 0; - if(R2 == 1) - (*kInfo)->lmem_size = 0; - else - { - if(strideO == 1) - (*kInfo)->lmem_size = (radix + 1)*batchSize; - else - (*kInfo)->lmem_size = threadsPerBlock*R1; - } - (*kInfo)->num_workgroups = numBlocks; - (*kInfo)->num_xforms_per_workgroup = 1; - (*kInfo)->num_workitems_per_workgroup = threadsPerBlock; - (*kInfo)->dir = dir; - if( (passNum == (numPasses - 1)) && (numPasses & 1) ) - (*kInfo)->in_place_possible = 1; - else - (*kInfo)->in_place_possible = 0; - (*kInfo)->next = NULL; - (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1)); - strcpy((*kInfo)->kernel_name, kernelName.c_str()); - - insertVariables(localString, R1); - - if(vertical) - { - localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n"); - localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n"); - localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n"); - localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n"); - localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n"); - localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); - int stride = radix*Rinit; - for(i = 0; i < passNum; i++) - stride *= radixArr[i]; - localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n"); - localString += string("bNum = groupId;\n"); - } - else - { - int lgNumBlocksPerXForm = (int)log2(numBlocksPerXForm); - localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n"); - localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n"); - localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n"); - localString += string("tid = indexIn;\n"); - localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n"); - localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); - int stride = radix*Rinit; - for(i = 0; i < passNum; i++) - stride *= radixArr[i]; - localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n"); - localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n"); - localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n"); - } - - // Load Data - int lgBatchSize = (int)log2(batchSize); - localString += string("tid = lId;\n"); - localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n"); - localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n"); - localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n"); + int radix, R1, R2; + int numRadices; + + int maxThreadsPerBlock = plan->max_work_item_per_workgroup; + int maxArrayLen = plan->max_radix; + int batchSize = plan->min_mem_coalesce_width; + clFFT_DataFormat dataFormat = plan->format; + int vertical = (dir == cl_fft_kernel_x) ? 0 : 1; + + getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices); + + int numPasses = numRadices; + + string localString(""), kernelName(""); + string *kernelString = plan->kernel_string; + cl_fft_kernel_info **kInfo = &plan->kernel_info; + int kCount = 0; + + while(*kInfo) + { + kInfo = &(*kInfo)->next; + kCount++; + } + + int N = n; + int m = (int)log2(n); + int Rinit = vertical ? BS : 1; + batchSize = vertical ? min(BS, batchSize) : batchSize; + int passNum; + + for(passNum = 0; passNum < numPasses; passNum++) + { + + localString.clear(); + kernelName.clear(); + + radix = radixArr[passNum]; + R1 = R1Arr[passNum]; + R2 = R2Arr[passNum]; + + int strideI = Rinit; + for(i = 0; i < numPasses; i++) + if(i != passNum) + strideI *= radixArr[i]; + + int strideO = Rinit; + for(i = 0; i < passNum; i++) + strideO *= radixArr[i]; + + int threadsPerXForm = R2; + batchSize = R2 == 1 ? plan->max_work_item_per_workgroup : batchSize; + batchSize = min(batchSize, strideI); + int threadsPerBlock = batchSize * threadsPerXForm; + threadsPerBlock = min(threadsPerBlock, maxThreadsPerBlock); + batchSize = threadsPerBlock / threadsPerXForm; + assert(R2 <= R1); + assert(R1*R2 == radix); + assert(R1 <= maxArrayLen); + assert(threadsPerBlock <= maxThreadsPerBlock); + + int numIter = R1 / R2; + int gInInc = threadsPerBlock / batchSize; + + + int lgStrideO = (int)log2(strideO); + int numBlocksPerXForm = strideI / batchSize; + int numBlocks = numBlocksPerXForm; + if(!vertical) + numBlocks *= BS; + else + numBlocks *= vertBS; + + kernelName = string("fft") + num2str(kCount); + *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info)); + (*kInfo)->kernel = 0; + if(R2 == 1) + (*kInfo)->lmem_size = 0; + else + { + if(strideO == 1) + (*kInfo)->lmem_size = (radix + 1)*batchSize; + else + (*kInfo)->lmem_size = threadsPerBlock*R1; + } + (*kInfo)->num_workgroups = numBlocks; + (*kInfo)->num_xforms_per_workgroup = 1; + (*kInfo)->num_workitems_per_workgroup = threadsPerBlock; + (*kInfo)->dir = dir; + if( (passNum == (numPasses - 1)) && (numPasses & 1) ) + (*kInfo)->in_place_possible = 1; + else + (*kInfo)->in_place_possible = 0; + (*kInfo)->next = NULL; + (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1)); + strcpy((*kInfo)->kernel_name, kernelName.c_str()); + + insertVariables(localString, R1); + + if(vertical) + { + localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n"); + localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n"); + localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n"); + localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n"); + localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n"); + localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); + int stride = radix*Rinit; + for(i = 0; i < passNum; i++) + stride *= radixArr[i]; + localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n"); + localString += string("bNum = groupId;\n"); + } + else + { + int lgNumBlocksPerXForm = (int)log2(numBlocksPerXForm); + localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n"); + localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n"); + localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n"); + localString += string("tid = indexIn;\n"); + localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n"); + localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); + int stride = radix*Rinit; + for(i = 0; i < passNum; i++) + stride *= radixArr[i]; + localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n"); + localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n"); + localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n"); + } + + // Load Data + int lgBatchSize = (int)log2(batchSize); + localString += string("tid = lId;\n"); + localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n"); + localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n"); + localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n"); - if(dataFormat == clFFT_SplitComplexFormat) - { - localString += string("in_real += indexIn;\n"); - localString += string("in_imag += indexIn;\n"); - for(j = 0; j < R1; j++) - localString += string("a[") + num2str(j) + string("].x = in_real[") + num2str(j*gInInc*strideI) + string("];\n"); - for(j = 0; j < R1; j++) - localString += string("a[") + num2str(j) + string("].y = in_imag[") + num2str(j*gInInc*strideI) + string("];\n"); - } - else - { - localString += string("in += indexIn;\n"); - for(j = 0; j < R1; j++) - localString += string("a[") + num2str(j) + string("] = in[") + num2str(j*gInInc*strideI) + string("];\n"); - } - - localString += string("fftKernel") + num2str(R1) + string("(a, dir);\n"); - - if(R2 > 1) - { - // twiddle - for(k = 1; k < R1; k++) - { - localString += string("ang = dir*(2.0f*M_PI*") + num2str(k) + string("/") + num2str(radix) + string(")*j;\n"); - localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n"); - localString += string("a[") + num2str(k) + string("] = complexMul(a[") + num2str(k) + string("], w);\n"); - } - - // shuffle - numIter = R1 / R2; - localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n"); - localString += string("lMemStore = sMem + tid;\n"); - localString += string("lMemLoad = sMem + indexIn;\n"); - for(k = 0; k < R1; k++) - localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n"); - localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); - for(k = 0; k < numIter; k++) - for(t = 0; t < R2; t++) - localString += string("a[") + num2str(k*R2+t) + string("].x = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n"); - localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); - for(k = 0; k < R1; k++) - localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n"); - localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); - for(k = 0; k < numIter; k++) - for(t = 0; t < R2; t++) - localString += string("a[") + num2str(k*R2+t) + string("].y = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n"); - localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); - - for(j = 0; j < numIter; j++) - localString += string("fftKernel") + num2str(R2) + string("(a + ") + num2str(j*R2) + string(", dir);\n"); - } - - // twiddle - if(passNum < (numPasses - 1)) - { - localString += string("l = ((bNum << ") + num2str(lgBatchSize) + string(") + i) >> ") + num2str(lgStrideO) + string(";\n"); - localString += string("k = j << ") + num2str((int)log2(R1/R2)) + string(";\n"); - localString += string("ang1 = dir*(2.0f*M_PI/") + num2str(N) + string(")*l;\n"); - for(t = 0; t < R1; t++) - { - localString += string("ang = ang1*(k + ") + num2str((t%R2)*R1 + (t/R2)) + string(");\n"); - localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n"); - localString += string("a[") + num2str(t) + string("] = complexMul(a[") + num2str(t) + string("], w);\n"); - } - } - - // Store Data - if(strideO == 1) - { - - localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n"); - localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n"); - - for(i = 0; i < R1/R2; i++) - for(j = 0; j < R2; j++) - localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].x;\n"); - localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); - if(threadsPerBlock >= radix) - { - for(i = 0; i < R1; i++) - localString += string("a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n"); - } + if(dataFormat == clFFT_SplitComplexFormat) + { + localString += string("in_real += indexIn;\n"); + localString += string("in_imag += indexIn;\n"); + for(j = 0; j < R1; j++) + localString += string("a[") + num2str(j) + string("].x = in_real[") + num2str(j*gInInc*strideI) + string("];\n"); + for(j = 0; j < R1; j++) + localString += string("a[") + num2str(j) + string("].y = in_imag[") + num2str(j*gInInc*strideI) + string("];\n"); + } else - { - int innerIter = radix/threadsPerBlock; - int outerIter = R1/innerIter; - for(i = 0; i < outerIter; i++) - for(j = 0; j < innerIter; j++) - localString += string("a[") + num2str(i*innerIter+j) + string("].x = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n"); - } - localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); - - for(i = 0; i < R1/R2; i++) - for(j = 0; j < R2; j++) - localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].y;\n"); - localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); - if(threadsPerBlock >= radix) - { - for(i = 0; i < R1; i++) - localString += string("a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n"); - } + { + localString += string("in += indexIn;\n"); + for(j = 0; j < R1; j++) + localString += string("a[") + num2str(j) + string("] = in[") + num2str(j*gInInc*strideI) + string("];\n"); + } + + localString += string("fftKernel") + num2str(R1) + string("(a, dir);\n"); + + if(R2 > 1) + { + // twiddle + for(k = 1; k < R1; k++) + { + localString += string("ang = dir*(2.0f*M_PI*") + num2str(k) + string("/") + num2str(radix) + string(")*j;\n"); + localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n"); + localString += string("a[") + num2str(k) + string("] = complexMul(a[") + num2str(k) + string("], w);\n"); + } + + // shuffle + numIter = R1 / R2; + localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n"); + localString += string("lMemStore = sMem + tid;\n"); + localString += string("lMemLoad = sMem + indexIn;\n"); + for(k = 0; k < R1; k++) + localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n"); + localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); + for(k = 0; k < numIter; k++) + for(t = 0; t < R2; t++) + localString += string("a[") + num2str(k*R2+t) + string("].x = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n"); + localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); + for(k = 0; k < R1; k++) + localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n"); + localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); + for(k = 0; k < numIter; k++) + for(t = 0; t < R2; t++) + localString += string("a[") + num2str(k*R2+t) + string("].y = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n"); + localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); + + for(j = 0; j < numIter; j++) + localString += string("fftKernel") + num2str(R2) + string("(a + ") + num2str(j*R2) + string(", dir);\n"); + } + + // twiddle + if(passNum < (numPasses - 1)) + { + localString += string("l = ((bNum << ") + num2str(lgBatchSize) + string(") + i) >> ") + num2str(lgStrideO) + string(";\n"); + localString += string("k = j << ") + num2str((int)log2(R1/R2)) + string(";\n"); + localString += string("ang1 = dir*(2.0f*M_PI/") + num2str(N) + string(")*l;\n"); + for(t = 0; t < R1; t++) + { + localString += string("ang = ang1*(k + ") + num2str((t%R2)*R1 + (t/R2)) + string(");\n"); + localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n"); + localString += string("a[") + num2str(t) + string("] = complexMul(a[") + num2str(t) + string("], w);\n"); + } + } + + // Store Data + if(strideO == 1) + { + + localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n"); + localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n"); + + for(i = 0; i < R1/R2; i++) + for(j = 0; j < R2; j++) + localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].x;\n"); + localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); + if(threadsPerBlock >= radix) + { + for(i = 0; i < R1; i++) + localString += string("a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n"); + } + else + { + int innerIter = radix/threadsPerBlock; + int outerIter = R1/innerIter; + for(i = 0; i < outerIter; i++) + for(j = 0; j < innerIter; j++) + localString += string("a[") + num2str(i*innerIter+j) + string("].x = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n"); + } + localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); + + for(i = 0; i < R1/R2; i++) + for(j = 0; j < R2; j++) + localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].y;\n"); + localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); + if(threadsPerBlock >= radix) + { + for(i = 0; i < R1; i++) + localString += string("a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n"); + } + else + { + int innerIter = radix/threadsPerBlock; + int outerIter = R1/innerIter; + for(i = 0; i < outerIter; i++) + for(j = 0; j < innerIter; j++) + localString += string("a[") + num2str(i*innerIter+j) + string("].y = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n"); + } + localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); + + localString += string("indexOut += tid;\n"); + if(dataFormat == clFFT_SplitComplexFormat) { + localString += string("out_real += indexOut;\n"); + localString += string("out_imag += indexOut;\n"); + for(k = 0; k < R1; k++) + localString += string("out_real[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n"); + for(k = 0; k < R1; k++) + localString += string("out_imag[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n"); + } + else { + localString += string("out += indexOut;\n"); + for(k = 0; k < R1; k++) + localString += string("out[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("];\n"); + } + + } else - { - int innerIter = radix/threadsPerBlock; - int outerIter = R1/innerIter; - for(i = 0; i < outerIter; i++) - for(j = 0; j < innerIter; j++) - localString += string("a[") + num2str(i*innerIter+j) + string("].y = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n"); - } - localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); - - localString += string("indexOut += tid;\n"); - if(dataFormat == clFFT_SplitComplexFormat) { - localString += string("out_real += indexOut;\n"); - localString += string("out_imag += indexOut;\n"); - for(k = 0; k < R1; k++) - localString += string("out_real[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n"); - for(k = 0; k < R1; k++) - localString += string("out_imag[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n"); - } - else { - localString += string("out += indexOut;\n"); - for(k = 0; k < R1; k++) - localString += string("out[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("];\n"); - } - - } - else - { - localString += string("indexOut += mad24(j, ") + num2str(numIter*strideO) + string(", i);\n"); - if(dataFormat == clFFT_SplitComplexFormat) { - localString += string("out_real += indexOut;\n"); - localString += string("out_imag += indexOut;\n"); - for(k = 0; k < R1; k++) - localString += string("out_real[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].x;\n"); - for(k = 0; k < R1; k++) - localString += string("out_imag[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].y;\n"); - } - else { - localString += string("out += indexOut;\n"); - for(k = 0; k < R1; k++) - localString += string("out[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("];\n"); - } - } - - insertHeader(*kernelString, kernelName, dataFormat); - *kernelString += string("{\n"); - if((*kInfo)->lmem_size) - *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n"); - *kernelString += localString; - *kernelString += string("}\n"); - - N /= radix; - kInfo = &(*kInfo)->next; - kCount++; - } + { + localString += string("indexOut += mad24(j, ") + num2str(numIter*strideO) + string(", i);\n"); + if(dataFormat == clFFT_SplitComplexFormat) { + localString += string("out_real += indexOut;\n"); + localString += string("out_imag += indexOut;\n"); + for(k = 0; k < R1; k++) + localString += string("out_real[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].x;\n"); + for(k = 0; k < R1; k++) + localString += string("out_imag[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].y;\n"); + } + else { + localString += string("out += indexOut;\n"); + for(k = 0; k < R1; k++) + localString += string("out[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("];\n"); + } + } + + insertHeader(*kernelString, kernelName, dataFormat); + *kernelString += string("{\n"); + if((*kInfo)->lmem_size) + *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n"); + *kernelString += localString; + *kernelString += string("}\n"); + + N /= radix; + kInfo = &(*kInfo)->next; + kCount++; + } } void FFT1D(cl_fft_plan *plan, cl_fft_kernel_dir dir) -{ +{ unsigned int radixArray[10]; unsigned int numRadix; - - switch(dir) - { - case cl_fft_kernel_x: - if(plan->n.x > plan->max_localmem_fft_size) - { - createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1); - } - else if(plan->n.x > 1) - { - getRadixArray(plan->n.x, radixArray, &numRadix, 0); - if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup) - { - createLocalMemfftKernelString(plan); - } - else - { - getRadixArray(plan->n.x, radixArray, &numRadix, plan->max_radix); - if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup) - createLocalMemfftKernelString(plan); - else - createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1); - } - } - break; - - case cl_fft_kernel_y: - if(plan->n.y > 1) - createGlobalFFTKernelString(plan, plan->n.y, plan->n.x, cl_fft_kernel_y, 1); - break; - - case cl_fft_kernel_z: - if(plan->n.z > 1) - createGlobalFFTKernelString(plan, plan->n.z, plan->n.x*plan->n.y, cl_fft_kernel_z, 1); - default: - return; - } + + switch(dir) + { + case cl_fft_kernel_x: + if(plan->n.x > plan->max_localmem_fft_size) + { + createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1); + } + else if(plan->n.x > 1) + { + getRadixArray(plan->n.x, radixArray, &numRadix, 0); + if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup) + { + createLocalMemfftKernelString(plan); + } + else + { + getRadixArray(plan->n.x, radixArray, &numRadix, plan->max_radix); + if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup) + createLocalMemfftKernelString(plan); + else + createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1); + } + } + break; + + case cl_fft_kernel_y: + if(plan->n.y > 1) + createGlobalFFTKernelString(plan, plan->n.y, plan->n.x, cl_fft_kernel_y, 1); + + + break; + + case cl_fft_kernel_z: + if(plan->n.z > 1) + createGlobalFFTKernelString(plan, plan->n.z, plan->n.x*plan->n.y, cl_fft_kernel_z, 1); + default: + return; + } } +