use of org.tribuo.regression.Regressor in project ml-commons by opensearch-project.
the class TribuoUtil method generateDatasetWithTarget.
/**
* Generate tribuo dataset from data frame with target.
* @param dataFrame features data
* @param outputFactory the tribuo output factory
* @param desc description for tribuo provenance
* @param outputType the tribuo output type
* @param target target name
* @return tribuo dataset
*/
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
if (StringUtils.isEmpty(target)) {
throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
}
List<Example<T>> dataset = new ArrayList<>();
Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
int targetIndex = -1;
for (int i = 0; i < featureNamesValues.v1().length; ++i) {
if (featureNamesValues.v1()[i].equals(target)) {
targetIndex = i;
break;
}
}
if (targetIndex == -1) {
throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
}
ArrayExample<T> example;
final int finalTargetIndex = targetIndex;
String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
for (int i = 0; i < dataFrame.size(); ++i) {
switch(outputType) {
case REGRESSOR:
final int finalI = i;
double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
}
dataset.add(example);
}
SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
use of org.tribuo.regression.Regressor in project ml-commons by opensearch-project.
the class TribuoUtil method generateDataset.
/**
* Generate tribuo dataset from data frame.
* @param dataFrame features data
* @param outputFactory the tribuo output factory
* @param desc description for tribuo provenance
* @param outputType the tribuo output type
* @return tribuo dataset
*/
public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
List<Example<T>> dataset = new ArrayList<>();
Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
ArrayExample<T> example;
for (int i = 0; i < dataFrame.size(); ++i) {
switch(outputType) {
case CLUSTERID:
example = new ArrayExample<>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
case REGRESSOR:
// Create single dimension tribuo regressor with name DIM-0 and value double NaN.
example = new ArrayExample<>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
case ANOMALY_DETECTION_LIBSVM:
// Why we set default event type as EXPECTED(non-anomalous)
// 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
// 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
Event.EventType defaultEventType = Event.EventType.EXPECTED;
// TODO: support anomaly labels to evaluate prediction result
example = new ArrayExample<>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
}
dataset.add(example);
}
SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
use of org.tribuo.regression.Regressor in project ml-commons by opensearch-project.
the class LinearRegression method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for linear regression prediction.");
}
org.tribuo.Model<Regressor> regressionModel = (org.tribuo.Model<Regressor>) ModelSerDeSer.deserialize(model.getContent());
MutableDataset<Regressor> predictionDataset = TribuoUtil.generateDataset(dataFrame, new RegressionFactory(), "Linear regression prediction data from opensearch", TribuoOutputType.REGRESSOR);
List<Prediction<Regressor>> predictions = regressionModel.predict(predictionDataset);
List<Map<String, Object>> listPrediction = new ArrayList<>();
predictions.forEach(e -> listPrediction.add(Collections.singletonMap(e.getOutput().getNames()[0], e.getOutput().getValues()[0])));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listPrediction)).build();
}
use of org.tribuo.regression.Regressor in project ml-commons by opensearch-project.
the class TribuoUtilTest method generateDatasetWithTarget.
@SuppressWarnings("unchecked")
@Test
public void generateDatasetWithTarget() {
MutableDataset<Regressor> dataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, "f2");
List<Example<Regressor>> examples = dataset.getData();
Assert.assertEquals(rawData.length, examples.size());
for (int i = 0; i < rawData.length; ++i) {
ArrayExample arrayExample = (ArrayExample) examples.get(i);
Iterator<Feature> iterator = arrayExample.iterator();
int idx = 1;
while (iterator.hasNext()) {
Feature feature = iterator.next();
Assert.assertEquals("f" + idx, feature.getName());
Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
++idx;
}
}
}
use of org.tribuo.regression.Regressor in project ml-commons by opensearch-project.
the class LinearRegression method train.
@Override
public Model train(DataFrame dataFrame) {
MutableDataset<Regressor> trainDataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "Linear regression training data from opensearch", TribuoOutputType.REGRESSOR, parameters.getTarget());
Integer epochs = Optional.ofNullable(parameters.getEpochs()).orElse(DEFAULT_EPOCHS);
LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(objective, optimiser, epochs, DEFAULT_INTERVAL, DEFAULT_BATCH_SIZE, seed);
org.tribuo.Model<Regressor> regressionModel = linearSGDTrainer.train(trainDataset);
Model model = new Model();
model.setName(FunctionName.LINEAR_REGRESSION.name());
model.setVersion(1);
model.setContent(ModelSerDeSer.serialize(regressionModel));
return model;
}
Aggregations