Search in sources :

Example 1 with MTable

use of com.alibaba.alink.common.MTable in project Alink by alibaba.

the class DeepARModelMapper method predictMultiVar.

@Override
protected Tuple2<Vector[], String> predictMultiVar(Timestamp[] historyTimes, Vector[] historyVals, int predictNum) {
    Timestamp[] predictTimes = TimeSeriesMapper.getPredictTimes(historyTimes, predictNum);
    int window = historyVals.length;
    int series = 0;
    DenseVector[] vectors = new DenseVector[historyVals.length];
    for (int i = 0; i < window; ++i) {
        vectors[i] = VectorUtil.getDenseVector(historyVals[i]);
        if (vectors[i] == null) {
            throw new IllegalArgumentException("history values should not be null.");
        }
        series = vectors[i].size();
    }
    FloatTensor[][] tensors = new FloatTensor[series][window];
    for (int i = 0; i < series; ++i) {
        tensors[i][0] = Tensor.cat(new FloatTensor[] { new FloatTensor(new float[] { 0.0f }), DeepARFeaturesGenerator.generateFromFrequency(calendar.get(), unit, historyTimes[0]) }, -1, null);
        for (int j = 1; j < window; ++j) {
            tensors[i][j] = Tensor.cat(new FloatTensor[] { new FloatTensor(new float[] { (float) vectors[j - 1].get(i) }), DeepARFeaturesGenerator.generateFromFrequency(calendar.get(), unit, historyTimes[j]) }, -1, null);
        }
    }
    FloatTensor[] batch = new FloatTensor[series];
    for (int i = 0; i < series; ++i) {
        batch[i] = Tensor.stack(tensors[i], 0, null);
    }
    Vector[] result = new Vector[predictNum];
    Row[] sigmas = new Row[predictNum];
    for (int i = 0; i < predictNum; ++i) {
        result[i] = new DenseVector(series);
        sigmas[i] = Row.of(new DenseVector(series));
    }
    for (int i = 0; i < series; ++i) {
        float mu = (float) historyVals[window - 1].get(i);
        FloatTensor v = new FloatTensor(new float[] { 0.0f, 0.0f });
        int nonZero = 0;
        for (int j = 0; j < window; ++j) {
            float cell = batch[i].getFloat(j, 0);
            if (cell != 0) {
                nonZero += 1;
            }
            v.setFloat(v.getFloat(0) + cell, 0);
        }
        if (mu != 0) {
            nonZero += 1;
            v.setFloat(v.getFloat(0) + mu, 0);
        }
        if (nonZero == 0) {
            continue;
        }
        v.setFloat(v.getFloat(0) / nonZero + 1.0f, 0);
        for (int j = 0; j < window; ++j) {
            batch[i].setFloat(batch[i].getFloat(j, 0) / v.getFloat(0), j, 0);
        }
        mu = mu / v.getFloat(0);
        for (int j = 0; j < predictNum; ++j) {
            batch[i] = Tensor.cat(new FloatTensor[] { batch[i], Tensor.stack(new FloatTensor[] { Tensor.cat(new FloatTensor[] { new FloatTensor(new float[] { mu }), DeepARFeaturesGenerator.generateFromFrequency(calendar.get(), unit, predictTimes[j]) }, -1, null) }, 0, null) }, 0, null);
            FloatTensor pred;
            try {
                pred = (FloatTensor) tfTableModelPredictModelMapper.map(Row.of(batch[i])).getField(0);
            } catch (Exception e) {
                return Tuple2.of(null, null);
            }
            mu = pred.getFloat(window + j, 0);
            float sigma = pred.getFloat(window + j, 1);
            result[j].set(i, mu * v.getFloat(0) + v.getFloat(1));
            ((Vector) (sigmas[j].getField(0))).set(i, sigma * v.getFloat(0));
        }
    }
    return Tuple2.of(result, new MTable(Arrays.asList(sigmas), new String[] { "sigma" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }).toString());
}
Also used : Timestamp(java.sql.Timestamp) MTable(com.alibaba.alink.common.MTable) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) DenseVector(com.alibaba.alink.common.linalg.DenseVector)

Example 2 with MTable

use of com.alibaba.alink.common.MTable in project Alink by alibaba.

