use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class VectorImputerModelDataConverter method serializeModel.
/**
* Serialize the model data to "Tuple3<Params, List<String>, List<Row>>".
*
* @param modelData The model data to serialize.
* @return The serialization result.
*/
public Tuple3<Params, Iterable<String>, Iterable<Row>> serializeModel(Tuple3<Strategy, BaseVectorSummary, Double> modelData) {
Strategy strategy = modelData.f0;
BaseVectorSummary summary = modelData.f1;
double fillValue = modelData.f2;
double[] values = null;
Params meta = new Params().set(SELECTED_COL, vectorColName).set(STRATEGY, strategy);
switch(strategy) {
case MIN:
if (summary.min() instanceof DenseVector) {
values = ((DenseVector) summary.min()).getData();
} else {
values = ((SparseVector) summary.min()).toDenseVector().getData();
}
break;
case MAX:
if (summary.max() instanceof DenseVector) {
values = ((DenseVector) summary.max()).getData();
} else {
values = ((SparseVector) summary.max()).toDenseVector().getData();
}
break;
case MEAN:
if (summary.mean() instanceof DenseVector) {
values = ((DenseVector) summary.mean()).getData();
} else {
values = ((SparseVector) summary.mean()).getValues();
}
break;
default:
meta.set(FILL_VALUE, fillValue);
}
List<String> data = new ArrayList<>();
data.add(JsonConverter.toJson(values));
return Tuple3.of(meta, data, new ArrayList<>());
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class GmmTrainBatchOp method linkFrom.
/**
* Train the Gaussian Mixture model with Expectation-Maximization algorithm.
*/
@Override
public GmmTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final String vectorColName = getVectorCol();
final int numClusters = getK();
final int maxIter = getMaxIter();
final double tol = getEpsilon();
// Extract the vectors from the input operator.
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> vectorAndSummary = StatisticsHelper.summaryHelper(in, null, vectorColName);
DataSet<Integer> featureSize = vectorAndSummary.f1.map(new MapFunction<BaseVectorSummary, Integer>() {
private static final long serialVersionUID = 8456872852742625845L;
@Override
public Integer map(BaseVectorSummary summary) throws Exception {
return summary.vectorSize();
}
});
DataSet<Vector> data = vectorAndSummary.f0.map(new RichMapFunction<Vector, Vector>() {
private static final long serialVersionUID = -845795862675993897L;
transient int featureSize;
@Override
public void open(Configuration parameters) throws Exception {
List<Integer> bc = getRuntimeContext().getBroadcastVariable("featureSize");
this.featureSize = bc.get(0);
}
@Override
public Vector map(Vector vec) throws Exception {
if (vec instanceof SparseVector) {
((SparseVector) vec).setSize(featureSize);
}
return vec;
}
}).withBroadcastSet(featureSize, "featureSize");
// Initialize the model.
DataSet<Tuple3<Integer, GmmClusterSummary, IterationStatus>> initialModel = initRandom(data, numClusters, getRandomSeed());
// Iteratively update the model with EM algorithm.
IterativeDataSet<Tuple3<Integer, GmmClusterSummary, IterationStatus>> loop = initialModel.iterate(maxIter);
DataSet<Tuple4<Integer, GmmClusterSummary, IterationStatus, MultivariateGaussian>> md = loop.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, GmmClusterSummary, IterationStatus>, Tuple4<Integer, GmmClusterSummary, IterationStatus, MultivariateGaussian>>() {
private static final long serialVersionUID = -1937088240477952410L;
@Override
public void mapPartition(Iterable<Tuple3<Integer, GmmClusterSummary, IterationStatus>> values, Collector<Tuple4<Integer, GmmClusterSummary, IterationStatus, MultivariateGaussian>> collector) throws Exception {
for (Tuple3<Integer, GmmClusterSummary, IterationStatus> value : values) {
DenseVector means = value.f1.mean;
DenseMatrix cov = GmmModelData.expandCovarianceMatrix(value.f1.cov, means.size());
MultivariateGaussian md = new MultivariateGaussian(means, cov);
collector.collect(Tuple4.of(value.f0, value.f1, value.f2, md));
}
}
}).withForwardedFields("f0;f1;f2");
DataSet<Tuple3<Integer, GmmClusterSummary, IterationStatus>> updatedModel = data.<LocalAggregator>mapPartition(new RichMapPartitionFunction<Vector, LocalAggregator>() {
private static final long serialVersionUID = 8356493076036649604L;
transient DenseVector oldWeights;
transient DenseVector[] oldMeans;
transient DenseVector[] oldCovs;
transient MultivariateGaussian[] mnd;
@Override
public void open(Configuration parameters) throws Exception {
oldWeights = new DenseVector(numClusters);
oldMeans = new DenseVector[numClusters];
oldCovs = new DenseVector[numClusters];
mnd = new MultivariateGaussian[numClusters];
}
@Override
public void mapPartition(Iterable<Vector> values, Collector<LocalAggregator> out) throws Exception {
List<Integer> bcNumFeatures = getRuntimeContext().getBroadcastVariable("featureSize");
List<Tuple4<Integer, GmmClusterSummary, IterationStatus, MultivariateGaussian>> bcOldModel = getRuntimeContext().getBroadcastVariable("oldModel");
double prevLogLikelihood = 0.;
for (Tuple4<Integer, GmmClusterSummary, IterationStatus, MultivariateGaussian> t : bcOldModel) {
int clusterId = t.f0;
GmmClusterSummary clusterInfo = t.f1;
prevLogLikelihood = t.f2.currLogLikelihood;
oldWeights.set(clusterId, clusterInfo.weight);
oldMeans[clusterId] = clusterInfo.mean;
oldCovs[clusterId] = clusterInfo.cov;
mnd[clusterId] = new MultivariateGaussian(t.f3);
// mnd[clusterId] = t.f3;
}
LocalAggregator aggregator = new LocalAggregator(numClusters, bcNumFeatures.get(0), prevLogLikelihood, oldWeights, oldMeans, oldCovs, mnd);
values.forEach(aggregator::add);
out.collect(aggregator);
}
}).withBroadcastSet(featureSize, "featureSize").withBroadcastSet(md, "oldModel").name("E-M_step").reduce(new ReduceFunction<LocalAggregator>() {
private static final long serialVersionUID = -6976429920344470952L;
@Override
public LocalAggregator reduce(LocalAggregator value1, LocalAggregator value2) throws Exception {
return value1.merge(value2);
}
}).flatMap(new RichFlatMapFunction<LocalAggregator, Tuple3<Integer, GmmClusterSummary, IterationStatus>>() {
private static final long serialVersionUID = 6599047947335456972L;
@Override
public void flatMap(LocalAggregator aggregator, Collector<Tuple3<Integer, GmmClusterSummary, IterationStatus>> out) throws Exception {
for (int i = 0; i < numClusters; i++) {
double w = aggregator.updatedWeightsSum.get(i);
aggregator.updatedMeansSum[i].scaleEqual(1.0 / w);
aggregator.updatedCovsSum[i].scaleEqual(1.0 / w);
GmmClusterSummary model = new GmmClusterSummary(i, w / aggregator.totalCount, aggregator.updatedMeansSum[i], aggregator.updatedCovsSum[i]);
// note that we use Cov(X,Y) = E[XY] - E[X]E[Y] to compute Cov(X,Y)
int featureSize = model.mean.size();
for (int m = 0; m < featureSize; m++) {
// loop over columns
for (int n = m; n < featureSize; n++) {
int pos = GmmModelData.getElementPositionInCompactMatrix(m, n, featureSize);
model.cov.add(pos, -1.0 * model.mean.get(m) * model.mean.get(n));
}
}
IterationStatus stat = new IterationStatus();
stat.prevLogLikelihood = aggregator.prevLogLikelihood;
stat.currLogLikelihood = aggregator.newLogLikelihood;
out.collect(Tuple3.of(i, model, stat));
}
}
}).partitionCustom(new Partitioner<Integer>() {
private static final long serialVersionUID = 1006932050560340472L;
@Override
public int partition(Integer key, int numPartitions) {
return key % numPartitions;
}
}, 0);
// Check whether stop criterion is met.
DataSet<Boolean> criterion = updatedModel.first(1).flatMap(new RichFlatMapFunction<Tuple3<Integer, GmmClusterSummary, IterationStatus>, Boolean>() {
private static final long serialVersionUID = 6890280483282243057L;
@Override
public void flatMap(Tuple3<Integer, GmmClusterSummary, IterationStatus> value, Collector<Boolean> out) throws Exception {
IterationStatus stat = value.f2;
int stepNo = getIterationRuntimeContext().getSuperstepNumber();
double diffLogLikelihood = Math.abs(stat.currLogLikelihood - stat.prevLogLikelihood);
LOG.info("step {}, prevLogLikelihood {}, currLogLikelihood {}, diffLogLikelihood {}", stepNo, stat.prevLogLikelihood, stat.currLogLikelihood, diffLogLikelihood);
if (stepNo <= 1 || diffLogLikelihood > tol) {
out.collect(false);
}
}
});
DataSet<Tuple3<Integer, GmmClusterSummary, IterationStatus>> finalModel = loop.closeWith(updatedModel, criterion);
// Output the model.
DataSet<Row> modelRows = finalModel.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, GmmClusterSummary, IterationStatus>, Row>() {
private static final long serialVersionUID = -8411238421923712023L;
transient int featureSize;
@Override
public void open(Configuration parameters) throws Exception {
this.featureSize = (int) (getRuntimeContext().getBroadcastVariable("featureSize").get(0));
}
@Override
public void mapPartition(Iterable<Tuple3<Integer, GmmClusterSummary, IterationStatus>> values, Collector<Row> out) throws Exception {
int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
if (numTasks > 1) {
throw new RuntimeException("parallelism is not 1 when saving model.");
}
GmmModelData model = new GmmModelData();
model.k = numClusters;
model.dim = featureSize;
model.vectorCol = vectorColName;
model.data = new ArrayList<>(numClusters);
for (Tuple3<Integer, GmmClusterSummary, IterationStatus> t : values) {
t.f1.clusterId = t.f0;
model.data.add(t.f1);
}
new GmmModelDataConverter().save(model, out);
}
}).setParallelism(1).withBroadcastSet(featureSize, "featureSize");
this.setOutput(modelRows, new GmmModelDataConverter().getModelSchema());
return this;
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class PCATest method testDense.
private void testDense() {
String[] colNames = new String[] { "id", "vec" };
Object[][] data = new Object[][] { { 1, "0.1 0.2 0.3 0.4" }, { 2, "0.2 0.1 0.2 0.6" }, { 3, "0.2 0.3 0.5 0.4" }, { 4, "0.3 0.1 0.3 0.7" }, { 5, "0.4 0.2 0.4 0.4" } };
MemSourceBatchOp source = new MemSourceBatchOp(data, colNames);
PCA pca = new PCA().setK(3).setCalculationType("CORR").setPredictionCol("pred").setReservedCols("id").setVectorCol("vec");
pca.enableLazyPrintModelInfo();
PCAModel model = pca.fit(source);
BatchOperator<?> predict = model.transform(source);
VectorSummarizerBatchOp summarizerOp = new VectorSummarizerBatchOp().setSelectedCol("pred");
summarizerOp.linkFrom(predict);
summarizerOp.lazyCollectVectorSummary(new Consumer<BaseVectorSummary>() {
@Override
public void accept(BaseVectorSummary summary) {
Assert.assertEquals(3.4416913763379853E-15, Math.abs(summary.sum().get(0)), 10e-8);
}
});
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class VectorCorrelationBatchOp method linkFrom.
@Override
public VectorCorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String vectorColName = getSelectedCol();
Method corrType = getMethod();
if (Method.PEARSON == corrType) {
DataSet<Tuple2<BaseVectorSummary, CorrelationResult>> srt = StatisticsHelper.vectorPearsonCorrelation(in, vectorColName);
// block
DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<BaseVectorSummary, CorrelationResult>, Row>() {
private static final long serialVersionUID = 2134644397476490118L;
@Override
public void flatMap(Tuple2<BaseVectorSummary, CorrelationResult> srt, Collector<Row> collector) throws Exception {
new CorrelationDataConverter().save(srt.f1, collector);
}
});
this.setOutput(result, new CorrelationDataConverter().getModelSchema());
} else {
DataSet<Row> data = StatisticsHelper.transformToColumns(in, null, vectorColName, null);
DataSet<Row> rank = SpearmanCorrelation.calcRank(data, true);
BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, new String[] { "col" }, new TypeInformation[] { Types.STRING })).setMLEnvironmentId(getMLEnvironmentId());
VectorCorrelationBatchOp corrBatchOp = new VectorCorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol("col");
rankOp.link(corrBatchOp);
this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
}
return this;
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class VectorSummarizerBatchOp method linkFrom.
@Override
public VectorSummarizerBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
DataSet<BaseVectorSummary> srt = StatisticsHelper.vectorSummary(in, getSelectedCol());
DataSet<Row> out = srt.flatMap(new VectorSummaryBuildModel());
VectorSummaryDataConverter converter = new VectorSummaryDataConverter();
this.setOutput(out, converter.getModelSchema());
return this;
}
Aggregations