Search in sources :

Example 1 with Row

use of org.opensearch.ml.common.dataframe.Row in project ml-commons by opensearch-project.

the class MLPredictionOutputTest method readInputStream.

private void readInputStream(MLPredictionOutput output) throws IOException {
    BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
    output.writeTo(bytesStreamOutput);
    StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
    MLOutputType outputType = streamInput.readEnum(MLOutputType.class);
    assertEquals(MLOutputType.PREDICTION, outputType);
    MLPredictionOutput parsedOutput = new MLPredictionOutput(streamInput);
    assertEquals(output.getType(), parsedOutput.getType());
    assertEquals(output.getTaskId(), parsedOutput.getTaskId());
    assertEquals(output.getStatus(), parsedOutput.getStatus());
    if (output.predictionResult == null) {
        assertEquals(output.predictionResult, parsedOutput.getPredictionResult());
    } else {
        assertEquals(output.predictionResult.size(), parsedOutput.getPredictionResult().size());
        for (int i = 0; i < output.predictionResult.size(); i++) {
            Row row = output.predictionResult.getRow(i);
            Row parsedRow = parsedOutput.predictionResult.getRow(i);
            assertEquals(row, parsedRow);
        }
    }
}
Also used : StreamInput(org.opensearch.common.io.stream.StreamInput) Row(org.opensearch.ml.common.dataframe.Row) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput)

Example 2 with Row

use of org.opensearch.ml.common.dataframe.Row in project ml-commons by opensearch-project.

the class MLPredictionOutputTest method setUp.

@Before
public void setUp() {
    ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.INTEGER) };
    List<Row> rows = new ArrayList<>();
    rows.add(new Row(new ColumnValue[] { new IntValue(1) }));
    rows.add(new Row(new ColumnValue[] { new IntValue(2) }));
    DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
    output = MLPredictionOutput.builder().taskId("test_task_id").status("test_status").predictionResult(dataFrame).build();
}
Also used : ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) ArrayList(java.util.ArrayList) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) Row(org.opensearch.ml.common.dataframe.Row) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) IntValue(org.opensearch.ml.common.dataframe.IntValue) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) Before(org.junit.Before)

Example 3 with Row

use of org.opensearch.ml.common.dataframe.Row in project ml-commons by opensearch-project.

the class TribuoUtil method transformDataFrame.

public static Tuple<String[], double[][]> transformDataFrame(DataFrame dataFrame) {
    String[] featureNames = Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).toArray(String[]::new);
    double[][] featureValues = new double[dataFrame.size()][];
    Iterator<Row> itr = dataFrame.iterator();
    int i = 0;
    while (itr.hasNext()) {
        Row row = itr.next();
        featureValues[i] = StreamSupport.stream(row.spliterator(), false).mapToDouble(ColumnValue::doubleValue).toArray();
        ++i;
    }
    return new Tuple<>(featureNames, featureValues);
}
Also used : ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) Row(org.opensearch.ml.common.dataframe.Row) Tuple(org.opensearch.common.collect.Tuple)

Example 4 with Row

use of org.opensearch.ml.common.dataframe.Row in project ml-commons by opensearch-project.

the class FixedInTimeRandomCutForest method process.

