Search in sources :

Example 1 with Svm

use of org.knime.base.node.mine.svm.Svm in project knime-core by knime.

the class SVMLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    BufferedDataTable inTable = (BufferedDataTable) inData[0];
    DataTableSpec inSpec = inTable.getDataTableSpec();
    LearnColumnsAndColumnRearrangerTuple tuple = createTrainTableColumnRearranger(inSpec);
    // no progress needed as constant operation (column removal only)
    BufferedDataTable trainTable = exec.createColumnRearrangeTable(inTable, tuple.getTrainingRearranger(), exec.createSubProgress(0.0));
    DataTableSpec trainSpec = trainTable.getDataTableSpec();
    int classpos = trainSpec.findColumnIndex(m_classcol.getStringValue());
    CheckUtils.checkArgument(classpos >= 0, "Selected class column not found: " + m_classcol.getStringValue());
    // convert input data
    ArrayList<DoubleVector> inputData = new ArrayList<DoubleVector>();
    List<String> categories = new ArrayList<String>();
    StringValue classvalue = null;
    for (DataRow row : trainTable) {
        exec.checkCanceled();
        ArrayList<Double> values = new ArrayList<Double>();
        boolean add = true;
        for (int i = 0; i < row.getNumCells(); i++) {
            if (row.getCell(i).isMissing()) {
                add = false;
                break;
            }
            if (i != classpos) {
                DoubleValue cell = (DoubleValue) row.getCell(i);
                values.add(cell.getDoubleValue());
            } else {
                classvalue = (StringValue) row.getCell(classpos);
                if (!categories.contains(classvalue.getStringValue())) {
                    categories.add(classvalue.getStringValue());
                }
            }
        }
        if (add) {
            @SuppressWarnings("null") final String nonNullClassValue = classvalue.getStringValue();
            inputData.add(new DoubleVector(row.getKey(), values, nonNullClassValue));
        }
    }
    if (categories.isEmpty()) {
        throw new Exception("No categories found to train SVM. " + "Possibly an empty input table was provided.");
    }
    DoubleVector[] inputDataArr = new DoubleVector[inputData.size()];
    inputDataArr = inputData.toArray(inputDataArr);
    Kernel kernel = KernelFactory.getKernel(m_kernelType);
    Vector<SettingsModelDouble> kernelparams = m_kernelParameters.get(m_kernelType);
    for (int i = 0; i < kernel.getNumberParameters(); ++i) {
        kernel.setParameter(i, kernelparams.get(i).getDoubleValue());
    }
    final Svm[] svms = new Svm[categories.size()];
    exec.setMessage("Training SVM");
    final BinarySvmRunnable[] bst = new BinarySvmRunnable[categories.size()];
    for (int i = 0; i < categories.size(); i++) {
        bst[i] = new BinarySvmRunnable(inputDataArr, categories.get(i), kernel, m_paramC.getDoubleValue(), exec.createSubProgress((1.0 / categories.size())));
    }
    ThreadPool pool = KNIMEConstants.GLOBAL_THREAD_POOL;
    final Future<?>[] fut = new Future<?>[bst.length];
    KNIMETimer timer = KNIMETimer.getInstance();
    TimerTask timerTask = new TimerTask() {

        @Override
        public void run() {
            try {
                exec.checkCanceled();
            } catch (final CanceledExecutionException ce) {
                for (int i = 0; i < fut.length; i++) {
                    if (fut[i] != null) {
                        fut[i].cancel(true);
                    }
                }
                super.cancel();
            }
        }
    };
    timer.scheduleAtFixedRate(timerTask, 0, 3000);
    for (int i = 0; i < bst.length; i++) {
        fut[i] = pool.enqueue(bst[i]);
    }
    try {
        pool.runInvisible(new Callable<Void>() {

            @Override
            public Void call() throws Exception {
                for (int i = 0; i < fut.length; ++i) {
                    fut[i].get();
                    bst[i].ok();
                    if (bst[i].getWarning() != null) {
                        setWarningMessage(bst[i].getWarning());
                    }
                    svms[i] = bst[i].getSvm();
                }
                return null;
            }
        });
    } catch (Exception ex) {
        exec.checkCanceled();
        Throwable t = ex;
        if (ex instanceof ExecutionException) {
            t = ex.getCause();
        }
        if (t instanceof Exception) {
            throw (Exception) t;
        } else {
            throw new Exception(t);
        }
    } finally {
        for (int i = 0; i < fut.length; i++) {
            fut[i].cancel(true);
        }
        timerTask.cancel();
    }
    // the optional PMML input (can be null)
    PMMLPortObject inPMMLPort = m_pmmlInEnabled ? (PMMLPortObject) inData[1] : null;
    // create the outgoing PMML spec
    PMMLPortObjectSpecCreator specCreator = new PMMLPortObjectSpecCreator(inPMMLPort, inSpec);
    specCreator.setLearningCols(trainSpec);
    specCreator.setTargetCol(trainSpec.getColumnSpec(m_classcol.getStringValue()));
    // create the outgoing PMML port object
    PMMLPortObject outPMMLPort = new PMMLPortObject(specCreator.createSpec(), inPMMLPort, inSpec);
    outPMMLPort.addModelTranslater(new PMMLSVMTranslator(categories, Arrays.asList(svms), kernel));
    m_svms = svms;
    return new PortObject[] { outPMMLPort };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) PMMLSVMTranslator(org.knime.base.node.mine.svm.PMMLSVMTranslator) ArrayList(java.util.ArrayList) ThreadPool(org.knime.core.util.ThreadPool) SettingsModelDouble(org.knime.core.node.defaultnodesettings.SettingsModelDouble) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) Svm(org.knime.base.node.mine.svm.Svm) DataRow(org.knime.core.data.DataRow) TimerTask(java.util.TimerTask) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) BufferedDataTable(org.knime.core.node.BufferedDataTable) StringValue(org.knime.core.data.StringValue) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) Kernel(org.knime.base.node.mine.svm.kernel.Kernel) PortObject(org.knime.core.node.port.PortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) KNIMETimer(org.knime.core.util.KNIMETimer) BinarySvmRunnable(org.knime.base.node.mine.svm.util.BinarySvmRunnable) SettingsModelDouble(org.knime.core.node.defaultnodesettings.SettingsModelDouble) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) DoubleValue(org.knime.core.data.DoubleValue) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) Future(java.util.concurrent.Future) DoubleVector(org.knime.base.node.mine.svm.util.DoubleVector) PMMLPortObjectSpecCreator(org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator)

