comparison fft_Example/fft_kernelstring.cc @ 7:ea2e7ce9d5bb

add sample.pgm
author Yuhi TOMARI <yuhi@cr.ie.u-ryukyu.ac.jp>
date Tue, 05 Feb 2013 15:19:02 +0900
parents 3602b23914ad
children
comparison
equal deleted inserted replaced
6:db074091ed0b 7:ea2e7ce9d5bb
59 using namespace std; 59 using namespace std;
60 60
61 #define max(A,B) ((A) > (B) ? (A) : (B)) 61 #define max(A,B) ((A) > (B) ? (A) : (B))
62 #define min(A,B) ((A) < (B) ? (A) : (B)) 62 #define min(A,B) ((A) < (B) ? (A) : (B))
63 63
64 static string 64 static string
65 num2str(int num) 65 num2str(int num)
66 { 66 {
67 char temp[200]; 67 char temp[200];
68 sprintf(temp, "%d", num); 68 sprintf(temp, "%d", num);
69 return string(temp); 69 return string(temp);
70 } 70 }
71 71
72 // For any n, this function decomposes n into factors for loacal memory tranpose 72 // For any n, this function decomposes n into factors for loacal memory tranpose
73 // based fft. Factors (radices) are sorted such that the first one (radixArray[0]) 73 // based fft. Factors (radices) are sorted such that the first one (radixArray[0])
74 // is the largest. This base radix determines the number of registers used by each 74 // is the largest. This base radix determines the number of registers used by each
75 // work item and product of remaining radices determine the size of work group needed. 75 // work item and product of remaining radices determine the size of work group needed.
76 // To make things concrete with and example, suppose n = 1024. It is decomposed into 76 // To make things concrete with and example, suppose n = 1024. It is decomposed into
77 // 1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and 77 // 1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and
78 // needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length 78 // needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length
79 // 16 ffts (64 work items working in parallel) following by transpose using local 79 // 16 ffts (64 work items working in parallel) following by transpose using local
80 // memory followed by again 64 length 16 ffts followed by transpose using local memory 80 // memory followed by again 64 length 16 ffts followed by transpose using local memory
81 // followed by 256 length 4 ffts. For the last step since with size of work group is 81 // followed by 256 length 4 ffts. For the last step since with size of work group is
82 // 64 and each work item can array for 16 values, 64 work items can compute 256 length 82 // 64 and each work item can array for 16 values, 64 work items can compute 256 length
83 // 4 ffts by each work item computing 4 length 4 ffts. 83 // 4 ffts by each work item computing 4 length 4 ffts.
84 // Similarly for n = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work 84 // Similarly for n = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work
85 // iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed 85 // iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed
86 // by transpose using local memory, followed by 256 length 8 in-register ffts, followed 86 // by transpose using local memory, followed by 256 length 8 in-register ffts, followed
87 // by transpose using local memory, followed by 256 length 8 in-register ffts, followed 87 // by transpose using local memory, followed by 256 length 8 in-register ffts, followed
88 // by transpose using local memory, followed by 512 length 4 in-register ffts. Again, 88 // by transpose using local memory, followed by 512 length 4 in-register ffts. Again,
89 // for the last step, each work item computes two length 4 in-register ffts and thus 89 // for the last step, each work item computes two length 4 in-register ffts and thus
90 // 256 work items are needed to compute all 512 ffts. 90 // 256 work items are needed to compute all 512 ffts.
91 // For n = 32 = 8 x 4, 4 work items first compute 4 in-register 91 // For n = 32 = 8 x 4, 4 work items first compute 4 in-register
92 // lenth 8 ffts, followed by transpose using local memory followed by 8 in-register 92 // lenth 8 ffts, followed by transpose using local memory followed by 8 in-register
93 // length 4 ffts, where each work item computes two length 4 ffts thus 4 work items 93 // length 4 ffts, where each work item computes two length 4 ffts thus 4 work items
94 // can compute 8 length 4 ffts. However if work group size of say 64 is choosen, 94 // can compute 8 length 4 ffts. However if work group size of say 64 is choosen,
95 // each work group can compute 64/ 4 = 16 size 32 ffts (batched transform). 95 // each work group can compute 64/ 4 = 16 size 32 ffts (batched transform).
96 // Users can play with these parameters to figure what gives best performance on 96 // Users can play with these parameters to figure what gives best performance on
97 // their particular device i.e. some device have less register space thus using 97 // their particular device i.e. some device have less register space thus using
98 // smaller base radix can avoid spilling ... some has small local memory thus 98 // smaller base radix can avoid spilling ... some has small local memory thus
99 // using smaller work group size may be required etc 99 // using smaller work group size may be required etc
100 100
101 static void 101 static void
102 getRadixArray(unsigned int n, unsigned int *radixArray, unsigned int *numRadices, unsigned int maxRadix) 102 getRadixArray(unsigned int n, unsigned int *radixArray, unsigned int *numRadices, unsigned int maxRadix)
103 { 103 {
104 if(maxRadix > 1) 104 if(maxRadix > 1)
105 { 105 {
106 maxRadix = min(n, maxRadix); 106 maxRadix = min(n, maxRadix);
107 unsigned int cnt = 0; 107 unsigned int cnt = 0;
108 while(n > maxRadix) 108 while(n > maxRadix)
109 { 109 {
110 radixArray[cnt++] = maxRadix; 110 radixArray[cnt++] = maxRadix;
111 n /= maxRadix; 111 n /= maxRadix;
112 } 112 }
113 radixArray[cnt++] = n; 113 radixArray[cnt++] = n;
114 *numRadices = cnt; 114 *numRadices = cnt;
115 return; 115 return;
116 } 116 }
117 117
118 switch(n) 118 switch(n)
119 { 119 {
120 case 2: 120 case 2:
121 *numRadices = 1; 121 *numRadices = 1;
122 radixArray[0] = 2; 122 radixArray[0] = 2;
123 break; 123 break;
124 124
125 case 4: 125 case 4:
126 *numRadices = 1; 126 *numRadices = 1;
127 radixArray[0] = 4; 127 radixArray[0] = 4;
128 break; 128 break;
129 129
130 case 8: 130 case 8:
131 *numRadices = 1; 131 *numRadices = 1;
132 radixArray[0] = 8; 132 radixArray[0] = 8;
133 break; 133 break;
134 134
135 case 16: 135 case 16:
136 *numRadices = 2; 136 *numRadices = 2;
137 radixArray[0] = 8; radixArray[1] = 2; 137 radixArray[0] = 8; radixArray[1] = 2;
138 break; 138 break;
139 139
140 case 32: 140 case 32:
141 *numRadices = 2; 141 *numRadices = 2;
142 radixArray[0] = 8; radixArray[1] = 4; 142 radixArray[0] = 8; radixArray[1] = 4;
143 break; 143 break;
144 144
145 case 64: 145 case 64:
146 *numRadices = 2; 146 *numRadices = 2;
147 radixArray[0] = 8; radixArray[1] = 8; 147 radixArray[0] = 8; radixArray[1] = 8;
148 break; 148 break;
149 149
150 case 128: 150 case 128:
151 *numRadices = 3; 151 *numRadices = 3;
152 radixArray[0] = 8; radixArray[1] = 4; radixArray[2] = 4; 152 radixArray[0] = 8; radixArray[1] = 4; radixArray[2] = 4;
153 break; 153 break;
154 154
155 case 256: 155 case 256:
156 *numRadices = 4; 156 *numRadices = 4;
157 radixArray[0] = 4; radixArray[1] = 4; radixArray[2] = 4; radixArray[3] = 4; 157 radixArray[0] = 4; radixArray[1] = 4; radixArray[2] = 4; radixArray[3] = 4;
158 break; 158 break;
159 159
160 case 512: 160 case 512:
161 *numRadices = 3; 161 *numRadices = 3;
162 radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; 162 radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8;
163 break; 163 break;
164 164
165 case 1024: 165 case 1024:
166 *numRadices = 3; 166 *numRadices = 3;
167 radixArray[0] = 16; radixArray[1] = 16; radixArray[2] = 4; 167 radixArray[0] = 16; radixArray[1] = 16; radixArray[2] = 4;
168 break; 168 break;
169 case 2048: 169 case 2048:
170 *numRadices = 4; 170 *numRadices = 4;
171 radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; radixArray[3] = 4; 171 radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; radixArray[3] = 4;
172 break; 172 break;
173 default: 173 default:
174 *numRadices = 0; 174 *numRadices = 0;
175 return; 175 return;
176 } 176 }
177 } 177 }
178 178
179 static void 179 static void
180 insertHeader(string &kernelString, string &kernelName, clFFT_DataFormat dataFormat) 180 insertHeader(string &kernelString, string &kernelName, clFFT_DataFormat dataFormat)
181 { 181 {
182 if(dataFormat == clFFT_SplitComplexFormat) 182 if(dataFormat == clFFT_SplitComplexFormat)
183 kernelString += string("__kernel void ") + kernelName + string("(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n"); 183 kernelString += string("__kernel void ") + kernelName + string("(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n");
184 else 184 else
185 kernelString += string("__kernel void ") + kernelName + string("(__global float2 *in, __global float2 *out, int dir, int S)\n"); 185 kernelString += string("__kernel void ") + kernelName + string("(__global float2 *in, __global float2 *out, int dir, int S)\n");
186 } 186 printf("%s\n",kernelName.c_str());
187 187 }
188 static void 188
189 static void
189 insertVariables(string &kStream, int maxRadix) 190 insertVariables(string &kStream, int maxRadix)
190 { 191 {
191 kStream += string(" int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n"); 192 kStream += string(" int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n");
192 kStream += string(" int s, ii, jj, offset;\n"); 193 kStream += string(" int s, ii, jj, offset;\n");
193 kStream += string(" float2 w;\n"); 194 kStream += string(" float2 w;\n");
200 } 201 }
201 202
202 static void 203 static void
203 formattedLoad(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat) 204 formattedLoad(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
204 { 205 {
205 if(dataFormat == clFFT_InterleavedComplexFormat) 206 if(dataFormat == clFFT_InterleavedComplexFormat)
206 kernelString += string(" a[") + num2str(aIndex) + string("] = in[") + num2str(gIndex) + string("];\n"); 207 kernelString += string(" a[") + num2str(aIndex) + string("] = in[") + num2str(gIndex) + string("];\n");
207 else 208 else
208 { 209 {
209 kernelString += string(" a[") + num2str(aIndex) + string("].x = in_real[") + num2str(gIndex) + string("];\n"); 210 kernelString += string(" a[") + num2str(aIndex) + string("].x = in_real[") + num2str(gIndex) + string("];\n");
210 kernelString += string(" a[") + num2str(aIndex) + string("].y = in_imag[") + num2str(gIndex) + string("];\n"); 211 kernelString += string(" a[") + num2str(aIndex) + string("].y = in_imag[") + num2str(gIndex) + string("];\n");
211 } 212 }
212 } 213 }
213 214
214 static void 215 static void
215 formattedStore(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat) 216 formattedStore(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
216 { 217 {
217 if(dataFormat == clFFT_InterleavedComplexFormat) 218 if(dataFormat == clFFT_InterleavedComplexFormat)
218 kernelString += string(" out[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("];\n"); 219 kernelString += string(" out[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("];\n");
219 else 220 else
220 { 221 {
221 kernelString += string(" out_real[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].x;\n"); 222 kernelString += string(" out_real[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].x;\n");
222 kernelString += string(" out_imag[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].y;\n"); 223 kernelString += string(" out_imag[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].y;\n");
223 } 224 }
224 } 225 }
225 226
226 static int 227 static int
227 insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXForm, int numXFormsPerWG, int R0, int mem_coalesce_width, clFFT_DataFormat dataFormat) 228 insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXForm, int numXFormsPerWG, int R0, int mem_coalesce_width, clFFT_DataFormat dataFormat)
228 { 229 {
229 int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm); 230 int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm);
230 int groupSize = numWorkItemsPerXForm * numXFormsPerWG; 231 int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
231 int i, j; 232 int i, j;
232 int lMemSize = 0; 233 int lMemSize = 0;
233 234
234 if(numXFormsPerWG > 1) 235 if(numXFormsPerWG > 1)
235 kernelString += string(" s = S & ") + num2str(numXFormsPerWG - 1) + string(";\n"); 236 kernelString += string(" s = S & ") + num2str(numXFormsPerWG - 1) + string(";\n");
236 237
237 if(numWorkItemsPerXForm >= mem_coalesce_width) 238 if(numWorkItemsPerXForm >= mem_coalesce_width)
238 { 239 {
239 if(numXFormsPerWG > 1) 240 if(numXFormsPerWG > 1)
240 { 241 {
241 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm-1) + string(";\n"); 242 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm-1) + string(";\n");
242 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); 243 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
243 kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n"); 244 kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
244 kernelString += string(" offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n"); 245 kernelString += string(" offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n");
245 if(dataFormat == clFFT_InterleavedComplexFormat) 246 if(dataFormat == clFFT_InterleavedComplexFormat)
246 { 247 {
247 kernelString += string(" in += offset;\n"); 248 kernelString += string(" in += offset;\n");
248 kernelString += string(" out += offset;\n"); 249 kernelString += string(" out += offset;\n");
249 } 250 }
250 else 251 else
251 { 252 {
252 kernelString += string(" in_real += offset;\n"); 253 kernelString += string(" in_real += offset;\n");
253 kernelString += string(" in_imag += offset;\n"); 254 kernelString += string(" in_imag += offset;\n");
254 kernelString += string(" out_real += offset;\n"); 255 kernelString += string(" out_real += offset;\n");
255 kernelString += string(" out_imag += offset;\n"); 256 kernelString += string(" out_imag += offset;\n");
256 } 257 }
257 for(i = 0; i < R0; i++) 258 for(i = 0; i < R0; i++)
258 formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat); 259 formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat);
259 kernelString += string(" }\n"); 260 kernelString += string(" }\n");
260 } 261 }
261 else 262 else
262 { 263 {
263 kernelString += string(" ii = lId;\n"); 264 kernelString += string(" ii = lId;\n");
264 kernelString += string(" jj = 0;\n"); 265 kernelString += string(" jj = 0;\n");
265 kernelString += string(" offset = mad24(groupId, ") + num2str(N) + string(", ii);\n"); 266 kernelString += string(" offset = mad24(groupId, ") + num2str(N) + string(", ii);\n");
266 if(dataFormat == clFFT_InterleavedComplexFormat) 267 if(dataFormat == clFFT_InterleavedComplexFormat)
267 { 268 {
268 kernelString += string(" in += offset;\n"); 269 kernelString += string(" in += offset;\n");
269 kernelString += string(" out += offset;\n"); 270 kernelString += string(" out += offset;\n");
270 } 271 }
271 else 272 else
272 { 273 {
273 kernelString += string(" in_real += offset;\n"); 274 kernelString += string(" in_real += offset;\n");
274 kernelString += string(" in_imag += offset;\n"); 275 kernelString += string(" in_imag += offset;\n");
275 kernelString += string(" out_real += offset;\n"); 276 kernelString += string(" out_real += offset;\n");
276 kernelString += string(" out_imag += offset;\n"); 277 kernelString += string(" out_imag += offset;\n");
277 } 278 }
278 for(i = 0; i < R0; i++) 279 for(i = 0; i < R0; i++)
279 formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat); 280 formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat);
280 } 281 }
281 } 282 }
282 else if( N >= mem_coalesce_width ) 283 else if( N >= mem_coalesce_width )
283 { 284 {
284 int numInnerIter = N / mem_coalesce_width; 285 int numInnerIter = N / mem_coalesce_width;
285 int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width ); 286 int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
286 287
287 kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n"); 288 kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
288 kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n"); 289 kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
289 kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); 290 kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
290 kernelString += string(" offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n"); 291 kernelString += string(" offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n");
291 kernelString += string(" offset = mad24( offset, ") + num2str(N) + string(", ii );\n"); 292 kernelString += string(" offset = mad24( offset, ") + num2str(N) + string(", ii );\n");
292 if(dataFormat == clFFT_InterleavedComplexFormat) 293 if(dataFormat == clFFT_InterleavedComplexFormat)
293 { 294 {
294 kernelString += string(" in += offset;\n"); 295 kernelString += string(" in += offset;\n");
295 kernelString += string(" out += offset;\n"); 296 kernelString += string(" out += offset;\n");
296 } 297 }
297 else 298 else
298 { 299 {
299 kernelString += string(" in_real += offset;\n"); 300 kernelString += string(" in_real += offset;\n");
300 kernelString += string(" in_imag += offset;\n"); 301 kernelString += string(" in_imag += offset;\n");
301 kernelString += string(" out_real += offset;\n"); 302 kernelString += string(" out_real += offset;\n");
302 kernelString += string(" out_imag += offset;\n"); 303 kernelString += string(" out_imag += offset;\n");
303 } 304 }
304 305
305 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); 306 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
306 for(i = 0; i < numOuterIter; i++ ) 307 for(i = 0; i < numOuterIter; i++ )
307 { 308 {
308 kernelString += string(" if( jj < s ) {\n"); 309 kernelString += string(" if( jj < s ) {\n");
309 for(j = 0; j < numInnerIter; j++ ) 310 for(j = 0; j < numInnerIter; j++ )
310 formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat); 311 formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat);
311 kernelString += string(" }\n"); 312 kernelString += string(" }\n");
312 if(i != numOuterIter - 1) 313 if(i != numOuterIter - 1)
313 kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n"); 314 kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");
314 } 315 }
315 kernelString += string("}\n "); 316 kernelString += string("}\n ");
316 kernelString += string("else {\n"); 317 kernelString += string("else {\n");
317 for(i = 0; i < numOuterIter; i++ ) 318 for(i = 0; i < numOuterIter; i++ )
318 { 319 {
319 for(j = 0; j < numInnerIter; j++ ) 320 for(j = 0; j < numInnerIter; j++ )
320 formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat); 321 formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat);
321 } 322 }
322 kernelString += string("}\n"); 323 kernelString += string("}\n");
323 324
324 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
325 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
326 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n");
327
328 for( i = 0; i < numOuterIter; i++ )
329 {
330 for( j = 0; j < numInnerIter; j++ )
331 {
332 kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") +
333 num2str(i * numInnerIter + j) + string("].x;\n");
334 }
335 }
336 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
337
338 for( i = 0; i < R0; i++ )
339 kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
340 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
341
342 for( i = 0; i < numOuterIter; i++ )
343 {
344 for( j = 0; j < numInnerIter; j++ )
345 {
346 kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") +
347 num2str(i * numInnerIter + j) + string("].y;\n");
348 }
349 }
350 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
351
352 for( i = 0; i < R0; i++ )
353 kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
354 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
355
356 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
357 }
358 else
359 {
360 kernelString += string(" offset = mad24( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n");
361 if(dataFormat == clFFT_InterleavedComplexFormat)
362 {
363 kernelString += string(" in += offset;\n");
364 kernelString += string(" out += offset;\n");
365 }
366 else
367 {
368 kernelString += string(" in_real += offset;\n");
369 kernelString += string(" in_imag += offset;\n");
370 kernelString += string(" out_real += offset;\n");
371 kernelString += string(" out_imag += offset;\n");
372 }
373
374 kernelString += string(" ii = lId & ") + num2str(N-1) + string(";\n");
375 kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n");
376 kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
377
378 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
379 for( i = 0; i < R0; i++ )
380 {
381 kernelString += string(" if(jj < s )\n");
382 formattedLoad(kernelString, i, i*groupSize, dataFormat);
383 if(i != R0 - 1)
384 kernelString += string(" jj += ") + num2str(groupSize / N) + string(";\n");
385 }
386 kernelString += string("}\n");
387 kernelString += string("else {\n");
388 for( i = 0; i < R0; i++ )
389 {
390 formattedLoad(kernelString, i, i*groupSize, dataFormat);
391 }
392 kernelString += string("}\n");
393
394 if(numWorkItemsPerXForm > 1)
395 {
396 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n"); 325 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
397 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); 326 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
398 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); 327 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n");
399 } 328
400 else 329 for( i = 0; i < numOuterIter; i++ )
401 { 330 {
402 kernelString += string(" ii = 0;\n"); 331 for( j = 0; j < numInnerIter; j++ )
403 kernelString += string(" jj = lId;\n"); 332 {
404 kernelString += string(" lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n"); 333 kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") +
405 } 334 num2str(i * numInnerIter + j) + string("].x;\n");
406 335 }
407 336 }
408 for( i = 0; i < R0; i++ ) 337 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
409 kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].x;\n"); 338
410 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 339 for( i = 0; i < R0; i++ )
411 340 kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
412 for( i = 0; i < R0; i++ ) 341 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
413 kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); 342
414 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 343 for( i = 0; i < numOuterIter; i++ )
415 344 {
416 for( i = 0; i < R0; i++ ) 345 for( j = 0; j < numInnerIter; j++ )
417 kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].y;\n"); 346 {
418 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 347 kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") +
419 348 num2str(i * numInnerIter + j) + string("].y;\n");
420 for( i = 0; i < R0; i++ ) 349 }
421 kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n"); 350 }
422 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 351 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
423 352
424 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; 353 for( i = 0; i < R0; i++ )
425 } 354 kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
426 355 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
427 return lMemSize; 356
357 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
358 }
359 else
360 {
361 kernelString += string(" offset = mad24( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n");
362 if(dataFormat == clFFT_InterleavedComplexFormat)
363 {
364 kernelString += string(" in += offset;\n");
365 kernelString += string(" out += offset;\n");
366 }
367 else
368 {
369 kernelString += string(" in_real += offset;\n");
370 kernelString += string(" in_imag += offset;\n");
371 kernelString += string(" out_real += offset;\n");
372 kernelString += string(" out_imag += offset;\n");
373 }
374
375 kernelString += string(" ii = lId & ") + num2str(N-1) + string(";\n");
376 kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n");
377 kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
378
379 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
380 for( i = 0; i < R0; i++ )
381 {
382 kernelString += string(" if(jj < s )\n");
383 formattedLoad(kernelString, i, i*groupSize, dataFormat);
384 if(i != R0 - 1)
385 kernelString += string(" jj += ") + num2str(groupSize / N) + string(";\n");
386 }
387 kernelString += string("}\n");
388 kernelString += string("else {\n");
389 for( i = 0; i < R0; i++ )
390 {
391 formattedLoad(kernelString, i, i*groupSize, dataFormat);
392 }
393 kernelString += string("}\n");
394
395 if(numWorkItemsPerXForm > 1)
396 {
397 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
398 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
399 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
400 }
401 else
402 {
403 kernelString += string(" ii = 0;\n");
404 kernelString += string(" jj = lId;\n");
405 kernelString += string(" lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n");
406 }
407
408
409 for( i = 0; i < R0; i++ )
410 kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].x;\n");
411 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
412
413 for( i = 0; i < R0; i++ )
414 kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
415 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
416
417 for( i = 0; i < R0; i++ )
418 kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].y;\n");
419 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
420
421 for( i = 0; i < R0; i++ )
422 kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
423 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
424
425 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
426 }
427
428 return lMemSize;
428 } 429 }
429 430
430 static int 431 static int
431 insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr, int numWorkItemsPerXForm, int numXFormsPerWG, int mem_coalesce_width, clFFT_DataFormat dataFormat) 432 insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr, int numWorkItemsPerXForm, int numXFormsPerWG, int mem_coalesce_width, clFFT_DataFormat dataFormat)
432 { 433 {
433 int groupSize = numWorkItemsPerXForm * numXFormsPerWG; 434 int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
434 int i, j, k, ind; 435 int i, j, k, ind;
435 int lMemSize = 0; 436 int lMemSize = 0;
436 int numIter = maxRadix / Nr; 437 int numIter = maxRadix / Nr;
437 string indent = string(""); 438 string indent = string("");
438 439
439 if( numWorkItemsPerXForm >= mem_coalesce_width ) 440 if( numWorkItemsPerXForm >= mem_coalesce_width )
440 { 441 {
441 if(numXFormsPerWG > 1) 442 if(numXFormsPerWG > 1)
442 { 443 {
443 kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n"); 444 kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
444 indent = string(" "); 445 indent = string(" ");
445 } 446 }
446 for(i = 0; i < maxRadix; i++) 447 for(i = 0; i < maxRadix; i++)
447 { 448 {
448 j = i % numIter; 449 j = i % numIter;
449 k = i / numIter; 450 k = i / numIter;
450 ind = j * Nr + k; 451 ind = j * Nr + k;
451 formattedStore(kernelString, ind, i*numWorkItemsPerXForm, dataFormat); 452 formattedStore(kernelString, ind, i*numWorkItemsPerXForm, dataFormat);
452 } 453 }
453 if(numXFormsPerWG > 1) 454 if(numXFormsPerWG > 1)
454 kernelString += string(" }\n"); 455 kernelString += string(" }\n");
455 } 456 }
456 else if( N >= mem_coalesce_width ) 457 else if( N >= mem_coalesce_width )
457 { 458 {
458 int numInnerIter = N / mem_coalesce_width; 459 int numInnerIter = N / mem_coalesce_width;
459 int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width ); 460 int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
460 461
461 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); 462 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
462 kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n"); 463 kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
463 kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n"); 464 kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
464 kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); 465 kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
465 466
466 for( i = 0; i < maxRadix; i++ ) 467 for( i = 0; i < maxRadix; i++ )
467 { 468 {
468 j = i % numIter; 469 j = i % numIter;
469 k = i / numIter; 470 k = i / numIter;
470 ind = j * Nr + k; 471 ind = j * Nr + k;
471 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n"); 472 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
472 } 473 }
473 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 474 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
474 475
475 for( i = 0; i < numOuterIter; i++ ) 476 for( i = 0; i < numOuterIter; i++ )
476 for( j = 0; j < numInnerIter; j++ ) 477 for( j = 0; j < numInnerIter; j++ )
477 kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].x = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n"); 478 kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].x = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n");
478 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 479 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
479 480
480 for( i = 0; i < maxRadix; i++ ) 481 for( i = 0; i < maxRadix; i++ )
481 { 482 {
482 j = i % numIter; 483 j = i % numIter;
483 k = i / numIter; 484 k = i / numIter;
484 ind = j * Nr + k; 485 ind = j * Nr + k;
485 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n"); 486 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
486 } 487 }
487 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 488 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
488 489
489 for( i = 0; i < numOuterIter; i++ ) 490 for( i = 0; i < numOuterIter; i++ )
490 for( j = 0; j < numInnerIter; j++ ) 491 for( j = 0; j < numInnerIter; j++ )
491 kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].y = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n"); 492 kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].y = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n");
492 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 493 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
493 494
494 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); 495 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
495 for(i = 0; i < numOuterIter; i++ ) 496 for(i = 0; i < numOuterIter; i++ )
496 { 497 {
497 kernelString += string(" if( jj < s ) {\n"); 498 kernelString += string(" if( jj < s ) {\n");
498 for(j = 0; j < numInnerIter; j++ ) 499 for(j = 0; j < numInnerIter; j++ )
499 formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat); 500 formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
500 kernelString += string(" }\n"); 501 kernelString += string(" }\n");
501 if(i != numOuterIter - 1) 502 if(i != numOuterIter - 1)
502 kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n"); 503 kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");
503 } 504 }
504 kernelString += string("}\n"); 505 kernelString += string("}\n");
505 kernelString += string("else {\n"); 506 kernelString += string("else {\n");
506 for(i = 0; i < numOuterIter; i++ ) 507 for(i = 0; i < numOuterIter; i++ )
507 { 508 {
508 for(j = 0; j < numInnerIter; j++ ) 509 for(j = 0; j < numInnerIter; j++ )
509 formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat); 510 formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
510 } 511 }
511 kernelString += string("}\n"); 512 kernelString += string("}\n");
512 513
513 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; 514 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
514 } 515 }
515 else 516 else
516 { 517 {
517 kernelString += string(" lMemLoad = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); 518 kernelString += string(" lMemLoad = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
518 519
519 kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n"); 520 kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n");
520 kernelString += string(" jj = lId >> ") + num2str((int) log2(N)) + string(";\n"); 521 kernelString += string(" jj = lId >> ") + num2str((int) log2(N)) + string(";\n");
521 kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); 522 kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
522 523
523 for( i = 0; i < maxRadix; i++ ) 524 for( i = 0; i < maxRadix; i++ )
524 { 525 {
525 j = i % numIter; 526 j = i % numIter;
526 k = i / numIter; 527 k = i / numIter;
527 ind = j * Nr + k; 528 ind = j * Nr + k;
528 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n"); 529 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
529 } 530 }
530 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 531 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
531 532
532 for( i = 0; i < maxRadix; i++ ) 533 for( i = 0; i < maxRadix; i++ )
533 kernelString += string(" a[") + num2str(i) + string("].x = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n"); 534 kernelString += string(" a[") + num2str(i) + string("].x = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n");
534 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 535 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
535 536
536 for( i = 0; i < maxRadix; i++ ) 537 for( i = 0; i < maxRadix; i++ )
537 { 538 {
538 j = i % numIter; 539 j = i % numIter;
539 k = i / numIter; 540 k = i / numIter;
540 ind = j * Nr + k; 541 ind = j * Nr + k;
541 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n"); 542 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
542 } 543 }
543 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 544 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
544 545
545 for( i = 0; i < maxRadix; i++ ) 546 for( i = 0; i < maxRadix; i++ )
546 kernelString += string(" a[") + num2str(i) + string("].y = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n"); 547 kernelString += string(" a[") + num2str(i) + string("].y = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n");
547 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n"); 548 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
548 549
549 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); 550 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
550 for( i = 0; i < maxRadix; i++ ) 551 for( i = 0; i < maxRadix; i++ )
551 { 552 {
552 kernelString += string(" if(jj < s ) {\n"); 553 kernelString += string(" if(jj < s ) {\n");
553 formattedStore(kernelString, i, i*groupSize, dataFormat); 554 formattedStore(kernelString, i, i*groupSize, dataFormat);
554 kernelString += string(" }\n"); 555 kernelString += string(" }\n");
555 if( i != maxRadix - 1) 556 if( i != maxRadix - 1)
556 kernelString += string(" jj +=") + num2str(groupSize / N) + string(";\n"); 557 kernelString += string(" jj +=") + num2str(groupSize / N) + string(";\n");
557 } 558 }
558 kernelString += string("}\n"); 559 kernelString += string("}\n");
559 kernelString += string("else {\n"); 560 kernelString += string("else {\n");
560 for( i = 0; i < maxRadix; i++ ) 561 for( i = 0; i < maxRadix; i++ )
561 { 562 {
562 formattedStore(kernelString, i, i*groupSize, dataFormat); 563 formattedStore(kernelString, i, i*groupSize, dataFormat);
563 } 564 }
564 kernelString += string("}\n"); 565 kernelString += string("}\n");
565 566
566 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG; 567 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
567 } 568 }
568 569
569 return lMemSize; 570 return lMemSize;
570 } 571 }
571 572
572 static void 573 static void
573 insertfftKernel(string &kernelString, int Nr, int numIter) 574 insertfftKernel(string &kernelString, int Nr, int numIter)
574 { 575 {
575 int i; 576 int i;
576 for(i = 0; i < numIter; i++) 577 for(i = 0; i < numIter; i++)
577 { 578 {
578 kernelString += string(" fftKernel") + num2str(Nr) + string("(a+") + num2str(i*Nr) + string(", dir);\n"); 579 kernelString += string(" fftKernel") + num2str(Nr) + string("(a+") + num2str(i*Nr) + string(", dir);\n");
579 } 580 }
580 } 581 }
581 582
582 static void 583 static void
583 insertTwiddleKernel(string &kernelString, int Nr, int numIter, int Nprev, int len, int numWorkItemsPerXForm) 584 insertTwiddleKernel(string &kernelString, int Nr, int numIter, int Nprev, int len, int numWorkItemsPerXForm)
584 { 585 {
585 int z, k; 586 int z, k;
586 int logNPrev = (int)log2(Nprev); 587 int logNPrev = (int)log2(Nprev);
587 588
588 for(z = 0; z < numIter; z++) 589 for(z = 0; z < numIter; z++)
589 { 590 {
590 if(z == 0) 591 if(z == 0)
591 { 592 {
592 if(Nprev > 1) 593 if(Nprev > 1)
593 kernelString += string(" angf = (float) (ii >> ") + num2str(logNPrev) + string(");\n"); 594 kernelString += string(" angf = (float) (ii >> ") + num2str(logNPrev) + string(");\n");
594 else 595 else
595 kernelString += string(" angf = (float) ii;\n"); 596 kernelString += string(" angf = (float) ii;\n");
596 } 597 }
597 else 598 else
598 { 599 {
599 if(Nprev > 1) 600 if(Nprev > 1)
600 kernelString += string(" angf = (float) ((") + num2str(z*numWorkItemsPerXForm) + string(" + ii) >>") + num2str(logNPrev) + string(");\n"); 601 kernelString += string(" angf = (float) ((") + num2str(z*numWorkItemsPerXForm) + string(" + ii) >>") + num2str(logNPrev) + string(");\n");
601 else 602 else
602 kernelString += string(" angf = (float) (") + num2str(z*numWorkItemsPerXForm) + string(" + ii);\n"); 603 kernelString += string(" angf = (float) (") + num2str(z*numWorkItemsPerXForm) + string(" + ii);\n");
603 } 604 }
604 605
605 for(k = 1; k < Nr; k++) { 606 for(k = 1; k < Nr; k++) {
606 int ind = z*Nr + k; 607 int ind = z*Nr + k;
607 //float fac = (float) (2.0 * M_PI * (double) k / (double) len); 608 //float fac = (float) (2.0 * M_PI * (double) k / (double) len);
608 kernelString += string(" ang = dir * ( 2.0f * M_PI * ") + num2str(k) + string(".0f / ") + num2str(len) + string(".0f )") + string(" * angf;\n"); 609 kernelString += string(" ang = dir * ( 2.0f * M_PI * ") + num2str(k) + string(".0f / ") + num2str(len) + string(".0f )") + string(" * angf;\n");
609 kernelString += string(" w = (float2)(native_cos(ang), native_sin(ang));\n"); 610 kernelString += string(" w = (float2)(native_cos(ang), native_sin(ang));\n");
610 kernelString += string(" a[") + num2str(ind) + string("] = complexMul(a[") + num2str(ind) + string("], w);\n"); 611 kernelString += string(" a[") + num2str(ind) + string("] = complexMul(a[") + num2str(ind) + string("], w);\n");
611 } 612 }
612 } 613 }
613 } 614 }
614 615
615 static int 616 static int
616 getPadding(int numWorkItemsPerXForm, int Nprev, int numWorkItemsReq, int numXFormsPerWG, int Nr, int numBanks, int *offset, int *midPad) 617 getPadding(int numWorkItemsPerXForm, int Nprev, int numWorkItemsReq, int numXFormsPerWG, int Nr, int numBanks, int *offset, int *midPad)
617 { 618 {
618 if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks)) 619 if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks))
619 *offset = 0; 620 *offset = 0;
620 else { 621 else {
621 int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev; 622 int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev;
622 int numColsReq = 1; 623 int numColsReq = 1;
623 if(numRowsReq > Nr) 624 if(numRowsReq > Nr)
624 numColsReq = numRowsReq / Nr; 625 numColsReq = numRowsReq / Nr;
625 numColsReq = Nprev * numColsReq; 626 numColsReq = Nprev * numColsReq;
626 *offset = numColsReq; 627 *offset = numColsReq;
627 } 628 }
628 629
629 if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1) 630 if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1)
630 *midPad = 0; 631 *midPad = 0;
631 else { 632 else {
632 int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1); 633 int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1);
633 if( bankNum >= numWorkItemsPerXForm ) 634 if( bankNum >= numWorkItemsPerXForm )
634 *midPad = 0; 635 *midPad = 0;
635 else 636 else
636 *midPad = numWorkItemsPerXForm - bankNum; 637 *midPad = numWorkItemsPerXForm - bankNum;
637 } 638 }
638 639
639 int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1); 640 int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1);
640 return lMemSize; 641 return lMemSize;
641 } 642 }
642 643
643 644
644 static void 645 static void
645 insertLocalStores(string &kernelString, int numIter, int Nr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp) 646 insertLocalStores(string &kernelString, int numIter, int Nr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
646 { 647 {
647 int z, k; 648 int z, k;
648 649
649 for(z = 0; z < numIter; z++) { 650 for(z = 0; z < numIter; z++) {
650 for(k = 0; k < Nr; k++) { 651 for(k = 0; k < Nr; k++) {
651 int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm; 652 int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm;
652 kernelString += string(" lMemStore[") + num2str(index) + string("] = a[") + num2str(z*Nr + k) + string("].") + comp + string(";\n"); 653 kernelString += string(" lMemStore[") + num2str(index) + string("] = a[") + num2str(z*Nr + k) + string("].") + comp + string(";\n");
653 } 654 }
654 } 655 }
655 kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n"); 656 kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n");
656 } 657 }
657 658
658 static void 659 static void
659 insertLocalLoads(string &kernelString, int n, int Nr, int Nrn, int Nprev, int Ncurr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp) 660 insertLocalLoads(string &kernelString, int n, int Nr, int Nrn, int Nprev, int Ncurr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
660 { 661 {
661 int numWorkItemsReqN = n / Nrn; 662 int numWorkItemsReqN = n / Nrn;
662 int interBlockHNum = max( Nprev / numWorkItemsPerXForm, 1 ); 663 int interBlockHNum = max( Nprev / numWorkItemsPerXForm, 1 );
663 int interBlockHStride = numWorkItemsPerXForm; 664 int interBlockHStride = numWorkItemsPerXForm;
664 int vertWidth = max(numWorkItemsPerXForm / Nprev, 1); 665 int vertWidth = max(numWorkItemsPerXForm / Nprev, 1);
665 vertWidth = min( vertWidth, Nr); 666 vertWidth = min( vertWidth, Nr);
666 int vertNum = Nr / vertWidth; 667 int vertNum = Nr / vertWidth;
667 int vertStride = ( n / Nr + offset ) * vertWidth; 668 int vertStride = ( n / Nr + offset ) * vertWidth;
668 int iter = max( numWorkItemsReqN / numWorkItemsPerXForm, 1); 669 int iter = max( numWorkItemsReqN / numWorkItemsPerXForm, 1);
669 int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1; 670 int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1;
670 intraBlockHStride *= Nprev; 671 intraBlockHStride *= Nprev;
671 672
672 int stride = numWorkItemsReq / Nrn; 673 int stride = numWorkItemsReq / Nrn;
673 int i; 674 int i;
674 for(i = 0; i < iter; i++) { 675 for(i = 0; i < iter; i++) {
675 int ii = i / (interBlockHNum * vertNum); 676 int ii = i / (interBlockHNum * vertNum);
676 int zz = i % (interBlockHNum * vertNum); 677 int zz = i % (interBlockHNum * vertNum);
677 int jj = zz % interBlockHNum; 678 int jj = zz % interBlockHNum;
678 int kk = zz / interBlockHNum; 679 int kk = zz / interBlockHNum;
679 int z; 680 int z;
680 for(z = 0; z < Nrn; z++) { 681 for(z = 0; z < Nrn; z++) {
681 int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride; 682 int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride;
682 kernelString += string(" a[") + num2str(i*Nrn + z) + string("].") + comp + string(" = lMemLoad[") + num2str(st) + string("];\n"); 683 kernelString += string(" a[") + num2str(i*Nrn + z) + string("].") + comp + string(" = lMemLoad[") + num2str(st) + string("];\n");
683 } 684 }
684 } 685 }
685 kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n"); 686 kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n");
686 } 687 }
687 688
688 static void 689 static void
689 insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numWorkItemsReq, int numWorkItemsPerXForm, int numXFormsPerWG, int offset, int midPad) 690 insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numWorkItemsReq, int numWorkItemsPerXForm, int numXFormsPerWG, int offset, int midPad)
690 { 691 {
691 int Ncurr = Nprev * Nr; 692 int Ncurr = Nprev * Nr;
692 int logNcurr = (int)log2(Ncurr); 693 int logNcurr = (int)log2(Ncurr);
693 int logNprev = (int)log2(Nprev); 694 int logNprev = (int)log2(Nprev);
694 int incr = (numWorkItemsReq + offset) * Nr + midPad; 695 int incr = (numWorkItemsReq + offset) * Nr + midPad;
695 696
696 if(Ncurr < numWorkItemsPerXForm) 697 if(Ncurr < numWorkItemsPerXForm)
697 { 698 {
698 if(Nprev == 1) 699 if(Nprev == 1)
699 kernelString += string(" j = ii & ") + num2str(Ncurr - 1) + string(";\n"); 700 kernelString += string(" j = ii & ") + num2str(Ncurr - 1) + string(";\n");
700 else 701 else
701 kernelString += string(" j = (ii & ") + num2str(Ncurr - 1) + string(") >> ") + num2str(logNprev) + string(";\n"); 702 kernelString += string(" j = (ii & ") + num2str(Ncurr - 1) + string(") >> ") + num2str(logNprev) + string(";\n");
702 703
703 if(Nprev == 1) 704 if(Nprev == 1)
704 kernelString += string(" i = ii >> ") + num2str(logNcurr) + string(";\n"); 705 kernelString += string(" i = ii >> ") + num2str(logNcurr) + string(";\n");
705 else 706 else
706 kernelString += string(" i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n"); 707 kernelString += string(" i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n");
707 } 708 }
708 else 709 else
709 { 710 {
710 if(Nprev == 1) 711 if(Nprev == 1)
711 kernelString += string(" j = ii;\n"); 712 kernelString += string(" j = ii;\n");
712 else 713 else
713 kernelString += string(" j = ii >> ") + num2str(logNprev) + string(";\n"); 714 kernelString += string(" j = ii >> ") + num2str(logNprev) + string(";\n");
714 if(Nprev == 1) 715 if(Nprev == 1)
715 kernelString += string(" i = 0;\n"); 716 kernelString += string(" i = 0;\n");
716 else 717 else
717 kernelString += string(" i = ii & ") + num2str(Nprev - 1) + string(";\n"); 718 kernelString += string(" i = ii & ") + num2str(Nprev - 1) + string(";\n");
718 } 719 }
719 720
720 if(numXFormsPerWG > 1) 721 if(numXFormsPerWG > 1)
721 kernelString += string(" i = mad24(jj, ") + num2str(incr) + string(", i);\n"); 722 kernelString += string(" i = mad24(jj, ") + num2str(incr) + string(", i);\n");
722 723
723 kernelString += string(" lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n"); 724 kernelString += string(" lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n");
724 } 725 }
725 726
726 static void 727 static void
727 insertLocalStoreIndexArithmatic(string &kernelString, int numWorkItemsReq, int numXFormsPerWG, int Nr, int offset, int midPad) 728 insertLocalStoreIndexArithmatic(string &kernelString, int numWorkItemsReq, int numXFormsPerWG, int Nr, int offset, int midPad)
728 { 729 {
729 if(numXFormsPerWG == 1) { 730 if(numXFormsPerWG == 1) {
730 kernelString += string(" lMemStore = sMem + ii;\n"); 731 kernelString += string(" lMemStore = sMem + ii;\n");
731 } 732 }
732 else { 733 else {
733 kernelString += string(" lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n"); 734 kernelString += string(" lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n");
734 } 735 }
735 } 736 }
736 737
737 738
738 static void 739 static void
739 createLocalMemfftKernelString(cl_fft_plan *plan) 740 createLocalMemfftKernelString(cl_fft_plan *plan)
740 { 741 {
741 unsigned int radixArray[10]; 742 unsigned int radixArray[10];
742 unsigned int numRadix; 743 unsigned int numRadix;
743 744
744 unsigned int n = plan->n.x; 745 unsigned int n = plan->n.x;
745 746
746 assert(n <= plan->max_work_item_per_workgroup * plan->max_radix && "signal lenght too big for local mem fft\n"); 747 assert(n <= plan->max_work_item_per_workgroup * plan->max_radix && "signal lenght too big for local mem fft\n");
747 748
748 getRadixArray(n, radixArray, &numRadix, 0); 749 getRadixArray(n, radixArray, &numRadix, 0);
749 assert(numRadix > 0 && "no radix array supplied\n"); 750 assert(numRadix > 0 && "no radix array supplied\n");
750 751
751 if(n/radixArray[0] > plan->max_work_item_per_workgroup) 752 if(n/radixArray[0] > plan->max_work_item_per_workgroup)
752 getRadixArray(n, radixArray, &numRadix, plan->max_radix); 753 getRadixArray(n, radixArray, &numRadix, plan->max_radix);
753 754
754 assert(radixArray[0] <= plan->max_radix && "max radix choosen is greater than allowed\n"); 755 assert(radixArray[0] <= plan->max_radix && "max radix choosen is greater than allowed\n");
755 assert(n/radixArray[0] <= plan->max_work_item_per_workgroup && "required work items per xform greater than maximum work items allowed per work group for local mem fft\n"); 756 assert(n/radixArray[0] <= plan->max_work_item_per_workgroup && "required work items per xform greater than maximum work items allowed per work group for local mem fft\n");
756 757
757 unsigned int tmpLen = 1; 758 unsigned int tmpLen = 1;
758 unsigned int i; 759 unsigned int i;
759 for(i = 0; i < numRadix; i++) 760 for(i = 0; i < numRadix; i++)
760 { 761 {
761 assert( radixArray[i] && !( (radixArray[i] - 1) & radixArray[i] ) ); 762 assert( radixArray[i] && !( (radixArray[i] - 1) & radixArray[i] ) );
762 tmpLen *= radixArray[i]; 763 tmpLen *= radixArray[i];
763 } 764 }
764 assert(tmpLen == n && "product of radices choosen doesnt match the length of signal\n"); 765 assert(tmpLen == n && "product of radices choosen doesnt match the length of signal\n");
765 766
766 int offset, midPad; 767 int offset, midPad;
767 string localString(""), kernelName(""); 768 string localString(""), kernelName("");
768 769
769 clFFT_DataFormat dataFormat = plan->format; 770 clFFT_DataFormat dataFormat = plan->format;
770 string *kernelString = plan->kernel_string; 771 string *kernelString = plan->kernel_string;
771 772
772 773
773 cl_fft_kernel_info **kInfo = &plan->kernel_info; 774 cl_fft_kernel_info **kInfo = &plan->kernel_info;
774 int kCount = 0; 775 int kCount = 0;
775 776
776 while(*kInfo) 777 while(*kInfo)
777 { 778 {
778 kInfo = &(*kInfo)->next; 779 kInfo = &(*kInfo)->next;
779 kCount++; 780 kCount++;
780 } 781 }
781 782
782 kernelName = string("fft") + num2str(kCount); 783 kernelName = string("fft") + num2str(kCount);
783 784
784 *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info)); 785 *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info));
785 (*kInfo)->kernel = 0; 786 (*kInfo)->kernel = 0;
786 (*kInfo)->lmem_size = 0; 787 (*kInfo)->lmem_size = 0;
787 (*kInfo)->num_workgroups = 0; 788 (*kInfo)->num_workgroups = 0;
788 (*kInfo)->num_workitems_per_workgroup = 0; 789 (*kInfo)->num_workitems_per_workgroup = 0;
789 (*kInfo)->dir = cl_fft_kernel_x; 790 (*kInfo)->dir = cl_fft_kernel_x;
790 (*kInfo)->in_place_possible = 1; 791 (*kInfo)->in_place_possible = 1;
791 (*kInfo)->next = NULL; 792 (*kInfo)->next = NULL;
792 (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1)); 793 (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1));
793 strcpy((*kInfo)->kernel_name, kernelName.c_str()); 794 strcpy((*kInfo)->kernel_name, kernelName.c_str());
794 795
795 unsigned int numWorkItemsPerXForm = n / radixArray[0]; 796 unsigned int numWorkItemsPerXForm = n / radixArray[0];
796 unsigned int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm; 797 unsigned int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm;
797 assert(numWorkItemsPerWG <= plan->max_work_item_per_workgroup); 798 assert(numWorkItemsPerWG <= plan->max_work_item_per_workgroup);
798 int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm; 799 int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm;
799 (*kInfo)->num_workgroups = 1; 800 (*kInfo)->num_workgroups = 1;
800 (*kInfo)->num_xforms_per_workgroup = numXFormsPerWG; 801 (*kInfo)->num_xforms_per_workgroup = numXFormsPerWG;
801 (*kInfo)->num_workitems_per_workgroup = numWorkItemsPerWG; 802 (*kInfo)->num_workitems_per_workgroup = numWorkItemsPerWG;
802 803
803 unsigned int *N = radixArray; 804 unsigned int *N = radixArray;
804 unsigned int maxRadix = N[0]; 805 unsigned int maxRadix = N[0];
805 unsigned int lMemSize = 0; 806 unsigned int lMemSize = 0;
806 807
807 insertVariables(localString, maxRadix); 808 insertVariables(localString, maxRadix);
808 809
809 lMemSize = insertGlobalLoadsAndTranspose(localString, n, numWorkItemsPerXForm, numXFormsPerWG, maxRadix, plan->min_mem_coalesce_width, dataFormat); 810 lMemSize = insertGlobalLoadsAndTranspose(localString, n, numWorkItemsPerXForm, numXFormsPerWG, maxRadix, plan->min_mem_coalesce_width, dataFormat);
810 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; 811 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
811 812
812 string xcomp = string("x"); 813 string xcomp = string("x");
813 string ycomp = string("y"); 814 string ycomp = string("y");
814 815
815 unsigned int Nprev = 1; 816 unsigned int Nprev = 1;
816 unsigned int len = n; 817 unsigned int len = n;
817 unsigned int r; 818 unsigned int r;
818 for(r = 0; r < numRadix; r++) 819 for(r = 0; r < numRadix; r++)
819 { 820 {
820 int numIter = N[0] / N[r]; 821 int numIter = N[0] / N[r];
821 int numWorkItemsReq = n / N[r]; 822 int numWorkItemsReq = n / N[r];
822 int Ncurr = Nprev * N[r]; 823 int Ncurr = Nprev * N[r];
823 insertfftKernel(localString, N[r], numIter); 824 insertfftKernel(localString, N[r], numIter);
824 825
825 if(r < (numRadix - 1)) { 826 if(r < (numRadix - 1)) {
826 insertTwiddleKernel(localString, N[r], numIter, Nprev, len, numWorkItemsPerXForm); 827 insertTwiddleKernel(localString, N[r], numIter, Nprev, len, numWorkItemsPerXForm);
827 lMemSize = getPadding(numWorkItemsPerXForm, Nprev, numWorkItemsReq, numXFormsPerWG, N[r], plan->num_local_mem_banks, &offset, &midPad); 828 lMemSize = getPadding(numWorkItemsPerXForm, Nprev, numWorkItemsReq, numXFormsPerWG, N[r], plan->num_local_mem_banks, &offset, &midPad);
828 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; 829 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
829 insertLocalStoreIndexArithmatic(localString, numWorkItemsReq, numXFormsPerWG, N[r], offset, midPad); 830 insertLocalStoreIndexArithmatic(localString, numWorkItemsReq, numXFormsPerWG, N[r], offset, midPad);
830 insertLocalLoadIndexArithmatic(localString, Nprev, N[r], numWorkItemsReq, numWorkItemsPerXForm, numXFormsPerWG, offset, midPad); 831 insertLocalLoadIndexArithmatic(localString, Nprev, N[r], numWorkItemsReq, numWorkItemsPerXForm, numXFormsPerWG, offset, midPad);
831 insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp); 832 insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
832 insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp); 833 insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
833 insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp); 834 insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
834 insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp); 835 insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
835 Nprev = Ncurr; 836 Nprev = Ncurr;
836 len = len / N[r]; 837 len = len / N[r];
837 } 838 }
838 } 839 }
839 840
840 lMemSize = insertGlobalStoresAndTranspose(localString, n, maxRadix, N[numRadix - 1], numWorkItemsPerXForm, numXFormsPerWG, plan->min_mem_coalesce_width, dataFormat); 841 lMemSize = insertGlobalStoresAndTranspose(localString, n, maxRadix, N[numRadix - 1], numWorkItemsPerXForm, numXFormsPerWG, plan->min_mem_coalesce_width, dataFormat);
841 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size; 842 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
842 843
843 insertHeader(*kernelString, kernelName, dataFormat); 844 insertHeader(*kernelString, kernelName, dataFormat);
844 *kernelString += string("{\n"); 845 *kernelString += string("{\n");
845 if((*kInfo)->lmem_size) 846 if((*kInfo)->lmem_size)
846 *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n"); 847 *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
847 *kernelString += localString; 848 *kernelString += localString;
848 *kernelString += string("}\n"); 849 *kernelString += string("}\n");
849 } 850 }
850 851
851 // For n larger than what can be computed using local memory fft, global transposes 852 // For n larger than what can be computed using local memory fft, global transposes
852 // multiple kernel launces is needed. For these sizes, n can be decomposed using 853 // multiple kernel launces is needed. For these sizes, n can be decomposed using
853 // much larger base radices i.e. say n = 262144 = 128 x 64 x 32. Thus three kernel 854 // much larger base radices i.e. say n = 262144 = 128 x 64 x 32. Thus three kernel
854 // launches will be needed, first computing 64 x 32, length 128 ffts, second computing 855 // launches will be needed, first computing 64 x 32, length 128 ffts, second computing
855 // 128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts. 856 // 128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts.
856 // Each of these base radices can futher be divided into factors so that each of these 857 // Each of these base radices can futher be divided into factors so that each of these
857 // base ffts can be computed within one kernel launch using in-register ffts and local 858 // base ffts can be computed within one kernel launch using in-register ffts and local
858 // memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length 859 // memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length
859 // 128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length 860 // 128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length
860 // 16 ffts followed by transpose using local memory followed by each of these eight 861 // 16 ffts followed by transpose using local memory followed by each of these eight
861 // work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This 862 // work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This
862 // means only 8 work items are needed for computing one length 128 fft. If we choose 863 // means only 8 work items are needed for computing one length 128 fft. If we choose
863 // work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one 864 // work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one
864 // work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this 865 // work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this
865 // means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each 866 // means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each
866 // work group where each work group is computing 8 length 128 ffts where each length 867 // work group where each work group is computing 8 length 128 ffts where each length
867 // 128 fft is computed by 8 work items. Same logic can be applied to other two kernels 868 // 128 fft is computed by 8 work items. Same logic can be applied to other two kernels
868 // in this example. Users can play with difference base radices and difference 869 // in this example. Users can play with difference base radices and difference
869 // decompositions of base radices to generates different kernels and see which gives 870 // decompositions of base radices to generates different kernels and see which gives
870 // best performance. Following function is just fixed to use 128 as base radix 871 // best performance. Following function is just fixed to use 128 as base radix
871 872
872 void 873 void
873 getGlobalRadixInfo(int n, int *radix, int *R1, int *R2, int *numRadices) 874 getGlobalRadixInfo(int n, int *radix, int *R1, int *R2, int *numRadices)
874 { 875 {
875 int baseRadix = min(n, 128); 876 int baseRadix = min(n, 128);
876 877
877 int numR = 0; 878 int numR = 0;
878 int N = n; 879 int N = n;
879 while(N > baseRadix) 880 while(N > baseRadix)
880 { 881 {
881 N /= baseRadix; 882 N /= baseRadix;
882 numR++; 883 numR++;
883 } 884 }
884 885
885 for(int i = 0; i < numR; i++) 886 for(int i = 0; i < numR; i++)
886 radix[i] = baseRadix; 887 radix[i] = baseRadix;
887 888
888 radix[numR] = N; 889 radix[numR] = N;
889 numR++; 890 numR++;
890 *numRadices = numR; 891 *numRadices = numR;
891 892
892 for(int i = 0; i < numR; i++) 893 for(int i = 0; i < numR; i++)
893 { 894 {
894 int B = radix[i]; 895 int B = radix[i];
895 if(B <= 8) 896 if(B <= 8)
896 { 897 {
897 R1[i] = B; 898 R1[i] = B;
898 R2[i] = 1; 899 R2[i] = 1;
899 continue; 900 continue;
900 } 901 }
901 902
902 int r1 = 2; 903 int r1 = 2;
903 int r2 = B / r1; 904 int r2 = B / r1;
904 while(r2 > r1) 905 while(r2 > r1)
905 { 906 {
906 r1 *=2; 907 r1 *=2;
907 r2 = B / r1; 908 r2 = B / r1;
908 } 909 }
909 R1[i] = r1; 910 R1[i] = r1;
910 R2[i] = r2; 911 R2[i] = r2;
911 } 912 }
912 } 913 }
913 914
914 static void 915 static void
915 createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir dir, int vertBS) 916 createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir dir, int vertBS)
916 { 917 {
917 int i, j, k, t; 918 int i, j, k, t;
918 int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; 919 int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
919 int R1Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; 920 int R1Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
920 int R2Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; 921 int R2Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
921 int radix, R1, R2; 922 int radix, R1, R2;
922 int numRadices; 923 int numRadices;
923 924
924 int maxThreadsPerBlock = plan->max_work_item_per_workgroup; 925 int maxThreadsPerBlock = plan->max_work_item_per_workgroup;
925 int maxArrayLen = plan->max_radix; 926 int maxArrayLen = plan->max_radix;
926 int batchSize = plan->min_mem_coalesce_width; 927 int batchSize = plan->min_mem_coalesce_width;
927 clFFT_DataFormat dataFormat = plan->format; 928 clFFT_DataFormat dataFormat = plan->format;
928 int vertical = (dir == cl_fft_kernel_x) ? 0 : 1; 929 int vertical = (dir == cl_fft_kernel_x) ? 0 : 1;
929 930
930 getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices); 931 getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices);
931 932
932 int numPasses = numRadices; 933 int numPasses = numRadices;
933 934
934 string localString(""), kernelName(""); 935 string localString(""), kernelName("");
935 string *kernelString = plan->kernel_string; 936 string *kernelString = plan->kernel_string;
936 cl_fft_kernel_info **kInfo = &plan->kernel_info; 937 cl_fft_kernel_info **kInfo = &plan->kernel_info;
937 int kCount = 0; 938 int kCount = 0;
938 939
939 while(*kInfo) 940 while(*kInfo)
940 { 941 {
941 kInfo = &(*kInfo)->next; 942 kInfo = &(*kInfo)->next;
942 kCount++; 943 kCount++;
943 } 944 }
944 945
945 int N = n; 946 int N = n;
946 int m = (int)log2(n); 947 int m = (int)log2(n);
947 int Rinit = vertical ? BS : 1; 948 int Rinit = vertical ? BS : 1;
948 batchSize = vertical ? min(BS, batchSize) : batchSize; 949 batchSize = vertical ? min(BS, batchSize) : batchSize;
949 int passNum; 950 int passNum;
950 951
951 for(passNum = 0; passNum < numPasses; passNum++) 952 for(passNum = 0; passNum < numPasses; passNum++)
952 { 953 {
953 954
954 localString.clear(); 955 localString.clear();
955 kernelName.clear(); 956 kernelName.clear();
956 957
957 radix = radixArr[passNum]; 958 radix = radixArr[passNum];
958 R1 = R1Arr[passNum]; 959 R1 = R1Arr[passNum];
959 R2 = R2Arr[passNum]; 960 R2 = R2Arr[passNum];
960 961
961 int strideI = Rinit; 962 int strideI = Rinit;
962 for(i = 0; i < numPasses; i++) 963 for(i = 0; i < numPasses; i++)
963 if(i != passNum) 964 if(i != passNum)
964 strideI *= radixArr[i]; 965 strideI *= radixArr[i];
965 966
966 int strideO = Rinit; 967 int strideO = Rinit;
967 for(i = 0; i < passNum; i++) 968 for(i = 0; i < passNum; i++)
968 strideO *= radixArr[i]; 969 strideO *= radixArr[i];
969 970
970 int threadsPerXForm = R2; 971 int threadsPerXForm = R2;
971 batchSize = R2 == 1 ? plan->max_work_item_per_workgroup : batchSize; 972 batchSize = R2 == 1 ? plan->max_work_item_per_workgroup : batchSize;
972 batchSize = min(batchSize, strideI); 973 batchSize = min(batchSize, strideI);
973 int threadsPerBlock = batchSize * threadsPerXForm; 974 int threadsPerBlock = batchSize * threadsPerXForm;
974 threadsPerBlock = min(threadsPerBlock, maxThreadsPerBlock); 975 threadsPerBlock = min(threadsPerBlock, maxThreadsPerBlock);
975 batchSize = threadsPerBlock / threadsPerXForm; 976 batchSize = threadsPerBlock / threadsPerXForm;
976 assert(R2 <= R1); 977 assert(R2 <= R1);
977 assert(R1*R2 == radix); 978 assert(R1*R2 == radix);
978 assert(R1 <= maxArrayLen); 979 assert(R1 <= maxArrayLen);
979 assert(threadsPerBlock <= maxThreadsPerBlock); 980 assert(threadsPerBlock <= maxThreadsPerBlock);
980 981
981 int numIter = R1 / R2; 982 int numIter = R1 / R2;
982 int gInInc = threadsPerBlock / batchSize; 983 int gInInc = threadsPerBlock / batchSize;
983 984
984 985
985 int lgStrideO = (int)log2(strideO); 986 int lgStrideO = (int)log2(strideO);
986 int numBlocksPerXForm = strideI / batchSize; 987 int numBlocksPerXForm = strideI / batchSize;
987 int numBlocks = numBlocksPerXForm; 988 int numBlocks = numBlocksPerXForm;
988 if(!vertical) 989 if(!vertical)
989 numBlocks *= BS; 990 numBlocks *= BS;
990 else 991 else
991 numBlocks *= vertBS; 992 numBlocks *= vertBS;
992 993
993 kernelName = string("fft") + num2str(kCount); 994 kernelName = string("fft") + num2str(kCount);
994 *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info)); 995 *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info));
995 (*kInfo)->kernel = 0; 996 (*kInfo)->kernel = 0;
996 if(R2 == 1) 997 if(R2 == 1)
997 (*kInfo)->lmem_size = 0; 998 (*kInfo)->lmem_size = 0;
998 else 999 else
999 { 1000 {
1000 if(strideO == 1) 1001 if(strideO == 1)
1001 (*kInfo)->lmem_size = (radix + 1)*batchSize; 1002 (*kInfo)->lmem_size = (radix + 1)*batchSize;
1002 else 1003 else
1003 (*kInfo)->lmem_size = threadsPerBlock*R1; 1004 (*kInfo)->lmem_size = threadsPerBlock*R1;
1004 } 1005 }
1005 (*kInfo)->num_workgroups = numBlocks; 1006 (*kInfo)->num_workgroups = numBlocks;
1006 (*kInfo)->num_xforms_per_workgroup = 1; 1007 (*kInfo)->num_xforms_per_workgroup = 1;
1007 (*kInfo)->num_workitems_per_workgroup = threadsPerBlock; 1008 (*kInfo)->num_workitems_per_workgroup = threadsPerBlock;
1008 (*kInfo)->dir = dir; 1009 (*kInfo)->dir = dir;
1009 if( (passNum == (numPasses - 1)) && (numPasses & 1) ) 1010 if( (passNum == (numPasses - 1)) && (numPasses & 1) )
1010 (*kInfo)->in_place_possible = 1; 1011 (*kInfo)->in_place_possible = 1;
1011 else 1012 else
1012 (*kInfo)->in_place_possible = 0; 1013 (*kInfo)->in_place_possible = 0;
1013 (*kInfo)->next = NULL; 1014 (*kInfo)->next = NULL;
1014 (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1)); 1015 (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1));
1015 strcpy((*kInfo)->kernel_name, kernelName.c_str()); 1016 strcpy((*kInfo)->kernel_name, kernelName.c_str());
1016 1017
1017 insertVariables(localString, R1); 1018 insertVariables(localString, R1);
1018 1019
1019 if(vertical) 1020 if(vertical)
1020 { 1021 {
1021 localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n"); 1022 localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n");
1022 localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n"); 1023 localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
1023 localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n"); 1024 localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n");
1024 localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n"); 1025 localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n");
1025 localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n"); 1026 localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
1026 localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); 1027 localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
1027 int stride = radix*Rinit; 1028 int stride = radix*Rinit;
1028 for(i = 0; i < passNum; i++) 1029 for(i = 0; i < passNum; i++)
1029 stride *= radixArr[i]; 1030 stride *= radixArr[i];
1030 localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n"); 1031 localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n");
1031 localString += string("bNum = groupId;\n"); 1032 localString += string("bNum = groupId;\n");
1032 } 1033 }
1033 else 1034 else
1034 { 1035 {
1035 int lgNumBlocksPerXForm = (int)log2(numBlocksPerXForm); 1036 int lgNumBlocksPerXForm = (int)log2(numBlocksPerXForm);
1036 localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n"); 1037 localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
1037 localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n"); 1038 localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n");
1038 localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n"); 1039 localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n");
1039 localString += string("tid = indexIn;\n"); 1040 localString += string("tid = indexIn;\n");
1040 localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n"); 1041 localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
1041 localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); 1042 localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
1042 int stride = radix*Rinit; 1043 int stride = radix*Rinit;
1043 for(i = 0; i < passNum; i++) 1044 for(i = 0; i < passNum; i++)
1044 stride *= radixArr[i]; 1045 stride *= radixArr[i];
1045 localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n"); 1046 localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n");
1046 localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n"); 1047 localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n");
1047 localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n"); 1048 localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n");
1048 } 1049 }
1049 1050
1050 // Load Data 1051 // Load Data
1051 int lgBatchSize = (int)log2(batchSize); 1052 int lgBatchSize = (int)log2(batchSize);
1052 localString += string("tid = lId;\n"); 1053 localString += string("tid = lId;\n");
1053 localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n"); 1054 localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n");
1054 localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n"); 1055 localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n");
1055 localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n"); 1056 localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n");
1056 1057
1057 if(dataFormat == clFFT_SplitComplexFormat) 1058 if(dataFormat == clFFT_SplitComplexFormat)
1058 { 1059 {
1059 localString += string("in_real += indexIn;\n"); 1060 localString += string("in_real += indexIn;\n");
1060 localString += string("in_imag += indexIn;\n"); 1061 localString += string("in_imag += indexIn;\n");
1061 for(j = 0; j < R1; j++) 1062 for(j = 0; j < R1; j++)
1062 localString += string("a[") + num2str(j) + string("].x = in_real[") + num2str(j*gInInc*strideI) + string("];\n"); 1063 localString += string("a[") + num2str(j) + string("].x = in_real[") + num2str(j*gInInc*strideI) + string("];\n");
1063 for(j = 0; j < R1; j++) 1064 for(j = 0; j < R1; j++)
1064 localString += string("a[") + num2str(j) + string("].y = in_imag[") + num2str(j*gInInc*strideI) + string("];\n"); 1065 localString += string("a[") + num2str(j) + string("].y = in_imag[") + num2str(j*gInInc*strideI) + string("];\n");
1065 } 1066 }
1066 else 1067 else
1067 { 1068 {
1068 localString += string("in += indexIn;\n"); 1069 localString += string("in += indexIn;\n");
1069 for(j = 0; j < R1; j++) 1070 for(j = 0; j < R1; j++)
1070 localString += string("a[") + num2str(j) + string("] = in[") + num2str(j*gInInc*strideI) + string("];\n"); 1071 localString += string("a[") + num2str(j) + string("] = in[") + num2str(j*gInInc*strideI) + string("];\n");
1071 } 1072 }
1072 1073
1073 localString += string("fftKernel") + num2str(R1) + string("(a, dir);\n"); 1074 localString += string("fftKernel") + num2str(R1) + string("(a, dir);\n");
1074 1075
1075 if(R2 > 1) 1076 if(R2 > 1)
1076 { 1077 {
1077 // twiddle 1078 // twiddle
1078 for(k = 1; k < R1; k++) 1079 for(k = 1; k < R1; k++)
1079 { 1080 {
1080 localString += string("ang = dir*(2.0f*M_PI*") + num2str(k) + string("/") + num2str(radix) + string(")*j;\n"); 1081 localString += string("ang = dir*(2.0f*M_PI*") + num2str(k) + string("/") + num2str(radix) + string(")*j;\n");
1081 localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n"); 1082 localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
1082 localString += string("a[") + num2str(k) + string("] = complexMul(a[") + num2str(k) + string("], w);\n"); 1083 localString += string("a[") + num2str(k) + string("] = complexMul(a[") + num2str(k) + string("], w);\n");
1083 } 1084 }
1084 1085
1085 // shuffle 1086 // shuffle
1086 numIter = R1 / R2; 1087 numIter = R1 / R2;
1087 localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n"); 1088 localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n");
1088 localString += string("lMemStore = sMem + tid;\n"); 1089 localString += string("lMemStore = sMem + tid;\n");
1089 localString += string("lMemLoad = sMem + indexIn;\n"); 1090 localString += string("lMemLoad = sMem + indexIn;\n");
1090 for(k = 0; k < R1; k++) 1091 for(k = 0; k < R1; k++)
1091 localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n"); 1092 localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
1092 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); 1093 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1093 for(k = 0; k < numIter; k++) 1094 for(k = 0; k < numIter; k++)
1094 for(t = 0; t < R2; t++) 1095 for(t = 0; t < R2; t++)
1095 localString += string("a[") + num2str(k*R2+t) + string("].x = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n"); 1096 localString += string("a[") + num2str(k*R2+t) + string("].x = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n");
1096 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); 1097 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1097 for(k = 0; k < R1; k++) 1098 for(k = 0; k < R1; k++)
1098 localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n"); 1099 localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
1099 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); 1100 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1100 for(k = 0; k < numIter; k++) 1101 for(k = 0; k < numIter; k++)
1101 for(t = 0; t < R2; t++) 1102 for(t = 0; t < R2; t++)
1102 localString += string("a[") + num2str(k*R2+t) + string("].y = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n"); 1103 localString += string("a[") + num2str(k*R2+t) + string("].y = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n");
1103 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); 1104 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1104 1105
1105 for(j = 0; j < numIter; j++) 1106 for(j = 0; j < numIter; j++)
1106 localString += string("fftKernel") + num2str(R2) + string("(a + ") + num2str(j*R2) + string(", dir);\n"); 1107 localString += string("fftKernel") + num2str(R2) + string("(a + ") + num2str(j*R2) + string(", dir);\n");
1107 } 1108 }
1108 1109
1109 // twiddle 1110 // twiddle
1110 if(passNum < (numPasses - 1)) 1111 if(passNum < (numPasses - 1))
1111 { 1112 {
1112 localString += string("l = ((bNum << ") + num2str(lgBatchSize) + string(") + i) >> ") + num2str(lgStrideO) + string(";\n"); 1113 localString += string("l = ((bNum << ") + num2str(lgBatchSize) + string(") + i) >> ") + num2str(lgStrideO) + string(";\n");
1113 localString += string("k = j << ") + num2str((int)log2(R1/R2)) + string(";\n"); 1114 localString += string("k = j << ") + num2str((int)log2(R1/R2)) + string(";\n");
1114 localString += string("ang1 = dir*(2.0f*M_PI/") + num2str(N) + string(")*l;\n"); 1115 localString += string("ang1 = dir*(2.0f*M_PI/") + num2str(N) + string(")*l;\n");
1115 for(t = 0; t < R1; t++) 1116 for(t = 0; t < R1; t++)
1116 { 1117 {
1117 localString += string("ang = ang1*(k + ") + num2str((t%R2)*R1 + (t/R2)) + string(");\n"); 1118 localString += string("ang = ang1*(k + ") + num2str((t%R2)*R1 + (t/R2)) + string(");\n");
1118 localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n"); 1119 localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
1119 localString += string("a[") + num2str(t) + string("] = complexMul(a[") + num2str(t) + string("], w);\n"); 1120 localString += string("a[") + num2str(t) + string("] = complexMul(a[") + num2str(t) + string("], w);\n");
1120 } 1121 }
1121 } 1122 }
1122 1123
1123 // Store Data 1124 // Store Data
1124 if(strideO == 1) 1125 if(strideO == 1)
1125 { 1126 {
1126 1127
1127 localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n"); 1128 localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n");
1128 localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n"); 1129 localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n");
1129 1130
1130 for(i = 0; i < R1/R2; i++) 1131 for(i = 0; i < R1/R2; i++)
1131 for(j = 0; j < R2; j++) 1132 for(j = 0; j < R2; j++)
1132 localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].x;\n"); 1133 localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].x;\n");
1133 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); 1134 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1134 if(threadsPerBlock >= radix) 1135 if(threadsPerBlock >= radix)
1135 { 1136 {
1136 for(i = 0; i < R1; i++) 1137 for(i = 0; i < R1; i++)
1137 localString += string("a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n"); 1138 localString += string("a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n");
1138 } 1139 }
1139 else 1140 else
1140 { 1141 {
1141 int innerIter = radix/threadsPerBlock; 1142 int innerIter = radix/threadsPerBlock;
1142 int outerIter = R1/innerIter; 1143 int outerIter = R1/innerIter;
1143 for(i = 0; i < outerIter; i++) 1144 for(i = 0; i < outerIter; i++)
1144 for(j = 0; j < innerIter; j++) 1145 for(j = 0; j < innerIter; j++)
1145 localString += string("a[") + num2str(i*innerIter+j) + string("].x = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n"); 1146 localString += string("a[") + num2str(i*innerIter+j) + string("].x = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n");
1146 } 1147 }
1147 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); 1148 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1148 1149
1149 for(i = 0; i < R1/R2; i++) 1150 for(i = 0; i < R1/R2; i++)
1150 for(j = 0; j < R2; j++) 1151 for(j = 0; j < R2; j++)
1151 localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].y;\n"); 1152 localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].y;\n");
1152 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); 1153 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1153 if(threadsPerBlock >= radix) 1154 if(threadsPerBlock >= radix)
1154 { 1155 {
1155 for(i = 0; i < R1; i++) 1156 for(i = 0; i < R1; i++)
1156 localString += string("a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n"); 1157 localString += string("a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n");
1157 } 1158 }
1158 else 1159 else
1159 { 1160 {
1160 int innerIter = radix/threadsPerBlock; 1161 int innerIter = radix/threadsPerBlock;
1161 int outerIter = R1/innerIter; 1162 int outerIter = R1/innerIter;
1162 for(i = 0; i < outerIter; i++) 1163 for(i = 0; i < outerIter; i++)
1163 for(j = 0; j < innerIter; j++) 1164 for(j = 0; j < innerIter; j++)
1164 localString += string("a[") + num2str(i*innerIter+j) + string("].y = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n"); 1165 localString += string("a[") + num2str(i*innerIter+j) + string("].y = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n");
1165 } 1166 }
1166 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n"); 1167 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
1167 1168
1168 localString += string("indexOut += tid;\n"); 1169 localString += string("indexOut += tid;\n");
1169 if(dataFormat == clFFT_SplitComplexFormat) { 1170 if(dataFormat == clFFT_SplitComplexFormat) {
1170 localString += string("out_real += indexOut;\n"); 1171 localString += string("out_real += indexOut;\n");
1171 localString += string("out_imag += indexOut;\n"); 1172 localString += string("out_imag += indexOut;\n");
1172 for(k = 0; k < R1; k++) 1173 for(k = 0; k < R1; k++)
1173 localString += string("out_real[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n"); 1174 localString += string("out_real[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
1174 for(k = 0; k < R1; k++) 1175 for(k = 0; k < R1; k++)
1175 localString += string("out_imag[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n"); 1176 localString += string("out_imag[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
1176 } 1177 }
1177 else { 1178 else {
1178 localString += string("out += indexOut;\n"); 1179 localString += string("out += indexOut;\n");
1179 for(k = 0; k < R1; k++) 1180 for(k = 0; k < R1; k++)
1180 localString += string("out[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("];\n"); 1181 localString += string("out[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("];\n");
1181 } 1182 }
1182 1183
1183 } 1184 }
1184 else 1185 else
1185 { 1186 {
1186 localString += string("indexOut += mad24(j, ") + num2str(numIter*strideO) + string(", i);\n"); 1187 localString += string("indexOut += mad24(j, ") + num2str(numIter*strideO) + string(", i);\n");
1187 if(dataFormat == clFFT_SplitComplexFormat) { 1188 if(dataFormat == clFFT_SplitComplexFormat) {
1188 localString += string("out_real += indexOut;\n"); 1189 localString += string("out_real += indexOut;\n");
1189 localString += string("out_imag += indexOut;\n"); 1190 localString += string("out_imag += indexOut;\n");
1190 for(k = 0; k < R1; k++) 1191 for(k = 0; k < R1; k++)
1191 localString += string("out_real[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].x;\n"); 1192 localString += string("out_real[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].x;\n");
1192 for(k = 0; k < R1; k++) 1193 for(k = 0; k < R1; k++)
1193 localString += string("out_imag[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].y;\n"); 1194 localString += string("out_imag[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].y;\n");
1194 } 1195 }
1195 else { 1196 else {
1196 localString += string("out += indexOut;\n"); 1197 localString += string("out += indexOut;\n");
1197 for(k = 0; k < R1; k++) 1198 for(k = 0; k < R1; k++)
1198 localString += string("out[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("];\n"); 1199 localString += string("out[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("];\n");
1199 } 1200 }
1200 } 1201 }
1201 1202
1202 insertHeader(*kernelString, kernelName, dataFormat); 1203 insertHeader(*kernelString, kernelName, dataFormat);
1203 *kernelString += string("{\n"); 1204 *kernelString += string("{\n");
1204 if((*kInfo)->lmem_size) 1205 if((*kInfo)->lmem_size)
1205 *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n"); 1206 *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
1206 *kernelString += localString; 1207 *kernelString += localString;
1207 *kernelString += string("}\n"); 1208 *kernelString += string("}\n");
1208 1209
1209 N /= radix; 1210 N /= radix;
1210 kInfo = &(*kInfo)->next; 1211 kInfo = &(*kInfo)->next;
1211 kCount++; 1212 kCount++;
1212 } 1213 }
1213 } 1214 }
1214 1215
1215 void FFT1D(cl_fft_plan *plan, cl_fft_kernel_dir dir) 1216 void FFT1D(cl_fft_plan *plan, cl_fft_kernel_dir dir)
1216 { 1217 {
1217 unsigned int radixArray[10]; 1218 unsigned int radixArray[10];
1218 unsigned int numRadix; 1219 unsigned int numRadix;
1219 1220
1220 switch(dir) 1221 switch(dir)
1221 { 1222 {
1222 case cl_fft_kernel_x: 1223 case cl_fft_kernel_x:
1223 if(plan->n.x > plan->max_localmem_fft_size) 1224 if(plan->n.x > plan->max_localmem_fft_size)
1224 { 1225 {
1225 createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1); 1226 createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
1226 } 1227 }
1227 else if(plan->n.x > 1) 1228 else if(plan->n.x > 1)
1228 { 1229 {
1229 getRadixArray(plan->n.x, radixArray, &numRadix, 0); 1230 getRadixArray(plan->n.x, radixArray, &numRadix, 0);
1230 if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup) 1231 if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
1231 { 1232 {
1232 createLocalMemfftKernelString(plan); 1233 createLocalMemfftKernelString(plan);
1233 } 1234 }
1234 else 1235 else
1235 { 1236 {
1236 getRadixArray(plan->n.x, radixArray, &numRadix, plan->max_radix); 1237 getRadixArray(plan->n.x, radixArray, &numRadix, plan->max_radix);
1237 if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup) 1238 if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
1238 createLocalMemfftKernelString(plan); 1239 createLocalMemfftKernelString(plan);
1239 else 1240 else
1240 createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1); 1241 createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
1241 } 1242 }
1242 } 1243 }
1243 break; 1244 break;
1244 1245
1245 case cl_fft_kernel_y: 1246 case cl_fft_kernel_y:
1246 if(plan->n.y > 1) 1247 if(plan->n.y > 1)
1247 createGlobalFFTKernelString(plan, plan->n.y, plan->n.x, cl_fft_kernel_y, 1); 1248 createGlobalFFTKernelString(plan, plan->n.y, plan->n.x, cl_fft_kernel_y, 1);
1248 break; 1249
1249 1250
1250 case cl_fft_kernel_z: 1251 break;
1251 if(plan->n.z > 1) 1252
1252 createGlobalFFTKernelString(plan, plan->n.z, plan->n.x*plan->n.y, cl_fft_kernel_z, 1); 1253 case cl_fft_kernel_z:
1253 default: 1254 if(plan->n.z > 1)
1254 return; 1255 createGlobalFFTKernelString(plan, plan->n.z, plan->n.x*plan->n.y, cl_fft_kernel_z, 1);
1255 } 1256 default:
1256 } 1257 return;
1257 1258 }
1259 }
1260
1261