use of smile.gap.BitString in project smile by haifengl.
the class GAFeatureSelectionTest method testLearn.
/**
* Test of learn method, of class GAFeatureSelection.
*/
@Test
public void testLearn() {
System.out.println("learn");
int size = 100;
int generation = 20;
ClassifierTrainer<double[]> trainer = new LDA.Trainer();
ClassificationMeasure measure = new Accuracy();
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
GAFeatureSelection instance = new GAFeatureSelection();
BitString[] result = instance.learn(size, generation, trainer, measure, x, y, testx, testy);
for (BitString bits : result) {
System.out.format("%.2f%% %d ", 100 * bits.fitness(), Math.sum(bits.bits()));
for (int i = 0; i < x[0].length; i++) {
System.out.print(bits.bits()[i] + " ");
}
System.out.println();
}
assertTrue(result[result.length - 1].fitness() > 0.88);
} catch (Exception ex) {
System.err.println(ex);
}
}
use of smile.gap.BitString in project smile by haifengl.
the class GAFeatureSelection method learn.
/**
* Genetic algorithm based feature selection for classification.
* @param size the population size of Genetic Algorithm.
* @param generation the maximum number of iterations.
* @param trainer classifier trainer.
* @param measure classification measure as the chromosome fitness measure.
* @param x training instances.
* @param y training labels.
* @param k k-fold cross validation for the evaluation.
* @return bit strings of last generation.
*/
public BitString[] learn(int size, int generation, ClassifierTrainer<double[]> trainer, ClassificationMeasure measure, double[][] x, int[] y, int k) {
if (size <= 0) {
throw new IllegalArgumentException("Invalid population size: " + size);
}
if (k < 2) {
throw new IllegalArgumentException("Invalid k-fold cross validation: " + k);
}
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
int p = x[0].length;
ClassificationFitness fitness = new ClassificationFitness(trainer, measure, x, y, k);
BitString[] seeds = new BitString[size];
for (int i = 0; i < size; i++) {
seeds[i] = new BitString(p, fitness, crossover, crossoverRate, mutationRate);
}
GeneticAlgorithm<BitString> ga = new GeneticAlgorithm<>(seeds, selection);
ga.evolve(generation);
return seeds;
}
use of smile.gap.BitString in project smile by haifengl.
the class GAFeatureSelection method learn.
/**
* Genetic algorithm based feature selection for regression.
* @param size the population size of Genetic Algorithm.
* @param generation the maximum number of iterations.
* @param trainer regression model trainer.
* @param measure classification measure as the chromosome fitness measure.
* @param x training instances.
* @param y training instance response variable.
* @param k k-fold cross validation for the evaluation.
* @return bit strings of last generation.
*/
public BitString[] learn(int size, int generation, RegressionTrainer<double[]> trainer, RegressionMeasure measure, double[][] x, double[] y, int k) {
if (size <= 0) {
throw new IllegalArgumentException("Invalid population size: " + size);
}
if (k < 2) {
throw new IllegalArgumentException("Invalid k-fold cross validation: " + k);
}
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
int p = x[0].length;
RegressionFitness fitness = new RegressionFitness(trainer, measure, x, y, k);
BitString[] seeds = new BitString[size];
for (int i = 0; i < size; i++) {
seeds[i] = new BitString(p, fitness, crossover, crossoverRate, mutationRate);
}
GeneticAlgorithm<BitString> ga = new GeneticAlgorithm<>(seeds, selection);
ga.evolve(generation);
return seeds;
}
Aggregations