2
|
1
|
|
2 //
|
|
3 // File: fft_kernelstring.cpp
|
|
4 //
|
|
5 // Version: <1.0>
|
|
6 //
|
|
7 // Disclaimer: IMPORTANT: This Apple software is supplied to you by Apple Inc. ("Apple")
|
|
8 // in consideration of your agreement to the following terms, and your use,
|
|
9 // installation, modification or redistribution of this Apple software
|
|
10 // constitutes acceptance of these terms. If you do not agree with these
|
|
11 // terms, please do not use, install, modify or redistribute this Apple
|
|
12 // software.
|
|
13 //
|
|
14 // In consideration of your agreement to abide by the following terms, and
|
|
15 // subject to these terms, Apple grants you a personal, non - exclusive
|
|
16 // license, under Apple's copyrights in this original Apple software ( the
|
|
17 // "Apple Software" ), to use, reproduce, modify and redistribute the Apple
|
|
18 // Software, with or without modifications, in source and / or binary forms;
|
|
19 // provided that if you redistribute the Apple Software in its entirety and
|
|
20 // without modifications, you must retain this notice and the following text
|
|
21 // and disclaimers in all such redistributions of the Apple Software. Neither
|
|
22 // the name, trademarks, service marks or logos of Apple Inc. may be used to
|
|
23 // endorse or promote products derived from the Apple Software without specific
|
|
24 // prior written permission from Apple. Except as expressly stated in this
|
|
25 // notice, no other rights or licenses, express or implied, are granted by
|
|
26 // Apple herein, including but not limited to any patent rights that may be
|
|
27 // infringed by your derivative works or by other works in which the Apple
|
|
28 // Software may be incorporated.
|
|
29 //
|
|
30 // The Apple Software is provided by Apple on an "AS IS" basis. APPLE MAKES NO
|
|
31 // WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED
|
|
32 // WARRANTIES OF NON - INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A
|
|
33 // PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND OPERATION
|
|
34 // ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
|
35 //
|
|
36 // IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR
|
|
37 // CONSEQUENTIAL DAMAGES ( INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
38 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
39 // INTERRUPTION ) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION
|
|
40 // AND / OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED AND WHETHER
|
|
41 // UNDER THEORY OF CONTRACT, TORT ( INCLUDING NEGLIGENCE ), STRICT LIABILITY OR
|
|
42 // OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
43 //
|
|
44 // Copyright ( C ) 2008 Apple Inc. All Rights Reserved.
|
|
45 //
|
|
46 ////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
47
|
|
48
|
|
49 #include <stdio.h>
|
|
50 #include <stdlib.h>
|
|
51 #include <math.h>
|
|
52 #include <iostream>
|
|
53 #include <sstream>
|
|
54 #include <string>
|
|
55 #include <assert.h>
|
|
56 #include "fft_internal.h"
|
|
57 #include "clFFT.h"
|
|
58
|
|
59 using namespace std;
|
|
60
|
|
61 #define max(A,B) ((A) > (B) ? (A) : (B))
|
|
62 #define min(A,B) ((A) < (B) ? (A) : (B))
|
|
63
|
|
64 static string
|
|
65 num2str(int num)
|
|
66 {
|
|
67 char temp[200];
|
|
68 sprintf(temp, "%d", num);
|
|
69 return string(temp);
|
|
70 }
|
|
71
|
|
72 // For any n, this function decomposes n into factors for loacal memory tranpose
|
|
73 // based fft. Factors (radices) are sorted such that the first one (radixArray[0])
|
|
74 // is the largest. This base radix determines the number of registers used by each
|
|
75 // work item and product of remaining radices determine the size of work group needed.
|
|
76 // To make things concrete with and example, suppose n = 1024. It is decomposed into
|
|
77 // 1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and
|
|
78 // needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length
|
|
79 // 16 ffts (64 work items working in parallel) following by transpose using local
|
|
80 // memory followed by again 64 length 16 ffts followed by transpose using local memory
|
|
81 // followed by 256 length 4 ffts. For the last step since with size of work group is
|
|
82 // 64 and each work item can array for 16 values, 64 work items can compute 256 length
|
|
83 // 4 ffts by each work item computing 4 length 4 ffts.
|
|
84 // Similarly for n = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work
|
|
85 // iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed
|
|
86 // by transpose using local memory, followed by 256 length 8 in-register ffts, followed
|
|
87 // by transpose using local memory, followed by 256 length 8 in-register ffts, followed
|
|
88 // by transpose using local memory, followed by 512 length 4 in-register ffts. Again,
|
|
89 // for the last step, each work item computes two length 4 in-register ffts and thus
|
|
90 // 256 work items are needed to compute all 512 ffts.
|
|
91 // For n = 32 = 8 x 4, 4 work items first compute 4 in-register
|
|
92 // lenth 8 ffts, followed by transpose using local memory followed by 8 in-register
|
|
93 // length 4 ffts, where each work item computes two length 4 ffts thus 4 work items
|
|
94 // can compute 8 length 4 ffts. However if work group size of say 64 is choosen,
|
|
95 // each work group can compute 64/ 4 = 16 size 32 ffts (batched transform).
|
|
96 // Users can play with these parameters to figure what gives best performance on
|
|
97 // their particular device i.e. some device have less register space thus using
|
|
98 // smaller base radix can avoid spilling ... some has small local memory thus
|
|
99 // using smaller work group size may be required etc
|
|
100
|
|
101 static void
|
|
102 getRadixArray(unsigned int n, unsigned int *radixArray, unsigned int *numRadices, unsigned int maxRadix)
|
|
103 {
|
|
104 if(maxRadix > 1)
|
|
105 {
|
|
106 maxRadix = min(n, maxRadix);
|
|
107 unsigned int cnt = 0;
|
|
108 while(n > maxRadix)
|
|
109 {
|
|
110 radixArray[cnt++] = maxRadix;
|
|
111 n /= maxRadix;
|
|
112 }
|
|
113 radixArray[cnt++] = n;
|
|
114 *numRadices = cnt;
|
|
115 return;
|
|
116 }
|
|
117
|
|
118 switch(n)
|
|
119 {
|
|
120 case 2:
|
|
121 *numRadices = 1;
|
|
122 radixArray[0] = 2;
|
|
123 break;
|
|
124
|
|
125 case 4:
|
|
126 *numRadices = 1;
|
|
127 radixArray[0] = 4;
|
|
128 break;
|
|
129
|
|
130 case 8:
|
|
131 *numRadices = 1;
|
|
132 radixArray[0] = 8;
|
|
133 break;
|
|
134
|
|
135 case 16:
|
|
136 *numRadices = 2;
|
|
137 radixArray[0] = 8; radixArray[1] = 2;
|
|
138 break;
|
|
139
|
|
140 case 32:
|
|
141 *numRadices = 2;
|
|
142 radixArray[0] = 8; radixArray[1] = 4;
|
|
143 break;
|
|
144
|
|
145 case 64:
|
|
146 *numRadices = 2;
|
|
147 radixArray[0] = 8; radixArray[1] = 8;
|
|
148 break;
|
|
149
|
|
150 case 128:
|
|
151 *numRadices = 3;
|
|
152 radixArray[0] = 8; radixArray[1] = 4; radixArray[2] = 4;
|
|
153 break;
|
|
154
|
|
155 case 256:
|
|
156 *numRadices = 4;
|
|
157 radixArray[0] = 4; radixArray[1] = 4; radixArray[2] = 4; radixArray[3] = 4;
|
|
158 break;
|
|
159
|
|
160 case 512:
|
|
161 *numRadices = 3;
|
|
162 radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8;
|
|
163 break;
|
|
164
|
|
165 case 1024:
|
|
166 *numRadices = 3;
|
|
167 radixArray[0] = 16; radixArray[1] = 16; radixArray[2] = 4;
|
|
168 break;
|
|
169 case 2048:
|
|
170 *numRadices = 4;
|
|
171 radixArray[0] = 8; radixArray[1] = 8; radixArray[2] = 8; radixArray[3] = 4;
|
|
172 break;
|
|
173 default:
|
|
174 *numRadices = 0;
|
|
175 return;
|
|
176 }
|
|
177 }
|
|
178
|
|
179 static void
|
|
180 insertHeader(string &kernelString, string &kernelName, clFFT_DataFormat dataFormat)
|
|
181 {
|
|
182 if(dataFormat == clFFT_SplitComplexFormat)
|
|
183 kernelString += string("__kernel void ") + kernelName + string("(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n");
|
|
184 else
|
|
185 kernelString += string("__kernel void ") + kernelName + string("(__global float2 *in, __global float2 *out, int dir, int S)\n");
|
|
186 }
|
|
187
|
|
188 static void
|
|
189 insertVariables(string &kStream, int maxRadix)
|
|
190 {
|
|
191 kStream += string(" int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n");
|
|
192 kStream += string(" int s, ii, jj, offset;\n");
|
|
193 kStream += string(" float2 w;\n");
|
|
194 kStream += string(" float ang, angf, ang1;\n");
|
|
195 kStream += string(" __local float *lMemStore, *lMemLoad;\n");
|
|
196 kStream += string(" float2 a[") + num2str(maxRadix) + string("];\n");
|
|
197 kStream += string(" int lId = get_local_id( 0 );\n");
|
|
198 kStream += string(" int groupId = get_group_id( 0 );\n");
|
|
199 }
|
|
200
|
|
201 static void
|
|
202 formattedLoad(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
|
|
203 {
|
|
204 if(dataFormat == clFFT_InterleavedComplexFormat)
|
|
205 kernelString += string(" a[") + num2str(aIndex) + string("] = in[") + num2str(gIndex) + string("];\n");
|
|
206 else
|
|
207 {
|
|
208 kernelString += string(" a[") + num2str(aIndex) + string("].x = in_real[") + num2str(gIndex) + string("];\n");
|
|
209 kernelString += string(" a[") + num2str(aIndex) + string("].y = in_imag[") + num2str(gIndex) + string("];\n");
|
|
210 }
|
|
211 }
|
|
212
|
|
213 static void
|
|
214 formattedStore(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
|
|
215 {
|
|
216 if(dataFormat == clFFT_InterleavedComplexFormat)
|
|
217 kernelString += string(" out[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("];\n");
|
|
218 else
|
|
219 {
|
|
220 kernelString += string(" out_real[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].x;\n");
|
|
221 kernelString += string(" out_imag[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].y;\n");
|
|
222 }
|
|
223 }
|
|
224
|
|
225 static int
|
|
226 insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXForm, int numXFormsPerWG, int R0, int mem_coalesce_width, clFFT_DataFormat dataFormat)
|
|
227 {
|
|
228 int log2NumWorkItemsPerXForm = (int) log2(numWorkItemsPerXForm);
|
|
229 int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
|
|
230 int i, j;
|
|
231 int lMemSize = 0;
|
|
232
|
|
233 if(numXFormsPerWG > 1)
|
|
234 kernelString += string(" s = S & ") + num2str(numXFormsPerWG - 1) + string(";\n");
|
|
235
|
|
236 if(numWorkItemsPerXForm >= mem_coalesce_width)
|
|
237 {
|
|
238 if(numXFormsPerWG > 1)
|
|
239 {
|
|
240 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm-1) + string(";\n");
|
|
241 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
|
|
242 kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
|
|
243 kernelString += string(" offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n");
|
|
244 if(dataFormat == clFFT_InterleavedComplexFormat)
|
|
245 {
|
|
246 kernelString += string(" in += offset;\n");
|
|
247 kernelString += string(" out += offset;\n");
|
|
248 }
|
|
249 else
|
|
250 {
|
|
251 kernelString += string(" in_real += offset;\n");
|
|
252 kernelString += string(" in_imag += offset;\n");
|
|
253 kernelString += string(" out_real += offset;\n");
|
|
254 kernelString += string(" out_imag += offset;\n");
|
|
255 }
|
|
256 for(i = 0; i < R0; i++)
|
|
257 formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat);
|
|
258 kernelString += string(" }\n");
|
|
259 }
|
|
260 else
|
|
261 {
|
|
262 kernelString += string(" ii = lId;\n");
|
|
263 kernelString += string(" jj = 0;\n");
|
|
264 kernelString += string(" offset = mad24(groupId, ") + num2str(N) + string(", ii);\n");
|
|
265 if(dataFormat == clFFT_InterleavedComplexFormat)
|
|
266 {
|
|
267 kernelString += string(" in += offset;\n");
|
|
268 kernelString += string(" out += offset;\n");
|
|
269 }
|
|
270 else
|
|
271 {
|
|
272 kernelString += string(" in_real += offset;\n");
|
|
273 kernelString += string(" in_imag += offset;\n");
|
|
274 kernelString += string(" out_real += offset;\n");
|
|
275 kernelString += string(" out_imag += offset;\n");
|
|
276 }
|
|
277 for(i = 0; i < R0; i++)
|
|
278 formattedLoad(kernelString, i, i*numWorkItemsPerXForm, dataFormat);
|
|
279 }
|
|
280 }
|
|
281 else if( N >= mem_coalesce_width )
|
|
282 {
|
|
283 int numInnerIter = N / mem_coalesce_width;
|
|
284 int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
|
|
285
|
|
286 kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
|
|
287 kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
|
|
288 kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
|
|
289 kernelString += string(" offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n");
|
|
290 kernelString += string(" offset = mad24( offset, ") + num2str(N) + string(", ii );\n");
|
|
291 if(dataFormat == clFFT_InterleavedComplexFormat)
|
|
292 {
|
|
293 kernelString += string(" in += offset;\n");
|
|
294 kernelString += string(" out += offset;\n");
|
|
295 }
|
|
296 else
|
|
297 {
|
|
298 kernelString += string(" in_real += offset;\n");
|
|
299 kernelString += string(" in_imag += offset;\n");
|
|
300 kernelString += string(" out_real += offset;\n");
|
|
301 kernelString += string(" out_imag += offset;\n");
|
|
302 }
|
|
303
|
|
304 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
|
|
305 for(i = 0; i < numOuterIter; i++ )
|
|
306 {
|
|
307 kernelString += string(" if( jj < s ) {\n");
|
|
308 for(j = 0; j < numInnerIter; j++ )
|
|
309 formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat);
|
|
310 kernelString += string(" }\n");
|
|
311 if(i != numOuterIter - 1)
|
|
312 kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");
|
|
313 }
|
|
314 kernelString += string("}\n ");
|
|
315 kernelString += string("else {\n");
|
|
316 for(i = 0; i < numOuterIter; i++ )
|
|
317 {
|
|
318 for(j = 0; j < numInnerIter; j++ )
|
|
319 formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * N, dataFormat);
|
|
320 }
|
|
321 kernelString += string("}\n");
|
|
322
|
|
323 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
|
|
324 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
|
|
325 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n");
|
|
326
|
|
327 for( i = 0; i < numOuterIter; i++ )
|
|
328 {
|
|
329 for( j = 0; j < numInnerIter; j++ )
|
|
330 {
|
|
331 kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") +
|
|
332 num2str(i * numInnerIter + j) + string("].x;\n");
|
|
333 }
|
|
334 }
|
|
335 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
336
|
|
337 for( i = 0; i < R0; i++ )
|
|
338 kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
|
|
339 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
340
|
|
341 for( i = 0; i < numOuterIter; i++ )
|
|
342 {
|
|
343 for( j = 0; j < numInnerIter; j++ )
|
|
344 {
|
|
345 kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * ( groupSize / mem_coalesce_width ) * (N + numWorkItemsPerXForm )) + string("] = a[") +
|
|
346 num2str(i * numInnerIter + j) + string("].y;\n");
|
|
347 }
|
|
348 }
|
|
349 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
350
|
|
351 for( i = 0; i < R0; i++ )
|
|
352 kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
|
|
353 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
354
|
|
355 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
|
|
356 }
|
|
357 else
|
|
358 {
|
|
359 kernelString += string(" offset = mad24( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n");
|
|
360 if(dataFormat == clFFT_InterleavedComplexFormat)
|
|
361 {
|
|
362 kernelString += string(" in += offset;\n");
|
|
363 kernelString += string(" out += offset;\n");
|
|
364 }
|
|
365 else
|
|
366 {
|
|
367 kernelString += string(" in_real += offset;\n");
|
|
368 kernelString += string(" in_imag += offset;\n");
|
|
369 kernelString += string(" out_real += offset;\n");
|
|
370 kernelString += string(" out_imag += offset;\n");
|
|
371 }
|
|
372
|
|
373 kernelString += string(" ii = lId & ") + num2str(N-1) + string(";\n");
|
|
374 kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n");
|
|
375 kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
|
|
376
|
|
377 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
|
|
378 for( i = 0; i < R0; i++ )
|
|
379 {
|
|
380 kernelString += string(" if(jj < s )\n");
|
|
381 formattedLoad(kernelString, i, i*groupSize, dataFormat);
|
|
382 if(i != R0 - 1)
|
|
383 kernelString += string(" jj += ") + num2str(groupSize / N) + string(";\n");
|
|
384 }
|
|
385 kernelString += string("}\n");
|
|
386 kernelString += string("else {\n");
|
|
387 for( i = 0; i < R0; i++ )
|
|
388 {
|
|
389 formattedLoad(kernelString, i, i*groupSize, dataFormat);
|
|
390 }
|
|
391 kernelString += string("}\n");
|
|
392
|
|
393 if(numWorkItemsPerXForm > 1)
|
|
394 {
|
|
395 kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
|
|
396 kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
|
|
397 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
|
|
398 }
|
|
399 else
|
|
400 {
|
|
401 kernelString += string(" ii = 0;\n");
|
|
402 kernelString += string(" jj = lId;\n");
|
|
403 kernelString += string(" lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n");
|
|
404 }
|
|
405
|
|
406
|
|
407 for( i = 0; i < R0; i++ )
|
|
408 kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].x;\n");
|
|
409 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
410
|
|
411 for( i = 0; i < R0; i++ )
|
|
412 kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
|
|
413 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
414
|
|
415 for( i = 0; i < R0; i++ )
|
|
416 kernelString += string(" lMemStore[") + num2str(i * ( groupSize / N ) * ( N + numWorkItemsPerXForm )) + string("] = a[") + num2str(i) + string("].y;\n");
|
|
417 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
418
|
|
419 for( i = 0; i < R0; i++ )
|
|
420 kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
|
|
421 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
422
|
|
423 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
|
|
424 }
|
|
425
|
|
426 return lMemSize;
|
|
427 }
|
|
428
|
|
429 static int
|
|
430 insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr, int numWorkItemsPerXForm, int numXFormsPerWG, int mem_coalesce_width, clFFT_DataFormat dataFormat)
|
|
431 {
|
|
432 int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
|
|
433 int i, j, k, ind;
|
|
434 int lMemSize = 0;
|
|
435 int numIter = maxRadix / Nr;
|
|
436 string indent = string("");
|
|
437
|
|
438 if( numWorkItemsPerXForm >= mem_coalesce_width )
|
|
439 {
|
|
440 if(numXFormsPerWG > 1)
|
|
441 {
|
|
442 kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
|
|
443 indent = string(" ");
|
|
444 }
|
|
445 for(i = 0; i < maxRadix; i++)
|
|
446 {
|
|
447 j = i % numIter;
|
|
448 k = i / numIter;
|
|
449 ind = j * Nr + k;
|
|
450 formattedStore(kernelString, ind, i*numWorkItemsPerXForm, dataFormat);
|
|
451 }
|
|
452 if(numXFormsPerWG > 1)
|
|
453 kernelString += string(" }\n");
|
|
454 }
|
|
455 else if( N >= mem_coalesce_width )
|
|
456 {
|
|
457 int numInnerIter = N / mem_coalesce_width;
|
|
458 int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width );
|
|
459
|
|
460 kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
|
|
461 kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
|
|
462 kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
|
|
463 kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
|
|
464
|
|
465 for( i = 0; i < maxRadix; i++ )
|
|
466 {
|
|
467 j = i % numIter;
|
|
468 k = i / numIter;
|
|
469 ind = j * Nr + k;
|
|
470 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
|
|
471 }
|
|
472 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
473
|
|
474 for( i = 0; i < numOuterIter; i++ )
|
|
475 for( j = 0; j < numInnerIter; j++ )
|
|
476 kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].x = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n");
|
|
477 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
478
|
|
479 for( i = 0; i < maxRadix; i++ )
|
|
480 {
|
|
481 j = i % numIter;
|
|
482 k = i / numIter;
|
|
483 ind = j * Nr + k;
|
|
484 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
|
|
485 }
|
|
486 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
487
|
|
488 for( i = 0; i < numOuterIter; i++ )
|
|
489 for( j = 0; j < numInnerIter; j++ )
|
|
490 kernelString += string(" a[") + num2str(i*numInnerIter + j) + string("].y = lMemStore[") + num2str(j*mem_coalesce_width + i*( groupSize / mem_coalesce_width )*(N + numWorkItemsPerXForm)) + string("];\n");
|
|
491 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
492
|
|
493 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
|
|
494 for(i = 0; i < numOuterIter; i++ )
|
|
495 {
|
|
496 kernelString += string(" if( jj < s ) {\n");
|
|
497 for(j = 0; j < numInnerIter; j++ )
|
|
498 formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
|
|
499 kernelString += string(" }\n");
|
|
500 if(i != numOuterIter - 1)
|
|
501 kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");
|
|
502 }
|
|
503 kernelString += string("}\n");
|
|
504 kernelString += string("else {\n");
|
|
505 for(i = 0; i < numOuterIter; i++ )
|
|
506 {
|
|
507 for(j = 0; j < numInnerIter; j++ )
|
|
508 formattedStore(kernelString, i*numInnerIter + j, j*mem_coalesce_width + i*(groupSize/mem_coalesce_width)*N, dataFormat);
|
|
509 }
|
|
510 kernelString += string("}\n");
|
|
511
|
|
512 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
|
|
513 }
|
|
514 else
|
|
515 {
|
|
516 kernelString += string(" lMemLoad = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
|
|
517
|
|
518 kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n");
|
|
519 kernelString += string(" jj = lId >> ") + num2str((int) log2(N)) + string(";\n");
|
|
520 kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
|
|
521
|
|
522 for( i = 0; i < maxRadix; i++ )
|
|
523 {
|
|
524 j = i % numIter;
|
|
525 k = i / numIter;
|
|
526 ind = j * Nr + k;
|
|
527 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
|
|
528 }
|
|
529 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
530
|
|
531 for( i = 0; i < maxRadix; i++ )
|
|
532 kernelString += string(" a[") + num2str(i) + string("].x = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n");
|
|
533 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
534
|
|
535 for( i = 0; i < maxRadix; i++ )
|
|
536 {
|
|
537 j = i % numIter;
|
|
538 k = i / numIter;
|
|
539 ind = j * Nr + k;
|
|
540 kernelString += string(" lMemLoad[") + num2str(i*numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
|
|
541 }
|
|
542 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
543
|
|
544 for( i = 0; i < maxRadix; i++ )
|
|
545 kernelString += string(" a[") + num2str(i) + string("].y = lMemStore[") + num2str(i*( groupSize / N )*( N + numWorkItemsPerXForm )) + string("];\n");
|
|
546 kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
|
|
547
|
|
548 kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
|
|
549 for( i = 0; i < maxRadix; i++ )
|
|
550 {
|
|
551 kernelString += string(" if(jj < s ) {\n");
|
|
552 formattedStore(kernelString, i, i*groupSize, dataFormat);
|
|
553 kernelString += string(" }\n");
|
|
554 if( i != maxRadix - 1)
|
|
555 kernelString += string(" jj +=") + num2str(groupSize / N) + string(";\n");
|
|
556 }
|
|
557 kernelString += string("}\n");
|
|
558 kernelString += string("else {\n");
|
|
559 for( i = 0; i < maxRadix; i++ )
|
|
560 {
|
|
561 formattedStore(kernelString, i, i*groupSize, dataFormat);
|
|
562 }
|
|
563 kernelString += string("}\n");
|
|
564
|
|
565 lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
|
|
566 }
|
|
567
|
|
568 return lMemSize;
|
|
569 }
|
|
570
|
|
571 static void
|
|
572 insertfftKernel(string &kernelString, int Nr, int numIter)
|
|
573 {
|
|
574 int i;
|
|
575 for(i = 0; i < numIter; i++)
|
|
576 {
|
|
577 kernelString += string(" fftKernel") + num2str(Nr) + string("(a+") + num2str(i*Nr) + string(", dir);\n");
|
|
578 }
|
|
579 }
|
|
580
|
|
581 static void
|
|
582 insertTwiddleKernel(string &kernelString, int Nr, int numIter, int Nprev, int len, int numWorkItemsPerXForm)
|
|
583 {
|
|
584 int z, k;
|
|
585 int logNPrev = (int)log2(Nprev);
|
|
586
|
|
587 for(z = 0; z < numIter; z++)
|
|
588 {
|
|
589 if(z == 0)
|
|
590 {
|
|
591 if(Nprev > 1)
|
|
592 kernelString += string(" angf = (float) (ii >> ") + num2str(logNPrev) + string(");\n");
|
|
593 else
|
|
594 kernelString += string(" angf = (float) ii;\n");
|
|
595 }
|
|
596 else
|
|
597 {
|
|
598 if(Nprev > 1)
|
|
599 kernelString += string(" angf = (float) ((") + num2str(z*numWorkItemsPerXForm) + string(" + ii) >>") + num2str(logNPrev) + string(");\n");
|
|
600 else
|
|
601 kernelString += string(" angf = (float) (") + num2str(z*numWorkItemsPerXForm) + string(" + ii);\n");
|
|
602 }
|
|
603
|
|
604 for(k = 1; k < Nr; k++) {
|
|
605 int ind = z*Nr + k;
|
|
606 //float fac = (float) (2.0 * M_PI * (double) k / (double) len);
|
|
607 kernelString += string(" ang = dir * ( 2.0f * M_PI * ") + num2str(k) + string(".0f / ") + num2str(len) + string(".0f )") + string(" * angf;\n");
|
|
608 kernelString += string(" w = (float2)(native_cos(ang), native_sin(ang));\n");
|
|
609 kernelString += string(" a[") + num2str(ind) + string("] = complexMul(a[") + num2str(ind) + string("], w);\n");
|
|
610 }
|
|
611 }
|
|
612 }
|
|
613
|
|
614 static int
|
|
615 getPadding(int numWorkItemsPerXForm, int Nprev, int numWorkItemsReq, int numXFormsPerWG, int Nr, int numBanks, int *offset, int *midPad)
|
|
616 {
|
|
617 if((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks))
|
|
618 *offset = 0;
|
|
619 else {
|
|
620 int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev;
|
|
621 int numColsReq = 1;
|
|
622 if(numRowsReq > Nr)
|
|
623 numColsReq = numRowsReq / Nr;
|
|
624 numColsReq = Nprev * numColsReq;
|
|
625 *offset = numColsReq;
|
|
626 }
|
|
627
|
|
628 if(numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1)
|
|
629 *midPad = 0;
|
|
630 else {
|
|
631 int bankNum = ( (numWorkItemsReq + *offset) * Nr ) & (numBanks - 1);
|
|
632 if( bankNum >= numWorkItemsPerXForm )
|
|
633 *midPad = 0;
|
|
634 else
|
|
635 *midPad = numWorkItemsPerXForm - bankNum;
|
|
636 }
|
|
637
|
|
638 int lMemSize = ( numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1);
|
|
639 return lMemSize;
|
|
640 }
|
|
641
|
|
642
|
|
643 static void
|
|
644 insertLocalStores(string &kernelString, int numIter, int Nr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
|
|
645 {
|
|
646 int z, k;
|
|
647
|
|
648 for(z = 0; z < numIter; z++) {
|
|
649 for(k = 0; k < Nr; k++) {
|
|
650 int index = k*(numWorkItemsReq + offset) + z*numWorkItemsPerXForm;
|
|
651 kernelString += string(" lMemStore[") + num2str(index) + string("] = a[") + num2str(z*Nr + k) + string("].") + comp + string(";\n");
|
|
652 }
|
|
653 }
|
|
654 kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
655 }
|
|
656
|
|
657 static void
|
|
658 insertLocalLoads(string &kernelString, int n, int Nr, int Nrn, int Nprev, int Ncurr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
|
|
659 {
|
|
660 int numWorkItemsReqN = n / Nrn;
|
|
661 int interBlockHNum = max( Nprev / numWorkItemsPerXForm, 1 );
|
|
662 int interBlockHStride = numWorkItemsPerXForm;
|
|
663 int vertWidth = max(numWorkItemsPerXForm / Nprev, 1);
|
|
664 vertWidth = min( vertWidth, Nr);
|
|
665 int vertNum = Nr / vertWidth;
|
|
666 int vertStride = ( n / Nr + offset ) * vertWidth;
|
|
667 int iter = max( numWorkItemsReqN / numWorkItemsPerXForm, 1);
|
|
668 int intraBlockHStride = (numWorkItemsPerXForm / (Nprev*Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev*Nr)) : 1;
|
|
669 intraBlockHStride *= Nprev;
|
|
670
|
|
671 int stride = numWorkItemsReq / Nrn;
|
|
672 int i;
|
|
673 for(i = 0; i < iter; i++) {
|
|
674 int ii = i / (interBlockHNum * vertNum);
|
|
675 int zz = i % (interBlockHNum * vertNum);
|
|
676 int jj = zz % interBlockHNum;
|
|
677 int kk = zz / interBlockHNum;
|
|
678 int z;
|
|
679 for(z = 0; z < Nrn; z++) {
|
|
680 int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride;
|
|
681 kernelString += string(" a[") + num2str(i*Nrn + z) + string("].") + comp + string(" = lMemLoad[") + num2str(st) + string("];\n");
|
|
682 }
|
|
683 }
|
|
684 kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
685 }
|
|
686
|
|
687 static void
|
|
688 insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numWorkItemsReq, int numWorkItemsPerXForm, int numXFormsPerWG, int offset, int midPad)
|
|
689 {
|
|
690 int Ncurr = Nprev * Nr;
|
|
691 int logNcurr = (int)log2(Ncurr);
|
|
692 int logNprev = (int)log2(Nprev);
|
|
693 int incr = (numWorkItemsReq + offset) * Nr + midPad;
|
|
694
|
|
695 if(Ncurr < numWorkItemsPerXForm)
|
|
696 {
|
|
697 if(Nprev == 1)
|
|
698 kernelString += string(" j = ii & ") + num2str(Ncurr - 1) + string(";\n");
|
|
699 else
|
|
700 kernelString += string(" j = (ii & ") + num2str(Ncurr - 1) + string(") >> ") + num2str(logNprev) + string(";\n");
|
|
701
|
|
702 if(Nprev == 1)
|
|
703 kernelString += string(" i = ii >> ") + num2str(logNcurr) + string(";\n");
|
|
704 else
|
|
705 kernelString += string(" i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n");
|
|
706 }
|
|
707 else
|
|
708 {
|
|
709 if(Nprev == 1)
|
|
710 kernelString += string(" j = ii;\n");
|
|
711 else
|
|
712 kernelString += string(" j = ii >> ") + num2str(logNprev) + string(";\n");
|
|
713 if(Nprev == 1)
|
|
714 kernelString += string(" i = 0;\n");
|
|
715 else
|
|
716 kernelString += string(" i = ii & ") + num2str(Nprev - 1) + string(";\n");
|
|
717 }
|
|
718
|
|
719 if(numXFormsPerWG > 1)
|
|
720 kernelString += string(" i = mad24(jj, ") + num2str(incr) + string(", i);\n");
|
|
721
|
|
722 kernelString += string(" lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n");
|
|
723 }
|
|
724
|
|
725 static void
|
|
726 insertLocalStoreIndexArithmatic(string &kernelString, int numWorkItemsReq, int numXFormsPerWG, int Nr, int offset, int midPad)
|
|
727 {
|
|
728 if(numXFormsPerWG == 1) {
|
|
729 kernelString += string(" lMemStore = sMem + ii;\n");
|
|
730 }
|
|
731 else {
|
|
732 kernelString += string(" lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n");
|
|
733 }
|
|
734 }
|
|
735
|
|
736
|
|
737 static void
|
|
738 createLocalMemfftKernelString(cl_fft_plan *plan)
|
|
739 {
|
|
740 unsigned int radixArray[10];
|
|
741 unsigned int numRadix;
|
|
742
|
|
743 unsigned int n = plan->n.x;
|
|
744
|
|
745 assert(n <= plan->max_work_item_per_workgroup * plan->max_radix && "signal lenght too big for local mem fft\n");
|
|
746
|
|
747 getRadixArray(n, radixArray, &numRadix, 0);
|
|
748 assert(numRadix > 0 && "no radix array supplied\n");
|
|
749
|
|
750 if(n/radixArray[0] > plan->max_work_item_per_workgroup)
|
|
751 getRadixArray(n, radixArray, &numRadix, plan->max_radix);
|
|
752
|
|
753 assert(radixArray[0] <= plan->max_radix && "max radix choosen is greater than allowed\n");
|
|
754 assert(n/radixArray[0] <= plan->max_work_item_per_workgroup && "required work items per xform greater than maximum work items allowed per work group for local mem fft\n");
|
|
755
|
|
756 unsigned int tmpLen = 1;
|
|
757 unsigned int i;
|
|
758 for(i = 0; i < numRadix; i++)
|
|
759 {
|
|
760 assert( radixArray[i] && !( (radixArray[i] - 1) & radixArray[i] ) );
|
|
761 tmpLen *= radixArray[i];
|
|
762 }
|
|
763 assert(tmpLen == n && "product of radices choosen doesnt match the length of signal\n");
|
|
764
|
|
765 int offset, midPad;
|
|
766 string localString(""), kernelName("");
|
|
767
|
|
768 clFFT_DataFormat dataFormat = plan->format;
|
|
769 string *kernelString = plan->kernel_string;
|
|
770
|
|
771
|
|
772 cl_fft_kernel_info **kInfo = &plan->kernel_info;
|
|
773 int kCount = 0;
|
|
774
|
|
775 while(*kInfo)
|
|
776 {
|
|
777 kInfo = &(*kInfo)->next;
|
|
778 kCount++;
|
|
779 }
|
|
780
|
|
781 kernelName = string("fft") + num2str(kCount);
|
|
782
|
|
783 *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info));
|
|
784 (*kInfo)->kernel = 0;
|
|
785 (*kInfo)->lmem_size = 0;
|
|
786 (*kInfo)->num_workgroups = 0;
|
|
787 (*kInfo)->num_workitems_per_workgroup = 0;
|
|
788 (*kInfo)->dir = cl_fft_kernel_x;
|
|
789 (*kInfo)->in_place_possible = 1;
|
|
790 (*kInfo)->next = NULL;
|
|
791 (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1));
|
|
792 strcpy((*kInfo)->kernel_name, kernelName.c_str());
|
|
793
|
|
794 unsigned int numWorkItemsPerXForm = n / radixArray[0];
|
|
795 unsigned int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm;
|
|
796 assert(numWorkItemsPerWG <= plan->max_work_item_per_workgroup);
|
|
797 int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm;
|
|
798 (*kInfo)->num_workgroups = 1;
|
|
799 (*kInfo)->num_xforms_per_workgroup = numXFormsPerWG;
|
|
800 (*kInfo)->num_workitems_per_workgroup = numWorkItemsPerWG;
|
|
801
|
|
802 unsigned int *N = radixArray;
|
|
803 unsigned int maxRadix = N[0];
|
|
804 unsigned int lMemSize = 0;
|
|
805
|
|
806 insertVariables(localString, maxRadix);
|
|
807
|
|
808 lMemSize = insertGlobalLoadsAndTranspose(localString, n, numWorkItemsPerXForm, numXFormsPerWG, maxRadix, plan->min_mem_coalesce_width, dataFormat);
|
|
809 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
|
|
810
|
|
811 string xcomp = string("x");
|
|
812 string ycomp = string("y");
|
|
813
|
|
814 unsigned int Nprev = 1;
|
|
815 unsigned int len = n;
|
|
816 unsigned int r;
|
|
817 for(r = 0; r < numRadix; r++)
|
|
818 {
|
|
819 int numIter = N[0] / N[r];
|
|
820 int numWorkItemsReq = n / N[r];
|
|
821 int Ncurr = Nprev * N[r];
|
|
822 insertfftKernel(localString, N[r], numIter);
|
|
823
|
|
824 if(r < (numRadix - 1)) {
|
|
825 insertTwiddleKernel(localString, N[r], numIter, Nprev, len, numWorkItemsPerXForm);
|
|
826 lMemSize = getPadding(numWorkItemsPerXForm, Nprev, numWorkItemsReq, numXFormsPerWG, N[r], plan->num_local_mem_banks, &offset, &midPad);
|
|
827 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
|
|
828 insertLocalStoreIndexArithmatic(localString, numWorkItemsReq, numXFormsPerWG, N[r], offset, midPad);
|
|
829 insertLocalLoadIndexArithmatic(localString, Nprev, N[r], numWorkItemsReq, numWorkItemsPerXForm, numXFormsPerWG, offset, midPad);
|
|
830 insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
|
|
831 insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
|
|
832 insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
|
|
833 insertLocalLoads(localString, n, N[r], N[r+1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
|
|
834 Nprev = Ncurr;
|
|
835 len = len / N[r];
|
|
836 }
|
|
837 }
|
|
838
|
|
839 lMemSize = insertGlobalStoresAndTranspose(localString, n, maxRadix, N[numRadix - 1], numWorkItemsPerXForm, numXFormsPerWG, plan->min_mem_coalesce_width, dataFormat);
|
|
840 (*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
|
|
841
|
|
842 insertHeader(*kernelString, kernelName, dataFormat);
|
|
843 *kernelString += string("{\n");
|
|
844 if((*kInfo)->lmem_size)
|
|
845 *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
|
|
846 *kernelString += localString;
|
|
847 *kernelString += string("}\n");
|
|
848 }
|
|
849
|
|
850 // For n larger than what can be computed using local memory fft, global transposes
|
|
851 // multiple kernel launces is needed. For these sizes, n can be decomposed using
|
|
852 // much larger base radices i.e. say n = 262144 = 128 x 64 x 32. Thus three kernel
|
|
853 // launches will be needed, first computing 64 x 32, length 128 ffts, second computing
|
|
854 // 128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts.
|
|
855 // Each of these base radices can futher be divided into factors so that each of these
|
|
856 // base ffts can be computed within one kernel launch using in-register ffts and local
|
|
857 // memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length
|
|
858 // 128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length
|
|
859 // 16 ffts followed by transpose using local memory followed by each of these eight
|
|
860 // work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This
|
|
861 // means only 8 work items are needed for computing one length 128 fft. If we choose
|
|
862 // work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one
|
|
863 // work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this
|
|
864 // means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each
|
|
865 // work group where each work group is computing 8 length 128 ffts where each length
|
|
866 // 128 fft is computed by 8 work items. Same logic can be applied to other two kernels
|
|
867 // in this example. Users can play with difference base radices and difference
|
|
868 // decompositions of base radices to generates different kernels and see which gives
|
|
869 // best performance. Following function is just fixed to use 128 as base radix
|
|
870
|
|
871 void
|
|
872 getGlobalRadixInfo(int n, int *radix, int *R1, int *R2, int *numRadices)
|
|
873 {
|
|
874 int baseRadix = min(n, 128);
|
|
875
|
|
876 int numR = 0;
|
|
877 int N = n;
|
|
878 while(N > baseRadix)
|
|
879 {
|
|
880 N /= baseRadix;
|
|
881 numR++;
|
|
882 }
|
|
883
|
|
884 for(int i = 0; i < numR; i++)
|
|
885 radix[i] = baseRadix;
|
|
886
|
|
887 radix[numR] = N;
|
|
888 numR++;
|
|
889 *numRadices = numR;
|
|
890
|
|
891 for(int i = 0; i < numR; i++)
|
|
892 {
|
|
893 int B = radix[i];
|
|
894 if(B <= 8)
|
|
895 {
|
|
896 R1[i] = B;
|
|
897 R2[i] = 1;
|
|
898 continue;
|
|
899 }
|
|
900
|
|
901 int r1 = 2;
|
|
902 int r2 = B / r1;
|
|
903 while(r2 > r1)
|
|
904 {
|
|
905 r1 *=2;
|
|
906 r2 = B / r1;
|
|
907 }
|
|
908 R1[i] = r1;
|
|
909 R2[i] = r2;
|
|
910 }
|
|
911 }
|
|
912
|
|
913 static void
|
|
914 createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir dir, int vertBS)
|
|
915 {
|
|
916 int i, j, k, t;
|
|
917 int radixArr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
|
|
918 int R1Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
|
|
919 int R2Arr[10] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
|
|
920 int radix, R1, R2;
|
|
921 int numRadices;
|
|
922
|
|
923 int maxThreadsPerBlock = plan->max_work_item_per_workgroup;
|
|
924 int maxArrayLen = plan->max_radix;
|
|
925 int batchSize = plan->min_mem_coalesce_width;
|
|
926 clFFT_DataFormat dataFormat = plan->format;
|
|
927 int vertical = (dir == cl_fft_kernel_x) ? 0 : 1;
|
|
928
|
|
929 getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices);
|
|
930
|
|
931 int numPasses = numRadices;
|
|
932
|
|
933 string localString(""), kernelName("");
|
|
934 string *kernelString = plan->kernel_string;
|
|
935 cl_fft_kernel_info **kInfo = &plan->kernel_info;
|
|
936 int kCount = 0;
|
|
937
|
|
938 while(*kInfo)
|
|
939 {
|
|
940 kInfo = &(*kInfo)->next;
|
|
941 kCount++;
|
|
942 }
|
|
943
|
|
944 int N = n;
|
|
945 int m = (int)log2(n);
|
|
946 int Rinit = vertical ? BS : 1;
|
|
947 batchSize = vertical ? min(BS, batchSize) : batchSize;
|
|
948 int passNum;
|
|
949
|
|
950 for(passNum = 0; passNum < numPasses; passNum++)
|
|
951 {
|
|
952
|
|
953 localString.clear();
|
|
954 kernelName.clear();
|
|
955
|
|
956 radix = radixArr[passNum];
|
|
957 R1 = R1Arr[passNum];
|
|
958 R2 = R2Arr[passNum];
|
|
959
|
|
960 int strideI = Rinit;
|
|
961 for(i = 0; i < numPasses; i++)
|
|
962 if(i != passNum)
|
|
963 strideI *= radixArr[i];
|
|
964
|
|
965 int strideO = Rinit;
|
|
966 for(i = 0; i < passNum; i++)
|
|
967 strideO *= radixArr[i];
|
|
968
|
|
969 int threadsPerXForm = R2;
|
|
970 batchSize = R2 == 1 ? plan->max_work_item_per_workgroup : batchSize;
|
|
971 batchSize = min(batchSize, strideI);
|
|
972 int threadsPerBlock = batchSize * threadsPerXForm;
|
|
973 threadsPerBlock = min(threadsPerBlock, maxThreadsPerBlock);
|
|
974 batchSize = threadsPerBlock / threadsPerXForm;
|
|
975 assert(R2 <= R1);
|
|
976 assert(R1*R2 == radix);
|
|
977 assert(R1 <= maxArrayLen);
|
|
978 assert(threadsPerBlock <= maxThreadsPerBlock);
|
|
979
|
|
980 int numIter = R1 / R2;
|
|
981 int gInInc = threadsPerBlock / batchSize;
|
|
982
|
|
983
|
|
984 int lgStrideO = (int)log2(strideO);
|
|
985 int numBlocksPerXForm = strideI / batchSize;
|
|
986 int numBlocks = numBlocksPerXForm;
|
|
987 if(!vertical)
|
|
988 numBlocks *= BS;
|
|
989 else
|
|
990 numBlocks *= vertBS;
|
|
991
|
|
992 kernelName = string("fft") + num2str(kCount);
|
|
993 *kInfo = (cl_fft_kernel_info *) malloc(sizeof(cl_fft_kernel_info));
|
|
994 (*kInfo)->kernel = 0;
|
|
995 if(R2 == 1)
|
|
996 (*kInfo)->lmem_size = 0;
|
|
997 else
|
|
998 {
|
|
999 if(strideO == 1)
|
|
1000 (*kInfo)->lmem_size = (radix + 1)*batchSize;
|
|
1001 else
|
|
1002 (*kInfo)->lmem_size = threadsPerBlock*R1;
|
|
1003 }
|
|
1004 (*kInfo)->num_workgroups = numBlocks;
|
|
1005 (*kInfo)->num_xforms_per_workgroup = 1;
|
|
1006 (*kInfo)->num_workitems_per_workgroup = threadsPerBlock;
|
|
1007 (*kInfo)->dir = dir;
|
|
1008 if( (passNum == (numPasses - 1)) && (numPasses & 1) )
|
|
1009 (*kInfo)->in_place_possible = 1;
|
|
1010 else
|
|
1011 (*kInfo)->in_place_possible = 0;
|
|
1012 (*kInfo)->next = NULL;
|
|
1013 (*kInfo)->kernel_name = (char *) malloc(sizeof(char)*(kernelName.size()+1));
|
|
1014 strcpy((*kInfo)->kernel_name, kernelName.c_str());
|
|
1015
|
|
1016 insertVariables(localString, R1);
|
|
1017
|
|
1018 if(vertical)
|
|
1019 {
|
|
1020 localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n");
|
|
1021 localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
|
|
1022 localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n");
|
|
1023 localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n");
|
|
1024 localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
|
|
1025 localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
|
|
1026 int stride = radix*Rinit;
|
|
1027 for(i = 0; i < passNum; i++)
|
|
1028 stride *= radixArr[i];
|
|
1029 localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n");
|
|
1030 localString += string("bNum = groupId;\n");
|
|
1031 }
|
|
1032 else
|
|
1033 {
|
|
1034 int lgNumBlocksPerXForm = (int)log2(numBlocksPerXForm);
|
|
1035 localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
|
|
1036 localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n");
|
|
1037 localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n");
|
|
1038 localString += string("tid = indexIn;\n");
|
|
1039 localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
|
|
1040 localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
|
|
1041 int stride = radix*Rinit;
|
|
1042 for(i = 0; i < passNum; i++)
|
|
1043 stride *= radixArr[i];
|
|
1044 localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n");
|
|
1045 localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n");
|
|
1046 localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n");
|
|
1047 }
|
|
1048
|
|
1049 // Load Data
|
|
1050 int lgBatchSize = (int)log2(batchSize);
|
|
1051 localString += string("tid = lId;\n");
|
|
1052 localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n");
|
|
1053 localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n");
|
|
1054 localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n");
|
|
1055
|
|
1056 if(dataFormat == clFFT_SplitComplexFormat)
|
|
1057 {
|
|
1058 localString += string("in_real += indexIn;\n");
|
|
1059 localString += string("in_imag += indexIn;\n");
|
|
1060 for(j = 0; j < R1; j++)
|
|
1061 localString += string("a[") + num2str(j) + string("].x = in_real[") + num2str(j*gInInc*strideI) + string("];\n");
|
|
1062 for(j = 0; j < R1; j++)
|
|
1063 localString += string("a[") + num2str(j) + string("].y = in_imag[") + num2str(j*gInInc*strideI) + string("];\n");
|
|
1064 }
|
|
1065 else
|
|
1066 {
|
|
1067 localString += string("in += indexIn;\n");
|
|
1068 for(j = 0; j < R1; j++)
|
|
1069 localString += string("a[") + num2str(j) + string("] = in[") + num2str(j*gInInc*strideI) + string("];\n");
|
|
1070 }
|
|
1071
|
|
1072 localString += string("fftKernel") + num2str(R1) + string("(a, dir);\n");
|
|
1073
|
|
1074 if(R2 > 1)
|
|
1075 {
|
|
1076 // twiddle
|
|
1077 for(k = 1; k < R1; k++)
|
|
1078 {
|
|
1079 localString += string("ang = dir*(2.0f*M_PI*") + num2str(k) + string("/") + num2str(radix) + string(")*j;\n");
|
|
1080 localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
|
|
1081 localString += string("a[") + num2str(k) + string("] = complexMul(a[") + num2str(k) + string("], w);\n");
|
|
1082 }
|
|
1083
|
|
1084 // shuffle
|
|
1085 numIter = R1 / R2;
|
|
1086 localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n");
|
|
1087 localString += string("lMemStore = sMem + tid;\n");
|
|
1088 localString += string("lMemLoad = sMem + indexIn;\n");
|
|
1089 for(k = 0; k < R1; k++)
|
|
1090 localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
|
|
1091 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
1092 for(k = 0; k < numIter; k++)
|
|
1093 for(t = 0; t < R2; t++)
|
|
1094 localString += string("a[") + num2str(k*R2+t) + string("].x = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n");
|
|
1095 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
1096 for(k = 0; k < R1; k++)
|
|
1097 localString += string("lMemStore[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
|
|
1098 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
1099 for(k = 0; k < numIter; k++)
|
|
1100 for(t = 0; t < R2; t++)
|
|
1101 localString += string("a[") + num2str(k*R2+t) + string("].y = lMemLoad[") + num2str(t*batchSize + k*threadsPerBlock) + string("];\n");
|
|
1102 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
1103
|
|
1104 for(j = 0; j < numIter; j++)
|
|
1105 localString += string("fftKernel") + num2str(R2) + string("(a + ") + num2str(j*R2) + string(", dir);\n");
|
|
1106 }
|
|
1107
|
|
1108 // twiddle
|
|
1109 if(passNum < (numPasses - 1))
|
|
1110 {
|
|
1111 localString += string("l = ((bNum << ") + num2str(lgBatchSize) + string(") + i) >> ") + num2str(lgStrideO) + string(";\n");
|
|
1112 localString += string("k = j << ") + num2str((int)log2(R1/R2)) + string(";\n");
|
|
1113 localString += string("ang1 = dir*(2.0f*M_PI/") + num2str(N) + string(")*l;\n");
|
|
1114 for(t = 0; t < R1; t++)
|
|
1115 {
|
|
1116 localString += string("ang = ang1*(k + ") + num2str((t%R2)*R1 + (t/R2)) + string(");\n");
|
|
1117 localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
|
|
1118 localString += string("a[") + num2str(t) + string("] = complexMul(a[") + num2str(t) + string("], w);\n");
|
|
1119 }
|
|
1120 }
|
|
1121
|
|
1122 // Store Data
|
|
1123 if(strideO == 1)
|
|
1124 {
|
|
1125
|
|
1126 localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n");
|
|
1127 localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n");
|
|
1128
|
|
1129 for(i = 0; i < R1/R2; i++)
|
|
1130 for(j = 0; j < R2; j++)
|
|
1131 localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].x;\n");
|
|
1132 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
1133 if(threadsPerBlock >= radix)
|
|
1134 {
|
|
1135 for(i = 0; i < R1; i++)
|
|
1136 localString += string("a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n");
|
|
1137 }
|
|
1138 else
|
|
1139 {
|
|
1140 int innerIter = radix/threadsPerBlock;
|
|
1141 int outerIter = R1/innerIter;
|
|
1142 for(i = 0; i < outerIter; i++)
|
|
1143 for(j = 0; j < innerIter; j++)
|
|
1144 localString += string("a[") + num2str(i*innerIter+j) + string("].x = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n");
|
|
1145 }
|
|
1146 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
1147
|
|
1148 for(i = 0; i < R1/R2; i++)
|
|
1149 for(j = 0; j < R2; j++)
|
|
1150 localString += string("lMemStore[ ") + num2str(i + j*R1) + string("] = a[") + num2str(i*R2+j) + string("].y;\n");
|
|
1151 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
1152 if(threadsPerBlock >= radix)
|
|
1153 {
|
|
1154 for(i = 0; i < R1; i++)
|
|
1155 localString += string("a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i*(radix+1)*(threadsPerBlock/radix)) + string("];\n");
|
|
1156 }
|
|
1157 else
|
|
1158 {
|
|
1159 int innerIter = radix/threadsPerBlock;
|
|
1160 int outerIter = R1/innerIter;
|
|
1161 for(i = 0; i < outerIter; i++)
|
|
1162 for(j = 0; j < innerIter; j++)
|
|
1163 localString += string("a[") + num2str(i*innerIter+j) + string("].y = lMemLoad[") + num2str(j*threadsPerBlock + i*(radix+1)) + string("];\n");
|
|
1164 }
|
|
1165 localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
|
|
1166
|
|
1167 localString += string("indexOut += tid;\n");
|
|
1168 if(dataFormat == clFFT_SplitComplexFormat) {
|
|
1169 localString += string("out_real += indexOut;\n");
|
|
1170 localString += string("out_imag += indexOut;\n");
|
|
1171 for(k = 0; k < R1; k++)
|
|
1172 localString += string("out_real[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
|
|
1173 for(k = 0; k < R1; k++)
|
|
1174 localString += string("out_imag[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
|
|
1175 }
|
|
1176 else {
|
|
1177 localString += string("out += indexOut;\n");
|
|
1178 for(k = 0; k < R1; k++)
|
|
1179 localString += string("out[") + num2str(k*threadsPerBlock) + string("] = a[") + num2str(k) + string("];\n");
|
|
1180 }
|
|
1181
|
|
1182 }
|
|
1183 else
|
|
1184 {
|
|
1185 localString += string("indexOut += mad24(j, ") + num2str(numIter*strideO) + string(", i);\n");
|
|
1186 if(dataFormat == clFFT_SplitComplexFormat) {
|
|
1187 localString += string("out_real += indexOut;\n");
|
|
1188 localString += string("out_imag += indexOut;\n");
|
|
1189 for(k = 0; k < R1; k++)
|
|
1190 localString += string("out_real[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].x;\n");
|
|
1191 for(k = 0; k < R1; k++)
|
|
1192 localString += string("out_imag[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("].y;\n");
|
|
1193 }
|
|
1194 else {
|
|
1195 localString += string("out += indexOut;\n");
|
|
1196 for(k = 0; k < R1; k++)
|
|
1197 localString += string("out[") + num2str(((k%R2)*R1 + (k/R2))*strideO) + string("] = a[") + num2str(k) + string("];\n");
|
|
1198 }
|
|
1199 }
|
|
1200
|
|
1201 insertHeader(*kernelString, kernelName, dataFormat);
|
|
1202 *kernelString += string("{\n");
|
|
1203 if((*kInfo)->lmem_size)
|
|
1204 *kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
|
|
1205 *kernelString += localString;
|
|
1206 *kernelString += string("}\n");
|
|
1207
|
|
1208 N /= radix;
|
|
1209 kInfo = &(*kInfo)->next;
|
|
1210 kCount++;
|
|
1211 }
|
|
1212 }
|
|
1213
|
|
1214 void FFT1D(cl_fft_plan *plan, cl_fft_kernel_dir dir)
|
|
1215 {
|
|
1216 unsigned int radixArray[10];
|
|
1217 unsigned int numRadix;
|
|
1218
|
|
1219 switch(dir)
|
|
1220 {
|
|
1221 case cl_fft_kernel_x:
|
|
1222 if(plan->n.x > plan->max_localmem_fft_size)
|
|
1223 {
|
|
1224 createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
|
|
1225 }
|
|
1226 else if(plan->n.x > 1)
|
|
1227 {
|
|
1228 getRadixArray(plan->n.x, radixArray, &numRadix, 0);
|
|
1229 if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
|
|
1230 {
|
|
1231 createLocalMemfftKernelString(plan);
|
|
1232 }
|
|
1233 else
|
|
1234 {
|
|
1235 getRadixArray(plan->n.x, radixArray, &numRadix, plan->max_radix);
|
|
1236 if(plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
|
|
1237 createLocalMemfftKernelString(plan);
|
|
1238 else
|
|
1239 createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
|
|
1240 }
|
|
1241 }
|
|
1242 break;
|
|
1243
|
|
1244 case cl_fft_kernel_y:
|
|
1245 if(plan->n.y > 1)
|
|
1246 createGlobalFFTKernelString(plan, plan->n.y, plan->n.x, cl_fft_kernel_y, 1);
|
|
1247 break;
|
|
1248
|
|
1249 case cl_fft_kernel_z:
|
|
1250 if(plan->n.z > 1)
|
|
1251 createGlobalFFTKernelString(plan, plan->n.z, plan->n.x*plan->n.y, cl_fft_kernel_z, 1);
|
|
1252 default:
|
|
1253 return;
|
|
1254 }
|
|
1255 }
|
|
1256
|