Search in sources :

Example 1 with DefaultDistributedInfo

use of com.alibaba.alink.common.io.directreader.DefaultDistributedInfo in project Alink by alibaba.

the class ApsIndexFunc4PullLine method requestIndex.

@Override
protected Set<Long> requestIndex(List<Number[]> data) throws Exception {
    long startTimeConsumption = System.nanoTime();
    LOG.info("taskId: {}, localId: {}", getPatitionId(), getRuntimeContext().getIndexOfThisSubtask());
    LOG.info("taskId: {}, negInputSize: {}", getPatitionId(), data.size());
    if (null != this.contextParams) {
        Long[] nsPool = this.contextParams.getLongArray("negBound");
        int subTaskId = getRuntimeContext().getIndexOfThisSubtask();
        Long[] seeds = this.contextParams.get(ApsContext.SEEDS);
        int seed = seeds[subTaskId].intValue();
        int negaTime = params.get(HasNegative.NEGATIVE);
        int threadNum = params.getIntegerOrDefault("threadNum", 8);
        double sampleRatioPerPartition = params.get(LineParams.SAMPLE_RATIO_PER_PARTITION);
        Thread[] thread = new Thread[threadNum];
        // save the vertices which need to be pulled.
        Set<Long>[] output = new Set[threadNum];
        DistributedInfo distributedInfo = new DefaultDistributedInfo();
        for (int i = 0; i < threadNum; ++i) {
            int start = (int) distributedInfo.startPos(i, threadNum, data.size());
            int end = (int) distributedInfo.localRowCnt(i, threadNum, data.size()) + start;
            LOG.info("taskId: {}, negStart: {}, end: {}", getPatitionId(), start, end);
            output[i] = new HashSet<>();
            thread[i] = new NegSampleRunner(nsPool, seed + i, data.subList(start, end), output[i], negaTime, sampleRatioPerPartition);
            thread[i].start();
        }
        for (int i = 0; i < threadNum; ++i) {
            thread[i].join();
        }
        Set<Long> outputMerger = new HashSet<>();
        for (int i = 0; i < threadNum; ++i) {
            outputMerger.addAll(output[i]);
        }
        LOG.info("taskId: {}, negOutputSize: {}", getPatitionId(), outputMerger.size());
        long endTimeConsumption = System.nanoTime();
        LOG.info("taskId: {}, negTime: {}", getPatitionId(), (endTimeConsumption - startTimeConsumption) / 1000000.0);
        return outputMerger;
    } else {
        throw new RuntimeException();
    }
}
Also used : HashSet(java.util.HashSet) Set(java.util.Set) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) DistributedInfo(com.alibaba.alink.common.io.directreader.DistributedInfo) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) HashSet(java.util.HashSet)

Example 2 with DefaultDistributedInfo

use of com.alibaba.alink.common.io.directreader.DefaultDistributedInfo in project Alink by alibaba.

the class ApsContext method updateLoopInfo.