the class DeepARModelMapper method predictSingleVar.

@Override
protected Tuple2<double[], String> predictSingleVar(Timestamp[] historyTimes, double[] historyVals, int predictNum) {
    Timestamp[] predictTimes = TimeSeriesMapper.getPredictTimes(historyTimes, predictNum);
    int window = historyVals.length;
    FloatTensor[] tensors = new FloatTensor[window];
    // fill the first z with zero
    tensors[0] = Tensor.cat(new FloatTensor[] { new FloatTensor(new float[] { 0.0f }), DeepARFeaturesGenerator.generateFromFrequency(calendar.get(), unit, historyTimes[0]) }, -1, null);
    // others
    for (int i = 1; i < window; ++i) {
        tensors[i] = Tensor.cat(new FloatTensor[] { new FloatTensor(new float[] { (float) historyVals[i - 1] }), DeepARFeaturesGenerator.generateFromFrequency(calendar.get(), unit, historyTimes[i]) }, -1, null);
    }
    FloatTensor batch = Tensor.stack(tensors, 0, null);
    // initialize mu
    float mu = (float) historyVals[window - 1];
    // calculate v
    FloatTensor v = new FloatTensor(new float[] { 0.0f, 0.0f });
    int nonZero = 0;
    for (int i = 0; i < window; ++i) {
        float cell = batch.getFloat(i, 0);
        if (cell != 0) {
            nonZero += 1;
        }
        v.setFloat(v.getFloat(0) + cell, 0);
    }
    if (mu != 0) {
        nonZero += 1;
        v.setFloat(v.getFloat(0) + mu, 0);
    }
    if (nonZero == 0) {
        double[] result = new double[predictNum];
        Row[] sigmas = new Row[predictNum];
        Arrays.fill(result, 0.0);
        Arrays.fill(sigmas, Row.of(0));
        return Tuple2.of(result, new MTable(Arrays.asList(sigmas), new String[] { "sigma" }, new TypeInformation<?>[] { Types.DOUBLE }).toString());
    }
    v.setFloat(v.getFloat(0) / nonZero + 1.0f, 0);
    // normalize with v
    for (int i = 0; i < window; ++i) {
        batch.setFloat(batch.getFloat(i, 0) / v.getFloat(0), i, 0);
    }
    mu = mu / v.getFloat(0);
    // result initialize.
    double[] result = new double[predictNum];
    Row[] sigmas = new Row[predictNum];
    Arrays.fill(result, 0.0);
    for (int i = 0; i < predictNum; ++i) {
        sigmas[i] = Row.of(0.0);
    }
    // prediction
    for (int j = 0; j < predictNum; ++j) {
        batch = Tensor.cat(new FloatTensor[] { batch, Tensor.stack(new FloatTensor[] { Tensor.cat(new FloatTensor[] { new FloatTensor(new float[] { mu }), DeepARFeaturesGenerator.generateFromFrequency(calendar.get(), unit, predictTimes[j]) }, -1, null) }, 0, null) }, 0, null);
        FloatTensor pred;
        try {
            pred = (FloatTensor) tfTableModelPredictModelMapper.map(Row.of(batch)).getField(0);
        } catch (Exception e) {
            return Tuple2.of(null, null);
        }
        mu = pred.getFloat(window + j, 0);
        float sigma = pred.getFloat(window + j, 1);
        result[j] = mu * v.getFloat(0) + v.getFloat(1);
        sigmas[j].setField(0, sigma * v.getFloat(0));
    }
    return Tuple2.of(result, new MTable(Arrays.asList(sigmas), new String[] { "sigma" }, new TypeInformation<?>[] { Types.DOUBLE }).toString());
}
Also used : MTable(com.alibaba.alink.common.MTable) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) Timestamp(java.sql.Timestamp)

Example 3 with MTable

use of com.alibaba.alink.common.MTable in project Alink by alibaba.

the class LookupValueInTimeSeriesMapper method map.

