Search in sources :

Example 1 with VectorStandardScalerModel

use of com.alibaba.alink.pipeline.dataproc.vector.VectorStandardScalerModel in project Alink by alibaba.

the class VectorStandardScalerTest method test.

@Test
public void test() throws Exception {
    BatchOperator batchData = new TableSourceBatchOp(GenerateData.getDenseBatch()).link(new AppendIdBatchOp().setIdCol("id"));
    StreamOperator streamData = new TableSourceStreamOp(GenerateData.getDenseStream());
    VectorStandardScalerTrainBatchOp op = new VectorStandardScalerTrainBatchOp().setWithMean(true).setWithStd(true).setSelectedCol("vec").linkFrom(batchData);
    BatchOperator res = new VectorStandardScalerPredictBatchOp().setOutputCol("vec_1").linkFrom(op, batchData);
    List<Row> list = res.collect();
    Collections.sort(list, new Comparator<Row>() {

        @Override
        public int compare(Row o1, Row o2) {
            return Long.compare((long) o1.getField(1), (long) o2.getField(1));
        }
    });
    assertDv(VectorUtil.getDenseVector(list.get(1).getField(2)), new DenseVector(new double[] { -0.9272, -1.1547 }));
    assertDv(VectorUtil.getDenseVector(list.get(0).getField(2)), new DenseVector(new double[] { -0.1325, 0.5774 }));
    assertDv(VectorUtil.getDenseVector(list.get(2).getField(2)), new DenseVector(new double[] { 1.0596, 0.5774 }));
    new VectorStandardScalerPredictStreamOp(op).setOutputCol("vec_1").linkFrom(streamData).print();
    VectorStandardScalerModel model1 = new VectorStandardScaler().setWithMean(true).setWithStd(false).setSelectedCol("vec").setOutputCol("vec_1").fit(batchData);
    list = model1.transform(batchData).collect();
    Collections.sort(list, new Comparator<Row>() {

        @Override
        public int compare(Row o1, Row o2) {
            return Long.compare((long) o1.getField(1), (long) o2.getField(1));
        }
    });
    assertDv(VectorUtil.getDenseVector(list.get(1).getField(2)), new DenseVector(new double[] { -2.3333, -3.3333 }));
    assertDv(VectorUtil.getDenseVector(list.get(0).getField(2)), new DenseVector(new double[] { -0.3333, 1.6666 }));
    assertDv(VectorUtil.getDenseVector(list.get(2).getField(2)), new DenseVector(new double[] { 2.6666, 1.6666 }));
    model1.transform(streamData).print();
    VectorStandardScalerModel model2 = new VectorStandardScaler().setWithMean(false).setWithStd(true).setSelectedCol("vec").setOutputCol("vec_1").fit(batchData);
    list = model2.transform(batchData).collect();
    Collections.sort(list, new Comparator<Row>() {

        @Override
        public int compare(Row o1, Row o2) {
            return Long.compare((long) o1.getField(1), (long) o2.getField(1));
        }
    });
    assertDv(VectorUtil.getDenseVector(list.get(1).getField(2)), new DenseVector(new double[] { -0.3974, -1.0392 }));
    assertDv(VectorUtil.getDenseVector(list.get(0).getField(2)), new DenseVector(new double[] { 0.3974, 0.6928 }));
    assertDv(VectorUtil.getDenseVector(list.get(2).getField(2)), new DenseVector(new double[] { 1.5894, 0.6928 }));
    model2.transform(streamData).print();
    VectorStandardScalerModel model3 = new VectorStandardScaler().setWithMean(false).setWithStd(false).setSelectedCol("vec").setOutputCol("vec_1").fit(batchData);
    list = model3.transform(batchData).collect();
    Collections.sort(list, new Comparator<Row>() {

        @Override
        public int compare(Row o1, Row o2) {
            return Long.compare((long) o1.getField(1), (long) o2.getField(1));
        }
    });
    assertDv(VectorUtil.getDenseVector(list.get(1).getField(2)), new DenseVector(new double[] { -1., -3. }));
    assertDv(VectorUtil.getDenseVector(list.get(0).getField(2)), new DenseVector(new double[] { 1., 2. }));
    assertDv(VectorUtil.getDenseVector(list.get(2).getField(2)), new DenseVector(new double[] { 4., 2. }));
    model3.transform(streamData).print();
    StreamOperator.execute();
}
Also used : VectorStandardScalerPredictStreamOp(com.alibaba.alink.operator.stream.dataproc.vector.VectorStandardScalerPredictStreamOp) VectorStandardScalerModel(com.alibaba.alink.pipeline.dataproc.vector.VectorStandardScalerModel) VectorStandardScaler(com.alibaba.alink.pipeline.dataproc.vector.VectorStandardScaler) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) AppendIdBatchOp(com.alibaba.alink.operator.batch.dataproc.AppendIdBatchOp) TableSourceStreamOp(com.alibaba.alink.operator.stream.source.TableSourceStreamOp) Row(org.apache.flink.types.Row) StreamOperator(com.alibaba.alink.operator.stream.StreamOperator) DenseVector(com.alibaba.alink.common.linalg.DenseVector) Test(org.junit.Test)

Aggregations

DenseVector (com.alibaba.alink.common.linalg.DenseVector)1 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 AppendIdBatchOp (com.alibaba.alink.operator.batch.dataproc.AppendIdBatchOp)1 TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)1 StreamOperator (com.alibaba.alink.operator.stream.StreamOperator)1 VectorStandardScalerPredictStreamOp (com.alibaba.alink.operator.stream.dataproc.vector.VectorStandardScalerPredictStreamOp)1 TableSourceStreamOp (com.alibaba.alink.operator.stream.source.TableSourceStreamOp)1 VectorStandardScaler (com.alibaba.alink.pipeline.dataproc.vector.VectorStandardScaler)1 VectorStandardScalerModel (com.alibaba.alink.pipeline.dataproc.vector.VectorStandardScalerModel)1 Row (org.apache.flink.types.Row)1 Test (org.junit.Test)1