use of edu.iu.dsc.tws.task.impl.ComputeConnection in project twister2 by DSC-SPIDAL.
the class SvmSgdAdvancedRunner method executeWeightVectorLoadingTaskGraph.
/**
* This method loads the training data in a distributed mode
* dataStreamerParallelism is the amount of parallelism used
* in loaded the data in parallel.
*
* @return twister2 DataObject containing the training data
*/
public DataObject<Object> executeWeightVectorLoadingTaskGraph() {
DataObject<Object> data = null;
DataObjectSource sourceTask = new DataObjectSource(Context.TWISTER2_DIRECT_EDGE, this.svmJobParameters.getWeightVectorDataDir());
DataObjectSink sinkTask = new DataObjectSink();
trainingBuilder.addSource(Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE, sourceTask, dataStreamerParallelism);
ComputeConnection firstGraphComputeConnection = trainingBuilder.addCompute(Constants.SimpleGraphConfig.DATA_OBJECT_SINK, sinkTask, dataStreamerParallelism);
firstGraphComputeConnection.direct(Constants.SimpleGraphConfig.DATA_OBJECT_SOURCE).viaEdge(Context.TWISTER2_DIRECT_EDGE).withDataType(MessageTypes.OBJECT);
trainingBuilder.setMode(OperationMode.BATCH);
ComputeGraph datapointsTaskGraph = trainingBuilder.build();
datapointsTaskGraph.setGraphName("weight-vector-loading-graph");
ExecutionPlan firstGraphExecutionPlan = taskExecutor.plan(datapointsTaskGraph);
taskExecutor.execute(datapointsTaskGraph, firstGraphExecutionPlan);
data = taskExecutor.getOutput(datapointsTaskGraph, firstGraphExecutionPlan, Constants.SimpleGraphConfig.DATA_OBJECT_SINK);
if (data == null) {
throw new NullPointerException("Something Went Wrong in Loading Weight Vector");
} else {
LOG.info("Training Data Total Partitions : " + data.getPartitions().length);
}
return data;
}
use of edu.iu.dsc.tws.task.impl.ComputeConnection in project twister2 by DSC-SPIDAL.
the class SvmSgdOnlineRunner method buildStreamingTrainingTG.
private ComputeGraph buildStreamingTrainingTG() {
iterativeStreamingDataStreamer = new IterativeStreamingDataStreamer(this.svmJobParameters.getFeatures(), OperationMode.STREAMING, this.svmJobParameters.isDummy(), this.binaryBatchModel);
BaseWindowedSink baseWindowedSink = getWindowSinkInstance();
iterativeStreamingCompute = new IterativeStreamingCompute(OperationMode.STREAMING, new ReduceAggregator(), this.svmJobParameters);
IterativeStreamingSinkEvaluator iterativeStreamingSinkEvaluator = new IterativeStreamingSinkEvaluator();
trainingBuilder.addSource(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_DATASTREAMER_SOURCE, iterativeStreamingDataStreamer, dataStreamerParallelism);
ComputeConnection svmComputeConnection = trainingBuilder.addCompute(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_SVM_COMPUTE, baseWindowedSink, dataStreamerParallelism);
ComputeConnection svmReduceConnection = trainingBuilder.addCompute("window-sink", iterativeStreamingCompute, dataStreamerParallelism);
ComputeConnection svmFinalEvaluationConnection = trainingBuilder.addCompute("window-evaluation-sink", iterativeStreamingSinkEvaluator, dataStreamerParallelism);
svmComputeConnection.direct(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_DATASTREAMER_SOURCE).viaEdge(Constants.SimpleGraphConfig.STREAMING_EDGE).withDataType(MessageTypes.DOUBLE_ARRAY);
svmReduceConnection.allreduce(Constants.SimpleGraphConfig.ITERATIVE_STREAMING_SVM_COMPUTE).viaEdge("window-sink-edge").withReductionFunction(new ReduceAggregator()).withDataType(MessageTypes.DOUBLE_ARRAY);
svmFinalEvaluationConnection.allreduce("window-sink").viaEdge("window-evaluation-edge").withReductionFunction(new IterativeAccuracyReduceFunction()).withDataType(MessageTypes.DOUBLE);
trainingBuilder.setMode(OperationMode.STREAMING);
trainingBuilder.setTaskGraphName(IterativeSVMConstants.ITERATIVE_STREAMING_TRAINING_TASK_GRAPH);
return trainingBuilder.build();
}
use of edu.iu.dsc.tws.task.impl.ComputeConnection in project twister2 by DSC-SPIDAL.
the class SvmSgdIterativeRunner method buildWeightVectorTG.
private ComputeGraph buildWeightVectorTG() {
DataFileReplicatedReadSource dataFileReplicatedReadSource = new DataFileReplicatedReadSource(Context.TWISTER2_DIRECT_EDGE, this.svmJobParameters.getWeightVectorDataDir(), 1);
IterativeSVMWeightVectorObjectCompute weightVectorObjectCompute = new IterativeSVMWeightVectorObjectCompute(Context.TWISTER2_DIRECT_EDGE, 1, this.svmJobParameters.getFeatures());
IterativeSVMWeightVectorObjectDirectSink weightVectorObjectSink = new IterativeSVMWeightVectorObjectDirectSink();
ComputeGraphBuilder weightVectorComputeGraphBuilder = ComputeGraphBuilder.newBuilder(config);
weightVectorComputeGraphBuilder.addSource(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_SOURCE, dataFileReplicatedReadSource, dataStreamerParallelism);
ComputeConnection weightVectorComputeConnection = weightVectorComputeGraphBuilder.addCompute(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_COMPUTE, weightVectorObjectCompute, dataStreamerParallelism);
ComputeConnection weightVectorSinkConnection = weightVectorComputeGraphBuilder.addCompute(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_SINK, weightVectorObjectSink, dataStreamerParallelism);
weightVectorComputeConnection.direct(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_SOURCE).viaEdge(Context.TWISTER2_DIRECT_EDGE).withDataType(MessageTypes.OBJECT);
weightVectorSinkConnection.direct(Constants.SimpleGraphConfig.WEIGHT_VECTOR_OBJECT_COMPUTE).viaEdge(Context.TWISTER2_DIRECT_EDGE).withDataType(MessageTypes.DOUBLE_ARRAY);
weightVectorComputeGraphBuilder.setMode(operationMode);
weightVectorComputeGraphBuilder.setTaskGraphName(IterativeSVMConstants.WEIGHT_VECTOR_LOADING_TASK_GRAPH);
return weightVectorComputeGraphBuilder.build();
}
use of edu.iu.dsc.tws.task.impl.ComputeConnection in project twister2 by DSC-SPIDAL.
the class SvmSgdIterativeRunner method buildSvmSgdIterativeTrainingTG.
private ComputeGraph buildSvmSgdIterativeTrainingTG() {
iterativeDataStream = new IterativeDataStream(this.svmJobParameters.getFeatures(), this.operationMode, this.svmJobParameters.isDummy(), this.binaryBatchModel);
iterativeSVMRiterativeSVMWeightVectorReduce = new IterativeSVMWeightVectorReduce(this.operationMode);
trainingBuilder.addSource(Constants.SimpleGraphConfig.ITERATIVE_DATASTREAMER_SOURCE, iterativeDataStream, dataStreamerParallelism);
ComputeConnection svmComputeConnection = trainingBuilder.addCompute(Constants.SimpleGraphConfig.ITERATIVE_SVM_REDUCE, iterativeSVMRiterativeSVMWeightVectorReduce, dataStreamerParallelism);
svmComputeConnection.allreduce(Constants.SimpleGraphConfig.ITERATIVE_DATASTREAMER_SOURCE).viaEdge(Constants.SimpleGraphConfig.REDUCE_EDGE).withReductionFunction(new IterativeWeightVectorReduceFunction()).withDataType(MessageTypes.DOUBLE_ARRAY);
trainingBuilder.setMode(operationMode);
trainingBuilder.setTaskGraphName(IterativeSVMConstants.ITERATIVE_TRAINING_TASK_GRAPH);
return trainingBuilder.build();
}
use of edu.iu.dsc.tws.task.impl.ComputeConnection in project twister2 by DSC-SPIDAL.
the class SvmSgdIterativeRunner method generateGenericDataPointLoader.
private ComputeGraph generateGenericDataPointLoader(int samples, int parallelism, int numOfFeatures, String dataSourcePathStr, String dataObjectSourceStr, String dataObjectComputeStr, String dataObjectSinkStr, String graphName) {
SVMDataObjectSource<String, TextInputSplit> sourceTask = new SVMDataObjectSource(Context.TWISTER2_DIRECT_EDGE, dataSourcePathStr, samples);
IterativeSVMDataObjectCompute dataObjectCompute = new IterativeSVMDataObjectCompute(Context.TWISTER2_DIRECT_EDGE, parallelism, samples, numOfFeatures, DELIMITER);
IterativeSVMDataObjectDirectSink iterativeSVMPrimaryDataObjectDirectSink = new IterativeSVMDataObjectDirectSink();
ComputeGraphBuilder datapointsComputeGraphBuilder = ComputeGraphBuilder.newBuilder(config);
datapointsComputeGraphBuilder.addSource(dataObjectSourceStr, sourceTask, parallelism);
ComputeConnection datapointComputeConnection = datapointsComputeGraphBuilder.addCompute(dataObjectComputeStr, dataObjectCompute, parallelism);
ComputeConnection computeConnectionSink = datapointsComputeGraphBuilder.addCompute(dataObjectSinkStr, iterativeSVMPrimaryDataObjectDirectSink, parallelism);
datapointComputeConnection.direct(dataObjectSourceStr).viaEdge(Context.TWISTER2_DIRECT_EDGE).withDataType(MessageTypes.OBJECT);
computeConnectionSink.direct(dataObjectComputeStr).viaEdge(Context.TWISTER2_DIRECT_EDGE).withDataType(MessageTypes.OBJECT);
datapointsComputeGraphBuilder.setMode(this.operationMode);
datapointsComputeGraphBuilder.setTaskGraphName(graphName);
// Build the first taskgraph
return datapointsComputeGraphBuilder.build();
}
Aggregations