Search in sources :

Example 1 with ArrayExample

use of org.tribuo.impl.ArrayExample in project ml-commons by opensearch-project.

the class TribuoUtil method generateDatasetWithTarget.

/**
 * Generate tribuo dataset from data frame with target.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @param target target name
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
    if (StringUtils.isEmpty(target)) {
        throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
    }
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    int targetIndex = -1;
    for (int i = 0; i < featureNamesValues.v1().length; ++i) {
        if (featureNamesValues.v1()[i].equals(target)) {
            targetIndex = i;
            break;
        }
    }
    if (targetIndex == -1) {
        throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
    }
    ArrayExample<T> example;
    final int finalTargetIndex = targetIndex;
    String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case REGRESSOR:
                final int finalI = i;
                double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
                double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
                example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : IntStream(java.util.stream.IntStream) Example(org.tribuo.Example) Arrays(java.util.Arrays) Row(org.opensearch.ml.common.dataframe.Row) Iterator(java.util.Iterator) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) ClusterID(org.tribuo.clustering.ClusterID) StringUtils(org.apache.commons.lang3.StringUtils) OutputFactory(org.tribuo.OutputFactory) Event(org.tribuo.anomaly.Event) Tuple(org.opensearch.common.collect.Tuple) ArrayList(java.util.ArrayList) UtilityClass(lombok.experimental.UtilityClass) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) List(java.util.List) TribuoOutputType(org.opensearch.ml.engine.contants.TribuoOutputType) Output(org.tribuo.Output) ListDataSource(org.tribuo.datasource.ListDataSource) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) StreamSupport(java.util.stream.StreamSupport) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) MutableDataset(org.tribuo.MutableDataset) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Example 2 with ArrayExample

use of org.tribuo.impl.ArrayExample in project ml-commons by opensearch-project.

the class TribuoUtil method generateDataset.

/**
 * Generate tribuo dataset from data frame.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    ArrayExample<T> example;
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case CLUSTERID:
                example = new ArrayExample<>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            case REGRESSOR:
                // Create single dimension tribuo regressor with name DIM-0 and value double NaN.
                example = new ArrayExample<>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            case ANOMALY_DETECTION_LIBSVM:
                // Why we set default event type as EXPECTED(non-anomalous)
                // 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
                // 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
                Event.EventType defaultEventType = Event.EventType.EXPECTED;
                // TODO: support anomaly labels to evaluate prediction result
                example = new ArrayExample<>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Event(org.tribuo.anomaly.Event) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Example 3 with ArrayExample

use of org.tribuo.impl.ArrayExample in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDatasetWithTarget.

@SuppressWarnings("unchecked")
@Test
public void generateDatasetWithTarget() {
    MutableDataset<Regressor> dataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, "f2");
    List<Example<Regressor>> examples = dataset.getData();
    Assert.assertEquals(rawData.length, examples.size());
    for (int i = 0; i < rawData.length; ++i) {
        ArrayExample arrayExample = (ArrayExample) examples.get(i);
        Iterator<Feature> iterator = arrayExample.iterator();
        int idx = 1;
        while (iterator.hasNext()) {
            Feature feature = iterator.next();
            Assert.assertEquals("f" + idx, feature.getName());
            Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
            ++idx;
        }
    }
}
Also used : ArrayExample(org.tribuo.impl.ArrayExample) RegressionFactory(org.tribuo.regression.RegressionFactory) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) Feature(org.tribuo.Feature) Test(org.junit.Test)

Example 4 with ArrayExample

use of org.tribuo.impl.ArrayExample in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDataset.

@SuppressWarnings("unchecked")
@Test
public void generateDataset() {
    MutableDataset<ClusterID> dataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "test", TribuoOutputType.CLUSTERID);
    List<Example<ClusterID>> examples = dataset.getData();
    Assert.assertEquals(rawData.length, examples.size());
    for (int i = 0; i < rawData.length; ++i) {
        ArrayExample arrayExample = (ArrayExample) examples.get(i);
        Iterator<Feature> iterator = arrayExample.iterator();
        int idx = 1;
        while (iterator.hasNext()) {
            Feature feature = iterator.next();
            Assert.assertEquals("f" + idx, feature.getName());
            Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
            ++idx;
        }
    }
}
Also used : ArrayExample(org.tribuo.impl.ArrayExample) ClusterID(org.tribuo.clustering.ClusterID) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Feature(org.tribuo.Feature) Test(org.junit.Test)

Aggregations

Example (org.tribuo.Example)4 ArrayExample (org.tribuo.impl.ArrayExample)4 ClusterID (org.tribuo.clustering.ClusterID)3 Regressor (org.tribuo.regression.Regressor)3 ArrayList (java.util.ArrayList)2 Test (org.junit.Test)2 Feature (org.tribuo.Feature)2 MutableDataset (org.tribuo.MutableDataset)2 Event (org.tribuo.anomaly.Event)2 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)2 Arrays (java.util.Arrays)1 Iterator (java.util.Iterator)1 List (java.util.List)1 IntStream (java.util.stream.IntStream)1 StreamSupport (java.util.stream.StreamSupport)1 UtilityClass (lombok.experimental.UtilityClass)1 StringUtils (org.apache.commons.lang3.StringUtils)1 Tuple (org.opensearch.common.collect.Tuple)1 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)1 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)1