Search in sources :

Example 6 with QNMinimizer

use of edu.stanford.nlp.optimization.QNMinimizer in project CoreNLP by stanfordnlp.

the class EntityClassifier method train.

private static void train(List<SceneGraphImage> images, String modelPath, Embedding embeddings) throws IOException {
    RVFDataset<String, String> dataset = new RVFDataset<String, String>();
    SceneGraphSentenceMatcher sentenceMatcher = new SceneGraphSentenceMatcher(embeddings);
    for (SceneGraphImage img : images) {
        for (SceneGraphImageRegion region : img.regions) {
            SemanticGraph sg = region.getEnhancedSemanticGraph();
            SemanticGraphEnhancer.enhance(sg);
            List<Triple<IndexedWord, IndexedWord, String>> relationTriples = sentenceMatcher.getRelationTriples(region);
            for (Triple<IndexedWord, IndexedWord, String> relation : relationTriples) {
                IndexedWord w1 = sg.getNodeByIndexSafe(relation.first.index());
                if (w1 != null) {
                    dataset.add(getDatum(w1, relation.first.get(SceneGraphCoreAnnotations.GoldEntityAnnotation.class), embeddings));
                }
            }
        }
    }
    LinearClassifierFactory<String, String> classifierFactory = new LinearClassifierFactory<String, String>(new QNMinimizer(15), 1e-4, false, REG_STRENGTH);
    Classifier<String, String> classifier = classifierFactory.trainClassifier(dataset);
    IOUtils.writeObjectToFile(classifier, modelPath);
    System.err.println(classifier.evaluateAccuracy(dataset));
}
Also used : RVFDataset(edu.stanford.nlp.classify.RVFDataset) SceneGraphImage(edu.stanford.nlp.scenegraph.image.SceneGraphImage) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) Triple(edu.stanford.nlp.util.Triple) LinearClassifierFactory(edu.stanford.nlp.classify.LinearClassifierFactory) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord) SceneGraphImageRegion(edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion)

Example 7 with QNMinimizer

use of edu.stanford.nlp.optimization.QNMinimizer in project CoreNLP by stanfordnlp.

the class SimpleSentiment method train.

/**
 * Train a sentiment model from a set of data.
 *
 * @param data The data to train the model from.
 * @param modelLocation An optional location to save the model.
 *                      Note that this stream will be closed in this method,
 *                      and should not be written to thereafter.
 *
 * @return A sentiment classifier, ready to use.
 */
