blob: 90d2300e553f9733318d21dd6018bb17ce844b32 [file] [log] [blame]
Tim Murray3d8215a2015-02-18 00:44:08 +00001/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import android.annotation.IntDef;
20import java.lang.annotation.Retention;
21import java.lang.annotation.RetentionPolicy;
22
23/**
24 *
25 * BLAS
26 *
27 * @hide
28 **/
29public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
30 private Allocation mLUT;
31
32 private ScriptIntrinsicBLAS(long id, RenderScript rs) {
33 super(id, rs);
34 }
35
36 private static final int RsBlas_sdsdot = 1;
37 private static final int RsBlas_dsdot = 2;
38 private static final int RsBlas_sdot = 3;
39 private static final int RsBlas_ddot = 4;
40 private static final int RsBlas_cdotu_sub = 5;
41 private static final int RsBlas_cdotc_sub = 6;
42 private static final int RsBlas_zdotu_sub = 7;
43 private static final int RsBlas_zdotc_sub = 8;
44 private static final int RsBlas_snrm2 = 9;
45 private static final int RsBlas_sasum = 10;
46 private static final int RsBlas_dnrm2 = 11;
47 private static final int RsBlas_dasum = 12;
48 private static final int RsBlas_scnrm2 = 13;
49 private static final int RsBlas_scasum = 14;
50 private static final int RsBlas_dznrm2 = 15;
51 private static final int RsBlas_dzasum = 16;
52 private static final int RsBlas_isamax = 17;
53 private static final int RsBlas_idamax = 18;
54 private static final int RsBlas_icamax = 19;
55 private static final int RsBlas_izamax = 20;
56 private static final int RsBlas_sswap = 21;
57 private static final int RsBlas_scopy = 22;
58 private static final int RsBlas_saxpy = 23;
59 private static final int RsBlas_dswap = 24;
60 private static final int RsBlas_dcopy = 25;
61 private static final int RsBlas_daxpy = 26;
62 private static final int RsBlas_cswap = 27;
63 private static final int RsBlas_ccopy = 28;
64 private static final int RsBlas_caxpy = 29;
65 private static final int RsBlas_zswap = 30;
66 private static final int RsBlas_zcopy = 31;
67 private static final int RsBlas_zaxpy = 32;
68 private static final int RsBlas_srotg = 33;
69 private static final int RsBlas_srotmg = 34;
70 private static final int RsBlas_srot = 35;
71 private static final int RsBlas_srotm = 36;
72 private static final int RsBlas_drotg = 37;
73 private static final int RsBlas_drotmg = 38;
74 private static final int RsBlas_drot = 39;
75 private static final int RsBlas_drotm = 40;
76 private static final int RsBlas_sscal = 41;
77 private static final int RsBlas_dscal = 42;
78 private static final int RsBlas_cscal = 43;
79 private static final int RsBlas_zscal = 44;
80 private static final int RsBlas_csscal = 45;
81 private static final int RsBlas_zdscal = 46;
82 private static final int RsBlas_sgemv = 47;
83 private static final int RsBlas_sgbmv = 48;
84 private static final int RsBlas_strmv = 49;
85 private static final int RsBlas_stbmv = 50;
86 private static final int RsBlas_stpmv = 51;
87 private static final int RsBlas_strsv = 52;
88 private static final int RsBlas_stbsv = 53;
89 private static final int RsBlas_stpsv = 54;
90 private static final int RsBlas_dgemv = 55;
91 private static final int RsBlas_dgbmv = 56;
92 private static final int RsBlas_dtrmv = 57;
93 private static final int RsBlas_dtbmv = 58;
94 private static final int RsBlas_dtpmv = 59;
95 private static final int RsBlas_dtrsv = 60;
96 private static final int RsBlas_dtbsv = 61;
97 private static final int RsBlas_dtpsv = 62;
98 private static final int RsBlas_cgemv = 63;
99 private static final int RsBlas_cgbmv = 64;
100 private static final int RsBlas_ctrmv = 65;
101 private static final int RsBlas_ctbmv = 66;
102 private static final int RsBlas_ctpmv = 67;
103 private static final int RsBlas_ctrsv = 68;
104 private static final int RsBlas_ctbsv = 69;
105 private static final int RsBlas_ctpsv = 70;
106 private static final int RsBlas_zgemv = 71;
107 private static final int RsBlas_zgbmv = 72;
108 private static final int RsBlas_ztrmv = 73;
109 private static final int RsBlas_ztbmv = 74;
110 private static final int RsBlas_ztpmv = 75;
111 private static final int RsBlas_ztrsv = 76;
112 private static final int RsBlas_ztbsv = 77;
113 private static final int RsBlas_ztpsv = 78;
114 private static final int RsBlas_ssymv = 79;
115 private static final int RsBlas_ssbmv = 80;
116 private static final int RsBlas_sspmv = 81;
117 private static final int RsBlas_sger = 82;
118 private static final int RsBlas_ssyr = 83;
119 private static final int RsBlas_sspr = 84;
120 private static final int RsBlas_ssyr2 = 85;
121 private static final int RsBlas_sspr2 = 86;
122 private static final int RsBlas_dsymv = 87;
123 private static final int RsBlas_dsbmv = 88;
124 private static final int RsBlas_dspmv = 89;
125 private static final int RsBlas_dger = 90;
126 private static final int RsBlas_dsyr = 91;
127 private static final int RsBlas_dspr = 92;
128 private static final int RsBlas_dsyr2 = 93;
129 private static final int RsBlas_dspr2 = 94;
130 private static final int RsBlas_chemv = 95;
131 private static final int RsBlas_chbmv = 96;
132 private static final int RsBlas_chpmv = 97;
133 private static final int RsBlas_cgeru = 98;
134 private static final int RsBlas_cgerc = 99;
135 private static final int RsBlas_cher = 100;
136 private static final int RsBlas_chpr = 101;
137 private static final int RsBlas_cher2 = 102;
138 private static final int RsBlas_chpr2 = 103;
139 private static final int RsBlas_zhemv = 104;
140 private static final int RsBlas_zhbmv = 105;
141 private static final int RsBlas_zhpmv = 106;
142 private static final int RsBlas_zgeru = 107;
143 private static final int RsBlas_zgerc = 108;
144 private static final int RsBlas_zher = 109;
145 private static final int RsBlas_zhpr = 110;
146 private static final int RsBlas_zher2 = 111;
147 private static final int RsBlas_zhpr2 = 112;
148 private static final int RsBlas_sgemm = 113;
149 private static final int RsBlas_ssymm = 114;
150 private static final int RsBlas_ssyrk = 115;
151 private static final int RsBlas_ssyr2k = 116;
152 private static final int RsBlas_strmm = 117;
153 private static final int RsBlas_strsm = 118;
154 private static final int RsBlas_dgemm = 119;
155 private static final int RsBlas_dsymm = 120;
156 private static final int RsBlas_dsyrk = 121;
157 private static final int RsBlas_dsyr2k = 122;
158 private static final int RsBlas_dtrmm = 123;
159 private static final int RsBlas_dtrsm = 124;
160 private static final int RsBlas_cgemm = 125;
161 private static final int RsBlas_csymm = 126;
162 private static final int RsBlas_csyrk = 127;
163 private static final int RsBlas_csyr2k = 128;
164 private static final int RsBlas_ctrmm = 129;
165 private static final int RsBlas_ctrsm = 130;
166 private static final int RsBlas_zgemm = 131;
167 private static final int RsBlas_zsymm = 132;
168 private static final int RsBlas_zsyrk = 133;
169 private static final int RsBlas_zsyr2k = 134;
170 private static final int RsBlas_ztrmm = 135;
171 private static final int RsBlas_ztrsm = 136;
172 private static final int RsBlas_chemm = 137;
173 private static final int RsBlas_cherk = 138;
174 private static final int RsBlas_cher2k = 139;
175 private static final int RsBlas_zhemm = 140;
176 private static final int RsBlas_zherk = 141;
177 private static final int RsBlas_zher2k = 142;
178
179 /**
180 */
181 public static ScriptIntrinsicBLAS create(RenderScript rs) {
182 long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs));
183 return new ScriptIntrinsicBLAS(id, rs);
184 }
185
186 @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE})
187 @Retention(RetentionPolicy.SOURCE)
188 public @interface Transpose {}
189
190 @IntDef({UPPER, LOWER})
191 @Retention(RetentionPolicy.SOURCE)
192 public @interface Uplo {}
193
194 @IntDef({NON_UNIT, UNIT})
195 @Retention(RetentionPolicy.SOURCE)
196 public @interface Diag {}
197
198 @IntDef({LEFT, RIGHT})
199 @Retention(RetentionPolicy.SOURCE)
200 public @interface Side {}
201
202 public static final int NO_TRANSPOSE = 111;
203 public static final int TRANSPOSE = 112;
204 public static final int CONJ_TRANSPOSE = 113;
205
206 public static final int UPPER = 121;
207 public static final int LOWER = 122;
208
209 public static final int NON_UNIT = 131;
210 public static final int UNIT = 132;
211
212 public static final int LEFT = 141;
213 public static final int RIGHT = 142;
214
215 static void validateSide(@Side int Side) {
216 if (Side != LEFT && Side != RIGHT) {
217 throw new RSRuntimeException("Invalid side passed to BLAS");
218 }
219 }
220
221 static void validateTranspose(@Transpose int Trans) {
222 if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE &&
223 Trans != CONJ_TRANSPOSE) {
224 throw new RSRuntimeException("Invalid transpose passed to BLAS");
225 }
226 }
227
228 static void validateConjTranspose(@Transpose int Trans) {
229 if (Trans != NO_TRANSPOSE &&
230 Trans != CONJ_TRANSPOSE) {
231 throw new RSRuntimeException("Invalid transpose passed to BLAS");
232 }
233 }
234
235 static void validateDiag(@Diag int Diag) {
236 if (Diag != NON_UNIT && Diag != UNIT) {
237 throw new RSRuntimeException("Invalid diag passed to BLAS");
238 }
239 }
240
241 static void validateUplo(@Uplo int Uplo) {
242 if (Uplo != LEFT && Uplo != RIGHT) {
243 throw new RSRuntimeException("Invalid uplo passed to BLAS");
244 }
245 }
246
247
248 /**
249 * Level 2 BLAS
250 */
251
252 static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) {
253 validateTranspose(TransA);
254 int M = A.getType().getY();
255 int N = A.getType().getX();
256 if (!A.getType().getElement().isCompatible(e) ||
257 !X.getType().getElement().isCompatible(e) ||
258 !Y.getType().getElement().isCompatible(e)) {
259 throw new RSRuntimeException("Called BLAS with wrong Element type");
260 }
261 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
262 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
263 }
264
265 if (incX <= 0 || incY <= 0) {
266 throw new RSRuntimeException("Vector increments must be greater than 0");
267 }
268 int expectedXDim = -1, expectedYDim = -1;
269 if (TransA == NO_TRANSPOSE) {
270 expectedXDim = 1 + (N - 1) * incX;
271 expectedYDim = 1 + (M - 1) * incY;
272 } else {
273 expectedXDim = 1 + (M - 1) * incX;
274 expectedYDim = 1 + (N - 1) * incY;
275 }
276 if (X.getType().getX() != expectedXDim ||
277 Y.getType().getY() != expectedXDim) {
278 throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
279 }
280 }
281 void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
282 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
283 int M = A.getType().getY();
284 int N = A.getType().getX();
285 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
286 }
287 void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
288 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
289 int M = A.getType().getY();
290 int N = A.getType().getX();
291 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
292 }
293 void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
294 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
295 int M = A.getType().getY();
296 int N = A.getType().getX();
297 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
298 }
299 void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
300 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
301 int M = A.getType().getY();
302 int N = A.getType().getX();
303 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
304 }
305
306 void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
307 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
308 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
309 if (KL < 0 || KU < 0) {
310 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
311 }
312 int M = A.getType().getY();
313 int N = A.getType().getX();
314 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
315 }
316 void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
317 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
318 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
319 if (KL < 0 || KU < 0) {
320 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
321 }
322 int M = A.getType().getY();
323 int N = A.getType().getX();
324 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
325 }
326 void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
327 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
328 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
329 if (KL < 0 || KU < 0) {
330 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
331 }
332 int M = A.getType().getY();
333 int N = A.getType().getX();
334 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
335 }
336 void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
337 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
338 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
339 if (KL < 0 || KU < 0) {
340 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
341 }
342 int M = A.getType().getY();
343 int N = A.getType().getX();
344 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
345 }
346
347 static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) {
348 validateTranspose(TransA);
349 int N = A.getType().getY();
350 if (A.getType().getX() != N) {
351 throw new RSRuntimeException("A must be a square matrix for TRMV");
352 }
353 if (!A.getType().getElement().isCompatible(e) ||
354 !X.getType().getElement().isCompatible(e)) {
355 throw new RSRuntimeException("Called BLAS with wrong Element type");
356 }
357 if (X.getType().getY() > 1) {
358 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
359 }
360
361 if (incX <= 0) {
362 throw new RSRuntimeException("Vector increments must be greater than 0");
363 }
364 int expectedXDim = 1 + (N - 1) * incX;
365 if (X.getType().getX() != expectedXDim) {
366 throw new RSRuntimeException("Incorrect vector dimensions for TRMV");
367 }
368 }
369
370 static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
371 validateTranspose(TransA);
372 validateUplo(Uplo);
373 validateDiag(Diag);
374 if (!Ap.getType().getElement().isCompatible(e) ||
375 !X.getType().getElement().isCompatible(e)) {
376 throw new RSRuntimeException("Called BLAS with wrong Element type");
377 }
378 if (X.getType().getY() > 1) {
379 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
380 }
381
382 if (Ap.getType().getY() > 1) {
383 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
384 }
385
386 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
387 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
388 throw new RSRuntimeException("Invalid dimension for Ap");
389 }
390
391 int expectedXDim = 1 + (N - 1) * incX;
392 if (X.getType().getX() != expectedXDim) {
393 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
394 }
395
396 return N;
397 }
398
399 void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
400 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
401 int N = A.getType().getY();
402 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
403 }
404 void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
405 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
406 int N = A.getType().getY();
407 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
408 }
409 void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
410 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
411 int N = A.getType().getY();
412 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
413 }
414 void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
415 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
416 int N = A.getType().getY();
417 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
418 }
419 void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
420 // TBMV has the same requirements as TRMV
421 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
422 int N = A.getType().getY();
423 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
424 }
425 void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
426 // TBMV has the same requirements as TRMV
427 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
428 int N = A.getType().getY();
429 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
430 }
431 void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
432 // TBMV has the same requirements as TRMV
433 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
434 int N = A.getType().getY();
435 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
436 }
437 void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
438 // TBMV has the same requirements as TRMV
439 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
440 int N = A.getType().getY();
441 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
442 }
443 void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
444 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
445 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
446 }
447 void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
448 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
449 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
450 }
451 void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
452 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
453 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
454 }
455 void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
456 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
457 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
458 }
459 void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
460 // TRSV is the same as TRMV
461 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
462 int N = A.getType().getY();
463 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
464
465 }
466 void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
467 // TRSV is the same as TRMV
468 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
469 int N = A.getType().getY();
470 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
471
472 }
473 void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
474 // TRSV is the same as TRMV
475 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
476 int N = A.getType().getY();
477 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
478
479 }
480 void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
481 // TRSV is the same as TRMV
482 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
483 int N = A.getType().getY();
484 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
485
486 }
487 void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
488 // TBSV is the same as TRMV
489 validateTRMV(Element.F32(mRS), TransA, A, X, incX);
490 int N = A.getType().getY();
491 if (K < 0) {
492 throw new RSRuntimeException("Number of diagonals must be positive");
493 }
494 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
495 }
496 void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
497 // TBSV is the same as TRMV
498 validateTRMV(Element.F64(mRS), TransA, A, X, incX);
499 int N = A.getType().getY();
500 if (K < 0) {
501 throw new RSRuntimeException("Number of diagonals must be positive");
502 }
503 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
504 }
505 void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
506 // TBSV is the same as TRMV
507 validateTRMV(Element.F32_2(mRS), TransA, A, X, incX);
508 int N = A.getType().getY();
509 if (K < 0) {
510 throw new RSRuntimeException("Number of diagonals must be positive");
511 }
512 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
513 }
514 void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
515 // TBSV is the same as TRMV
516 validateTRMV(Element.F64_2(mRS), TransA, A, X, incX);
517 int N = A.getType().getY();
518 if (K < 0) {
519 throw new RSRuntimeException("Number of diagonals must be positive");
520 }
521 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
522 }
523 void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
524 // TPSV is same as TPMV
525 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
526 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
527 }
528 void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
529 // TPSV is same as TPMV
530 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
531 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
532 }
533 void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
534 // TPSV is same as TPMV
535 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
536 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
537 }
538 void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
539 // TPSV is same as TPMV
540 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
541 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
542 }
543
544 /**
545 * Level 2, S and D only
546 */
547 static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) {
548 validateUplo(Uplo);
549 int N = A.getType().getY();
550 if (A.getType().getX() != N) {
551 throw new RSRuntimeException("A must be a square matrix for SYMV");
552 }
553 if (!A.getType().getElement().isCompatible(e) ||
554 !X.getType().getElement().isCompatible(e) ||
555 !Y.getType().getElement().isCompatible(e) ) {
556 throw new RSRuntimeException("Called BLAS with wrong Element type");
557 }
558 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
559 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
560 }
561
562 if (incX <= 0 || incY <= 0) {
563 throw new RSRuntimeException("Vector increments must be greater than 0");
564 }
565 int expectedXDim = 1 + (N - 1) * incX;
566 if (X.getType().getX() != expectedXDim) {
567 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
568 }
569 int expectedYDim = 1 + (N - 1) * incY;
570 if (Y.getType().getX() != expectedYDim) {
571 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
572 }
573 return N;
574 }
575 static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) {
576 validateUplo(Uplo);
577 if (!Ap.getType().getElement().isCompatible(e) ||
578 !X.getType().getElement().isCompatible(e) ||
579 !Y.getType().getElement().isCompatible(e)) {
580 throw new RSRuntimeException("Called BLAS with wrong Element type");
581 }
582 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
583 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
584 }
585
586 if (Ap.getType().getY() > 1) {
587 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
588 }
589
590 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
591 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
592 throw new RSRuntimeException("Invalid dimension for Ap");
593 }
594
595 int expectedXDim = 1 + (N - 1) * incX;
596 if (X.getType().getX() != expectedXDim) {
597 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
598 }
599 int expectedYDim = 1 + (N - 1) * incY;
600 if (Y.getType().getX() != expectedYDim) {
601 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
602 }
603
604 return N;
605 }
606 static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
607 if (!A.getType().getElement().isCompatible(e) ||
608 !X.getType().getElement().isCompatible(e) ||
609 !Y.getType().getElement().isCompatible(e) ) {
610 throw new RSRuntimeException("Called BLAS with wrong Element type");
611 }
612
613 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
614 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
615 }
616
617 int M = A.getType().getY();
618 int N = A.getType().getX();
619
620 if (N < 1 || M < 1) {
621 throw new RSRuntimeException("M and N must be 1 or greater for GER");
622 }
623
624 int expectedXDim = 1 + (N - 1) * incX;
625 if (X.getType().getX() != expectedXDim) {
626 throw new RSRuntimeException("Incorrect vector dimensions for GER");
627 }
628 int expectedYDim = 1 + (N - 1) * incY;
629 if (Y.getType().getX() != expectedYDim) {
630 throw new RSRuntimeException("Incorrect vector dimensions for GER");
631 }
632
633
634 }
635 static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) {
636 validateUplo(Uplo);
637 if (!A.getType().getElement().isCompatible(e) ||
638 !X.getType().getElement().isCompatible(e)) {
639 throw new RSRuntimeException("Called BLAS with wrong Element type");
640 }
641
642 int N = A.getType().getX();
643
644 if (X.getType().getY() > 1) {
645 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
646 }
647 if (N != A.getType().getY()) {
648 throw new RSRuntimeException("A must be a symmetric matrix");
649 }
650
651 int expectedXDim = 1 + (N - 1) * incX;
652 if (X.getType().getX() != expectedXDim) {
653 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
654 }
655 return N;
656 }
657 static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) {
658 validateUplo(Uplo);
659 if (!Ap.getType().getElement().isCompatible(e) ||
660 !X.getType().getElement().isCompatible(e)) {
661 throw new RSRuntimeException("Called BLAS with wrong Element type");
662 }
663 if (X.getType().getY() > 1) {
664 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
665 }
666
667 if (Ap.getType().getY() > 1) {
668 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
669 }
670
671 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
672 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
673 throw new RSRuntimeException("Invalid dimension for Ap");
674 }
675
676 int expectedXDim = 1 + (N - 1) * incX;
677 if (X.getType().getX() != expectedXDim) {
678 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
679 }
680
681 return N;
682 }
683
684 static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
685 validateUplo(Uplo);
686 if (!A.getType().getElement().isCompatible(e) ||
687 !X.getType().getElement().isCompatible(e) ||
688 !Y.getType().getElement().isCompatible(e)) {
689 throw new RSRuntimeException("Called BLAS with wrong Element type");
690 }
691
692 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
693 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
694 }
695
696 int N = A.getType().getX();
697
698 if (N != A.getType().getY()) {
699 throw new RSRuntimeException("A must be a symmetric matrix");
700 }
701
702 int expectedXDim = 1 + (N - 1) * incX;
703 int expectedYDim = 1 + (N - 1) * incY;
704 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
705 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
706 }
707 return N;
708
709 }
710 static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
711 validateUplo(Uplo);
712 if (!Ap.getType().getElement().isCompatible(e) ||
713 !X.getType().getElement().isCompatible(e) ||
714 !Y.getType().getElement().isCompatible(e)) {
715 throw new RSRuntimeException("Called BLAS with wrong Element type");
716 }
717 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
718 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
719 }
720
721 if (Ap.getType().getY() > 1) {
722 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
723 }
724
725 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
726 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
727 throw new RSRuntimeException("Invalid dimension for Ap");
728 }
729
730 int expectedXDim = 1 + (N - 1) * incX;
731 int expectedYDim = 1 + (N - 1) * incY;
732 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
733 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
734 }
735
736 return N;
737 }
738
739 void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
740 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
741 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
742 }
743 void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
744 // SBMV is the same as SYMV
745 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
746 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
747 }
748 void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
749 int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY);
750 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
751 }
752 void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
753 int M = A.getType().getY();
754 int N = A.getType().getX();
755 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
756 }
757 void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
758 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
759 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
760 }
761 void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
762 int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap);
763 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
764 }
765 void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
766 int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A);
767 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
768 }
769 void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
770 int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap);
771 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
772 }
773 void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
774 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
775 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
776 }
777 void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
778 // SBMV is the same as SYMV
779 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
780 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
781 }
782 void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
783 int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY);
784 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
785 }
786 void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
787 int M = A.getType().getY();
788 int N = A.getType().getX();
789 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
790 }
791 void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
792 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
793 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
794 }
795 void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
796 int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap);
797 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
798 }
799 void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
800 int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A);
801 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
802 }
803 void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
804 int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap);
805 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
806 }
807
808
809 /**
810 * Level 2, C and Z only
811 */
812
813 static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
814 if (!A.getType().getElement().isCompatible(e) ||
815 !X.getType().getElement().isCompatible(e) ||
816 !Y.getType().getElement().isCompatible(e)) {
817 throw new RSRuntimeException("Called BLAS with wrong Element type");
818 }
819 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
820 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
821 }
822
823 int M = A.getType().getY();
824 int N = A.getType().getX();
825
826 int expectedXDim = 1 + (N - 1) * incX;
827 if (X.getType().getX() != expectedXDim) {
828 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
829 }
830 int expectedYDim = 1 + (N - 1) * incY;
831 if (Y.getType().getX() != expectedYDim) {
832 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
833 }
834
835 }
836
837 void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
838 // HEMV is the same as SYR2 validation-wise
839 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
840 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
841 }
842 void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
843 // HBMV is the same as SYR2 validation-wise
844 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
845 if (K < 0) {
846 throw new RSRuntimeException("K must be 0 or greater for HBMV");
847 }
848 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
849 }
850 void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
851 // HPMV is the same as SPR2
852 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
853 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
854 }
855 void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
856 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
857 int M = A.getType().getY();
858 int N = A.getType().getX();
859 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
860 }
861 void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
862 // same as GERU
863 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
864 int M = A.getType().getY();
865 int N = A.getType().getX();
866 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
867 }
868 void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
869 // same as SYR
870 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
871 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
872 }
873 void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
874 // equivalent to SPR for validation
875 int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap);
876 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
877 }
878 void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
879 // same as SYR2
880 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
881 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
882 }
883 void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
884 // same as SPR2
885 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
886 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
887 }
888 void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
889 // HEMV is the same as SYR2 validation-wise
890 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
891 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
892 }
893 void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
894 // HBMV is the same as SYR2 validation-wise
895 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
896 if (K < 0) {
897 throw new RSRuntimeException("K must be 0 or greater for HBMV");
898 }
899 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
900 }
901 void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
902 // HPMV is the same as SPR2
903 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
904 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
905 }
906 void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
907 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
908 int M = A.getType().getY();
909 int N = A.getType().getX();
910 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
911 }
912 void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
913 // same as GERU
914 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
915 int M = A.getType().getY();
916 int N = A.getType().getX();
917 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
918 }
919 void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
920 // same as SYR
921 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
922 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
923 }
924 void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
925 // equivalent to SPR for validation
926 int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap);
927 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
928 }
929 void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
930 // same as SYR2
931 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
932 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
933 }
934 void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
935 // same as SPR2
936 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
937 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
938 }
939
940
941 /**
942 * Level 3 BLAS
943 */
944
945 static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
946 int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1;
947 if ((A != null && !A.getType().getElement().isCompatible(e)) ||
948 (B != null && !B.getType().getElement().isCompatible(e)) ||
949 (C != null && !C.getType().getElement().isCompatible(e))) {
950 throw new RSRuntimeException("Called BLAS with wrong Element type");
951 }
952 if (C != null) {
953 cX = C.getType().getY();
954 cY = C.getType().getX();
955 }
956 if (Side == RIGHT) {
957 if (B != null) {
958 bX = A.getType().getY();
959 bY = A.getType().getX();
960 }
961 if (A != null) {
962 aX = B.getType().getY();
963 aY = B.getType().getX();
964 }
965 } else {
966 if (A != null) {
967 if (TransA == TRANSPOSE) {
968 aY = A.getType().getY();
969 aX = A.getType().getX();
970 } else {
971 aX = A.getType().getY();
972 aY = A.getType().getX();
973 }
974 }
975 if (B != null) {
976 if (TransB == TRANSPOSE) {
977 bY = B.getType().getY();
978 bX = B.getType().getX();
979 } else {
980 bX = B.getType().getY();
981 bY = B.getType().getX();
982 }
983 }
984 }
985 if (A != null && B != null && C != null) {
986 if (aY != bX || aX != cX || bY != cY) {
987 throw new RSRuntimeException("Called BLAS with invalid dimensions");
988 }
989 } else if (A != null && C != null) {
990 // A and C only
991 if (aX != cY || aY != cX) {
992 throw new RSRuntimeException("Called BLAS with invalid dimensions");
993 }
994 } else if (A != null && B != null) {
995 // A and B only
996 }
997
998 }
999
1000 public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A,
1001 Allocation B, float beta, Allocation C) {
1002 validateTranspose(TransA);
1003 validateTranspose(TransB);
1004 validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
1005
1006 int M = -1, N = -1, K = -1;
1007 if (TransA == TRANSPOSE) {
1008 M = A.getType().getX();
1009 K = A.getType().getY();
1010 } else {
1011 M = A.getType().getY();
1012 K = A.getType().getX();
1013 }
1014 if (TransB == TRANSPOSE) {
1015 N = B.getType().getY();
1016 } else {
1017 N = B.getType().getX();
1018 }
1019 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1020 beta, C.getID(mRS), 0, 0, 0, 0);
1021 }
1022 public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A,
1023 Allocation B, double beta, Allocation C) {
1024 validateTranspose(TransA);
1025 validateTranspose(TransB);
1026 validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
1027 int M = -1, N = -1, K = -1;
1028 if (TransA == TRANSPOSE) {
1029 M = A.getType().getX();
1030 K = A.getType().getY();
1031 } else {
1032 M = A.getType().getY();
1033 K = A.getType().getX();
1034 }
1035 if (TransB == TRANSPOSE) {
1036 N = B.getType().getY();
1037 } else {
1038 N = B.getType().getX();
1039 }
1040 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1041 beta, C.getID(mRS), 0, 0, 0, 0);
1042 }
1043 public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A,
1044 Allocation B, Float2 beta, Allocation C) {
1045 validateTranspose(TransA);
1046 validateTranspose(TransB);
1047 validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
1048 int M = -1, N = -1, K = -1;
1049 if (TransA == TRANSPOSE) {
1050 M = A.getType().getX();
1051 K = A.getType().getY();
1052 } else {
1053 M = A.getType().getY();
1054 K = A.getType().getX();
1055 }
1056 if (TransB == TRANSPOSE) {
1057 N = B.getType().getY();
1058 } else {
1059 N = B.getType().getX();
1060 }
1061 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1062 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1063 }
1064
1065 public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A,
1066 Allocation B, Double2 beta, Allocation C) {
1067 validateTranspose(TransA);
1068 validateTranspose(TransB);
1069 validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
1070 int M = -1, N = -1, K = -1;
1071 if (TransA == TRANSPOSE) {
1072 M = A.getType().getX();
1073 K = A.getType().getY();
1074 } else {
1075 M = A.getType().getY();
1076 K = A.getType().getX();
1077 }
1078 if (TransB == TRANSPOSE) {
1079 N = B.getType().getY();
1080 } else {
1081 N = B.getType().getX();
1082 }
1083 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1084 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1085 }
1086
1087 public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A,
1088 Allocation B, float beta, Allocation C) {
1089 validateSide(Side);
1090 validateUplo(Uplo);
1091 validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
1092 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1093 beta, C.getID(mRS), 0, 0, 0, 0);
1094 }
1095 public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A,
1096 Allocation B, double beta, Allocation C) {
1097 validateSide(Side);
1098 validateUplo(Uplo);
1099 validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
1100 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1101 beta, C.getID(mRS), 0, 0, 0, 0);
1102 }
1103 public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A,
1104 Allocation B, Float2 beta, Allocation C) {
1105 validateSide(Side);
1106 validateUplo(Uplo);
1107 validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
1108 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1109 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1110 }
1111 public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A,
1112 Allocation B, Double2 beta, Allocation C) {
1113 validateSide(Side);
1114 validateUplo(Uplo);
1115 validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
1116 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1117 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1118 }
1119
1120 public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1121 validateTranspose(Trans);
1122 validateUplo(Uplo);
1123 validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
1124 int K = -1;
1125 if (Trans == TRANSPOSE) {
1126 K = A.getType().getY();
1127 } else {
1128 K = A.getType().getX();
1129 }
1130
1131 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1132 }
1133
1134 public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1135 validateTranspose(Trans);
1136 validateUplo(Uplo);
1137 validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
1138 int K = -1;
1139 if (Trans == TRANSPOSE) {
1140 K = A.getType().getY();
1141 } else {
1142 K = A.getType().getX();
1143 }
1144 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1145 }
1146 public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) {
1147 validateTranspose(Trans);
1148 validateUplo(Uplo);
1149 validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
1150 int K = -1;
1151 if (Trans == TRANSPOSE) {
1152 K = A.getType().getY();
1153 } else {
1154 K = A.getType().getX();
1155 }
1156 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY,
1157 C.getID(mRS), 0, 0, 0, 0);
1158 }
1159 public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) {
1160 validateTranspose(Trans);
1161 validateUplo(Uplo);
1162 validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
1163 int K = -1;
1164 if (Trans == TRANSPOSE) {
1165 K = A.getType().getY();
1166 } else {
1167 K = A.getType().getX();
1168 }
1169 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY,
1170 C.getID(mRS), 0, 0, 0, 0);
1171 }
1172
1173 static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1174 validateTranspose(Trans);
1175 if (!A.getType().getElement().isCompatible(e) ||
1176 !B.getType().getElement().isCompatible(e) ||
1177 !C.getType().getElement().isCompatible(e)) {
1178 throw new RSRuntimeException("Called BLAS with wrong Element type");
1179 }
1180 int Cdim = -1;
1181 // A is n x k if no transpose, k x n if transpose
1182 // C is n x n
1183 if (Trans == TRANSPOSE) {
1184 // check columns versus C
1185 Cdim = A.getType().getX();
1186 } else {
1187 // check rows versus C
1188 Cdim = A.getType().getY();
1189 }
1190 if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) {
1191 throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
1192 }
1193 // A dims == B dims
1194 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1195 throw new RSRuntimeException("Invalid A and B in SYR2K");
1196 }
1197 }
1198 public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
1199 validateUplo(Uplo);
1200 validateSYR2K(Element.F32(mRS), Trans, A, B, C);
1201 int K = -1;
1202 if (Trans == TRANSPOSE) {
1203 K = A.getType().getY();
1204 } else {
1205 K = A.getType().getX();
1206 }
1207 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1208 }
1209 public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
1210 validateUplo(Uplo);
1211 validateSYR2K(Element.F64(mRS), Trans, A, B, C);
1212 int K = -1;
1213 if (Trans == TRANSPOSE) {
1214 K = A.getType().getY();
1215 } else {
1216 K = A.getType().getX();
1217 }
1218 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1219 }
1220 public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
1221 validateUplo(Uplo);
1222 validateSYR2K(Element.F32_2(mRS), Trans, A, B, C);
1223 int K = -1;
1224 if (Trans == TRANSPOSE) {
1225 K = A.getType().getY();
1226 } else {
1227 K = A.getType().getX();
1228 }
1229 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1230 }
1231 public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
1232 validateUplo(Uplo);
1233 validateSYR2K(Element.F64_2(mRS), Trans, A, B, C);
1234 int K = -1;
1235 if (Trans == TRANSPOSE) {
1236 K = A.getType().getY();
1237 } else {
1238 K = A.getType().getX();
1239 }
1240 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1241 }
1242
1243 static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1244 validateSide(Side);
1245 validateTranspose(TransA);
1246 int aX = -1, aY = -1, bX = -1, bY = -1;
1247 if (!A.getType().getElement().isCompatible(e) ||
1248 !B.getType().getElement().isCompatible(e)) {
1249 throw new RSRuntimeException("Called BLAS with wrong Element type");
1250 }
1251 if (TransA == TRANSPOSE) {
1252 aY = A.getType().getY();
1253 aX = A.getType().getX();
1254 } else {
1255 aY = A.getType().getX();
1256 aX = A.getType().getY();
1257 }
1258 bX = B.getType().getY();
1259 bY = B.getType().getX();
1260 if (Side == LEFT) {
1261 if (aX == 0 || aY != bX) {
1262 throw new RSRuntimeException("Called TRMM with invalid matrices");
1263 }
1264 } else {
1265 if (bY != aX || aY == 0) {
1266 throw new RSRuntimeException("Called TRMM with invalid matrices");
1267 }
1268 }
1269 }
1270 public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1271 validateUplo(Uplo);
1272 validateDiag(Diag);
1273 validateTRMM(Element.F32(mRS), Side, TransA, A, B);
1274 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1275 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1276 }
1277 public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1278 validateUplo(Uplo);
1279 validateDiag(Diag);
1280 validateTRMM(Element.F64(mRS), Side, TransA, A, B);
1281 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1282 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1283 }
1284 public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1285 validateUplo(Uplo);
1286 validateDiag(Diag);
1287 validateTRMM(Element.F32_2(mRS), Side, TransA, A, B);
1288 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1289 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1290 }
1291 public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1292 validateUplo(Uplo);
1293 validateDiag(Diag);
1294 validateTRMM(Element.F64_2(mRS), Side, TransA, A, B);
1295 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1296 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1297 }
1298
1299 static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1300 int adim = -1, bX = -1, bY = -1;
1301 validateSide(Side);
1302 validateTranspose(TransA);
1303 if (!A.getType().getElement().isCompatible(e) ||
1304 !B.getType().getElement().isCompatible(e)) {
1305 throw new RSRuntimeException("Called BLAS with wrong Element type");
1306 }
1307 adim = A.getType().getX();
1308 if (adim != A.getType().getY()) {
1309 // this may be unnecessary, the restriction could potentially be relaxed
1310 // A needs to contain at least that symmetric matrix but could theoretically be larger
1311 // for now we assume adapters are sufficient, will reevaluate in the future
1312 throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
1313 }
1314 bX = B.getType().getY();
1315 bY = B.getType().getX();
1316 if (Side == LEFT) {
1317 // A is M*M
1318 if (adim != bY) {
1319 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1320 }
1321 } else {
1322 // A is N*N
1323 if (adim != bX) {
1324 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1325 }
1326 }
1327 }
1328 public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1329 validateUplo(Uplo);
1330 validateDiag(Diag);
1331 validateTRSM(Element.F32(mRS), Side, TransA, A, B);
1332 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1333 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1334 }
1335 public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1336 validateUplo(Uplo);
1337 validateDiag(Diag);
1338 validateTRSM(Element.F64(mRS), Side, TransA, A, B);
1339 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1340 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1341 }
1342 public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1343 validateUplo(Uplo);
1344 validateDiag(Diag);
1345 validateTRSM(Element.F32_2(mRS), Side, TransA, A, B);
1346 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1347 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1348 }
1349 public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1350 validateUplo(Uplo);
1351 validateDiag(Diag);
1352 validateTRSM(Element.F64_2(mRS), Side, TransA, A, B);
1353 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1354 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1355 }
1356
1357 static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) {
1358 validateSide(Side);
1359
1360 if (!A.getType().getElement().isCompatible(e) ||
1361 !B.getType().getElement().isCompatible(e) ||
1362 !C.getType().getElement().isCompatible(e)) {
1363 throw new RSRuntimeException("Called BLAS with wrong Element type");
1364 }
1365
1366 // A must be square; can potentially be relaxed similar to TRSM
1367 int adim = A.getType().getX();
1368 if (adim != A.getType().getY()) {
1369 throw new RSRuntimeException("Called HEMM with non-square A");
1370 }
1371 if ((Side == LEFT && adim != B.getType().getY()) ||
1372 (Side == RIGHT && adim != B.getType().getX())) {
1373 throw new RSRuntimeException("Called HEMM with invalid B");
1374 }
1375 if (B.getType().getX() != C.getType().getX() ||
1376 B.getType().getY() != C.getType().getY()) {
1377 throw new RSRuntimeException("Called HEMM with mismatched B and C");
1378 }
1379 }
1380 public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
1381 validateUplo(Uplo);
1382 validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1383 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
1384 alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1385 }
1386 public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
1387 validateUplo(Uplo);
1388 validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1389 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
1390 alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1391 }
1392
1393 static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) {
1394 if (!A.getType().getElement().isCompatible(e) ||
1395 !C.getType().getElement().isCompatible(e)) {
1396 throw new RSRuntimeException("Called BLAS with wrong Element type");
1397 }
1398 validateConjTranspose(Trans);
1399 int cdim = C.getType().getX();
1400 if (cdim != C.getType().getY()) {
1401 throw new RSRuntimeException("Called HERK with non-square C");
1402 }
1403 if (Trans == NO_TRANSPOSE) {
1404 if (cdim != A.getType().getX()) {
1405 throw new RSRuntimeException("Called HERK with invalid A");
1406 }
1407 } else {
1408 if (cdim != A.getType().getY()) {
1409 throw new RSRuntimeException("Called HERK with invalid A");
1410 }
1411 }
1412 }
1413 public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1414 validateUplo(Uplo);
1415 validateHERK(Element.F32_2(mRS), Trans, A, C);
1416 int k = 0;
1417 if (Trans == TRANSPOSE) {
1418 k = A.getType().getY();
1419 } else {
1420 k = A.getType().getX();
1421 }
1422 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1423 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1424 }
1425 public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1426 validateUplo(Uplo);
1427 validateHERK(Element.F64_2(mRS), Trans, A, C);
1428 int k = 0;
1429 if (Trans == TRANSPOSE) {
1430 k = A.getType().getY();
1431 } else {
1432 k = A.getType().getX();
1433 }
1434 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1435 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1436 }
1437
1438 static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1439 if (!A.getType().getElement().isCompatible(e) ||
1440 !B.getType().getElement().isCompatible(e) ||
1441 !C.getType().getElement().isCompatible(e)) {
1442 throw new RSRuntimeException("Called BLAS with wrong Element type");
1443 }
1444 validateConjTranspose(Trans);
1445 int cdim = C.getType().getX();
1446 if (cdim != C.getType().getY()) {
1447 throw new RSRuntimeException("Called HER2K with non-square C");
1448 }
1449 if (Trans == NO_TRANSPOSE) {
1450 if (A.getType().getY() != cdim) {
1451 throw new RSRuntimeException("Called HER2K with invalid matrices");
1452 }
1453 } else {
1454 if (A.getType().getX() != cdim) {
1455 throw new RSRuntimeException("Called HER2K with invalid matrices");
1456 }
1457 }
1458 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1459 throw new RSRuntimeException("Called HER2K with invalid A and B matrices");
1460 }
1461 }
1462 public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) {
1463 validateUplo(Uplo);
1464 validateHER2K(Element.F32_2(mRS), Trans, A, B, C);
1465 int k = 0;
1466 if (Trans == NO_TRANSPOSE) {
1467 k = A.getType().getX();
1468 } else {
1469 k = A.getType().getY();
1470 }
1471 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1472 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1473 }
1474 public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) {
1475 validateUplo(Uplo);
1476 validateHER2K(Element.F64_2(mRS), Trans, A, B, C);
1477 int k = 0;
1478 if (Trans == NO_TRANSPOSE) {
1479 k = A.getType().getX();
1480 } else {
1481 k = A.getType().getY();
1482 }
1483 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1484 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1485 }
1486
1487
1488
1489}