use of smile.data.AttributeDataset in project smile by haifengl.
the class RandomForestTest method testCPU.
/**
* Test of learn method, of class RandomForest.
*/
@Test
public void testCPU() {
System.out.println("CPU");
ArffParser parser = new ArffParser();
parser.setResponseIndex(6);
try {
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
double[] datay = data.toArray(new double[data.size()]);
double[][] datax = data.toArray(new double[data.size()][]);
int n = datax.length;
int m = 3 * n / 4;
int[] index = Math.permutate(n);
double[][] trainx = new double[m][];
double[] trainy = new double[m];
for (int i = 0; i < m; i++) {
trainx[i] = datax[index[i]];
trainy[i] = datay[index[i]];
}
double[][] testx = new double[n - m][];
double[] testy = new double[n - m];
for (int i = m; i < n; i++) {
testx[i - m] = datax[index[i]];
testy[i - m] = datay[index[i]];
}
RandomForest forest = new RandomForest(data.attributes(), trainx, trainy, 100, n, 5, trainx[0].length / 3);
System.out.format("RMSE = %.4f%n", Validation.test(forest, testx, testy));
double[] rmse = forest.test(testx, testy);
for (int i = 1; i <= rmse.length; i++) {
System.out.format("%d trees RMSE = %.4f%n", i, rmse[i - 1]);
}
double[] importance = forest.importance();
index = QuickSort.sort(importance);
for (int i = importance.length; i-- > 0; ) {
System.out.format("%s importance is %.4f%n", data.attributes()[index[i]], importance[i]);
}
} catch (Exception ex) {
System.err.println(ex);
}
}
use of smile.data.AttributeDataset in project smile by haifengl.
the class RESParser method parse.
/**
* Parse a RES dataset from an input stream.
* @param name the name of dataset.
* @param stream the input stream of data.
* @throws java.io.IOException
*/
public AttributeDataset parse(String name, InputStream stream) throws IOException, ParseException {
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = reader.readLine();
if (line == null) {
throw new IOException("Empty data source.");
}
String[] tokens = line.split("\t", -1);
int p = (tokens.length - 2) / 2;
line = reader.readLine();
if (line == null) {
throw new IOException("Premature end of file.");
}
String[] samples = line.split("\t", -1);
if (samples.length != tokens.length - 1) {
throw new IOException("Invalid sample description header.");
}
Attribute[] attributes = new Attribute[p];
for (int i = 0; i < p; i++) {
attributes[i] = new NumericAttribute(tokens[2 * i + 2], samples[2 * i + 1]);
}
line = reader.readLine();
if (line == null) {
throw new IOException("Premature end of file.");
}
int n = Integer.parseInt(line);
if (n <= 0) {
throw new IOException("Invalid number of rows: " + n);
}
AttributeDataset data = new AttributeDataset(name, attributes);
for (int i = 0; i < n; i++) {
line = reader.readLine();
if (line == null) {
throw new IOException("Premature end of file.");
}
tokens = line.split("\t", -1);
if (tokens.length != samples.length + 1) {
throw new IOException(String.format("Invalid number of elements of line %d: %d", i + 4, tokens.length));
}
double[] x = new double[p];
for (int j = 0; j < p; j++) {
x[j] = Double.valueOf(tokens[2 * j + 2]);
}
Datum<double[]> datum = new Datum<>(x);
datum.name = tokens[1];
datum.description = tokens[0];
data.add(datum);
}
reader.close();
return data;
}
use of smile.data.AttributeDataset in project smile by haifengl.
the class ArffParserTest method testParseSparse.
/**
* Test of parse method, of class ArffParser.
*/
@Test
public void testParseSparse() throws Exception {
System.out.println("sparse");
try {
ArffParser arffParser = new ArffParser();
AttributeDataset sparse = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/sparse.arff"));
double[][] x = sparse.toArray(new double[sparse.size()][]);
assertEquals(2, sparse.size());
assertEquals(5, sparse.attributes().length);
assertEquals(0.0, x[0][0], 1E-7);
assertEquals(2.0, x[0][1], 1E-7);
assertEquals(0.0, x[0][2], 1E-7);
assertEquals(3.0, x[0][3], 1E-7);
assertEquals(0.0, x[0][4], 1E-7);
assertEquals(0.0, x[1][0], 1E-7);
assertEquals(0.0, x[1][1], 1E-7);
assertEquals(1.0, x[1][2], 1E-7);
assertEquals(0.0, x[1][3], 1E-7);
assertEquals(1.0, x[1][4], 1E-7);
} catch (Exception ex) {
System.err.println(ex);
}
}
use of smile.data.AttributeDataset in project smile by haifengl.
the class ArffParserTest method testParseString.
/**
* Test of parse method, of class ArffParser.
*/
@Test
public void testParseString() throws Exception {
System.out.println("string");
try {
ArffParser arffParser = new ArffParser();
AttributeDataset string = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/string.arff"));
double[][] x = string.toArray(new double[string.size()][]);
for (Attribute attribute : string.attributes()) {
assertEquals(Attribute.Type.STRING, attribute.getType());
}
Attribute[] attributes = string.attributes();
assertEquals(5, string.size());
assertEquals(2, attributes.length);
assertEquals("AG5", attributes[0].toString(x[0][0]));
assertEquals("Encyclopedias and dictionaries.;Twentieth century.", attributes[1].toString(x[0][1]));
assertEquals("AS281", attributes[0].toString(x[4][0]));
assertEquals("Astronomy, Assyro-Babylonian.;Moon -- Tables.", attributes[1].toString(x[4][1]));
} catch (Exception ex) {
System.err.println(ex);
}
}
use of smile.data.AttributeDataset in project smile by haifengl.
the class DelimitedTextParserTest method testParse.
/**
* Test of parse method, of class DelimitedTextParser.
*/
@Test
public void testParse() throws Exception {
System.out.println("parse");
try {
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
AttributeDataset usps = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
double[][] x = usps.toArray(new double[usps.size()][]);
int[] y = usps.toArray(new int[usps.size()]);
assertEquals(Attribute.Type.NOMINAL, usps.response().getType());
for (Attribute attribute : usps.attributes()) {
assertEquals(Attribute.Type.NUMERIC, attribute.getType());
}
assertEquals(7291, usps.size());
assertEquals(256, usps.attributes().length);
assertEquals("6", usps.response().toString(y[0]));
assertEquals("5", usps.response().toString(y[1]));
assertEquals("4", usps.response().toString(y[2]));
assertEquals(-1.0000, x[0][6], 1E-7);
assertEquals(-0.6310, x[0][7], 1E-7);
assertEquals(0.8620, x[0][8], 1E-7);
assertEquals("1", usps.response().toString(y[7290]));
assertEquals(-1.0000, x[7290][4], 1E-7);
assertEquals(-0.1080, x[7290][5], 1E-7);
assertEquals(1.0000, x[7290][6], 1E-7);
} catch (Exception ex) {
System.err.println(ex);
}
}
Aggregations