public <MT> ApsContext updateLoopInfo(IterativeDataSet<Tuple2<Long, MT>> loop) {
    this.put(alinkApsStepNum, loop.mapPartition(new RichMapPartitionFunction<Tuple2<Long, MT>, Integer>() {

        private static final long serialVersionUID = 4816930852791283240L;

        @Override
        public void mapPartition(Iterable<Tuple2<Long, MT>> iterable, Collector<Integer> collector) throws Exception {
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                collector.collect(getIterationRuntimeContext().getSuperstepNumber());
            }
        }
    }).returns(Types.INT));
    context = context.map(new RichMapFunction<Params, Params>() {

        private static final long serialVersionUID = -7796590264875170687L;

        IntCounter counter = new IntCounter();

        @Override
        public void open(Configuration parameters) throws Exception {
            getRuntimeContext().addAccumulator(alinkApsBreakAll, this.counter);
        }

        @Override
        public Params map(Params params) throws Exception {
            int numCheckPoint = params.getInteger(alinkApsNumCheckpoint);
            int numIter = params.getIntegerOrDefault(alinkApsNumIter, 1);
            int numMiniBatch = params.getInteger(alinkApsNumMiniBatch);
            int startPos, endPos;
            if (numCheckPoint <= 0) {
                startPos = 0;
                endPos = numMiniBatch * numIter;
            } else {
                int curCheckPoint = params.getInteger(alinkApsCurCheckpoint);
                DistributedInfo distributedInfo = new DefaultDistributedInfo();
                startPos = (int) distributedInfo.startPos(curCheckPoint, numCheckPoint, numMiniBatch * numIter);
                endPos = (int) distributedInfo.localRowCnt(curCheckPoint, numCheckPoint, numMiniBatch * numIter) + startPos;
            }
            int curPos = params.getInteger(alinkApsStepNum) - 1 + startPos;
            LOG.info("taskId:{}, stepNum:{}", getRuntimeContext().getIndexOfThisSubtask(), curPos);
            int curBlock;
            if (curPos >= endPos) {
                curBlock = -1;
            } else if (curPos >= numMiniBatch * numIter) {
                curBlock = -1;
            } else {
                curBlock = curPos % numMiniBatch;
            }
            boolean hasNextBlock = true;
            if (curPos >= endPos - 1) {
                hasNextBlock = false;
            } else if (curPos >= numMiniBatch * numIter - 1) {
                hasNextBlock = false;
            }
            if (endPos >= numMiniBatch * numIter) {
                this.counter.add(1);
            }
            params.set(ALINK_APS_HAS_NEXT_BLOCK, hasNextBlock);
            params.set(ALINK_APS_CUR_BLOCK, curBlock);
            return params;
        }
    });
    return this;
}
Also used : Configuration(org.apache.flink.configuration.Configuration) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) Params(org.apache.flink.ml.api.misc.param.Params) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) DistributedInfo(com.alibaba.alink.common.io.directreader.DistributedInfo) Tuple2(org.apache.flink.api.java.tuple.Tuple2) RichMapFunction(org.apache.flink.api.common.functions.RichMapFunction) IntCounter(org.apache.flink.api.common.accumulators.IntCounter)

Example 3 with DefaultDistributedInfo

use of com.alibaba.alink.common.io.directreader.DefaultDistributedInfo in project Alink by alibaba.

the class ApsFuncIndex4PullW2V method requestIndex.

@Override
protected Set<Long> requestIndex(List<int[]> data) throws Exception {
    long startTimeConsumption = System.nanoTime();
    LOG.info("taskId: {}, localId: {}", getPatitionId(), getRuntimeContext().getIndexOfThisSubtask());
    LOG.info("taskId: {}, negInputSize: {}", getPatitionId(), data.size());
    if (null != this.contextParams) {
        Long[] seeds = this.contextParams.get(ApsContext.SEEDS);
        Long[] nsPool = this.contextParams.getLongArray("negBound");
        final boolean metapathMode = params.getBoolOrDefault("metapathMode", false);
        long[] groupIdxStarts = null;
        if (metapathMode) {
            groupIdxStarts = ArrayUtils.toPrimitive(this.contextParams.getLongArray("groupIdxes"));
        }
        int vocSize = this.contextParams.getLong("vocSize").intValue();
        long seed = seeds[getPatitionId()];
        int threadNum = params.getIntegerOrDefault("threadNum", 8);
        Thread[] thread = new Thread[threadNum];
        Set<Long>[] output = new Set[threadNum];
        DistributedInfo distributedInfo = new DefaultDistributedInfo();
        for (int i = 0; i < threadNum; ++i) {
            int start = (int) distributedInfo.startPos(i, threadNum, data.size());
            int end = (int) distributedInfo.localRowCnt(i, threadNum, data.size()) + start;
            LOG.info("taskId: {}, negStart: {}, end: {}", getPatitionId(), start, end);
            output[i] = new HashSet<>();
            thread[i] = new NegSampleRunner(vocSize, nsPool, params, seed + i, data.subList(start, end), output[i], null, groupIdxStarts);
            thread[i].start();
        }
        for (int i = 0; i < threadNum; ++i) {
            thread[i].join();
        }
        Set<Long> outputMerger = new HashSet<>();
        for (int i = 0; i < threadNum; ++i) {
            outputMerger.addAll(output[i]);
        }
        LOG.info("taskId: {}, negOutputSize: {}", getPatitionId(), outputMerger.size());
        long endTimeConsumption = System.nanoTime();
        LOG.info("taskId: {}, negTime: {}", getPatitionId(), (endTimeConsumption - startTimeConsumption) / 1000000.0);
        return outputMerger;
    } else {
        throw new RuntimeException();
    }
}
Also used : HashSet(java.util.HashSet) Set(java.util.Set) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) DistributedInfo(com.alibaba.alink.common.io.directreader.DistributedInfo) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) HashSet(java.util.HashSet)

