blob: 149c0beccb833ce241204185edc7a63b4ccfcdee [file] [log] [blame]
Tim Murray25207df2015-01-12 16:47:56 -08001/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import android.annotation.IntDef;
20import java.lang.annotation.Retention;
21import java.lang.annotation.RetentionPolicy;
22
23/**
24 *
25 * BLAS
26 *
Tim Murray25207df2015-01-12 16:47:56 -080027 **/
28public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
29 private Allocation mLUT;
30
31 private ScriptIntrinsicBLAS(long id, RenderScript rs) {
32 super(id, rs);
33 }
34
35 private static final int RsBlas_sdsdot = 1;
36 private static final int RsBlas_dsdot = 2;
37 private static final int RsBlas_sdot = 3;
38 private static final int RsBlas_ddot = 4;
39 private static final int RsBlas_cdotu_sub = 5;
40 private static final int RsBlas_cdotc_sub = 6;
41 private static final int RsBlas_zdotu_sub = 7;
42 private static final int RsBlas_zdotc_sub = 8;
43 private static final int RsBlas_snrm2 = 9;
44 private static final int RsBlas_sasum = 10;
45 private static final int RsBlas_dnrm2 = 11;
46 private static final int RsBlas_dasum = 12;
47 private static final int RsBlas_scnrm2 = 13;
48 private static final int RsBlas_scasum = 14;
49 private static final int RsBlas_dznrm2 = 15;
50 private static final int RsBlas_dzasum = 16;
51 private static final int RsBlas_isamax = 17;
52 private static final int RsBlas_idamax = 18;
53 private static final int RsBlas_icamax = 19;
54 private static final int RsBlas_izamax = 20;
55 private static final int RsBlas_sswap = 21;
56 private static final int RsBlas_scopy = 22;
57 private static final int RsBlas_saxpy = 23;
58 private static final int RsBlas_dswap = 24;
59 private static final int RsBlas_dcopy = 25;
60 private static final int RsBlas_daxpy = 26;
61 private static final int RsBlas_cswap = 27;
62 private static final int RsBlas_ccopy = 28;
63 private static final int RsBlas_caxpy = 29;
64 private static final int RsBlas_zswap = 30;
65 private static final int RsBlas_zcopy = 31;
66 private static final int RsBlas_zaxpy = 32;
67 private static final int RsBlas_srotg = 33;
68 private static final int RsBlas_srotmg = 34;
69 private static final int RsBlas_srot = 35;
70 private static final int RsBlas_srotm = 36;
71 private static final int RsBlas_drotg = 37;
72 private static final int RsBlas_drotmg = 38;
73 private static final int RsBlas_drot = 39;
74 private static final int RsBlas_drotm = 40;
75 private static final int RsBlas_sscal = 41;
76 private static final int RsBlas_dscal = 42;
77 private static final int RsBlas_cscal = 43;
78 private static final int RsBlas_zscal = 44;
79 private static final int RsBlas_csscal = 45;
80 private static final int RsBlas_zdscal = 46;
81 private static final int RsBlas_sgemv = 47;
82 private static final int RsBlas_sgbmv = 48;
83 private static final int RsBlas_strmv = 49;
84 private static final int RsBlas_stbmv = 50;
85 private static final int RsBlas_stpmv = 51;
86 private static final int RsBlas_strsv = 52;
87 private static final int RsBlas_stbsv = 53;
88 private static final int RsBlas_stpsv = 54;
89 private static final int RsBlas_dgemv = 55;
90 private static final int RsBlas_dgbmv = 56;
91 private static final int RsBlas_dtrmv = 57;
92 private static final int RsBlas_dtbmv = 58;
93 private static final int RsBlas_dtpmv = 59;
94 private static final int RsBlas_dtrsv = 60;
95 private static final int RsBlas_dtbsv = 61;
96 private static final int RsBlas_dtpsv = 62;
97 private static final int RsBlas_cgemv = 63;
98 private static final int RsBlas_cgbmv = 64;
99 private static final int RsBlas_ctrmv = 65;
100 private static final int RsBlas_ctbmv = 66;
101 private static final int RsBlas_ctpmv = 67;
102 private static final int RsBlas_ctrsv = 68;
103 private static final int RsBlas_ctbsv = 69;
104 private static final int RsBlas_ctpsv = 70;
105 private static final int RsBlas_zgemv = 71;
106 private static final int RsBlas_zgbmv = 72;
107 private static final int RsBlas_ztrmv = 73;
108 private static final int RsBlas_ztbmv = 74;
109 private static final int RsBlas_ztpmv = 75;
110 private static final int RsBlas_ztrsv = 76;
111 private static final int RsBlas_ztbsv = 77;
112 private static final int RsBlas_ztpsv = 78;
113 private static final int RsBlas_ssymv = 79;
114 private static final int RsBlas_ssbmv = 80;
115 private static final int RsBlas_sspmv = 81;
116 private static final int RsBlas_sger = 82;
117 private static final int RsBlas_ssyr = 83;
118 private static final int RsBlas_sspr = 84;
119 private static final int RsBlas_ssyr2 = 85;
120 private static final int RsBlas_sspr2 = 86;
121 private static final int RsBlas_dsymv = 87;
122 private static final int RsBlas_dsbmv = 88;
123 private static final int RsBlas_dspmv = 89;
124 private static final int RsBlas_dger = 90;
125 private static final int RsBlas_dsyr = 91;
126 private static final int RsBlas_dspr = 92;
127 private static final int RsBlas_dsyr2 = 93;
128 private static final int RsBlas_dspr2 = 94;
129 private static final int RsBlas_chemv = 95;
130 private static final int RsBlas_chbmv = 96;
131 private static final int RsBlas_chpmv = 97;
132 private static final int RsBlas_cgeru = 98;
133 private static final int RsBlas_cgerc = 99;
134 private static final int RsBlas_cher = 100;
135 private static final int RsBlas_chpr = 101;
136 private static final int RsBlas_cher2 = 102;
137 private static final int RsBlas_chpr2 = 103;
138 private static final int RsBlas_zhemv = 104;
139 private static final int RsBlas_zhbmv = 105;
140 private static final int RsBlas_zhpmv = 106;
141 private static final int RsBlas_zgeru = 107;
142 private static final int RsBlas_zgerc = 108;
143 private static final int RsBlas_zher = 109;
144 private static final int RsBlas_zhpr = 110;
145 private static final int RsBlas_zher2 = 111;
146 private static final int RsBlas_zhpr2 = 112;
147 private static final int RsBlas_sgemm = 113;
148 private static final int RsBlas_ssymm = 114;
149 private static final int RsBlas_ssyrk = 115;
150 private static final int RsBlas_ssyr2k = 116;
151 private static final int RsBlas_strmm = 117;
152 private static final int RsBlas_strsm = 118;
153 private static final int RsBlas_dgemm = 119;
154 private static final int RsBlas_dsymm = 120;
155 private static final int RsBlas_dsyrk = 121;
156 private static final int RsBlas_dsyr2k = 122;
157 private static final int RsBlas_dtrmm = 123;
158 private static final int RsBlas_dtrsm = 124;
159 private static final int RsBlas_cgemm = 125;
160 private static final int RsBlas_csymm = 126;
161 private static final int RsBlas_csyrk = 127;
162 private static final int RsBlas_csyr2k = 128;
163 private static final int RsBlas_ctrmm = 129;
164 private static final int RsBlas_ctrsm = 130;
165 private static final int RsBlas_zgemm = 131;
166 private static final int RsBlas_zsymm = 132;
167 private static final int RsBlas_zsyrk = 133;
168 private static final int RsBlas_zsyr2k = 134;
169 private static final int RsBlas_ztrmm = 135;
170 private static final int RsBlas_ztrsm = 136;
171 private static final int RsBlas_chemm = 137;
172 private static final int RsBlas_cherk = 138;
173 private static final int RsBlas_cher2k = 139;
174 private static final int RsBlas_zhemm = 140;
175 private static final int RsBlas_zherk = 141;
176 private static final int RsBlas_zher2k = 142;
177
Tim Murray9cb16a22015-04-01 11:07:16 -0700178 // BLAS extensions start here
179 private static final int RsBlas_bnnm = 1000;
180
Tim Murray25207df2015-01-12 16:47:56 -0800181 /**
182 */
183 public static ScriptIntrinsicBLAS create(RenderScript rs) {
184 long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs));
185 return new ScriptIntrinsicBLAS(id, rs);
186 }
187
188 @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE})
189 @Retention(RetentionPolicy.SOURCE)
190 public @interface Transpose {}
191
192 @IntDef({UPPER, LOWER})
193 @Retention(RetentionPolicy.SOURCE)
194 public @interface Uplo {}
195
196 @IntDef({NON_UNIT, UNIT})
197 @Retention(RetentionPolicy.SOURCE)
198 public @interface Diag {}
199
200 @IntDef({LEFT, RIGHT})
201 @Retention(RetentionPolicy.SOURCE)
202 public @interface Side {}
203
204 public static final int NO_TRANSPOSE = 111;
205 public static final int TRANSPOSE = 112;
206 public static final int CONJ_TRANSPOSE = 113;
207
208 public static final int UPPER = 121;
209 public static final int LOWER = 122;
210
211 public static final int NON_UNIT = 131;
212 public static final int UNIT = 132;
213
214 public static final int LEFT = 141;
215 public static final int RIGHT = 142;
216
217 static void validateSide(@Side int Side) {
218 if (Side != LEFT && Side != RIGHT) {
219 throw new RSRuntimeException("Invalid side passed to BLAS");
220 }
221 }
222
223 static void validateTranspose(@Transpose int Trans) {
224 if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE &&
225 Trans != CONJ_TRANSPOSE) {
226 throw new RSRuntimeException("Invalid transpose passed to BLAS");
227 }
228 }
229
230 static void validateConjTranspose(@Transpose int Trans) {
231 if (Trans != NO_TRANSPOSE &&
232 Trans != CONJ_TRANSPOSE) {
233 throw new RSRuntimeException("Invalid transpose passed to BLAS");
234 }
235 }
236
237 static void validateDiag(@Diag int Diag) {
238 if (Diag != NON_UNIT && Diag != UNIT) {
239 throw new RSRuntimeException("Invalid diag passed to BLAS");
240 }
241 }
242
243 static void validateUplo(@Uplo int Uplo) {
Miao Wang37ae07c2015-04-24 11:19:53 -0700244 if (Uplo != UPPER && Uplo != LOWER) {
Tim Murray25207df2015-01-12 16:47:56 -0800245 throw new RSRuntimeException("Invalid uplo passed to BLAS");
246 }
247 }
248
249
250 /**
251 * Level 2 BLAS
252 */
253
254 static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) {
255 validateTranspose(TransA);
256 int M = A.getType().getY();
257 int N = A.getType().getX();
258 if (!A.getType().getElement().isCompatible(e) ||
259 !X.getType().getElement().isCompatible(e) ||
260 !Y.getType().getElement().isCompatible(e)) {
261 throw new RSRuntimeException("Called BLAS with wrong Element type");
262 }
263 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
264 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
265 }
266
267 if (incX <= 0 || incY <= 0) {
268 throw new RSRuntimeException("Vector increments must be greater than 0");
269 }
270 int expectedXDim = -1, expectedYDim = -1;
271 if (TransA == NO_TRANSPOSE) {
272 expectedXDim = 1 + (N - 1) * incX;
273 expectedYDim = 1 + (M - 1) * incY;
274 } else {
275 expectedXDim = 1 + (M - 1) * incX;
276 expectedYDim = 1 + (N - 1) * incY;
277 }
278 if (X.getType().getX() != expectedXDim ||
Miao Wang68ca43e2015-04-23 15:06:09 -0700279 Y.getType().getX() != expectedYDim) {
Tim Murray25207df2015-01-12 16:47:56 -0800280 throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
281 }
282 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700283 public void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800284 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
285 int M = A.getType().getY();
286 int N = A.getType().getX();
287 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
288 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700289 public void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800290 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
291 int M = A.getType().getY();
292 int N = A.getType().getX();
293 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
294 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700295 public void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800296 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
297 int M = A.getType().getY();
298 int N = A.getType().getX();
299 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
300 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700301 public void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800302 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
303 int M = A.getType().getY();
304 int N = A.getType().getX();
305 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
306 }
307
Miao Wang89c3a5f2015-04-23 15:20:11 -0700308 public void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800309 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
310 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
311 if (KL < 0 || KU < 0) {
312 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
313 }
314 int M = A.getType().getY();
315 int N = A.getType().getX();
316 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
317 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700318 public void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800319 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
320 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
321 if (KL < 0 || KU < 0) {
322 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
323 }
324 int M = A.getType().getY();
325 int N = A.getType().getX();
326 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
327 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700328 public void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800329 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
330 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
331 if (KL < 0 || KU < 0) {
332 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
333 }
334 int M = A.getType().getY();
335 int N = A.getType().getX();
336 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
337 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700338 public void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800339 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
340 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
341 if (KL < 0 || KU < 0) {
342 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
343 }
344 int M = A.getType().getY();
345 int N = A.getType().getX();
346 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
347 }
348
Miao Wang68ca43e2015-04-23 15:06:09 -0700349 static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800350 validateTranspose(TransA);
Miao Wang68ca43e2015-04-23 15:06:09 -0700351 validateUplo(Uplo);
352 validateDiag(Diag);
Tim Murray25207df2015-01-12 16:47:56 -0800353 int N = A.getType().getY();
354 if (A.getType().getX() != N) {
355 throw new RSRuntimeException("A must be a square matrix for TRMV");
356 }
357 if (!A.getType().getElement().isCompatible(e) ||
358 !X.getType().getElement().isCompatible(e)) {
359 throw new RSRuntimeException("Called BLAS with wrong Element type");
360 }
361 if (X.getType().getY() > 1) {
362 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
363 }
364
365 if (incX <= 0) {
366 throw new RSRuntimeException("Vector increments must be greater than 0");
367 }
368 int expectedXDim = 1 + (N - 1) * incX;
369 if (X.getType().getX() != expectedXDim) {
370 throw new RSRuntimeException("Incorrect vector dimensions for TRMV");
371 }
372 }
373
374 static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
375 validateTranspose(TransA);
376 validateUplo(Uplo);
377 validateDiag(Diag);
378 if (!Ap.getType().getElement().isCompatible(e) ||
379 !X.getType().getElement().isCompatible(e)) {
380 throw new RSRuntimeException("Called BLAS with wrong Element type");
381 }
382 if (X.getType().getY() > 1) {
383 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
384 }
385
386 if (Ap.getType().getY() > 1) {
387 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
388 }
389
390 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
Miao Wang68ca43e2015-04-23 15:06:09 -0700391 //is it really doing anything?
Tim Murray25207df2015-01-12 16:47:56 -0800392 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
393 throw new RSRuntimeException("Invalid dimension for Ap");
394 }
Miao Wang68ca43e2015-04-23 15:06:09 -0700395 if (incX <= 0) {
396 throw new RSRuntimeException("Vector increments must be greater than 0");
397 }
Tim Murray25207df2015-01-12 16:47:56 -0800398 int expectedXDim = 1 + (N - 1) * incX;
399 if (X.getType().getX() != expectedXDim) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700400 throw new RSRuntimeException("Incorrect vector dimensions for TPMV");
Tim Murray25207df2015-01-12 16:47:56 -0800401 }
402
403 return N;
404 }
405
Miao Wang89c3a5f2015-04-23 15:20:11 -0700406 public void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700407 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800408 int N = A.getType().getY();
409 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
410 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700411 public void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700412 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800413 int N = A.getType().getY();
414 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
415 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700416 public void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700417 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800418 int N = A.getType().getY();
419 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
420 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700421 public void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700422 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800423 int N = A.getType().getY();
424 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
425 }
Miao Wang68ca43e2015-04-23 15:06:09 -0700426
Miao Wang89c3a5f2015-04-23 15:20:11 -0700427 public void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700428 // TBMV has the same requirements as TRMV + K >= 0
429 if (K < 0) {
430 throw new RSRuntimeException("K must be greater than or equal to 0");
431 }
432 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800433 int N = A.getType().getY();
434 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
435 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700436 public void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700437 // TBMV has the same requirements as TRMV + K >= 0
438 if (K < 0) {
439 throw new RSRuntimeException("K must be greater than or equal to 0");
440 }
441 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800442 int N = A.getType().getY();
443 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
444 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700445 public void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700446 // TBMV has the same requirements as TRMV + K >= 0
447 if (K < 0) {
448 throw new RSRuntimeException("K must be greater than or equal to 0");
449 }
450 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800451 int N = A.getType().getY();
452 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
453 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700454 public void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700455 // TBMV has the same requirements as TRMV + K >= 0
456 if (K < 0) {
457 throw new RSRuntimeException("K must be greater than or equal to 0");
458 }
459 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800460 int N = A.getType().getY();
461 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
462 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700463 public void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800464 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
465 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
466 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700467 public void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800468 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
469 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
470 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700471 public void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800472 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
473 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
474 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700475 public void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800476 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
477 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
478 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700479 public void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800480 // TRSV is the same as TRMV
Miao Wang68ca43e2015-04-23 15:06:09 -0700481 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800482 int N = A.getType().getY();
483 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
484
485 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700486 public void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800487 // TRSV is the same as TRMV
Miao Wang68ca43e2015-04-23 15:06:09 -0700488 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800489 int N = A.getType().getY();
490 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
491
492 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700493 public void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800494 // TRSV is the same as TRMV
Miao Wang68ca43e2015-04-23 15:06:09 -0700495 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800496 int N = A.getType().getY();
497 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
498
499 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700500 public void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800501 // TRSV is the same as TRMV
Miao Wang68ca43e2015-04-23 15:06:09 -0700502 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800503 int N = A.getType().getY();
504 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
505
506 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700507 public void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700508 // TBSV is the same as TRMV + K >= 0
509 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800510 int N = A.getType().getY();
511 if (K < 0) {
512 throw new RSRuntimeException("Number of diagonals must be positive");
513 }
514 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
515 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700516 public void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700517 // TBSV is the same as TRMV + K >= 0
518 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800519 int N = A.getType().getY();
520 if (K < 0) {
521 throw new RSRuntimeException("Number of diagonals must be positive");
522 }
523 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
524 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700525 public void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700526 // TBSV is the same as TRMV + K >= 0
527 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800528 int N = A.getType().getY();
529 if (K < 0) {
530 throw new RSRuntimeException("Number of diagonals must be positive");
531 }
532 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
533 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700534 public void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700535 // TBSV is the same as TRMV + K >= 0
536 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800537 int N = A.getType().getY();
538 if (K < 0) {
539 throw new RSRuntimeException("Number of diagonals must be positive");
540 }
541 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
542 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700543 public void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800544 // TPSV is same as TPMV
545 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
546 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
547 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700548 public void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800549 // TPSV is same as TPMV
550 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
551 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
552 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700553 public void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800554 // TPSV is same as TPMV
555 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
556 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
557 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700558 public void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800559 // TPSV is same as TPMV
560 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
561 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
562 }
563
564 /**
565 * Level 2, S and D only
566 */
567 static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) {
568 validateUplo(Uplo);
569 int N = A.getType().getY();
570 if (A.getType().getX() != N) {
571 throw new RSRuntimeException("A must be a square matrix for SYMV");
572 }
573 if (!A.getType().getElement().isCompatible(e) ||
574 !X.getType().getElement().isCompatible(e) ||
575 !Y.getType().getElement().isCompatible(e) ) {
576 throw new RSRuntimeException("Called BLAS with wrong Element type");
577 }
578 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
579 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
580 }
581
582 if (incX <= 0 || incY <= 0) {
583 throw new RSRuntimeException("Vector increments must be greater than 0");
584 }
585 int expectedXDim = 1 + (N - 1) * incX;
586 if (X.getType().getX() != expectedXDim) {
587 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
588 }
589 int expectedYDim = 1 + (N - 1) * incY;
590 if (Y.getType().getX() != expectedYDim) {
591 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
592 }
593 return N;
594 }
595 static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) {
596 validateUplo(Uplo);
597 if (!Ap.getType().getElement().isCompatible(e) ||
598 !X.getType().getElement().isCompatible(e) ||
599 !Y.getType().getElement().isCompatible(e)) {
600 throw new RSRuntimeException("Called BLAS with wrong Element type");
601 }
602 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
603 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
604 }
605
606 if (Ap.getType().getY() > 1) {
607 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
608 }
609
610 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
611 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
612 throw new RSRuntimeException("Invalid dimension for Ap");
613 }
Miao Wang68ca43e2015-04-23 15:06:09 -0700614 if (incX <= 0 || incY <= 0) {
615 throw new RSRuntimeException("Vector increments must be greater than 0");
616 }
Tim Murray25207df2015-01-12 16:47:56 -0800617 int expectedXDim = 1 + (N - 1) * incX;
618 if (X.getType().getX() != expectedXDim) {
619 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
620 }
621 int expectedYDim = 1 + (N - 1) * incY;
622 if (Y.getType().getX() != expectedYDim) {
623 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
624 }
625
626 return N;
627 }
628 static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
629 if (!A.getType().getElement().isCompatible(e) ||
630 !X.getType().getElement().isCompatible(e) ||
631 !Y.getType().getElement().isCompatible(e) ) {
632 throw new RSRuntimeException("Called BLAS with wrong Element type");
633 }
634
635 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
636 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
637 }
638
639 int M = A.getType().getY();
640 int N = A.getType().getX();
641
642 if (N < 1 || M < 1) {
643 throw new RSRuntimeException("M and N must be 1 or greater for GER");
644 }
Miao Wang68ca43e2015-04-23 15:06:09 -0700645 if (incX <= 0 || incY <= 0) {
646 throw new RSRuntimeException("Vector increments must be greater than 0");
647 }
648 int expectedXDim = 1 + (M - 1) * incX;
Tim Murray25207df2015-01-12 16:47:56 -0800649 if (X.getType().getX() != expectedXDim) {
650 throw new RSRuntimeException("Incorrect vector dimensions for GER");
651 }
652 int expectedYDim = 1 + (N - 1) * incY;
653 if (Y.getType().getX() != expectedYDim) {
654 throw new RSRuntimeException("Incorrect vector dimensions for GER");
655 }
656
657
658 }
659 static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) {
660 validateUplo(Uplo);
661 if (!A.getType().getElement().isCompatible(e) ||
662 !X.getType().getElement().isCompatible(e)) {
663 throw new RSRuntimeException("Called BLAS with wrong Element type");
664 }
665
666 int N = A.getType().getX();
667
668 if (X.getType().getY() > 1) {
669 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
670 }
671 if (N != A.getType().getY()) {
672 throw new RSRuntimeException("A must be a symmetric matrix");
673 }
Miao Wang68ca43e2015-04-23 15:06:09 -0700674 if (incX <= 0) {
675 throw new RSRuntimeException("Vector increments must be greater than 0");
676 }
Tim Murray25207df2015-01-12 16:47:56 -0800677 int expectedXDim = 1 + (N - 1) * incX;
678 if (X.getType().getX() != expectedXDim) {
679 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
680 }
681 return N;
682 }
683 static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) {
684 validateUplo(Uplo);
685 if (!Ap.getType().getElement().isCompatible(e) ||
686 !X.getType().getElement().isCompatible(e)) {
687 throw new RSRuntimeException("Called BLAS with wrong Element type");
688 }
689 if (X.getType().getY() > 1) {
690 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
691 }
692
693 if (Ap.getType().getY() > 1) {
694 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
695 }
696
697 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
698 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
699 throw new RSRuntimeException("Invalid dimension for Ap");
700 }
Miao Wang68ca43e2015-04-23 15:06:09 -0700701 if (incX <= 0) {
702 throw new RSRuntimeException("Vector increments must be greater than 0");
703 }
Tim Murray25207df2015-01-12 16:47:56 -0800704 int expectedXDim = 1 + (N - 1) * incX;
705 if (X.getType().getX() != expectedXDim) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700706 throw new RSRuntimeException("Incorrect vector dimensions for SPR");
Tim Murray25207df2015-01-12 16:47:56 -0800707 }
708
709 return N;
710 }
711
712 static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
713 validateUplo(Uplo);
714 if (!A.getType().getElement().isCompatible(e) ||
715 !X.getType().getElement().isCompatible(e) ||
716 !Y.getType().getElement().isCompatible(e)) {
717 throw new RSRuntimeException("Called BLAS with wrong Element type");
718 }
719
720 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
721 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
722 }
723
724 int N = A.getType().getX();
725
726 if (N != A.getType().getY()) {
727 throw new RSRuntimeException("A must be a symmetric matrix");
728 }
Miao Wang68ca43e2015-04-23 15:06:09 -0700729 if (incX <= 0 || incY <= 0) {
730 throw new RSRuntimeException("Vector increments must be greater than 0");
731 }
Tim Murray25207df2015-01-12 16:47:56 -0800732 int expectedXDim = 1 + (N - 1) * incX;
733 int expectedYDim = 1 + (N - 1) * incY;
734 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
735 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
736 }
737 return N;
738
739 }
740 static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
741 validateUplo(Uplo);
742 if (!Ap.getType().getElement().isCompatible(e) ||
743 !X.getType().getElement().isCompatible(e) ||
744 !Y.getType().getElement().isCompatible(e)) {
745 throw new RSRuntimeException("Called BLAS with wrong Element type");
746 }
747 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
748 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
749 }
750
751 if (Ap.getType().getY() > 1) {
752 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
753 }
754
755 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
756 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
757 throw new RSRuntimeException("Invalid dimension for Ap");
758 }
Miao Wang68ca43e2015-04-23 15:06:09 -0700759 if (incX <= 0 || incY <= 0) {
760 throw new RSRuntimeException("Vector increments must be greater than 0");
761 }
Tim Murray25207df2015-01-12 16:47:56 -0800762 int expectedXDim = 1 + (N - 1) * incX;
763 int expectedYDim = 1 + (N - 1) * incY;
764 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700765 throw new RSRuntimeException("Incorrect vector dimensions for SPR2");
Tim Murray25207df2015-01-12 16:47:56 -0800766 }
767
768 return N;
769 }
770
Miao Wang89c3a5f2015-04-23 15:20:11 -0700771 public void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800772 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
773 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
774 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700775 public void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700776 // SBMV is the same as SYMV + K >= 0
777 if (K < 0) {
778 throw new RSRuntimeException("K must be greater than or equal to 0");
779 }
Tim Murray25207df2015-01-12 16:47:56 -0800780 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
781 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
782 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700783 public void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800784 int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY);
785 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
786 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700787 public void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800788 int M = A.getType().getY();
789 int N = A.getType().getX();
Miao Wang68ca43e2015-04-23 15:06:09 -0700790 validateGER(Element.F32(mRS), X, incX, Y, incY, A);
Tim Murray25207df2015-01-12 16:47:56 -0800791 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
792 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700793 public void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800794 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
795 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
796 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700797 public void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800798 int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap);
799 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
800 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700801 public void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800802 int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A);
803 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
804 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700805 public void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800806 int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap);
807 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
808 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700809 public void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800810 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
811 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
812 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700813 public void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
Miao Wang68ca43e2015-04-23 15:06:09 -0700814 // SBMV is the same as SYMV + K >= 0
815 if (K < 0) {
816 throw new RSRuntimeException("K must be greater than or equal to 0");
817 }
Tim Murray25207df2015-01-12 16:47:56 -0800818 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
819 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
820 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700821 public void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800822 int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY);
823 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
824 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700825 public void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800826 int M = A.getType().getY();
827 int N = A.getType().getX();
Miao Wang68ca43e2015-04-23 15:06:09 -0700828 validateGER(Element.F64(mRS), X, incX, Y, incY, A);
Tim Murray25207df2015-01-12 16:47:56 -0800829 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
830 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700831 public void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800832 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
833 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
834 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700835 public void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800836 int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap);
837 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
838 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700839 public void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800840 int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A);
841 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
842 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700843 public void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800844 int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap);
845 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
846 }
847
848
849 /**
850 * Level 2, C and Z only
851 */
852
853 static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
854 if (!A.getType().getElement().isCompatible(e) ||
855 !X.getType().getElement().isCompatible(e) ||
856 !Y.getType().getElement().isCompatible(e)) {
857 throw new RSRuntimeException("Called BLAS with wrong Element type");
858 }
859 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
860 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
861 }
862
863 int M = A.getType().getY();
864 int N = A.getType().getX();
Miao Wang68ca43e2015-04-23 15:06:09 -0700865 if (incX <= 0 || incY <= 0) {
866 throw new RSRuntimeException("Vector increments must be greater than 0");
867 }
868 int expectedXDim = 1 + (M - 1) * incX;
Tim Murray25207df2015-01-12 16:47:56 -0800869 if (X.getType().getX() != expectedXDim) {
870 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
871 }
872 int expectedYDim = 1 + (N - 1) * incY;
873 if (Y.getType().getX() != expectedYDim) {
874 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
875 }
876
877 }
878
Miao Wang89c3a5f2015-04-23 15:20:11 -0700879 public void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800880 // HEMV is the same as SYR2 validation-wise
881 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
882 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
883 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700884 public void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800885 // HBMV is the same as SYR2 validation-wise
886 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
887 if (K < 0) {
888 throw new RSRuntimeException("K must be 0 or greater for HBMV");
889 }
890 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
891 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700892 public void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800893 // HPMV is the same as SPR2
894 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
895 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
896 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700897 public void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800898 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
899 int M = A.getType().getY();
900 int N = A.getType().getX();
901 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
902 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700903 public void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800904 // same as GERU
905 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
906 int M = A.getType().getY();
907 int N = A.getType().getX();
908 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
909 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700910 public void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800911 // same as SYR
Miao Wang68ca43e2015-04-23 15:06:09 -0700912 int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A);
Tim Murray25207df2015-01-12 16:47:56 -0800913 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
914 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700915 public void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800916 // equivalent to SPR for validation
917 int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap);
918 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
919 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700920 public void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800921 // same as SYR2
922 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
923 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
924 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700925 public void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800926 // same as SPR2
927 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
928 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
929 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700930 public void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800931 // HEMV is the same as SYR2 validation-wise
932 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
933 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
934 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700935 public void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800936 // HBMV is the same as SYR2 validation-wise
937 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
938 if (K < 0) {
939 throw new RSRuntimeException("K must be 0 or greater for HBMV");
940 }
941 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
942 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700943 public void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800944 // HPMV is the same as SPR2
945 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
946 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
947 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700948 public void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800949 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
950 int M = A.getType().getY();
951 int N = A.getType().getX();
952 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
953 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700954 public void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800955 // same as GERU
956 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
957 int M = A.getType().getY();
958 int N = A.getType().getX();
959 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
960 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700961 public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800962 // same as SYR
Miao Wangcecc00a2015-04-29 18:14:55 -0700963 int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A);
Tim Murray25207df2015-01-12 16:47:56 -0800964 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
965 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700966 public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800967 // equivalent to SPR for validation
968 int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap);
969 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
970 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700971 public void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800972 // same as SYR2
973 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
974 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
975 }
Miao Wang89c3a5f2015-04-23 15:20:11 -0700976 public void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800977 // same as SPR2
978 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
979 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
980 }
981
982
983 /**
984 * Level 3 BLAS
985 */
986
987 static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
Miao Wang37ae07c2015-04-24 11:19:53 -0700988 int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
Tim Murray25207df2015-01-12 16:47:56 -0800989 if ((A != null && !A.getType().getElement().isCompatible(e)) ||
990 (B != null && !B.getType().getElement().isCompatible(e)) ||
991 (C != null && !C.getType().getElement().isCompatible(e))) {
992 throw new RSRuntimeException("Called BLAS with wrong Element type");
993 }
Miao Wang37ae07c2015-04-24 11:19:53 -0700994 if (C == null) {
995 //since matrix C is used to store the result, it cannot be null.
996 throw new RSRuntimeException("Allocation C cannot be null");
Tim Murray25207df2015-01-12 16:47:56 -0800997 }
Miao Wang37ae07c2015-04-24 11:19:53 -0700998 cM = C.getType().getY();
999 cN = C.getType().getX();
1000
Tim Murray25207df2015-01-12 16:47:56 -08001001 if (Side == RIGHT) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001002 if ((A == null && B != null) || (A != null && B == null)) {
1003 throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa");
1004 }
Tim Murray25207df2015-01-12 16:47:56 -08001005 if (B != null) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001006 bM = A.getType().getY();
1007 bN = A.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001008 }
1009 if (A != null) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001010 aM = B.getType().getY();
1011 aN = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001012 }
1013 } else {
1014 if (A != null) {
Miao Wang1e940d82015-04-30 10:47:42 -07001015 if (TransA == TRANSPOSE || TransA == CONJ_TRANSPOSE) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001016 aN = A.getType().getY();
1017 aM = A.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001018 } else {
Miao Wang37ae07c2015-04-24 11:19:53 -07001019 aM = A.getType().getY();
1020 aN = A.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001021 }
1022 }
1023 if (B != null) {
Miao Wang1e940d82015-04-30 10:47:42 -07001024 if (TransB == TRANSPOSE || TransB == CONJ_TRANSPOSE) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001025 bN = B.getType().getY();
1026 bM = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001027 } else {
Miao Wang37ae07c2015-04-24 11:19:53 -07001028 bM = B.getType().getY();
1029 bN = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001030 }
1031 }
1032 }
1033 if (A != null && B != null && C != null) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001034 if (aN != bM || aM != cM || bN != cN) {
Tim Murray25207df2015-01-12 16:47:56 -08001035 throw new RSRuntimeException("Called BLAS with invalid dimensions");
1036 }
1037 } else if (A != null && C != null) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001038 // A and C only, for SYRK
1039 if (cM != cN) {
1040 throw new RSRuntimeException("Matrix C is not symmetric");
1041 }
1042 if (TransA != NO_TRANSPOSE) {
1043 if (aN != cM) {
1044 throw new RSRuntimeException("Called BLAS with invalid dimensions");
1045 }
1046 } else {
1047 if (aM != cM) {
1048 throw new RSRuntimeException("Called BLAS with invalid dimensions");
1049 }
Tim Murray25207df2015-01-12 16:47:56 -08001050 }
1051 } else if (A != null && B != null) {
1052 // A and B only
Miao Wang37ae07c2015-04-24 11:19:53 -07001053 if (aN != bM) {
1054 throw new RSRuntimeException("Called BLAS with invalid dimensions");
1055 }
Tim Murray25207df2015-01-12 16:47:56 -08001056 }
1057
1058 }
1059
1060 public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A,
1061 Allocation B, float beta, Allocation C) {
1062 validateTranspose(TransA);
1063 validateTranspose(TransB);
1064 validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
1065
1066 int M = -1, N = -1, K = -1;
Miao Wang37ae07c2015-04-24 11:19:53 -07001067 if (TransA != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001068 M = A.getType().getX();
1069 K = A.getType().getY();
1070 } else {
1071 M = A.getType().getY();
1072 K = A.getType().getX();
1073 }
Miao Wang37ae07c2015-04-24 11:19:53 -07001074 if (TransB != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001075 N = B.getType().getY();
1076 } else {
1077 N = B.getType().getX();
1078 }
1079 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1080 beta, C.getID(mRS), 0, 0, 0, 0);
1081 }
1082 public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A,
1083 Allocation B, double beta, Allocation C) {
1084 validateTranspose(TransA);
1085 validateTranspose(TransB);
1086 validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
1087 int M = -1, N = -1, K = -1;
Miao Wang37ae07c2015-04-24 11:19:53 -07001088 if (TransA != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001089 M = A.getType().getX();
1090 K = A.getType().getY();
1091 } else {
1092 M = A.getType().getY();
1093 K = A.getType().getX();
1094 }
Miao Wang37ae07c2015-04-24 11:19:53 -07001095 if (TransB != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001096 N = B.getType().getY();
1097 } else {
1098 N = B.getType().getX();
1099 }
1100 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1101 beta, C.getID(mRS), 0, 0, 0, 0);
1102 }
1103 public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A,
1104 Allocation B, Float2 beta, Allocation C) {
1105 validateTranspose(TransA);
1106 validateTranspose(TransB);
1107 validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
1108 int M = -1, N = -1, K = -1;
Miao Wang37ae07c2015-04-24 11:19:53 -07001109 if (TransA != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001110 M = A.getType().getX();
1111 K = A.getType().getY();
1112 } else {
1113 M = A.getType().getY();
1114 K = A.getType().getX();
1115 }
Miao Wang37ae07c2015-04-24 11:19:53 -07001116 if (TransB != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001117 N = B.getType().getY();
1118 } else {
1119 N = B.getType().getX();
1120 }
1121 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1122 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1123 }
1124
1125 public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A,
1126 Allocation B, Double2 beta, Allocation C) {
1127 validateTranspose(TransA);
1128 validateTranspose(TransB);
1129 validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
1130 int M = -1, N = -1, K = -1;
Miao Wang37ae07c2015-04-24 11:19:53 -07001131 if (TransA != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001132 M = A.getType().getX();
1133 K = A.getType().getY();
1134 } else {
1135 M = A.getType().getY();
1136 K = A.getType().getX();
1137 }
Miao Wang37ae07c2015-04-24 11:19:53 -07001138 if (TransB != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001139 N = B.getType().getY();
1140 } else {
1141 N = B.getType().getX();
1142 }
1143 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1144 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1145 }
1146
1147 public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A,
1148 Allocation B, float beta, Allocation C) {
1149 validateSide(Side);
1150 validateUplo(Uplo);
Miao Wang37ae07c2015-04-24 11:19:53 -07001151 //For SYMM, Matrix A should be symmetric
1152 if (A.getType().getX() != A.getType().getY()) {
1153 throw new RSRuntimeException("Matrix A is not symmetric");
1154 }
Tim Murray25207df2015-01-12 16:47:56 -08001155 validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
1156 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1157 beta, C.getID(mRS), 0, 0, 0, 0);
1158 }
1159 public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A,
1160 Allocation B, double beta, Allocation C) {
1161 validateSide(Side);
1162 validateUplo(Uplo);
Miao Wang37ae07c2015-04-24 11:19:53 -07001163 if (A.getType().getX() != A.getType().getY()) {
1164 throw new RSRuntimeException("Matrix A is not symmetric");
1165 }
Tim Murray25207df2015-01-12 16:47:56 -08001166 validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
1167 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1168 beta, C.getID(mRS), 0, 0, 0, 0);
1169 }
1170 public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A,
1171 Allocation B, Float2 beta, Allocation C) {
1172 validateSide(Side);
1173 validateUplo(Uplo);
Miao Wang37ae07c2015-04-24 11:19:53 -07001174 if (A.getType().getX() != A.getType().getY()) {
1175 throw new RSRuntimeException("Matrix A is not symmetric");
1176 }
Tim Murray25207df2015-01-12 16:47:56 -08001177 validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
1178 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1179 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1180 }
1181 public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A,
1182 Allocation B, Double2 beta, Allocation C) {
1183 validateSide(Side);
1184 validateUplo(Uplo);
Miao Wang37ae07c2015-04-24 11:19:53 -07001185 if (A.getType().getX() != A.getType().getY()) {
1186 throw new RSRuntimeException("Matrix A is not symmetric");
1187 }
Tim Murray25207df2015-01-12 16:47:56 -08001188 validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
1189 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1190 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1191 }
1192
1193 public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1194 validateTranspose(Trans);
1195 validateUplo(Uplo);
1196 validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
1197 int K = -1;
Miao Wang37ae07c2015-04-24 11:19:53 -07001198 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001199 K = A.getType().getY();
1200 } else {
1201 K = A.getType().getX();
1202 }
1203
1204 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1205 }
1206
1207 public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1208 validateTranspose(Trans);
1209 validateUplo(Uplo);
1210 validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
1211 int K = -1;
Miao Wang37ae07c2015-04-24 11:19:53 -07001212 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001213 K = A.getType().getY();
1214 } else {
1215 K = A.getType().getX();
1216 }
1217 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1218 }
Miao Wang4c472742015-04-22 15:57:57 -07001219 public void CSYRK(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Float2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001220 validateTranspose(Trans);
1221 validateUplo(Uplo);
1222 validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
1223 int K = -1;
Miao Wang37ae07c2015-04-24 11:19:53 -07001224 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001225 K = A.getType().getY();
1226 } else {
1227 K = A.getType().getX();
1228 }
Miao Wang4c472742015-04-22 15:57:57 -07001229 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
Tim Murray25207df2015-01-12 16:47:56 -08001230 C.getID(mRS), 0, 0, 0, 0);
1231 }
Miao Wang4c472742015-04-22 15:57:57 -07001232 public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Double2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001233 validateTranspose(Trans);
1234 validateUplo(Uplo);
1235 validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
1236 int K = -1;
Miao Wang37ae07c2015-04-24 11:19:53 -07001237 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001238 K = A.getType().getY();
1239 } else {
1240 K = A.getType().getX();
1241 }
Miao Wang4c472742015-04-22 15:57:57 -07001242 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
Tim Murray25207df2015-01-12 16:47:56 -08001243 C.getID(mRS), 0, 0, 0, 0);
1244 }
1245
1246 static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1247 validateTranspose(Trans);
1248 if (!A.getType().getElement().isCompatible(e) ||
1249 !B.getType().getElement().isCompatible(e) ||
1250 !C.getType().getElement().isCompatible(e)) {
1251 throw new RSRuntimeException("Called BLAS with wrong Element type");
1252 }
1253 int Cdim = -1;
1254 // A is n x k if no transpose, k x n if transpose
1255 // C is n x n
1256 if (Trans == TRANSPOSE) {
1257 // check columns versus C
1258 Cdim = A.getType().getX();
1259 } else {
1260 // check rows versus C
1261 Cdim = A.getType().getY();
1262 }
Miao Wang37ae07c2015-04-24 11:19:53 -07001263 if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
Tim Murray25207df2015-01-12 16:47:56 -08001264 throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
1265 }
1266 // A dims == B dims
1267 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1268 throw new RSRuntimeException("Invalid A and B in SYR2K");
1269 }
1270 }
1271 public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
1272 validateUplo(Uplo);
1273 validateSYR2K(Element.F32(mRS), Trans, A, B, C);
1274 int K = -1;
Miao Wang1e940d82015-04-30 10:47:42 -07001275 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001276 K = A.getType().getY();
1277 } else {
1278 K = A.getType().getX();
1279 }
1280 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1281 }
1282 public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
1283 validateUplo(Uplo);
1284 validateSYR2K(Element.F64(mRS), Trans, A, B, C);
1285 int K = -1;
Miao Wang1e940d82015-04-30 10:47:42 -07001286 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001287 K = A.getType().getY();
1288 } else {
1289 K = A.getType().getX();
1290 }
1291 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1292 }
1293 public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
1294 validateUplo(Uplo);
1295 validateSYR2K(Element.F32_2(mRS), Trans, A, B, C);
1296 int K = -1;
Miao Wang1e940d82015-04-30 10:47:42 -07001297 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001298 K = A.getType().getY();
1299 } else {
1300 K = A.getType().getX();
1301 }
1302 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1303 }
1304 public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
1305 validateUplo(Uplo);
1306 validateSYR2K(Element.F64_2(mRS), Trans, A, B, C);
1307 int K = -1;
Miao Wang1e940d82015-04-30 10:47:42 -07001308 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001309 K = A.getType().getY();
1310 } else {
1311 K = A.getType().getX();
1312 }
1313 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1314 }
1315
1316 static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1317 validateSide(Side);
1318 validateTranspose(TransA);
Miao Wang37ae07c2015-04-24 11:19:53 -07001319 int aM = -1, aN = -1, bM = -1, bN = -1;
Tim Murray25207df2015-01-12 16:47:56 -08001320 if (!A.getType().getElement().isCompatible(e) ||
1321 !B.getType().getElement().isCompatible(e)) {
1322 throw new RSRuntimeException("Called BLAS with wrong Element type");
1323 }
Miao Wang37ae07c2015-04-24 11:19:53 -07001324
1325 aM = A.getType().getY();
1326 aN = A.getType().getX();
1327 if (aM != aN) {
1328 throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A");
Tim Murray25207df2015-01-12 16:47:56 -08001329 }
Miao Wang37ae07c2015-04-24 11:19:53 -07001330
1331 bM = B.getType().getY();
1332 bN = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001333 if (Side == LEFT) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001334 if (aN != bM) {
Tim Murray25207df2015-01-12 16:47:56 -08001335 throw new RSRuntimeException("Called TRMM with invalid matrices");
1336 }
1337 } else {
Miao Wang37ae07c2015-04-24 11:19:53 -07001338 if (bN != aM) {
Tim Murray25207df2015-01-12 16:47:56 -08001339 throw new RSRuntimeException("Called TRMM with invalid matrices");
1340 }
1341 }
1342 }
1343 public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1344 validateUplo(Uplo);
1345 validateDiag(Diag);
1346 validateTRMM(Element.F32(mRS), Side, TransA, A, B);
1347 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1348 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1349 }
1350 public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1351 validateUplo(Uplo);
1352 validateDiag(Diag);
1353 validateTRMM(Element.F64(mRS), Side, TransA, A, B);
1354 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1355 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1356 }
1357 public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1358 validateUplo(Uplo);
1359 validateDiag(Diag);
1360 validateTRMM(Element.F32_2(mRS), Side, TransA, A, B);
1361 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1362 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1363 }
1364 public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1365 validateUplo(Uplo);
1366 validateDiag(Diag);
1367 validateTRMM(Element.F64_2(mRS), Side, TransA, A, B);
1368 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1369 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1370 }
1371
1372 static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001373 int adim = -1, bM = -1, bN = -1;
Tim Murray25207df2015-01-12 16:47:56 -08001374 validateSide(Side);
1375 validateTranspose(TransA);
1376 if (!A.getType().getElement().isCompatible(e) ||
1377 !B.getType().getElement().isCompatible(e)) {
1378 throw new RSRuntimeException("Called BLAS with wrong Element type");
1379 }
1380 adim = A.getType().getX();
1381 if (adim != A.getType().getY()) {
1382 // this may be unnecessary, the restriction could potentially be relaxed
1383 // A needs to contain at least that symmetric matrix but could theoretically be larger
1384 // for now we assume adapters are sufficient, will reevaluate in the future
1385 throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
1386 }
Miao Wang37ae07c2015-04-24 11:19:53 -07001387 bM = B.getType().getY();
1388 bN = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001389 if (Side == LEFT) {
1390 // A is M*M
Miao Wang37ae07c2015-04-24 11:19:53 -07001391 if (adim != bM) {
Tim Murray25207df2015-01-12 16:47:56 -08001392 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1393 }
1394 } else {
1395 // A is N*N
Miao Wang37ae07c2015-04-24 11:19:53 -07001396 if (adim != bN) {
Tim Murray25207df2015-01-12 16:47:56 -08001397 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1398 }
1399 }
1400 }
1401 public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1402 validateUplo(Uplo);
1403 validateDiag(Diag);
1404 validateTRSM(Element.F32(mRS), Side, TransA, A, B);
1405 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1406 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1407 }
1408 public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1409 validateUplo(Uplo);
1410 validateDiag(Diag);
1411 validateTRSM(Element.F64(mRS), Side, TransA, A, B);
1412 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1413 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1414 }
1415 public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1416 validateUplo(Uplo);
1417 validateDiag(Diag);
1418 validateTRSM(Element.F32_2(mRS), Side, TransA, A, B);
1419 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1420 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1421 }
1422 public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1423 validateUplo(Uplo);
1424 validateDiag(Diag);
1425 validateTRSM(Element.F64_2(mRS), Side, TransA, A, B);
1426 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1427 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1428 }
1429
1430 static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) {
1431 validateSide(Side);
1432
1433 if (!A.getType().getElement().isCompatible(e) ||
1434 !B.getType().getElement().isCompatible(e) ||
1435 !C.getType().getElement().isCompatible(e)) {
1436 throw new RSRuntimeException("Called BLAS with wrong Element type");
1437 }
1438
1439 // A must be square; can potentially be relaxed similar to TRSM
1440 int adim = A.getType().getX();
1441 if (adim != A.getType().getY()) {
1442 throw new RSRuntimeException("Called HEMM with non-square A");
1443 }
1444 if ((Side == LEFT && adim != B.getType().getY()) ||
1445 (Side == RIGHT && adim != B.getType().getX())) {
1446 throw new RSRuntimeException("Called HEMM with invalid B");
1447 }
1448 if (B.getType().getX() != C.getType().getX() ||
1449 B.getType().getY() != C.getType().getY()) {
1450 throw new RSRuntimeException("Called HEMM with mismatched B and C");
1451 }
1452 }
Miao Wang4c472742015-04-22 15:57:57 -07001453 public void CHEMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001454 validateUplo(Uplo);
1455 validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1456 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
Miao Wang4c472742015-04-22 15:57:57 -07001457 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001458 }
Miao Wang4c472742015-04-22 15:57:57 -07001459 public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001460 validateUplo(Uplo);
Miao Wang37ae07c2015-04-24 11:19:53 -07001461 validateHEMM(Element.F64_2(mRS), Side, A, B, C);
Tim Murray25207df2015-01-12 16:47:56 -08001462 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
Miao Wang4c472742015-04-22 15:57:57 -07001463 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001464 }
1465
1466 static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) {
1467 if (!A.getType().getElement().isCompatible(e) ||
1468 !C.getType().getElement().isCompatible(e)) {
1469 throw new RSRuntimeException("Called BLAS with wrong Element type");
1470 }
1471 validateConjTranspose(Trans);
1472 int cdim = C.getType().getX();
1473 if (cdim != C.getType().getY()) {
1474 throw new RSRuntimeException("Called HERK with non-square C");
1475 }
1476 if (Trans == NO_TRANSPOSE) {
Miao Wang37ae07c2015-04-24 11:19:53 -07001477 if (cdim != A.getType().getY()) {
Tim Murray25207df2015-01-12 16:47:56 -08001478 throw new RSRuntimeException("Called HERK with invalid A");
1479 }
1480 } else {
Miao Wang37ae07c2015-04-24 11:19:53 -07001481 if (cdim != A.getType().getX()) {
Tim Murray25207df2015-01-12 16:47:56 -08001482 throw new RSRuntimeException("Called HERK with invalid A");
1483 }
1484 }
1485 }
1486 public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1487 validateUplo(Uplo);
1488 validateHERK(Element.F32_2(mRS), Trans, A, C);
1489 int k = 0;
Miao Wang37ae07c2015-04-24 11:19:53 -07001490 if (Trans == CONJ_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001491 k = A.getType().getY();
1492 } else {
1493 k = A.getType().getX();
1494 }
1495 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1496 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1497 }
1498 public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1499 validateUplo(Uplo);
1500 validateHERK(Element.F64_2(mRS), Trans, A, C);
1501 int k = 0;
Miao Wang37ae07c2015-04-24 11:19:53 -07001502 if (Trans == CONJ_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001503 k = A.getType().getY();
1504 } else {
1505 k = A.getType().getX();
1506 }
1507 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1508 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1509 }
1510
1511 static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1512 if (!A.getType().getElement().isCompatible(e) ||
1513 !B.getType().getElement().isCompatible(e) ||
1514 !C.getType().getElement().isCompatible(e)) {
1515 throw new RSRuntimeException("Called BLAS with wrong Element type");
1516 }
1517 validateConjTranspose(Trans);
1518 int cdim = C.getType().getX();
1519 if (cdim != C.getType().getY()) {
1520 throw new RSRuntimeException("Called HER2K with non-square C");
1521 }
1522 if (Trans == NO_TRANSPOSE) {
1523 if (A.getType().getY() != cdim) {
1524 throw new RSRuntimeException("Called HER2K with invalid matrices");
1525 }
1526 } else {
1527 if (A.getType().getX() != cdim) {
1528 throw new RSRuntimeException("Called HER2K with invalid matrices");
1529 }
1530 }
1531 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1532 throw new RSRuntimeException("Called HER2K with invalid A and B matrices");
1533 }
1534 }
1535 public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) {
1536 validateUplo(Uplo);
1537 validateHER2K(Element.F32_2(mRS), Trans, A, B, C);
1538 int k = 0;
1539 if (Trans == NO_TRANSPOSE) {
1540 k = A.getType().getX();
1541 } else {
1542 k = A.getType().getY();
1543 }
1544 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1545 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1546 }
1547 public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) {
1548 validateUplo(Uplo);
1549 validateHER2K(Element.F64_2(mRS), Trans, A, B, C);
1550 int k = 0;
1551 if (Trans == NO_TRANSPOSE) {
1552 k = A.getType().getX();
1553 } else {
1554 k = A.getType().getY();
1555 }
1556 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1557 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1558 }
1559
1560
Tim Murray9cb16a22015-04-01 11:07:16 -07001561 /**
1562 *
1563 * 8-bit GEMM-like operation for neural networks
1564 *
Tim Murray9cb16a22015-04-01 11:07:16 -07001565 **/
1566 public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) {
1567 validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C);
1568
1569 int M = -1, N = -1, K = -1;
1570 M = A.getType().getY();
1571 N = B.getType().getY();
1572 K = A.getType().getX();
1573
1574
1575 mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult);
1576
1577 }
Tim Murray25207df2015-01-12 16:47:56 -08001578
1579}