Search in sources :

Example 56 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class GradientTreeBoostTest method testUSPS.

/**
     * Test of learn method, of class GradientTreeBoost.
     */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        GradientTreeBoost boost = new GradientTreeBoost(train.attributes(), x, y, 100);
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (boost.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.format("Gradient Tree Boost error rate = %.2f%%%n", 100.0 * error / testx.length);
        double[] accuracy = boost.test(testx, testy);
        for (int i = 1; i <= accuracy.length; i++) {
            System.out.format("%d trees accuracy = %.2f%%%n", i, 100.0 * accuracy[i - 1]);
        }
        double[] importance = boost.importance();
        int[] index = QuickSort.sort(importance);
        for (int i = importance.length; i-- > 0; ) {
            System.out.format("%s importance is %.4f%n", train.attributes()[index[i]], importance[i]);
        }
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) AttributeDataset(smile.data.AttributeDataset) NominalAttribute(smile.data.NominalAttribute) Test(org.junit.Test)

Example 57 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class GradientTreeBoostTest method testUSPS2.

/**
     * Test of learn method, of class GradientTreeBoost.
     */
@Test
public void testUSPS2() {
    System.out.println("USPS 2 classes");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        for (int i = 0; i < y.length; i++) {
            if (y[i] != 0) {
                y[i] = 1;
            }
        }
        for (int i = 0; i < testy.length; i++) {
            if (testy[i] != 0) {
                testy[i] = 1;
            }
        }
        GradientTreeBoost boost = new GradientTreeBoost(train.attributes(), x, y, 100);
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (boost.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.format("Gradient Tree Boost error rate = %.2f%%%n", 100.0 * error / testx.length);
        double[] accuracy = boost.test(testx, testy);
        for (int i = 1; i <= accuracy.length; i++) {
            System.out.format("%d trees accuracy = %.2f%%%n", i, 100.0 * accuracy[i - 1]);
        }
        double[] importance = boost.importance();
        int[] index = QuickSort.sort(importance);
        for (int i = importance.length; i-- > 0; ) {
            System.out.format("%s importance is %.4f%n", train.attributes()[index[i]], importance[i]);
        }
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) AttributeDataset(smile.data.AttributeDataset) NominalAttribute(smile.data.NominalAttribute) Test(org.junit.Test)

Example 58 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class GradientTreeBoostTest method testSegment.

/**
     * Test of learn method, of class GradientTreeBoost.
     */
@Test
public void testSegment() {
    System.out.println("Segment");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(19);
    try {
        AttributeDataset train = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/segment-challenge.arff"));
        AttributeDataset test = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/segment-test.arff"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        GradientTreeBoost boost = new GradientTreeBoost(train.attributes(), x, y, 100);
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (boost.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.format("Gradient Tree Boost error rate = %.2f%%%n", 100.0 * error / testx.length);
    //assertEquals(28, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : ArffParser(smile.data.parser.ArffParser) AttributeDataset(smile.data.AttributeDataset) Test(org.junit.Test)

Example 59 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class KNNTest method testSegment.

/**
     * Test of learn method, of class KNN.
     */
@Test
public void testSegment() throws ParseException {
    System.out.println("Segment");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(19);
    try {
        AttributeDataset train = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/segment-challenge.arff"));
        AttributeDataset test = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/segment-test.arff"));
        double[][] x = train.toArray(new double[0][]);
        int[] y = train.toArray(new int[0]);
        double[][] testx = test.toArray(new double[0][]);
        int[] testy = test.toArray(new int[0]);
        KNN<double[]> knn = KNN.learn(x, y);
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (knn.predict(testx[i]) != testy[i]) {
                error++;
            }
        }
        System.out.format("Segment error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertEquals(39, error);
    } catch (IOException ex) {
        System.err.println(ex);
    }
}
Also used : ArffParser(smile.data.parser.ArffParser) AttributeDataset(smile.data.AttributeDataset) IOException(java.io.IOException) Test(org.junit.Test)

Example 60 with AttributeDataset

use of smile.data.AttributeDataset in project smile by haifengl.

the class CLARANSTest method testUSPS.

/**
     * Test of learn method, of class CLARANS.
     */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        AdjustedRandIndex ari = new AdjustedRandIndex();
        RandIndex rand = new RandIndex();
        CLARANS<double[]> clarans = new CLARANS<>(x, new EuclideanDistance(), 10, 50, 8);
        double r = rand.measure(y, clarans.getClusterLabel());
        double r2 = ari.measure(y, clarans.getClusterLabel());
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.8);
        assertTrue(r2 > 0.28);
        int[] p = new int[testx.length];
        for (int i = 0; i < testx.length; i++) {
            p[i] = clarans.predict(testx[i]);
        }
        r = rand.measure(testy, p);
        r2 = ari.measure(testy, p);
        System.out.format("Testing rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.8);
        assertTrue(r2 > 0.25);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) AttributeDataset(smile.data.AttributeDataset) RandIndex(smile.validation.RandIndex) AdjustedRandIndex(smile.validation.AdjustedRandIndex) EuclideanDistance(smile.math.distance.EuclideanDistance) NominalAttribute(smile.data.NominalAttribute) AdjustedRandIndex(smile.validation.AdjustedRandIndex) Test(org.junit.Test)

Aggregations

AttributeDataset (smile.data.AttributeDataset)140 Test (org.junit.Test)125 ArffParser (smile.data.parser.ArffParser)75 NominalAttribute (smile.data.NominalAttribute)50 DelimitedTextParser (smile.data.parser.DelimitedTextParser)48 Attribute (smile.data.Attribute)29 EuclideanDistance (smile.math.distance.EuclideanDistance)19 LOOCV (smile.validation.LOOCV)18 CrossValidation (smile.validation.CrossValidation)17 AdjustedRandIndex (smile.validation.AdjustedRandIndex)14 RandIndex (smile.validation.RandIndex)14 ClassifierTrainer (smile.classification.ClassifierTrainer)13 GaussianKernel (smile.math.kernel.GaussianKernel)11 IOException (java.io.IOException)10 RadialBasisFunction (smile.math.rbf.RadialBasisFunction)9 RBFNetwork (smile.regression.RBFNetwork)8 ArrayList (java.util.ArrayList)6 KMeans (smile.clustering.KMeans)6 Datum (smile.data.Datum)6 NumericAttribute (smile.data.NumericAttribute)6