@Override
protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception {
    MTable mTable = null;
    if (selection.get(1) == null) {
        result.set(0, null);
        return;
    }
    if (selection.get(1) instanceof MTable) {
        mTable = (MTable) selection.get(1);
    } else {
        mTable = new MTable((String) selection.get(1));
    }
    if (mTable.getNumRow() == 0) {
        result.set(0, null);
        return;
    }
    Timestamp lookupTime = (Timestamp) selection.get(0);
    TableSchema schema = mTable.getTableSchema();
    String timeCol = null;
    int timeIdx = -1;
    TypeInformation<?>[] colTypes = schema.getFieldTypes();
    for (int i = 0; i < colTypes.length; i++) {
        if (colTypes[i] == Types.SQL_TIMESTAMP) {
            timeCol = schema.getFieldNames()[i];
            timeIdx = i;
        }
    }
    if (timeIdx == -1) {
        throw new RuntimeException("can not find time column, lookup failed");
    }
    String[] valueCols = TableUtil.getNumericCols(schema);
    if (valueCols.length >= 1) {
        List<Object> times = MTableUtils.getColumn(mTable, timeCol);
        int idxRow = times.indexOf(lookupTime);
        int idxCol = TableUtil.findColIndex(schema, valueCols[0]);
        if (idxRow >= 0) {
            result.set(0, ((Number) mTable.getEntry(idxRow, idxCol)).doubleValue());
            return;
        } else {
            mTable.orderBy(timeIdx);
            Timestamp[] timesArr = MTableUtils.getColumn(mTable, timeCol).toArray(new Timestamp[] {});
            int pos = Arrays.binarySearch(timesArr, lookupTime);
            if (pos == -1 || -1 - pos == timesArr.length) {
                result.set(0, null);
                return;
            // throw new RuntimeException("can not find value, value of time is not within expected");
            } else {
                int pos0 = -2 - pos;
                int pos1 = -1 - pos;
                long time0 = timesArr[pos0].getTime();
                long time1 = timesArr[pos1].getTime();
                double scale = (double) (lookupTime.getTime() - time0) / (double) (time1 - time0);
                double inter = (1 - scale) * (double) mTable.getEntry(pos0, idxCol) + scale * (double) mTable.getEntry(pos1, idxCol);
                result.set(0, inter);
            }
            return;
        }
    }
    result.set(0, null);
}
Also used : TableSchema(org.apache.flink.table.api.TableSchema) Timestamp(java.sql.Timestamp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) MTable(com.alibaba.alink.common.MTable)

Example 4 with MTable

use of com.alibaba.alink.common.MTable in project Alink by alibaba.

the class LookupVectorInTimeSeriesMapper method map.

@Override
protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception {
    MTable mTable = null;
    if (selection.get(1) == null) {
        result.set(0, null);
        return;
    }
    if (selection.get(1) instanceof MTable) {
        mTable = (MTable) selection.get(1);
    } else {
        mTable = new MTable((String) selection.get(1));
    }
    if (mTable.getNumRow() == 0) {
        result.set(0, null);
        return;
    }
    Timestamp lookupTime = (Timestamp) selection.get(0);
    TableSchema schema = mTable.getTableSchema();
    String timeCol = null;
    String vectorCol = null;
    int timeIdx = -1;
    TypeInformation<?>[] colTypes = schema.getFieldTypes();
    for (int i = 0; i < colTypes.length; i++) {
        if (colTypes[i] == Types.SQL_TIMESTAMP) {
            timeCol = schema.getFieldNames()[i];
            timeIdx = i;
        }
        if (colTypes[i] == VectorTypes.VECTOR || colTypes[i] == VectorTypes.DENSE_VECTOR || colTypes[i] == VectorTypes.SPARSE_VECTOR) {
            vectorCol = schema.getFieldNames()[i];
        }
    }
    if (null != timeCol && null != vectorCol) {
        List<Object> times = MTableUtils.getColumn(mTable, timeCol);
        int idxRow = times.indexOf(lookupTime);
        int idxCol = TableUtil.findColIndex(schema, vectorCol);
        if (idxRow >= 0) {
            result.set(0, mTable.getEntry(idxRow, idxCol));
            return;
        } else {
            mTable.orderBy(timeIdx);
            Timestamp[] timesArr = MTableUtils.getColumn(mTable, timeCol).toArray(new Timestamp[] {});
            int pos = Arrays.binarySearch(timesArr, lookupTime);
            if (pos == -1) {
                result.set(0, mTable.getEntry(0, idxCol));
            } else if (-pos == timesArr.length + 1) {
                result.set(0, mTable.getEntry(timesArr.length - 1, idxCol));
            } else {
                int pos0 = -2 - pos;
                int pos1 = -1 - pos;
                long time0 = timesArr[pos0].getTime();
                long time1 = timesArr[pos1].getTime();
                double scale = (double) (lookupTime.getTime() - time0) / (double) (time1 - time0);
                DenseVector inter = ((DenseVector) mTable.getEntry(pos0, idxCol)).scale(1 - scale);
                inter.plusEqual(((DenseVector) mTable.getEntry(pos1, idxCol)).scale(scale));
                result.set(0, inter);
            }
            return;
        }
    }
    result.set(0, null);
}
Also used : MTable(com.alibaba.alink.common.MTable) TableSchema(org.apache.flink.table.api.TableSchema) Timestamp(java.sql.Timestamp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) DenseVector(com.alibaba.alink.common.linalg.DenseVector)

