Search in sources :

Example 1 with PredictionAggregator

use of edu.iu.dsc.tws.examples.ml.svm.test.PredictionAggregator in project twister2 by DSC-SPIDAL.

the class SvmSgdAdvancedRunner method executeTestingTaskGraph.

/**
 * This method executes the testing taskgraph with testing data loaded from testing taskgraph
 * and uses the final weight vector obtained from the training task graph
 * Testing is also done in a parallel way. At the testing data loading stage we load the data
 * in parallel with reference to the given parallelism and testing is also in in parallel
 * Then we get test results for all these testing data partitions
 *
 * @return Returns the Accuracy value obtained
 */
public DataObject<Object> executeTestingTaskGraph() {
    DataObject<Object> data = null;
    predictionSourceTask = new PredictionSourceTask(svmJobParameters.isDummy(), this.binaryBatchModel, operationMode);
    predictionReduceTask = new PredictionReduceTask(operationMode);
    testingBuilder.addSource(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, predictionSourceTask, dataStreamerParallelism);
    ComputeConnection predictionReduceConnection = testingBuilder.addCompute(Constants.SimpleGraphConfig.PREDICTION_REDUCE_TASK, predictionReduceTask, reduceParallelism);
    predictionReduceConnection.reduce(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK).viaEdge(Constants.SimpleGraphConfig.PREDICTION_EDGE).withReductionFunction(new PredictionAggregator()).withDataType(MessageTypes.OBJECT);
    testingBuilder.setMode(operationMode);
    ComputeGraph predictionGraph = testingBuilder.build();
    predictionGraph.setGraphName("testing-graph");
    ExecutionPlan predictionPlan = taskExecutor.plan(predictionGraph);
    // adding test data set
    taskExecutor.addInput(predictionGraph, predictionPlan, Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, Constants.SimpleGraphConfig.TEST_DATA, testingData);
    // adding final weight vector
    taskExecutor.addInput(predictionGraph, predictionPlan, Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, Constants.SimpleGraphConfig.FINAL_WEIGHT_VECTOR, trainedWeightVector);
    taskExecutor.execute(predictionGraph, predictionPlan);
    data = retrieveTestingAccuracyObject(predictionGraph, predictionPlan);
    return data;
}
Also used : ExecutionPlan(edu.iu.dsc.tws.api.compute.executor.ExecutionPlan) ComputeGraph(edu.iu.dsc.tws.api.compute.graph.ComputeGraph) PredictionReduceTask(edu.iu.dsc.tws.examples.ml.svm.test.PredictionReduceTask) DataObject(edu.iu.dsc.tws.api.dataset.DataObject) PredictionAggregator(edu.iu.dsc.tws.examples.ml.svm.test.PredictionAggregator) PredictionSourceTask(edu.iu.dsc.tws.examples.ml.svm.test.PredictionSourceTask) ComputeConnection(edu.iu.dsc.tws.task.impl.ComputeConnection)

Aggregations

ExecutionPlan (edu.iu.dsc.tws.api.compute.executor.ExecutionPlan)1 ComputeGraph (edu.iu.dsc.tws.api.compute.graph.ComputeGraph)1 DataObject (edu.iu.dsc.tws.api.dataset.DataObject)1 PredictionAggregator (edu.iu.dsc.tws.examples.ml.svm.test.PredictionAggregator)1 PredictionReduceTask (edu.iu.dsc.tws.examples.ml.svm.test.PredictionReduceTask)1 PredictionSourceTask (edu.iu.dsc.tws.examples.ml.svm.test.PredictionSourceTask)1 ComputeConnection (edu.iu.dsc.tws.task.impl.ComputeConnection)1