use of org.tribuo.MutableDataset in project ml-commons by opensearch-project.
the class TribuoUtil method generateDatasetWithTarget.
/**
* Generate tribuo dataset from data frame with target.
* @param dataFrame features data
* @param outputFactory the tribuo output factory
* @param desc description for tribuo provenance
* @param outputType the tribuo output type
* @param target target name
* @return tribuo dataset
*/
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
if (StringUtils.isEmpty(target)) {
throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
}
List<Example<T>> dataset = new ArrayList<>();
Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
int targetIndex = -1;
for (int i = 0; i < featureNamesValues.v1().length; ++i) {
if (featureNamesValues.v1()[i].equals(target)) {
targetIndex = i;
break;
}
}
if (targetIndex == -1) {
throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
}
ArrayExample<T> example;
final int finalTargetIndex = targetIndex;
String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
for (int i = 0; i < dataFrame.size(); ++i) {
switch(outputType) {
case REGRESSOR:
final int finalI = i;
double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
}
dataset.add(example);
}
SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
use of org.tribuo.MutableDataset in project ml-commons by opensearch-project.
the class TribuoUtil method generateDataset.
/**
* Generate tribuo dataset from data frame.
* @param dataFrame features data
* @param outputFactory the tribuo output factory
* @param desc description for tribuo provenance
* @param outputType the tribuo output type
* @return tribuo dataset
*/
public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
List<Example<T>> dataset = new ArrayList<>();
Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
ArrayExample<T> example;
for (int i = 0; i < dataFrame.size(); ++i) {
switch(outputType) {
case CLUSTERID:
example = new ArrayExample<>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
case REGRESSOR:
// Create single dimension tribuo regressor with name DIM-0 and value double NaN.
example = new ArrayExample<>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
case ANOMALY_DETECTION_LIBSVM:
// Why we set default event type as EXPECTED(non-anomalous)
// 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
// 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
Event.EventType defaultEventType = Event.EventType.EXPECTED;
// TODO: support anomaly labels to evaluate prediction result
example = new ArrayExample<>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
}
dataset.add(example);
}
SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Aggregations