Search in sources :

Example 6 with RegressionEvaluation

use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.

the class TestRegressionEnsembles method testThreeDenseData.

@Test
public void testThreeDenseData() {
    Pair<Dataset<Regressor>, Dataset<Regressor>> p = RegressionDataGenerator.threeDimDenseTrainTest(1.0, false);
    BaggingTrainer<Regressor> bagT = new BaggingTrainer<>(t, new AveragingCombiner(), 10);
    Model<Regressor> llModel = bagT.train(p.getA());
    RegressionEvaluation llEval = evaluator.evaluate(llModel, p.getB());
    double expectedDim1 = 0.1632337913237244;
    double expectedDim2 = 0.1632337913237244;
    double expectedDim3 = -0.5727741047992028;
    double expectedAve = -0.08210217405058466;
    assertEquals(expectedDim1, llEval.r2(new Regressor(RegressionDataGenerator.firstDimensionName, Double.NaN)), 1e-6);
    assertEquals(expectedDim2, llEval.r2(new Regressor(RegressionDataGenerator.secondDimensionName, Double.NaN)), 1e-6);
    assertEquals(expectedDim3, llEval.r2(new Regressor(RegressionDataGenerator.thirdDimensionName, Double.NaN)), 1e-6);
    assertEquals(expectedAve, llEval.averageR2(), 1e-6);
    p = RegressionDataGenerator.threeDimDenseTrainTest(1.0, true);
    // reset RNG
    bagT = new BaggingTrainer<>(t, new AveragingCombiner(), 10);
    llModel = bagT.train(p.getA());
    llEval = evaluator.evaluate(llModel, p.getB());
    assertEquals(expectedDim1, llEval.r2(new Regressor(RegressionDataGenerator.firstDimensionName, Double.NaN)), 1e-6);
    assertEquals(expectedDim2, llEval.r2(new Regressor(RegressionDataGenerator.secondDimensionName, Double.NaN)), 1e-6);
    assertEquals(expectedDim3, llEval.r2(new Regressor(RegressionDataGenerator.thirdDimensionName, Double.NaN)), 1e-6);
    assertEquals(expectedAve, llEval.averageR2(), 1e-6);
}
Also used : AveragingCombiner(org.tribuo.regression.ensemble.AveragingCombiner) Dataset(org.tribuo.Dataset) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) BaggingTrainer(org.tribuo.ensemble.BaggingTrainer) Test(org.junit.jupiter.api.Test)

Example 7 with RegressionEvaluation

use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    SGDOptions o = new SGDOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    logger.info("Configuring gradient optimiser");
    RegressionObjective obj = null;
    switch(o.loss) {
        case ABSOLUTE:
            obj = new AbsoluteLoss();
            break;
        case SQUARED:
            obj = new SquaredLoss();
            break;
        case HUBER:
            obj = new Huber();
            break;
        default:
            logger.warning("Unknown objective function " + o.loss);
            logger.info(cm.usage());
            return;
    }
    StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
    logger.info(String.format("Set logging interval to %d", o.loggingInterval));
    RegressionFactory factory = new RegressionFactory();
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    Trainer<Regressor> trainer = new LinearSGDTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed);
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) LinearSGDTrainer(org.tribuo.regression.sgd.linear.LinearSGDTrainer) SquaredLoss(org.tribuo.regression.sgd.objectives.SquaredLoss) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) Huber(org.tribuo.regression.sgd.objectives.Huber) AbsoluteLoss(org.tribuo.regression.sgd.objectives.AbsoluteLoss) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) StochasticGradientOptimiser(org.tribuo.math.StochasticGradientOptimiser) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 8 with RegressionEvaluation

use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.

the class RegressionTest method regressionMLPTest.

