use of com.alibaba.alink.pipeline.dataproc.vector.VectorFunction in project Alink by alibaba.
the class Chap25 method cnn.
public static void cnn(BatchOperator<?> train_set, BatchOperator<?> test_set) throws Exception {
BatchOperator.setParallelism(1);
new Pipeline().add(new VectorFunction().setSelectedCol("vec").setFuncName("Scale").setWithVariable(1.0 / 255.0)).add(new VectorToTensor().setTensorDataType("float").setTensorShape(28, 28).setSelectedCol("vec").setOutputCol("tensor").setReservedCols("label")).add(new KerasSequentialClassifier().setTensorCol("tensor").setLabelCol("label").setPredictionCol("pred").setLayers("Reshape((28, 28, 1))", "Conv2D(32, kernel_size=(3, 3), activation='relu')", "MaxPooling2D(pool_size=(2, 2))", "Conv2D(64, kernel_size=(3, 3), activation='relu')", "MaxPooling2D(pool_size=(2, 2))", "Flatten()", "Dropout(0.5)").setNumEpochs(20).setValidationSplit(0.1).setSaveBestOnly(true).setBestMetric("sparse_categorical_accuracy")).fit(train_set).transform(test_set).link(new EvalMultiClassBatchOp().setLabelCol("label").setPredictionCol("pred").lazyPrintMetrics());
BatchOperator.execute();
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorFunction in project Alink by alibaba.
the class Chap25 method dnn.
public static void dnn(BatchOperator<?> train_set, BatchOperator<?> test_set) throws Exception {
BatchOperator.setParallelism(1);
new Pipeline().add(new VectorFunction().setSelectedCol("vec").setFuncName("Scale").setWithVariable(1.0 / 255.0)).add(new VectorToTensor().setTensorDataType("float").setSelectedCol("vec").setOutputCol("tensor").setReservedCols("label")).add(new KerasSequentialClassifier().setTensorCol("tensor").setLabelCol("label").setPredictionCol("pred").setLayers("Dense(256, activation='relu')", "Dense(128, activation='relu')").setNumEpochs(50).setBatchSize(512).setValidationSplit(0.1).setSaveBestOnly(true).setBestMetric("sparse_categorical_accuracy")).fit(train_set).transform(test_set).link(new EvalMultiClassBatchOp().setLabelCol("label").setPredictionCol("pred").lazyPrintMetrics());
BatchOperator.execute();
}
Aggregations