Search in sources :

Example 1 with DoubleVector

use of org.knime.base.node.mine.svm.util.DoubleVector in project knime-core by knime.

the class PMMLSVMTranslator method addVectorDictionary.

/**
 * Adds the vector dictionary to the SVM model.
 *
 * @param svmModel the SVM model to add the dictionary to
 * @param learningCol a list with the names of the learning columns
 */
private void addVectorDictionary(final SupportVectorMachineModel svmModel, final List<String> learningCol) {
    VectorDictionary dict = svmModel.addNewVectorDictionary();
    Set<DoubleVector> supportVectors = new LinkedHashSet<DoubleVector>();
    for (Svm svm : m_svms) {
        supportVectors.addAll(Arrays.asList(svm.getSupportVectors()));
    }
    dict.setNumberOfVectors(BigInteger.valueOf(supportVectors.size()));
    VectorFields vectorFields = dict.addNewVectorFields();
    vectorFields.setNumberOfFields(BigInteger.valueOf(learningCol.size()));
    for (String field : learningCol) {
        vectorFields.addNewFieldRef().setField(m_nameMapper.getDerivedFieldName(field));
    }
    for (DoubleVector vector : supportVectors) {
        VectorInstance vectorInstance = dict.addNewVectorInstance();
        vectorInstance.setId(vector.getClassValue() + CLASS_KEY_SEPARATOR + vector.getKey().getString());
        REALSparseArray pmmlRealSparseArray = vectorInstance.addNewREALSparseArray1();
        int nrValues = vector.getNumberValues();
        pmmlRealSparseArray.setN(BigInteger.valueOf(nrValues));
        // set Indices and Entries
        List<String> indicesList = new ArrayList<String>();
        List<String> entriesList = new ArrayList<String>();
        for (int i = 1; i <= nrValues; i++) {
            indicesList.add(String.valueOf(i));
            entriesList.add(String.valueOf(vector.getValue(i - 1)));
        }
        pmmlRealSparseArray.setIndices(indicesList);
        pmmlRealSparseArray.setREALEntries(entriesList);
    }
}
Also used : LinkedHashSet(java.util.LinkedHashSet) VectorInstance(org.dmg.pmml.VectorInstanceDocument.VectorInstance) VectorDictionary(org.dmg.pmml.VectorDictionaryDocument.VectorDictionary) REALSparseArray(org.dmg.pmml.REALSparseArrayDocument.REALSparseArray) VectorFields(org.dmg.pmml.VectorFieldsDocument.VectorFields) ArrayList(java.util.ArrayList) DoubleVector(org.knime.base.node.mine.svm.util.DoubleVector)

Example 2 with DoubleVector

use of org.knime.base.node.mine.svm.util.DoubleVector in project knime-core by knime.

