Search in sources :

Example 1 with TrainTestSplitter

use of org.tribuo.evaluation.TrainTestSplitter in project tribuo by oracle.

the class TestHdbscan method runEvaluation.

public static void runEvaluation(HdbscanTrainer trainer) {
    DataSource<ClusterID> gaussianSource = new GaussianClusterDataSource(1000, 1L);
    TrainTestSplitter<ClusterID> splitter = new TrainTestSplitter<>(gaussianSource, 0.7f, 2L);
    Dataset<ClusterID> trainData = new MutableDataset<>(splitter.getTrain());
    Dataset<ClusterID> testData = new MutableDataset<>(splitter.getTest());
    ClusteringEvaluator eval = new ClusteringEvaluator();
    HdbscanModel model = trainer.train(trainData);
    // Test serialization
    Helpers.testModelSerialization(model, ClusterID.class);
    ClusteringEvaluation trainEvaluation = eval.evaluate(model, trainData);
    assertFalse(Double.isNaN(trainEvaluation.adjustedMI()));
    assertFalse(Double.isNaN(trainEvaluation.normalizedMI()));
    ClusteringEvaluation testEvaluation = eval.evaluate(model, testData);
    assertFalse(Double.isNaN(testEvaluation.adjustedMI()));
    assertFalse(Double.isNaN(testEvaluation.normalizedMI()));
}
Also used : TrainTestSplitter(org.tribuo.evaluation.TrainTestSplitter) ClusterID(org.tribuo.clustering.ClusterID) ClusteringEvaluator(org.tribuo.clustering.evaluation.ClusteringEvaluator) GaussianClusterDataSource(org.tribuo.clustering.example.GaussianClusterDataSource) MutableDataset(org.tribuo.MutableDataset) ClusteringEvaluation(org.tribuo.clustering.evaluation.ClusteringEvaluation)

Example 2 with TrainTestSplitter

use of org.tribuo.evaluation.TrainTestSplitter in project tribuo by oracle.

the class TestKNN method testKNNClassification.

private static void testKNNClassification(KNNTrainer<Label> trainer) {
    NoisyInterlockingCrescentsDataSource source = new NoisyInterlockingCrescentsDataSource(200, 1, 0.1);
    TrainTestSplitter<Label> splitter = new TrainTestSplitter<>(source, 0.8, 1L);
    MutableDataset<Label> trainingDataset = new MutableDataset<>(splitter.getTrain());
    MutableDataset<Label> testingDataset = new MutableDataset<>(splitter.getTest());
    Model<Label> model = trainer.train(trainingDataset);
    // The expected list of predictions
    List<String> expectedList = Arrays.asList("O", "X", "O", "X", "O", "X", "O", "X", "O", "X", "O", "X", "X", "O", "X", "O", "X", "O", "X", "O", "O", "X", "O", "X", "X", "X", "O", "X", "O", "O", "O", "O", "X", "O", "O", "X", "O", "X", "X", "O");
    List<Prediction<Label>> predictions = model.predict(testingDataset);
    List<String> predictionList = new ArrayList<>();
    for (Prediction<Label> prediction : predictions) {
        predictionList.add(prediction.getOutput().getLabel());
    }
    assertEquals(predictionList, expectedList);
}
Also used : TrainTestSplitter(org.tribuo.evaluation.TrainTestSplitter) Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) NoisyInterlockingCrescentsDataSource(org.tribuo.classification.example.NoisyInterlockingCrescentsDataSource) MutableDataset(org.tribuo.MutableDataset)

Example 3 with TrainTestSplitter

use of org.tribuo.evaluation.TrainTestSplitter in project tribuo by oracle.

the class TestKNN method knnClassificationEvaluationTest.

@Test
public void knnClassificationEvaluationTest() {
    NoisyInterlockingCrescentsDataSource source = new NoisyInterlockingCrescentsDataSource(400, 1, 0.1);
    TrainTestSplitter<Label> splitter = new TrainTestSplitter<>(source, 0.8, 1L);
    MutableDataset<Label> trainingDataset = new MutableDataset<>(splitter.getTrain());
    MutableDataset<Label> testingDataset = new MutableDataset<>(splitter.getTest());
    Model<Label> model = classificationTrainer.train(trainingDataset);
    LabelEvaluation evaluation = (LabelEvaluation) trainingDataset.getOutputFactory().getEvaluator().evaluate(model, testingDataset);
    assertEquals(evaluation.accuracy(DemoLabelDataSource.FIRST_CLASS), 1.0);
    assertEquals(evaluation.accuracy(DemoLabelDataSource.SECOND_CLASS), 1.0);
    assertEquals(evaluation.recall(DemoLabelDataSource.FIRST_CLASS), 1.0);
    assertEquals(evaluation.recall(DemoLabelDataSource.SECOND_CLASS), 1.0);
    // Test serialization
    Helpers.testModelSerialization(model, Label.class);
}
Also used : TrainTestSplitter(org.tribuo.evaluation.TrainTestSplitter) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Label(org.tribuo.classification.Label) NoisyInterlockingCrescentsDataSource(org.tribuo.classification.example.NoisyInterlockingCrescentsDataSource) MutableDataset(org.tribuo.MutableDataset) Test(org.junit.jupiter.api.Test)

Example 4 with TrainTestSplitter

use of org.tribuo.evaluation.TrainTestSplitter in project gluon-samples by gluonhq.

