Search in sources :

Example 1 with SparseVector

use of org.tribuo.math.la.SparseVector in project tribuo by oracle.

the class SparseLinearModel method getExcuse.

@Override
public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
    Prediction<Regressor> prediction = predict(example);
    Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
    SparseVector features = createFeatures(example);
    for (int i = 0; i < dimensions.length; i++) {
        List<Pair<String, Double>> classScores = new ArrayList<>();
        for (VectorTuple f : features) {
            double score = weights[i].get(f.index) * f.value;
            classScores.add(new Pair<>(featureIDMap.get(f.index).getName(), score));
        }
        classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
        weightMap.put(dimensions[i], classScores);
    }
    return Optional.of(new Excuse<>(example, prediction, weightMap));
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SparseVector(org.tribuo.math.la.SparseVector) ArrayList(java.util.ArrayList) List(java.util.List) VectorTuple(org.tribuo.math.la.VectorTuple) Regressor(org.tribuo.regression.Regressor) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 2 with SparseVector

use of org.tribuo.math.la.SparseVector in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    SLMOptions o = new SLMOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    RegressionFactory factory = new RegressionFactory();
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    SparseTrainer<Regressor> trainer;
    switch(o.algorithm) {
        case SFS:
            trainer = new SLMTrainer(false, Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
            break;
        case LARS:
            trainer = new LARSTrainer(Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
            break;
        case LARSLASSO:
            trainer = new LARSLassoTrainer(Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
            break;
        case SFSN:
            trainer = new SLMTrainer(true, Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
            break;
        case ELASTICNET:
            trainer = new ElasticNetCDTrainer(o.alpha, o.l1Ratio, 1e-4, o.iterations, false, o.general.seed);
            break;
        default:
            logger.warning("Unknown SLMType, found " + o.algorithm);
            return;
    }
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    SparseModel<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    logger.info("Selected features: " + model.getActiveFeatures());
    Map<String, SparseVector> weights = ((SparseLinearModel) model).getWeights();
    for (Map.Entry<String, SparseVector> e : weights.entrySet()) {
        logger.info("Target:" + e.getKey());
        logger.info("\tWeights: " + e.getValue());
        logger.info("\tWeights one norm: " + e.getValue().oneNorm());
        logger.info("\tWeights two norm: " + e.getValue().twoNorm());
    }
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) SparseVector(org.tribuo.math.la.SparseVector) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Map(java.util.Map)

Example 3 with SparseVector

use of org.tribuo.math.la.SparseVector in project tribuo by oracle.

the class Util method shuffleInPlace.

/**
 * In place shuffle of the features, labels and weights.
 * @param features Input features.
 * @param regressors Input regressors.
 * @param weights Input weights.
 * @param rng SplittableRandom number generator.
 */
public static void shuffleInPlace(SparseVector[] features, DenseVector[] regressors, double[] weights, SplittableRandom rng) {
    int size = features.length;
    // Shuffle array
    for (int i = size; i > 1; i--) {
        int j = rng.nextInt(i);
        // swap features
        SparseVector tmpFeature = features[i - 1];
        features[i - 1] = features[j];
        features[j] = tmpFeature;
        // swap regressors
        DenseVector tmpRegressors = regressors[i - 1];
        regressors[i - 1] = regressors[j];
        regressors[j] = tmpRegressors;
        // swap weights
        double tmpWeight = weights[i - 1];
        weights[i - 1] = weights[j];
        weights[j] = tmpWeight;
    }
}
Also used : SparseVector(org.tribuo.math.la.SparseVector) DenseVector(org.tribuo.math.la.DenseVector)

Example 4 with SparseVector

use of org.tribuo.math.la.SparseVector in project tribuo by oracle.

the class TensorFlowModel method predictBatch.

private List<Prediction<T>> predictBatch(List<Example<T>> batchExamples) {
    if (closed) {
        throw new IllegalStateException("Can't use a closed model, the state has gone.");
    }
    // Convert the batch
    List<SparseVector> vectors = new ArrayList<>(batchExamples.size());
    int[] numActiveElements = new int[batchExamples.size()];
    for (int i = 0; i < batchExamples.size(); i++) {
        SparseVector vec = SparseVector.createSparseVector(batchExamples.get(i), featureIDMap, false);
        numActiveElements[i] = vec.numActiveElements();
        vectors.add(vec);
    }
    // Send a batch to Tensorflow
    try (TensorMap transformedInput = featureConverter.convert(vectors);
        Tensor outputTensor = transformedInput.feedInto(session.runner()).fetch(outputName).run().get(0)) {
        // Transform the returned tensor into a list of Predictions.
        return outputConverter.convertToBatchPrediction(outputTensor, outputIDInfo, numActiveElements, batchExamples);
    }
}
Also used : Tensor(org.tensorflow.Tensor) ArrayList(java.util.ArrayList) SparseVector(org.tribuo.math.la.SparseVector)

Example 5 with SparseVector

use of org.tribuo.math.la.SparseVector in project tribuo by oracle.

the class MultinomialNaiveBayesModel method predict.

@Override
public Prediction<Label> predict(Example<Label> example) {
    SparseVector exVector = SparseVector.createSparseVector(example, featureIDMap, false);
    if (exVector.minValue() < 0.0) {
        throw new IllegalArgumentException("Example has negative feature values, example = " + example.toString());
    }
    if (exVector.numActiveElements() == 0) {
        throw new IllegalArgumentException("No features found in Example " + example.toString());
    }
    /* Since we keep the label by feature matrix sparse, we need to manually
         * add the weights contributed by smoothing unobserved features. We need to
         * add in the portion of the inner product for the indices that are active
         * in the example but are not active in the labelWordProbs matrix (but are
         * still non-zero due to smoothing).
         */
    double[] alphaOffsets = new double[outputIDInfo.size()];
    int vocabSize = labelWordProbs.getDimension2Size();
    if (alpha > 0.0) {
        for (int i = 0; i < outputIDInfo.size(); i++) {
            double unobservedProb = Math.log(alpha / (labelWordProbs.getRow(i).oneNorm() + (vocabSize * alpha)));
            int[] mismatchedIndices = exVector.difference(labelWordProbs.getRow(i));
            double inExampleFactor = 0.0;
            for (int idx = 0; idx < mismatchedIndices.length; idx++) {
                // TODO - exVector.get is slow as it does a binary search into the vector.
                inExampleFactor += exVector.get(mismatchedIndices[idx]) * unobservedProb;
            }
            alphaOffsets[i] = inExampleFactor;
        }
    }
    DenseVector prediction = labelWordProbs.leftMultiply(exVector);
    prediction.intersectAndAddInPlace(DenseVector.createDenseVector(alphaOffsets));
    prediction.normalize(normalizer);
    Map<String, Label> distribution = new LinkedHashMap<>();
    Label maxLabel = null;
    double maxScore = Double.NEGATIVE_INFINITY;
    for (VectorTuple vt : prediction) {
        String name = outputIDInfo.getOutput(vt.index).getLabel();
        Label label = new Label(name, vt.value);
        if (vt.value > maxScore) {
            maxScore = vt.value;
            maxLabel = label;
        }
        distribution.put(name, label);
    }
    Prediction<Label> p = new Prediction<>(maxLabel, distribution, exVector.numActiveElements(), example, true);
    return p;
}
Also used : Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) SparseVector(org.tribuo.math.la.SparseVector) LinkedHashMap(java.util.LinkedHashMap) VectorTuple(org.tribuo.math.la.VectorTuple) DenseVector(org.tribuo.math.la.DenseVector)

Aggregations

SparseVector (org.tribuo.math.la.SparseVector)44 ArrayList (java.util.ArrayList)15 Regressor (org.tribuo.regression.Regressor)14 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)12 VectorTuple (org.tribuo.math.la.VectorTuple)10 Prediction (org.tribuo.Prediction)9 Pair (com.oracle.labs.mlrg.olcut.util.Pair)7 HashMap (java.util.HashMap)7 Label (org.tribuo.classification.Label)7 DenseVector (org.tribuo.math.la.DenseVector)7 ModelProvenance (org.tribuo.provenance.ModelProvenance)6 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)6 List (java.util.List)5 Map (java.util.Map)5 SplittableRandom (java.util.SplittableRandom)5 LinkedHashMap (java.util.LinkedHashMap)4 PriorityQueue (java.util.PriorityQueue)3 DenseSparseMatrix (org.tribuo.math.la.DenseSparseMatrix)3 LinkedList (java.util.LinkedList)2 TFloat32 (org.tensorflow.types.TFloat32)2