use of com.alibaba.alink.common.mapper.Mapper in project Alink by alibaba.
the class TFTableModelClassificationModelMapper method predictResultDetail.
@Override
protected Tuple2<Object, String> predictResultDetail(SlicedSelectedSample selection) throws Exception {
Row output = new Row(selection.length());
selection.fillRow(output);
for (Mapper mapper : mappers) {
output = mapper.map(output);
}
FloatTensor tensor = (FloatTensor) output.getField(predColId);
Object predLabel = PredictionExtractUtils.extractFromTensor(tensor, sortedLabels, predDetail, isOutputLogits);
return Tuple2.of(predLabel, JsonConverter.toJson(predDetail));
}
use of com.alibaba.alink.common.mapper.Mapper in project Alink by alibaba.
the class PipelineModel method collectLocalPredictor.
@Override
public LocalPredictor collectLocalPredictor(TableSchema inputSchema) throws Exception {
if (params.get(ModelStreamScanParams.MODEL_STREAM_FILE_PATH) != null) {
BatchOperator<?> modelSave = ModelExporterUtils.serializePipelineStages(Arrays.asList(transformers), params);
TableSchema extendSchema = getOutSchema(this, inputSchema);
BatchOperator<?> model = new TableSourceBatchOp(DataSetConversionUtil.toTable(modelSave.getMLEnvironmentId(), modelSave.getDataSet().map(new PipelineModelMapper.ExtendPipelineModelRow(extendSchema.getFieldNames().length + 1)), PipelineModelMapper.getExtendModelSchema(modelSave.getSchema(), extendSchema.getFieldNames(), extendSchema.getFieldTypes())));
List<Row> modelRows = model.collect();
ModelMapper mapper = new PipelineModelMapper(model.getSchema(), inputSchema, this.params);
mapper.loadModel(modelRows);
return new LocalPredictor(mapper);
}
if (null == transformers || transformers.length == 0) {
throw new RuntimeException("PipelineModel is empty.");
}
List<BatchOperator<?>> allModelData = new ArrayList<>();
for (TransformerBase<?> transformer : transformers) {
if (!(transformer instanceof LocalPredictable)) {
throw new RuntimeException(transformer.getClass().toString() + " not support local predict.");
}
if (transformer instanceof ModelBase) {
allModelData.add(((ModelBase<?>) transformer).getModelData());
}
}
List<List<Row>> allModelDataRows;
if (!allModelData.isEmpty()) {
allModelDataRows = BatchOperator.collect(allModelData.toArray(new BatchOperator<?>[0]));
} else {
allModelDataRows = new ArrayList<>();
}
TableSchema schema = inputSchema;
int numMapperModel = 0;
List<Mapper> mappers = new ArrayList<>();
for (TransformerBase<?> transformer : transformers) {
Mapper mapper;
if (transformer instanceof MapModel) {
mapper = ModelExporterUtils.createMapperFromStage(transformer, ((MapModel<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
numMapperModel++;
} else if (transformer instanceof BaseRecommender) {
mapper = ModelExporterUtils.createMapperFromStage(transformer, ((BaseRecommender<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
numMapperModel++;
} else {
mapper = ModelExporterUtils.createMapperFromStage(transformer, null, schema, null);
}
mappers.add(mapper);
schema = mapper.getOutputSchema();
}
return new LocalPredictor(mappers.toArray(new Mapper[0]));
}
use of com.alibaba.alink.common.mapper.Mapper in project Alink by alibaba.
the class MapStreamOp method linkFrom.
@Override
public T linkFrom(StreamOperator<?>... inputs) {
StreamOperator<?> in = checkAndGetFirst(inputs);
try {
Mapper mapper = this.mapperBuilder.apply(in.getSchema(), this.getParams());
DataStream<Row> resultRows;
if (getParams().get(MapperParams.NUM_THREADS) <= 1) {
resultRows = in.getDataStream().map(new MapperAdapter(mapper));
} else {
resultRows = in.getDataStream().flatMap(new MapperAdapterMT(mapper, getParams().get(MapperParams.NUM_THREADS)));
}
this.setOutput(resultRows, mapper.getOutputSchema());
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
use of com.alibaba.alink.common.mapper.Mapper in project Alink by alibaba.
the class VectorToTensorMapperTest method testFloatType.
@Test
public void testFloatType() throws Exception {
final Mapper mapper = new VectorToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(VectorToTensorParams.SELECTED_COL, "vec").set(VectorToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(VectorToTensorParams.TENSOR_DATA_TYPE, DataType.FLOAT));
Assert.assertEquals(TensorTypes.FLOAT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
use of com.alibaba.alink.common.mapper.Mapper in project Alink by alibaba.
the class VectorToTensorMapperTest method testReshape.
@Test
public void testReshape() throws Exception {
final Mapper mapper = new VectorToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(VectorToTensorParams.SELECTED_COL, "vec").set(VectorToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
Aggregations