the class SVMLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    BufferedDataTable inTable = (BufferedDataTable) inData[0];
    DataTableSpec inSpec = inTable.getDataTableSpec();
    LearnColumnsAndColumnRearrangerTuple tuple = createTrainTableColumnRearranger(inSpec);
    // no progress needed as constant operation (column removal only)
    BufferedDataTable trainTable = exec.createColumnRearrangeTable(inTable, tuple.getTrainingRearranger(), exec.createSubProgress(0.0));
    DataTableSpec trainSpec = trainTable.getDataTableSpec();
    int classpos = trainSpec.findColumnIndex(m_classcol.getStringValue());
    CheckUtils.checkArgument(classpos >= 0, "Selected class column not found: " + m_classcol.getStringValue());
    // convert input data
    ArrayList<DoubleVector> inputData = new ArrayList<DoubleVector>();
    List<String> categories = new ArrayList<String>();
    StringValue classvalue = null;
    for (DataRow row : trainTable) {
        exec.checkCanceled();
        ArrayList<Double> values = new ArrayList<Double>();
        boolean add = true;
        for (int i = 0; i < row.getNumCells(); i++) {
            if (row.getCell(i).isMissing()) {
                add = false;
                break;
            }
            if (i != classpos) {
                DoubleValue cell = (DoubleValue) row.getCell(i);
                values.add(cell.getDoubleValue());
            } else {
                classvalue = (StringValue) row.getCell(classpos);
                if (!categories.contains(classvalue.getStringValue())) {
                    categories.add(classvalue.getStringValue());
                }
            }
        }
        if (add) {
            @SuppressWarnings("null") final String nonNullClassValue = classvalue.getStringValue();
            inputData.add(new DoubleVector(row.getKey(), values, nonNullClassValue));
        }
    }
    if (categories.isEmpty()) {
        throw new Exception("No categories found to train SVM. " + "Possibly an empty input table was provided.");
    }
    DoubleVector[] inputDataArr = new DoubleVector[inputData.size()];
    inputDataArr = inputData.toArray(inputDataArr);
    Kernel kernel = KernelFactory.getKernel(m_kernelType);
    Vector<SettingsModelDouble> kernelparams = m_kernelParameters.get(m_kernelType);
    for (int i = 0; i < kernel.getNumberParameters(); ++i) {
        kernel.setParameter(i, kernelparams.get(i).getDoubleValue());
    }
    final Svm[] svms = new Svm[categories.size()];
    exec.setMessage("Training SVM");
    final BinarySvmRunnable[] bst = new BinarySvmRunnable[categories.size()];
    for (int i = 0; i < categories.size(); i++) {
        bst[i] = new BinarySvmRunnable(inputDataArr, categories.get(i), kernel, m_paramC.getDoubleValue(), exec.createSubProgress((1.0 / categories.size())));
    }
    ThreadPool pool = KNIMEConstants.GLOBAL_THREAD_POOL;
    final Future<?>[] fut = new Future<?>[bst.length];
    KNIMETimer timer = KNIMETimer.getInstance();
    TimerTask timerTask = new TimerTask() {

        @Override
        public void run() {
            try {
                exec.checkCanceled();
            } catch (final CanceledExecutionException ce) {
                for (int i = 0; i < fut.length; i++) {
                    if (fut[i] != null) {
                        fut[i].cancel(true);
                    }
                }
                super.cancel();
            }
        }
    };
    timer.scheduleAtFixedRate(timerTask, 0, 3000);
    for (int i = 0; i < bst.length; i++) {
        fut[i] = pool.enqueue(bst[i]);
    }
    try {
        pool.runInvisible(new Callable<Void>() {

            @Override
            public Void call() throws Exception {
                for (int i = 0; i < fut.length; ++i) {
                    fut[i].get();
                    bst[i].ok();
                    if (bst[i].getWarning() != null) {
                        setWarningMessage(bst[i].getWarning());
                    }
                    svms[i] = bst[i].getSvm();
                }
                return null;
            }
        });
    } catch (Exception ex) {
        exec.checkCanceled();
        Throwable t = ex;
        if (ex instanceof ExecutionException) {
            t = ex.getCause();
        }
        if (t instanceof Exception) {
            throw (Exception) t;
        } else {
            throw new Exception(t);
        }
    } finally {
        for (int i = 0; i < fut.length; i++) {
            fut[i].cancel(true);
        }
        timerTask.cancel();
    }
    // the optional PMML input (can be null)
    PMMLPortObject inPMMLPort = m_pmmlInEnabled ? (PMMLPortObject) inData[1] : null;
    // create the outgoing PMML spec
    PMMLPortObjectSpecCreator specCreator = new PMMLPortObjectSpecCreator(inPMMLPort, inSpec);
    specCreator.setLearningCols(trainSpec);
    specCreator.setTargetCol(trainSpec.getColumnSpec(m_classcol.getStringValue()));
    // create the outgoing PMML port object
    PMMLPortObject outPMMLPort = new PMMLPortObject(specCreator.createSpec(), inPMMLPort, inSpec);
    outPMMLPort.addModelTranslater(new PMMLSVMTranslator(categories, Arrays.asList(svms), kernel));
    m_svms = svms;
    return new PortObject[] { outPMMLPort };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) PMMLSVMTranslator(org.knime.base.node.mine.svm.PMMLSVMTranslator) ArrayList(java.util.ArrayList) ThreadPool(org.knime.core.util.ThreadPool) SettingsModelDouble(org.knime.core.node.defaultnodesettings.SettingsModelDouble) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) Svm(org.knime.base.node.mine.svm.Svm) DataRow(org.knime.core.data.DataRow) TimerTask(java.util.TimerTask) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) BufferedDataTable(org.knime.core.node.BufferedDataTable) StringValue(org.knime.core.data.StringValue) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) Kernel(org.knime.base.node.mine.svm.kernel.Kernel) PortObject(org.knime.core.node.port.PortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) KNIMETimer(org.knime.core.util.KNIMETimer) BinarySvmRunnable(org.knime.base.node.mine.svm.util.BinarySvmRunnable) SettingsModelDouble(org.knime.core.node.defaultnodesettings.SettingsModelDouble) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) DoubleValue(org.knime.core.data.DoubleValue) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) Future(java.util.concurrent.Future) DoubleVector(org.knime.base.node.mine.svm.util.DoubleVector) PMMLPortObjectSpecCreator(org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator)

