use of com.alibaba.alink.pipeline.regression.KerasSequentialRegressor in project Alink by alibaba.
the class Chap25 method dnnReg.
public static void dnnReg(BatchOperator<?> train_set, BatchOperator<?> test_set) throws Exception {
BatchOperator.setParallelism(1);
new Pipeline().add(new StandardScaler().setSelectedCols(Chap16.FEATURE_COL_NAMES)).add(new VectorAssembler().setSelectedCols(Chap16.FEATURE_COL_NAMES).setOutputCol("vec")).add(new VectorToTensor().setSelectedCol("vec").setOutputCol("tensor").setReservedCols("quality")).add(new KerasSequentialRegressor().setTensorCol("tensor").setLabelCol("quality").setPredictionCol("pred").setLayers("Dense(64, activation='relu')", "Dense(64, activation='relu')", "Dense(64, activation='relu')", "Dense(64, activation='relu')", "Dense(64, activation='relu')").setNumEpochs(20)).fit(train_set).transform(test_set).lazyPrintStatistics().link(new EvalRegressionBatchOp().setLabelCol("quality").setPredictionCol("pred").lazyPrintMetrics());
BatchOperator.execute();
}
Aggregations