private List<Map<String, Object>> process(DataFrame dataFrame, ThresholdedRandomCutForest forest) {
    List<Double> pointList = new ArrayList<>();
    ColumnMeta[] columnMetas = dataFrame.columnMetas();
    List<Map<String, Object>> predictResult = new ArrayList<>();
    for (int rowNum = 0; rowNum < dataFrame.size(); rowNum++) {
        Row row = dataFrame.getRow(rowNum);
        long timestamp = -1;
        for (int i = 0; i < columnMetas.length; i++) {
            ColumnMeta columnMeta = columnMetas[i];
            ColumnValue value = row.getValue(i);
            // TODO: sort dataframe by time field with asc order. Currently consider the date already sorted by time.
            if (timeField != null && timeField.equals(columnMeta.getName())) {
                ColumnType columnType = columnMeta.getColumnType();
                if (columnType == ColumnType.LONG) {
                    timestamp = value.longValue();
                } else if (columnType == ColumnType.STRING) {
                    try {
                        timestamp = simpleDateFormat.parse(value.stringValue()).getTime();
                    } catch (ParseException e) {
                        log.error("Failed to parse timestamp " + value.stringValue(), e);
                        throw new MLValidationException("Failed to parse timestamp " + value.stringValue());
                    }
                } else {
                    throw new MLValidationException("Wrong data type of time field. Should use LONG or STRING, but got " + columnType);
                }
            } else {
                pointList.add(value.doubleValue());
            }
        }
        double[] point = pointList.stream().mapToDouble(d -> d).toArray();
        pointList.clear();
        Map<String, Object> result = new HashMap<>();
        AnomalyDescriptor process = forest.process(point, timestamp);
        result.put(timeField, timestamp);
        result.put("score", process.getRCFScore());
        result.put("anomaly_grade", process.getAnomalyGrade());
        predictResult.add(result);
    }
    return predictResult;
}
Also used : MLOutput(org.opensearch.ml.common.parameter.MLOutput) Precision(com.amazon.randomcutforest.config.Precision) ThresholdedRandomCutForestMapper(com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper) SimpleDateFormat(java.text.SimpleDateFormat) MLValidationException(org.opensearch.ml.common.exception.MLValidationException) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) FunctionName(org.opensearch.ml.common.parameter.FunctionName) Map(java.util.Map) MLAlgoParams(org.opensearch.ml.common.parameter.MLAlgoParams) FitRCFParams(org.opensearch.ml.common.parameter.FitRCFParams) DataFrameBuilder(org.opensearch.ml.common.dataframe.DataFrameBuilder) ParseException(java.text.ParseException) DateFormat(java.text.DateFormat) Row(org.opensearch.ml.common.dataframe.Row) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) TimeZone(java.util.TimeZone) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Function(org.opensearch.ml.engine.annotation.Function) ThresholdedRandomCutForestState(com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState) List(java.util.List) ColumnType(org.opensearch.ml.common.dataframe.ColumnType) Model(org.opensearch.ml.common.parameter.Model) ThresholdedRandomCutForest(com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest) ModelSerDeSer(org.opensearch.ml.engine.utils.ModelSerDeSer) Log4j2(lombok.extern.log4j.Log4j2) Optional(java.util.Optional) ForestMode(com.amazon.randomcutforest.config.ForestMode) TrainAndPredictable(org.opensearch.ml.engine.TrainAndPredictable) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) AnomalyDescriptor(com.amazon.randomcutforest.parkservices.AnomalyDescriptor) ColumnType(org.opensearch.ml.common.dataframe.ColumnType) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) MLValidationException(org.opensearch.ml.common.exception.MLValidationException) AnomalyDescriptor(com.amazon.randomcutforest.parkservices.AnomalyDescriptor) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) Row(org.opensearch.ml.common.dataframe.Row) ParseException(java.text.ParseException) HashMap(java.util.HashMap) Map(java.util.Map)

Example 5 with Row

use of org.opensearch.ml.common.dataframe.Row in project ml-commons by opensearch-project.

the class MLInputTest method setUp.

@Before
public void setUp() throws Exception {
    final ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.DOUBLE) };
    List<Row> rows = new ArrayList<>();
    rows.add(new Row(new ColumnValue[] { new DoubleValue(1.0) }));
    rows.add(new Row(new ColumnValue[] { new DoubleValue(2.0) }));
    rows.add(new Row(new ColumnValue[] { new DoubleValue(3.0) }));
    DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
    input = MLInput.builder().algorithm(algorithm).parameters(LinearRegressionParams.builder().learningRate(0.1).build()).inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()).build();
}
Also used : ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) DoubleValue(org.opensearch.ml.common.dataframe.DoubleValue) ArrayList(java.util.ArrayList) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) Row(org.opensearch.ml.common.dataframe.Row) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) Before(org.junit.Before)

Aggregations

Row (org.opensearch.ml.common.dataframe.Row)8 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)6 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)5 ArrayList (java.util.ArrayList)4 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)4 DefaultDataFrame (org.opensearch.ml.common.dataframe.DefaultDataFrame)4 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)3 Model (org.opensearch.ml.common.parameter.Model)3 HashMap (java.util.HashMap)2 List (java.util.List)2 Map (java.util.Map)2 Optional (java.util.Optional)2 Log4j2 (lombok.extern.log4j.Log4j2)2 Before (org.junit.Before)2 DataFrameBuilder (org.opensearch.ml.common.dataframe.DataFrameBuilder)2 FunctionName (org.opensearch.ml.common.parameter.FunctionName)2 MLAlgoParams (org.opensearch.ml.common.parameter.MLAlgoParams)2 MLOutput (org.opensearch.ml.common.parameter.MLOutput)2 TrainAndPredictable (org.opensearch.ml.engine.TrainAndPredictable)2 Function (org.opensearch.ml.engine.annotation.Function)2