blob: 51096a019b93d4f8d44558cfa5751c2108bcf56a [file] [log] [blame]
Tim Murray25207df2015-01-12 16:47:56 -08001/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import android.annotation.IntDef;
20import java.lang.annotation.Retention;
21import java.lang.annotation.RetentionPolicy;
22
23/**
24 *
25 * BLAS
26 *
Tim Murray25207df2015-01-12 16:47:56 -080027 **/
28public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
29 private Allocation mLUT;
30
31 private ScriptIntrinsicBLAS(long id, RenderScript rs) {
32 super(id, rs);
33 }
34
35 private static final int RsBlas_sdsdot = 1;
36 private static final int RsBlas_dsdot = 2;
37 private static final int RsBlas_sdot = 3;
38 private static final int RsBlas_ddot = 4;
39 private static final int RsBlas_cdotu_sub = 5;
40 private static final int RsBlas_cdotc_sub = 6;
41 private static final int RsBlas_zdotu_sub = 7;
42 private static final int RsBlas_zdotc_sub = 8;
43 private static final int RsBlas_snrm2 = 9;
44 private static final int RsBlas_sasum = 10;
45 private static final int RsBlas_dnrm2 = 11;
46 private static final int RsBlas_dasum = 12;
47 private static final int RsBlas_scnrm2 = 13;
48 private static final int RsBlas_scasum = 14;
49 private static final int RsBlas_dznrm2 = 15;
50 private static final int RsBlas_dzasum = 16;
51 private static final int RsBlas_isamax = 17;
52 private static final int RsBlas_idamax = 18;
53 private static final int RsBlas_icamax = 19;
54 private static final int RsBlas_izamax = 20;
55 private static final int RsBlas_sswap = 21;
56 private static final int RsBlas_scopy = 22;
57 private static final int RsBlas_saxpy = 23;
58 private static final int RsBlas_dswap = 24;
59 private static final int RsBlas_dcopy = 25;
60 private static final int RsBlas_daxpy = 26;
61 private static final int RsBlas_cswap = 27;
62 private static final int RsBlas_ccopy = 28;
63 private static final int RsBlas_caxpy = 29;
64 private static final int RsBlas_zswap = 30;
65 private static final int RsBlas_zcopy = 31;
66 private static final int RsBlas_zaxpy = 32;
67 private static final int RsBlas_srotg = 33;
68 private static final int RsBlas_srotmg = 34;
69 private static final int RsBlas_srot = 35;
70 private static final int RsBlas_srotm = 36;
71 private static final int RsBlas_drotg = 37;
72 private static final int RsBlas_drotmg = 38;
73 private static final int RsBlas_drot = 39;
74 private static final int RsBlas_drotm = 40;
75 private static final int RsBlas_sscal = 41;
76 private static final int RsBlas_dscal = 42;
77 private static final int RsBlas_cscal = 43;
78 private static final int RsBlas_zscal = 44;
79 private static final int RsBlas_csscal = 45;
80 private static final int RsBlas_zdscal = 46;
81 private static final int RsBlas_sgemv = 47;
82 private static final int RsBlas_sgbmv = 48;
83 private static final int RsBlas_strmv = 49;
84 private static final int RsBlas_stbmv = 50;
85 private static final int RsBlas_stpmv = 51;
86 private static final int RsBlas_strsv = 52;
87 private static final int RsBlas_stbsv = 53;
88 private static final int RsBlas_stpsv = 54;
89 private static final int RsBlas_dgemv = 55;
90 private static final int RsBlas_dgbmv = 56;
91 private static final int RsBlas_dtrmv = 57;
92 private static final int RsBlas_dtbmv = 58;
93 private static final int RsBlas_dtpmv = 59;
94 private static final int RsBlas_dtrsv = 60;
95 private static final int RsBlas_dtbsv = 61;
96 private static final int RsBlas_dtpsv = 62;
97 private static final int RsBlas_cgemv = 63;
98 private static final int RsBlas_cgbmv = 64;
99 private static final int RsBlas_ctrmv = 65;
100 private static final int RsBlas_ctbmv = 66;
101 private static final int RsBlas_ctpmv = 67;
102 private static final int RsBlas_ctrsv = 68;
103 private static final int RsBlas_ctbsv = 69;
104 private static final int RsBlas_ctpsv = 70;
105 private static final int RsBlas_zgemv = 71;
106 private static final int RsBlas_zgbmv = 72;
107 private static final int RsBlas_ztrmv = 73;
108 private static final int RsBlas_ztbmv = 74;
109 private static final int RsBlas_ztpmv = 75;
110 private static final int RsBlas_ztrsv = 76;
111 private static final int RsBlas_ztbsv = 77;
112 private static final int RsBlas_ztpsv = 78;
113 private static final int RsBlas_ssymv = 79;
114 private static final int RsBlas_ssbmv = 80;
115 private static final int RsBlas_sspmv = 81;
116 private static final int RsBlas_sger = 82;
117 private static final int RsBlas_ssyr = 83;
118 private static final int RsBlas_sspr = 84;
119 private static final int RsBlas_ssyr2 = 85;
120 private static final int RsBlas_sspr2 = 86;
121 private static final int RsBlas_dsymv = 87;
122 private static final int RsBlas_dsbmv = 88;
123 private static final int RsBlas_dspmv = 89;
124 private static final int RsBlas_dger = 90;
125 private static final int RsBlas_dsyr = 91;
126 private static final int RsBlas_dspr = 92;
127 private static final int RsBlas_dsyr2 = 93;
128 private static final int RsBlas_dspr2 = 94;
129 private static final int RsBlas_chemv = 95;
130 private static final int RsBlas_chbmv = 96;
131 private static final int RsBlas_chpmv = 97;
132 private static final int RsBlas_cgeru = 98;
133 private static final int RsBlas_cgerc = 99;
134 private static final int RsBlas_cher = 100;
135 private static final int RsBlas_chpr = 101;
136 private static final int RsBlas_cher2 = 102;
137 private static final int RsBlas_chpr2 = 103;
138 private static final int RsBlas_zhemv = 104;
139 private static final int RsBlas_zhbmv = 105;
140 private static final int RsBlas_zhpmv = 106;
141 private static final int RsBlas_zgeru = 107;
142 private static final int RsBlas_zgerc = 108;
143 private static final int RsBlas_zher = 109;
144 private static final int RsBlas_zhpr = 110;
145 private static final int RsBlas_zher2 = 111;
146 private static final int RsBlas_zhpr2 = 112;
147 private static final int RsBlas_sgemm = 113;
148 private static final int RsBlas_ssymm = 114;
149 private static final int RsBlas_ssyrk = 115;
150 private static final int RsBlas_ssyr2k = 116;
151 private static final int RsBlas_strmm = 117;
152 private static final int RsBlas_strsm = 118;
153 private static final int RsBlas_dgemm = 119;
154 private static final int RsBlas_dsymm = 120;
155 private static final int RsBlas_dsyrk = 121;
156 private static final int RsBlas_dsyr2k = 122;
157 private static final int RsBlas_dtrmm = 123;
158 private static final int RsBlas_dtrsm = 124;
159 private static final int RsBlas_cgemm = 125;
160 private static final int RsBlas_csymm = 126;
161 private static final int RsBlas_csyrk = 127;
162 private static final int RsBlas_csyr2k = 128;
163 private static final int RsBlas_ctrmm = 129;
164 private static final int RsBlas_ctrsm = 130;
165 private static final int RsBlas_zgemm = 131;
166 private static final int RsBlas_zsymm = 132;
167 private static final int RsBlas_zsyrk = 133;
168 private static final int RsBlas_zsyr2k = 134;
169 private static final int RsBlas_ztrmm = 135;
170 private static final int RsBlas_ztrsm = 136;
171 private static final int RsBlas_chemm = 137;
172 private static final int RsBlas_cherk = 138;
173 private static final int RsBlas_cher2k = 139;
174 private static final int RsBlas_zhemm = 140;
175 private static final int RsBlas_zherk = 141;
176 private static final int RsBlas_zher2k = 142;
177
Tim Murray9cb16a22015-04-01 11:07:16 -0700178 // BLAS extensions start here
179 private static final int RsBlas_bnnm = 1000;
180
Tim Murray25207df2015-01-12 16:47:56 -0800181 /**
182 */
183 public static ScriptIntrinsicBLAS create(RenderScript rs) {
184 long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs));
185 return new ScriptIntrinsicBLAS(id, rs);
186 }
187
188 @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE})
189 @Retention(RetentionPolicy.SOURCE)
190 public @interface Transpose {}
191
192 @IntDef({UPPER, LOWER})
193 @Retention(RetentionPolicy.SOURCE)
194 public @interface Uplo {}
195
196 @IntDef({NON_UNIT, UNIT})
197 @Retention(RetentionPolicy.SOURCE)
198 public @interface Diag {}
199
200 @IntDef({LEFT, RIGHT})
201 @Retention(RetentionPolicy.SOURCE)
202 public @interface Side {}
203
204 public static final int NO_TRANSPOSE = 111;
205 public static final int TRANSPOSE = 112;
206 public static final int CONJ_TRANSPOSE = 113;
207
208 public static final int UPPER = 121;
209 public static final int LOWER = 122;
210
211 public static final int NON_UNIT = 131;
212 public static final int UNIT = 132;
213
214 public static final int LEFT = 141;
215 public static final int RIGHT = 142;
216
217 static void validateSide(@Side int Side) {
218 if (Side != LEFT && Side != RIGHT) {
219 throw new RSRuntimeException("Invalid side passed to BLAS");
220 }
221 }
222
223 static void validateTranspose(@Transpose int Trans) {
224 if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE &&
225 Trans != CONJ_TRANSPOSE) {
226 throw new RSRuntimeException("Invalid transpose passed to BLAS");
227 }
228 }
229
230 static void validateConjTranspose(@Transpose int Trans) {
231 if (Trans != NO_TRANSPOSE &&
232 Trans != CONJ_TRANSPOSE) {
233 throw new RSRuntimeException("Invalid transpose passed to BLAS");
234 }
235 }
236
237 static void validateDiag(@Diag int Diag) {
238 if (Diag != NON_UNIT && Diag != UNIT) {
239 throw new RSRuntimeException("Invalid diag passed to BLAS");
240 }
241 }
242
243 static void validateUplo(@Uplo int Uplo) {
244 if (Uplo != LEFT && Uplo != RIGHT) {
245 throw new RSRuntimeException("Invalid uplo passed to BLAS");
246 }
247 }
248
249
250 /**
251 * Level 2 BLAS
252 */
253
254 static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) {
255 validateTranspose(TransA);
256 int M = A.getType().getY();
257 int N = A.getType().getX();
258 if (!A.getType().getElement().isCompatible(e) ||
259 !X.getType().getElement().isCompatible(e) ||
260 !Y.getType().getElement().isCompatible(e)) {
261 throw new RSRuntimeException("Called BLAS with wrong Element type");
262 }
263 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
264 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
265 }
266
267 if (incX <= 0 || incY <= 0) {
268 throw new RSRuntimeException("Vector increments must be greater than 0");
269 }
270 int expectedXDim = -1, expectedYDim = -1;
271 if (TransA == NO_TRANSPOSE) {
272 expectedXDim = 1 + (N - 1) * incX;
273 expectedYDim = 1 + (M - 1) * incY;
274 } else {
275 expectedXDim = 1 + (M - 1) * incX;
276 expectedYDim = 1 + (N - 1) * incY;
277 }
278 if (X.getType().getX() != expectedXDim ||
279 Y.getType().getY() != expectedXDim) {
280 throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
281 }
282 }
283 void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
284 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
285 int M = A.getType().getY();
286 int N = A.getType().getX();
287 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
288 }
289 void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
290 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
291 int M = A.getType().getY();
292 int N = A.getType().getX();
293 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
294 }
295 void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
296 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
297 int M = A.getType().getY();
298 int N = A.getType().getX();
299 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
300 }
301 void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
302 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
303 int M = A.getType().getY();
304 int N = A.getType().getX();
305 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
306 }
307
308 void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
309 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
310 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
311 if (KL < 0 || KU < 0) {
312 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
313 }
314 int M = A.getType().getY();
315 int N = A.getType().getX();
316 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
317 }
318 void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
319 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
320 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
321 if (KL < 0 || KU < 0) {
322 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
323 }
324 int M = A.getType().getY();
325 int N = A.getType().getX();
326 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
327 }
328 void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
329 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
330 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
331 if (KL < 0 || KU < 0) {
332 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
333 }
334 int M = A.getType().getY();
335 int N = A.getType().getX();
336 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
337 }
338 void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
339 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
340 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
341 if (KL < 0 || KU < 0) {
342 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
343 }
344 int M = A.getType().getY();
345 int N = A.getType().getX();
346 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
347 }
348
349 static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) {
350 validateTranspose(TransA);
351 int N = A.getType().getY();
352 if (A.getType().getX() != N) {
353 throw new RSRuntimeException("A must be a square matrix for TRMV");
354 }
355 if (!A.getType().getElement().isCompatible(e) ||
356 !X.getType().getElement().isCompatible(e)) {
357 throw new RSRuntimeException("Called BLAS with wrong Element type");
358 }
359 if (X.getType().getY() > 1) {
360 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
361 }
362
363 if (incX <= 0) {
364 throw new RSRuntimeException("Vector increments must be greater than 0");
365 }
366 int expectedXDim = 1 + (N - 1) * incX;
367 if (X.getType().getX() != expectedXDim) {
368 throw new RSRuntimeException("Incorrect vector dimensions for TRMV");
369 }
370 }
371
372 static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
373 validateTranspose(TransA);
374 validateUplo(Uplo);
375 validateDiag(Diag);
376 if (!Ap.getType().getElement().isCompatible(e) ||
377 !X.getType().getElement().isCompatible(e)) {
378 throw new RSRuntimeException("Called BLAS with wrong Element type");
379 }
380 if (X.getType().getY() > 1) {
381 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
382 }
383
384 if (Ap.getType().getY() > 1) {
385 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
386 }
387
388 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
389 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
390 throw new RSRuntimeException("Invalid dimension for Ap");
391 }
392
393 int expectedXDim = 1 + (N - 1) * incX;
394 if (X.getType().getX() != expectedXDim) {
395 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
396 }
397
398 return N;
399 }
400
401 void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
402 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
403 int N = A.getType().getY();
404 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
405 }
406 void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
407 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
408 int N = A.getType().getY();
409 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
410 }
411 void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
412 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
413 int N = A.getType().getY();
414 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
415 }
416 void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
417 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
418 int N = A.getType().getY();
419 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
420 }
421 void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
422 // TBMV has the same requirements as TRMV
423 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
424 int N = A.getType().getY();
425 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
426 }
427 void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
428 // TBMV has the same requirements as TRMV
429 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
430 int N = A.getType().getY();
431 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
432 }
433 void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
434 // TBMV has the same requirements as TRMV
435 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
436 int N = A.getType().getY();
437 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
438 }
439 void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
440 // TBMV has the same requirements as TRMV
441 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
442 int N = A.getType().getY();
443 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
444 }
445 void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
446 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
447 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
448 }
449 void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
450 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
451 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
452 }
453 void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
454 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
455 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
456 }
457 void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
458 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
459 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
460 }
461 void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
462 // TRSV is the same as TRMV
463 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
464 int N = A.getType().getY();
465 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
466
467 }
468 void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
469 // TRSV is the same as TRMV
470 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
471 int N = A.getType().getY();
472 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
473
474 }
475 void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
476 // TRSV is the same as TRMV
477 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
478 int N = A.getType().getY();
479 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
480
481 }
482 void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
483 // TRSV is the same as TRMV
484 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
485 int N = A.getType().getY();
486 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
487
488 }
489 void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
490 // TBSV is the same as TRMV
491 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
492 int N = A.getType().getY();
493 if (K < 0) {
494 throw new RSRuntimeException("Number of diagonals must be positive");
495 }
496 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
497 }
498 void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
499 // TBSV is the same as TRMV
500 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
501 int N = A.getType().getY();
502 if (K < 0) {
503 throw new RSRuntimeException("Number of diagonals must be positive");
504 }
505 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
506 }
507 void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
508 // TBSV is the same as TRMV
509 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
510 int N = A.getType().getY();
511 if (K < 0) {
512 throw new RSRuntimeException("Number of diagonals must be positive");
513 }
514 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
515 }
516 void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
517 // TBSV is the same as TRMV
518 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
519 int N = A.getType().getY();
520 if (K < 0) {
521 throw new RSRuntimeException("Number of diagonals must be positive");
522 }
523 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
524 }
525 void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
526 // TPSV is same as TPMV
527 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
528 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
529 }
530 void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
531 // TPSV is same as TPMV
532 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
533 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
534 }
535 void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
536 // TPSV is same as TPMV
537 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
538 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
539 }
540 void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
541 // TPSV is same as TPMV
542 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
543 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
544 }
545
546 /**
547 * Level 2, S and D only
548 */
549 static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) {
550 validateUplo(Uplo);
551 int N = A.getType().getY();
552 if (A.getType().getX() != N) {
553 throw new RSRuntimeException("A must be a square matrix for SYMV");
554 }
555 if (!A.getType().getElement().isCompatible(e) ||
556 !X.getType().getElement().isCompatible(e) ||
557 !Y.getType().getElement().isCompatible(e) ) {
558 throw new RSRuntimeException("Called BLAS with wrong Element type");
559 }
560 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
561 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
562 }
563
564 if (incX <= 0 || incY <= 0) {
565 throw new RSRuntimeException("Vector increments must be greater than 0");
566 }
567 int expectedXDim = 1 + (N - 1) * incX;
568 if (X.getType().getX() != expectedXDim) {
569 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
570 }
571 int expectedYDim = 1 + (N - 1) * incY;
572 if (Y.getType().getX() != expectedYDim) {
573 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
574 }
575 return N;
576 }
577 static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) {
578 validateUplo(Uplo);
579 if (!Ap.getType().getElement().isCompatible(e) ||
580 !X.getType().getElement().isCompatible(e) ||
581 !Y.getType().getElement().isCompatible(e)) {
582 throw new RSRuntimeException("Called BLAS with wrong Element type");
583 }
584 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
585 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
586 }
587
588 if (Ap.getType().getY() > 1) {
589 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
590 }
591
592 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
593 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
594 throw new RSRuntimeException("Invalid dimension for Ap");
595 }
596
597 int expectedXDim = 1 + (N - 1) * incX;
598 if (X.getType().getX() != expectedXDim) {
599 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
600 }
601 int expectedYDim = 1 + (N - 1) * incY;
602 if (Y.getType().getX() != expectedYDim) {
603 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
604 }
605
606 return N;
607 }
608 static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
609 if (!A.getType().getElement().isCompatible(e) ||
610 !X.getType().getElement().isCompatible(e) ||
611 !Y.getType().getElement().isCompatible(e) ) {
612 throw new RSRuntimeException("Called BLAS with wrong Element type");
613 }
614
615 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
616 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
617 }
618
619 int M = A.getType().getY();
620 int N = A.getType().getX();
621
622 if (N < 1 || M < 1) {
623 throw new RSRuntimeException("M and N must be 1 or greater for GER");
624 }
625
626 int expectedXDim = 1 + (N - 1) * incX;
627 if (X.getType().getX() != expectedXDim) {
628 throw new RSRuntimeException("Incorrect vector dimensions for GER");
629 }
630 int expectedYDim = 1 + (N - 1) * incY;
631 if (Y.getType().getX() != expectedYDim) {
632 throw new RSRuntimeException("Incorrect vector dimensions for GER");
633 }
634
635
636 }
637 static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) {
638 validateUplo(Uplo);
639 if (!A.getType().getElement().isCompatible(e) ||
640 !X.getType().getElement().isCompatible(e)) {
641 throw new RSRuntimeException("Called BLAS with wrong Element type");
642 }
643
644 int N = A.getType().getX();
645
646 if (X.getType().getY() > 1) {
647 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
648 }
649 if (N != A.getType().getY()) {
650 throw new RSRuntimeException("A must be a symmetric matrix");
651 }
652
653 int expectedXDim = 1 + (N - 1) * incX;
654 if (X.getType().getX() != expectedXDim) {
655 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
656 }
657 return N;
658 }
659 static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) {
660 validateUplo(Uplo);
661 if (!Ap.getType().getElement().isCompatible(e) ||
662 !X.getType().getElement().isCompatible(e)) {
663 throw new RSRuntimeException("Called BLAS with wrong Element type");
664 }
665 if (X.getType().getY() > 1) {
666 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
667 }
668
669 if (Ap.getType().getY() > 1) {
670 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
671 }
672
673 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
674 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
675 throw new RSRuntimeException("Invalid dimension for Ap");
676 }
677
678 int expectedXDim = 1 + (N - 1) * incX;
679 if (X.getType().getX() != expectedXDim) {
680 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
681 }
682
683 return N;
684 }
685
686 static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
687 validateUplo(Uplo);
688 if (!A.getType().getElement().isCompatible(e) ||
689 !X.getType().getElement().isCompatible(e) ||
690 !Y.getType().getElement().isCompatible(e)) {
691 throw new RSRuntimeException("Called BLAS with wrong Element type");
692 }
693
694 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
695 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
696 }
697
698 int N = A.getType().getX();
699
700 if (N != A.getType().getY()) {
701 throw new RSRuntimeException("A must be a symmetric matrix");
702 }
703
704 int expectedXDim = 1 + (N - 1) * incX;
705 int expectedYDim = 1 + (N - 1) * incY;
706 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
707 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
708 }
709 return N;
710
711 }
712 static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
713 validateUplo(Uplo);
714 if (!Ap.getType().getElement().isCompatible(e) ||
715 !X.getType().getElement().isCompatible(e) ||
716 !Y.getType().getElement().isCompatible(e)) {
717 throw new RSRuntimeException("Called BLAS with wrong Element type");
718 }
719 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
720 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
721 }
722
723 if (Ap.getType().getY() > 1) {
724 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
725 }
726
727 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
728 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
729 throw new RSRuntimeException("Invalid dimension for Ap");
730 }
731
732 int expectedXDim = 1 + (N - 1) * incX;
733 int expectedYDim = 1 + (N - 1) * incY;
734 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
735 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
736 }
737
738 return N;
739 }
740
741 void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
742 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
743 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
744 }
745 void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
746 // SBMV is the same as SYMV
747 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
748 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
749 }
750 void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
751 int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY);
752 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
753 }
754 void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
755 int M = A.getType().getY();
756 int N = A.getType().getX();
757 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
758 }
759 void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
760 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
761 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
762 }
763 void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
764 int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap);
765 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
766 }
767 void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
768 int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A);
769 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
770 }
771 void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
772 int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap);
773 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
774 }
775 void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
776 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
777 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
778 }
779 void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
780 // SBMV is the same as SYMV
781 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
782 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
783 }
784 void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
785 int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY);
786 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
787 }
788 void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
789 int M = A.getType().getY();
790 int N = A.getType().getX();
791 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
792 }
793 void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
794 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
795 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
796 }
797 void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
798 int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap);
799 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
800 }
801 void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
802 int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A);
803 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
804 }
805 void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
806 int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap);
807 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
808 }
809
810
811 /**
812 * Level 2, C and Z only
813 */
814
815 static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
816 if (!A.getType().getElement().isCompatible(e) ||
817 !X.getType().getElement().isCompatible(e) ||
818 !Y.getType().getElement().isCompatible(e)) {
819 throw new RSRuntimeException("Called BLAS with wrong Element type");
820 }
821 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
822 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
823 }
824
825 int M = A.getType().getY();
826 int N = A.getType().getX();
827
828 int expectedXDim = 1 + (N - 1) * incX;
829 if (X.getType().getX() != expectedXDim) {
830 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
831 }
832 int expectedYDim = 1 + (N - 1) * incY;
833 if (Y.getType().getX() != expectedYDim) {
834 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
835 }
836
837 }
838
839 void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
840 // HEMV is the same as SYR2 validation-wise
841 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
842 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
843 }
844 void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
845 // HBMV is the same as SYR2 validation-wise
846 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
847 if (K < 0) {
848 throw new RSRuntimeException("K must be 0 or greater for HBMV");
849 }
850 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
851 }
852 void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
853 // HPMV is the same as SPR2
854 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
855 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
856 }
857 void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
858 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
859 int M = A.getType().getY();
860 int N = A.getType().getX();
861 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
862 }
863 void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
864 // same as GERU
865 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
866 int M = A.getType().getY();
867 int N = A.getType().getX();
868 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
869 }
870 void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
871 // same as SYR
872 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
873 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
874 }
875 void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
876 // equivalent to SPR for validation
877 int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap);
878 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
879 }
880 void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
881 // same as SYR2
882 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
883 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
884 }
885 void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
886 // same as SPR2
887 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
888 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
889 }
890 void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
891 // HEMV is the same as SYR2 validation-wise
892 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
893 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
894 }
895 void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
896 // HBMV is the same as SYR2 validation-wise
897 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
898 if (K < 0) {
899 throw new RSRuntimeException("K must be 0 or greater for HBMV");
900 }
901 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
902 }
903 void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
904 // HPMV is the same as SPR2
905 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
906 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
907 }
908 void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
909 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
910 int M = A.getType().getY();
911 int N = A.getType().getX();
912 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
913 }
914 void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
915 // same as GERU
916 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
917 int M = A.getType().getY();
918 int N = A.getType().getX();
919 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
920 }
921 void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
922 // same as SYR
923 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
924 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
925 }
926 void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
927 // equivalent to SPR for validation
928 int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap);
929 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
930 }
931 void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
932 // same as SYR2
933 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
934 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
935 }
936 void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
937 // same as SPR2
938 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
939 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
940 }
941
942
943 /**
944 * Level 3 BLAS
945 */
946
947 static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
948 int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1;
949 if ((A != null && !A.getType().getElement().isCompatible(e)) ||
950 (B != null && !B.getType().getElement().isCompatible(e)) ||
951 (C != null && !C.getType().getElement().isCompatible(e))) {
952 throw new RSRuntimeException("Called BLAS with wrong Element type");
953 }
954 if (C != null) {
955 cX = C.getType().getY();
956 cY = C.getType().getX();
957 }
958 if (Side == RIGHT) {
959 if (B != null) {
960 bX = A.getType().getY();
961 bY = A.getType().getX();
962 }
963 if (A != null) {
964 aX = B.getType().getY();
965 aY = B.getType().getX();
966 }
967 } else {
968 if (A != null) {
969 if (TransA == TRANSPOSE) {
970 aY = A.getType().getY();
971 aX = A.getType().getX();
972 } else {
973 aX = A.getType().getY();
974 aY = A.getType().getX();
975 }
976 }
977 if (B != null) {
978 if (TransB == TRANSPOSE) {
979 bY = B.getType().getY();
980 bX = B.getType().getX();
981 } else {
982 bX = B.getType().getY();
983 bY = B.getType().getX();
984 }
985 }
986 }
987 if (A != null && B != null && C != null) {
988 if (aY != bX || aX != cX || bY != cY) {
989 throw new RSRuntimeException("Called BLAS with invalid dimensions");
990 }
991 } else if (A != null && C != null) {
992 // A and C only
993 if (aX != cY || aY != cX) {
994 throw new RSRuntimeException("Called BLAS with invalid dimensions");
995 }
996 } else if (A != null && B != null) {
997 // A and B only
998 }
999
1000 }
1001
1002 public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A,
1003 Allocation B, float beta, Allocation C) {
1004 validateTranspose(TransA);
1005 validateTranspose(TransB);
1006 validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
1007
1008 int M = -1, N = -1, K = -1;
1009 if (TransA == TRANSPOSE) {
1010 M = A.getType().getX();
1011 K = A.getType().getY();
1012 } else {
1013 M = A.getType().getY();
1014 K = A.getType().getX();
1015 }
1016 if (TransB == TRANSPOSE) {
1017 N = B.getType().getY();
1018 } else {
1019 N = B.getType().getX();
1020 }
1021 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1022 beta, C.getID(mRS), 0, 0, 0, 0);
1023 }
1024 public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A,
1025 Allocation B, double beta, Allocation C) {
1026 validateTranspose(TransA);
1027 validateTranspose(TransB);
1028 validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
1029 int M = -1, N = -1, K = -1;
1030 if (TransA == TRANSPOSE) {
1031 M = A.getType().getX();
1032 K = A.getType().getY();
1033 } else {
1034 M = A.getType().getY();
1035 K = A.getType().getX();
1036 }
1037 if (TransB == TRANSPOSE) {
1038 N = B.getType().getY();
1039 } else {
1040 N = B.getType().getX();
1041 }
1042 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1043 beta, C.getID(mRS), 0, 0, 0, 0);
1044 }
1045 public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A,
1046 Allocation B, Float2 beta, Allocation C) {
1047 validateTranspose(TransA);
1048 validateTranspose(TransB);
1049 validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
1050 int M = -1, N = -1, K = -1;
1051 if (TransA == TRANSPOSE) {
1052 M = A.getType().getX();
1053 K = A.getType().getY();
1054 } else {
1055 M = A.getType().getY();
1056 K = A.getType().getX();
1057 }
1058 if (TransB == TRANSPOSE) {
1059 N = B.getType().getY();
1060 } else {
1061 N = B.getType().getX();
1062 }
1063 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1064 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1065 }
1066
1067 public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A,
1068 Allocation B, Double2 beta, Allocation C) {
1069 validateTranspose(TransA);
1070 validateTranspose(TransB);
1071 validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
1072 int M = -1, N = -1, K = -1;
1073 if (TransA == TRANSPOSE) {
1074 M = A.getType().getX();
1075 K = A.getType().getY();
1076 } else {
1077 M = A.getType().getY();
1078 K = A.getType().getX();
1079 }
1080 if (TransB == TRANSPOSE) {
1081 N = B.getType().getY();
1082 } else {
1083 N = B.getType().getX();
1084 }
1085 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1086 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1087 }
1088
1089 public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A,
1090 Allocation B, float beta, Allocation C) {
1091 validateSide(Side);
1092 validateUplo(Uplo);
1093 validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
1094 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1095 beta, C.getID(mRS), 0, 0, 0, 0);
1096 }
1097 public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A,
1098 Allocation B, double beta, Allocation C) {
1099 validateSide(Side);
1100 validateUplo(Uplo);
1101 validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
1102 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1103 beta, C.getID(mRS), 0, 0, 0, 0);
1104 }
1105 public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A,
1106 Allocation B, Float2 beta, Allocation C) {
1107 validateSide(Side);
1108 validateUplo(Uplo);
1109 validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
1110 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1111 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1112 }
1113 public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A,
1114 Allocation B, Double2 beta, Allocation C) {
1115 validateSide(Side);
1116 validateUplo(Uplo);
1117 validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
1118 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1119 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1120 }
1121
1122 public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1123 validateTranspose(Trans);
1124 validateUplo(Uplo);
1125 validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
1126 int K = -1;
1127 if (Trans == TRANSPOSE) {
1128 K = A.getType().getY();
1129 } else {
1130 K = A.getType().getX();
1131 }
1132
1133 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1134 }
1135
1136 public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1137 validateTranspose(Trans);
1138 validateUplo(Uplo);
1139 validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
1140 int K = -1;
1141 if (Trans == TRANSPOSE) {
1142 K = A.getType().getY();
1143 } else {
1144 K = A.getType().getX();
1145 }
1146 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1147 }
Miao Wang4c472742015-04-22 15:57:57 -07001148 public void CSYRK(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Float2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001149 validateTranspose(Trans);
1150 validateUplo(Uplo);
1151 validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
1152 int K = -1;
1153 if (Trans == TRANSPOSE) {
1154 K = A.getType().getY();
1155 } else {
1156 K = A.getType().getX();
1157 }
Miao Wang4c472742015-04-22 15:57:57 -07001158 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
Tim Murray25207df2015-01-12 16:47:56 -08001159 C.getID(mRS), 0, 0, 0, 0);
1160 }
Miao Wang4c472742015-04-22 15:57:57 -07001161 public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Double2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001162 validateTranspose(Trans);
1163 validateUplo(Uplo);
1164 validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
1165 int K = -1;
1166 if (Trans == TRANSPOSE) {
1167 K = A.getType().getY();
1168 } else {
1169 K = A.getType().getX();
1170 }
Miao Wang4c472742015-04-22 15:57:57 -07001171 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
Tim Murray25207df2015-01-12 16:47:56 -08001172 C.getID(mRS), 0, 0, 0, 0);
1173 }
1174
1175 static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1176 validateTranspose(Trans);
1177 if (!A.getType().getElement().isCompatible(e) ||
1178 !B.getType().getElement().isCompatible(e) ||
1179 !C.getType().getElement().isCompatible(e)) {
1180 throw new RSRuntimeException("Called BLAS with wrong Element type");
1181 }
1182 int Cdim = -1;
1183 // A is n x k if no transpose, k x n if transpose
1184 // C is n x n
1185 if (Trans == TRANSPOSE) {
1186 // check columns versus C
1187 Cdim = A.getType().getX();
1188 } else {
1189 // check rows versus C
1190 Cdim = A.getType().getY();
1191 }
1192 if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) {
1193 throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
1194 }
1195 // A dims == B dims
1196 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1197 throw new RSRuntimeException("Invalid A and B in SYR2K");
1198 }
1199 }
1200 public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
1201 validateUplo(Uplo);
1202 validateSYR2K(Element.F32(mRS), Trans, A, B, C);
1203 int K = -1;
1204 if (Trans == TRANSPOSE) {
1205 K = A.getType().getY();
1206 } else {
1207 K = A.getType().getX();
1208 }
1209 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1210 }
1211 public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
1212 validateUplo(Uplo);
1213 validateSYR2K(Element.F64(mRS), Trans, A, B, C);
1214 int K = -1;
1215 if (Trans == TRANSPOSE) {
1216 K = A.getType().getY();
1217 } else {
1218 K = A.getType().getX();
1219 }
1220 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1221 }
1222 public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
1223 validateUplo(Uplo);
1224 validateSYR2K(Element.F32_2(mRS), Trans, A, B, C);
1225 int K = -1;
1226 if (Trans == TRANSPOSE) {
1227 K = A.getType().getY();
1228 } else {
1229 K = A.getType().getX();
1230 }
1231 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1232 }
1233 public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
1234 validateUplo(Uplo);
1235 validateSYR2K(Element.F64_2(mRS), Trans, A, B, C);
1236 int K = -1;
1237 if (Trans == TRANSPOSE) {
1238 K = A.getType().getY();
1239 } else {
1240 K = A.getType().getX();
1241 }
1242 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1243 }
1244
1245 static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1246 validateSide(Side);
1247 validateTranspose(TransA);
1248 int aX = -1, aY = -1, bX = -1, bY = -1;
1249 if (!A.getType().getElement().isCompatible(e) ||
1250 !B.getType().getElement().isCompatible(e)) {
1251 throw new RSRuntimeException("Called BLAS with wrong Element type");
1252 }
1253 if (TransA == TRANSPOSE) {
1254 aY = A.getType().getY();
1255 aX = A.getType().getX();
1256 } else {
1257 aY = A.getType().getX();
1258 aX = A.getType().getY();
1259 }
1260 bX = B.getType().getY();
1261 bY = B.getType().getX();
1262 if (Side == LEFT) {
1263 if (aX == 0 || aY != bX) {
1264 throw new RSRuntimeException("Called TRMM with invalid matrices");
1265 }
1266 } else {
1267 if (bY != aX || aY == 0) {
1268 throw new RSRuntimeException("Called TRMM with invalid matrices");
1269 }
1270 }
1271 }
1272 public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1273 validateUplo(Uplo);
1274 validateDiag(Diag);
1275 validateTRMM(Element.F32(mRS), Side, TransA, A, B);
1276 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1277 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1278 }
1279 public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1280 validateUplo(Uplo);
1281 validateDiag(Diag);
1282 validateTRMM(Element.F64(mRS), Side, TransA, A, B);
1283 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1284 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1285 }
1286 public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1287 validateUplo(Uplo);
1288 validateDiag(Diag);
1289 validateTRMM(Element.F32_2(mRS), Side, TransA, A, B);
1290 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1291 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1292 }
1293 public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1294 validateUplo(Uplo);
1295 validateDiag(Diag);
1296 validateTRMM(Element.F64_2(mRS), Side, TransA, A, B);
1297 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1298 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1299 }
1300
1301 static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1302 int adim = -1, bX = -1, bY = -1;
1303 validateSide(Side);
1304 validateTranspose(TransA);
1305 if (!A.getType().getElement().isCompatible(e) ||
1306 !B.getType().getElement().isCompatible(e)) {
1307 throw new RSRuntimeException("Called BLAS with wrong Element type");
1308 }
1309 adim = A.getType().getX();
1310 if (adim != A.getType().getY()) {
1311 // this may be unnecessary, the restriction could potentially be relaxed
1312 // A needs to contain at least that symmetric matrix but could theoretically be larger
1313 // for now we assume adapters are sufficient, will reevaluate in the future
1314 throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
1315 }
1316 bX = B.getType().getY();
1317 bY = B.getType().getX();
1318 if (Side == LEFT) {
1319 // A is M*M
1320 if (adim != bY) {
1321 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1322 }
1323 } else {
1324 // A is N*N
1325 if (adim != bX) {
1326 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1327 }
1328 }
1329 }
1330 public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1331 validateUplo(Uplo);
1332 validateDiag(Diag);
1333 validateTRSM(Element.F32(mRS), Side, TransA, A, B);
1334 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1335 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1336 }
1337 public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1338 validateUplo(Uplo);
1339 validateDiag(Diag);
1340 validateTRSM(Element.F64(mRS), Side, TransA, A, B);
1341 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1342 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1343 }
1344 public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1345 validateUplo(Uplo);
1346 validateDiag(Diag);
1347 validateTRSM(Element.F32_2(mRS), Side, TransA, A, B);
1348 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1349 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1350 }
1351 public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1352 validateUplo(Uplo);
1353 validateDiag(Diag);
1354 validateTRSM(Element.F64_2(mRS), Side, TransA, A, B);
1355 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1356 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1357 }
1358
1359 static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) {
1360 validateSide(Side);
1361
1362 if (!A.getType().getElement().isCompatible(e) ||
1363 !B.getType().getElement().isCompatible(e) ||
1364 !C.getType().getElement().isCompatible(e)) {
1365 throw new RSRuntimeException("Called BLAS with wrong Element type");
1366 }
1367
1368 // A must be square; can potentially be relaxed similar to TRSM
1369 int adim = A.getType().getX();
1370 if (adim != A.getType().getY()) {
1371 throw new RSRuntimeException("Called HEMM with non-square A");
1372 }
1373 if ((Side == LEFT && adim != B.getType().getY()) ||
1374 (Side == RIGHT && adim != B.getType().getX())) {
1375 throw new RSRuntimeException("Called HEMM with invalid B");
1376 }
1377 if (B.getType().getX() != C.getType().getX() ||
1378 B.getType().getY() != C.getType().getY()) {
1379 throw new RSRuntimeException("Called HEMM with mismatched B and C");
1380 }
1381 }
Miao Wang4c472742015-04-22 15:57:57 -07001382 public void CHEMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001383 validateUplo(Uplo);
1384 validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1385 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
Miao Wang4c472742015-04-22 15:57:57 -07001386 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001387 }
Miao Wang4c472742015-04-22 15:57:57 -07001388 public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001389 validateUplo(Uplo);
1390 validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1391 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
Miao Wang4c472742015-04-22 15:57:57 -07001392 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001393 }
1394
1395 static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) {
1396 if (!A.getType().getElement().isCompatible(e) ||
1397 !C.getType().getElement().isCompatible(e)) {
1398 throw new RSRuntimeException("Called BLAS with wrong Element type");
1399 }
1400 validateConjTranspose(Trans);
1401 int cdim = C.getType().getX();
1402 if (cdim != C.getType().getY()) {
1403 throw new RSRuntimeException("Called HERK with non-square C");
1404 }
1405 if (Trans == NO_TRANSPOSE) {
1406 if (cdim != A.getType().getX()) {
1407 throw new RSRuntimeException("Called HERK with invalid A");
1408 }
1409 } else {
1410 if (cdim != A.getType().getY()) {
1411 throw new RSRuntimeException("Called HERK with invalid A");
1412 }
1413 }
1414 }
1415 public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1416 validateUplo(Uplo);
1417 validateHERK(Element.F32_2(mRS), Trans, A, C);
1418 int k = 0;
1419 if (Trans == TRANSPOSE) {
1420 k = A.getType().getY();
1421 } else {
1422 k = A.getType().getX();
1423 }
1424 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1425 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1426 }
1427 public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1428 validateUplo(Uplo);
1429 validateHERK(Element.F64_2(mRS), Trans, A, C);
1430 int k = 0;
1431 if (Trans == TRANSPOSE) {
1432 k = A.getType().getY();
1433 } else {
1434 k = A.getType().getX();
1435 }
1436 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1437 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1438 }
1439
1440 static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1441 if (!A.getType().getElement().isCompatible(e) ||
1442 !B.getType().getElement().isCompatible(e) ||
1443 !C.getType().getElement().isCompatible(e)) {
1444 throw new RSRuntimeException("Called BLAS with wrong Element type");
1445 }
1446 validateConjTranspose(Trans);
1447 int cdim = C.getType().getX();
1448 if (cdim != C.getType().getY()) {
1449 throw new RSRuntimeException("Called HER2K with non-square C");
1450 }
1451 if (Trans == NO_TRANSPOSE) {
1452 if (A.getType().getY() != cdim) {
1453 throw new RSRuntimeException("Called HER2K with invalid matrices");
1454 }
1455 } else {
1456 if (A.getType().getX() != cdim) {
1457 throw new RSRuntimeException("Called HER2K with invalid matrices");
1458 }
1459 }
1460 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1461 throw new RSRuntimeException("Called HER2K with invalid A and B matrices");
1462 }
1463 }
1464 public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) {
1465 validateUplo(Uplo);
1466 validateHER2K(Element.F32_2(mRS), Trans, A, B, C);
1467 int k = 0;
1468 if (Trans == NO_TRANSPOSE) {
1469 k = A.getType().getX();
1470 } else {
1471 k = A.getType().getY();
1472 }
1473 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1474 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1475 }
1476 public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) {
1477 validateUplo(Uplo);
1478 validateHER2K(Element.F64_2(mRS), Trans, A, B, C);
1479 int k = 0;
1480 if (Trans == NO_TRANSPOSE) {
1481 k = A.getType().getX();
1482 } else {
1483 k = A.getType().getY();
1484 }
1485 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1486 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1487 }
1488
1489
Tim Murray9cb16a22015-04-01 11:07:16 -07001490 /**
1491 *
1492 * 8-bit GEMM-like operation for neural networks
1493 *
Tim Murray9cb16a22015-04-01 11:07:16 -07001494 **/
1495 public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) {
1496 validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C);
1497
1498 int M = -1, N = -1, K = -1;
1499 M = A.getType().getY();
1500 N = B.getType().getY();
1501 K = A.getType().getX();
1502
1503
1504 mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult);
1505
1506 }
Tim Murray25207df2015-01-12 16:47:56 -08001507
1508}