Search in sources :

Example 1 with MutableDataset

use of org.tribuo.MutableDataset in project ml-commons by opensearch-project.

the class TribuoUtil method generateDatasetWithTarget.

/**
 * Generate tribuo dataset from data frame with target.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @param target target name
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
    if (StringUtils.isEmpty(target)) {
        throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
    }
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    int targetIndex = -1;
    for (int i = 0; i < featureNamesValues.v1().length; ++i) {
        if (featureNamesValues.v1()[i].equals(target)) {
            targetIndex = i;
            break;
        }
    }
    if (targetIndex == -1) {
        throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
    }
    ArrayExample<T> example;
    final int finalTargetIndex = targetIndex;
    String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case REGRESSOR:
                final int finalI = i;
                double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
                double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
                example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : IntStream(java.util.stream.IntStream) Example(org.tribuo.Example) Arrays(java.util.Arrays) Row(org.opensearch.ml.common.dataframe.Row) Iterator(java.util.Iterator) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) ClusterID(org.tribuo.clustering.ClusterID) StringUtils(org.apache.commons.lang3.StringUtils) OutputFactory(org.tribuo.OutputFactory) Event(org.tribuo.anomaly.Event) Tuple(org.opensearch.common.collect.Tuple) ArrayList(java.util.ArrayList) UtilityClass(lombok.experimental.UtilityClass) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) List(java.util.List) TribuoOutputType(org.opensearch.ml.engine.contants.TribuoOutputType) Output(org.tribuo.Output) ListDataSource(org.tribuo.datasource.ListDataSource) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) StreamSupport(java.util.stream.StreamSupport) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) MutableDataset(org.tribuo.MutableDataset) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Example 2 with MutableDataset

use of org.tribuo.MutableDataset in project ml-commons by opensearch-project.

the class TribuoUtil method generateDataset.

/**
 * Generate tribuo dataset from data frame.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    ArrayExample<T> example;
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case CLUSTERID:
                example = new ArrayExample<>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            case REGRESSOR:
                // Create single dimension tribuo regressor with name DIM-0 and value double NaN.
                example = new ArrayExample<>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            case ANOMALY_DETECTION_LIBSVM:
                // Why we set default event type as EXPECTED(non-anomalous)
                // 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
                // 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
                Event.EventType defaultEventType = Event.EventType.EXPECTED;
                // TODO: support anomaly labels to evaluate prediction result
                example = new ArrayExample<>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Event(org.tribuo.anomaly.Event) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Aggregations

ArrayList (java.util.ArrayList)2 Example (org.tribuo.Example)2 MutableDataset (org.tribuo.MutableDataset)2 Event (org.tribuo.anomaly.Event)2 ClusterID (org.tribuo.clustering.ClusterID)2 ArrayExample (org.tribuo.impl.ArrayExample)2 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)2 Regressor (org.tribuo.regression.Regressor)2 Arrays (java.util.Arrays)1 Iterator (java.util.Iterator)1 List (java.util.List)1 IntStream (java.util.stream.IntStream)1 StreamSupport (java.util.stream.StreamSupport)1 UtilityClass (lombok.experimental.UtilityClass)1 StringUtils (org.apache.commons.lang3.StringUtils)1 Tuple (org.opensearch.common.collect.Tuple)1 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)1 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)1 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)1 Row (org.opensearch.ml.common.dataframe.Row)1