Example 4 with DefaultDistributedInfo

use of com.alibaba.alink.common.io.directreader.DefaultDistributedInfo in project Alink by alibaba.

the class BaseTuning method split.

private DataSet<Tuple2<Integer, Row>> split(BatchOperator<?> data, int k) {
    DataSet<Row> input = shuffle(data.getDataSet());
    DataSet<Tuple2<Integer, Long>> counts = DataSetUtils.countElementsPerPartition(input);
    return input.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() {

        private static final long serialVersionUID = -902599228310615694L;

        long taskStart = 0L;

        long totalNumInstance = 0L;

        @Override
        public void open(Configuration parameters) {
            List<Tuple2<Integer, Long>> counts1 = getRuntimeContext().getBroadcastVariable("counts");
            int taskId = getRuntimeContext().getIndexOfThisSubtask();
            for (Tuple2<Integer, Long> cnt : counts1) {
                if (taskId < cnt.f0) {
                    taskStart += cnt.f1;
                }
                totalNumInstance += cnt.f1;
            }
        }

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Row>> out) {
            DistributedInfo distributedInfo = new DefaultDistributedInfo();
            Tuple2<Integer, Long> split1 = new Tuple2<>(-1, -1L);
            long lcnt = taskStart;
            for (int i = 0; i <= k; ++i) {
                long sp = distributedInfo.startPos(i, k, totalNumInstance);
                long lrc = distributedInfo.localRowCnt(i, k, totalNumInstance);
                if (taskStart < sp) {
                    split1.f0 = i - 1;
                    split1.f1 = distributedInfo.startPos(i - 1, k, totalNumInstance) + distributedInfo.localRowCnt(i - 1, k, totalNumInstance);
                    break;
                }
                if (taskStart == sp) {
                    split1.f0 = i;
                    split1.f1 = sp + lrc;
                    break;
                }
            }
            for (Row val : values) {
                if (lcnt >= split1.f1) {
                    split1.f0 += 1;
                    split1.f1 = distributedInfo.localRowCnt(split1.f0, k, totalNumInstance) + lcnt;
                }
                out.collect(Tuple2.of(split1.f0, val));
                lcnt++;
            }
        }
    }).withBroadcastSet(counts, "counts");
}
Also used : Configuration(org.apache.flink.configuration.Configuration) AlinkGlobalConfiguration(com.alibaba.alink.common.AlinkGlobalConfiguration) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) RichMapPartitionFunction(org.apache.flink.api.common.functions.RichMapPartitionFunction) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) DistributedInfo(com.alibaba.alink.common.io.directreader.DistributedInfo) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Collector(org.apache.flink.util.Collector) Row(org.apache.flink.types.Row)

Example 5 with DefaultDistributedInfo

use of com.alibaba.alink.common.io.directreader.DefaultDistributedInfo in project Alink by alibaba.

the class CalcFeatureGain method calc.