@Test
public void regressionMLPTest() throws IOException {
    // Create the train and test data
    DataSource<Regressor> trainSource = new NonlinearGaussianDataSource(1024, new float[] { 1.0f, 2.0f, -3.0f, 4.0f }, 1.0f, 0.1f, -5.0f, 5.0f, -1.0f, 1.0f, 42);
    DataSource<Regressor> testSource = new NonlinearGaussianDataSource(1024, new float[] { 1.0f, 2.0f, -3.0f, 4.0f }, 1.0f, 0.1f, -5.0f, 5.0f, -1.0f, 1.0f, 42 * 42);
    MutableDataset<Regressor> trainData = new MutableDataset<>(trainSource);
    MutableDataset<Regressor> testData = new MutableDataset<>(testSource);
    // Build the MLP graph
    GraphDefTuple graphDefTuple = MLPExamples.buildMLPGraph(INPUT_NAME, trainData.getFeatureMap().size(), new int[] { 50, 50 }, trainData.getOutputs().size());
    // Configure the trainer
    Map<String, Float> gradientParams = new HashMap<>();
    gradientParams.put("learningRate", 0.01f);
    gradientParams.put("initialAccumulatorValue", 0.1f);
    FeatureConverter denseConverter = new DenseFeatureConverter(INPUT_NAME);
    OutputConverter<Regressor> outputConverter = new RegressorConverter();
    TensorFlowTrainer<Regressor> trainer = new TensorFlowTrainer<>(graphDefTuple.graphDef, graphDefTuple.outputName, GradientOptimiser.ADAGRAD, gradientParams, denseConverter, outputConverter, 16, 10, 16, -1);
    // Train the model
    TensorFlowModel<Regressor> model = trainer.train(trainData);
    // Run smoke test evaluation
    RegressionEvaluation eval = new RegressionEvaluator().evaluate(model, testData);
    Assertions.assertFalse(eval.r2().isEmpty());
    // Check Tribuo serialization
    Helpers.testModelSerialization(model, Regressor.class);
    // Check saved model bundle export
    Path outputPath = Files.createTempDirectory("tf-regression-test");
    model.exportModel(outputPath.toString());
    List<Path> files = Files.list(outputPath).collect(Collectors.toList());
    Assertions.assertNotEquals(0, files.size());
    // Cleanup saved model bundle
    Files.walk(outputPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
    Assertions.assertFalse(Files.exists(outputPath));
    // Cleanup created model
    model.close();
}
Also used : Path(java.nio.file.Path) GraphDefTuple(org.tribuo.interop.tensorflow.example.GraphDefTuple) HashMap(java.util.HashMap) NonlinearGaussianDataSource(org.tribuo.regression.example.NonlinearGaussianDataSource) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) RegressionEvaluator(org.tribuo.regression.evaluation.RegressionEvaluator) MutableDataset(org.tribuo.MutableDataset) File(java.io.File) Test(org.junit.jupiter.api.Test)

Example 9 with RegressionEvaluation

use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.

the class LIMEBase method explainWithSamples.

protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Example<Label> example) {
    // Predict using the full model, and generate a new example containing that prediction.
    Prediction<Label> prediction = innerModel.predict(example);
    Example<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction), example, 1.0f);
    // Sample a dataset.
    List<Example<Regressor>> sample = sampleData(example);
    // Generate a sparse model on the sampled data.
    SparseModel<Regressor> model = trainExplainer(labelledExample, sample);
    // Test the sparse model against the predictions of the real model.
    List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
    predictions.add(model.predict(labelledExample));
    RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory));
    return new Pair<>(new LIMEExplanation(model, prediction, evaluation), sample);
}
Also used : SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) ArrayExample(org.tribuo.impl.ArrayExample) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 10 with RegressionEvaluation

use of org.tribuo.regression.evaluation.RegressionEvaluation in project tribuo by oracle.

the class LIMEText method explain.

@Override
public LIMEExplanation explain(String inputText) {
    Example<Label> trueExample = extractor.extract(LabelFactory.UNKNOWN_LABEL, inputText);
    Prediction<Label> prediction = innerModel.predict(trueExample);
    ArrayExample<Regressor> bowExample = new ArrayExample<>(transformOutput(prediction));
    List<Token> tokens = tokenizerThreadLocal.get().tokenize(inputText);
    for (int i = 0; i < tokens.size(); i++) {
        bowExample.add(nameFeature(tokens.get(i).text, i), 1.0);
    }
    // Sample a dataset.
    List<Example<Regressor>> sample = sampleData(inputText, tokens);
    // Generate a sparse model on the sampled data.
    SparseModel<Regressor> model = trainExplainer(bowExample, sample);
    // Test the sparse model against the predictions of the real model.
    List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
    predictions.add(model.predict(bowExample));
    RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEText sampled data", regressionFactory));
    return new LIMEExplanation(model, prediction, evaluation);
}
Also used : SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Token(org.tribuo.util.tokens.Token) ArrayExample(org.tribuo.impl.ArrayExample) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor)

Aggregations

Regressor (org.tribuo.regression.Regressor)35 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)35 Dataset (org.tribuo.Dataset)21 Test (org.junit.jupiter.api.Test)14 List (java.util.List)9 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)7 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 ObjectInputStream (java.io.ObjectInputStream)6 Model (org.tribuo.Model)6 MutableDataset (org.tribuo.MutableDataset)4 RegressionEvaluator (org.tribuo.regression.evaluation.RegressionEvaluator)4 ArrayList (java.util.ArrayList)3 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)3 Example (org.tribuo.Example)3 Prediction (org.tribuo.Prediction)3 Label (org.tribuo.classification.Label)3 ArrayExample (org.tribuo.impl.ArrayExample)3 SquaredLoss (org.tribuo.regression.sgd.objectives.SquaredLoss)3 Pair (com.oracle.labs.mlrg.olcut.util.Pair)2