use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.
the class TestRegressionEnsembles method testThreeDenseData.
@Test
public void testThreeDenseData() {
Pair<Dataset<Regressor>, Dataset<Regressor>> p = RegressionDataGenerator.threeDimDenseTrainTest(1.0, false);
BaggingTrainer<Regressor> bagT = new BaggingTrainer<>(t, new AveragingCombiner(), 10);
Model<Regressor> llModel = bagT.train(p.getA());
RegressionEvaluation llEval = evaluator.evaluate(llModel, p.getB());
double expectedDim1 = 0.1632337913237244;
double expectedDim2 = 0.1632337913237244;
double expectedDim3 = -0.5727741047992028;
double expectedAve = -0.08210217405058466;
assertEquals(expectedDim1, llEval.r2(new Regressor(RegressionDataGenerator.firstDimensionName, Double.NaN)), 1e-6);
assertEquals(expectedDim2, llEval.r2(new Regressor(RegressionDataGenerator.secondDimensionName, Double.NaN)), 1e-6);
assertEquals(expectedDim3, llEval.r2(new Regressor(RegressionDataGenerator.thirdDimensionName, Double.NaN)), 1e-6);
assertEquals(expectedAve, llEval.averageR2(), 1e-6);
p = RegressionDataGenerator.threeDimDenseTrainTest(1.0, true);
// reset RNG
bagT = new BaggingTrainer<>(t, new AveragingCombiner(), 10);
llModel = bagT.train(p.getA());
llEval = evaluator.evaluate(llModel, p.getB());
assertEquals(expectedDim1, llEval.r2(new Regressor(RegressionDataGenerator.firstDimensionName, Double.NaN)), 1e-6);
assertEquals(expectedDim2, llEval.r2(new Regressor(RegressionDataGenerator.secondDimensionName, Double.NaN)), 1e-6);
assertEquals(expectedDim3, llEval.r2(new Regressor(RegressionDataGenerator.thirdDimensionName, Double.NaN)), 1e-6);
assertEquals(expectedAve, llEval.averageR2(), 1e-6);
}
use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
SGDOptions o = new SGDOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
logger.info("Configuring gradient optimiser");
RegressionObjective obj = null;
switch(o.loss) {
case ABSOLUTE:
obj = new AbsoluteLoss();
break;
case SQUARED:
obj = new SquaredLoss();
break;
case HUBER:
obj = new Huber();
break;
default:
logger.warning("Unknown objective function " + o.loss);
logger.info(cm.usage());
return;
}
StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
logger.info(String.format("Set logging interval to %d", o.loggingInterval));
RegressionFactory factory = new RegressionFactory();
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
Trainer<Regressor> trainer = new LinearSGDTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed);
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.
the class RegressionTest method regressionMLPTest.
@Test
public void regressionMLPTest() throws IOException {
// Create the train and test data
DataSource<Regressor> trainSource = new NonlinearGaussianDataSource(1024, new float[] { 1.0f, 2.0f, -3.0f, 4.0f }, 1.0f, 0.1f, -5.0f, 5.0f, -1.0f, 1.0f, 42);
DataSource<Regressor> testSource = new NonlinearGaussianDataSource(1024, new float[] { 1.0f, 2.0f, -3.0f, 4.0f }, 1.0f, 0.1f, -5.0f, 5.0f, -1.0f, 1.0f, 42 * 42);
MutableDataset<Regressor> trainData = new MutableDataset<>(trainSource);
MutableDataset<Regressor> testData = new MutableDataset<>(testSource);
// Build the MLP graph
GraphDefTuple graphDefTuple = MLPExamples.buildMLPGraph(INPUT_NAME, trainData.getFeatureMap().size(), new int[] { 50, 50 }, trainData.getOutputs().size());
// Configure the trainer
Map<String, Float> gradientParams = new HashMap<>();
gradientParams.put("learningRate", 0.01f);
gradientParams.put("initialAccumulatorValue", 0.1f);
FeatureConverter denseConverter = new DenseFeatureConverter(INPUT_NAME);
OutputConverter<Regressor> outputConverter = new RegressorConverter();
TensorFlowTrainer<Regressor> trainer = new TensorFlowTrainer<>(graphDefTuple.graphDef, graphDefTuple.outputName, GradientOptimiser.ADAGRAD, gradientParams, denseConverter, outputConverter, 16, 10, 16, -1);
// Train the model
TensorFlowModel<Regressor> model = trainer.train(trainData);
// Run smoke test evaluation
RegressionEvaluation eval = new RegressionEvaluator().evaluate(model, testData);
Assertions.assertFalse(eval.r2().isEmpty());
// Check Tribuo serialization
Helpers.testModelSerialization(model, Regressor.class);
// Check saved model bundle export
Path outputPath = Files.createTempDirectory("tf-regression-test");
model.exportModel(outputPath.toString());
List<Path> files = Files.list(outputPath).collect(Collectors.toList());
Assertions.assertNotEquals(0, files.size());
// Cleanup saved model bundle
Files.walk(outputPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
Assertions.assertFalse(Files.exists(outputPath));
// Cleanup created model
model.close();
}
use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.
the class LIMEBase method explainWithSamples.
protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Example<Label> example) {
// Predict using the full model, and generate a new example containing that prediction.
Prediction<Label> prediction = innerModel.predict(example);
Example<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction), example, 1.0f);
// Sample a dataset.
List<Example<Regressor>> sample = sampleData(example);
// Generate a sparse model on the sampled data.
SparseModel<Regressor> model = trainExplainer(labelledExample, sample);
// Test the sparse model against the predictions of the real model.
List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
predictions.add(model.predict(labelledExample));
RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory));
return new Pair<>(new LIMEExplanation(model, prediction, evaluation), sample);
}
use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.
the class LIMEText method explain.
@Override
public LIMEExplanation explain(String inputText) {
Example<Label> trueExample = extractor.extract(LabelFactory.UNKNOWN_LABEL, inputText);
Prediction<Label> prediction = innerModel.predict(trueExample);
ArrayExample<Regressor> bowExample = new ArrayExample<>(transformOutput(prediction));
List<Token> tokens = tokenizerThreadLocal.get().tokenize(inputText);
for (int i = 0; i < tokens.size(); i++) {
bowExample.add(nameFeature(tokens.get(i).text, i), 1.0);
}
// Sample a dataset.
List<Example<Regressor>> sample = sampleData(inputText, tokens);
// Generate a sparse model on the sampled data.
SparseModel<Regressor> model = trainExplainer(bowExample, sample);
// Test the sparse model against the predictions of the real model.
List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
predictions.add(model.predict(bowExample));
RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEText sampled data", regressionFactory));
return new LIMEExplanation(model, prediction, evaluation);
}
Aggregations