use of edu.iu.dsc.tws.examples.ml.svm.test.PredictionReduceTask in project twister2 by DSC-SPIDAL.
the class SvmSgdAdvancedRunner method executeTestingTaskGraph.
/**
* This method executes the testing taskgraph with testing data loaded from testing taskgraph
* and uses the final weight vector obtained from the training task graph
* Testing is also done in a parallel way. At the testing data loading stage we load the data
* in parallel with reference to the given parallelism and testing is also in in parallel
* Then we get test results for all these testing data partitions
*
* @return Returns the Accuracy value obtained
*/
public DataObject<Object> executeTestingTaskGraph() {
DataObject<Object> data = null;
predictionSourceTask = new PredictionSourceTask(svmJobParameters.isDummy(), this.binaryBatchModel, operationMode);
predictionReduceTask = new PredictionReduceTask(operationMode);
testingBuilder.addSource(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, predictionSourceTask, dataStreamerParallelism);
ComputeConnection predictionReduceConnection = testingBuilder.addCompute(Constants.SimpleGraphConfig.PREDICTION_REDUCE_TASK, predictionReduceTask, reduceParallelism);
predictionReduceConnection.reduce(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK).viaEdge(Constants.SimpleGraphConfig.PREDICTION_EDGE).withReductionFunction(new PredictionAggregator()).withDataType(MessageTypes.OBJECT);
testingBuilder.setMode(operationMode);
ComputeGraph predictionGraph = testingBuilder.build();
predictionGraph.setGraphName("testing-graph");
ExecutionPlan predictionPlan = taskExecutor.plan(predictionGraph);
// adding test data set
taskExecutor.addInput(predictionGraph, predictionPlan, Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, Constants.SimpleGraphConfig.TEST_DATA, testingData);
// adding final weight vector
taskExecutor.addInput(predictionGraph, predictionPlan, Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, Constants.SimpleGraphConfig.FINAL_WEIGHT_VECTOR, trainedWeightVector);
taskExecutor.execute(predictionGraph, predictionPlan);
data = retrieveTestingAccuracyObject(predictionGraph, predictionPlan);
return data;
}
Aggregations