Example 5 with MTable

use of com.alibaba.alink.common.MTable in project Alink by alibaba.

the class FlattenMTableStreamTest method linkFrom.

@Test
public void linkFrom() throws Exception {
    List<Row> rows = new ArrayList<>();
    rows.add(Row.of(1, "2", 0, null, new SparseVector(3, new int[] { 1 }, new double[] { 2.0 }), new FloatTensor(new float[] { 3.0f })));
    rows.add(Row.of(null, "2", 0, new DenseVector(new double[] { 0.0, 1.0 }), new SparseVector(4, new int[] { 2 }, new double[] { 3.0 }), new FloatTensor(new float[] { 3.0f })));
    rows.add(Row.of(null, "2", 0, new DenseVector(new double[] { 0.1, 1.0 }), new SparseVector(4, new int[] { 2 }, new double[] { 3.0 }), new FloatTensor(new float[] { 3.0f })));
    String schemaStr = "col0 int, col1 string, label int" + ", d_vec DENSE_VECTOR" + ", s_vec SPARSE_VECTOR" + ", tensor FLOAT_TENSOR";
    MTable mTable = new MTable(rows, schemaStr);
    List<Row> table = new ArrayList<>();
    table.add(Row.of("id", mTable.toString()));
    StreamOperator<?> op = new MemSourceStreamOp(table, new String[] { "id", "mTable" });
    StreamOperator<?> res = op.link(new FlattenMTableStreamOp().setSchemaStr(schemaStr).setSelectedCol("mTable").setReservedCols("id"));
    CollectSinkStreamOp sop = res.link(new CollectSinkStreamOp());
    StreamOperator.execute();
    List<Row> list = sop.getAndRemoveValues();
    for (Row row : list) {
        Assert.assertEquals(row.getField(0), "id");
    }
}
Also used : MemSourceStreamOp(com.alibaba.alink.operator.stream.source.MemSourceStreamOp) MTable(com.alibaba.alink.common.MTable) CollectSinkStreamOp(com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp) ArrayList(java.util.ArrayList) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) SparseVector(com.alibaba.alink.common.linalg.SparseVector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) Test(org.junit.Test)

Aggregations

MTable (com.alibaba.alink.common.MTable)26 Row (org.apache.flink.types.Row)19 Test (org.junit.Test)12 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)9 Timestamp (java.sql.Timestamp)8 ArrayList (java.util.ArrayList)8 TableSchema (org.apache.flink.table.api.TableSchema)8 SparseVector (com.alibaba.alink.common.linalg.SparseVector)7 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)6 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)5 HashMap (java.util.HashMap)5 Params (org.apache.flink.ml.api.misc.param.Params)5 DenseVector (com.alibaba.alink.common.linalg.DenseVector)4 List (java.util.List)4 FlinkTypeConverter (com.alibaba.alink.operator.common.io.types.FlinkTypeConverter)3 CollectSinkStreamOp (com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp)3 MemSourceStreamOp (com.alibaba.alink.operator.stream.source.MemSourceStreamOp)3 BaseItemsPerUserRecommParams (com.alibaba.alink.params.recommendation.BaseItemsPerUserRecommParams)3 BaseSimilarItemsRecommParams (com.alibaba.alink.params.recommendation.BaseSimilarItemsRecommParams)3 Serializable (java.io.Serializable)3