use of org.opensearch.ml.common.dataframe.ColumnValue in project ml-commons by opensearch-project.
the class MLPredictionOutputTest method setUp.
@Before
public void setUp() {
ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.INTEGER) };
List<Row> rows = new ArrayList<>();
rows.add(new Row(new ColumnValue[] { new IntValue(1) }));
rows.add(new Row(new ColumnValue[] { new IntValue(2) }));
DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
output = MLPredictionOutput.builder().taskId("test_task_id").status("test_status").predictionResult(dataFrame).build();
}
use of org.opensearch.ml.common.dataframe.ColumnValue in project ml-commons by opensearch-project.
the class TribuoUtil method transformDataFrame.
public static Tuple<String[], double[][]> transformDataFrame(DataFrame dataFrame) {
String[] featureNames = Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).toArray(String[]::new);
double[][] featureValues = new double[dataFrame.size()][];
Iterator<Row> itr = dataFrame.iterator();
int i = 0;
while (itr.hasNext()) {
Row row = itr.next();
featureValues[i] = StreamSupport.stream(row.spliterator(), false).mapToDouble(ColumnValue::doubleValue).toArray();
++i;
}
return new Tuple<>(featureNames, featureValues);
}
use of org.opensearch.ml.common.dataframe.ColumnValue in project ml-commons by opensearch-project.
the class FixedInTimeRandomCutForest method process.
private List<Map<String, Object>> process(DataFrame dataFrame, ThresholdedRandomCutForest forest) {
List<Double> pointList = new ArrayList<>();
ColumnMeta[] columnMetas = dataFrame.columnMetas();
List<Map<String, Object>> predictResult = new ArrayList<>();
for (int rowNum = 0; rowNum < dataFrame.size(); rowNum++) {
Row row = dataFrame.getRow(rowNum);
long timestamp = -1;
for (int i = 0; i < columnMetas.length; i++) {
ColumnMeta columnMeta = columnMetas[i];
ColumnValue value = row.getValue(i);
// TODO: sort dataframe by time field with asc order. Currently consider the date already sorted by time.
if (timeField != null && timeField.equals(columnMeta.getName())) {
ColumnType columnType = columnMeta.getColumnType();
if (columnType == ColumnType.LONG) {
timestamp = value.longValue();
} else if (columnType == ColumnType.STRING) {
try {
timestamp = simpleDateFormat.parse(value.stringValue()).getTime();
} catch (ParseException e) {
log.error("Failed to parse timestamp " + value.stringValue(), e);
throw new MLValidationException("Failed to parse timestamp " + value.stringValue());
}
} else {
throw new MLValidationException("Wrong data type of time field. Should use LONG or STRING, but got " + columnType);
}
} else {
pointList.add(value.doubleValue());
}
}
double[] point = pointList.stream().mapToDouble(d -> d).toArray();
pointList.clear();
Map<String, Object> result = new HashMap<>();
AnomalyDescriptor process = forest.process(point, timestamp);
result.put(timeField, timestamp);
result.put("score", process.getRCFScore());
result.put("anomaly_grade", process.getAnomalyGrade());
predictResult.add(result);
}
return predictResult;
}
use of org.opensearch.ml.common.dataframe.ColumnValue in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVMTest method constructDataFrame.
private DataFrame constructDataFrame(Dataset<Event> data, boolean training, List<Event.EventType> labels) {
Iterator<Example<Event>> iterator = data.iterator();
List<ColumnMeta> columns = null;
DataFrame dataFrame = null;
while (iterator.hasNext()) {
Example<Event> example = iterator.next();
if (columns == null) {
columns = new ArrayList<>();
List<ColumnValue> columnValues = new ArrayList<>();
for (Feature feature : example) {
columns.add(new ColumnMeta(feature.getName(), ColumnType.DOUBLE));
columnValues.add(new DoubleValue(feature.getValue()));
}
ColumnMeta[] columnMetas = columns.toArray(new ColumnMeta[columns.size()]);
dataFrame = new DefaultDataFrame(columnMetas);
addRow(columnValues, training, example, dataFrame, labels);
} else {
List<ColumnValue> columnValues = new ArrayList<>();
for (Feature feature : example) {
columnValues.add(new DoubleValue(feature.getValue()));
}
addRow(columnValues, training, example, dataFrame, labels);
}
}
return dataFrame;
}
use of org.opensearch.ml.common.dataframe.ColumnValue in project ml-commons by opensearch-project.
the class MLInputTest method setUp.
@Before
public void setUp() throws Exception {
final ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.DOUBLE) };
List<Row> rows = new ArrayList<>();
rows.add(new Row(new ColumnValue[] { new DoubleValue(1.0) }));
rows.add(new Row(new ColumnValue[] { new DoubleValue(2.0) }));
rows.add(new Row(new ColumnValue[] { new DoubleValue(3.0) }));
DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
input = MLInput.builder().algorithm(algorithm).parameters(LinearRegressionParams.builder().learningRate(0.1).build()).inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()).build();
}
Aggregations