blob: 16b703356b47796e7e9671d62c9f2ec4ff41dd58 [file] [log] [blame]
Tim Murray25207df2015-01-12 16:47:56 -08001/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import android.annotation.IntDef;
20import java.lang.annotation.Retention;
21import java.lang.annotation.RetentionPolicy;
22
23/**
24 *
25 * BLAS
26 *
27 * @hide
28 **/
29public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
30 private Allocation mLUT;
31
32 private ScriptIntrinsicBLAS(long id, RenderScript rs) {
33 super(id, rs);
34 }
35
36 private static final int RsBlas_sdsdot = 1;
37 private static final int RsBlas_dsdot = 2;
38 private static final int RsBlas_sdot = 3;
39 private static final int RsBlas_ddot = 4;
40 private static final int RsBlas_cdotu_sub = 5;
41 private static final int RsBlas_cdotc_sub = 6;
42 private static final int RsBlas_zdotu_sub = 7;
43 private static final int RsBlas_zdotc_sub = 8;
44 private static final int RsBlas_snrm2 = 9;
45 private static final int RsBlas_sasum = 10;
46 private static final int RsBlas_dnrm2 = 11;
47 private static final int RsBlas_dasum = 12;
48 private static final int RsBlas_scnrm2 = 13;
49 private static final int RsBlas_scasum = 14;
50 private static final int RsBlas_dznrm2 = 15;
51 private static final int RsBlas_dzasum = 16;
52 private static final int RsBlas_isamax = 17;
53 private static final int RsBlas_idamax = 18;
54 private static final int RsBlas_icamax = 19;
55 private static final int RsBlas_izamax = 20;
56 private static final int RsBlas_sswap = 21;
57 private static final int RsBlas_scopy = 22;
58 private static final int RsBlas_saxpy = 23;
59 private static final int RsBlas_dswap = 24;
60 private static final int RsBlas_dcopy = 25;
61 private static final int RsBlas_daxpy = 26;
62 private static final int RsBlas_cswap = 27;
63 private static final int RsBlas_ccopy = 28;
64 private static final int RsBlas_caxpy = 29;
65 private static final int RsBlas_zswap = 30;
66 private static final int RsBlas_zcopy = 31;
67 private static final int RsBlas_zaxpy = 32;
68 private static final int RsBlas_srotg = 33;
69 private static final int RsBlas_srotmg = 34;
70 private static final int RsBlas_srot = 35;
71 private static final int RsBlas_srotm = 36;
72 private static final int RsBlas_drotg = 37;
73 private static final int RsBlas_drotmg = 38;
74 private static final int RsBlas_drot = 39;
75 private static final int RsBlas_drotm = 40;
76 private static final int RsBlas_sscal = 41;
77 private static final int RsBlas_dscal = 42;
78 private static final int RsBlas_cscal = 43;
79 private static final int RsBlas_zscal = 44;
80 private static final int RsBlas_csscal = 45;
81 private static final int RsBlas_zdscal = 46;
82 private static final int RsBlas_sgemv = 47;
83 private static final int RsBlas_sgbmv = 48;
84 private static final int RsBlas_strmv = 49;
85 private static final int RsBlas_stbmv = 50;
86 private static final int RsBlas_stpmv = 51;
87 private static final int RsBlas_strsv = 52;
88 private static final int RsBlas_stbsv = 53;
89 private static final int RsBlas_stpsv = 54;
90 private static final int RsBlas_dgemv = 55;
91 private static final int RsBlas_dgbmv = 56;
92 private static final int RsBlas_dtrmv = 57;
93 private static final int RsBlas_dtbmv = 58;
94 private static final int RsBlas_dtpmv = 59;
95 private static final int RsBlas_dtrsv = 60;
96 private static final int RsBlas_dtbsv = 61;
97 private static final int RsBlas_dtpsv = 62;
98 private static final int RsBlas_cgemv = 63;
99 private static final int RsBlas_cgbmv = 64;
100 private static final int RsBlas_ctrmv = 65;
101 private static final int RsBlas_ctbmv = 66;
102 private static final int RsBlas_ctpmv = 67;
103 private static final int RsBlas_ctrsv = 68;
104 private static final int RsBlas_ctbsv = 69;
105 private static final int RsBlas_ctpsv = 70;
106 private static final int RsBlas_zgemv = 71;
107 private static final int RsBlas_zgbmv = 72;
108 private static final int RsBlas_ztrmv = 73;
109 private static final int RsBlas_ztbmv = 74;
110 private static final int RsBlas_ztpmv = 75;
111 private static final int RsBlas_ztrsv = 76;
112 private static final int RsBlas_ztbsv = 77;
113 private static final int RsBlas_ztpsv = 78;
114 private static final int RsBlas_ssymv = 79;
115 private static final int RsBlas_ssbmv = 80;
116 private static final int RsBlas_sspmv = 81;
117 private static final int RsBlas_sger = 82;
118 private static final int RsBlas_ssyr = 83;
119 private static final int RsBlas_sspr = 84;
120 private static final int RsBlas_ssyr2 = 85;
121 private static final int RsBlas_sspr2 = 86;
122 private static final int RsBlas_dsymv = 87;
123 private static final int RsBlas_dsbmv = 88;
124 private static final int RsBlas_dspmv = 89;
125 private static final int RsBlas_dger = 90;
126 private static final int RsBlas_dsyr = 91;
127 private static final int RsBlas_dspr = 92;
128 private static final int RsBlas_dsyr2 = 93;
129 private static final int RsBlas_dspr2 = 94;
130 private static final int RsBlas_chemv = 95;
131 private static final int RsBlas_chbmv = 96;
132 private static final int RsBlas_chpmv = 97;
133 private static final int RsBlas_cgeru = 98;
134 private static final int RsBlas_cgerc = 99;
135 private static final int RsBlas_cher = 100;
136 private static final int RsBlas_chpr = 101;
137 private static final int RsBlas_cher2 = 102;
138 private static final int RsBlas_chpr2 = 103;
139 private static final int RsBlas_zhemv = 104;
140 private static final int RsBlas_zhbmv = 105;
141 private static final int RsBlas_zhpmv = 106;
142 private static final int RsBlas_zgeru = 107;
143 private static final int RsBlas_zgerc = 108;
144 private static final int RsBlas_zher = 109;
145 private static final int RsBlas_zhpr = 110;
146 private static final int RsBlas_zher2 = 111;
147 private static final int RsBlas_zhpr2 = 112;
148 private static final int RsBlas_sgemm = 113;
149 private static final int RsBlas_ssymm = 114;
150 private static final int RsBlas_ssyrk = 115;
151 private static final int RsBlas_ssyr2k = 116;
152 private static final int RsBlas_strmm = 117;
153 private static final int RsBlas_strsm = 118;
154 private static final int RsBlas_dgemm = 119;
155 private static final int RsBlas_dsymm = 120;
156 private static final int RsBlas_dsyrk = 121;
157 private static final int RsBlas_dsyr2k = 122;
158 private static final int RsBlas_dtrmm = 123;
159 private static final int RsBlas_dtrsm = 124;
160 private static final int RsBlas_cgemm = 125;
161 private static final int RsBlas_csymm = 126;
162 private static final int RsBlas_csyrk = 127;
163 private static final int RsBlas_csyr2k = 128;
164 private static final int RsBlas_ctrmm = 129;
165 private static final int RsBlas_ctrsm = 130;
166 private static final int RsBlas_zgemm = 131;
167 private static final int RsBlas_zsymm = 132;
168 private static final int RsBlas_zsyrk = 133;
169 private static final int RsBlas_zsyr2k = 134;
170 private static final int RsBlas_ztrmm = 135;
171 private static final int RsBlas_ztrsm = 136;
172 private static final int RsBlas_chemm = 137;
173 private static final int RsBlas_cherk = 138;
174 private static final int RsBlas_cher2k = 139;
175 private static final int RsBlas_zhemm = 140;
176 private static final int RsBlas_zherk = 141;
177 private static final int RsBlas_zher2k = 142;
178
Tim Murray9cb16a22015-04-01 11:07:16 -0700179 // BLAS extensions start here
180 private static final int RsBlas_bnnm = 1000;
181
Tim Murray25207df2015-01-12 16:47:56 -0800182 /**
183 */
184 public static ScriptIntrinsicBLAS create(RenderScript rs) {
185 long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs));
186 return new ScriptIntrinsicBLAS(id, rs);
187 }
188
189 @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE})
190 @Retention(RetentionPolicy.SOURCE)
191 public @interface Transpose {}
192
193 @IntDef({UPPER, LOWER})
194 @Retention(RetentionPolicy.SOURCE)
195 public @interface Uplo {}
196
197 @IntDef({NON_UNIT, UNIT})
198 @Retention(RetentionPolicy.SOURCE)
199 public @interface Diag {}
200
201 @IntDef({LEFT, RIGHT})
202 @Retention(RetentionPolicy.SOURCE)
203 public @interface Side {}
204
205 public static final int NO_TRANSPOSE = 111;
206 public static final int TRANSPOSE = 112;
207 public static final int CONJ_TRANSPOSE = 113;
208
209 public static final int UPPER = 121;
210 public static final int LOWER = 122;
211
212 public static final int NON_UNIT = 131;
213 public static final int UNIT = 132;
214
215 public static final int LEFT = 141;
216 public static final int RIGHT = 142;
217
218 static void validateSide(@Side int Side) {
219 if (Side != LEFT && Side != RIGHT) {
220 throw new RSRuntimeException("Invalid side passed to BLAS");
221 }
222 }
223
224 static void validateTranspose(@Transpose int Trans) {
225 if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE &&
226 Trans != CONJ_TRANSPOSE) {
227 throw new RSRuntimeException("Invalid transpose passed to BLAS");
228 }
229 }
230
231 static void validateConjTranspose(@Transpose int Trans) {
232 if (Trans != NO_TRANSPOSE &&
233 Trans != CONJ_TRANSPOSE) {
234 throw new RSRuntimeException("Invalid transpose passed to BLAS");
235 }
236 }
237
238 static void validateDiag(@Diag int Diag) {
239 if (Diag != NON_UNIT && Diag != UNIT) {
240 throw new RSRuntimeException("Invalid diag passed to BLAS");
241 }
242 }
243
244 static void validateUplo(@Uplo int Uplo) {
245 if (Uplo != LEFT && Uplo != RIGHT) {
246 throw new RSRuntimeException("Invalid uplo passed to BLAS");
247 }
248 }
249
250
251 /**
252 * Level 2 BLAS
253 */
254
255 static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) {
256 validateTranspose(TransA);
257 int M = A.getType().getY();
258 int N = A.getType().getX();
259 if (!A.getType().getElement().isCompatible(e) ||
260 !X.getType().getElement().isCompatible(e) ||
261 !Y.getType().getElement().isCompatible(e)) {
262 throw new RSRuntimeException("Called BLAS with wrong Element type");
263 }
264 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
265 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
266 }
267
268 if (incX <= 0 || incY <= 0) {
269 throw new RSRuntimeException("Vector increments must be greater than 0");
270 }
271 int expectedXDim = -1, expectedYDim = -1;
272 if (TransA == NO_TRANSPOSE) {
273 expectedXDim = 1 + (N - 1) * incX;
274 expectedYDim = 1 + (M - 1) * incY;
275 } else {
276 expectedXDim = 1 + (M - 1) * incX;
277 expectedYDim = 1 + (N - 1) * incY;
278 }
279 if (X.getType().getX() != expectedXDim ||
280 Y.getType().getY() != expectedXDim) {
281 throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
282 }
283 }
284 void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
285 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
286 int M = A.getType().getY();
287 int N = A.getType().getX();
288 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
289 }
290 void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
291 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
292 int M = A.getType().getY();
293 int N = A.getType().getX();
294 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
295 }
296 void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
297 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
298 int M = A.getType().getY();
299 int N = A.getType().getX();
300 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
301 }
302 void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
303 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
304 int M = A.getType().getY();
305 int N = A.getType().getX();
306 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
307 }
308
309 void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
310 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
311 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
312 if (KL < 0 || KU < 0) {
313 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
314 }
315 int M = A.getType().getY();
316 int N = A.getType().getX();
317 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
318 }
319 void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
320 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
321 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
322 if (KL < 0 || KU < 0) {
323 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
324 }
325 int M = A.getType().getY();
326 int N = A.getType().getX();
327 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
328 }
329 void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
330 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
331 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
332 if (KL < 0 || KU < 0) {
333 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
334 }
335 int M = A.getType().getY();
336 int N = A.getType().getX();
337 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
338 }
339 void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
340 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
341 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
342 if (KL < 0 || KU < 0) {
343 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
344 }
345 int M = A.getType().getY();
346 int N = A.getType().getX();
347 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
348 }
349
350 static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) {
351 validateTranspose(TransA);
352 int N = A.getType().getY();
353 if (A.getType().getX() != N) {
354 throw new RSRuntimeException("A must be a square matrix for TRMV");
355 }
356 if (!A.getType().getElement().isCompatible(e) ||
357 !X.getType().getElement().isCompatible(e)) {
358 throw new RSRuntimeException("Called BLAS with wrong Element type");
359 }
360 if (X.getType().getY() > 1) {
361 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
362 }
363
364 if (incX <= 0) {
365 throw new RSRuntimeException("Vector increments must be greater than 0");
366 }
367 int expectedXDim = 1 + (N - 1) * incX;
368 if (X.getType().getX() != expectedXDim) {
369 throw new RSRuntimeException("Incorrect vector dimensions for TRMV");
370 }
371 }
372
373 static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
374 validateTranspose(TransA);
375 validateUplo(Uplo);
376 validateDiag(Diag);
377 if (!Ap.getType().getElement().isCompatible(e) ||
378 !X.getType().getElement().isCompatible(e)) {
379 throw new RSRuntimeException("Called BLAS with wrong Element type");
380 }
381 if (X.getType().getY() > 1) {
382 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
383 }
384
385 if (Ap.getType().getY() > 1) {
386 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
387 }
388
389 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
390 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
391 throw new RSRuntimeException("Invalid dimension for Ap");
392 }
393
394 int expectedXDim = 1 + (N - 1) * incX;
395 if (X.getType().getX() != expectedXDim) {
396 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
397 }
398
399 return N;
400 }
401
402 void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
403 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
404 int N = A.getType().getY();
405 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
406 }
407 void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
408 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
409 int N = A.getType().getY();
410 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
411 }
412 void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
413 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
414 int N = A.getType().getY();
415 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
416 }
417 void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
418 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
419 int N = A.getType().getY();
420 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
421 }
422 void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
423 // TBMV has the same requirements as TRMV
424 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
425 int N = A.getType().getY();
426 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
427 }
428 void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
429 // TBMV has the same requirements as TRMV
430 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
431 int N = A.getType().getY();
432 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
433 }
434 void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
435 // TBMV has the same requirements as TRMV
436 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
437 int N = A.getType().getY();
438 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
439 }
440 void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
441 // TBMV has the same requirements as TRMV
442 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
443 int N = A.getType().getY();
444 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
445 }
446 void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
447 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
448 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
449 }
450 void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
451 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
452 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
453 }
454 void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
455 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
456 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
457 }
458 void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
459 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
460 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
461 }
462 void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
463 // TRSV is the same as TRMV
464 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
465 int N = A.getType().getY();
466 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
467
468 }
469 void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
470 // TRSV is the same as TRMV
471 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
472 int N = A.getType().getY();
473 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
474
475 }
476 void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
477 // TRSV is the same as TRMV
478 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
479 int N = A.getType().getY();
480 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
481
482 }
483 void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
484 // TRSV is the same as TRMV
485 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
486 int N = A.getType().getY();
487 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
488
489 }
490 void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
491 // TBSV is the same as TRMV
492 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
493 int N = A.getType().getY();
494 if (K < 0) {
495 throw new RSRuntimeException("Number of diagonals must be positive");
496 }
497 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
498 }
499 void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
500 // TBSV is the same as TRMV
501 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
502 int N = A.getType().getY();
503 if (K < 0) {
504 throw new RSRuntimeException("Number of diagonals must be positive");
505 }
506 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
507 }
508 void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
509 // TBSV is the same as TRMV
510 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
511 int N = A.getType().getY();
512 if (K < 0) {
513 throw new RSRuntimeException("Number of diagonals must be positive");
514 }
515 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
516 }
517 void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
518 // TBSV is the same as TRMV
519 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
520 int N = A.getType().getY();
521 if (K < 0) {
522 throw new RSRuntimeException("Number of diagonals must be positive");
523 }
524 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
525 }
526 void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
527 // TPSV is same as TPMV
528 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
529 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
530 }
531 void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
532 // TPSV is same as TPMV
533 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
534 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
535 }
536 void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
537 // TPSV is same as TPMV
538 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
539 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
540 }
541 void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
542 // TPSV is same as TPMV
543 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
544 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
545 }
546
547 /**
548 * Level 2, S and D only
549 */
550 static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) {
551 validateUplo(Uplo);
552 int N = A.getType().getY();
553 if (A.getType().getX() != N) {
554 throw new RSRuntimeException("A must be a square matrix for SYMV");
555 }
556 if (!A.getType().getElement().isCompatible(e) ||
557 !X.getType().getElement().isCompatible(e) ||
558 !Y.getType().getElement().isCompatible(e) ) {
559 throw new RSRuntimeException("Called BLAS with wrong Element type");
560 }
561 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
562 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
563 }
564
565 if (incX <= 0 || incY <= 0) {
566 throw new RSRuntimeException("Vector increments must be greater than 0");
567 }
568 int expectedXDim = 1 + (N - 1) * incX;
569 if (X.getType().getX() != expectedXDim) {
570 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
571 }
572 int expectedYDim = 1 + (N - 1) * incY;
573 if (Y.getType().getX() != expectedYDim) {
574 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
575 }
576 return N;
577 }
578 static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) {
579 validateUplo(Uplo);
580 if (!Ap.getType().getElement().isCompatible(e) ||
581 !X.getType().getElement().isCompatible(e) ||
582 !Y.getType().getElement().isCompatible(e)) {
583 throw new RSRuntimeException("Called BLAS with wrong Element type");
584 }
585 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
586 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
587 }
588
589 if (Ap.getType().getY() > 1) {
590 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
591 }
592
593 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
594 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
595 throw new RSRuntimeException("Invalid dimension for Ap");
596 }
597
598 int expectedXDim = 1 + (N - 1) * incX;
599 if (X.getType().getX() != expectedXDim) {
600 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
601 }
602 int expectedYDim = 1 + (N - 1) * incY;
603 if (Y.getType().getX() != expectedYDim) {
604 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
605 }
606
607 return N;
608 }
609 static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
610 if (!A.getType().getElement().isCompatible(e) ||
611 !X.getType().getElement().isCompatible(e) ||
612 !Y.getType().getElement().isCompatible(e) ) {
613 throw new RSRuntimeException("Called BLAS with wrong Element type");
614 }
615
616 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
617 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
618 }
619
620 int M = A.getType().getY();
621 int N = A.getType().getX();
622
623 if (N < 1 || M < 1) {
624 throw new RSRuntimeException("M and N must be 1 or greater for GER");
625 }
626
627 int expectedXDim = 1 + (N - 1) * incX;
628 if (X.getType().getX() != expectedXDim) {
629 throw new RSRuntimeException("Incorrect vector dimensions for GER");
630 }
631 int expectedYDim = 1 + (N - 1) * incY;
632 if (Y.getType().getX() != expectedYDim) {
633 throw new RSRuntimeException("Incorrect vector dimensions for GER");
634 }
635
636
637 }
638 static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) {
639 validateUplo(Uplo);
640 if (!A.getType().getElement().isCompatible(e) ||
641 !X.getType().getElement().isCompatible(e)) {
642 throw new RSRuntimeException("Called BLAS with wrong Element type");
643 }
644
645 int N = A.getType().getX();
646
647 if (X.getType().getY() > 1) {
648 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
649 }
650 if (N != A.getType().getY()) {
651 throw new RSRuntimeException("A must be a symmetric matrix");
652 }
653
654 int expectedXDim = 1 + (N - 1) * incX;
655 if (X.getType().getX() != expectedXDim) {
656 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
657 }
658 return N;
659 }
660 static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) {
661 validateUplo(Uplo);
662 if (!Ap.getType().getElement().isCompatible(e) ||
663 !X.getType().getElement().isCompatible(e)) {
664 throw new RSRuntimeException("Called BLAS with wrong Element type");
665 }
666 if (X.getType().getY() > 1) {
667 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
668 }
669
670 if (Ap.getType().getY() > 1) {
671 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
672 }
673
674 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
675 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
676 throw new RSRuntimeException("Invalid dimension for Ap");
677 }
678
679 int expectedXDim = 1 + (N - 1) * incX;
680 if (X.getType().getX() != expectedXDim) {
681 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
682 }
683
684 return N;
685 }
686
687 static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
688 validateUplo(Uplo);
689 if (!A.getType().getElement().isCompatible(e) ||
690 !X.getType().getElement().isCompatible(e) ||
691 !Y.getType().getElement().isCompatible(e)) {
692 throw new RSRuntimeException("Called BLAS with wrong Element type");
693 }
694
695 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
696 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
697 }
698
699 int N = A.getType().getX();
700
701 if (N != A.getType().getY()) {
702 throw new RSRuntimeException("A must be a symmetric matrix");
703 }
704
705 int expectedXDim = 1 + (N - 1) * incX;
706 int expectedYDim = 1 + (N - 1) * incY;
707 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
708 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
709 }
710 return N;
711
712 }
713 static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
714 validateUplo(Uplo);
715 if (!Ap.getType().getElement().isCompatible(e) ||
716 !X.getType().getElement().isCompatible(e) ||
717 !Y.getType().getElement().isCompatible(e)) {
718 throw new RSRuntimeException("Called BLAS with wrong Element type");
719 }
720 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
721 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
722 }
723
724 if (Ap.getType().getY() > 1) {
725 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
726 }
727
728 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
729 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
730 throw new RSRuntimeException("Invalid dimension for Ap");
731 }
732
733 int expectedXDim = 1 + (N - 1) * incX;
734 int expectedYDim = 1 + (N - 1) * incY;
735 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
736 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
737 }
738
739 return N;
740 }
741
742 void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
743 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
744 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
745 }
746 void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
747 // SBMV is the same as SYMV
748 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
749 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
750 }
751 void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
752 int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY);
753 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
754 }
755 void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
756 int M = A.getType().getY();
757 int N = A.getType().getX();
758 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
759 }
760 void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
761 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
762 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
763 }
764 void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
765 int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap);
766 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
767 }
768 void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
769 int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A);
770 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
771 }
772 void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
773 int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap);
774 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
775 }
776 void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
777 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
778 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
779 }
780 void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
781 // SBMV is the same as SYMV
782 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
783 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
784 }
785 void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
786 int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY);
787 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
788 }
789 void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
790 int M = A.getType().getY();
791 int N = A.getType().getX();
792 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
793 }
794 void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
795 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
796 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
797 }
798 void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
799 int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap);
800 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
801 }
802 void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
803 int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A);
804 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
805 }
806 void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
807 int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap);
808 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
809 }
810
811
812 /**
813 * Level 2, C and Z only
814 */
815
816 static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
817 if (!A.getType().getElement().isCompatible(e) ||
818 !X.getType().getElement().isCompatible(e) ||
819 !Y.getType().getElement().isCompatible(e)) {
820 throw new RSRuntimeException("Called BLAS with wrong Element type");
821 }
822 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
823 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
824 }
825
826 int M = A.getType().getY();
827 int N = A.getType().getX();
828
829 int expectedXDim = 1 + (N - 1) * incX;
830 if (X.getType().getX() != expectedXDim) {
831 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
832 }
833 int expectedYDim = 1 + (N - 1) * incY;
834 if (Y.getType().getX() != expectedYDim) {
835 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
836 }
837
838 }
839
840 void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
841 // HEMV is the same as SYR2 validation-wise
842 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
843 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
844 }
845 void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
846 // HBMV is the same as SYR2 validation-wise
847 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
848 if (K < 0) {
849 throw new RSRuntimeException("K must be 0 or greater for HBMV");
850 }
851 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
852 }
853 void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
854 // HPMV is the same as SPR2
855 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
856 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
857 }
858 void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
859 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
860 int M = A.getType().getY();
861 int N = A.getType().getX();
862 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
863 }
864 void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
865 // same as GERU
866 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
867 int M = A.getType().getY();
868 int N = A.getType().getX();
869 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
870 }
871 void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
872 // same as SYR
873 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
874 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
875 }
876 void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
877 // equivalent to SPR for validation
878 int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap);
879 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
880 }
881 void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
882 // same as SYR2
883 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
884 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
885 }
886 void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
887 // same as SPR2
888 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
889 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
890 }
891 void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
892 // HEMV is the same as SYR2 validation-wise
893 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
894 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
895 }
896 void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
897 // HBMV is the same as SYR2 validation-wise
898 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
899 if (K < 0) {
900 throw new RSRuntimeException("K must be 0 or greater for HBMV");
901 }
902 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
903 }
904 void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
905 // HPMV is the same as SPR2
906 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
907 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
908 }
909 void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
910 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
911 int M = A.getType().getY();
912 int N = A.getType().getX();
913 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
914 }
915 void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
916 // same as GERU
917 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
918 int M = A.getType().getY();
919 int N = A.getType().getX();
920 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
921 }
922 void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
923 // same as SYR
924 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
925 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
926 }
927 void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
928 // equivalent to SPR for validation
929 int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap);
930 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
931 }
932 void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
933 // same as SYR2
934 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
935 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
936 }
937 void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
938 // same as SPR2
939 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
940 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
941 }
942
943
944 /**
945 * Level 3 BLAS
946 */
947
948 static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
949 int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1;
950 if ((A != null && !A.getType().getElement().isCompatible(e)) ||
951 (B != null && !B.getType().getElement().isCompatible(e)) ||
952 (C != null && !C.getType().getElement().isCompatible(e))) {
953 throw new RSRuntimeException("Called BLAS with wrong Element type");
954 }
955 if (C != null) {
956 cX = C.getType().getY();
957 cY = C.getType().getX();
958 }
959 if (Side == RIGHT) {
960 if (B != null) {
961 bX = A.getType().getY();
962 bY = A.getType().getX();
963 }
964 if (A != null) {
965 aX = B.getType().getY();
966 aY = B.getType().getX();
967 }
968 } else {
969 if (A != null) {
970 if (TransA == TRANSPOSE) {
971 aY = A.getType().getY();
972 aX = A.getType().getX();
973 } else {
974 aX = A.getType().getY();
975 aY = A.getType().getX();
976 }
977 }
978 if (B != null) {
979 if (TransB == TRANSPOSE) {
980 bY = B.getType().getY();
981 bX = B.getType().getX();
982 } else {
983 bX = B.getType().getY();
984 bY = B.getType().getX();
985 }
986 }
987 }
988 if (A != null && B != null && C != null) {
989 if (aY != bX || aX != cX || bY != cY) {
990 throw new RSRuntimeException("Called BLAS with invalid dimensions");
991 }
992 } else if (A != null && C != null) {
993 // A and C only
994 if (aX != cY || aY != cX) {
995 throw new RSRuntimeException("Called BLAS with invalid dimensions");
996 }
997 } else if (A != null && B != null) {
998 // A and B only
999 }
1000
1001 }
1002
1003 public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A,
1004 Allocation B, float beta, Allocation C) {
1005 validateTranspose(TransA);
1006 validateTranspose(TransB);
1007 validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
1008
1009 int M = -1, N = -1, K = -1;
1010 if (TransA == TRANSPOSE) {
1011 M = A.getType().getX();
1012 K = A.getType().getY();
1013 } else {
1014 M = A.getType().getY();
1015 K = A.getType().getX();
1016 }
1017 if (TransB == TRANSPOSE) {
1018 N = B.getType().getY();
1019 } else {
1020 N = B.getType().getX();
1021 }
1022 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1023 beta, C.getID(mRS), 0, 0, 0, 0);
1024 }
1025 public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A,
1026 Allocation B, double beta, Allocation C) {
1027 validateTranspose(TransA);
1028 validateTranspose(TransB);
1029 validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
1030 int M = -1, N = -1, K = -1;
1031 if (TransA == TRANSPOSE) {
1032 M = A.getType().getX();
1033 K = A.getType().getY();
1034 } else {
1035 M = A.getType().getY();
1036 K = A.getType().getX();
1037 }
1038 if (TransB == TRANSPOSE) {
1039 N = B.getType().getY();
1040 } else {
1041 N = B.getType().getX();
1042 }
1043 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1044 beta, C.getID(mRS), 0, 0, 0, 0);
1045 }
1046 public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A,
1047 Allocation B, Float2 beta, Allocation C) {
1048 validateTranspose(TransA);
1049 validateTranspose(TransB);
1050 validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
1051 int M = -1, N = -1, K = -1;
1052 if (TransA == TRANSPOSE) {
1053 M = A.getType().getX();
1054 K = A.getType().getY();
1055 } else {
1056 M = A.getType().getY();
1057 K = A.getType().getX();
1058 }
1059 if (TransB == TRANSPOSE) {
1060 N = B.getType().getY();
1061 } else {
1062 N = B.getType().getX();
1063 }
1064 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1065 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1066 }
1067
1068 public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A,
1069 Allocation B, Double2 beta, Allocation C) {
1070 validateTranspose(TransA);
1071 validateTranspose(TransB);
1072 validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
1073 int M = -1, N = -1, K = -1;
1074 if (TransA == TRANSPOSE) {
1075 M = A.getType().getX();
1076 K = A.getType().getY();
1077 } else {
1078 M = A.getType().getY();
1079 K = A.getType().getX();
1080 }
1081 if (TransB == TRANSPOSE) {
1082 N = B.getType().getY();
1083 } else {
1084 N = B.getType().getX();
1085 }
1086 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1087 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1088 }
1089
1090 public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A,
1091 Allocation B, float beta, Allocation C) {
1092 validateSide(Side);
1093 validateUplo(Uplo);
1094 validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
1095 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1096 beta, C.getID(mRS), 0, 0, 0, 0);
1097 }
1098 public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A,
1099 Allocation B, double beta, Allocation C) {
1100 validateSide(Side);
1101 validateUplo(Uplo);
1102 validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
1103 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1104 beta, C.getID(mRS), 0, 0, 0, 0);
1105 }
1106 public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A,
1107 Allocation B, Float2 beta, Allocation C) {
1108 validateSide(Side);
1109 validateUplo(Uplo);
1110 validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
1111 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1112 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1113 }
1114 public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A,
1115 Allocation B, Double2 beta, Allocation C) {
1116 validateSide(Side);
1117 validateUplo(Uplo);
1118 validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
1119 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1120 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1121 }
1122
1123 public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1124 validateTranspose(Trans);
1125 validateUplo(Uplo);
1126 validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
1127 int K = -1;
1128 if (Trans == TRANSPOSE) {
1129 K = A.getType().getY();
1130 } else {
1131 K = A.getType().getX();
1132 }
1133
1134 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1135 }
1136
1137 public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1138 validateTranspose(Trans);
1139 validateUplo(Uplo);
1140 validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
1141 int K = -1;
1142 if (Trans == TRANSPOSE) {
1143 K = A.getType().getY();
1144 } else {
1145 K = A.getType().getX();
1146 }
1147 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1148 }
1149 public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) {
1150 validateTranspose(Trans);
1151 validateUplo(Uplo);
1152 validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
1153 int K = -1;
1154 if (Trans == TRANSPOSE) {
1155 K = A.getType().getY();
1156 } else {
1157 K = A.getType().getX();
1158 }
1159 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY,
1160 C.getID(mRS), 0, 0, 0, 0);
1161 }
1162 public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) {
1163 validateTranspose(Trans);
1164 validateUplo(Uplo);
1165 validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
1166 int K = -1;
1167 if (Trans == TRANSPOSE) {
1168 K = A.getType().getY();
1169 } else {
1170 K = A.getType().getX();
1171 }
1172 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY,
1173 C.getID(mRS), 0, 0, 0, 0);
1174 }
1175
1176 static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1177 validateTranspose(Trans);
1178 if (!A.getType().getElement().isCompatible(e) ||
1179 !B.getType().getElement().isCompatible(e) ||
1180 !C.getType().getElement().isCompatible(e)) {
1181 throw new RSRuntimeException("Called BLAS with wrong Element type");
1182 }
1183 int Cdim = -1;
1184 // A is n x k if no transpose, k x n if transpose
1185 // C is n x n
1186 if (Trans == TRANSPOSE) {
1187 // check columns versus C
1188 Cdim = A.getType().getX();
1189 } else {
1190 // check rows versus C
1191 Cdim = A.getType().getY();
1192 }
1193 if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) {
1194 throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
1195 }
1196 // A dims == B dims
1197 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1198 throw new RSRuntimeException("Invalid A and B in SYR2K");
1199 }
1200 }
1201 public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
1202 validateUplo(Uplo);
1203 validateSYR2K(Element.F32(mRS), Trans, A, B, C);
1204 int K = -1;
1205 if (Trans == TRANSPOSE) {
1206 K = A.getType().getY();
1207 } else {
1208 K = A.getType().getX();
1209 }
1210 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1211 }
1212 public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
1213 validateUplo(Uplo);
1214 validateSYR2K(Element.F64(mRS), Trans, A, B, C);
1215 int K = -1;
1216 if (Trans == TRANSPOSE) {
1217 K = A.getType().getY();
1218 } else {
1219 K = A.getType().getX();
1220 }
1221 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1222 }
1223 public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
1224 validateUplo(Uplo);
1225 validateSYR2K(Element.F32_2(mRS), Trans, A, B, C);
1226 int K = -1;
1227 if (Trans == TRANSPOSE) {
1228 K = A.getType().getY();
1229 } else {
1230 K = A.getType().getX();
1231 }
1232 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1233 }
1234 public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
1235 validateUplo(Uplo);
1236 validateSYR2K(Element.F64_2(mRS), Trans, A, B, C);
1237 int K = -1;
1238 if (Trans == TRANSPOSE) {
1239 K = A.getType().getY();
1240 } else {
1241 K = A.getType().getX();
1242 }
1243 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1244 }
1245
1246 static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1247 validateSide(Side);
1248 validateTranspose(TransA);
1249 int aX = -1, aY = -1, bX = -1, bY = -1;
1250 if (!A.getType().getElement().isCompatible(e) ||
1251 !B.getType().getElement().isCompatible(e)) {
1252 throw new RSRuntimeException("Called BLAS with wrong Element type");
1253 }
1254 if (TransA == TRANSPOSE) {
1255 aY = A.getType().getY();
1256 aX = A.getType().getX();
1257 } else {
1258 aY = A.getType().getX();
1259 aX = A.getType().getY();
1260 }
1261 bX = B.getType().getY();
1262 bY = B.getType().getX();
1263 if (Side == LEFT) {
1264 if (aX == 0 || aY != bX) {
1265 throw new RSRuntimeException("Called TRMM with invalid matrices");
1266 }
1267 } else {
1268 if (bY != aX || aY == 0) {
1269 throw new RSRuntimeException("Called TRMM with invalid matrices");
1270 }
1271 }
1272 }
1273 public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1274 validateUplo(Uplo);
1275 validateDiag(Diag);
1276 validateTRMM(Element.F32(mRS), Side, TransA, A, B);
1277 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1278 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1279 }
1280 public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1281 validateUplo(Uplo);
1282 validateDiag(Diag);
1283 validateTRMM(Element.F64(mRS), Side, TransA, A, B);
1284 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1285 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1286 }
1287 public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1288 validateUplo(Uplo);
1289 validateDiag(Diag);
1290 validateTRMM(Element.F32_2(mRS), Side, TransA, A, B);
1291 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1292 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1293 }
1294 public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1295 validateUplo(Uplo);
1296 validateDiag(Diag);
1297 validateTRMM(Element.F64_2(mRS), Side, TransA, A, B);
1298 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1299 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1300 }
1301
1302 static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1303 int adim = -1, bX = -1, bY = -1;
1304 validateSide(Side);
1305 validateTranspose(TransA);
1306 if (!A.getType().getElement().isCompatible(e) ||
1307 !B.getType().getElement().isCompatible(e)) {
1308 throw new RSRuntimeException("Called BLAS with wrong Element type");
1309 }
1310 adim = A.getType().getX();
1311 if (adim != A.getType().getY()) {
1312 // this may be unnecessary, the restriction could potentially be relaxed
1313 // A needs to contain at least that symmetric matrix but could theoretically be larger
1314 // for now we assume adapters are sufficient, will reevaluate in the future
1315 throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
1316 }
1317 bX = B.getType().getY();
1318 bY = B.getType().getX();
1319 if (Side == LEFT) {
1320 // A is M*M
1321 if (adim != bY) {
1322 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1323 }
1324 } else {
1325 // A is N*N
1326 if (adim != bX) {
1327 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1328 }
1329 }
1330 }
1331 public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1332 validateUplo(Uplo);
1333 validateDiag(Diag);
1334 validateTRSM(Element.F32(mRS), Side, TransA, A, B);
1335 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1336 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1337 }
1338 public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1339 validateUplo(Uplo);
1340 validateDiag(Diag);
1341 validateTRSM(Element.F64(mRS), Side, TransA, A, B);
1342 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1343 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1344 }
1345 public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1346 validateUplo(Uplo);
1347 validateDiag(Diag);
1348 validateTRSM(Element.F32_2(mRS), Side, TransA, A, B);
1349 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1350 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1351 }
1352 public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1353 validateUplo(Uplo);
1354 validateDiag(Diag);
1355 validateTRSM(Element.F64_2(mRS), Side, TransA, A, B);
1356 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1357 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1358 }
1359
1360 static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) {
1361 validateSide(Side);
1362
1363 if (!A.getType().getElement().isCompatible(e) ||
1364 !B.getType().getElement().isCompatible(e) ||
1365 !C.getType().getElement().isCompatible(e)) {
1366 throw new RSRuntimeException("Called BLAS with wrong Element type");
1367 }
1368
1369 // A must be square; can potentially be relaxed similar to TRSM
1370 int adim = A.getType().getX();
1371 if (adim != A.getType().getY()) {
1372 throw new RSRuntimeException("Called HEMM with non-square A");
1373 }
1374 if ((Side == LEFT && adim != B.getType().getY()) ||
1375 (Side == RIGHT && adim != B.getType().getX())) {
1376 throw new RSRuntimeException("Called HEMM with invalid B");
1377 }
1378 if (B.getType().getX() != C.getType().getX() ||
1379 B.getType().getY() != C.getType().getY()) {
1380 throw new RSRuntimeException("Called HEMM with mismatched B and C");
1381 }
1382 }
1383 public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
1384 validateUplo(Uplo);
1385 validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1386 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
1387 alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1388 }
1389 public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
1390 validateUplo(Uplo);
1391 validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1392 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
1393 alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1394 }
1395
1396 static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) {
1397 if (!A.getType().getElement().isCompatible(e) ||
1398 !C.getType().getElement().isCompatible(e)) {
1399 throw new RSRuntimeException("Called BLAS with wrong Element type");
1400 }
1401 validateConjTranspose(Trans);
1402 int cdim = C.getType().getX();
1403 if (cdim != C.getType().getY()) {
1404 throw new RSRuntimeException("Called HERK with non-square C");
1405 }
1406 if (Trans == NO_TRANSPOSE) {
1407 if (cdim != A.getType().getX()) {
1408 throw new RSRuntimeException("Called HERK with invalid A");
1409 }
1410 } else {
1411 if (cdim != A.getType().getY()) {
1412 throw new RSRuntimeException("Called HERK with invalid A");
1413 }
1414 }
1415 }
1416 public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1417 validateUplo(Uplo);
1418 validateHERK(Element.F32_2(mRS), Trans, A, C);
1419 int k = 0;
1420 if (Trans == TRANSPOSE) {
1421 k = A.getType().getY();
1422 } else {
1423 k = A.getType().getX();
1424 }
1425 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1426 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1427 }
1428 public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1429 validateUplo(Uplo);
1430 validateHERK(Element.F64_2(mRS), Trans, A, C);
1431 int k = 0;
1432 if (Trans == TRANSPOSE) {
1433 k = A.getType().getY();
1434 } else {
1435 k = A.getType().getX();
1436 }
1437 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1438 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1439 }
1440
1441 static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1442 if (!A.getType().getElement().isCompatible(e) ||
1443 !B.getType().getElement().isCompatible(e) ||
1444 !C.getType().getElement().isCompatible(e)) {
1445 throw new RSRuntimeException("Called BLAS with wrong Element type");
1446 }
1447 validateConjTranspose(Trans);
1448 int cdim = C.getType().getX();
1449 if (cdim != C.getType().getY()) {
1450 throw new RSRuntimeException("Called HER2K with non-square C");
1451 }
1452 if (Trans == NO_TRANSPOSE) {
1453 if (A.getType().getY() != cdim) {
1454 throw new RSRuntimeException("Called HER2K with invalid matrices");
1455 }
1456 } else {
1457 if (A.getType().getX() != cdim) {
1458 throw new RSRuntimeException("Called HER2K with invalid matrices");
1459 }
1460 }
1461 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1462 throw new RSRuntimeException("Called HER2K with invalid A and B matrices");
1463 }
1464 }
1465 public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) {
1466 validateUplo(Uplo);
1467 validateHER2K(Element.F32_2(mRS), Trans, A, B, C);
1468 int k = 0;
1469 if (Trans == NO_TRANSPOSE) {
1470 k = A.getType().getX();
1471 } else {
1472 k = A.getType().getY();
1473 }
1474 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1475 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1476 }
1477 public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) {
1478 validateUplo(Uplo);
1479 validateHER2K(Element.F64_2(mRS), Trans, A, B, C);
1480 int k = 0;
1481 if (Trans == NO_TRANSPOSE) {
1482 k = A.getType().getX();
1483 } else {
1484 k = A.getType().getY();
1485 }
1486 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1487 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1488 }
1489
1490
Tim Murray9cb16a22015-04-01 11:07:16 -07001491 /**
1492 *
1493 * 8-bit GEMM-like operation for neural networks
1494 *
1495 * @hide
1496 **/
1497 public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) {
1498 validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C);
1499
1500 int M = -1, N = -1, K = -1;
1501 M = A.getType().getY();
1502 N = B.getType().getY();
1503 K = A.getType().getX();
1504
1505
1506 mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult);
1507
1508 }
Tim Murray25207df2015-01-12 16:47:56 -08001509
1510}