Example 2 with Svm

use of org.knime.base.node.mine.svm.Svm in project knime-core by knime.

the class SVMPredictorNodeModel method adjustOrder.

/**
 * @param targetSpec The target column from the model.
 */
private void adjustOrder(final DataColumnSpec targetSpec) {
    if (targetSpec.getDomain() != null) {
        Map<String, Svm> map = new LinkedHashMap<>();
        for (Svm svm : m_svms) {
            map.put(svm.getPositive(), svm);
        }
        int i = 0;
        for (DataCell v : targetSpec.getDomain().getValues()) {
            String key = v.toString();
            Svm svm = map.get(key);
            if (svm != null) {
                m_svms[i++] = svm;
            }
        }
    }
}
Also used : DataCell(org.knime.core.data.DataCell) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) Svm(org.knime.base.node.mine.svm.Svm) LinkedHashMap(java.util.LinkedHashMap)

Example 3 with Svm

use of org.knime.base.node.mine.svm.Svm in project knime-core by knime.

the class SVMPredictorNodeModel method createColumnRearranger.

private ColumnRearranger createColumnRearranger(final PMMLPortObject pmmlModel, final DataTableSpec inSpec) throws InvalidSettingsException {
    List<Node> models = pmmlModel.getPMMLValue().getModels(PMMLModelType.SupportVectorMachineModel);
    if (models.isEmpty()) {
        String msg = "SVM evaluation failed: " + "No support vector machine model found.";
        LOGGER.error(msg);
        throw new RuntimeException(msg);
    }
    PMMLSVMTranslator trans = new PMMLSVMTranslator();
    pmmlModel.initializeModelTranslator(trans);
    List<Svm> svms = trans.getSVMs();
    m_svms = svms.toArray(new Svm[svms.size()]);
    if (m_addProbabilities.getBooleanValue() == pmmlModel.getSpec().getTargetCols().size() > 0) {
        adjustOrder(pmmlModel.getSpec().getTargetCols().get(0));
    }
    DataTableSpec testSpec = inSpec;
    PMMLPortObjectSpec pmmlSpec = pmmlModel.getSpec();
    DataTableSpec trainingSpec = pmmlSpec.getDataTableSpec();
    // try to find all columns (except the class column)
    Vector<Integer> colindices = new Vector<Integer>();
    for (DataColumnSpec colspec : trainingSpec) {
        if (colspec.getType().isCompatible(DoubleValue.class)) {
            int colindex = testSpec.findColumnIndex(colspec.getName());
            if (colindex < 0) {
                throw new InvalidSettingsException("Column " + "\'" + colspec.getName() + "\' not found" + " in test data");
            }
            colindices.add(colindex);
        }
    }
    m_colindices = new int[colindices.size()];
    for (int i = 0; i < m_colindices.length; i++) {
        m_colindices[i] = colindices.get(i);
    }
    final PredictorHelper predictorHelper = PredictorHelper.getInstance();
    final String targetCol = pmmlSpec.getTargetFields().iterator().next();
    SVMPredictor svmpredict = new SVMPredictor(targetCol, m_svms, m_colindices, predictorHelper.computePredictionColumnName(m_predictionColumn.getStringValue(), m_overridePrediction.getBooleanValue(), targetCol), m_addProbabilities.getBooleanValue(), m_suffix.getStringValue());
    ColumnRearranger colre = new ColumnRearranger(testSpec);
    colre.append(svmpredict);
    return colre;
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PredictorHelper(org.knime.base.node.mine.util.PredictorHelper) PMMLSVMTranslator(org.knime.base.node.mine.svm.PMMLSVMTranslator) Node(org.w3c.dom.Node) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) Svm(org.knime.base.node.mine.svm.Svm) DataColumnSpec(org.knime.core.data.DataColumnSpec) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) Vector(java.util.Vector)

