use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class StatisticsHelperTest method summaryHelperVector.
@Test
public void summaryHelperVector() throws Exception {
BatchOperator data = getDenseBatch();
String vectorColName = "vec";
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataSet = StatisticsHelper.summaryHelper(data, null, vectorColName);
BaseVectorSummary summary = dataSet.f1.collect().get(0);
assertEquals(summary.vectorSize(), 3);
assertEquals(summary.count(), 4);
assertEquals(summary.max(2), 4.0, 10e-4);
assertEquals(summary.min(1), 0.0, 10e-4);
assertEquals(summary.mean(2), 1.25, 10e-4);
assertEquals(summary.variance(2), 8.9167, 10e-4);
assertEquals(summary.standardDeviation(2), 2.9861, 10e-4);
assertEquals(summary.normL1(2), 11.0, 10e-4);
assertEquals(summary.normL2(2), 5.7446, 10e-4);
List<Vector> vectors = dataSet.f0.collect();
assertEquals(vectors.size(), 4);
assertArrayEquals(((DenseVector) vectors.get(0)).getData(), new double[] { 1, 1, 2.0 }, 10e-4);
assertArrayEquals(((DenseVector) vectors.get(1)).getData(), new double[] { 2, 2, -3.0 }, 10e-4);
assertArrayEquals(((DenseVector) vectors.get(2)).getData(), new double[] { 1, 3, 2.0 }, 10e-4);
assertArrayEquals(((DenseVector) vectors.get(3)).getData(), new double[] { 0, 0, 4.0 }, 10e-4);
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class StatisticsHelperTest method summaryHelperTableWithReservedCols.
@Test
public void summaryHelperTableWithReservedCols() throws Exception {
BatchOperator data = getBatchTable();
String[] selectedColNames = new String[] { "f_long", "f_int", "f_double" };
String[] reservedColNames = new String[] { "id" };
Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> dataSet = StatisticsHelper.summaryHelper(data, selectedColNames, null, reservedColNames);
BaseVectorSummary summary = dataSet.f1.collect().get(0);
assertEquals(summary.vectorSize(), 3);
assertEquals(summary.count(), 4);
assertEquals(summary.max(2), 4.0, 10e-4);
assertEquals(summary.min(1), 0.0, 10e-4);
assertEquals(summary.mean(2), 1.25, 10e-4);
assertEquals(summary.variance(2), 8.9167, 10e-4);
assertEquals(summary.standardDeviation(2), 2.9861, 10e-4);
assertEquals(summary.normL1(2), 11.0, 10e-4);
assertEquals(summary.normL2(2), 5.7446, 10e-4);
List<Tuple2<Vector, Row>> tuple2s = dataSet.f0.collect();
assertEquals(tuple2s.size(), 4);
assertArrayEquals(((DenseVector) tuple2s.get(0).f0).getData(), new double[] { 1, 1, 2.0 }, 10e-4);
assertArrayEquals(((DenseVector) tuple2s.get(1).f0).getData(), new double[] { 2, 2, -3.0 }, 10e-4);
assertArrayEquals(((DenseVector) tuple2s.get(2).f0).getData(), new double[] { 1, 3, 2.0 }, 10e-4);
assertArrayEquals(((DenseVector) tuple2s.get(3).f0).getData(), new double[] { 0, 0, 4.0 }, 10e-4);
assertEquals(tuple2s.get(0).f1.getField(0), 1);
assertEquals(tuple2s.get(1).f1.getField(0), 2);
assertEquals(tuple2s.get(2).f1.getField(0), 3);
assertEquals(tuple2s.get(3).f1.getField(0), 4);
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class LdaTrainBatchOp method online.
private void online(Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataAndStat, int numTopic, int numIter, double alpha, double beta, DataSet<DocCountVectorizerModelData> resDocCountModel, int gammaShape, Integer seed) {
if (beta == -1) {
beta = 1.0 / numTopic;
}
if (alpha == -1) {
alpha = 1.0 / numTopic;
}
double learningOffset = getParams().get(ONLINE_LEARNING_OFFSET);
double learningDecay = getParams().get(LEARNING_DECAY);
double subSamplingRate = getParams().get(SUBSAMPLING_RATE);
boolean optimizeDocConcentration = getParams().get(OPTIMIZE_DOC_CONCENTRATION);
DataSet<Vector> data = dataAndStat.f0;
DataSet<Tuple2<Long, Integer>> shape = dataAndStat.f1.map(new MapFunction<BaseVectorSummary, Tuple2<Long, Integer>>() {
private static final long serialVersionUID = 1305270477796787466L;
@Override
public Tuple2<Long, Integer> map(BaseVectorSummary srt) {
return new Tuple2<>(srt.count(), srt.vectorSize());
}
});
DataSet<Tuple2<DenseMatrix, DenseMatrix>> initModel = data.mapPartition(new OnlineInit(numTopic, gammaShape, alpha, seed)).name("init lambda").withBroadcastSet(shape, LdaVariable.shape);
DataSet<Row> ldaModelData = new IterativeComQueue().initWithPartitionedData(LdaVariable.data, data).initWithBroadcastData(LdaVariable.shape, shape).initWithBroadcastData(LdaVariable.initModel, initModel).add(new OnlineCorpusStep(numTopic, subSamplingRate, gammaShape, seed)).add(new AllReduce(LdaVariable.wordTopicStat)).add(new AllReduce(LdaVariable.logPhatPart)).add(new AllReduce(LdaVariable.nonEmptyWordCount)).add(new AllReduce(LdaVariable.nonEmptyDocCount)).add(new UpdateLambdaAndAlpha(numTopic, learningOffset, learningDecay, subSamplingRate, optimizeDocConcentration, beta)).add(new OnlineLogLikelihood(beta, numTopic, numIter, gammaShape, seed)).add(new AllReduce(LdaVariable.logLikelihood)).closeWith(new BuildOnlineLdaModel(numTopic, beta)).setMaxIter(numIter).exec();
DataSet<Row> model = ldaModelData.flatMap(new BuildResModel(seed)).withBroadcastSet(resDocCountModel, "DocCountModel");
setOutput(model, new LdaModelDataConverter().getModelSchema());
saveWordTopicModelAndPerplexity(model, numTopic, true);
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class KDTreeModelDataConverter method buildIndex.
@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
Preconditions.checkArgument(params.get(VectorApproxNearestNeighborTrainParams.METRIC).equals(VectorApproxNearestNeighborTrainParams.Metric.EUCLIDEAN), "KDTree solver only supports Euclidean distance!");
EuclideanDistance distance = new EuclideanDistance();
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, params.get(VectorApproxNearestNeighborTrainParams.SELECTED_COL));
return in.getDataSet().rebalance().mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = 6654757741959479783L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
BaseVectorSummary summary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
int vectorSize = summary.vectorSize();
List<FastDistanceVectorData> list = new ArrayList<>();
for (Row row : values) {
FastDistanceVectorData vector = distance.prepareVectorData(row, 1, 0);
list.add(vector);
vectorSize = vector.getVector().size();
}
if (list.size() > 0) {
FastDistanceVectorData[] vectorArray = list.toArray(new FastDistanceVectorData[0]);
KDTree tree = new KDTree(vectorArray, vectorSize, distance);
tree.buildTree();
int taskId = getRuntimeContext().getIndexOfThisSubtask();
Row row = new Row(ROW_SIZE);
row.setField(TASKID_INDEX, (long) taskId);
for (int i = 0; i < vectorArray.length; i++) {
row.setField(DATA_ID_INDEX, (long) i);
row.setField(DATA_IDNEX, vectorArray[i].toString());
out.collect(row);
}
row.setField(DATA_ID_INDEX, null);
row.setField(DATA_IDNEX, null);
row.setField(ROOT_IDDEX, JsonConverter.toJson(tree.getRoot()));
out.collect(row);
}
}
}).withBroadcastSet(statistics.f1, "vectorSize").mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = 6849403933586157611L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
Params meta = null;
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
meta = params;
BaseVectorSummary summary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
int vectorSize = summary.vectorSize();
meta.set(VECTOR_SIZE, vectorSize);
}
new KDTreeModelDataConverter().save(Tuple2.of(meta, values), out);
}
}).withBroadcastSet(statistics.f1, "vectorSize");
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class LocalitySensitiveHashApproxFunctions method buildLSH.
public static DataSet<BaseLSH> buildLSH(BatchOperator in, Params params, String vectorCol) {
DataSet<BaseLSH> lsh;
VectorApproxNearestNeighborTrainParams.Metric metric = params.get(VectorApproxNearestNeighborTrainParams.METRIC);
switch(metric) {
case JACCARD:
{
lsh = MLEnvironmentFactory.get(params.get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)).getExecutionEnvironment().fromElements(new MinHashLSH(params.get(VectorApproxNearestNeighborTrainParams.SEED), params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE), params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES)));
break;
}
case EUCLIDEAN:
{
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, vectorCol);
lsh = statistics.f1.mapPartition(new MapPartitionFunction<BaseVectorSummary, BaseLSH>() {
private static final long serialVersionUID = -3698577489884292933L;
@Override
public void mapPartition(Iterable<BaseVectorSummary> values, Collector<BaseLSH> out) {
List<BaseVectorSummary> tensorInfo = new ArrayList<>();
values.forEach(tensorInfo::add);
out.collect(new BucketRandomProjectionLSH(params.get(VectorApproxNearestNeighborTrainParams.SEED), tensorInfo.get(0).vectorSize(), params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE), params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES), params.get(VectorApproxNearestNeighborTrainParams.PROJECTION_WIDTH)));
}
});
break;
}
default:
{
throw new IllegalArgumentException("Not support " + metric);
}
}
return lsh;
}
Aggregations