use of edu.iu.dsc.tws.api.tset.fn.ReduceFunc in project twister2 by DSC-SPIDAL.
the class KMeansTsetJob method execute.
@Override
public void execute(WorkerEnvironment workerEnv) {
BatchEnvironment env = TSetEnvironment.initBatch(workerEnv);
int workerId = env.getWorkerID();
LOG.info("TSet worker starting: " + workerId);
Config config = env.getConfig();
int parallelism = config.getIntegerValue(DataObjectConstants.PARALLELISM_VALUE);
int dimension = config.getIntegerValue(DataObjectConstants.DIMENSIONS);
int numFiles = config.getIntegerValue(DataObjectConstants.NUMBER_OF_FILES);
int dsize = config.getIntegerValue(DataObjectConstants.DSIZE);
int csize = config.getIntegerValue(DataObjectConstants.CSIZE);
int iterations = config.getIntegerValue(DataObjectConstants.ARGS_ITERATIONS);
String dataDirectory = config.getStringValue(DataObjectConstants.DINPUT_DIRECTORY) + workerId;
String centroidDirectory = config.getStringValue(DataObjectConstants.CINPUT_DIRECTORY) + workerId;
String type = config.getStringValue(DataObjectConstants.FILE_TYPE);
KMeansUtils.generateDataPoints(env.getConfig(), dimension, numFiles, dsize, csize, dataDirectory, centroidDirectory, type);
long startTime = System.currentTimeMillis();
/*CachedTSet<double[][]> points =
tc.createSource(new PointsSource(type), parallelismValue).setName("dataSource").cache();*/
SourceTSet<String[]> pointSource = env.createCSVSource(dataDirectory, dsize, parallelism, "split");
ComputeTSet<double[][]> points = pointSource.direct().compute(new ComputeFunc<Iterator<String[]>, double[][]>() {
private double[][] localPoints = new double[dsize / parallelism][dimension];
@Override
public double[][] compute(Iterator<String[]> input) {
for (int i = 0; i < dsize / parallelism && input.hasNext(); i++) {
String[] value = input.next();
for (int j = 0; j < value.length; j++) {
localPoints[i][j] = Double.parseDouble(value[j]);
}
}
return localPoints;
}
});
points.setName("dataSource").cache();
// CachedTSet<double[][]> centers = tc.createSource(new CenterSource(type), parallelism).cache();
SourceTSet<String[]> centerSource = env.createCSVSource(centroidDirectory, csize, parallelism, "complete");
ComputeTSet<double[][]> centers = centerSource.direct().compute(new ComputeFunc<Iterator<String[]>, double[][]>() {
private double[][] localCenters = new double[csize][dimension];
@Override
public double[][] compute(Iterator<String[]> input) {
for (int i = 0; i < csize && input.hasNext(); i++) {
String[] value = input.next();
for (int j = 0; j < dimension; j++) {
localCenters[i][j] = Double.parseDouble(value[j]);
}
}
return localCenters;
}
});
CachedTSet<double[][]> cachedCenters = centers.cache();
long endTimeData = System.currentTimeMillis();
ComputeTSet<double[][]> kmeansTSet = points.direct().map(new KMeansMap());
ComputeTSet<double[][]> reduced = kmeansTSet.allReduce((ReduceFunc<double[][]>) (t1, t2) -> {
double[][] newCentroids = new double[t1.length][t1[0].length];
for (int j = 0; j < t1.length; j++) {
for (int k = 0; k < t1[0].length; k++) {
double newVal = t1[j][k] + t2[j][k];
newCentroids[j][k] = newVal;
}
}
return newCentroids;
}).map(new AverageCenters());
kmeansTSet.addInput("centers", cachedCenters);
CachedTSet<double[][]> cached = reduced.lazyCache();
for (int i = 0; i < iterations; i++) {
env.evalAndUpdate(cached, cachedCenters);
}
env.finishEval(cached);
long endTime = System.currentTimeMillis();
if (workerId == 0) {
LOG.info("Data Load time : " + (endTimeData - startTime) + "\n" + "Total Time : " + (endTime - startTime) + "Compute Time : " + (endTime - endTimeData));
LOG.info("Final Centroids After\t" + iterations + "\titerations\t");
cachedCenters.direct().forEach(i -> LOG.info(Arrays.deepToString(i)));
}
}
Aggregations