blob: 8943f7529b9a858b38f68b0b362fd3f55f2a3c32 [file] [log] [blame]
Jason Sams423ebcb2012-08-10 15:40:53 -07001/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import java.lang.reflect.Method;
Jason Sams08a81582012-09-18 12:32:10 -070020import java.util.ArrayList;
Jason Sams423ebcb2012-08-10 15:40:53 -070021
22/**
Jason Sams08a81582012-09-18 12:32:10 -070023 * ScriptGroup creates a groups of scripts which are executed
24 * together based upon upon one execution call as if they were
25 * all part of a single script. The scripts may be connected
26 * internally or to an external allocation. For the internal
27 * connections the intermediate results are not observable after
28 * the execution of the script.
29 * <p>
30 * The external connections are grouped into inputs and outputs.
31 * All outputs are produced by a script kernel and placed into a
32 * user supplied allocation. Inputs are similar but supply the
33 * input of a kernal. Inputs bounds to a script are set directly
34 * upon the script.
Tim Murray2a603892012-10-10 14:21:46 -070035 * <p>
36 * A ScriptGroup must contain at least one kernel. A ScriptGroup
37 * must contain only a single directed acyclic graph (DAG) of
38 * script kernels and connections. Attempting to create a
39 * ScriptGroup with multiple DAGs or attempting to create
40 * a cycle within a ScriptGroup will throw an exception.
Jason Sams08a81582012-09-18 12:32:10 -070041 *
Jason Sams423ebcb2012-08-10 15:40:53 -070042 **/
Jason Sams08a81582012-09-18 12:32:10 -070043public final class ScriptGroup extends BaseObj {
Jason Sams423ebcb2012-08-10 15:40:53 -070044 IO mOutputs[];
45 IO mInputs[];
46
47 static class IO {
Jason Sams08a81582012-09-18 12:32:10 -070048 Script.KernelID mKID;
Jason Sams423ebcb2012-08-10 15:40:53 -070049 Allocation mAllocation;
Jason Sams423ebcb2012-08-10 15:40:53 -070050
Jason Sams08a81582012-09-18 12:32:10 -070051 IO(Script.KernelID s) {
52 mKID = s;
Jason Sams423ebcb2012-08-10 15:40:53 -070053 }
54 }
55
Jason Sams08a81582012-09-18 12:32:10 -070056 static class ConnectLine {
57 ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
58 mFrom = from;
59 mToK = to;
Jason Sams423ebcb2012-08-10 15:40:53 -070060 mAllocationType = t;
61 }
62
Jason Sams08a81582012-09-18 12:32:10 -070063 ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
64 mFrom = from;
65 mToF = to;
66 mAllocationType = t;
Jason Sams423ebcb2012-08-10 15:40:53 -070067 }
Jason Sams08a81582012-09-18 12:32:10 -070068
69 Script.FieldID mToF;
70 Script.KernelID mToK;
71 Script.KernelID mFrom;
72 Type mAllocationType;
Jason Sams423ebcb2012-08-10 15:40:53 -070073 }
74
75 static class Node {
76 Script mScript;
Jason Sams08a81582012-09-18 12:32:10 -070077 ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
78 ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
79 ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
Jason Sams423ebcb2012-08-10 15:40:53 -070080 boolean mSeen;
Tim Murray2a603892012-10-10 14:21:46 -070081 int dagNumber;
Jason Sams423ebcb2012-08-10 15:40:53 -070082
83 Node mNext;
84
85 Node(Script s) {
86 mScript = s;
87 }
Jason Sams423ebcb2012-08-10 15:40:53 -070088 }
89
90
91 ScriptGroup(int id, RenderScript rs) {
92 super(id, rs);
93 }
94
Jason Sams08a81582012-09-18 12:32:10 -070095 /**
96 * Sets an input of the ScriptGroup. This specifies an
97 * Allocation to be used for the kernels which require a kernel
98 * input and that input is provided external to the group.
99 *
100 * @param s The ID of the kernel where the allocation should be
101 * connected.
102 * @param a The allocation to connect.
103 */
104 public void setInput(Script.KernelID s, Allocation a) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700105 for (int ct=0; ct < mInputs.length; ct++) {
Jason Sams08a81582012-09-18 12:32:10 -0700106 if (mInputs[ct].mKID == s) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700107 mInputs[ct].mAllocation = a;
Jason Sams08a81582012-09-18 12:32:10 -0700108 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
Jason Sams423ebcb2012-08-10 15:40:53 -0700109 return;
110 }
111 }
112 throw new RSIllegalArgumentException("Script not found");
113 }
114
Jason Sams08a81582012-09-18 12:32:10 -0700115 /**
116 * Sets an output of the ScriptGroup. This specifies an
117 * Allocation to be used for the kernels which require a kernel
118 * output and that output is provided external to the group.
119 *
120 * @param s The ID of the kernel where the allocation should be
121 * connected.
122 * @param a The allocation to connect.
123 */
124 public void setOutput(Script.KernelID s, Allocation a) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700125 for (int ct=0; ct < mOutputs.length; ct++) {
Jason Sams08a81582012-09-18 12:32:10 -0700126 if (mOutputs[ct].mKID == s) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700127 mOutputs[ct].mAllocation = a;
Jason Sams08a81582012-09-18 12:32:10 -0700128 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
Jason Sams423ebcb2012-08-10 15:40:53 -0700129 return;
130 }
131 }
132 throw new RSIllegalArgumentException("Script not found");
133 }
134
Jason Sams08a81582012-09-18 12:32:10 -0700135 /**
136 * Execute the ScriptGroup. This will run all the kernels in
137 * the script. The state of the connecting lines will not be
138 * observable after this operation.
139 */
Jason Sams423ebcb2012-08-10 15:40:53 -0700140 public void execute() {
Jason Sams08a81582012-09-18 12:32:10 -0700141 mRS.nScriptGroupExecute(getID(mRS));
Jason Sams423ebcb2012-08-10 15:40:53 -0700142 }
143
144
Jason Sams08a81582012-09-18 12:32:10 -0700145 /**
146 * Create a ScriptGroup. There are two steps to creating a
147 * ScriptGoup.
148 * <p>
149 * First all the Kernels to be used by the group should be
150 * added. Once this is done the kernels should be connected.
151 * Kernels cannot be added once a connection has been made.
152 * <p>
153 * Second, add connections. There are two forms of connections.
154 * Kernel to Kernel and Kernel to Field. Kernel to Kernel is
155 * higher performance and should be used where possible. The
156 * line of connections cannot form a loop. If a loop is detected
157 * an exception is thrown.
158 * <p>
159 * Once all the connections are made a call to create will
160 * return the ScriptGroup object.
161 *
162 */
163 public static final class Builder {
164 private RenderScript mRS;
165 private ArrayList<Node> mNodes = new ArrayList<Node>();
166 private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
167 private int mKernelCount;
Jason Sams423ebcb2012-08-10 15:40:53 -0700168
Jason Sams08a81582012-09-18 12:32:10 -0700169 /**
170 * Create a builder for generating a ScriptGroup.
171 *
172 *
173 * @param rs The Renderscript context.
174 */
Jason Sams423ebcb2012-08-10 15:40:53 -0700175 public Builder(RenderScript rs) {
176 mRS = rs;
177 }
178
Tim Murray2a603892012-10-10 14:21:46 -0700179 private void validateCycleRecurse(Node n, int depth) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700180 n.mSeen = true;
Jason Sams423ebcb2012-08-10 15:40:53 -0700181
Tim Murray2a603892012-10-10 14:21:46 -0700182 //android.util.Log.v("RSR", " validateCycleRecurse outputCount " + n.mOutputs.size());
Jason Sams08a81582012-09-18 12:32:10 -0700183 for (int ct=0; ct < n.mOutputs.size(); ct++) {
184 final ConnectLine cl = n.mOutputs.get(ct);
185 if (cl.mToK != null) {
186 Node tn = findNode(cl.mToK.mScript);
187 if (tn.mSeen) {
Jason Sams423ebcb2012-08-10 15:40:53 -0700188 throw new RSInvalidStateException("Loops in group not allowed.");
189 }
Tim Murray2a603892012-10-10 14:21:46 -0700190 validateCycleRecurse(tn, depth + 1);
Jason Sams08a81582012-09-18 12:32:10 -0700191 }
192 if (cl.mToF != null) {
193 Node tn = findNode(cl.mToF.mScript);
194 if (tn.mSeen) {
195 throw new RSInvalidStateException("Loops in group not allowed.");
196 }
Tim Murray2a603892012-10-10 14:21:46 -0700197 validateCycleRecurse(tn, depth + 1);
Jason Sams423ebcb2012-08-10 15:40:53 -0700198 }
199 }
200 }
201
Tim Murray2a603892012-10-10 14:21:46 -0700202 private void validateCycle() {
203 //android.util.Log.v("RSR", "validateCycle");
Jason Sams423ebcb2012-08-10 15:40:53 -0700204
Jason Sams08a81582012-09-18 12:32:10 -0700205 for (int ct=0; ct < mNodes.size(); ct++) {
206 for (int ct2=0; ct2 < mNodes.size(); ct2++) {
207 mNodes.get(ct2).mSeen = false;
208 }
209 Node n = mNodes.get(ct);
210 if (n.mInputs.size() == 0) {
Tim Murray2a603892012-10-10 14:21:46 -0700211 validateCycleRecurse(n, 0);
212 }
213 }
214 }
215
216 private void mergeDAGs(int valueUsed, int valueKilled) {
217 for (int ct=0; ct < mNodes.size(); ct++) {
218 if (mNodes.get(ct).dagNumber == valueKilled)
219 mNodes.get(ct).dagNumber = valueUsed;
220 }
221 }
222
223 private void validateDAGRecurse(Node n, int dagNumber) {
224 // combine DAGs if this node has been seen already
225 if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
226 mergeDAGs(n.dagNumber, dagNumber);
227 return;
228 }
229
230 n.dagNumber = dagNumber;
231 for (int ct=0; ct < n.mOutputs.size(); ct++) {
232 final ConnectLine cl = n.mOutputs.get(ct);
233 if (cl.mToK != null) {
234 Node tn = findNode(cl.mToK.mScript);
235 validateDAGRecurse(tn, dagNumber);
236 }
237 if (cl.mToF != null) {
238 Node tn = findNode(cl.mToF.mScript);
239 validateDAGRecurse(tn, dagNumber);
240 }
241 }
242 }
243
244 private void validateDAG() {
245 for (int ct=0; ct < mNodes.size(); ct++) {
246 Node n = mNodes.get(ct);
247 if (n.mInputs.size() == 0) {
248 if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
249 throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
250 }
251 validateDAGRecurse(n, ct+1);
252 }
253 }
254 int dagNumber = mNodes.get(0).dagNumber;
255 for (int ct=0; ct < mNodes.size(); ct++) {
256 if (mNodes.get(ct).dagNumber != dagNumber) {
257 throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
Jason Sams423ebcb2012-08-10 15:40:53 -0700258 }
Jason Sams423ebcb2012-08-10 15:40:53 -0700259 }
260 }
261
Jason Sams08a81582012-09-18 12:32:10 -0700262 private Node findNode(Script s) {
263 for (int ct=0; ct < mNodes.size(); ct++) {
264 if (s == mNodes.get(ct).mScript) {
265 return mNodes.get(ct);
Jason Sams423ebcb2012-08-10 15:40:53 -0700266 }
Jason Sams423ebcb2012-08-10 15:40:53 -0700267 }
268 return null;
269 }
270
Jason Sams08a81582012-09-18 12:32:10 -0700271 private Node findNode(Script.KernelID k) {
272 for (int ct=0; ct < mNodes.size(); ct++) {
273 Node n = mNodes.get(ct);
274 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
275 if (k == n.mKernels.get(ct2)) {
276 return n;
Jason Sams423ebcb2012-08-10 15:40:53 -0700277 }
278 }
Jason Sams423ebcb2012-08-10 15:40:53 -0700279 }
Jason Sams08a81582012-09-18 12:32:10 -0700280 return null;
281 }
Jason Sams423ebcb2012-08-10 15:40:53 -0700282
Jason Sams08a81582012-09-18 12:32:10 -0700283 /**
284 * Adds a Kernel to the group.
285 *
286 *
287 * @param k The kernel to add.
288 *
289 * @return Builder Returns this.
290 */
291 public Builder addKernel(Script.KernelID k) {
292 if (mLines.size() != 0) {
293 throw new RSInvalidStateException(
294 "Kernels may not be added once connections exist.");
Jason Sams423ebcb2012-08-10 15:40:53 -0700295 }
Jason Sams08a81582012-09-18 12:32:10 -0700296
297 //android.util.Log.v("RSR", "addKernel 1 k=" + k);
298 if (findNode(k) != null) {
299 return this;
300 }
301 //android.util.Log.v("RSR", "addKernel 2 ");
302 mKernelCount++;
303 Node n = findNode(k.mScript);
304 if (n == null) {
305 //android.util.Log.v("RSR", "addKernel 3 ");
306 n = new Node(k.mScript);
307 mNodes.add(n);
308 }
309 n.mKernels.add(k);
310 return this;
311 }
312
313 /**
314 * Adds a connection to the group.
315 *
316 *
317 * @param t The type of the connection. This is used to
318 * determine the kernel launch sizes on the source side
319 * of this connection.
320 * @param from The source for the connection.
321 * @param to The destination of the connection.
322 *
323 * @return Builder Returns this
324 */
325 public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
326 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
327
328 Node nf = findNode(from);
329 if (nf == null) {
330 throw new RSInvalidStateException("From kernel not found.");
331 }
332
333 Node nt = findNode(to.mScript);
334 if (nt == null) {
335 throw new RSInvalidStateException("To script not found.");
336 }
337
338 ConnectLine cl = new ConnectLine(t, from, to);
339 mLines.add(new ConnectLine(t, from, to));
340
341 nf.mOutputs.add(cl);
342 nt.mInputs.add(cl);
Jason Sams423ebcb2012-08-10 15:40:53 -0700343
Tim Murray2a603892012-10-10 14:21:46 -0700344 validateCycle();
Jason Sams423ebcb2012-08-10 15:40:53 -0700345 return this;
346 }
347
Jason Sams08a81582012-09-18 12:32:10 -0700348 /**
349 * Adds a connection to the group.
350 *
351 *
352 * @param t The type of the connection. This is used to
353 * determine the kernel launch sizes for both sides of
354 * this connection.
355 * @param from The source for the connection.
356 * @param to The destination of the connection.
357 *
358 * @return Builder Returns this
359 */
360 public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
361 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
362
363 Node nf = findNode(from);
364 if (nf == null) {
365 throw new RSInvalidStateException("From kernel not found.");
366 }
367
368 Node nt = findNode(to);
369 if (nt == null) {
370 throw new RSInvalidStateException("To script not found.");
371 }
372
373 ConnectLine cl = new ConnectLine(t, from, to);
374 mLines.add(new ConnectLine(t, from, to));
375
376 nf.mOutputs.add(cl);
377 nt.mInputs.add(cl);
378
Tim Murray2a603892012-10-10 14:21:46 -0700379 validateCycle();
Jason Sams08a81582012-09-18 12:32:10 -0700380 return this;
381 }
382
383
384
385 /**
386 * Creates the Script group.
387 *
388 *
389 * @return ScriptGroup The new ScriptGroup
390 */
Jason Sams423ebcb2012-08-10 15:40:53 -0700391 public ScriptGroup create() {
Tim Murray2a603892012-10-10 14:21:46 -0700392
393 if (mNodes.size() == 0) {
394 throw new RSInvalidStateException("Empty script groups are not allowed");
395 }
396
397 // reset DAG numbers in case we're building a second group
398 for (int ct=0; ct < mNodes.size(); ct++) {
399 mNodes.get(ct).dagNumber = 0;
400 }
401 validateDAG();
402
Jason Sams08a81582012-09-18 12:32:10 -0700403 ArrayList<IO> inputs = new ArrayList<IO>();
404 ArrayList<IO> outputs = new ArrayList<IO>();
Jason Sams423ebcb2012-08-10 15:40:53 -0700405
Jason Sams08a81582012-09-18 12:32:10 -0700406 int[] kernels = new int[mKernelCount];
407 int idx = 0;
408 for (int ct=0; ct < mNodes.size(); ct++) {
409 Node n = mNodes.get(ct);
410 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
411 final Script.KernelID kid = n.mKernels.get(ct2);
412 kernels[idx++] = kid.getID(mRS);
Jason Sams423ebcb2012-08-10 15:40:53 -0700413
Jason Sams08a81582012-09-18 12:32:10 -0700414 boolean hasInput = false;
415 boolean hasOutput = false;
416 for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
417 if (n.mInputs.get(ct3).mToK == kid) {
418 hasInput = true;
419 }
420 }
421 for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
422 if (n.mOutputs.get(ct3).mFrom == kid) {
423 hasOutput = true;
424 }
425 }
426 if (!hasInput) {
427 inputs.add(new IO(kid));
428 }
429 if (!hasOutput) {
430 outputs.add(new IO(kid));
431 }
432
433 }
434 }
435 if (idx != mKernelCount) {
436 throw new RSRuntimeException("Count mismatch, should not happen.");
437 }
438
439 int[] src = new int[mLines.size()];
440 int[] dstk = new int[mLines.size()];
441 int[] dstf = new int[mLines.size()];
442 int[] types = new int[mLines.size()];
443
444 for (int ct=0; ct < mLines.size(); ct++) {
445 ConnectLine cl = mLines.get(ct);
446 src[ct] = cl.mFrom.getID(mRS);
447 if (cl.mToK != null) {
448 dstk[ct] = cl.mToK.getID(mRS);
449 }
450 if (cl.mToF != null) {
451 dstf[ct] = cl.mToF.getID(mRS);
452 }
453 types[ct] = cl.mAllocationType.getID(mRS);
454 }
455
456 int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
457 if (id == 0) {
458 throw new RSRuntimeException("Object creation error, should not happen.");
459 }
460
461 ScriptGroup sg = new ScriptGroup(id, mRS);
462 sg.mOutputs = new IO[outputs.size()];
463 for (int ct=0; ct < outputs.size(); ct++) {
464 sg.mOutputs[ct] = outputs.get(ct);
465 }
466
467 sg.mInputs = new IO[inputs.size()];
468 for (int ct=0; ct < inputs.size(); ct++) {
469 sg.mInputs[ct] = inputs.get(ct);
470 }
471
Jason Sams423ebcb2012-08-10 15:40:53 -0700472 return sg;
473 }
474
475 }
476
477
478}
479
480