Example 4 with Svm

use of org.knime.base.node.mine.svm.Svm in project knime-core by knime.

the class SVMPredictorNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
public PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    PMMLPortObject port = (PMMLPortObject) inData[0];
    List<Node> models = port.getPMMLValue().getModels(PMMLModelType.SupportVectorMachineModel);
    if (models.isEmpty()) {
        String msg = "SVM evaluation failed: " + "No support vector machine model found.";
        LOGGER.error(msg);
        throw new RuntimeException(msg);
    }
    PMMLSVMTranslator trans = new PMMLSVMTranslator();
    port.initializeModelTranslator(trans);
    List<Svm> svms = trans.getSVMs();
    m_svms = svms.toArray(new Svm[svms.size()]);
    DataTableSpec testSpec = ((BufferedDataTable) inData[1]).getDataTableSpec();
    DataTableSpec trainingSpec = ((PMMLPortObject) inData[0]).getSpec().getDataTableSpec();
    // try to find all columns (except the class column)
    Vector<Integer> colindices = new Vector<Integer>();
    for (DataColumnSpec colspec : trainingSpec) {
        if (colspec.getType().isCompatible(DoubleValue.class)) {
            int colindex = testSpec.findColumnIndex(colspec.getName());
            if (colindex < 0) {
                throw new InvalidSettingsException("Column " + "\'" + colspec.getName() + "\' not found" + " in test data");
            }
            colindices.add(colindex);
        }
    }
    m_colindices = new int[colindices.size()];
    for (int i = 0; i < m_colindices.length; i++) {
        m_colindices[i] = colindices.get(i);
    }
    SVMPredictor svmpredict = new SVMPredictor(m_svms, m_colindices);
    BufferedDataTable testData = (BufferedDataTable) inData[1];
    ColumnRearranger colre = new ColumnRearranger(testData.getDataTableSpec());
    colre.append(svmpredict);
    BufferedDataTable result = exec.createColumnRearrangeTable(testData, colre, exec);
    return new BufferedDataTable[] { result };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) PMMLSVMTranslator(org.knime.base.node.mine.svm.PMMLSVMTranslator) Node(org.w3c.dom.Node) Svm(org.knime.base.node.mine.svm.Svm) DataColumnSpec(org.knime.core.data.DataColumnSpec) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) BufferedDataTable(org.knime.core.node.BufferedDataTable) Vector(java.util.Vector)

Example 5 with Svm

use of org.knime.base.node.mine.svm.Svm in project knime-core by knime.

the class SVMLearnerNodeModel method getSVMInfos.

/**
 * @return a string containing all SVM infos in HTML for the view.
 */
String getSVMInfos() {
    if (!m_svmInfo.isEmpty()) {
        return m_svmInfo;
    }
    StringBuilder sb = new StringBuilder();
    // avoid NPE when reset is called during view update
    Svm[] svms = m_svms;
    if (svms != null) {
        sb.append("<html>\n");
        sb.append("<body>\n");
        for (int i = 0; i < svms.length; i++) {
            if (svms[i] != null) {
                sb.append("<h1> SVM " + i + " Class: " + svms[i].getPositive() + "</h1>");
                sb.append("<b> Support Vectors: </b><br>");
                DoubleVector[] supvecs = svms[i].getSupportVectors();
                for (DoubleVector vec : supvecs) {
                    for (int s = 0; s < vec.getNumberValues(); s++) {
                        sb.append(vec.getValue(s) + ", ");
                    }
                    sb.append(vec.getClassValue() + "<br>");
                }
            }
        }
        sb.append("</body>\n");
        sb.append("</html>\n");
    }
    m_svmInfo = sb.toString();
    return m_svmInfo;
}
Also used : Svm(org.knime.base.node.mine.svm.Svm) DoubleVector(org.knime.base.node.mine.svm.util.DoubleVector)

Aggregations

Svm (org.knime.base.node.mine.svm.Svm)5 PMMLSVMTranslator (org.knime.base.node.mine.svm.PMMLSVMTranslator)3 DataTableSpec (org.knime.core.data.DataTableSpec)3 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)3 SettingsModelString (org.knime.core.node.defaultnodesettings.SettingsModelString)3 Vector (java.util.Vector)2 DoubleVector (org.knime.base.node.mine.svm.util.DoubleVector)2 DataColumnSpec (org.knime.core.data.DataColumnSpec)2 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)2 BufferedDataTable (org.knime.core.node.BufferedDataTable)2 PMMLPortObject (org.knime.core.node.port.pmml.PMMLPortObject)2 Node (org.w3c.dom.Node)2 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 LinkedHashMap (java.util.LinkedHashMap)1 TimerTask (java.util.TimerTask)1 ExecutionException (java.util.concurrent.ExecutionException)1 Future (java.util.concurrent.Future)1 Kernel (org.knime.base.node.mine.svm.kernel.Kernel)1 BinarySvmRunnable (org.knime.base.node.mine.svm.util.BinarySvmRunnable)1