@SuppressWarnings({ "OptionalUsedAsFieldOrParameterType", "ConstantConditions" })
public static SimpleSentiment train(Stream<SentimentDatum> data, Optional<OutputStream> modelLocation) {
    // Some useful variables configuring how we train
    boolean useL1 = true;
    double sigma = 1.0;
    int featureCountThreshold = 5;
    // Featurize the data
    forceTrack("Featurizing");
    RVFDataset<SentimentClass, String> dataset = new RVFDataset<>();
    AtomicInteger datasize = new AtomicInteger(0);
    Counter<SentimentClass> distribution = new ClassicCounter<>();
    data.unordered().parallel().map(datum -> {
        if (datasize.incrementAndGet() % 10000 == 0) {
            log("Added " + datasize.get() + " datums");
        }
        return new RVFDatum<>(featurize(datum.asCoreMap()), datum.sentiment);
    }).forEach(x -> {
        synchronized (dataset) {
            distribution.incrementCount(x.label());
            dataset.add(x);
        }
    });
    endTrack("Featurizing");
    // Print label distribution
    startTrack("Distribution");
    for (SentimentClass label : SentimentClass.values()) {
        log(String.format("%7d", (int) distribution.getCount(label)) + "   " + label);
    }
    endTrack("Distribution");
    // Train the classifier
    forceTrack("Training");
    if (featureCountThreshold > 1) {
        dataset.applyFeatureCountThreshold(featureCountThreshold);
    }
    dataset.randomize(42L);
    LinearClassifierFactory<SentimentClass, String> factory = new LinearClassifierFactory<>();
    factory.setVerbose(true);
    try {
        factory.setMinimizerCreator(() -> {
            QNMinimizer minimizer = new QNMinimizer();
            if (useL1) {
                minimizer.useOWLQN(true, 1 / (sigma * sigma));
            } else {
                factory.setSigma(sigma);
            }
            return minimizer;
        });
    } catch (Exception ignored) {
    }
    factory.setSigma(sigma);
    LinearClassifier<SentimentClass, String> classifier = factory.trainClassifier(dataset);
    // Optionally save the model
    modelLocation.ifPresent(stream -> {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(stream);
            oos.writeObject(classifier);
            oos.close();
        } catch (IOException e) {
            log.err("Could not save model to stream!");
        }
    });
    endTrack("Training");
    // Evaluate the model
    forceTrack("Evaluating");
    factory.setVerbose(false);
    double sumAccuracy = 0.0;
    Counter<SentimentClass> sumP = new ClassicCounter<>();
    Counter<SentimentClass> sumR = new ClassicCounter<>();
    int numFolds = 4;
    for (int fold = 0; fold < numFolds; ++fold) {
        Pair<GeneralDataset<SentimentClass, String>, GeneralDataset<SentimentClass, String>> trainTest = dataset.splitOutFold(fold, numFolds);
        // convex objective, so this should be OK
        LinearClassifier<SentimentClass, String> foldClassifier = factory.trainClassifierWithInitialWeights(trainTest.first, classifier);
        sumAccuracy += foldClassifier.evaluateAccuracy(trainTest.second);
        for (SentimentClass label : SentimentClass.values()) {
            Pair<Double, Double> pr = foldClassifier.evaluatePrecisionAndRecall(trainTest.second, label);
            sumP.incrementCount(label, pr.first);
            sumP.incrementCount(label, pr.second);
        }
    }
    DecimalFormat df = new DecimalFormat("0.000%");
    log.info("----------");
    double aveAccuracy = sumAccuracy / ((double) numFolds);
    log.info("" + numFolds + "-fold accuracy: " + df.format(aveAccuracy));
    log.info("");
    for (SentimentClass label : SentimentClass.values()) {
        double p = sumP.getCount(label) / numFolds;
        double r = sumR.getCount(label) / numFolds;
        log.info(label + " (P)  = " + df.format(p));
        log.info(label + " (R)  = " + df.format(r));
        log.info(label + " (F1) = " + df.format(2 * p * r / (p + r)));
        log.info("");
    }
    log.info("----------");
    endTrack("Evaluating");
    // Return
    return new SimpleSentiment(classifier);
}
Also used : Arrays(java.util.Arrays) SentimentClass(edu.stanford.nlp.simple.SentimentClass) Document(edu.stanford.nlp.simple.Document) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) Counter(edu.stanford.nlp.stats.Counter) StanfordCoreNLP(edu.stanford.nlp.pipeline.StanfordCoreNLP) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) Pair(edu.stanford.nlp.util.Pair) ObjectOutputStream(java.io.ObjectOutputStream) StreamSupport(java.util.stream.StreamSupport) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) CoreMap(edu.stanford.nlp.util.CoreMap) RVFDatum(edu.stanford.nlp.ling.RVFDatum) OutputStream(java.io.OutputStream) CoreLabel(edu.stanford.nlp.ling.CoreLabel) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) Properties(java.util.Properties) IOUtils(edu.stanford.nlp.io.IOUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) DecimalFormat(java.text.DecimalFormat) Util(edu.stanford.nlp.util.logging.Redwood.Util) IOException(java.io.IOException) File(java.io.File) Lazy(edu.stanford.nlp.util.Lazy) List(java.util.List) Stream(java.util.stream.Stream) Annotation(edu.stanford.nlp.pipeline.Annotation) edu.stanford.nlp.classify(edu.stanford.nlp.classify) StringUtils(edu.stanford.nlp.util.StringUtils) Optional(java.util.Optional) RedwoodConfiguration(edu.stanford.nlp.util.logging.RedwoodConfiguration) Pattern(java.util.regex.Pattern) SentimentClass(edu.stanford.nlp.simple.SentimentClass) DecimalFormat(java.text.DecimalFormat) ObjectOutputStream(java.io.ObjectOutputStream) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) IOException(java.io.IOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) IOException(java.io.IOException) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 8 with QNMinimizer