the class Main method train.

private void train() {
    tpSeries.getData().clear();
    fpSeries.getData().clear();
    Thread thread = new Thread() {

        @Override
        public void run() {
            try {
                URL dataUrl = Main.class.getResource("/bezdekIris.data");
                var irisHeaders = new String[] { "sepalLength", "sepalWidth", "petalLength", "petalWidth", "species" };
                ListDataSource<org.tribuo.classification.Label> irisData = new CSVLoader<>(new LabelFactory()).loadDataSource(dataUrl, irisHeaders[4], irisHeaders);
                TrainTestSplitter<org.tribuo.classification.Label> irisSplitter = new TrainTestSplitter<>(irisData, 0.7, 1L);
                MutableDataset<org.tribuo.classification.Label> trainData = new MutableDataset<>(irisSplitter.getTrain());
                MutableDataset<org.tribuo.classification.Label> testData = new MutableDataset<>(irisSplitter.getTest());
                var cartTrainer = new CARTClassificationTrainer();
                TreeModel<org.tribuo.classification.Label> tree = cartTrainer.train(trainData);
                var evaluator = new LabelEvaluator();
                LabelEvaluation evaluation = evaluator.evaluate(tree, testData);
                for (org.tribuo.classification.Label label : trainData.getOutputs()) {
                    double f1 = evaluation.f1(label);
                    double fn = evaluation.fn(label);
                    double fp = evaluation.fp(label);
                    double tn = evaluation.tn(label);
                    double tp = evaluation.tp(label);
                    javafx.application.Platform.runLater(() -> {
                        tpSeries.getData().add(new Data<>(label.getLabel(), tp));
                        fpSeries.getData().add(new Data<>(label.getLabel(), fp));
                    });
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    };
    thread.start();
}
Also used : TrainTestSplitter(org.tribuo.evaluation.TrainTestSplitter) Label(javafx.scene.control.Label) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) URL(java.net.URL) CARTClassificationTrainer(org.tribuo.classification.dtree.CARTClassificationTrainer) LabelFactory(org.tribuo.classification.LabelFactory) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) MutableDataset(org.tribuo.MutableDataset)

Example 5 with TrainTestSplitter

use of org.tribuo.evaluation.TrainTestSplitter in project tribuo by oracle.

the class TestHdbscan method runBasicTrainPredict.

public static void runBasicTrainPredict(HdbscanTrainer trainer) {
    DataSource<ClusterID> gaussianSource = new GaussianClusterDataSource(1000, 1L);
    TrainTestSplitter<ClusterID> splitter = new TrainTestSplitter<>(gaussianSource, 0.8f, 2L);
    Dataset<ClusterID> trainData = new MutableDataset<>(splitter.getTrain());
    Dataset<ClusterID> testData = new MutableDataset<>(splitter.getTest());
    HdbscanModel model = trainer.train(trainData);
    for (HdbscanTrainer.ClusterExemplar e : model.getClusterExemplars()) {
        assertTrue(e.getMaxDistToEdge() > 0.0);
    }
    List<Integer> clusterLabels = model.getClusterLabels();
    List<Double> outlierScores = model.getOutlierScores();
    List<Pair<Integer, List<Feature>>> exemplarLists = model.getClusters();
    List<HdbscanTrainer.ClusterExemplar> exemplars = model.getClusterExemplars();
    assertEquals(exemplars.size(), exemplarLists.size());
    // Check there's at least one exemplar per label
    Set<Integer> exemplarLabels = exemplarLists.stream().map(Pair::getA).collect(Collectors.toSet());
    Set<Integer> clusterLabelSet = new HashSet<>(clusterLabels);
    // Remove the noise label
    clusterLabelSet.remove(Integer.valueOf(0));
    assertEquals(exemplarLabels, clusterLabelSet);
    for (int i = 0; i < exemplars.size(); i++) {
        HdbscanTrainer.ClusterExemplar e = exemplars.get(i);
        Pair<Integer, List<Feature>> p = exemplarLists.get(i);
        assertEquals(model.getFeatureIDMap().size(), e.getFeatures().size());
        assertEquals(p.getB().size(), e.getFeatures().size());
        SGDVector otherFeatures = DenseVector.createDenseVector(new ArrayExample<>(trainData.getOutputFactory().getUnknownOutput(), p.getB()), model.getFeatureIDMap(), false);
        assertEquals(otherFeatures, e.getFeatures());
    }
    int[] expectedIntClusterLabels = { 4, 3, 4, 5, 3, 5, 3, 4, 3, 4, 5, 5, 3, 4, 4, 0, 3, 4, 0, 5, 5, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 4, 5, 3, 5, 3, 4, 3, 4, 4, 3, 0, 5, 0, 4, 4, 4, 4, 4, 5, 4, 3, 4, 4, 4, 4, 4, 5, 3, 4, 3, 5, 3, 4, 5, 3, 4, 0, 5, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 5, 3, 4, 4, 3, 4, 3, 5, 5, 0, 5, 4, 4, 3, 5, 5, 4, 5, 5, 3, 5, 4, 4, 3, 5, 4, 5, 5, 5, 4, 4, 5, 5, 3, 5, 4, 4, 3, 5, 5, 3, 5, 4, 4, 5, 5, 5, 3, 5, 4, 5, 3, 4, 3, 5, 4, 4, 3, 3, 5, 4, 4, 5, 5, 4, 3, 4, 5, 4, 5, 4, 3, 3, 3, 4, 5, 4, 5, 5, 3, 4, 3, 3, 4, 5, 3, 5, 5, 5, 5, 5, 4, 4, 3, 4, 5, 5, 4, 4, 3, 4, 3, 4, 5, 4, 4, 5, 4, 3, 3, 0, 3, 5, 5, 3, 3, 3, 4, 3, 3, 5, 5, 5, 5, 3, 5, 5, 3, 5, 3, 4, 5, 3, 3, 3, 4, 4, 3, 3, 3, 5, 3, 4, 5, 3, 5, 5, 5, 3, 5, 3, 5, 4, 5, 4, 4, 5, 5, 5, 3, 5, 4, 5, 5, 4, 4, 4, 5, 4, 5, 4, 3, 3, 4, 5, 4, 4, 3, 3, 3, 4, 5, 4, 4, 4, 4, 5, 4, 4, 4, 5, 3, 5, 4, 5, 3, 5, 3, 5, 4, 4, 0, 4, 4, 5, 3, 4, 5, 5, 0, 5, 4, 5, 3, 4, 3, 5, 5, 4, 5, 5, 5, 5, 5, 5, 3, 5, 4, 3, 3, 5, 3, 4, 5, 4, 3, 5, 4, 3, 3, 3, 5, 4, 5, 4, 5, 5, 4, 3, 5, 4, 5, 4, 5, 4, 3, 4, 5, 4, 4, 5, 5, 5, 3, 4, 5, 4, 0, 3, 5, 3, 4, 3, 3, 5, 5, 5, 4, 4, 3, 3, 4, 3, 5, 3, 3, 4, 3, 5, 3, 4, 5, 4, 4, 3, 4, 4, 3, 3, 5, 4, 4, 5, 3, 5, 3, 3, 4, 5, 3, 4, 5, 5, 4, 4, 4, 5, 5, 5, 5, 3, 3, 4, 4, 4, 4, 4, 3, 5, 4, 3, 4, 4, 5, 3, 5, 3, 4, 5, 4, 4, 5, 3, 4, 4, 4, 5, 5, 4, 5, 0, 4, 5, 3, 4, 5, 4, 4, 4, 5, 4, 4, 4, 0, 3, 4, 5, 5, 4, 4, 3, 3, 4, 3, 3, 4, 5, 5, 4, 3, 5, 4, 4, 4, 4, 5, 4, 4, 3, 4, 5, 5, 4, 3, 4, 5, 4, 3, 5, 5, 5, 3, 4, 4, 4, 4, 4, 4, 5, 3, 3, 3, 5, 5, 4, 5, 3, 5, 3, 5, 4, 5, 3, 4, 5, 4, 3, 5, 4, 4, 5, 5, 0, 3, 3, 5, 5, 3, 0, 5, 5, 5, 5, 3, 4, 5, 4, 3, 3, 4, 5, 4, 4, 0, 5, 3, 4, 4, 4, 4, 5, 5, 5, 3, 5, 4, 3, 3, 5, 3, 4, 3, 5, 3, 3, 4, 3, 5, 4, 3, 4, 3, 0, 4, 5, 5, 5, 3, 4, 3, 5, 5, 4, 5, 4, 4, 4, 5, 4, 3, 4, 3, 4, 5, 3, 5, 4, 5, 3, 0, 4, 0, 4, 3, 3, 4, 3, 0, 3, 3, 3, 3, 4, 4, 5, 3, 3, 5, 4, 4, 4, 5, 5, 5, 3, 3, 4, 4, 3, 4, 5, 3, 4, 4, 5, 3, 4, 4, 4, 3, 4, 4, 4, 5, 4, 4, 5, 5, 5, 4, 4, 4, 5, 5, 5, 5, 4, 3, 4, 3, 3, 3, 4, 4, 5, 4, 5, 4, 4, 4, 4, 4, 5, 4, 5, 5, 5, 4, 3, 5, 3, 5, 4, 5, 4, 4, 5, 0, 5, 3, 4, 5, 4, 4, 5, 3, 4, 4, 3, 5, 4, 4, 4, 5, 3, 3, 4, 4, 5, 5, 5, 3, 4, 3, 4, 5, 5, 4, 4, 3, 3, 4, 4, 5, 5, 5, 3, 4, 3, 4, 4, 4, 5, 5, 5, 0, 4, 5, 5, 3, 3, 4, 5, 4, 3, 3, 4, 3, 4, 5, 4, 3, 4, 5, 5, 3, 3, 4, 4, 3, 3, 5, 4, 5, 3, 4, 5, 4, 3, 3, 4, 5, 5, 5, 3, 3, 4, 4, 5, 5, 5, 4, 5, 5, 5, 4, 4, 4, 5, 4, 5, 5, 3, 3, 4, 4, 3, 5, 5, 3, 3, 4, 4, 5, 3, 3, 3 };
    List<Integer> expectedClusterLabels = Arrays.stream(expectedIntClusterLabels).boxed().collect(Collectors.toList());
    double[] expectedDoubleOutlierScores = { 0.46676776260759345, 0.2743698754772864, 0.7559982720268424, 0.8501840034553623, 0.49318092730464635, 0.13138938738160744, 0.4713199767058086, 0.6252876350317327, 0.5993132028604171, 0.5099794170903283, 0.34739656697344323, 0.7877610766946352, 0.6725050057122981, 0.0, 0.3443411864540462, 0.942517632028674, 0.49375727602750374, 0.0, 0.8895331356424256, 0.6324670047095703, 0.42882347542687815, 0.49318092730464635, 0.691903096844513, 0.6380593801053474, 0.2406826282408977, 0.0, 0.6968734399959293, 0.3610993140443196, 0.5535004360403812, 0.6096176323143576, 0.0, 0.0, 0.3913407463849664, 0.9519927727728552, 0.0, 0.5393032152890598, 0.6503011262262826, 0.1433842216333847, 0.49506479112319557, 0.5709634323345956, 0.4563315958116082, 0.44618653226418115, 0.44814977073944906, 0.0, 0.9074703755075781, 0.7291450269088865, 0.9484293814844095, 0.0, 0.3705649211930552, 0.3480591862782948, 0.0, 0.7517459138118392, 0.5690934599956823, 0.6502288567686347, 0.6206513888636165, 0.6017282507095788, 0.5733419619457072, 0.7117066450461398, 0.7782759723827917, 0.0, 0.3564466150534611, 0.5610358783143924, 0.3777803191375566, 0.6968916961864624, 0.0, 0.6017974254583286, 0.46129103283467177, 0.41698356410558357, 0.8424752330761394, 0.5904539743502417, 0.8830377178678264, 0.8611226634391924, 0.20674767396012295, 0.4730307972339155, 0.0, 0.5304342181470512, 0.43017634165005014, 0.4343582676741472, 0.0, 0.47654625883887125, 0.0, 0.25196441733320185, 0.2957896263676023, 0.3640784213318997, 0.0, 0.2569367496898257, 0.40541971866030124, 0.6111594683636595, 0.5793323777062229, 0.8402720631189264, 0.5844081900168824, 0.4972444475275547, 0.8821097291727182, 0.3736050635122288, 0.3443411864540462, 0.0, 0.8810705083902441, 0.0, 0.0, 0.0, 0.455748881321304, 0.27939799074230476, 0.8371148472011978, 0.7906776310593313, 0.16481581471815432, 0.0, 0.8473854479601578, 0.06314528545928999, 0.0, 0.10155008808627008, 0.3964971775769046, 0.0, 0.3443411864540462, 0.6098919291901133, 0.6379495148152017, 0.5841073563731103, 0.5229236206301431, 0.6501619239874012, 0.3443411864540462, 0.5953992476545811, 0.5072740714951296, 0.4605406727289195, 0.0, 0.6544956955298067, 0.6501619239874012, 0.5609881604044372, 0.8026896506782976, 0.3964971775769046, 0.08357327465242192, 0.7193239335926954, 0.5595574991480672, 0.0, 0.1995613282017613, 0.0, 0.8597242790184043, 0.6979503825421514, 0.7104042746619621, 0.5548121181682619, 0.2935311018632297, 0.4636958976102995, 0.7083109193732234, 0.471177917299232, 0.0, 0.7939891118341219, 0.24683871216813147, 0.3464327881240151, 0.7460168123057431, 0.7134267350609619, 0.5610358783143924, 0.07838284850471278, 0.7396457383752255, 0.7617607303705443, 0.47464886817517193, 0.3163448961273, 0.7773362486663004, 0.8263291289832474, 0.5238013994386587, 0.5675502692728531, 0.0, 0.46129103283467177, 0.0, 0.43476375433827963, 0.0, 0.0, 0.8277619726852297, 0.8435902856429793, 0.0, 0.4327059463200057, 0.8641362465433343, 0.27939799074230476, 0.18594335385103455, 0.5298320190732928, 0.14012035981759297, 0.27066293065269187, 0.6360710982377022, 0.45019911030204196, 0.21506286554014598, 0.3443411864540462, 0.8282719359151842, 0.09812661154489999, 0.5733419619457072, 0.3756037273200026, 0.551629899884865, 0.0, 0.6544956955298067, 0.5696608258449514, 0.6907048630196866, 0.45261858796940857, 0.2284755074997018, 0.0, 0.0, 0.6725050057122981, 0.5653348091708341, 0.8901221625800861, 0.5072740714951296, 0.2765984401011835, 0.08357327465242192, 0.6470473237995003, 0.7959804069715428, 0.5909403594346183, 0.514153415940259, 0.0, 0.9292108460851046, 0.5819766000229493, 0.0, 0.579942440138181, 0.8774911291582113, 0.05399443648683844, 0.6216675811633252, 0.41698356410558357, 0.5404377149880093, 0.14012035981759297, 0.2743698754772864, 0.3486561149101042, 0.49766275174228025, 0.7403568247909462, 0.6725050057122981, 0.7366428541023031, 0.31848573305331973, 0.0, 0.6351806919052676, 0.6379936443946175, 0.5643094888768054, 0.87242048722334, 0.45114165259596173, 0.4267483686744079, 0.4459420844294427, 0.4979058181573891, 0.6687168793513216, 0.6725039248778191, 0.5898326395194236, 0.3786565035491881, 0.49270476278526754, 0.6845482002842296, 0.5312504911095062, 0.0, 0.14012035981759297, 0.0, 0.13260183495549782, 0.0, 0.1114322036219213, 0.0, 0.22354672645501728, 0.4437011486963448, 0.32454746691008174, 0.2129513856466727, 0.7322362660260833, 0.6428127919521037, 0.7299583115950901, 0.3805190861517057, 0.5942874634038553, 0.32454746691008174, 0.48232688782776534, 0.0, 0.5079723021903902, 0.426792378877326, 0.0, 0.0, 0.0, 0.46129103283467177, 0.42701054653212456, 0.05399443648683844, 0.32963857165340793, 0.4011402980489066, 0.49884143112507695, 0.6742625110529403, 0.7000703820982973, 0.0059186060290421505, 0.3935328253717545, 0.4982012977088105, 0.46129103283467177, 0.1884812300740285, 0.0, 0.0, 0.0, 0.7193239335926954, 0.20322704801088176, 0.6147177259392704, 0.6386340973679066, 0.4605406727289195, 0.35336940249769566, 0.7532381385481872, 0.0, 0.0, 0.9119502216589748, 0.2739356391172154, 0.10511679922476258, 0.5928945724077263, 0.6483155244223532, 0.32972277729175326, 0.7437779508567764, 0.8667982942943174, 0.9016466990936328, 0.0, 0.20095036828617263, 0.09796389879511136, 0.6237249089180936, 0.46217811332893666, 0.4874368146961645, 0.8015888606177888, 0.2129513856466727, 0.25867271023886684, 0.8627287953882614, 0.6387184765952185, 0.29477547307621677, 0.44990336909740747, 0.27939799074230476, 0.0, 0.719192575186314, 0.0752302226040954, 0.0, 0.7480360862618636, 0.8490766015294575, 0.3804955800190625, 0.6725050057122981, 0.41293235299928155, 0.7906776310593313, 0.5576591657399872, 0.5993132028604171, 0.11368083044659938, 0.0, 0.2959490414383299, 0.5072740714951296, 0.3872053944517456, 0.3990149169995906, 0.0, 0.31715178735852456, 0.7169533729847726, 0.5178185436513978, 0.7049185121521424, 0.6968734399959293, 0.0, 0.02065034905954999, 0.1400214190055723, 0.5507059666282947, 0.3163448961273, 0.4481475262869722, 0.18097453586463574, 0.811581582918492, 0.563763712019848, 0.08357327465242192, 0.053392009649182115, 0.7567212136194974, 0.8504256359936448, 0.8054570184077979, 0.8501840034553623, 0.5006038683696926, 0.4723191896910045, 0.39335681832489, 0.7128589406590711, 0.9126499391266593, 0.5993132028604171, 0.7707901441751914, 0.5051009955803989, 0.6729664427879043, 0.3872053944517456, 0.6256349097006557, 0.5146933549477546, 0.5596642473643627, 0.11368083044659938, 0.6248666542874606, 0.7404248189769456, 0.22821556208077354, 0.5893320991429196, 0.3935328253717545, 0.22354672645501728, 0.7087678762850806, 0.7944281661194821, 0.5404377149880093, 0.8020958632147266, 0.46174439351502505, 0.09812661154489999, 0.7492684267597052, 0.6309454765188025, 0.576131133494547, 0.11028289394958157, 0.5655946122939484, 0.05399443648683844, 0.43650020150407365, 0.08212338693601973, 0.4901722099881468, 0.0, 0.0, 0.0, 0.4148561856261437, 0.6040842867023796, 0.8217703793127248, 0.6367390214802099, 0.5239122935203155, 0.31517820690329956, 0.00509618317782945, 0.6253843090069545, 0.5751062454265111, 0.5749814506098023, 0.6799761261709416, 0.5438395449287226, 0.31848573305331973, 0.0, 0.35152584520905694, 0.5004067400291996, 0.7291450269088865, 0.39649970945668, 0.11161437102919114, 0.0, 0.46368151674524083, 0.7904624552413114, 0.3443411864540462, 0.31848573305331973, 0.37130998229165235, 0.5733419619457072, 0.6544956955298067, 0.6316843154288729, 0.5617117678592323, 0.8786618155909707, 0.575985154427317, 0.32454746691008174, 0.4605406727289195, 0.4943123620097287, 0.8719502427750575, 0.0, 0.6210977617501681, 0.5526961492564358, 0.0, 0.3756037273200026, 0.8075793400487667, 0.6725050057122981, 0.39893901604072746, 0.3410240139382218, 0.2935311018632297, 0.635440473010447, 0.29477547307621677, 0.0, 0.02065034905954999, 0.9392912491571164, 0.42249911335168855, 0.14097336145379602, 0.6847290646834571, 0.5194865825883319, 0.14012035981759297, 0.7832174963908662, 0.4365915265432285, 0.3819107449024405, 0.6726949923291138, 0.4011402980489066, 0.3935328253717545, 0.25196441733320185, 0.881350459326978, 0.7349084526542964, 0.6047151306067444, 0.0, 0.20977384411280398, 0.4679944192020826, 0.6698566773200858, 0.0, 0.2959490414383299, 0.7932391110064335, 0.622319399561805, 0.8015596575831598, 0.39998141687078625, 0.4574247863501638, 0.17981462714349217, 0.3717035453553832, 0.6667041054756846, 0.553583549266421, 0.3443411864540462, 0.5433733027536815, 0.5148600399821457, 0.6968734399959293, 0.580828106385362, 0.39422048008001054, 0.5733419619457072, 0.20675485588245224, 0.35152584520905694, 0.5031988755530088, 0.5467067205040299, 0.0, 0.4414501866925844, 0.0, 0.49147375734502297, 0.7537969758806241, 0.05399443648683844, 0.0, 0.8130406304754614, 0.3893369731046228, 0.08857940056865066, 0.12319636014473168, 0.47827834758956966, 0.46196635668966657, 0.36779495204329093, 0.3139968403491683, 0.455185529426422, 0.0, 0.0, 0.23496850036493677, 0.31517820690329956, 0.4727343218057717, 0.0, 0.6980083836215727, 0.7249879488732429, 0.5404377149880093, 0.36704925127119226, 0.0, 0.0, 0.7939891118341219, 0.858959809289114, 0.6243592907663799, 0.0, 0.7193239335926954, 0.4487733002302773, 0.7295180327424304, 0.29131593514163934, 0.13649363353017852, 0.3562234122805952, 0.8627287953882614, 0.0, 0.8959448055905569, 0.697568664523059, 0.6847290646834571, 0.3811096272517188, 0.3921958171277884, 0.0, 0.9138808799654548, 0.8373253417127464, 0.03806887171267048, 0.0, 0.32210719489953354, 0.8993257360128568, 0.5733419619457072, 0.5298320190732928, 0.5406291933185892, 0.4936169024186612, 0.0, 0.5676673482591083, 0.797644954156777, 0.25196441733320185, 0.7517459138118392, 0.8804787122327714, 0.4437951161132294, 0.517719961495386, 0.6736118674798475, 0.0, 0.3869796117891282, 0.0, 0.0, 0.11368083044659938, 0.49476284568662654, 0.0, 0.3990149169995906, 0.5497567156841348, 0.49755861260331113, 0.5431034895139442, 0.4849594616738354, 0.634359329239355, 0.6102777431117171, 0.5993132028604171, 0.029110696378704337, 0.0041352367783195065, 0.5455357591968216, 0.614637828267818, 0.3872053944517456, 0.8774911291582113, 0.0, 0.515011903882065, 0.5653271117366987, 0.6725050057122981, 0.8828338637775253, 0.0, 0.025305516244955806, 0.1153305490422244, 0.1189992879103362, 0.0, 0.4623766305089029, 0.4972070270025144, 0.0, 0.19162317880112645, 0.0, 0.0, 0.46322024301962184, 0.3564486209165576, 0.13221917122566706, 0.8074269371585792, 0.8086158459470157, 0.45215040851165234, 0.7250991976553953, 0.7104042746619621, 0.0, 0.7564957221431223, 0.6678218872757002, 0.49889559243077775, 0.6362040007967809, 0.8504597507708, 0.45590549902450317, 0.9009290923658844, 0.2935311018632297, 0.8837545475640602, 0.5878793983623484, 0.0, 0.7173151984840158, 0.6299022163181591, 0.0, 0.9104797902272275, 0.5661891151536319, 0.6891663882134245, 0.5072740714951296, 0.594386514812078, 0.5731856857918096, 0.0, 0.19553532551745423, 0.6285842416556849, 0.5404377149880093, 0.029110696378704337, 0.47054972913556337, 0.5475804999462761, 0.3805190861517057, 0.7966719729940861, 0.7555806975782535, 0.010061020299601542, 0.6333336380718192, 0.8917306933911169, 0.2935311018632297, 0.07406510775737263, 0.45215040851165234, 0.18097453586463574, 0.0, 0.08132701192855352, 0.4366734759781771, 0.3564486209165576, 0.7310554814333057, 0.9204338265893961, 0.08471089907230389, 0.5696608258449514, 0.6922091225884852, 0.6279823624589616, 0.0, 0.3935328253717545, 0.0, 0.7413065729103836, 0.680500524047025, 0.6769166895418774, 0.5002418053594004, 0.0, 0.39649970945668, 0.0, 0.25196441733320185, 0.22794789542087268, 0.535253084286895, 0.797644954156777, 0.6736140765300983, 0.6390655334750306, 0.4613651254516825, 0.7655684232457723, 0.0, 0.3872053944517456, 0.6047587670112572, 0.7216841523191753, 0.6581024749535422, 0.2957896263676023, 0.0, 0.0, 0.1358560783357583, 0.3443411864540462, 0.32972277729175326, 0.7412343863236006, 0.7137328127320526, 0.4807164899348386, 0.5816339146436611, 0.0, 0.39649970945668, 0.6554079940503144, 0.32081918857815017, 0.6250251993493536, 0.7099539309159079, 0.8460951407432118, 0.7210253357230385, 0.07792093768838082, 0.0, 0.8460951407432118, 0.617166909604469, 0.0, 0.17820178754345894, 0.9208897313509731, 0.1494984069468719, 0.6800023209762943, 0.0, 0.4314997931110507, 0.3818456558681199, 0.3559937076353532, 0.3573443810325546, 0.0, 0.0, 0.4108575710450929, 0.5404377149880093, 0.3107705511500243, 0.34467641363479895, 0.6588136178954032, 0.0, 0.7124877485819077, 0.4271754772995946, 0.6352932475152177, 0.6358395247009956, 0.6157498971205291, 0.2129513856466727, 0.3300832680463701, 0.3922526981839124, 0.0, 0.0, 0.5993702342757096, 0.6044080434255494, 0.0, 0.4254461589520476, 0.505472792649011, 0.5052779575119282, 0.0, 0.7683354875196753, 0.41622494718112346, 0.28903657152106876, 0.34739656697344323, 0.0, 0.22045718710077167, 0.850119563971185, 0.53811553498556, 0.0, 0.5899675992703288, 0.3935328253717545, 0.5352299710205285, 0.5928945724077263, 0.4252236114436363, 0.7329731025102499, 0.9228363488631618, 0.10329887928234749, 0.7193239335926954, 0.025305516244955806, 0.6827262685358917, 0.5204591459253809, 0.31848573305331973, 0.8307996083812397, 0.08118104079195632, 0.5490352427566516, 0.5779985937481655, 0.7740622259507297, 0.7718601671125394, 0.4945464442491214, 0.818070291907619, 0.2935311018632297, 0.5633646723184442, 0.0, 0.025305516244955806, 0.803889349277999, 0.04556290658626683, 0.6682951293623198, 0.0, 0.16135790337112854, 0.2656863426127436, 0.5856638908845988, 0.6415288134310423, 0.0, 0.5826267384294974, 0.8404883382141526, 0.7401049925075427, 0.8552831593916908, 0.5688163749124342, 0.0, 0.0, 0.4108575710450929, 0.0, 0.7858019964605708, 0.23346803074513822, 0.670754636973998, 0.07860444563433111, 0.0, 0.7286247704137883, 0.4334354337345858, 0.49880706067428204, 0.0, 0.5696608258449514, 0.29875909495688924, 0.5298320190732928, 0.18191542983970188, 0.3769853117713595, 0.43820864595574294, 0.7169533729847726, 0.7224900394386976, 0.41293235299928155, 0.5819766000229493, 0.4295586290786334, 0.936944591291591, 0.5404377149880093, 0.3935328253717545, 0.31848573305331973, 0.0, 0.7036693377297398, 0.02065034905954999, 0.4224925490810527, 0.11216184489956871, 0.42215466113858024, 0.0, 0.39649970945668, 0.4396553712320629, 0.8400030552287128, 0.6735267975566936 };
    List<Double> expectedOutlierScores = Arrays.stream(expectedDoubleOutlierScores).boxed().collect(Collectors.toList());
    assertEquals(expectedClusterLabels, clusterLabels);
    assertEquals(expectedOutlierScores, outlierScores);
    List<Prediction<ClusterID>> predictions = model.predict(testData);
    int i = 0;
    int[] actualLabelPredictions = new int[testData.size()];
    double[] actualOutlierScorePredictions = new double[testData.size()];
    for (Prediction<ClusterID> pred : predictions) {
        actualLabelPredictions[i] = pred.getOutput().getID();
        actualOutlierScorePredictions[i] = pred.getOutput().getScore();
        i++;
    }
    int[] expectedLabelPredictions = { 4, 5, 3, 5, 5, 3, 5, 4, 5, 3, 5, 5, 4, 4, 4, 5, 3, 4, 4, 3, 3, 5, 4, 5, 4, 5, 3, 3, 4, 5, 4, 4, 5, 3, 4, 5, 4, 4, 5, 5, 3, 5, 5, 5, 4, 5, 3, 4, 4, 5, 5, 5, 3, 3, 5, 4, 3, 5, 5, 5, 4, 5, 4, 5, 3, 4, 4, 3, 3, 3, 4, 5, 5, 5, 5, 3, 4, 5, 3, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 3, 4, 3, 3, 5, 3, 5, 5, 4, 4, 4, 4, 3, 3, 4, 4, 4, 4, 4, 3, 3, 5, 5, 5, 4, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 5, 5, 4, 5, 4, 4, 3, 5, 4, 5, 5, 3, 4, 5, 5, 4, 5, 4, 4, 3, 3, 5, 5, 5, 5, 4, 4, 5, 4, 4, 4, 4, 5, 3, 5, 5, 5, 3, 4, 5, 4, 4, 4, 4, 4, 4, 5, 5, 3, 4, 3, 5, 5, 5, 4, 5, 5, 5, 4, 4, 3, 4, 5, 4, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5 };
    double[] expectedOutlierScorePredictions = { 0.08118104079195632, 0.0, 0.04556290658626683, 0.010061020299601542, 0.029110696378704337, 0.08132701192855352, 0.010061020299601542, 0.07838284850471278, 0.025305516244955806, 0.08132701192855352, 0.03806887171267048, 0.02065034905954999, 0.08118104079195632, 0.0059186060290421505, 0.08118104079195632, 0.0, 0.08132701192855352, 0.0059186060290421505, 0.08118104079195632, 0.0041352367783195065, 0.08132701192855352, 0.02065034905954999, 0.08471089907230389, 0.025305516244955806, 0.0, 0.0, 0.0, 0.0, 0.08118104079195632, 0.06314528545928999, 0.07838284850471278, 0.0059186060290421505, 0.010061020299601542, 0.0041352367783195065, 0.0059186060290421505, 0.010061020299601542, 0.08118104079195632, 0.08118104079195632, 0.029110696378704337, 0.010061020299601542, 0.08132701192855352, 0.03806887171267048, 0.03806887171267048, 0.029110696378704337, 0.07406510775737263, 0.02065034905954999, 0.04556290658626683, 0.07838284850471278, 0.08118104079195632, 0.0, 0.0, 0.02065034905954999, 0.0041352367783195065, 0.0, 0.03806887171267048, 0.08118104079195632, 0.0041352367783195065, 0.03806887171267048, 0.02065034905954999, 0.02065034905954999, 0.00509618317782945, 0.06314528545928999, 0.0059186060290421505, 0.029110696378704337, 0.07860444563433111, 0.08118104079195632, 0.08471089907230389, 0.0041352367783195065, 0.0, 0.07860444563433111, 0.07838284850471278, 0.029110696378704337, 0.010061020299601542, 0.0, 0.06314528545928999, 0.0041352367783195065, 0.08118104079195632, 0.029110696378704337, 0.0, 0.02065034905954999, 0.06314528545928999, 0.0, 0.08471089907230389, 0.00509618317782945, 0.08212338693601973, 0.08471089907230389, 0.00509618317782945, 0.08118104079195632, 0.08118104079195632, 0.0, 0.025305516244955806, 0.0, 0.08118104079195632, 0.0041352367783195065, 0.0, 0.06314528545928999, 0.08132701192855352, 0.025305516244955806, 0.029110696378704337, 0.07838284850471278, 0.08118104079195632, 0.0, 0.053392009649182115, 0.05399443648683844, 0.07860444563433111, 0.08118104079195632, 0.08471089907230389, 0.08212338693601973, 0.08118104079195632, 0.0, 0.04556290658626683, 0.08132701192855352, 0.03806887171267048, 0.0, 0.0, 0.0, 0.08118104079195632, 0.029110696378704337, 0.0059186060290421505, 0.02065034905954999, 0.08118104079195632, 0.02065034905954999, 0.053392009649182115, 0.010061020299601542, 0.07838284850471278, 0.02065034905954999, 0.010061020299601542, 0.0, 0.08118104079195632, 0.029110696378704337, 0.0059186060290421505, 0.0, 0.08132701192855352, 0.025305516244955806, 0.0059186060290421505, 0.029110696378704337, 0.029110696378704337, 0.0, 0.00509618317782945, 0.025305516244955806, 0.010061020299601542, 0.08118104079195632, 0.029110696378704337, 0.08118104079195632, 0.0059186060290421505, 0.08132701192855352, 0.0, 0.06314528545928999, 0.029110696378704337, 0.029110696378704337, 0.029110696378704337, 0.0059186060290421505, 0.08212338693601973, 0.02065034905954999, 0.08118104079195632, 0.0059186060290421505, 0.0059186060290421505, 0.00509618317782945, 0.0, 0.08132701192855352, 0.025305516244955806, 0.025305516244955806, 0.02065034905954999, 0.0, 0.08118104079195632, 0.03806887171267048, 0.08471089907230389, 0.07838284850471278, 0.07838284850471278, 0.08118104079195632, 0.07838284850471278, 0.07838284850471278, 0.02065034905954999, 0.029110696378704337, 0.07860444563433111, 0.07838284850471278, 0.08132701192855352, 0.03806887171267048, 0.025305516244955806, 0.010061020299601542, 0.08471089907230389, 0.029110696378704337, 0.025305516244955806, 0.03806887171267048, 0.08471089907230389, 0.0, 0.07860444563433111, 0.08212338693601973, 0.029110696378704337, 0.08118104079195632, 0.08118104079195632, 0.06314528545928999, 0.010061020299601542, 0.010061020299601542, 0.029110696378704337, 0.07838284850471278, 0.025305516244955806, 0.010061020299601542, 0.010061020299601542, 0.06314528545928999 };
    assertArrayEquals(expectedLabelPredictions, actualLabelPredictions);
    assertArrayEquals(expectedOutlierScorePredictions, actualOutlierScorePredictions);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) GaussianClusterDataSource(org.tribuo.clustering.example.GaussianClusterDataSource) Feature(org.tribuo.Feature) List(java.util.List) SGDVector(org.tribuo.math.la.SGDVector) MutableDataset(org.tribuo.MutableDataset) Pair(com.oracle.labs.mlrg.olcut.util.Pair) HashSet(java.util.HashSet) TrainTestSplitter(org.tribuo.evaluation.TrainTestSplitter) Prediction(org.tribuo.Prediction)

Aggregations

MutableDataset (org.tribuo.MutableDataset)5 TrainTestSplitter (org.tribuo.evaluation.TrainTestSplitter)5 Prediction (org.tribuo.Prediction)2 Label (org.tribuo.classification.Label)2 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)2 NoisyInterlockingCrescentsDataSource (org.tribuo.classification.example.NoisyInterlockingCrescentsDataSource)2 ClusterID (org.tribuo.clustering.ClusterID)2 GaussianClusterDataSource (org.tribuo.clustering.example.GaussianClusterDataSource)2 Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 URL (java.net.URL)1 ArrayList (java.util.ArrayList)1 HashSet (java.util.HashSet)1 List (java.util.List)1 Label (javafx.scene.control.Label)1 Test (org.junit.jupiter.api.Test)1 Feature (org.tribuo.Feature)1 LabelFactory (org.tribuo.classification.LabelFactory)1 CARTClassificationTrainer (org.tribuo.classification.dtree.CARTClassificationTrainer)1 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)1 ClusteringEvaluation (org.tribuo.clustering.evaluation.ClusteringEvaluation)1