use of com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelData in project Alink by alibaba.
the class DocCountVectorizerTrainBatchOp method generateDocCountModel.
public static DataSet<DocCountVectorizerModelData> generateDocCountModel(Params params, BatchOperator in) {
BatchOperator<?> docWordCnt = in.udtf(params.get(SELECTED_COL), new String[] { WORD_COL_NAME, DOC_WORD_COUNT_COL_NAME }, new DocWordSplitCount(NLPConstant.WORD_DELIMITER), new String[] {});
BatchOperator docCnt = in.select("COUNT(1) AS " + DOC_COUNT_COL_NAME);
DataSet<Row> sortInput = docWordCnt.select(new String[] { WORD_COL_NAME, DOC_WORD_COUNT_COL_NAME }).getDataSet().groupBy(0).reduceGroup(new CalcIdf(params.get(MAX_DF), params.get(MIN_DF))).withBroadcastSet(docCnt.getDataSet(), "docCnt");
Tuple2<DataSet<Tuple2<Integer, Row>>, DataSet<Tuple2<Integer, Long>>> partitioned = SortUtils.pSort(sortInput, 1);
DataSet<Tuple2<Long, Row>> ordered = localSort(partitioned.f0, partitioned.f1, 1);
int vocabSize = params.get(VOCAB_SIZE);
DataSet<DocCountVectorizerModelData> resDocCountModel = ordered.flatMap(new FlatMapFunction<Tuple2<Long, Row>, Tuple2<String, Double>>() {
private static final long serialVersionUID = -1668412648425550909L;
@Override
public void flatMap(Tuple2<Long, Row> value, Collector<Tuple2<String, Double>> out) throws Exception {
if (value.f0 < vocabSize) {
out.collect(Tuple2.of(value.f1.getField(0).toString(), ((Number) value.f1.getField(2)).doubleValue()));
}
}
}).partitionCustom(new Partitioner<String>() {
private static final long serialVersionUID = 5129015018479212319L;
@Override
public int partition(String key, int numPartitions) {
return 0;
}
}, 0).sortPartition(0, Order.DESCENDING).mapPartition(new BuildDocCountModel(params)).setParallelism(1);
return resDocCountModel;
}
use of com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelData in project Alink by alibaba.
the class LdaTrainBatchOp method linkFrom.
@Override
public LdaTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
int parallelism = BatchOperator.getExecutionEnvironmentFromOps(in).getParallelism();
long mlEnvId = getMLEnvironmentId();
int numTopic = getTopicNum();
int numIter = getNumIter();
Integer seed = getRandomSeed();
boolean setSeed = true;
if (seed == null) {
setSeed = false;
}
String vectorColName = getSelectedCol();
Method optimizer = getMethod();
final DataSet<DocCountVectorizerModelData> resDocCountModel = DocCountVectorizerTrainBatchOp.generateDocCountModel(getParams(), in);
int index = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), vectorColName);
DataSet<Row> resRow = in.getDataSet().flatMap(new Document2Vector(index)).withBroadcastSet(resDocCountModel, "DocCountModel");
TypeInformation<?>[] types = in.getColTypes().clone();
types[index] = TypeInformation.of(SparseVector.class);
BatchOperator trainData = new TableSourceBatchOp(DataSetConversionUtil.toTable(mlEnvId, resRow, in.getColNames(), types)).setMLEnvironmentId(mlEnvId);
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataAndStat = StatisticsHelper.summaryHelper(trainData, null, vectorColName);
if (setSeed) {
DataSet<Tuple2<Long, Vector>> hashValue = dataAndStat.f0.map(new MapHashValue(seed)).partitionCustom(new Partitioner<Long>() {
private static final long serialVersionUID = 5179898093029365608L;
@Override
public int partition(Long key, int numPartitions) {
return (int) (Math.abs(key) % ((long) numPartitions));
}
}, 0);
dataAndStat.f0 = hashValue.mapPartition(new MapPartitionFunction<Tuple2<Long, Vector>, Vector>() {
private static final long serialVersionUID = -550512476573928350L;
@Override
public void mapPartition(Iterable<Tuple2<Long, Vector>> values, Collector<Vector> out) throws Exception {
List<Tuple2<Long, Vector>> listValues = Lists.newArrayList(values);
listValues.sort(new Comparator<Tuple2<Long, Vector>>() {
@Override
public int compare(Tuple2<Long, Vector> o1, Tuple2<Long, Vector> o2) {
int compare1 = o1.f0.compareTo(o2.f0);
if (compare1 == 0) {
String o1s = o1.f1.toString();
String o2s = o2.f1.toString();
return o1s.compareTo(o2s);
}
return compare1;
}
});
listValues.forEach(x -> out.collect(x.f1));
}
}).setParallelism(parallelism);
}
double beta = getParams().get(BETA);
double alpha = getParams().get(ALPHA);
int gammaShape = 250;
switch(optimizer) {
case EM:
gibbsSample(dataAndStat, numTopic, numIter, alpha, beta, resDocCountModel, seed);
break;
case Online:
online(dataAndStat, numTopic, numIter, alpha, beta, resDocCountModel, gammaShape, seed);
break;
default:
throw new NotImplementedException("Optimizer not support.");
}
return this;
}
Aggregations