use of edu.stanford.nlp.optimization.QNMinimizer in project CoreNLP by stanfordnlp.

the class BoWSceneGraphParser method train.

/**
 * Trains a classifier using the examples in trainingFile and saves
 * it to modelPath.
 *
 * @param trainingFile Path to JSON file with images and scene graphs.
 * @param modelPath
 * @throws IOException
 */
public void train(String trainingFile, String modelPath) throws IOException {
    LinearClassifierFactory<String, String> classifierFactory = new LinearClassifierFactory<String, String>(new QNMinimizer(15), 1e-4, false, REG_STRENGTH);
    /* Create dataset. */
    Dataset<String, String> dataset = getTrainingExamples(trainingFile, true);
    /* Train the classifier. */
    Classifier<String, String> classifier = classifierFactory.trainClassifier(dataset);
    /* Save classifier to disk. */
    IOUtils.writeObjectToFile(classifier, modelPath);
}
Also used : LinearClassifierFactory(edu.stanford.nlp.classify.LinearClassifierFactory) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer)

Example 9 with QNMinimizer

use of edu.stanford.nlp.optimization.QNMinimizer in project CoreNLP by stanfordnlp.

the class LogisticClassifierFactory method trainWeightedData.

public LogisticClassifier<L, F> trainWeightedData(GeneralDataset<L, F> data, float[] dataWeights) {
    if (data instanceof RVFDataset)
        ((RVFDataset<L, F>) data).ensureRealValues();
    if (data.labelIndex.size() != 2) {
        throw new RuntimeException("LogisticClassifier is only for binary classification!");
    }
    Minimizer<DiffFunction> minim;
    LogisticObjectiveFunction lof = null;
    if (data instanceof Dataset<?, ?>)
        lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), new LogPrior(LogPrior.LogPriorType.QUADRATIC), dataWeights);
    else if (data instanceof RVFDataset<?, ?>)
        lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), new LogPrior(LogPrior.LogPriorType.QUADRATIC), dataWeights);
    minim = new QNMinimizer(lof);
    weights = minim.minimize(lof, 1e-4, new double[data.numFeatureTypes()]);
    featureIndex = data.featureIndex;
    classes[0] = data.labelIndex.get(0);
    classes[1] = data.labelIndex.get(1);
    return new LogisticClassifier<>(weights, featureIndex, classes);
}
Also used : DiffFunction(edu.stanford.nlp.optimization.DiffFunction) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer)

Example 10 with QNMinimizer

use of edu.stanford.nlp.optimization.QNMinimizer in project CoreNLP by stanfordnlp.

the class DVParser method executeOneTrainingBatch.

