blob: 1ea7108d5e6e7086b1532ab4b0dbc5f9c8c8018d [file] [log] [blame]
Tim Murray25207df2015-01-12 16:47:56 -08001/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import android.annotation.IntDef;
20import java.lang.annotation.Retention;
21import java.lang.annotation.RetentionPolicy;
22
23/**
24 *
25 * BLAS
26 *
27 * @hide
28 **/
29public final class ScriptIntrinsicBLAS extends ScriptIntrinsic {
30 private Allocation mLUT;
31
32 private ScriptIntrinsicBLAS(long id, RenderScript rs) {
33 super(id, rs);
34 }
35
36 private static final int RsBlas_sdsdot = 1;
37 private static final int RsBlas_dsdot = 2;
38 private static final int RsBlas_sdot = 3;
39 private static final int RsBlas_ddot = 4;
40 private static final int RsBlas_cdotu_sub = 5;
41 private static final int RsBlas_cdotc_sub = 6;
42 private static final int RsBlas_zdotu_sub = 7;
43 private static final int RsBlas_zdotc_sub = 8;
44 private static final int RsBlas_snrm2 = 9;
45 private static final int RsBlas_sasum = 10;
46 private static final int RsBlas_dnrm2 = 11;
47 private static final int RsBlas_dasum = 12;
48 private static final int RsBlas_scnrm2 = 13;
49 private static final int RsBlas_scasum = 14;
50 private static final int RsBlas_dznrm2 = 15;
51 private static final int RsBlas_dzasum = 16;
52 private static final int RsBlas_isamax = 17;
53 private static final int RsBlas_idamax = 18;
54 private static final int RsBlas_icamax = 19;
55 private static final int RsBlas_izamax = 20;
56 private static final int RsBlas_sswap = 21;
57 private static final int RsBlas_scopy = 22;
58 private static final int RsBlas_saxpy = 23;
59 private static final int RsBlas_dswap = 24;
60 private static final int RsBlas_dcopy = 25;
61 private static final int RsBlas_daxpy = 26;
62 private static final int RsBlas_cswap = 27;
63 private static final int RsBlas_ccopy = 28;
64 private static final int RsBlas_caxpy = 29;
65 private static final int RsBlas_zswap = 30;
66 private static final int RsBlas_zcopy = 31;
67 private static final int RsBlas_zaxpy = 32;
68 private static final int RsBlas_srotg = 33;
69 private static final int RsBlas_srotmg = 34;
70 private static final int RsBlas_srot = 35;
71 private static final int RsBlas_srotm = 36;
72 private static final int RsBlas_drotg = 37;
73 private static final int RsBlas_drotmg = 38;
74 private static final int RsBlas_drot = 39;
75 private static final int RsBlas_drotm = 40;
76 private static final int RsBlas_sscal = 41;
77 private static final int RsBlas_dscal = 42;
78 private static final int RsBlas_cscal = 43;
79 private static final int RsBlas_zscal = 44;
80 private static final int RsBlas_csscal = 45;
81 private static final int RsBlas_zdscal = 46;
82 private static final int RsBlas_sgemv = 47;
83 private static final int RsBlas_sgbmv = 48;
84 private static final int RsBlas_strmv = 49;
85 private static final int RsBlas_stbmv = 50;
86 private static final int RsBlas_stpmv = 51;
87 private static final int RsBlas_strsv = 52;
88 private static final int RsBlas_stbsv = 53;
89 private static final int RsBlas_stpsv = 54;
90 private static final int RsBlas_dgemv = 55;
91 private static final int RsBlas_dgbmv = 56;
92 private static final int RsBlas_dtrmv = 57;
93 private static final int RsBlas_dtbmv = 58;
94 private static final int RsBlas_dtpmv = 59;
95 private static final int RsBlas_dtrsv = 60;
96 private static final int RsBlas_dtbsv = 61;
97 private static final int RsBlas_dtpsv = 62;
98 private static final int RsBlas_cgemv = 63;
99 private static final int RsBlas_cgbmv = 64;
100 private static final int RsBlas_ctrmv = 65;
101 private static final int RsBlas_ctbmv = 66;
102 private static final int RsBlas_ctpmv = 67;
103 private static final int RsBlas_ctrsv = 68;
104 private static final int RsBlas_ctbsv = 69;
105 private static final int RsBlas_ctpsv = 70;
106 private static final int RsBlas_zgemv = 71;
107 private static final int RsBlas_zgbmv = 72;
108 private static final int RsBlas_ztrmv = 73;
109 private static final int RsBlas_ztbmv = 74;
110 private static final int RsBlas_ztpmv = 75;
111 private static final int RsBlas_ztrsv = 76;
112 private static final int RsBlas_ztbsv = 77;
113 private static final int RsBlas_ztpsv = 78;
114 private static final int RsBlas_ssymv = 79;
115 private static final int RsBlas_ssbmv = 80;
116 private static final int RsBlas_sspmv = 81;
117 private static final int RsBlas_sger = 82;
118 private static final int RsBlas_ssyr = 83;
119 private static final int RsBlas_sspr = 84;
120 private static final int RsBlas_ssyr2 = 85;
121 private static final int RsBlas_sspr2 = 86;
122 private static final int RsBlas_dsymv = 87;
123 private static final int RsBlas_dsbmv = 88;
124 private static final int RsBlas_dspmv = 89;
125 private static final int RsBlas_dger = 90;
126 private static final int RsBlas_dsyr = 91;
127 private static final int RsBlas_dspr = 92;
128 private static final int RsBlas_dsyr2 = 93;
129 private static final int RsBlas_dspr2 = 94;
130 private static final int RsBlas_chemv = 95;
131 private static final int RsBlas_chbmv = 96;
132 private static final int RsBlas_chpmv = 97;
133 private static final int RsBlas_cgeru = 98;
134 private static final int RsBlas_cgerc = 99;
135 private static final int RsBlas_cher = 100;
136 private static final int RsBlas_chpr = 101;
137 private static final int RsBlas_cher2 = 102;
138 private static final int RsBlas_chpr2 = 103;
139 private static final int RsBlas_zhemv = 104;
140 private static final int RsBlas_zhbmv = 105;
141 private static final int RsBlas_zhpmv = 106;
142 private static final int RsBlas_zgeru = 107;
143 private static final int RsBlas_zgerc = 108;
144 private static final int RsBlas_zher = 109;
145 private static final int RsBlas_zhpr = 110;
146 private static final int RsBlas_zher2 = 111;
147 private static final int RsBlas_zhpr2 = 112;
148 private static final int RsBlas_sgemm = 113;
149 private static final int RsBlas_ssymm = 114;
150 private static final int RsBlas_ssyrk = 115;
151 private static final int RsBlas_ssyr2k = 116;
152 private static final int RsBlas_strmm = 117;
153 private static final int RsBlas_strsm = 118;
154 private static final int RsBlas_dgemm = 119;
155 private static final int RsBlas_dsymm = 120;
156 private static final int RsBlas_dsyrk = 121;
157 private static final int RsBlas_dsyr2k = 122;
158 private static final int RsBlas_dtrmm = 123;
159 private static final int RsBlas_dtrsm = 124;
160 private static final int RsBlas_cgemm = 125;
161 private static final int RsBlas_csymm = 126;
162 private static final int RsBlas_csyrk = 127;
163 private static final int RsBlas_csyr2k = 128;
164 private static final int RsBlas_ctrmm = 129;
165 private static final int RsBlas_ctrsm = 130;
166 private static final int RsBlas_zgemm = 131;
167 private static final int RsBlas_zsymm = 132;
168 private static final int RsBlas_zsyrk = 133;
169 private static final int RsBlas_zsyr2k = 134;
170 private static final int RsBlas_ztrmm = 135;
171 private static final int RsBlas_ztrsm = 136;
172 private static final int RsBlas_chemm = 137;
173 private static final int RsBlas_cherk = 138;
174 private static final int RsBlas_cher2k = 139;
175 private static final int RsBlas_zhemm = 140;
176 private static final int RsBlas_zherk = 141;
177 private static final int RsBlas_zher2k = 142;
178
Tim Murray9cb16a22015-04-01 11:07:16 -0700179 // BLAS extensions start here
180 private static final int RsBlas_bnnm = 1000;
181
Tim Murray25207df2015-01-12 16:47:56 -0800182 /**
183 */
184 public static ScriptIntrinsicBLAS create(RenderScript rs) {
185 long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs));
186 return new ScriptIntrinsicBLAS(id, rs);
187 }
188
189 @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE})
190 @Retention(RetentionPolicy.SOURCE)
191 public @interface Transpose {}
192
193 @IntDef({UPPER, LOWER})
194 @Retention(RetentionPolicy.SOURCE)
195 public @interface Uplo {}
196
197 @IntDef({NON_UNIT, UNIT})
198 @Retention(RetentionPolicy.SOURCE)
199 public @interface Diag {}
200
201 @IntDef({LEFT, RIGHT})
202 @Retention(RetentionPolicy.SOURCE)
203 public @interface Side {}
204
205 public static final int NO_TRANSPOSE = 111;
206 public static final int TRANSPOSE = 112;
207 public static final int CONJ_TRANSPOSE = 113;
208
209 public static final int UPPER = 121;
210 public static final int LOWER = 122;
211
212 public static final int NON_UNIT = 131;
213 public static final int UNIT = 132;
214
215 public static final int LEFT = 141;
216 public static final int RIGHT = 142;
217
218 static void validateSide(@Side int Side) {
219 if (Side != LEFT && Side != RIGHT) {
220 throw new RSRuntimeException("Invalid side passed to BLAS");
221 }
222 }
223
224 static void validateTranspose(@Transpose int Trans) {
225 if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE &&
226 Trans != CONJ_TRANSPOSE) {
227 throw new RSRuntimeException("Invalid transpose passed to BLAS");
228 }
229 }
230
231 static void validateConjTranspose(@Transpose int Trans) {
232 if (Trans != NO_TRANSPOSE &&
233 Trans != CONJ_TRANSPOSE) {
234 throw new RSRuntimeException("Invalid transpose passed to BLAS");
235 }
236 }
237
238 static void validateDiag(@Diag int Diag) {
239 if (Diag != NON_UNIT && Diag != UNIT) {
240 throw new RSRuntimeException("Invalid diag passed to BLAS");
241 }
242 }
243
244 static void validateUplo(@Uplo int Uplo) {
Miao Wangb530d8e2015-04-24 11:19:53 -0700245 if (Uplo != UPPER && Uplo != LOWER) {
Tim Murray25207df2015-01-12 16:47:56 -0800246 throw new RSRuntimeException("Invalid uplo passed to BLAS");
247 }
248 }
249
250
251 /**
252 * Level 2 BLAS
253 */
254
255 static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) {
256 validateTranspose(TransA);
257 int M = A.getType().getY();
258 int N = A.getType().getX();
259 if (!A.getType().getElement().isCompatible(e) ||
260 !X.getType().getElement().isCompatible(e) ||
261 !Y.getType().getElement().isCompatible(e)) {
262 throw new RSRuntimeException("Called BLAS with wrong Element type");
263 }
264 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
265 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
266 }
267
268 if (incX <= 0 || incY <= 0) {
269 throw new RSRuntimeException("Vector increments must be greater than 0");
270 }
271 int expectedXDim = -1, expectedYDim = -1;
272 if (TransA == NO_TRANSPOSE) {
273 expectedXDim = 1 + (N - 1) * incX;
274 expectedYDim = 1 + (M - 1) * incY;
275 } else {
276 expectedXDim = 1 + (M - 1) * incX;
277 expectedYDim = 1 + (N - 1) * incY;
278 }
279 if (X.getType().getX() != expectedXDim ||
Miao Wang2b6fad92015-04-23 15:06:09 -0700280 Y.getType().getX() != expectedYDim) {
Tim Murray25207df2015-01-12 16:47:56 -0800281 throw new RSRuntimeException("Incorrect vector dimensions for GEMV");
282 }
283 }
Miao Wang6517eb62015-05-07 17:56:05 -0700284 public void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800285 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
286 int M = A.getType().getY();
287 int N = A.getType().getX();
288 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
289 }
Miao Wang6517eb62015-05-07 17:56:05 -0700290 public void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800291 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
292 int M = A.getType().getY();
293 int N = A.getType().getX();
294 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
295 }
Miao Wang6517eb62015-05-07 17:56:05 -0700296 public void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800297 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
298 int M = A.getType().getY();
299 int N = A.getType().getX();
300 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
301 }
Miao Wang6517eb62015-05-07 17:56:05 -0700302 public void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800303 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
304 int M = A.getType().getY();
305 int N = A.getType().getX();
306 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
307 }
308
Miao Wang6517eb62015-05-07 17:56:05 -0700309 public void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800310 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
311 validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY);
312 if (KL < 0 || KU < 0) {
313 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
314 }
315 int M = A.getType().getY();
316 int N = A.getType().getX();
317 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
318 }
Miao Wang6517eb62015-05-07 17:56:05 -0700319 public void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800320 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
321 validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY);
322 if (KL < 0 || KU < 0) {
323 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
324 }
325 int M = A.getType().getY();
326 int N = A.getType().getX();
327 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU);
328 }
Miao Wang6517eb62015-05-07 17:56:05 -0700329 public void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800330 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
331 validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY);
332 if (KL < 0 || KU < 0) {
333 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
334 }
335 int M = A.getType().getY();
336 int N = A.getType().getX();
337 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
338 }
Miao Wang6517eb62015-05-07 17:56:05 -0700339 public void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800340 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
341 validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY);
342 if (KL < 0 || KU < 0) {
343 throw new RSRuntimeException("KL and KU must be greater than or equal to 0");
344 }
345 int M = A.getType().getY();
346 int N = A.getType().getX();
347 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU);
348 }
349
Miao Wang2b6fad92015-04-23 15:06:09 -0700350 static void validateTRMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800351 validateTranspose(TransA);
Miao Wang2b6fad92015-04-23 15:06:09 -0700352 validateUplo(Uplo);
353 validateDiag(Diag);
Tim Murray25207df2015-01-12 16:47:56 -0800354 int N = A.getType().getY();
355 if (A.getType().getX() != N) {
356 throw new RSRuntimeException("A must be a square matrix for TRMV");
357 }
358 if (!A.getType().getElement().isCompatible(e) ||
359 !X.getType().getElement().isCompatible(e)) {
360 throw new RSRuntimeException("Called BLAS with wrong Element type");
361 }
362 if (X.getType().getY() > 1) {
363 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
364 }
365
366 if (incX <= 0) {
367 throw new RSRuntimeException("Vector increments must be greater than 0");
368 }
369 int expectedXDim = 1 + (N - 1) * incX;
370 if (X.getType().getX() != expectedXDim) {
371 throw new RSRuntimeException("Incorrect vector dimensions for TRMV");
372 }
373 }
374
375 static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
376 validateTranspose(TransA);
377 validateUplo(Uplo);
378 validateDiag(Diag);
379 if (!Ap.getType().getElement().isCompatible(e) ||
380 !X.getType().getElement().isCompatible(e)) {
381 throw new RSRuntimeException("Called BLAS with wrong Element type");
382 }
383 if (X.getType().getY() > 1) {
384 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
385 }
386
387 if (Ap.getType().getY() > 1) {
388 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
389 }
390
391 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
Miao Wang2b6fad92015-04-23 15:06:09 -0700392 //is it really doing anything?
Tim Murray25207df2015-01-12 16:47:56 -0800393 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
394 throw new RSRuntimeException("Invalid dimension for Ap");
395 }
Miao Wang2b6fad92015-04-23 15:06:09 -0700396 if (incX <= 0) {
397 throw new RSRuntimeException("Vector increments must be greater than 0");
398 }
Tim Murray25207df2015-01-12 16:47:56 -0800399 int expectedXDim = 1 + (N - 1) * incX;
400 if (X.getType().getX() != expectedXDim) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700401 throw new RSRuntimeException("Incorrect vector dimensions for TPMV");
Tim Murray25207df2015-01-12 16:47:56 -0800402 }
403
404 return N;
405 }
406
Miao Wang6517eb62015-05-07 17:56:05 -0700407 public void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700408 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800409 int N = A.getType().getY();
410 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
411 }
Miao Wang6517eb62015-05-07 17:56:05 -0700412 public void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700413 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800414 int N = A.getType().getY();
415 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
416 }
Miao Wang6517eb62015-05-07 17:56:05 -0700417 public void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700418 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800419 int N = A.getType().getY();
420 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
421 }
Miao Wang6517eb62015-05-07 17:56:05 -0700422 public void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700423 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800424 int N = A.getType().getY();
425 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
426 }
Miao Wang2b6fad92015-04-23 15:06:09 -0700427
Miao Wang6517eb62015-05-07 17:56:05 -0700428 public void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700429 // TBMV has the same requirements as TRMV + K >= 0
430 if (K < 0) {
431 throw new RSRuntimeException("K must be greater than or equal to 0");
432 }
433 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800434 int N = A.getType().getY();
435 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
436 }
Miao Wang6517eb62015-05-07 17:56:05 -0700437 public void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700438 // TBMV has the same requirements as TRMV + K >= 0
439 if (K < 0) {
440 throw new RSRuntimeException("K must be greater than or equal to 0");
441 }
442 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800443 int N = A.getType().getY();
444 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
445 }
Miao Wang6517eb62015-05-07 17:56:05 -0700446 public void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700447 // TBMV has the same requirements as TRMV + K >= 0
448 if (K < 0) {
449 throw new RSRuntimeException("K must be greater than or equal to 0");
450 }
451 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800452 int N = A.getType().getY();
453 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
454 }
Miao Wang6517eb62015-05-07 17:56:05 -0700455 public void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700456 // TBMV has the same requirements as TRMV + K >= 0
457 if (K < 0) {
458 throw new RSRuntimeException("K must be greater than or equal to 0");
459 }
460 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800461 int N = A.getType().getY();
462 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
463 }
Miao Wang6517eb62015-05-07 17:56:05 -0700464 public void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800465 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
466 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
467 }
Miao Wang6517eb62015-05-07 17:56:05 -0700468 public void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800469 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
470 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
471 }
Miao Wang6517eb62015-05-07 17:56:05 -0700472 public void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800473 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
474 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
475 }
Miao Wang6517eb62015-05-07 17:56:05 -0700476 public void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800477 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
478 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
479 }
Miao Wang6517eb62015-05-07 17:56:05 -0700480 public void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800481 // TRSV is the same as TRMV
Miao Wang2b6fad92015-04-23 15:06:09 -0700482 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800483 int N = A.getType().getY();
484 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
485
486 }
Miao Wang6517eb62015-05-07 17:56:05 -0700487 public void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800488 // TRSV is the same as TRMV
Miao Wang2b6fad92015-04-23 15:06:09 -0700489 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800490 int N = A.getType().getY();
491 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
492
493 }
Miao Wang6517eb62015-05-07 17:56:05 -0700494 public void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800495 // TRSV is the same as TRMV
Miao Wang2b6fad92015-04-23 15:06:09 -0700496 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800497 int N = A.getType().getY();
498 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
499
500 }
Miao Wang6517eb62015-05-07 17:56:05 -0700501 public void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800502 // TRSV is the same as TRMV
Miao Wang2b6fad92015-04-23 15:06:09 -0700503 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800504 int N = A.getType().getY();
505 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
506
507 }
Miao Wang6517eb62015-05-07 17:56:05 -0700508 public void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700509 // TBSV is the same as TRMV + K >= 0
510 validateTRMV(Element.F32(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800511 int N = A.getType().getY();
512 if (K < 0) {
513 throw new RSRuntimeException("Number of diagonals must be positive");
514 }
515 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
516 }
Miao Wang6517eb62015-05-07 17:56:05 -0700517 public void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700518 // TBSV is the same as TRMV + K >= 0
519 validateTRMV(Element.F64(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800520 int N = A.getType().getY();
521 if (K < 0) {
522 throw new RSRuntimeException("Number of diagonals must be positive");
523 }
524 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
525 }
Miao Wang6517eb62015-05-07 17:56:05 -0700526 public void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700527 // TBSV is the same as TRMV + K >= 0
528 validateTRMV(Element.F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800529 int N = A.getType().getY();
530 if (K < 0) {
531 throw new RSRuntimeException("Number of diagonals must be positive");
532 }
533 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
534 }
Miao Wang6517eb62015-05-07 17:56:05 -0700535 public void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700536 // TBSV is the same as TRMV + K >= 0
537 validateTRMV(Element.F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
Tim Murray25207df2015-01-12 16:47:56 -0800538 int N = A.getType().getY();
539 if (K < 0) {
540 throw new RSRuntimeException("Number of diagonals must be positive");
541 }
542 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
543 }
Miao Wang6517eb62015-05-07 17:56:05 -0700544 public void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800545 // TPSV is same as TPMV
546 int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
547 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
548 }
Miao Wang6517eb62015-05-07 17:56:05 -0700549 public void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800550 // TPSV is same as TPMV
551 int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
552 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0);
553 }
Miao Wang6517eb62015-05-07 17:56:05 -0700554 public void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800555 // TPSV is same as TPMV
556 int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
557 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
558 }
Miao Wang6517eb62015-05-07 17:56:05 -0700559 public void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) {
Tim Murray25207df2015-01-12 16:47:56 -0800560 // TPSV is same as TPMV
561 int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
562 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0);
563 }
564
565 /**
566 * Level 2, S and D only
567 */
568 static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) {
569 validateUplo(Uplo);
570 int N = A.getType().getY();
571 if (A.getType().getX() != N) {
572 throw new RSRuntimeException("A must be a square matrix for SYMV");
573 }
574 if (!A.getType().getElement().isCompatible(e) ||
575 !X.getType().getElement().isCompatible(e) ||
576 !Y.getType().getElement().isCompatible(e) ) {
577 throw new RSRuntimeException("Called BLAS with wrong Element type");
578 }
579 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
580 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
581 }
582
583 if (incX <= 0 || incY <= 0) {
584 throw new RSRuntimeException("Vector increments must be greater than 0");
585 }
586 int expectedXDim = 1 + (N - 1) * incX;
587 if (X.getType().getX() != expectedXDim) {
588 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
589 }
590 int expectedYDim = 1 + (N - 1) * incY;
591 if (Y.getType().getX() != expectedYDim) {
592 throw new RSRuntimeException("Incorrect vector dimensions for SYMV");
593 }
594 return N;
595 }
596 static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) {
597 validateUplo(Uplo);
598 if (!Ap.getType().getElement().isCompatible(e) ||
599 !X.getType().getElement().isCompatible(e) ||
600 !Y.getType().getElement().isCompatible(e)) {
601 throw new RSRuntimeException("Called BLAS with wrong Element type");
602 }
603 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
604 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
605 }
606
607 if (Ap.getType().getY() > 1) {
608 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
609 }
610
611 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
612 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
613 throw new RSRuntimeException("Invalid dimension for Ap");
614 }
Miao Wang2b6fad92015-04-23 15:06:09 -0700615 if (incX <= 0 || incY <= 0) {
616 throw new RSRuntimeException("Vector increments must be greater than 0");
617 }
Tim Murray25207df2015-01-12 16:47:56 -0800618 int expectedXDim = 1 + (N - 1) * incX;
619 if (X.getType().getX() != expectedXDim) {
620 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
621 }
622 int expectedYDim = 1 + (N - 1) * incY;
623 if (Y.getType().getX() != expectedYDim) {
624 throw new RSRuntimeException("Incorrect vector dimensions for SPMV");
625 }
626
627 return N;
628 }
629 static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
630 if (!A.getType().getElement().isCompatible(e) ||
631 !X.getType().getElement().isCompatible(e) ||
632 !Y.getType().getElement().isCompatible(e) ) {
633 throw new RSRuntimeException("Called BLAS with wrong Element type");
634 }
635
636 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
637 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
638 }
639
640 int M = A.getType().getY();
641 int N = A.getType().getX();
642
643 if (N < 1 || M < 1) {
644 throw new RSRuntimeException("M and N must be 1 or greater for GER");
645 }
Miao Wang2b6fad92015-04-23 15:06:09 -0700646 if (incX <= 0 || incY <= 0) {
647 throw new RSRuntimeException("Vector increments must be greater than 0");
648 }
649 int expectedXDim = 1 + (M - 1) * incX;
Tim Murray25207df2015-01-12 16:47:56 -0800650 if (X.getType().getX() != expectedXDim) {
651 throw new RSRuntimeException("Incorrect vector dimensions for GER");
652 }
653 int expectedYDim = 1 + (N - 1) * incY;
654 if (Y.getType().getX() != expectedYDim) {
655 throw new RSRuntimeException("Incorrect vector dimensions for GER");
656 }
657
658
659 }
660 static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) {
661 validateUplo(Uplo);
662 if (!A.getType().getElement().isCompatible(e) ||
663 !X.getType().getElement().isCompatible(e)) {
664 throw new RSRuntimeException("Called BLAS with wrong Element type");
665 }
666
667 int N = A.getType().getX();
668
669 if (X.getType().getY() > 1) {
670 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
671 }
672 if (N != A.getType().getY()) {
673 throw new RSRuntimeException("A must be a symmetric matrix");
674 }
Miao Wang2b6fad92015-04-23 15:06:09 -0700675 if (incX <= 0) {
676 throw new RSRuntimeException("Vector increments must be greater than 0");
677 }
Tim Murray25207df2015-01-12 16:47:56 -0800678 int expectedXDim = 1 + (N - 1) * incX;
679 if (X.getType().getX() != expectedXDim) {
680 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
681 }
682 return N;
683 }
684 static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) {
685 validateUplo(Uplo);
686 if (!Ap.getType().getElement().isCompatible(e) ||
687 !X.getType().getElement().isCompatible(e)) {
688 throw new RSRuntimeException("Called BLAS with wrong Element type");
689 }
690 if (X.getType().getY() > 1) {
691 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
692 }
693
694 if (Ap.getType().getY() > 1) {
695 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
696 }
697
698 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
699 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
700 throw new RSRuntimeException("Invalid dimension for Ap");
701 }
Miao Wang2b6fad92015-04-23 15:06:09 -0700702 if (incX <= 0) {
703 throw new RSRuntimeException("Vector increments must be greater than 0");
704 }
Tim Murray25207df2015-01-12 16:47:56 -0800705 int expectedXDim = 1 + (N - 1) * incX;
706 if (X.getType().getX() != expectedXDim) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700707 throw new RSRuntimeException("Incorrect vector dimensions for SPR");
Tim Murray25207df2015-01-12 16:47:56 -0800708 }
709
710 return N;
711 }
712
713 static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
714 validateUplo(Uplo);
715 if (!A.getType().getElement().isCompatible(e) ||
716 !X.getType().getElement().isCompatible(e) ||
717 !Y.getType().getElement().isCompatible(e)) {
718 throw new RSRuntimeException("Called BLAS with wrong Element type");
719 }
720
721 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
722 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
723 }
724
725 int N = A.getType().getX();
726
727 if (N != A.getType().getY()) {
728 throw new RSRuntimeException("A must be a symmetric matrix");
729 }
Miao Wang2b6fad92015-04-23 15:06:09 -0700730 if (incX <= 0 || incY <= 0) {
731 throw new RSRuntimeException("Vector increments must be greater than 0");
732 }
Tim Murray25207df2015-01-12 16:47:56 -0800733 int expectedXDim = 1 + (N - 1) * incX;
734 int expectedYDim = 1 + (N - 1) * incY;
735 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
736 throw new RSRuntimeException("Incorrect vector dimensions for SYR");
737 }
738 return N;
739
740 }
741 static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
742 validateUplo(Uplo);
743 if (!Ap.getType().getElement().isCompatible(e) ||
744 !X.getType().getElement().isCompatible(e) ||
745 !Y.getType().getElement().isCompatible(e)) {
746 throw new RSRuntimeException("Called BLAS with wrong Element type");
747 }
748 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
749 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
750 }
751
752 if (Ap.getType().getY() > 1) {
753 throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1");
754 }
755
756 int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
757 if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
758 throw new RSRuntimeException("Invalid dimension for Ap");
759 }
Miao Wang2b6fad92015-04-23 15:06:09 -0700760 if (incX <= 0 || incY <= 0) {
761 throw new RSRuntimeException("Vector increments must be greater than 0");
762 }
Tim Murray25207df2015-01-12 16:47:56 -0800763 int expectedXDim = 1 + (N - 1) * incX;
764 int expectedYDim = 1 + (N - 1) * incY;
765 if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700766 throw new RSRuntimeException("Incorrect vector dimensions for SPR2");
Tim Murray25207df2015-01-12 16:47:56 -0800767 }
768
769 return N;
770 }
771
Miao Wang6517eb62015-05-07 17:56:05 -0700772 public void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800773 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
774 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
775 }
Miao Wang6517eb62015-05-07 17:56:05 -0700776 public void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700777 // SBMV is the same as SYMV + K >= 0
778 if (K < 0) {
779 throw new RSRuntimeException("K must be greater than or equal to 0");
780 }
Tim Murray25207df2015-01-12 16:47:56 -0800781 int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY);
782 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
783 }
Miao Wang6517eb62015-05-07 17:56:05 -0700784 public void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800785 int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY);
786 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
787 }
Miao Wang6517eb62015-05-07 17:56:05 -0700788 public void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800789 int M = A.getType().getY();
790 int N = A.getType().getX();
Miao Wang2b6fad92015-04-23 15:06:09 -0700791 validateGER(Element.F32(mRS), X, incX, Y, incY, A);
Tim Murray25207df2015-01-12 16:47:56 -0800792 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
793 }
Miao Wang6517eb62015-05-07 17:56:05 -0700794 public void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800795 int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A);
796 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
797 }
Miao Wang6517eb62015-05-07 17:56:05 -0700798 public void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800799 int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap);
800 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
801 }
Miao Wang6517eb62015-05-07 17:56:05 -0700802 public void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800803 int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A);
804 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
805 }
Miao Wang6517eb62015-05-07 17:56:05 -0700806 public void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800807 int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap);
808 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
809 }
Miao Wang6517eb62015-05-07 17:56:05 -0700810 public void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800811 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
812 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
813 }
Miao Wang6517eb62015-05-07 17:56:05 -0700814 public void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) {
Miao Wang2b6fad92015-04-23 15:06:09 -0700815 // SBMV is the same as SYMV + K >= 0
816 if (K < 0) {
817 throw new RSRuntimeException("K must be greater than or equal to 0");
818 }
Tim Murray25207df2015-01-12 16:47:56 -0800819 int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY);
820 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
821 }
Miao Wang6517eb62015-05-07 17:56:05 -0700822 public void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800823 int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY);
824 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0);
825 }
Miao Wang6517eb62015-05-07 17:56:05 -0700826 public void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800827 int M = A.getType().getY();
828 int N = A.getType().getX();
Miao Wang2b6fad92015-04-23 15:06:09 -0700829 validateGER(Element.F64(mRS), X, incX, Y, incY, A);
Tim Murray25207df2015-01-12 16:47:56 -0800830 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0);
831 }
Miao Wang6517eb62015-05-07 17:56:05 -0700832 public void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800833 int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A);
834 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0);
835 }
Miao Wang6517eb62015-05-07 17:56:05 -0700836 public void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800837 int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap);
838 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0);
839 }
Miao Wang6517eb62015-05-07 17:56:05 -0700840 public void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800841 int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A);
842 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0);
843 }
Miao Wang6517eb62015-05-07 17:56:05 -0700844 public void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800845 int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap);
846 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0);
847 }
848
849
850 /**
851 * Level 2, C and Z only
852 */
853
854 static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
855 if (!A.getType().getElement().isCompatible(e) ||
856 !X.getType().getElement().isCompatible(e) ||
857 !Y.getType().getElement().isCompatible(e)) {
858 throw new RSRuntimeException("Called BLAS with wrong Element type");
859 }
860 if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
861 throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1");
862 }
863
864 int M = A.getType().getY();
865 int N = A.getType().getX();
Miao Wang2b6fad92015-04-23 15:06:09 -0700866 if (incX <= 0 || incY <= 0) {
867 throw new RSRuntimeException("Vector increments must be greater than 0");
868 }
869 int expectedXDim = 1 + (M - 1) * incX;
Tim Murray25207df2015-01-12 16:47:56 -0800870 if (X.getType().getX() != expectedXDim) {
871 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
872 }
873 int expectedYDim = 1 + (N - 1) * incY;
874 if (Y.getType().getX() != expectedYDim) {
875 throw new RSRuntimeException("Incorrect vector dimensions for GERU");
876 }
877
878 }
879
Miao Wang6517eb62015-05-07 17:56:05 -0700880 public void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800881 // HEMV is the same as SYR2 validation-wise
882 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
883 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
884 }
Miao Wang6517eb62015-05-07 17:56:05 -0700885 public void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800886 // HBMV is the same as SYR2 validation-wise
887 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
888 if (K < 0) {
889 throw new RSRuntimeException("K must be 0 or greater for HBMV");
890 }
891 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
892 }
Miao Wang6517eb62015-05-07 17:56:05 -0700893 public void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800894 // HPMV is the same as SPR2
895 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
896 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
897 }
Miao Wang6517eb62015-05-07 17:56:05 -0700898 public void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800899 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
900 int M = A.getType().getY();
901 int N = A.getType().getX();
902 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
903 }
Miao Wang6517eb62015-05-07 17:56:05 -0700904 public void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800905 // same as GERU
906 validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A);
907 int M = A.getType().getY();
908 int N = A.getType().getX();
909 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
910 }
Miao Wang6517eb62015-05-07 17:56:05 -0700911 public void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800912 // same as SYR
Miao Wang2b6fad92015-04-23 15:06:09 -0700913 int N = validateSYR(Element.F32_2(mRS), Uplo, X, incX, A);
Tim Murray25207df2015-01-12 16:47:56 -0800914 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
915 }
Miao Wang6517eb62015-05-07 17:56:05 -0700916 public void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800917 // equivalent to SPR for validation
918 int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap);
919 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
920 }
Miao Wang6517eb62015-05-07 17:56:05 -0700921 public void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800922 // same as SYR2
923 int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A);
924 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
925 }
Miao Wang6517eb62015-05-07 17:56:05 -0700926 public void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800927 // same as SPR2
928 int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
929 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
930 }
Miao Wang6517eb62015-05-07 17:56:05 -0700931 public void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800932 // HEMV is the same as SYR2 validation-wise
933 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
934 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
935 }
Miao Wang6517eb62015-05-07 17:56:05 -0700936 public void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800937 // HBMV is the same as SYR2 validation-wise
938 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
939 if (K < 0) {
940 throw new RSRuntimeException("K must be 0 or greater for HBMV");
941 }
942 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
943 }
Miao Wang6517eb62015-05-07 17:56:05 -0700944 public void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) {
Tim Murray25207df2015-01-12 16:47:56 -0800945 // HPMV is the same as SPR2
946 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
947 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0);
948 }
Miao Wang6517eb62015-05-07 17:56:05 -0700949 public void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800950 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
951 int M = A.getType().getY();
952 int N = A.getType().getX();
953 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
954 }
Miao Wang6517eb62015-05-07 17:56:05 -0700955 public void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800956 // same as GERU
957 validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A);
958 int M = A.getType().getY();
959 int N = A.getType().getX();
960 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
961 }
Miao Wang6517eb62015-05-07 17:56:05 -0700962 public void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800963 // same as SYR
Miao Wangcc711792015-04-29 18:14:55 -0700964 int N = validateSYR(Element.F64_2(mRS), Uplo, X, incX, A);
Tim Murray25207df2015-01-12 16:47:56 -0800965 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0);
966 }
Miao Wang6517eb62015-05-07 17:56:05 -0700967 public void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800968 // equivalent to SPR for validation
969 int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap);
970 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0);
971 }
Miao Wang6517eb62015-05-07 17:56:05 -0700972 public void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
Tim Murray25207df2015-01-12 16:47:56 -0800973 // same as SYR2
974 int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A);
975 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0);
976 }
Miao Wang6517eb62015-05-07 17:56:05 -0700977 public void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
Tim Murray25207df2015-01-12 16:47:56 -0800978 // same as SPR2
979 int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
980 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0);
981 }
982
983
984 /**
985 * Level 3 BLAS
986 */
987
988 static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
Miao Wangb530d8e2015-04-24 11:19:53 -0700989 int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
Tim Murray25207df2015-01-12 16:47:56 -0800990 if ((A != null && !A.getType().getElement().isCompatible(e)) ||
991 (B != null && !B.getType().getElement().isCompatible(e)) ||
992 (C != null && !C.getType().getElement().isCompatible(e))) {
993 throw new RSRuntimeException("Called BLAS with wrong Element type");
994 }
Miao Wangb530d8e2015-04-24 11:19:53 -0700995 if (C == null) {
996 //since matrix C is used to store the result, it cannot be null.
997 throw new RSRuntimeException("Allocation C cannot be null");
Tim Murray25207df2015-01-12 16:47:56 -0800998 }
Miao Wangb530d8e2015-04-24 11:19:53 -0700999 cM = C.getType().getY();
1000 cN = C.getType().getX();
1001
Tim Murray25207df2015-01-12 16:47:56 -08001002 if (Side == RIGHT) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001003 if ((A == null && B != null) || (A != null && B == null)) {
1004 throw new RSRuntimeException("Provided Matrix A without Matrix B, or vice versa");
1005 }
Tim Murray25207df2015-01-12 16:47:56 -08001006 if (B != null) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001007 bM = A.getType().getY();
1008 bN = A.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001009 }
1010 if (A != null) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001011 aM = B.getType().getY();
1012 aN = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001013 }
1014 } else {
1015 if (A != null) {
Miao Wange1cf0952015-04-30 10:47:42 -07001016 if (TransA == TRANSPOSE || TransA == CONJ_TRANSPOSE) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001017 aN = A.getType().getY();
1018 aM = A.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001019 } else {
Miao Wangb530d8e2015-04-24 11:19:53 -07001020 aM = A.getType().getY();
1021 aN = A.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001022 }
1023 }
1024 if (B != null) {
Miao Wange1cf0952015-04-30 10:47:42 -07001025 if (TransB == TRANSPOSE || TransB == CONJ_TRANSPOSE) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001026 bN = B.getType().getY();
1027 bM = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001028 } else {
Miao Wangb530d8e2015-04-24 11:19:53 -07001029 bM = B.getType().getY();
1030 bN = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001031 }
1032 }
1033 }
1034 if (A != null && B != null && C != null) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001035 if (aN != bM || aM != cM || bN != cN) {
Tim Murray25207df2015-01-12 16:47:56 -08001036 throw new RSRuntimeException("Called BLAS with invalid dimensions");
1037 }
1038 } else if (A != null && C != null) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001039 // A and C only, for SYRK
1040 if (cM != cN) {
1041 throw new RSRuntimeException("Matrix C is not symmetric");
1042 }
1043 if (TransA != NO_TRANSPOSE) {
1044 if (aN != cM) {
1045 throw new RSRuntimeException("Called BLAS with invalid dimensions");
1046 }
1047 } else {
1048 if (aM != cM) {
1049 throw new RSRuntimeException("Called BLAS with invalid dimensions");
1050 }
Tim Murray25207df2015-01-12 16:47:56 -08001051 }
1052 } else if (A != null && B != null) {
1053 // A and B only
Miao Wangb530d8e2015-04-24 11:19:53 -07001054 if (aN != bM) {
1055 throw new RSRuntimeException("Called BLAS with invalid dimensions");
1056 }
Tim Murray25207df2015-01-12 16:47:56 -08001057 }
1058
1059 }
1060
1061 public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A,
1062 Allocation B, float beta, Allocation C) {
1063 validateTranspose(TransA);
1064 validateTranspose(TransB);
1065 validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C);
1066
1067 int M = -1, N = -1, K = -1;
Miao Wangb530d8e2015-04-24 11:19:53 -07001068 if (TransA != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001069 M = A.getType().getX();
1070 K = A.getType().getY();
1071 } else {
1072 M = A.getType().getY();
1073 K = A.getType().getX();
1074 }
Miao Wangb530d8e2015-04-24 11:19:53 -07001075 if (TransB != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001076 N = B.getType().getY();
1077 } else {
1078 N = B.getType().getX();
1079 }
1080 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1081 beta, C.getID(mRS), 0, 0, 0, 0);
1082 }
1083 public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A,
1084 Allocation B, double beta, Allocation C) {
1085 validateTranspose(TransA);
1086 validateTranspose(TransB);
1087 validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C);
1088 int M = -1, N = -1, K = -1;
Miao Wangb530d8e2015-04-24 11:19:53 -07001089 if (TransA != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001090 M = A.getType().getX();
1091 K = A.getType().getY();
1092 } else {
1093 M = A.getType().getY();
1094 K = A.getType().getX();
1095 }
Miao Wangb530d8e2015-04-24 11:19:53 -07001096 if (TransB != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001097 N = B.getType().getY();
1098 } else {
1099 N = B.getType().getX();
1100 }
1101 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS),
1102 beta, C.getID(mRS), 0, 0, 0, 0);
1103 }
1104 public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A,
1105 Allocation B, Float2 beta, Allocation C) {
1106 validateTranspose(TransA);
1107 validateTranspose(TransB);
1108 validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C);
1109 int M = -1, N = -1, K = -1;
Miao Wangb530d8e2015-04-24 11:19:53 -07001110 if (TransA != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001111 M = A.getType().getX();
1112 K = A.getType().getY();
1113 } else {
1114 M = A.getType().getY();
1115 K = A.getType().getX();
1116 }
Miao Wangb530d8e2015-04-24 11:19:53 -07001117 if (TransB != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001118 N = B.getType().getY();
1119 } else {
1120 N = B.getType().getX();
1121 }
1122 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1123 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1124 }
1125
1126 public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A,
1127 Allocation B, Double2 beta, Allocation C) {
1128 validateTranspose(TransA);
1129 validateTranspose(TransB);
1130 validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C);
1131 int M = -1, N = -1, K = -1;
Miao Wangb530d8e2015-04-24 11:19:53 -07001132 if (TransA != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001133 M = A.getType().getX();
1134 K = A.getType().getY();
1135 } else {
1136 M = A.getType().getY();
1137 K = A.getType().getX();
1138 }
Miao Wangb530d8e2015-04-24 11:19:53 -07001139 if (TransB != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001140 N = B.getType().getY();
1141 } else {
1142 N = B.getType().getX();
1143 }
1144 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1145 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1146 }
1147
1148 public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A,
1149 Allocation B, float beta, Allocation C) {
1150 validateSide(Side);
1151 validateUplo(Uplo);
Miao Wangb530d8e2015-04-24 11:19:53 -07001152 //For SYMM, Matrix A should be symmetric
1153 if (A.getType().getX() != A.getType().getY()) {
1154 throw new RSRuntimeException("Matrix A is not symmetric");
1155 }
Tim Murray25207df2015-01-12 16:47:56 -08001156 validateL3(Element.F32(mRS), 0, 0, Side, A, B, C);
1157 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1158 beta, C.getID(mRS), 0, 0, 0, 0);
1159 }
1160 public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A,
1161 Allocation B, double beta, Allocation C) {
1162 validateSide(Side);
1163 validateUplo(Uplo);
Miao Wangb530d8e2015-04-24 11:19:53 -07001164 if (A.getType().getX() != A.getType().getY()) {
1165 throw new RSRuntimeException("Matrix A is not symmetric");
1166 }
Tim Murray25207df2015-01-12 16:47:56 -08001167 validateL3(Element.F64(mRS), 0, 0, Side, A, B, C);
1168 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS),
1169 beta, C.getID(mRS), 0, 0, 0, 0);
1170 }
1171 public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A,
1172 Allocation B, Float2 beta, Allocation C) {
1173 validateSide(Side);
1174 validateUplo(Uplo);
Miao Wangb530d8e2015-04-24 11:19:53 -07001175 if (A.getType().getX() != A.getType().getY()) {
1176 throw new RSRuntimeException("Matrix A is not symmetric");
1177 }
Tim Murray25207df2015-01-12 16:47:56 -08001178 validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C);
1179 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1180 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1181 }
1182 public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A,
1183 Allocation B, Double2 beta, Allocation C) {
1184 validateSide(Side);
1185 validateUplo(Uplo);
Miao Wangb530d8e2015-04-24 11:19:53 -07001186 if (A.getType().getX() != A.getType().getY()) {
1187 throw new RSRuntimeException("Matrix A is not symmetric");
1188 }
Tim Murray25207df2015-01-12 16:47:56 -08001189 validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C);
1190 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS),
1191 beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
1192 }
1193
1194 public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1195 validateTranspose(Trans);
1196 validateUplo(Uplo);
1197 validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C);
1198 int K = -1;
Miao Wangb530d8e2015-04-24 11:19:53 -07001199 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001200 K = A.getType().getY();
1201 } else {
1202 K = A.getType().getX();
1203 }
1204
1205 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1206 }
1207
1208 public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1209 validateTranspose(Trans);
1210 validateUplo(Uplo);
1211 validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C);
1212 int K = -1;
Miao Wangb530d8e2015-04-24 11:19:53 -07001213 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001214 K = A.getType().getY();
1215 } else {
1216 K = A.getType().getX();
1217 }
1218 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0);
1219 }
Miao Wang333bcc02015-04-22 15:57:57 -07001220 public void CSYRK(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Float2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001221 validateTranspose(Trans);
1222 validateUplo(Uplo);
1223 validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C);
1224 int K = -1;
Miao Wangb530d8e2015-04-24 11:19:53 -07001225 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001226 K = A.getType().getY();
1227 } else {
1228 K = A.getType().getX();
1229 }
Miao Wang333bcc02015-04-22 15:57:57 -07001230 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
Tim Murray25207df2015-01-12 16:47:56 -08001231 C.getID(mRS), 0, 0, 0, 0);
1232 }
Miao Wang333bcc02015-04-22 15:57:57 -07001233 public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Double2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001234 validateTranspose(Trans);
1235 validateUplo(Uplo);
1236 validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C);
1237 int K = -1;
Miao Wangb530d8e2015-04-24 11:19:53 -07001238 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001239 K = A.getType().getY();
1240 } else {
1241 K = A.getType().getX();
1242 }
Miao Wang333bcc02015-04-22 15:57:57 -07001243 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), 0, beta.x, beta.y,
Tim Murray25207df2015-01-12 16:47:56 -08001244 C.getID(mRS), 0, 0, 0, 0);
1245 }
1246
1247 static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1248 validateTranspose(Trans);
1249 if (!A.getType().getElement().isCompatible(e) ||
1250 !B.getType().getElement().isCompatible(e) ||
1251 !C.getType().getElement().isCompatible(e)) {
1252 throw new RSRuntimeException("Called BLAS with wrong Element type");
1253 }
1254 int Cdim = -1;
1255 // A is n x k if no transpose, k x n if transpose
1256 // C is n x n
1257 if (Trans == TRANSPOSE) {
1258 // check columns versus C
1259 Cdim = A.getType().getX();
1260 } else {
1261 // check rows versus C
1262 Cdim = A.getType().getY();
1263 }
Miao Wangb530d8e2015-04-24 11:19:53 -07001264 if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
Tim Murray25207df2015-01-12 16:47:56 -08001265 throw new RSRuntimeException("Invalid symmetric matrix in SYR2K");
1266 }
1267 // A dims == B dims
1268 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1269 throw new RSRuntimeException("Invalid A and B in SYR2K");
1270 }
1271 }
1272 public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) {
1273 validateUplo(Uplo);
1274 validateSYR2K(Element.F32(mRS), Trans, A, B, C);
1275 int K = -1;
Miao Wange1cf0952015-04-30 10:47:42 -07001276 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001277 K = A.getType().getY();
1278 } else {
1279 K = A.getType().getX();
1280 }
1281 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
1282 }
1283 public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) {
1284 validateUplo(Uplo);
1285 validateSYR2K(Element.F64(mRS), Trans, A, B, C);
1286 int K = -1;
Miao Wange1cf0952015-04-30 10:47:42 -07001287 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001288 K = A.getType().getY();
1289 } else {
1290 K = A.getType().getX();
1291 }
Miao Wang328919a2015-04-30 17:14:28 -07001292 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001293 }
1294 public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
1295 validateUplo(Uplo);
1296 validateSYR2K(Element.F32_2(mRS), Trans, A, B, C);
1297 int K = -1;
Miao Wange1cf0952015-04-30 10:47:42 -07001298 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001299 K = A.getType().getY();
1300 } else {
1301 K = A.getType().getX();
1302 }
Miao Wang328919a2015-04-30 17:14:28 -07001303 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001304 }
1305 public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
1306 validateUplo(Uplo);
1307 validateSYR2K(Element.F64_2(mRS), Trans, A, B, C);
1308 int K = -1;
Miao Wange1cf0952015-04-30 10:47:42 -07001309 if (Trans != NO_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001310 K = A.getType().getY();
1311 } else {
1312 K = A.getType().getX();
1313 }
Miao Wang328919a2015-04-30 17:14:28 -07001314 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001315 }
1316
1317 static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
1318 validateSide(Side);
1319 validateTranspose(TransA);
Miao Wangb530d8e2015-04-24 11:19:53 -07001320 int aM = -1, aN = -1, bM = -1, bN = -1;
Tim Murray25207df2015-01-12 16:47:56 -08001321 if (!A.getType().getElement().isCompatible(e) ||
1322 !B.getType().getElement().isCompatible(e)) {
1323 throw new RSRuntimeException("Called BLAS with wrong Element type");
1324 }
Miao Wangb530d8e2015-04-24 11:19:53 -07001325
1326 aM = A.getType().getY();
1327 aN = A.getType().getX();
1328 if (aM != aN) {
1329 throw new RSRuntimeException("Called TRMM with a non-symmetric matrix A");
Tim Murray25207df2015-01-12 16:47:56 -08001330 }
Miao Wangb530d8e2015-04-24 11:19:53 -07001331
1332 bM = B.getType().getY();
1333 bN = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001334 if (Side == LEFT) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001335 if (aN != bM) {
Tim Murray25207df2015-01-12 16:47:56 -08001336 throw new RSRuntimeException("Called TRMM with invalid matrices");
1337 }
1338 } else {
Miao Wangb530d8e2015-04-24 11:19:53 -07001339 if (bN != aM) {
Tim Murray25207df2015-01-12 16:47:56 -08001340 throw new RSRuntimeException("Called TRMM with invalid matrices");
1341 }
1342 }
1343 }
1344 public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1345 validateUplo(Uplo);
1346 validateDiag(Diag);
1347 validateTRMM(Element.F32(mRS), Side, TransA, A, B);
1348 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1349 alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0);
1350 }
1351 public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1352 validateUplo(Uplo);
1353 validateDiag(Diag);
1354 validateTRMM(Element.F64(mRS), Side, TransA, A, B);
Miao Wang328919a2015-04-30 17:14:28 -07001355 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1356 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001357 }
1358 public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1359 validateUplo(Uplo);
1360 validateDiag(Diag);
1361 validateTRMM(Element.F32_2(mRS), Side, TransA, A, B);
Miao Wang328919a2015-04-30 17:14:28 -07001362 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
Tim Murray25207df2015-01-12 16:47:56 -08001363 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1364 }
1365 public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1366 validateUplo(Uplo);
1367 validateDiag(Diag);
1368 validateTRMM(Element.F64_2(mRS), Side, TransA, A, B);
Miao Wang328919a2015-04-30 17:14:28 -07001369 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
Tim Murray25207df2015-01-12 16:47:56 -08001370 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1371 }
1372
1373 static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001374 int adim = -1, bM = -1, bN = -1;
Tim Murray25207df2015-01-12 16:47:56 -08001375 validateSide(Side);
1376 validateTranspose(TransA);
1377 if (!A.getType().getElement().isCompatible(e) ||
1378 !B.getType().getElement().isCompatible(e)) {
1379 throw new RSRuntimeException("Called BLAS with wrong Element type");
1380 }
1381 adim = A.getType().getX();
1382 if (adim != A.getType().getY()) {
1383 // this may be unnecessary, the restriction could potentially be relaxed
1384 // A needs to contain at least that symmetric matrix but could theoretically be larger
1385 // for now we assume adapters are sufficient, will reevaluate in the future
1386 throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A");
1387 }
Miao Wangb530d8e2015-04-24 11:19:53 -07001388 bM = B.getType().getY();
1389 bN = B.getType().getX();
Tim Murray25207df2015-01-12 16:47:56 -08001390 if (Side == LEFT) {
1391 // A is M*M
Miao Wangb530d8e2015-04-24 11:19:53 -07001392 if (adim != bM) {
Tim Murray25207df2015-01-12 16:47:56 -08001393 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1394 }
1395 } else {
1396 // A is N*N
Miao Wangb530d8e2015-04-24 11:19:53 -07001397 if (adim != bN) {
Tim Murray25207df2015-01-12 16:47:56 -08001398 throw new RSRuntimeException("Called TRSM with invalid matrix dimensions");
1399 }
1400 }
1401 }
1402 public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) {
1403 validateUplo(Uplo);
1404 validateDiag(Diag);
1405 validateTRSM(Element.F32(mRS), Side, TransA, A, B);
1406 mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
1407 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1408 }
1409 public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) {
1410 validateUplo(Uplo);
1411 validateDiag(Diag);
1412 validateTRSM(Element.F64(mRS), Side, TransA, A, B);
Miao Wang328919a2015-04-30 17:14:28 -07001413 mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
Tim Murray25207df2015-01-12 16:47:56 -08001414 alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0);
1415 }
1416 public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) {
1417 validateUplo(Uplo);
1418 validateDiag(Diag);
1419 validateTRSM(Element.F32_2(mRS), Side, TransA, A, B);
Miao Wang328919a2015-04-30 17:14:28 -07001420 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
Tim Murray25207df2015-01-12 16:47:56 -08001421 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1422 }
1423 public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) {
1424 validateUplo(Uplo);
1425 validateDiag(Diag);
1426 validateTRSM(Element.F64_2(mRS), Side, TransA, A, B);
Miao Wang328919a2015-04-30 17:14:28 -07001427 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0,
Tim Murray25207df2015-01-12 16:47:56 -08001428 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0);
1429 }
1430
1431 static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) {
1432 validateSide(Side);
1433
1434 if (!A.getType().getElement().isCompatible(e) ||
1435 !B.getType().getElement().isCompatible(e) ||
1436 !C.getType().getElement().isCompatible(e)) {
1437 throw new RSRuntimeException("Called BLAS with wrong Element type");
1438 }
1439
1440 // A must be square; can potentially be relaxed similar to TRSM
1441 int adim = A.getType().getX();
1442 if (adim != A.getType().getY()) {
1443 throw new RSRuntimeException("Called HEMM with non-square A");
1444 }
1445 if ((Side == LEFT && adim != B.getType().getY()) ||
1446 (Side == RIGHT && adim != B.getType().getX())) {
1447 throw new RSRuntimeException("Called HEMM with invalid B");
1448 }
1449 if (B.getType().getX() != C.getType().getX() ||
1450 B.getType().getY() != C.getType().getY()) {
1451 throw new RSRuntimeException("Called HEMM with mismatched B and C");
1452 }
1453 }
Miao Wang333bcc02015-04-22 15:57:57 -07001454 public void CHEMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001455 validateUplo(Uplo);
1456 validateHEMM(Element.F32_2(mRS), Side, A, B, C);
1457 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
Miao Wang333bcc02015-04-22 15:57:57 -07001458 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001459 }
Miao Wang333bcc02015-04-22 15:57:57 -07001460 public void ZHEMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) {
Tim Murray25207df2015-01-12 16:47:56 -08001461 validateUplo(Uplo);
Miao Wangb530d8e2015-04-24 11:19:53 -07001462 validateHEMM(Element.F64_2(mRS), Side, A, B, C);
Tim Murray25207df2015-01-12 16:47:56 -08001463 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0,
Miao Wang333bcc02015-04-22 15:57:57 -07001464 alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0);
Tim Murray25207df2015-01-12 16:47:56 -08001465 }
1466
1467 static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) {
1468 if (!A.getType().getElement().isCompatible(e) ||
1469 !C.getType().getElement().isCompatible(e)) {
1470 throw new RSRuntimeException("Called BLAS with wrong Element type");
1471 }
1472 validateConjTranspose(Trans);
1473 int cdim = C.getType().getX();
1474 if (cdim != C.getType().getY()) {
1475 throw new RSRuntimeException("Called HERK with non-square C");
1476 }
1477 if (Trans == NO_TRANSPOSE) {
Miao Wangb530d8e2015-04-24 11:19:53 -07001478 if (cdim != A.getType().getY()) {
Tim Murray25207df2015-01-12 16:47:56 -08001479 throw new RSRuntimeException("Called HERK with invalid A");
1480 }
1481 } else {
Miao Wangb530d8e2015-04-24 11:19:53 -07001482 if (cdim != A.getType().getX()) {
Tim Murray25207df2015-01-12 16:47:56 -08001483 throw new RSRuntimeException("Called HERK with invalid A");
1484 }
1485 }
1486 }
1487 public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) {
1488 validateUplo(Uplo);
1489 validateHERK(Element.F32_2(mRS), Trans, A, C);
1490 int k = 0;
Miao Wangb530d8e2015-04-24 11:19:53 -07001491 if (Trans == CONJ_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001492 k = A.getType().getY();
1493 } else {
1494 k = A.getType().getX();
1495 }
1496 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1497 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1498 }
1499 public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) {
1500 validateUplo(Uplo);
1501 validateHERK(Element.F64_2(mRS), Trans, A, C);
1502 int k = 0;
Miao Wangb530d8e2015-04-24 11:19:53 -07001503 if (Trans == CONJ_TRANSPOSE) {
Tim Murray25207df2015-01-12 16:47:56 -08001504 k = A.getType().getY();
1505 } else {
1506 k = A.getType().getX();
1507 }
1508 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k,
1509 alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0);
1510 }
1511
1512 static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) {
1513 if (!A.getType().getElement().isCompatible(e) ||
1514 !B.getType().getElement().isCompatible(e) ||
1515 !C.getType().getElement().isCompatible(e)) {
1516 throw new RSRuntimeException("Called BLAS with wrong Element type");
1517 }
1518 validateConjTranspose(Trans);
1519 int cdim = C.getType().getX();
1520 if (cdim != C.getType().getY()) {
1521 throw new RSRuntimeException("Called HER2K with non-square C");
1522 }
1523 if (Trans == NO_TRANSPOSE) {
1524 if (A.getType().getY() != cdim) {
1525 throw new RSRuntimeException("Called HER2K with invalid matrices");
1526 }
1527 } else {
1528 if (A.getType().getX() != cdim) {
1529 throw new RSRuntimeException("Called HER2K with invalid matrices");
1530 }
1531 }
1532 if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
1533 throw new RSRuntimeException("Called HER2K with invalid A and B matrices");
1534 }
1535 }
1536 public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) {
1537 validateUplo(Uplo);
1538 validateHER2K(Element.F32_2(mRS), Trans, A, B, C);
1539 int k = 0;
1540 if (Trans == NO_TRANSPOSE) {
1541 k = A.getType().getX();
1542 } else {
1543 k = A.getType().getY();
1544 }
1545 mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1546 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1547 }
1548 public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) {
1549 validateUplo(Uplo);
1550 validateHER2K(Element.F64_2(mRS), Trans, A, B, C);
1551 int k = 0;
1552 if (Trans == NO_TRANSPOSE) {
1553 k = A.getType().getX();
1554 } else {
1555 k = A.getType().getY();
1556 }
1557 mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y,
1558 A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0);
1559 }
1560
1561
Tim Murray9cb16a22015-04-01 11:07:16 -07001562 /**
1563 *
1564 * 8-bit GEMM-like operation for neural networks
1565 *
1566 * @hide
1567 **/
1568 public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) {
1569 validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C);
1570
1571 int M = -1, N = -1, K = -1;
1572 M = A.getType().getY();
1573 N = B.getType().getY();
1574 K = A.getType().getX();
1575
1576
1577 mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult);
1578
1579 }
Tim Murray25207df2015-01-12 16:47:56 -08001580
1581}