Search in sources :

Example 1 with ReduceFunc

use of edu.iu.dsc.tws.api.tset.fn.ReduceFunc in project twister2 by DSC-SPIDAL.

the class KMeansTsetJob method execute.

@Override
public void execute(WorkerEnvironment workerEnv) {
    BatchEnvironment env = TSetEnvironment.initBatch(workerEnv);
    int workerId = env.getWorkerID();
    LOG.info("TSet worker starting: " + workerId);
    Config config = env.getConfig();
    int parallelism = config.getIntegerValue(DataObjectConstants.PARALLELISM_VALUE);
    int dimension = config.getIntegerValue(DataObjectConstants.DIMENSIONS);
    int numFiles = config.getIntegerValue(DataObjectConstants.NUMBER_OF_FILES);
    int dsize = config.getIntegerValue(DataObjectConstants.DSIZE);
    int csize = config.getIntegerValue(DataObjectConstants.CSIZE);
    int iterations = config.getIntegerValue(DataObjectConstants.ARGS_ITERATIONS);
    String dataDirectory = config.getStringValue(DataObjectConstants.DINPUT_DIRECTORY) + workerId;
    String centroidDirectory = config.getStringValue(DataObjectConstants.CINPUT_DIRECTORY) + workerId;
    String type = config.getStringValue(DataObjectConstants.FILE_TYPE);
    KMeansUtils.generateDataPoints(env.getConfig(), dimension, numFiles, dsize, csize, dataDirectory, centroidDirectory, type);
    long startTime = System.currentTimeMillis();
    /*CachedTSet<double[][]> points =
        tc.createSource(new PointsSource(type), parallelismValue).setName("dataSource").cache();*/
    SourceTSet<String[]> pointSource = env.createCSVSource(dataDirectory, dsize, parallelism, "split");
    ComputeTSet<double[][]> points = pointSource.direct().compute(new ComputeFunc<Iterator<String[]>, double[][]>() {

        private double[][] localPoints = new double[dsize / parallelism][dimension];

        @Override
        public double[][] compute(Iterator<String[]> input) {
            for (int i = 0; i < dsize / parallelism && input.hasNext(); i++) {
                String[] value = input.next();
                for (int j = 0; j < value.length; j++) {
                    localPoints[i][j] = Double.parseDouble(value[j]);
                }
            }
            return localPoints;
        }
    });
    points.setName("dataSource").cache();
    // CachedTSet<double[][]> centers = tc.createSource(new CenterSource(type), parallelism).cache();
    SourceTSet<String[]> centerSource = env.createCSVSource(centroidDirectory, csize, parallelism, "complete");
    ComputeTSet<double[][]> centers = centerSource.direct().compute(new ComputeFunc<Iterator<String[]>, double[][]>() {

        private double[][] localCenters = new double[csize][dimension];

        @Override
        public double[][] compute(Iterator<String[]> input) {
            for (int i = 0; i < csize && input.hasNext(); i++) {
                String[] value = input.next();
                for (int j = 0; j < dimension; j++) {
                    localCenters[i][j] = Double.parseDouble(value[j]);
                }
            }
            return localCenters;
        }
    });
    CachedTSet<double[][]> cachedCenters = centers.cache();
    long endTimeData = System.currentTimeMillis();
    ComputeTSet<double[][]> kmeansTSet = points.direct().map(new KMeansMap());
    ComputeTSet<double[][]> reduced = kmeansTSet.allReduce((ReduceFunc<double[][]>) (t1, t2) -> {
        double[][] newCentroids = new double[t1.length][t1[0].length];
        for (int j = 0; j < t1.length; j++) {
            for (int k = 0; k < t1[0].length; k++) {
                double newVal = t1[j][k] + t2[j][k];
                newCentroids[j][k] = newVal;
            }
        }
        return newCentroids;
    }).map(new AverageCenters());
    kmeansTSet.addInput("centers", cachedCenters);
    CachedTSet<double[][]> cached = reduced.lazyCache();
    for (int i = 0; i < iterations; i++) {
        env.evalAndUpdate(cached, cachedCenters);
    }
    env.finishEval(cached);
    long endTime = System.currentTimeMillis();
    if (workerId == 0) {
        LOG.info("Data Load time : " + (endTimeData - startTime) + "\n" + "Total Time : " + (endTime - startTime) + "Compute Time : " + (endTime - endTimeData));
        LOG.info("Final Centroids After\t" + iterations + "\titerations\t");
        cachedCenters.direct().forEach(i -> LOG.info(Arrays.deepToString(i)));
    }
}
Also used : BatchEnvironment(edu.iu.dsc.tws.tset.env.BatchEnvironment) Config(edu.iu.dsc.tws.api.config.Config) Iterator(java.util.Iterator) ReduceFunc(edu.iu.dsc.tws.api.tset.fn.ReduceFunc)

Aggregations

Config (edu.iu.dsc.tws.api.config.Config)1 ReduceFunc (edu.iu.dsc.tws.api.tset.fn.ReduceFunc)1 BatchEnvironment (edu.iu.dsc.tws.tset.env.BatchEnvironment)1 Iterator (java.util.Iterator)1