use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.
the class LinearRegression method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for linear regression prediction.");
}
org.tribuo.Model<Regressor> regressionModel = (org.tribuo.Model<Regressor>) ModelSerDeSer.deserialize(model.getContent());
MutableDataset<Regressor> predictionDataset = TribuoUtil.generateDataset(dataFrame, new RegressionFactory(), "Linear regression prediction data from opensearch", TribuoOutputType.REGRESSOR);
List<Prediction<Regressor>> predictions = regressionModel.predict(predictionDataset);
List<Map<String, Object>> listPrediction = new ArrayList<>();
predictions.forEach(e -> listPrediction.add(Collections.singletonMap(e.getOutput().getNames()[0], e.getOutput().getValues()[0])));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listPrediction)).build();
}
use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.
the class TribuoUtilTest method generateDatasetWithTarget.
@SuppressWarnings("unchecked")
@Test
public void generateDatasetWithTarget() {
MutableDataset<Regressor> dataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, "f2");
List<Example<Regressor>> examples = dataset.getData();
Assert.assertEquals(rawData.length, examples.size());
for (int i = 0; i < rawData.length; ++i) {
ArrayExample arrayExample = (ArrayExample) examples.get(i);
Iterator<Feature> iterator = arrayExample.iterator();
int idx = 1;
while (iterator.hasNext()) {
Feature feature = iterator.next();
Assert.assertEquals("f" + idx, feature.getName());
Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
++idx;
}
}
}
use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.
the class TribuoUtilTest method generateDatasetWithUnmatchedTarget.
@Test
public void generateDatasetWithUnmatchedTarget() {
exceptionRule.expect(RuntimeException.class);
exceptionRule.expectMessage("No matched target when generating dataset from data frame.");
TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, "f0");
}
use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.
the class TribuoUtilTest method generateDatasetWithEmptyTarget.
@Test
public void generateDatasetWithEmptyTarget() {
exceptionRule.expect(RuntimeException.class);
exceptionRule.expectMessage("Empty target when generating dataset from data frame.");
TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, null);
}
use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.
the class LinearRegression method train.
@Override
public Model train(DataFrame dataFrame) {
MutableDataset<Regressor> trainDataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "Linear regression training data from opensearch", TribuoOutputType.REGRESSOR, parameters.getTarget());
Integer epochs = Optional.ofNullable(parameters.getEpochs()).orElse(DEFAULT_EPOCHS);
LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(objective, optimiser, epochs, DEFAULT_INTERVAL, DEFAULT_BATCH_SIZE, seed);
org.tribuo.Model<Regressor> regressionModel = linearSGDTrainer.train(trainDataset);
Model model = new Model();
model.setName(FunctionName.LINEAR_REGRESSION.name());
model.setVersion(1);
model.setContent(ModelSerDeSer.serialize(regressionModel));
return model;
}
Aggregations