Search in sources :

Example 1 with Feature

use of org.tribuo.Feature in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVMTest method constructDataFrame.

private DataFrame constructDataFrame(Dataset<Event> data, boolean training, List<Event.EventType> labels) {
    Iterator<Example<Event>> iterator = data.iterator();
    List<ColumnMeta> columns = null;
    DataFrame dataFrame = null;
    while (iterator.hasNext()) {
        Example<Event> example = iterator.next();
        if (columns == null) {
            columns = new ArrayList<>();
            List<ColumnValue> columnValues = new ArrayList<>();
            for (Feature feature : example) {
                columns.add(new ColumnMeta(feature.getName(), ColumnType.DOUBLE));
                columnValues.add(new DoubleValue(feature.getValue()));
            }
            ColumnMeta[] columnMetas = columns.toArray(new ColumnMeta[columns.size()]);
            dataFrame = new DefaultDataFrame(columnMetas);
            addRow(columnValues, training, example, dataFrame, labels);
        } else {
            List<ColumnValue> columnValues = new ArrayList<>();
            for (Feature feature : example) {
                columnValues.add(new DoubleValue(feature.getValue()));
            }
            addRow(columnValues, training, example, dataFrame, labels);
        }
    }
    return dataFrame;
}
Also used : ArrayList(java.util.ArrayList) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) Feature(org.tribuo.Feature) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) DoubleValue(org.opensearch.ml.common.dataframe.DoubleValue) Example(org.tribuo.Example) Event(org.tribuo.anomaly.Event) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue)

Example 2 with Feature

use of org.tribuo.Feature in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDatasetWithTarget.

@SuppressWarnings("unchecked")
@Test
public void generateDatasetWithTarget() {
    MutableDataset<Regressor> dataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, "f2");
    List<Example<Regressor>> examples = dataset.getData();
    Assert.assertEquals(rawData.length, examples.size());
    for (int i = 0; i < rawData.length; ++i) {
        ArrayExample arrayExample = (ArrayExample) examples.get(i);
        Iterator<Feature> iterator = arrayExample.iterator();
        int idx = 1;
        while (iterator.hasNext()) {
            Feature feature = iterator.next();
            Assert.assertEquals("f" + idx, feature.getName());
            Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
            ++idx;
        }
    }
}
Also used : ArrayExample(org.tribuo.impl.ArrayExample) RegressionFactory(org.tribuo.regression.RegressionFactory) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) Feature(org.tribuo.Feature) Test(org.junit.Test)

Example 3 with Feature

use of org.tribuo.Feature in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDataset.

@SuppressWarnings("unchecked")
@Test
public void generateDataset() {
    MutableDataset<ClusterID> dataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "test", TribuoOutputType.CLUSTERID);
    List<Example<ClusterID>> examples = dataset.getData();
    Assert.assertEquals(rawData.length, examples.size());
    for (int i = 0; i < rawData.length; ++i) {
        ArrayExample arrayExample = (ArrayExample) examples.get(i);
        Iterator<Feature> iterator = arrayExample.iterator();
        int idx = 1;
        while (iterator.hasNext()) {
            Feature feature = iterator.next();
            Assert.assertEquals("f" + idx, feature.getName());
            Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
            ++idx;
        }
    }
}
Also used : ArrayExample(org.tribuo.impl.ArrayExample) ClusterID(org.tribuo.clustering.ClusterID) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Feature(org.tribuo.Feature) Test(org.junit.Test)

Aggregations

Example (org.tribuo.Example)3 Feature (org.tribuo.Feature)3 Test (org.junit.Test)2 ArrayExample (org.tribuo.impl.ArrayExample)2 ArrayList (java.util.ArrayList)1 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)1 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)1 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)1 DefaultDataFrame (org.opensearch.ml.common.dataframe.DefaultDataFrame)1 DoubleValue (org.opensearch.ml.common.dataframe.DoubleValue)1 Event (org.tribuo.anomaly.Event)1 ClusterID (org.tribuo.clustering.ClusterID)1 ClusteringFactory (org.tribuo.clustering.ClusteringFactory)1 RegressionFactory (org.tribuo.regression.RegressionFactory)1 Regressor (org.tribuo.regression.Regressor)1