use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class OneHotModelMapperTest method testNullModel.
@Test
public void testNullModel() throws Exception {
Params params = new Params().set(OneHotPredictParams.ENCODE, HasEncodeWithoutWoe.Encode.VECTOR).set(OneHotPredictParams.SELECTED_COLS, new String[] { "cnt", "word", "docid" });
OneHotModelMapper mapper = new OneHotModelMapper(modelSchema, dataSchema, params);
mapper.loadModel(nullModel);
Assert.assertEquals(mapper.map(defaultRow), Row.of(new SparseVector(2, new int[] { 1 }, new double[] { 1.0 }), new SparseVector(2, new int[] { 1 }, new double[] { 1.0 }), new SparseVector(2, new int[] { 1 }, new double[] { 1.0 })));
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class OneHotModelMapperTest method testHandleInvalidError.
@Test
public void testHandleInvalidError() throws Exception {
Params params = new Params().set(OneHotPredictParams.ENCODE, HasEncodeWithoutWoe.Encode.VECTOR).set(OneHotPredictParams.HANDLE_INVALID, HasHandleInvalid.HandleInvalid.ERROR).set(OneHotPredictParams.SELECTED_COLS, new String[] { "cnt", "word" }).set(OneHotPredictParams.DROP_LAST, false);
OneHotModelMapper mapper = new OneHotModelMapper(modelSchema, dataSchema, params);
mapper.loadModel(model);
assertEquals(mapper.map(nullElseRow), Row.of(null, new SparseVector(8, new int[] { 7 }, new double[] { 1.0 }), new SparseVector(6, new int[] { 2 }, new double[] { 1.0 })));
mapper.loadModel(newModel);
try {
assertEquals(mapper.map(nullElseRow), Row.of(null, null, new SparseVector(5, new int[] { 2 }, new double[] { 1.0 })));
} catch (Exception e) {
assertEquals(e.getMessage(), "Unseen token: 梅");
}
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class QuantileDiscretizerModelMapperTest method testVector.
@Test
public void testVector() throws Exception {
Params params = new Params().set(QuantileDiscretizerPredictParams.ENCODE, HasEncodeWithoutWoe.Encode.VECTOR).set(QuantileDiscretizerPredictParams.SELECTED_COLS, new String[] { "col2", "col3" }).set(QuantileDiscretizerPredictParams.DROP_LAST, false);
QuantileDiscretizerModelMapper mapper = new QuantileDiscretizerModelMapper(modelSchema, dataSchema, params);
mapper.loadModel(model);
assertEquals(mapper.map(defaultRow), Row.of("a", new SparseVector(4, new int[] { 0 }, new double[] { 1.0 }), new SparseVector(4, new int[] { 2 }, new double[] { 1.0 })));
assertEquals(mapper.map(nullElseRow), Row.of("b", new SparseVector(4, new int[] { 3 }, new double[] { 1.0 }), new SparseVector(4, new int[] { 1 }, new double[] { 1.0 })));
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class SosBatchOp method linkFrom.
@Override
public SosBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final String vectorColName = getVectorCol();
final String predResultColName = getPredictionCol();
final int vectorColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), vectorColName);
DataSet<Tuple2<Integer, Row>> pointsWithIndex = DataSetUtils.zipWithIndex(in.getDataSet()).map(new MapFunction<Tuple2<Long, Row>, Tuple2<Integer, Row>>() {
private static final long serialVersionUID = -2866816204744880078L;
@Override
public Tuple2<Integer, Row> map(Tuple2<Long, Row> in) throws Exception {
return new Tuple2<>(in.f0.intValue(), in.f1);
}
});
DataSet<Tuple2<Integer, DenseVector>> sosInput = pointsWithIndex.map(new MapFunction<Tuple2<Integer, Row>, Tuple2<Integer, DenseVector>>() {
private static final long serialVersionUID = 5798841020968290284L;
@Override
public Tuple2<Integer, DenseVector> map(Tuple2<Integer, Row> in) throws Exception {
Vector vec = VectorUtil.getVector(in.f1.getField(vectorColIdx));
if (null == vec) {
return new Tuple2<>(in.f0, null);
} else {
return new Tuple2<>(in.f0, (vec instanceof DenseVector) ? (DenseVector) vec : ((SparseVector) vec).toDenseVector());
}
}
});
SOSImpl sos = new SOSImpl(this.getParams());
DataSet<Tuple2<Integer, Double>> outlierProb = sos.outlierSelection(sosInput);
DataSet<Row> output = outlierProb.join(pointsWithIndex).where(0).equalTo(0).with(new JoinFunction<Tuple2<Integer, Double>, Tuple2<Integer, Row>, Row>() {
private static final long serialVersionUID = 7086848937713200592L;
@Override
public Row join(Tuple2<Integer, Double> in1, Tuple2<Integer, Row> in2) throws Exception {
Row row = new Row(in2.f1.getArity() + 1);
for (int i = 0; i < in2.f1.getArity(); i++) {
row.setField(i, in2.f1.getField(i));
}
row.setField(in2.f1.getArity(), in1.f1);
return row;
}
}).returns(new RowTypeInfo(ArrayUtils.add(in.getColTypes(), Types.DOUBLE)));
this.setOutput(output, ArrayUtils.add(in.getColNames(), predResultColName));
return this;
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class OnlineCorpusStep method onlineCorpusUpdate.
public static Tuple4<DenseMatrix, DenseMatrix, Long, Long> onlineCorpusUpdate(List<Vector> data, DenseMatrix lambda, DenseMatrix alpha, DenseMatrix gammad, int vocabularySize, int numTopic, double subSamplingRate, RandomDataGenerator random, int gammaShape) {
DenseMatrix wordTopicStat = DenseMatrix.zeros(numTopic, vocabularySize);
DenseMatrix logPhatPart = new DenseMatrix(numTopic, 1);
DenseMatrix expELogBeta = LdaUtil.expDirichletExpectation(lambda).transpose();
long nonEmptyWordCount = 0;
long nonEmptyDocCount = 0;
// the online corpus update stage can update the model in two way.
// if the document order is determined, then it will update in the order.
// or it will choose documents randomly.
int dataSize = data.size();
int[] indices = generateOnlineDocs(dataSize, subSamplingRate, random);
for (int index : indices) {
Vector vec = data.get(index);
SparseVector sv = (SparseVector) vec;
sv.setSize(vocabularySize);
sv.removeZeroValues();
for (int i = 0; i < sv.numberOfValues(); i++) {
nonEmptyWordCount += sv.getValues()[i];
}
gammad = LdaUtil.geneGamma(numTopic, gammaShape, random);
Tuple2<DenseMatrix, DenseMatrix> topicDistributionTuple = LdaUtil.getTopicDistributionMethod(sv, expELogBeta, alpha, gammad, numTopic);
for (int i = 0; i < sv.getIndices().length; i++) {
for (int k = 0; k < numTopic; k++) {
wordTopicStat.add(k, sv.getIndices()[i], topicDistributionTuple.f1.get(k, i));
}
}
gammad = topicDistributionTuple.f0;
DenseMatrix deGammad = LdaUtil.dirichletExpectationVec(gammad);
for (int k = 0; k < numTopic; k++) {
logPhatPart.add(k, 0, deGammad.get(k, 0));
}
nonEmptyDocCount++;
}
return new Tuple4<>(wordTopicStat, logPhatPart, nonEmptyWordCount, nonEmptyDocCount);
}
Aggregations