Example 3 with DoubleVector

use of org.knime.base.node.mine.svm.util.DoubleVector in project knime-core by knime.

the class SVMPredictor method doPredict.

/**
 * Given a vector, find out it's class.
 *
 * @param values the parameters.
 */
private String doPredict(final ArrayList<Double> values) {
    DoubleVector vector = new DoubleVector(values, "not_known_yet");
    int pos = 0;
    double bestDistance = m_svms[0].distance(vector);
    for (int i = 1; i < m_svms.length; ++i) {
        double newDist = m_svms[i].distance(vector);
        if (newDist > bestDistance) {
            pos = i;
            bestDistance = newDist;
        }
    }
    return m_svms[pos].getPositive();
}
Also used : DoubleVector(org.knime.base.node.mine.svm.util.DoubleVector)

Example 4 with DoubleVector

use of org.knime.base.node.mine.svm.util.DoubleVector in project knime-core by knime.

the class SVMPredictor method doPredict.

/**
 * Given a vector, find out it's class.
 *
 * @param values the parameters.
 */
private String doPredict(final ArrayList<Double> values) {
    DoubleVector vector = new DoubleVector(values, "not_known_yet");
    int pos = 0;
    double bestDistance = m_svms[0].distance(vector);
    for (int i = 1; i < m_svms.length; ++i) {
        if (m_svms[i].distance(vector) > bestDistance) {
            pos = i;
            bestDistance = m_svms[i].distance(vector);
        }
    }
    return m_svms[pos].getPositive();
}
Also used : DoubleVector(org.knime.base.node.mine.svm.util.DoubleVector)

Example 5 with DoubleVector

use of org.knime.base.node.mine.svm.util.DoubleVector in project knime-core by knime.

the class PMMLSVMTranslator method addSVMs.

/**
 * @param svmModel the SVM model to add the SVMs to
 */
