blob: 4efb45b92aedf542426dba5a198a7e4deffd7bd2 [file] [log] [blame]
Jason Sams423ebcb2012-08-10 15:40:53 -07001/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import java.lang.reflect.Method;
Jason Sams08a81582012-09-18 12:32:10 -070020import java.util.ArrayList;
Jason Sams423ebcb2012-08-10 15:40:53 -070021
22/**
Jason Sams08a81582012-09-18 12:32:10 -070023 * ScriptGroup creates a groups of scripts which are executed
24 * together based upon upon one execution call as if they were
25 * all part of a single script. The scripts may be connected
26 * internally or to an external allocation. For the internal
27 * connections the intermediate results are not observable after
28 * the execution of the script.
29 * <p>
30 * The external connections are grouped into inputs and outputs.
31 * All outputs are produced by a script kernel and placed into a
32 * user supplied allocation. Inputs are similar but supply the
33 * input of a kernal. Inputs bounds to a script are set directly
34 * upon the script.
35 *
Jason Sams423ebcb2012-08-10 15:40:53 -070036 **/
Jason Sams08a81582012-09-18 12:32:10 -070037public final class ScriptGroup extends BaseObj {
Jason Sams423ebcb2012-08-10 15:40:53 -070038 IO mOutputs[];
39 IO mInputs[];
40
41 static class IO {
Jason Sams08a81582012-09-18 12:32:10 -070042 Script.KernelID mKID;
Jason Sams423ebcb2012-08-10 15:40:53 -070043 Allocation mAllocation;
Jason Sams423ebcb2012-08-10 15:40:53 -070044
Jason Sams08a81582012-09-18 12:32:10 -070045 IO(Script.KernelID s) {
46 mKID = s;
Jason Sams423ebcb2012-08-10 15:40:53 -070047 }
48 }
49
Jason Sams08a81582012-09-18 12:32:10 -070050 static class ConnectLine {
51 ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
52 mFrom = from;
53 mToK = to;
Jason Sams423ebcb2012-08-10 15:40:53 -070054 mAllocationType = t;
55 }
56
Jason Sams08a81582012-09-18 12:32:10 -070057 ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
58 mFrom = from;
59 mToF = to;
60 mAllocationType = t;
Jason Sams423ebcb2012-08-10 15:40:53 -070061 }
Jason Sams08a81582012-09-18 12:32:10 -070062
63 Script.FieldID mToF;
64 Script.KernelID mToK;
65 Script.KernelID mFrom;
66 Type mAllocationType;
Jason Sams423ebcb2012-08-10 15:40:53 -070067 }
68
69 static class Node {
70 Script mScript;
Jason Sams08a81582012-09-18 12:32:10 -070071 ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
72 ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
73 ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
Jason Sams423ebcb2012-08-10 15:40:53 -070074 boolean mSeen;
75
76 Node mNext;
77
78 Node(Script s) {
79 mScript = s;
80 }
Jason Sams423ebcb2012-08-10 15:40:53 -070081 }
82
83
84 ScriptGroup(int id, RenderScript rs) {
85 super(id, rs);
86 }
87
Jason Sams08a81582012-09-18 12:32:10 -070088 /**
89 * Sets an input of the ScriptGroup. This specifies an
90 * Allocation to be used for the kernels which require a kernel
91 * input and that input is provided external to the group.
92 *
93 * @param s The ID of the kernel where the allocation should be
94 * connected.
95 * @param a The allocation to connect.
96 */
97 public void setInput(Script.KernelID s, Allocation a) {
Jason Sams423ebcb2012-08-10 15:40:53 -070098 for (int ct=0; ct < mInputs.length; ct++) {
Jason Sams08a81582012-09-18 12:32:10 -070099 if (mInputs[ct].mKID == s) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700100 mInputs[ct].mAllocation = a;
Jason Sams08a81582012-09-18 12:32:10 -0700101 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
Jason Sams423ebcb2012-08-10 15:40:53 -0700102 return;
103 }
104 }
105 throw new RSIllegalArgumentException("Script not found");
106 }
107
Jason Sams08a81582012-09-18 12:32:10 -0700108 /**
109 * Sets an output of the ScriptGroup. This specifies an
110 * Allocation to be used for the kernels which require a kernel
111 * output and that output is provided external to the group.
112 *
113 * @param s The ID of the kernel where the allocation should be
114 * connected.
115 * @param a The allocation to connect.
116 */
117 public void setOutput(Script.KernelID s, Allocation a) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700118 for (int ct=0; ct < mOutputs.length; ct++) {
Jason Sams08a81582012-09-18 12:32:10 -0700119 if (mOutputs[ct].mKID == s) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700120 mOutputs[ct].mAllocation = a;
Jason Sams08a81582012-09-18 12:32:10 -0700121 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
Jason Sams423ebcb2012-08-10 15:40:53 -0700122 return;
123 }
124 }
125 throw new RSIllegalArgumentException("Script not found");
126 }
127
Jason Sams08a81582012-09-18 12:32:10 -0700128 /**
129 * Execute the ScriptGroup. This will run all the kernels in
130 * the script. The state of the connecting lines will not be
131 * observable after this operation.
132 */
Jason Sams423ebcb2012-08-10 15:40:53 -0700133 public void execute() {
Jason Sams08a81582012-09-18 12:32:10 -0700134 mRS.nScriptGroupExecute(getID(mRS));
Jason Sams423ebcb2012-08-10 15:40:53 -0700135 }
136
137
Jason Sams08a81582012-09-18 12:32:10 -0700138 /**
139 * Create a ScriptGroup. There are two steps to creating a
140 * ScriptGoup.
141 * <p>
142 * First all the Kernels to be used by the group should be
143 * added. Once this is done the kernels should be connected.
144 * Kernels cannot be added once a connection has been made.
145 * <p>
146 * Second, add connections. There are two forms of connections.
147 * Kernel to Kernel and Kernel to Field. Kernel to Kernel is
148 * higher performance and should be used where possible. The
149 * line of connections cannot form a loop. If a loop is detected
150 * an exception is thrown.
151 * <p>
152 * Once all the connections are made a call to create will
153 * return the ScriptGroup object.
154 *
155 */
156 public static final class Builder {
157 private RenderScript mRS;
158 private ArrayList<Node> mNodes = new ArrayList<Node>();
159 private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
160 private int mKernelCount;
Jason Sams423ebcb2012-08-10 15:40:53 -0700161
Jason Sams08a81582012-09-18 12:32:10 -0700162 /**
163 * Create a builder for generating a ScriptGroup.
164 *
165 *
166 * @param rs The Renderscript context.
167 */
Jason Sams423ebcb2012-08-10 15:40:53 -0700168 public Builder(RenderScript rs) {
169 mRS = rs;
170 }
171
172 private void validateRecurse(Node n, int depth) {
173 n.mSeen = true;
Jason Sams423ebcb2012-08-10 15:40:53 -0700174
Jason Sams08a81582012-09-18 12:32:10 -0700175 //android.util.Log.v("RSR", " validateRecurse outputCount " + n.mOutputs.size());
176 for (int ct=0; ct < n.mOutputs.size(); ct++) {
177 final ConnectLine cl = n.mOutputs.get(ct);
178 if (cl.mToK != null) {
179 Node tn = findNode(cl.mToK.mScript);
180 if (tn.mSeen) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700181 throw new RSInvalidStateException("Loops in group not allowed.");
182 }
Jason Sams08a81582012-09-18 12:32:10 -0700183 validateRecurse(tn, depth + 1);
184 }
185 if (cl.mToF != null) {
186 Node tn = findNode(cl.mToF.mScript);
187 if (tn.mSeen) {
188 throw new RSInvalidStateException("Loops in group not allowed.");
189 }
190 validateRecurse(tn, depth + 1);
Jason Sams423ebcb2012-08-10 15:40:53 -0700191 }
192 }
193 }
194
195 private void validate() {
Jason Sams08a81582012-09-18 12:32:10 -0700196 //android.util.Log.v("RSR", "validate");
Jason Sams423ebcb2012-08-10 15:40:53 -0700197
Jason Sams08a81582012-09-18 12:32:10 -0700198 for (int ct=0; ct < mNodes.size(); ct++) {
199 for (int ct2=0; ct2 < mNodes.size(); ct2++) {
200 mNodes.get(ct2).mSeen = false;
201 }
202 Node n = mNodes.get(ct);
203 if (n.mInputs.size() == 0) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700204 validateRecurse(n, 0);
205 }
Jason Sams423ebcb2012-08-10 15:40:53 -0700206 }
207 }
208
Jason Sams08a81582012-09-18 12:32:10 -0700209 private Node findNode(Script s) {
210 for (int ct=0; ct < mNodes.size(); ct++) {
211 if (s == mNodes.get(ct).mScript) {
212 return mNodes.get(ct);
Jason Sams423ebcb2012-08-10 15:40:53 -0700213 }
Jason Sams423ebcb2012-08-10 15:40:53 -0700214 }
215 return null;
216 }
217
Jason Sams08a81582012-09-18 12:32:10 -0700218 private Node findNode(Script.KernelID k) {
219 for (int ct=0; ct < mNodes.size(); ct++) {
220 Node n = mNodes.get(ct);
221 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
222 if (k == n.mKernels.get(ct2)) {
223 return n;
Jason Sams423ebcb2012-08-10 15:40:53 -0700224 }
225 }
Jason Sams423ebcb2012-08-10 15:40:53 -0700226 }
Jason Sams08a81582012-09-18 12:32:10 -0700227 return null;
228 }
Jason Sams423ebcb2012-08-10 15:40:53 -0700229
Jason Sams08a81582012-09-18 12:32:10 -0700230 /**
231 * Adds a Kernel to the group.
232 *
233 *
234 * @param k The kernel to add.
235 *
236 * @return Builder Returns this.
237 */
238 public Builder addKernel(Script.KernelID k) {
239 if (mLines.size() != 0) {
240 throw new RSInvalidStateException(
241 "Kernels may not be added once connections exist.");
Jason Sams423ebcb2012-08-10 15:40:53 -0700242 }
Jason Sams08a81582012-09-18 12:32:10 -0700243
244 //android.util.Log.v("RSR", "addKernel 1 k=" + k);
245 if (findNode(k) != null) {
246 return this;
247 }
248 //android.util.Log.v("RSR", "addKernel 2 ");
249 mKernelCount++;
250 Node n = findNode(k.mScript);
251 if (n == null) {
252 //android.util.Log.v("RSR", "addKernel 3 ");
253 n = new Node(k.mScript);
254 mNodes.add(n);
255 }
256 n.mKernels.add(k);
257 return this;
258 }
259
260 /**
261 * Adds a connection to the group.
262 *
263 *
264 * @param t The type of the connection. This is used to
265 * determine the kernel launch sizes on the source side
266 * of this connection.
267 * @param from The source for the connection.
268 * @param to The destination of the connection.
269 *
270 * @return Builder Returns this
271 */
272 public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
273 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
274
275 Node nf = findNode(from);
276 if (nf == null) {
277 throw new RSInvalidStateException("From kernel not found.");
278 }
279
280 Node nt = findNode(to.mScript);
281 if (nt == null) {
282 throw new RSInvalidStateException("To script not found.");
283 }
284
285 ConnectLine cl = new ConnectLine(t, from, to);
286 mLines.add(new ConnectLine(t, from, to));
287
288 nf.mOutputs.add(cl);
289 nt.mInputs.add(cl);
Jason Sams423ebcb2012-08-10 15:40:53 -0700290
291 validate();
292 return this;
293 }
294
Jason Sams08a81582012-09-18 12:32:10 -0700295 /**
296 * Adds a connection to the group.
297 *
298 *
299 * @param t The type of the connection. This is used to
300 * determine the kernel launch sizes for both sides of
301 * this connection.
302 * @param from The source for the connection.
303 * @param to The destination of the connection.
304 *
305 * @return Builder Returns this
306 */
307 public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
308 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
309
310 Node nf = findNode(from);
311 if (nf == null) {
312 throw new RSInvalidStateException("From kernel not found.");
313 }
314
315 Node nt = findNode(to);
316 if (nt == null) {
317 throw new RSInvalidStateException("To script not found.");
318 }
319
320 ConnectLine cl = new ConnectLine(t, from, to);
321 mLines.add(new ConnectLine(t, from, to));
322
323 nf.mOutputs.add(cl);
324 nt.mInputs.add(cl);
325
326 validate();
327 return this;
328 }
329
330
331
332 /**
333 * Creates the Script group.
334 *
335 *
336 * @return ScriptGroup The new ScriptGroup
337 */
Jason Sams423ebcb2012-08-10 15:40:53 -0700338 public ScriptGroup create() {
Jason Sams08a81582012-09-18 12:32:10 -0700339 ArrayList<IO> inputs = new ArrayList<IO>();
340 ArrayList<IO> outputs = new ArrayList<IO>();
Jason Sams423ebcb2012-08-10 15:40:53 -0700341
Jason Sams08a81582012-09-18 12:32:10 -0700342 int[] kernels = new int[mKernelCount];
343 int idx = 0;
344 for (int ct=0; ct < mNodes.size(); ct++) {
345 Node n = mNodes.get(ct);
346 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
347 final Script.KernelID kid = n.mKernels.get(ct2);
348 kernels[idx++] = kid.getID(mRS);
Jason Sams423ebcb2012-08-10 15:40:53 -0700349
Jason Sams08a81582012-09-18 12:32:10 -0700350 boolean hasInput = false;
351 boolean hasOutput = false;
352 for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
353 if (n.mInputs.get(ct3).mToK == kid) {
354 hasInput = true;
355 }
356 }
357 for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
358 if (n.mOutputs.get(ct3).mFrom == kid) {
359 hasOutput = true;
360 }
361 }
362 if (!hasInput) {
363 inputs.add(new IO(kid));
364 }
365 if (!hasOutput) {
366 outputs.add(new IO(kid));
367 }
368
369 }
370 }
371 if (idx != mKernelCount) {
372 throw new RSRuntimeException("Count mismatch, should not happen.");
373 }
374
375 int[] src = new int[mLines.size()];
376 int[] dstk = new int[mLines.size()];
377 int[] dstf = new int[mLines.size()];
378 int[] types = new int[mLines.size()];
379
380 for (int ct=0; ct < mLines.size(); ct++) {
381 ConnectLine cl = mLines.get(ct);
382 src[ct] = cl.mFrom.getID(mRS);
383 if (cl.mToK != null) {
384 dstk[ct] = cl.mToK.getID(mRS);
385 }
386 if (cl.mToF != null) {
387 dstf[ct] = cl.mToF.getID(mRS);
388 }
389 types[ct] = cl.mAllocationType.getID(mRS);
390 }
391
392 int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
393 if (id == 0) {
394 throw new RSRuntimeException("Object creation error, should not happen.");
395 }
396
397 ScriptGroup sg = new ScriptGroup(id, mRS);
398 sg.mOutputs = new IO[outputs.size()];
399 for (int ct=0; ct < outputs.size(); ct++) {
400 sg.mOutputs[ct] = outputs.get(ct);
401 }
402
403 sg.mInputs = new IO[inputs.size()];
404 for (int ct=0; ct < inputs.size(); ct++) {
405 sg.mInputs[ct] = inputs.get(ct);
406 }
407
Jason Sams423ebcb2012-08-10 15:40:53 -0700408 return sg;
409 }
410
411 }
412
413
414}
415
416