145
|
1 // { dg-do compile }
|
|
2 // { dg-options "-std=c++14 -O2 -ftemplate-depth=1000000" }
|
|
3
|
|
4 template <class T, int Dim0, int Dim1, int Dim2> struct Tensor3;
|
|
5 template <class A, class T, int Dim0, int Dim1, int Dim2, char i, char j,
|
|
6 char k>
|
|
7 struct Tensor3_Expr;
|
|
8
|
|
9 template <class T, int Dim0, int Dim1, int Dim2, int Dim3> struct Tensor4;
|
|
10 template <class A, class T, int Dim0, int Dim1, int Dim2, int Dim3, char i,
|
|
11 char j, char k, char l>
|
|
12 struct Tensor4_Expr;
|
|
13
|
|
14 template <char i, int Dim> struct Index
|
|
15 {};
|
|
16 template <const int N> struct Number
|
|
17 {
|
|
18 Number(){};
|
|
19 operator int() const { return N; }
|
|
20 };
|
|
21
|
|
22 template <class T, int Tensor_Dim0, int Tensor_Dim1, int Tensor_Dim2>
|
|
23 struct Tensor3
|
|
24 {
|
|
25 T data[Tensor_Dim0][Tensor_Dim1][Tensor_Dim2];
|
|
26
|
|
27 T operator()(const int N1, const int N2, const int N3) const
|
|
28 {
|
|
29 return data[N1][N2][N3];
|
|
30 }
|
|
31
|
|
32 template <char i, char j, char k, int Dim0, int Dim1, int Dim2>
|
|
33 Tensor3_Expr<const Tensor3<T, Tensor_Dim0, Tensor_Dim1, Tensor_Dim2>, T,
|
|
34 Dim0, Dim1, Dim2, i, j, k>
|
|
35 operator()(const Index<i, Dim0>, const Index<j, Dim1>,
|
|
36 const Index<k, Dim2>) const
|
|
37 {
|
|
38 return Tensor3_Expr<const Tensor3<T, Tensor_Dim0, Tensor_Dim1, Tensor_Dim2>,
|
|
39 T, Dim0, Dim1, Dim2, i, j, k>(*this);
|
|
40 }
|
|
41 };
|
|
42
|
|
43 template <class A, class T, int Dim0, int Dim1, int Dim2, char i, char j,
|
|
44 char k>
|
|
45 struct Tensor3_Expr
|
|
46 {
|
|
47 A iter;
|
|
48
|
|
49 Tensor3_Expr(const A &a) : iter(a) {}
|
|
50 T operator()(const int N1, const int N2, const int N3) const
|
|
51 {
|
|
52 return iter(N1, N2, N3);
|
|
53 }
|
|
54 };
|
|
55
|
|
56 template <class A, class T, int Tensor_Dim0, int Tensor_Dim1, int Tensor_Dim2,
|
|
57 int Dim0, int Dim1, int Dim2, char i, char j, char k>
|
|
58 struct Tensor3_Expr<Tensor3<A, Tensor_Dim0, Tensor_Dim1, Tensor_Dim2>, T, Dim0,
|
|
59 Dim1, Dim2, i, j, k>
|
|
60 {
|
|
61 Tensor3<A, Tensor_Dim0, Tensor_Dim1, Tensor_Dim2> &iter;
|
|
62
|
|
63 Tensor3_Expr(Tensor3<A, Tensor_Dim0, Tensor_Dim1, Tensor_Dim2> &a) : iter(a)
|
|
64 {}
|
|
65 T operator()(const int N1, const int N2, const int N3) const
|
|
66 {
|
|
67 return iter(N1, N2, N3);
|
|
68 }
|
|
69 };
|
|
70
|
|
71 template <class A, class B, class T, class U, int Dim0, int Dim1, int Dim23,
|
|
72 int Dim4, int Dim5, char i, char j, char k, char l, char m>
|
|
73 struct Tensor3_times_Tensor3_21
|
|
74 {
|
|
75 Tensor3_Expr<A, T, Dim0, Dim1, Dim23, i, j, k> iterA;
|
|
76 Tensor3_Expr<B, U, Dim23, Dim4, Dim5, k, l, m> iterB;
|
|
77
|
|
78 template <int CurrentDim>
|
|
79 T eval(const int N1, const int N2, const int N3, const int N4,
|
|
80 const Number<CurrentDim> &) const
|
|
81 {
|
|
82 return iterA(N1, N2, CurrentDim - 1) * iterB(CurrentDim - 1, N3, N4)
|
|
83 + eval(N1, N2, N3, N4, Number<CurrentDim - 1>());
|
|
84 }
|
|
85 T eval(const int N1, const int N2, const int N3, const int N4,
|
|
86 const Number<1> &) const
|
|
87 {
|
|
88 return iterA(N1, N2, 0) * iterB(0, N3, N4);
|
|
89 }
|
|
90
|
|
91 Tensor3_times_Tensor3_21(
|
|
92 const Tensor3_Expr<A, T, Dim0, Dim1, Dim23, i, j, k> &a,
|
|
93 const Tensor3_Expr<B, U, Dim23, Dim4, Dim5, k, l, m> &b)
|
|
94 : iterA(a), iterB(b)
|
|
95 {}
|
|
96 T operator()(const int &N1, const int &N2, const int &N3,
|
|
97 const int &N4) const
|
|
98 {
|
|
99 return eval(N1, N2, N3, N4, Number<Dim23>());
|
|
100 }
|
|
101 };
|
|
102
|
|
103 template <class A, class B, class T, class U, int Dim0, int Dim1, int Dim23,
|
|
104 int Dim4, int Dim5, char i, char j, char k, char l, char m>
|
|
105 Tensor4_Expr<Tensor3_times_Tensor3_21<A, B, T, U, Dim0, Dim1, Dim23, Dim4,
|
|
106 Dim5, i, j, k, l, m>,
|
|
107 T, Dim0, Dim1, Dim4, Dim5, i, j, l, m>
|
|
108 operator*(const Tensor3_Expr<A, T, Dim0, Dim1, Dim23, i, j, k> &a,
|
|
109 const Tensor3_Expr<B, U, Dim23, Dim4, Dim5, k, l, m> &b)
|
|
110 {
|
|
111 using TensorExpr = Tensor3_times_Tensor3_21<A, B, T, U, Dim0, Dim1, Dim23,
|
|
112 Dim4, Dim5, i, j, k, l, m>;
|
|
113 return Tensor4_Expr<TensorExpr, T, Dim0, Dim1, Dim4, Dim5, i, j, l, m>(
|
|
114 TensorExpr(a, b));
|
|
115 };
|
|
116
|
|
117 template <class T, int Tensor_Dim0, int Tensor_Dim1, int Tensor_Dim2,
|
|
118 int Tensor_Dim3>
|
|
119 struct Tensor4
|
|
120 {
|
|
121 T data[Tensor_Dim0][Tensor_Dim1][Tensor_Dim2][Tensor_Dim3];
|
|
122
|
|
123 Tensor4() {}
|
|
124 T &operator()(const int N1, const int N2, const int N3, const int N4)
|
|
125 {
|
|
126 return data[N1][N2][N3][N4];
|
|
127 }
|
|
128
|
|
129 template <char i, char j, char k, char l, int Dim0, int Dim1, int Dim2,
|
|
130 int Dim3>
|
|
131 Tensor4_Expr<Tensor4<T, Tensor_Dim0, Tensor_Dim1, Tensor_Dim2, Tensor_Dim3>,
|
|
132 T, Dim0, Dim1, Dim2, Dim3, i, j, k, l>
|
|
133 operator()(const Index<i, Dim0>, const Index<j, Dim1>, const Index<k, Dim2>,
|
|
134 const Index<l, Dim3>)
|
|
135 {
|
|
136 return Tensor4_Expr<
|
|
137 Tensor4<T, Tensor_Dim0, Tensor_Dim1, Tensor_Dim2, Tensor_Dim3>, T, Dim0,
|
|
138 Dim1, Dim2, Dim3, i, j, k, l>(*this);
|
|
139 };
|
|
140 };
|
|
141
|
|
142 template <class A, class T, int Dim0, int Dim1, int Dim2, int Dim3, char i,
|
|
143 char j, char k, char l>
|
|
144 struct Tensor4_Expr
|
|
145 {
|
|
146 A iter;
|
|
147
|
|
148 Tensor4_Expr(const A &a) : iter(a) {}
|
|
149 T operator()(const int N1, const int N2, const int N3, const int N4) const
|
|
150 {
|
|
151 return iter(N1, N2, N3, N4);
|
|
152 }
|
|
153 };
|
|
154
|
|
155 template <class A, class T, int Dim0, int Dim1, int Dim2, int Dim3, char i,
|
|
156 char j, char k, char l>
|
|
157 struct Tensor4_Expr<Tensor4<A, Dim0, Dim1, Dim2, Dim3>, T, Dim0, Dim1, Dim2,
|
|
158 Dim3, i, j, k, l>
|
|
159 {
|
|
160 Tensor4<A, Dim0, Dim1, Dim2, Dim3> &iter;
|
|
161
|
|
162 Tensor4_Expr(Tensor4<A, Dim0, Dim1, Dim2, Dim3> &a) : iter(a) {}
|
|
163 T operator()(const int N1, const int N2, const int N3, const int N4) const
|
|
164 {
|
|
165 return iter(N1, N2, N3, N4);
|
|
166 }
|
|
167
|
|
168 template <class B, class U, int Dim1_0, int Dim1_1, int Dim1_2, int Dim1_3,
|
|
169 char i_1, char j_1, char k_1, char l_1>
|
|
170 auto &operator=(const Tensor4_Expr<B, U, Dim1_0, Dim1_1, Dim1_2, Dim1_3, i_1,
|
|
171 j_1, k_1, l_1> &rhs)
|
|
172 {
|
|
173 for(int ii = 0; ii < Dim0; ++ii)
|
|
174 for(int jj = 0; jj < Dim1; ++jj)
|
|
175 for(int kk = 0; kk < Dim2; ++kk)
|
|
176 for(int ll = 0; ll < Dim3; ++ll)
|
|
177 {
|
|
178 iter(ii, jj, kk, ll) = rhs(ii, jj, kk, ll);
|
|
179 }
|
|
180 return *this;
|
|
181 }
|
|
182 };
|
|
183
|
|
184 int main()
|
|
185 {
|
|
186 Tensor3<float, 100, 100, 1000> t1;
|
|
187 Tensor3<float, 1000, 100, 100> t2;
|
|
188
|
|
189 Index<'l', 100> l;
|
|
190 Index<'m', 100> m;
|
|
191 Index<'k', 1000> k;
|
|
192 Index<'n', 100> n;
|
|
193 Index<'o', 100> o;
|
|
194
|
|
195 Tensor4<float, 100, 100, 100, 100> res;
|
|
196 res(l, m, n, o) = t1(l, m, k) * t2(k, n, o);
|
|
197 return 0;
|
|
198 }
|
|
199
|