public void executeOneTrainingBatch(List<Tree> trainingBatch, IdentityHashMap<Tree, byte[]> compressedParses, double[] sumGradSquare) {
    Timing convertTiming = new Timing();
    convertTiming.doing("Converting trees");
    IdentityHashMap<Tree, List<Tree>> topParses = CacheParseHypotheses.convertToTrees(trainingBatch, compressedParses, op.trainOptions.trainingThreads);
    convertTiming.done();
    DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(trainingBatch, topParses, dvModel, op);
    double[] theta = dvModel.paramsToVector();
    // 1: QNMinimizer, 2: SGD
    switch(MINIMIZER) {
        case (1):
            {
                QNMinimizer qn = new QNMinimizer(op.trainOptions.qnEstimates, true);
                qn.useMinPackSearch();
                qn.useDiagonalScaling();
                qn.terminateOnAverageImprovement(true);
                qn.terminateOnNumericalZero(true);
                qn.terminateOnRelativeNorm(true);
                theta = qn.minimize(gcFunc, op.trainOptions.qnTolerance, theta, op.trainOptions.qnIterationsPerBatch);
                break;
            }
        case 2:
            {
                // Minimizer smd = new SGDMinimizer();    	double tol = 1e-4;    	theta = smd.minimize(gcFunc,tol,theta,op.trainOptions.qnIterationsPerBatch);
                double lastCost = 0, currCost = 0;
                boolean firstTime = true;
                for (int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++) {
                    // gcFunc.calculate(theta);
                    double[] grad = gcFunc.derivativeAt(theta);
                    currCost = gcFunc.valueAt(theta);
                    log.info("batch cost: " + currCost);
                    // if(!firstTime){
                    // if(currCost > lastCost){
                    // System.out.println("HOW IS FUNCTION VALUE INCREASING????!!! ... still updating theta");
                    // }
                    // if(Math.abs(currCost - lastCost) < 0.0001){
                    // System.out.println("function value is not decreasing. stop");
                    // }
                    // }else{
                    // firstTime = false;
                    // }
                    lastCost = currCost;
                    ArrayMath.addMultInPlace(theta, grad, -1 * op.trainOptions.learningRate);
                }
                break;
            }
        case 3:
            {
                // AdaGrad
                double eps = 1e-3;
                double currCost = 0;
                for (int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++) {
                    double[] gradf = gcFunc.derivativeAt(theta);
                    currCost = gcFunc.valueAt(theta);
                    log.info("batch cost: " + currCost);
                    for (int feature = 0; feature < gradf.length; feature++) {
                        sumGradSquare[feature] = sumGradSquare[feature] + gradf[feature] * gradf[feature];
                        theta[feature] = theta[feature] - (op.trainOptions.learningRate * gradf[feature] / (Math.sqrt(sumGradSquare[feature]) + eps));
                    }
                }
                break;
            }
        default:
            {
                throw new IllegalArgumentException("Unsupported minimizer " + MINIMIZER);
            }
    }
    dvModel.vectorToParams(theta);
}
Also used : Tree(edu.stanford.nlp.trees.Tree) ArrayList(java.util.ArrayList) List(java.util.List) Timing(edu.stanford.nlp.util.Timing) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer)

Aggregations

QNMinimizer (edu.stanford.nlp.optimization.QNMinimizer)10 DiffFunction (edu.stanford.nlp.optimization.DiffFunction)5 LinearClassifierFactory (edu.stanford.nlp.classify.LinearClassifierFactory)3 List (java.util.List)2 edu.stanford.nlp.classify (edu.stanford.nlp.classify)1 RVFDataset (edu.stanford.nlp.classify.RVFDataset)1 WeightedDataset (edu.stanford.nlp.classify.WeightedDataset)1 IOUtils (edu.stanford.nlp.io.IOUtils)1 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)1 BasicDatum (edu.stanford.nlp.ling.BasicDatum)1 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)1 CoreLabel (edu.stanford.nlp.ling.CoreLabel)1 IndexedWord (edu.stanford.nlp.ling.IndexedWord)1 RVFDatum (edu.stanford.nlp.ling.RVFDatum)1 TaggedWord (edu.stanford.nlp.ling.TaggedWord)1 Annotation (edu.stanford.nlp.pipeline.Annotation)1 StanfordCoreNLP (edu.stanford.nlp.pipeline.StanfordCoreNLP)1 SceneGraphImage (edu.stanford.nlp.scenegraph.image.SceneGraphImage)1 SceneGraphImageRegion (edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion)1 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)1