Search in sources :

Example 1 with BitString

use of smile.gap.BitString in project smile by haifengl.

the class GAFeatureSelectionTest method testLearn.

/**
     * Test of learn method, of class GAFeatureSelection.
     */
@Test
public void testLearn() {
    System.out.println("learn");
    int size = 100;
    int generation = 20;
    ClassifierTrainer<double[]> trainer = new LDA.Trainer();
    ClassificationMeasure measure = new Accuracy();
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        GAFeatureSelection instance = new GAFeatureSelection();
        BitString[] result = instance.learn(size, generation, trainer, measure, x, y, testx, testy);
        for (BitString bits : result) {
            System.out.format("%.2f%% %d ", 100 * bits.fitness(), Math.sum(bits.bits()));
            for (int i = 0; i < x[0].length; i++) {
                System.out.print(bits.bits()[i] + " ");
            }
            System.out.println();
        }
        assertTrue(result[result.length - 1].fitness() > 0.88);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
Also used : DelimitedTextParser(smile.data.parser.DelimitedTextParser) AttributeDataset(smile.data.AttributeDataset) ClassificationMeasure(smile.validation.ClassificationMeasure) ClassifierTrainer(smile.classification.ClassifierTrainer) Accuracy(smile.validation.Accuracy) NominalAttribute(smile.data.NominalAttribute) BitString(smile.gap.BitString) Test(org.junit.Test)

Example 2 with BitString

use of smile.gap.BitString in project smile by haifengl.

the class GAFeatureSelection method learn.

/**
     * Genetic algorithm based feature selection for classification.
     * @param size the population size of Genetic Algorithm.
     * @param generation the maximum number of iterations.
     * @param trainer classifier trainer.
     * @param measure classification measure as the chromosome fitness measure.
     * @param x training instances.
     * @param y training labels.
     * @param k k-fold cross validation for the evaluation.
     * @return bit strings of last generation.
     */
public BitString[] learn(int size, int generation, ClassifierTrainer<double[]> trainer, ClassificationMeasure measure, double[][] x, int[] y, int k) {
    if (size <= 0) {
        throw new IllegalArgumentException("Invalid population size: " + size);
    }
    if (k < 2) {
        throw new IllegalArgumentException("Invalid k-fold cross validation: " + k);
    }
    if (x.length != y.length) {
        throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
    }
    int p = x[0].length;
    ClassificationFitness fitness = new ClassificationFitness(trainer, measure, x, y, k);
    BitString[] seeds = new BitString[size];
    for (int i = 0; i < size; i++) {
        seeds[i] = new BitString(p, fitness, crossover, crossoverRate, mutationRate);
    }
    GeneticAlgorithm<BitString> ga = new GeneticAlgorithm<>(seeds, selection);
    ga.evolve(generation);
    return seeds;
}
Also used : BitString(smile.gap.BitString) GeneticAlgorithm(smile.gap.GeneticAlgorithm)

Example 3 with BitString

use of smile.gap.BitString in project smile by haifengl.

the class GAFeatureSelection method learn.

/**
     * Genetic algorithm based feature selection for regression.
     * @param size the population size of Genetic Algorithm.
     * @param generation the maximum number of iterations.
     * @param trainer regression model trainer.
     * @param measure classification measure as the chromosome fitness measure.
     * @param x training instances.
     * @param y training instance response variable.
     * @param k k-fold cross validation for the evaluation.
     * @return bit strings of last generation.
     */
public BitString[] learn(int size, int generation, RegressionTrainer<double[]> trainer, RegressionMeasure measure, double[][] x, double[] y, int k) {
    if (size <= 0) {
        throw new IllegalArgumentException("Invalid population size: " + size);
    }
    if (k < 2) {
        throw new IllegalArgumentException("Invalid k-fold cross validation: " + k);
    }
    if (x.length != y.length) {
        throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
    }
    int p = x[0].length;
    RegressionFitness fitness = new RegressionFitness(trainer, measure, x, y, k);
    BitString[] seeds = new BitString[size];
    for (int i = 0; i < size; i++) {
        seeds[i] = new BitString(p, fitness, crossover, crossoverRate, mutationRate);
    }
    GeneticAlgorithm<BitString> ga = new GeneticAlgorithm<>(seeds, selection);
    ga.evolve(generation);
    return seeds;
}
Also used : BitString(smile.gap.BitString) GeneticAlgorithm(smile.gap.GeneticAlgorithm)

Aggregations

BitString (smile.gap.BitString)3 GeneticAlgorithm (smile.gap.GeneticAlgorithm)2 Test (org.junit.Test)1 ClassifierTrainer (smile.classification.ClassifierTrainer)1 AttributeDataset (smile.data.AttributeDataset)1 NominalAttribute (smile.data.NominalAttribute)1 DelimitedTextParser (smile.data.parser.DelimitedTextParser)1 Accuracy (smile.validation.Accuracy)1 ClassificationMeasure (smile.validation.ClassificationMeasure)1