private void addSVMs(final SupportVectorMachineModel svmModel) {
    if (m_svms.size() == 0) {
        svmModel.addNewSupportVectorMachine();
    // TODO Review what is necessary for the case of an empty model
    // for (String target : m_targetValues) {
    // SupportVectorMachine svm =
    // svmModel.addNewSupportVectorMachine();
    // svm.setTargetCategory(target);
    // SupportVectors supportVectors = svm.addNewSupportVectors();
    // supportVectors.setNumberOfSupportVectors(
    // BigInteger.valueOf(0));
    // }
    } else {
        for (Svm svm : m_svms) {
            SupportVectorMachine pmmlSvm = svmModel.addNewSupportVectorMachine();
            pmmlSvm.setTargetCategory(svm.getPositive());
            final boolean binaryClassification = (m_svms.size() == 2);
            if (binaryClassification) {
                pmmlSvm.setAlternateTargetCategory(m_svms.get(1).getPositive());
            }
            // add support vectors
            SupportVectors pmmlSupportVectors = pmmlSvm.addNewSupportVectors();
            DoubleVector[] supVectors = svm.getSupportVectors();
            pmmlSupportVectors.setNumberOfAttributes(BigInteger.valueOf(supVectors[0].getNumberValues()));
            pmmlSupportVectors.setNumberOfSupportVectors(BigInteger.valueOf(supVectors.length));
            for (int i = 0; i < supVectors.length; i++) {
                SupportVector pmmlSupportVector = pmmlSupportVectors.addNewSupportVector();
                pmmlSupportVector.setVectorId(getSupportVectorId(supVectors[i]));
            }
            // ----------------------------------------
            // add coefficients
            Coefficients pmmlCoefficients = pmmlSvm.addNewCoefficients();
            double[] alphas = svm.getTargetAlphas();
            pmmlCoefficients.setNumberOfCoefficients(BigInteger.valueOf(alphas.length));
            pmmlCoefficients.setAbsoluteValue(svm.getThreshold());
            for (int i = 0; i < alphas.length; i++) {
                Coefficient pmmlCoefficient = pmmlCoefficients.addNewCoefficient();
                /* KNIME defines the winner as the positive side of the
                     * threshold, but the DMG defines the winner as the negative
                     * side of the threshold. Therefore, to avoid changing KNIME
                     * algorithm, we need to add an additional minus sign for
                     * each svm output. When importing the PMML into KNIME
                     * the absolute value of the alphas is read. Hence the
                     * negative sign in the PMML alpha has no influence on
                     * the KNIME model.*/
                pmmlCoefficient.setValue(-1 * alphas[i]);
            }
            if (binaryClassification) {
                /* Binary classification case. Only one SVM is needed. */
                break;
            }
        }
    }
}
Also used : Coefficient(org.dmg.pmml.CoefficientDocument.Coefficient) SupportVectors(org.dmg.pmml.SupportVectorsDocument.SupportVectors) SupportVectorMachine(org.dmg.pmml.SupportVectorMachineDocument.SupportVectorMachine) DoubleVector(org.knime.base.node.mine.svm.util.DoubleVector) Coefficients(org.dmg.pmml.CoefficientsDocument.Coefficients) SupportVector(org.dmg.pmml.SupportVectorDocument.SupportVector)

Aggregations

DoubleVector (org.knime.base.node.mine.svm.util.DoubleVector)7 ArrayList (java.util.ArrayList)3 Coefficient (org.dmg.pmml.CoefficientDocument.Coefficient)2 Coefficients (org.dmg.pmml.CoefficientsDocument.Coefficients)2 REALSparseArray (org.dmg.pmml.REALSparseArrayDocument.REALSparseArray)2 SupportVector (org.dmg.pmml.SupportVectorDocument.SupportVector)2 SupportVectorMachine (org.dmg.pmml.SupportVectorMachineDocument.SupportVectorMachine)2 SupportVectors (org.dmg.pmml.SupportVectorsDocument.SupportVectors)2 VectorInstance (org.dmg.pmml.VectorInstanceDocument.VectorInstance)2 Svm (org.knime.base.node.mine.svm.Svm)2 IOException (java.io.IOException)1 LinkedHashMap (java.util.LinkedHashMap)1 LinkedHashSet (java.util.LinkedHashSet)1 TimerTask (java.util.TimerTask)1 ExecutionException (java.util.concurrent.ExecutionException)1 Future (java.util.concurrent.Future)1 VectorDictionary (org.dmg.pmml.VectorDictionaryDocument.VectorDictionary)1 VectorFields (org.dmg.pmml.VectorFieldsDocument.VectorFields)1 PMMLSVMTranslator (org.knime.base.node.mine.svm.PMMLSVMTranslator)1 Kernel (org.knime.base.node.mine.svm.kernel.Kernel)1