@Override
public void calc(ComContext context) {
    LOG.info("taskId: {}, {} start", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
    BoostingObjs boostingObjs = context.getObj("boostingObjs");
    HistogramBaseTreeObjs tree = context.getObj("tree");
    double[] histogram = context.getObj("histogram");
    if (context.getStepNo() == 1) {
        context.putObj("best", new Node[tree.maxNodeSize]);
        featureSplitters = new HistogramFeatureSplitter[boostingObjs.data.getN()];
        for (int i = 0; i < boostingObjs.data.getN(); ++i) {
            featureSplitters[i] = createFeatureSplitter(boostingObjs.data.getFeatureMetas()[i].getType() == FeatureMeta.FeatureType.CATEGORICAL, boostingObjs.params, boostingObjs.data.getFeatureMetas()[i], tree.compareIndex4Categorical);
        }
    }
    int sumFeatureCount = 0;
    for (NodeInfoPair item : tree.queue) {
        sumFeatureCount += boostingObjs.numBaggingFeatures;
        if (item.big != null) {
            sumFeatureCount += boostingObjs.numBaggingFeatures;
        }
    }
    DistributedInfo distributedInfo = new DefaultDistributedInfo();
    int start = (int) distributedInfo.startPos(context.getTaskId(), context.getNumTask(), sumFeatureCount);
    int cnt = (int) distributedInfo.localRowCnt(context.getTaskId(), context.getNumTask(), sumFeatureCount);
    int end = start + cnt;
    int featureCnt = 0;
    int featureBinCnt = 0;
    Node[] best = context.getObj("best");
    int index = 0;
    for (NodeInfoPair item : tree.queue) {
        best[index] = null;
        final int[] smallBaggingFeatures = item.small.baggingFeatures;
        for (int smallBaggingFeature : smallBaggingFeatures) {
            if (featureCnt >= start && featureCnt < end) {
                featureSplitters[smallBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing)), item.small.depth);
                double gain = featureSplitters[smallBaggingFeature].bestSplit(tree.leaves.size());
                if (best[index] == null || (featureSplitters[smallBaggingFeature].canSplit() && gain > best[index].getGain())) {
                    best[index] = new Node();
                    featureSplitters[smallBaggingFeature].fillNode(best[index]);
                }
                featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing);
            }
            featureCnt++;
        }
        index++;
        if (item.big != null) {
            best[index] = null;
            final int[] bigBaggingFeatures = item.big.baggingFeatures;
            for (int bigBaggingFeature : bigBaggingFeatures) {
                if (featureCnt >= start && featureCnt < end) {
                    featureSplitters[bigBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing)), item.big.depth);
                    double gain = featureSplitters[bigBaggingFeature].bestSplit(tree.leaves.size());
                    if (best[index] == null || (featureSplitters[bigBaggingFeature].canSplit() && gain > best[index].getGain())) {
                        best[index] = new Node();
                        featureSplitters[bigBaggingFeature].fillNode(best[index]);
                    }
                    featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing);
                }
                featureCnt++;
            }
            index++;
        }
    }
    context.putObj("bestLength", index);
    LOG.info("taskId: {}, {} end", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
}
Also used : DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) Node(com.alibaba.alink.operator.common.tree.Node) DistributedInfo(com.alibaba.alink.common.io.directreader.DistributedInfo) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice)

Aggregations

DefaultDistributedInfo (com.alibaba.alink.common.io.directreader.DefaultDistributedInfo)5 DistributedInfo (com.alibaba.alink.common.io.directreader.DistributedInfo)5 HashSet (java.util.HashSet)2 Set (java.util.Set)2 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)2 Configuration (org.apache.flink.configuration.Configuration)2 AlinkGlobalConfiguration (com.alibaba.alink.common.AlinkGlobalConfiguration)1 Node (com.alibaba.alink.operator.common.tree.Node)1 Slice (com.alibaba.alink.operator.common.tree.parallelcart.data.Slice)1 IntCounter (org.apache.flink.api.common.accumulators.IntCounter)1 RichMapFunction (org.apache.flink.api.common.functions.RichMapFunction)1 RichMapPartitionFunction (org.apache.flink.api.common.functions.RichMapPartitionFunction)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Row (org.apache.flink.types.Row)1 Collector (org.apache.flink.util.Collector)1