Search in sources :

Example 1 with OutputFactory

use of org.tribuo.OutputFactory in project ml-commons by opensearch-project.

the class TribuoUtil method generateDatasetWithTarget.

/**
 * Generate tribuo dataset from data frame with target.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @param target target name
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
    if (StringUtils.isEmpty(target)) {
        throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
    }
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    int targetIndex = -1;
    for (int i = 0; i < featureNamesValues.v1().length; ++i) {
        if (featureNamesValues.v1()[i].equals(target)) {
            targetIndex = i;
            break;
        }
    }
    if (targetIndex == -1) {
        throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
    }
    ArrayExample<T> example;
    final int finalTargetIndex = targetIndex;
    String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case REGRESSOR:
                final int finalI = i;
                double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
                double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
                example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : IntStream(java.util.stream.IntStream) Example(org.tribuo.Example) Arrays(java.util.Arrays) Row(org.opensearch.ml.common.dataframe.Row) Iterator(java.util.Iterator) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) ClusterID(org.tribuo.clustering.ClusterID) StringUtils(org.apache.commons.lang3.StringUtils) OutputFactory(org.tribuo.OutputFactory) Event(org.tribuo.anomaly.Event) Tuple(org.opensearch.common.collect.Tuple) ArrayList(java.util.ArrayList) UtilityClass(lombok.experimental.UtilityClass) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) List(java.util.List) TribuoOutputType(org.opensearch.ml.engine.contants.TribuoOutputType) Output(org.tribuo.Output) ListDataSource(org.tribuo.datasource.ListDataSource) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) StreamSupport(java.util.stream.StreamSupport) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) MutableDataset(org.tribuo.MutableDataset) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Aggregations

ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Iterator (java.util.Iterator)1 List (java.util.List)1 IntStream (java.util.stream.IntStream)1 StreamSupport (java.util.stream.StreamSupport)1 UtilityClass (lombok.experimental.UtilityClass)1 StringUtils (org.apache.commons.lang3.StringUtils)1 Tuple (org.opensearch.common.collect.Tuple)1 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)1 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)1 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)1 Row (org.opensearch.ml.common.dataframe.Row)1 TribuoOutputType (org.opensearch.ml.engine.contants.TribuoOutputType)1 Example (org.tribuo.Example)1 MutableDataset (org.tribuo.MutableDataset)1 Output (org.tribuo.Output)1 OutputFactory (org.tribuo.OutputFactory)1 Event (org.tribuo.anomaly.Event)1 ClusterID (org.tribuo.clustering.ClusterID)1