use of org.verdictdb.coordinator.VerdictSingleResultFromDbmsQueryResult in project traindb by traindb-project.
the class TrainDBQueryEngine method trainModelInstance.
@Override
public void trainModelInstance(String modelName, String modelInstanceName, String schemaName, String tableName, List<String> columnNames) throws Exception {
if (catalogContext.modelInstanceExists(modelInstanceName)) {
throw new CatalogException("model instance '" + modelInstanceName + "' already exist");
}
if (schemaName == null) {
schemaName = conn.getDefaultSchema();
}
JSONObject tableMetadata = getTableMetadata(schemaName, tableName, columnNames);
Path instancePath = catalogContext.getModelInstancePath(modelName, modelInstanceName);
Files.createDirectories(instancePath);
String outputPath = instancePath.toString();
// write metadata for model training scripts in python
String metadataFilename = outputPath + "/metadata.json";
FileWriter fileWriter = new FileWriter(metadataFilename);
fileWriter.write(tableMetadata.toJSONString());
fileWriter.flush();
fileWriter.close();
// FIXME securely pass training data for ML model training
DbmsQueryResult trainingData = getTrainingData(schemaName, tableName, columnNames);
String dataFilename = outputPath + "/data.csv";
FileWriter datafileWriter = new FileWriter(dataFilename);
datafileWriter.write(new VerdictSingleResultFromDbmsQueryResult(trainingData).toCsv());
datafileWriter.flush();
datafileWriter.close();
MModel mModel = catalogContext.getModel(modelName);
// train ML model
ProcessBuilder pb = new ProcessBuilder("python", conf.getModelRunnerPath(), "train", mModel.getClassName(), TrainDBConfiguration.absoluteUri(mModel.getUri()), dataFilename, metadataFilename, outputPath);
pb.inheritIO();
Process process = pb.start();
process.waitFor();
catalogContext.trainModelInstance(modelName, modelInstanceName, schemaName, tableName, columnNames);
}
use of org.verdictdb.coordinator.VerdictSingleResultFromDbmsQueryResult in project traindb by traindb-project.
the class TrainDBQueryEngine method processQuery.
@Override
public VerdictSingleResult processQuery(String query) throws Exception {
SqlParser.Config parserConf = SqlParser.config().withParserFactory(TrainDBCalciteSQLParserImpl.FACTORY).withUnquotedCasing(Casing.TO_LOWER);
FrameworkConfig config = Frameworks.newConfigBuilder().defaultSchema(schemaManager.getCurrentSchema()).parserConfig(parserConf).build();
Planner planner = Frameworks.getPlanner(config);
SqlNode parse = planner.parse(query);
TableNameQualifier.toFullyQualifiedName(schemaManager, conn.getDefaultSchema(), parse);
LOG.debug("Parsed query: " + parse.toString());
SqlNode validate = planner.validate(parse);
RelRoot relRoot = planner.rel(validate);
LOG.debug(RelOptUtil.dumpPlan("Generated plan: ", relRoot.rel, SqlExplainFormat.TEXT, SqlExplainLevel.ALL_ATTRIBUTES));
SqlDialect.DatabaseProduct dp = SqlDialect.DatabaseProduct.POSTGRESQL;
String queryString = validate.toSqlString(dp.getDialect()).getSql();
LOG.debug("query string: " + queryString);
try {
Connection internalConn = DriverManager.getConnection("jdbc:traindb-calcite:");
PreparedStatement stmt = internalConn.prepareStatement(queryString);
ResultSet rs = stmt.executeQuery();
return new VerdictSingleResultFromDbmsQueryResult(new JdbcQueryResult(rs));
} catch (SQLException e) {
LOG.debug(ExceptionUtils.getStackTrace(e));
}
return null;
}
Aggregations