use of edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeAccuracyReduceFunction in project twister2 by DSC-SPIDAL.
the class SvmSgdIterativeRunner method buildSvmSgdTestingTG.
private ComputeGraph buildSvmSgdTestingTG() {
iterativePredictionDataStreamer = new IterativePredictionDataStreamer(this.svmJobParameters.getFeatures(), this.operationMode, this.svmJobParameters.isDummy(), this.binaryBatchModel);
iterativeSVMAccuracyReduce = new IterativeSVMAccuracyReduce(this.operationMode);
testingBuilder.addSource(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK, iterativePredictionDataStreamer, dataStreamerParallelism);
ComputeConnection svmComputeConnection = testingBuilder.addCompute(Constants.SimpleGraphConfig.PREDICTION_REDUCE_TASK, iterativeSVMAccuracyReduce, dataStreamerParallelism);
svmComputeConnection.allreduce(Constants.SimpleGraphConfig.PREDICTION_SOURCE_TASK).viaEdge(Constants.SimpleGraphConfig.PREDICTION_EDGE).withReductionFunction(new IterativeAccuracyReduceFunction()).withDataType(MessageTypes.DOUBLE);
testingBuilder.setMode(operationMode);
testingBuilder.setTaskGraphName(IterativeSVMConstants.ITERATIVE_PREDICTION_TASK_GRAPH);
return testingBuilder.build();
}
use of edu.iu.dsc.tws.examples.ml.svm.aggregate.IterativeAccuracyReduceFunction in project twister2 by DSC-SPIDAL.
the class SvmSgdOnlineRunner method buildStreamingTrainingTG.
private ComputeGraph buildStreamingTrainingTG() {
iterativeStreamingDataStreamer = new IterativeStreamingDataStreamer(this.svmJobParameters.getFeatures(), OperationMode.STREAMING, this.svmJobParameters.isDummy(), this.binaryBatchModel);
BaseWindowedSink baseWindowedSink = getWindowSinkInstance();
iterativeStreamingCompute = new IterativeStreamingCompute(OperationMode.STREAMING, new ReduceAggregator(), this.svmJobParameters);
IterativeStreamingSinkEvaluator iterativeStreamingSinkEvaluator = new IterativeStreamingSinkEvaluator();
trainingBuilder.addSource(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_DATASTREAMER_SOURCE, iterativeStreamingDataStreamer, dataStreamerParallelism);
ComputeConnection svmComputeConnection = trainingBuilder.addCompute(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_SVM_COMPUTE, baseWindowedSink, dataStreamerParallelism);
ComputeConnection svmReduceConnection = trainingBuilder.addCompute("window-sink", iterativeStreamingCompute, dataStreamerParallelism);
ComputeConnection svmFinalEvaluationConnection = trainingBuilder.addCompute("window-evaluation-sink", iterativeStreamingSinkEvaluator, dataStreamerParallelism);
svmComputeConnection.direct(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_DATASTREAMER_SOURCE).viaEdge(Constants.SimpleGraphConfig.STREAMING_EDGE).withDataType(MessageTypes.DOUBLE_ARRAY);
svmReduceConnection.allreduce(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_SVM_COMPUTE).viaEdge("window-sink-edge").withReductionFunction(new ReduceAggregator()).withDataType(MessageTypes.DOUBLE_ARRAY);
svmFinalEvaluationConnection.allreduce("window-sink").viaEdge("window-evaluation-edge").withReductionFunction(new IterativeAccuracyReduceFunction()).withDataType(MessageTypes.DOUBLE);
trainingBuilder.setMode(OperationMode.STREAMING);
trainingBuilder.setTaskGraphName(IterativeSVMConstants.ITERATIVE_STREAMING_TRAINING_TASK_GRAPH);
return trainingBuilder.build();
}
Aggregations