Search in sources :

Example 1 with RegressionFactory

use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.

the class LinearRegression method predict.

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
    if (model == null) {
        throw new IllegalArgumentException("No model found for linear regression prediction.");
    }
    org.tribuo.Model<Regressor> regressionModel = (org.tribuo.Model<Regressor>) ModelSerDeSer.deserialize(model.getContent());
    MutableDataset<Regressor> predictionDataset = TribuoUtil.generateDataset(dataFrame, new RegressionFactory(), "Linear regression prediction data from opensearch", TribuoOutputType.REGRESSOR);
    List<Prediction<Regressor>> predictions = regressionModel.predict(predictionDataset);
    List<Map<String, Object>> listPrediction = new ArrayList<>();
    predictions.forEach(e -> listPrediction.add(Collections.singletonMap(e.getOutput().getNames()[0], e.getOutput().getValues()[0])));
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listPrediction)).build();
}
Also used : RegressionFactory(org.tribuo.regression.RegressionFactory) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) Model(org.opensearch.ml.common.parameter.Model) Regressor(org.tribuo.regression.Regressor) Map(java.util.Map)

Example 2 with RegressionFactory

use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDatasetWithTarget.

@SuppressWarnings("unchecked")
@Test
public void generateDatasetWithTarget() {
    MutableDataset<Regressor> dataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, "f2");
    List<Example<Regressor>> examples = dataset.getData();
    Assert.assertEquals(rawData.length, examples.size());
    for (int i = 0; i < rawData.length; ++i) {
        ArrayExample arrayExample = (ArrayExample) examples.get(i);
        Iterator<Feature> iterator = arrayExample.iterator();
        int idx = 1;
        while (iterator.hasNext()) {
            Feature feature = iterator.next();
            Assert.assertEquals("f" + idx, feature.getName());
            Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
            ++idx;
        }
    }
}
Also used : ArrayExample(org.tribuo.impl.ArrayExample) RegressionFactory(org.tribuo.regression.RegressionFactory) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) Feature(org.tribuo.Feature) Test(org.junit.Test)

Example 3 with RegressionFactory

use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDatasetWithUnmatchedTarget.

@Test
public void generateDatasetWithUnmatchedTarget() {
    exceptionRule.expect(RuntimeException.class);
    exceptionRule.expectMessage("No matched target when generating dataset from data frame.");
    TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, "f0");
}
Also used : RegressionFactory(org.tribuo.regression.RegressionFactory) Test(org.junit.Test)

Example 4 with RegressionFactory

use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDatasetWithEmptyTarget.

@Test
public void generateDatasetWithEmptyTarget() {
    exceptionRule.expect(RuntimeException.class);
    exceptionRule.expectMessage("Empty target when generating dataset from data frame.");
    TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "test", TribuoOutputType.REGRESSOR, null);
}
Also used : RegressionFactory(org.tribuo.regression.RegressionFactory) Test(org.junit.Test)

Example 5 with RegressionFactory

use of org.tribuo.regression.RegressionFactory in project ml-commons by opensearch-project.

the class LinearRegression method train.

@Override
public Model train(DataFrame dataFrame) {
    MutableDataset<Regressor> trainDataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "Linear regression training data from opensearch", TribuoOutputType.REGRESSOR, parameters.getTarget());
    Integer epochs = Optional.ofNullable(parameters.getEpochs()).orElse(DEFAULT_EPOCHS);
    LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(objective, optimiser, epochs, DEFAULT_INTERVAL, DEFAULT_BATCH_SIZE, seed);
    org.tribuo.Model<Regressor> regressionModel = linearSGDTrainer.train(trainDataset);
    Model model = new Model();
    model.setName(FunctionName.LINEAR_REGRESSION.name());
    model.setVersion(1);
    model.setContent(ModelSerDeSer.serialize(regressionModel));
    return model;
}
Also used : LinearSGDTrainer(org.tribuo.regression.sgd.linear.LinearSGDTrainer) RegressionFactory(org.tribuo.regression.RegressionFactory) Model(org.opensearch.ml.common.parameter.Model) Regressor(org.tribuo.regression.Regressor)

Aggregations

RegressionFactory (org.tribuo.regression.RegressionFactory)5 Test (org.junit.Test)3 Regressor (org.tribuo.regression.Regressor)3 Model (org.opensearch.ml.common.parameter.Model)2 ArrayList (java.util.ArrayList)1 Map (java.util.Map)1 Example (org.tribuo.Example)1 Feature (org.tribuo.Feature)1 Prediction (org.tribuo.Prediction)1 ArrayExample (org.tribuo.impl.ArrayExample)1 LinearSGDTrainer (org.tribuo.regression.sgd.linear.LinearSGDTrainer)1