use of org.apache.spark.ml.regression.GeneralizedLinearRegression in project net.jgp.labs.spark by jgperrin.
the class GeneralizedLinearRegressionApp method main.
public static void main(String[] args) {
SparkSession spark = SparkSession.builder().appName("GeneralizedLinearRegressionApp").master("local[*]").getOrCreate();
// $example on$
// Load training data
// Dataset<Row> dataset = spark.read().format("libsvm")
// .load("data/mllib/sample_linear_regression_data.txt");
Dataset<Row> dataset = spark.read().format("libsvm").load("data/sample-ml/simplegauss.txt");
dataset.show(20, false);
System.out.println("Records: " + dataset.count());
GeneralizedLinearRegression glr = new GeneralizedLinearRegression().setFamily("gamma").setLink("log").setMaxIter(10).setRegParam(0.3);
// Fit the model
GeneralizedLinearRegressionModel model = glr.fit(dataset);
// Print the coefficients and intercept for generalized linear regression model
System.out.println("Coefficients: " + model.coefficients());
System.out.println("Intercept: " + model.intercept());
// Summarize the model over the training set and print out some metrics
GeneralizedLinearRegressionTrainingSummary summary = model.summary();
System.out.println("Coefficient Standard Errors: " + Arrays.toString(summary.coefficientStandardErrors()));
System.out.println("T Values: " + Arrays.toString(summary.tValues()));
System.out.println("P Values: " + Arrays.toString(summary.pValues()));
System.out.println("Dispersion: " + summary.dispersion());
System.out.println("Null Deviance: " + summary.nullDeviance());
System.out.println("Residual Degree Of Freedom Null: " + summary.residualDegreeOfFreedomNull());
System.out.println("Deviance: " + summary.deviance());
System.out.println("Residual Degree Of Freedom: " + summary.residualDegreeOfFreedom());
System.out.println("AIC: " + summary.aic());
System.out.println("Deviance Residuals: ");
summary.residuals().show();
Double feature = 2.0;
Vector features = Vectors.dense(feature);
double p = model.predict(features);
System.out.println("Prediction for feature " + feature + " is " + p);
feature = 11.0;
features = Vectors.dense(feature);
p = model.predict(features);
System.out.println("Prediction for feature " + feature + " is " + p);
spark.stop();
}
use of org.apache.spark.ml.regression.GeneralizedLinearRegression in project mmtf-spark by sbl-sdsc.
the class DatasetRegressor method main.
/**
* @param args args[0] path to parquet file, args[1] name of the prediction column
* @throws IOException
* @throws StructureException
*/
public static void main(String[] args) throws IOException {
if (args.length != 2) {
System.err.println("Usage: " + DatasetRegressor.class.getSimpleName() + " <parquet file> <prediction column name>");
System.exit(1);
}
// name of the prediction column
String label = args[1];
long start = System.nanoTime();
SparkSession spark = SparkSession.builder().master("local[*]").appName(DatasetRegressor.class.getSimpleName()).getOrCreate();
Dataset<Row> data = spark.read().parquet(args[0]).cache();
int featureCount = ((DenseVector) data.first().getAs("features")).numActives();
System.out.println("Feature count: " + featureCount);
System.out.println("Dataset size : " + data.count());
double testFraction = 0.3;
long seed = 123;
LinearRegression lr = new LinearRegression().setLabelCol(label).setFeaturesCol("features");
SparkRegressor reg = new SparkRegressor(lr, label, testFraction, seed);
System.out.println(reg.fit(data));
GBTRegressor gbt = new GBTRegressor().setLabelCol(label).setFeaturesCol("features");
reg = new SparkRegressor(gbt, label, testFraction, seed);
System.out.println(reg.fit(data));
GeneralizedLinearRegression glr = new GeneralizedLinearRegression().setLabelCol(label).setFeaturesCol("features").setFamily("gaussian").setLink("identity").setMaxIter(10).setRegParam(0.3);
reg = new SparkRegressor(glr, label, testFraction, seed);
System.out.println(reg.fit(data));
long end = System.nanoTime();
System.out.println((end - start) / 1E9 + " sec");
}
Aggregations