use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class VectorStandardScalerModelDataConverter method serializeModel.
/**
* Serialize the model data to "Tuple3<Params, List<String>, List<Row>>".
*
* @param modelData The model data to serialize.
* @return The serialization result.
*/
public Tuple3<Params, Iterable<String>, Iterable<Row>> serializeModel(Tuple3<Boolean, Boolean, BaseVectorSummary> modelData) {
Boolean withMean = modelData.f0;
Boolean withStd = modelData.f1;
BaseVectorSummary summary = modelData.f2;
double[] means;
double[] stdDeviations;
int n = summary.vectorSize();
if (withMean) {
if (summary.mean() instanceof DenseVector) {
means = ((DenseVector) summary.mean()).getData();
} else {
means = ((SparseVector) summary.mean()).toDenseVector().getData();
}
} else {
means = new double[n];
}
if (withStd) {
if (summary.standardDeviation() instanceof DenseVector) {
stdDeviations = ((DenseVector) summary.standardDeviation()).getData();
} else {
stdDeviations = ((SparseVector) summary.standardDeviation()).toDenseVector().getData();
}
} else {
stdDeviations = new double[n];
Arrays.fill(stdDeviations, 1);
}
List<String> data = new ArrayList<>();
data.add(JsonConverter.toJson(means));
data.add(JsonConverter.toJson(stdDeviations));
Params meta = new Params().set(VectorStandardTrainParams.WITH_MEAN, withMean).set(VectorStandardTrainParams.WITH_STD, withStd).set(VectorMinMaxScalerTrainParams.SELECTED_COL, vectorColName);
return Tuple3.of(meta, data, new ArrayList<>());
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class StatisticsHelperTest method summaryHelperVectorWithReservedCols.
@Test
public void summaryHelperVectorWithReservedCols() throws Exception {
BatchOperator data = getDenseBatch();
String vectorColName = "vec";
String[] reservedColNames = new String[] { "id" };
Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> dataSet = StatisticsHelper.summaryHelper(data, null, vectorColName, reservedColNames);
BaseVectorSummary summary = dataSet.f1.collect().get(0);
assertEquals(summary.vectorSize(), 3);
assertEquals(summary.count(), 4);
assertEquals(summary.max(2), 4.0, 10e-4);
assertEquals(summary.min(1), 0.0, 10e-4);
assertEquals(summary.mean(2), 1.25, 10e-4);
assertEquals(summary.variance(2), 8.9167, 10e-4);
assertEquals(summary.standardDeviation(2), 2.9861, 10e-4);
assertEquals(summary.normL1(2), 11.0, 10e-4);
assertEquals(summary.normL2(2), 5.7446, 10e-4);
List<Tuple2<Vector, Row>> tuple2s = dataSet.f0.collect();
assertEquals(tuple2s.size(), 4);
assertArrayEquals(((DenseVector) tuple2s.get(0).f0).getData(), new double[] { 1, 1, 2.0 }, 10e-4);
assertArrayEquals(((DenseVector) tuple2s.get(1).f0).getData(), new double[] { 2, 2, -3.0 }, 10e-4);
assertArrayEquals(((DenseVector) tuple2s.get(2).f0).getData(), new double[] { 1, 3, 2.0 }, 10e-4);
assertArrayEquals(((DenseVector) tuple2s.get(3).f0).getData(), new double[] { 0, 0, 4.0 }, 10e-4);
assertEquals(tuple2s.get(0).f1.getField(0), 1);
assertEquals(tuple2s.get(1).f1.getField(0), 2);
assertEquals(tuple2s.get(2).f1.getField(0), 3);
assertEquals(tuple2s.get(3).f1.getField(0), 4);
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class StatisticsHelperTest method vectorPearsonCorrelation.
@Test
public void vectorPearsonCorrelation() throws Exception {
BatchOperator data = getDenseBatch();
String vectorColName = "vec";
DataSet<Tuple2<BaseVectorSummary, CorrelationResult>> dataSet = StatisticsHelper.vectorPearsonCorrelation(data, vectorColName);
Tuple2<BaseVectorSummary, CorrelationResult> tuple2 = dataSet.collect().get(0);
BaseVectorSummary summary = tuple2.f0;
CorrelationResult corr = tuple2.f1;
assertEquals(summary.vectorSize(), 3);
assertEquals(summary.count(), 4);
assertEquals(summary.max(2), 4.0, 10e-4);
assertEquals(summary.min(1), 0.0, 10e-4);
assertEquals(summary.mean(2), 1.25, 10e-4);
assertEquals(summary.variance(2), 8.9167, 10e-4);
assertEquals(summary.standardDeviation(2), 2.9861, 10e-4);
assertEquals(summary.normL1(2), 11.0, 10e-4);
assertEquals(summary.normL2(2), 5.7446, 10e-4);
assertArrayEquals(corr.getCorrelationMatrix().getArrayCopy1D(true), new double[] { 1.0, 0.6325, -0.9570, 0.6325, 1.0, -0.4756, -0.9570, -0.4756, 1.0 }, 10e-4);
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class StatisticsHelperTest method summaryHelperTable.
@Test
public void summaryHelperTable() throws Exception {
BatchOperator data = getBatchTable();
String[] selectedColNames = new String[] { "f_long", "f_int", "f_double" };
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataSet = StatisticsHelper.summaryHelper(data, selectedColNames, null);
BaseVectorSummary summary = dataSet.f1.collect().get(0);
assertEquals(summary.vectorSize(), 3);
assertEquals(summary.count(), 4);
assertEquals(summary.max(2), 4.0, 10e-4);
assertEquals(summary.min(1), 0.0, 10e-4);
assertEquals(summary.mean(2), 1.25, 10e-4);
assertEquals(summary.variance(2), 8.9167, 10e-4);
assertEquals(summary.standardDeviation(2), 2.9861, 10e-4);
assertEquals(summary.normL1(2), 11.0, 10e-4);
assertEquals(summary.normL2(2), 5.7446, 10e-4);
List<Vector> vectors = dataSet.f0.collect();
assertEquals(vectors.size(), 4);
assertArrayEquals(((DenseVector) vectors.get(0)).getData(), new double[] { 1, 1, 2.0 }, 10e-4);
assertArrayEquals(((DenseVector) vectors.get(1)).getData(), new double[] { 2, 2, -3.0 }, 10e-4);
assertArrayEquals(((DenseVector) vectors.get(2)).getData(), new double[] { 1, 3, 2.0 }, 10e-4);
assertArrayEquals(((DenseVector) vectors.get(3)).getData(), new double[] { 0, 0, 4.0 }, 10e-4);
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.
the class StatisticsHelperTest method vectorSummary.
@Test
public void vectorSummary() throws Exception {
BatchOperator data = getDenseBatch();
String vectorColName = "vec";
DataSet<BaseVectorSummary> dataSet = StatisticsHelper.vectorSummary(data, vectorColName);
BaseVectorSummary summary = dataSet.collect().get(0);
assertEquals(summary.vectorSize(), 3);
assertEquals(summary.count(), 4);
assertEquals(summary.max(2), 4.0, 10e-4);
assertEquals(summary.min(1), 0.0, 10e-4);
assertEquals(summary.mean(2), 1.25, 10e-4);
assertEquals(summary.variance(2), 8.9167, 10e-4);
assertEquals(summary.standardDeviation(2), 2.9861, 10e-4);
assertEquals(summary.normL1(2), 11.0, 10e-4);
assertEquals(summary.normL2(2), 5.7446, 10e-4);
}
Aggregations