use of org.opensearch.ml.engine.contants.TribuoOutputType in project ml-commons by opensearch-project.
the class TribuoUtil method generateDatasetWithTarget.
/**
* Generate tribuo dataset from data frame with target.
* @param dataFrame features data
* @param outputFactory the tribuo output factory
* @param desc description for tribuo provenance
* @param outputType the tribuo output type
* @param target target name
* @return tribuo dataset
*/
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
if (StringUtils.isEmpty(target)) {
throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
}
List<Example<T>> dataset = new ArrayList<>();
Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
int targetIndex = -1;
for (int i = 0; i < featureNamesValues.v1().length; ++i) {
if (featureNamesValues.v1()[i].equals(target)) {
targetIndex = i;
break;
}
}
if (targetIndex == -1) {
throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
}
ArrayExample<T> example;
final int finalTargetIndex = targetIndex;
String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
for (int i = 0; i < dataFrame.size(); ++i) {
switch(outputType) {
case REGRESSOR:
final int finalI = i;
double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
}
dataset.add(example);
}
SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Aggregations