use of com.alibaba.alink.common.io.directreader.DistributedInfo in project Alink by alibaba.
the class ApsIndexFunc4PullLine method requestIndex.
@Override
protected Set<Long> requestIndex(List<Number[]> data) throws Exception {
long startTimeConsumption = System.nanoTime();
LOG.info("taskId: {}, localId: {}", getPatitionId(), getRuntimeContext().getIndexOfThisSubtask());
LOG.info("taskId: {}, negInputSize: {}", getPatitionId(), data.size());
if (null != this.contextParams) {
Long[] nsPool = this.contextParams.getLongArray("negBound");
int subTaskId = getRuntimeContext().getIndexOfThisSubtask();
Long[] seeds = this.contextParams.get(ApsContext.SEEDS);
int seed = seeds[subTaskId].intValue();
int negaTime = params.get(HasNegative.NEGATIVE);
int threadNum = params.getIntegerOrDefault("threadNum", 8);
double sampleRatioPerPartition = params.get(LineParams.SAMPLE_RATIO_PER_PARTITION);
Thread[] thread = new Thread[threadNum];
// save the vertices which need to be pulled.
Set<Long>[] output = new Set[threadNum];
DistributedInfo distributedInfo = new DefaultDistributedInfo();
for (int i = 0; i < threadNum; ++i) {
int start = (int) distributedInfo.startPos(i, threadNum, data.size());
int end = (int) distributedInfo.localRowCnt(i, threadNum, data.size()) + start;
LOG.info("taskId: {}, negStart: {}, end: {}", getPatitionId(), start, end);
output[i] = new HashSet<>();
thread[i] = new NegSampleRunner(nsPool, seed + i, data.subList(start, end), output[i], negaTime, sampleRatioPerPartition);
thread[i].start();
}
for (int i = 0; i < threadNum; ++i) {
thread[i].join();
}
Set<Long> outputMerger = new HashSet<>();
for (int i = 0; i < threadNum; ++i) {
outputMerger.addAll(output[i]);
}
LOG.info("taskId: {}, negOutputSize: {}", getPatitionId(), outputMerger.size());
long endTimeConsumption = System.nanoTime();
LOG.info("taskId: {}, negTime: {}", getPatitionId(), (endTimeConsumption - startTimeConsumption) / 1000000.0);
return outputMerger;
} else {
throw new RuntimeException();
}
}
use of com.alibaba.alink.common.io.directreader.DistributedInfo in project Alink by alibaba.
the class ApsContext method updateLoopInfo.
public <MT> ApsContext updateLoopInfo(IterativeDataSet<Tuple2<Long, MT>> loop) {
this.put(alinkApsStepNum, loop.mapPartition(new RichMapPartitionFunction<Tuple2<Long, MT>, Integer>() {
private static final long serialVersionUID = 4816930852791283240L;
@Override
public void mapPartition(Iterable<Tuple2<Long, MT>> iterable, Collector<Integer> collector) throws Exception {
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
collector.collect(getIterationRuntimeContext().getSuperstepNumber());
}
}
}).returns(Types.INT));
context = context.map(new RichMapFunction<Params, Params>() {
private static final long serialVersionUID = -7796590264875170687L;
IntCounter counter = new IntCounter();
@Override
public void open(Configuration parameters) throws Exception {
getRuntimeContext().addAccumulator(alinkApsBreakAll, this.counter);
}
@Override
public Params map(Params params) throws Exception {
int numCheckPoint = params.getInteger(alinkApsNumCheckpoint);
int numIter = params.getIntegerOrDefault(alinkApsNumIter, 1);
int numMiniBatch = params.getInteger(alinkApsNumMiniBatch);
int startPos, endPos;
if (numCheckPoint <= 0) {
startPos = 0;
endPos = numMiniBatch * numIter;
} else {
int curCheckPoint = params.getInteger(alinkApsCurCheckpoint);
DistributedInfo distributedInfo = new DefaultDistributedInfo();
startPos = (int) distributedInfo.startPos(curCheckPoint, numCheckPoint, numMiniBatch * numIter);
endPos = (int) distributedInfo.localRowCnt(curCheckPoint, numCheckPoint, numMiniBatch * numIter) + startPos;
}
int curPos = params.getInteger(alinkApsStepNum) - 1 + startPos;
LOG.info("taskId:{}, stepNum:{}", getRuntimeContext().getIndexOfThisSubtask(), curPos);
int curBlock;
if (curPos >= endPos) {
curBlock = -1;
} else if (curPos >= numMiniBatch * numIter) {
curBlock = -1;
} else {
curBlock = curPos % numMiniBatch;
}
boolean hasNextBlock = true;
if (curPos >= endPos - 1) {
hasNextBlock = false;
} else if (curPos >= numMiniBatch * numIter - 1) {
hasNextBlock = false;
}
if (endPos >= numMiniBatch * numIter) {
this.counter.add(1);
}
params.set(ALINK_APS_HAS_NEXT_BLOCK, hasNextBlock);
params.set(ALINK_APS_CUR_BLOCK, curBlock);
return params;
}
});
return this;
}
use of com.alibaba.alink.common.io.directreader.DistributedInfo in project Alink by alibaba.
the class ApsFuncIndex4PullW2V method requestIndex.
@Override
protected Set<Long> requestIndex(List<int[]> data) throws Exception {
long startTimeConsumption = System.nanoTime();
LOG.info("taskId: {}, localId: {}", getPatitionId(), getRuntimeContext().getIndexOfThisSubtask());
LOG.info("taskId: {}, negInputSize: {}", getPatitionId(), data.size());
if (null != this.contextParams) {
Long[] seeds = this.contextParams.get(ApsContext.SEEDS);
Long[] nsPool = this.contextParams.getLongArray("negBound");
final boolean metapathMode = params.getBoolOrDefault("metapathMode", false);
long[] groupIdxStarts = null;
if (metapathMode) {
groupIdxStarts = ArrayUtils.toPrimitive(this.contextParams.getLongArray("groupIdxes"));
}
int vocSize = this.contextParams.getLong("vocSize").intValue();
long seed = seeds[getPatitionId()];
int threadNum = params.getIntegerOrDefault("threadNum", 8);
Thread[] thread = new Thread[threadNum];
Set<Long>[] output = new Set[threadNum];
DistributedInfo distributedInfo = new DefaultDistributedInfo();
for (int i = 0; i < threadNum; ++i) {
int start = (int) distributedInfo.startPos(i, threadNum, data.size());
int end = (int) distributedInfo.localRowCnt(i, threadNum, data.size()) + start;
LOG.info("taskId: {}, negStart: {}, end: {}", getPatitionId(), start, end);
output[i] = new HashSet<>();
thread[i] = new NegSampleRunner(vocSize, nsPool, params, seed + i, data.subList(start, end), output[i], null, groupIdxStarts);
thread[i].start();
}
for (int i = 0; i < threadNum; ++i) {
thread[i].join();
}
Set<Long> outputMerger = new HashSet<>();
for (int i = 0; i < threadNum; ++i) {
outputMerger.addAll(output[i]);
}
LOG.info("taskId: {}, negOutputSize: {}", getPatitionId(), outputMerger.size());
long endTimeConsumption = System.nanoTime();
LOG.info("taskId: {}, negTime: {}", getPatitionId(), (endTimeConsumption - startTimeConsumption) / 1000000.0);
return outputMerger;
} else {
throw new RuntimeException();
}
}
use of com.alibaba.alink.common.io.directreader.DistributedInfo in project Alink by alibaba.
the class BaseTuning method split.
private DataSet<Tuple2<Integer, Row>> split(BatchOperator<?> data, int k) {
DataSet<Row> input = shuffle(data.getDataSet());
DataSet<Tuple2<Integer, Long>> counts = DataSetUtils.countElementsPerPartition(input);
return input.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() {
private static final long serialVersionUID = -902599228310615694L;
long taskStart = 0L;
long totalNumInstance = 0L;
@Override
public void open(Configuration parameters) {
List<Tuple2<Integer, Long>> counts1 = getRuntimeContext().getBroadcastVariable("counts");
int taskId = getRuntimeContext().getIndexOfThisSubtask();
for (Tuple2<Integer, Long> cnt : counts1) {
if (taskId < cnt.f0) {
taskStart += cnt.f1;
}
totalNumInstance += cnt.f1;
}
}
@Override
public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Row>> out) {
DistributedInfo distributedInfo = new DefaultDistributedInfo();
Tuple2<Integer, Long> split1 = new Tuple2<>(-1, -1L);
long lcnt = taskStart;
for (int i = 0; i <= k; ++i) {
long sp = distributedInfo.startPos(i, k, totalNumInstance);
long lrc = distributedInfo.localRowCnt(i, k, totalNumInstance);
if (taskStart < sp) {
split1.f0 = i - 1;
split1.f1 = distributedInfo.startPos(i - 1, k, totalNumInstance) + distributedInfo.localRowCnt(i - 1, k, totalNumInstance);
break;
}
if (taskStart == sp) {
split1.f0 = i;
split1.f1 = sp + lrc;
break;
}
}
for (Row val : values) {
if (lcnt >= split1.f1) {
split1.f0 += 1;
split1.f1 = distributedInfo.localRowCnt(split1.f0, k, totalNumInstance) + lcnt;
}
out.collect(Tuple2.of(split1.f0, val));
lcnt++;
}
}
}).withBroadcastSet(counts, "counts");
}
use of com.alibaba.alink.common.io.directreader.DistributedInfo in project Alink by alibaba.
the class CalcFeatureGain method calc.
@Override
public void calc(ComContext context) {
LOG.info("taskId: {}, {} start", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
BoostingObjs boostingObjs = context.getObj("boostingObjs");
HistogramBaseTreeObjs tree = context.getObj("tree");
double[] histogram = context.getObj("histogram");
if (context.getStepNo() == 1) {
context.putObj("best", new Node[tree.maxNodeSize]);
featureSplitters = new HistogramFeatureSplitter[boostingObjs.data.getN()];
for (int i = 0; i < boostingObjs.data.getN(); ++i) {
featureSplitters[i] = createFeatureSplitter(boostingObjs.data.getFeatureMetas()[i].getType() == FeatureMeta.FeatureType.CATEGORICAL, boostingObjs.params, boostingObjs.data.getFeatureMetas()[i], tree.compareIndex4Categorical);
}
}
int sumFeatureCount = 0;
for (NodeInfoPair item : tree.queue) {
sumFeatureCount += boostingObjs.numBaggingFeatures;
if (item.big != null) {
sumFeatureCount += boostingObjs.numBaggingFeatures;
}
}
DistributedInfo distributedInfo = new DefaultDistributedInfo();
int start = (int) distributedInfo.startPos(context.getTaskId(), context.getNumTask(), sumFeatureCount);
int cnt = (int) distributedInfo.localRowCnt(context.getTaskId(), context.getNumTask(), sumFeatureCount);
int end = start + cnt;
int featureCnt = 0;
int featureBinCnt = 0;
Node[] best = context.getObj("best");
int index = 0;
for (NodeInfoPair item : tree.queue) {
best[index] = null;
final int[] smallBaggingFeatures = item.small.baggingFeatures;
for (int smallBaggingFeature : smallBaggingFeatures) {
if (featureCnt >= start && featureCnt < end) {
featureSplitters[smallBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing)), item.small.depth);
double gain = featureSplitters[smallBaggingFeature].bestSplit(tree.leaves.size());
if (best[index] == null || (featureSplitters[smallBaggingFeature].canSplit() && gain > best[index].getGain())) {
best[index] = new Node();
featureSplitters[smallBaggingFeature].fillNode(best[index]);
}
featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing);
}
featureCnt++;
}
index++;
if (item.big != null) {
best[index] = null;
final int[] bigBaggingFeatures = item.big.baggingFeatures;
for (int bigBaggingFeature : bigBaggingFeatures) {
if (featureCnt >= start && featureCnt < end) {
featureSplitters[bigBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing)), item.big.depth);
double gain = featureSplitters[bigBaggingFeature].bestSplit(tree.leaves.size());
if (best[index] == null || (featureSplitters[bigBaggingFeature].canSplit() && gain > best[index].getGain())) {
best[index] = new Node();
featureSplitters[bigBaggingFeature].fillNode(best[index]);
}
featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing);
}
featureCnt++;
}
index++;
}
}
context.putObj("bestLength", index);
LOG.info("taskId: {}, {} end", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
}
Aggregations