Search in sources :

Example 21 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class NeighbourQueryTestHelper method neighboursQuerySingleDimension.

static void neighboursQuerySingleDimension(NeighboursQueryFactory nqf) {
    SGDVector[] data = new SGDVector[10];
    data[0] = DenseVector.createDenseVector(new double[] { 0 });
    data[1] = DenseVector.createDenseVector(new double[] { 1 });
    data[2] = DenseVector.createDenseVector(new double[] { 2 });
    data[3] = DenseVector.createDenseVector(new double[] { 3 });
    data[4] = DenseVector.createDenseVector(new double[] { 4 });
    data[5] = DenseVector.createDenseVector(new double[] { 5 });
    data[6] = DenseVector.createDenseVector(new double[] { 6 });
    data[7] = DenseVector.createDenseVector(new double[] { 7 });
    data[8] = DenseVector.createDenseVector(new double[] { 8 });
    data[9] = DenseVector.createDenseVector(new double[] { 9 });
    NeighboursQuery nq = nqf.createNeighboursQuery(data);
    SGDVector candidate = DenseVector.createDenseVector(new double[] { 1.75 });
    List<Pair<Integer, Double>> query = nq.query(candidate, 3);
    assertEquals(2, query.get(0).getA());
    assertEquals(1, query.get(1).getA());
    assertEquals(3, query.get(2).getA());
}
Also used : SGDVector(org.tribuo.math.la.SGDVector) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 22 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class AnomalyDataGenerator method gaussianAnomaly.

/**
 * Generates two datasets, one without anomalies drawn from a single gaussian
 * and the second drawn from a mixture of two gaussians, with the second tagged
 * anomalous.
 *
 * @param size The number of points to sample for each dataset.
 * @param fractionAnomalous The fraction of anomalous data to generate.
 * @return A pair of datasets.
 */
public static Pair<Dataset<Event>, Dataset<Event>> gaussianAnomaly(long size, double fractionAnomalous) {
    if (size < 1) {
        throw new IllegalArgumentException("Size must be a positive number, received " + size);
    } else if ((fractionAnomalous > 1) || (fractionAnomalous < 0)) {
        throw new IllegalArgumentException("FractionAnomalous must be between zero and one, received " + fractionAnomalous);
    }
    Random rng = new Random(1L);
    // Dataset parameters
    String[] featureNames = new String[] { "A", "B", "C", "D", "E" };
    double[] expectedMeans = new double[] { 1.0, 2.0, 1.0, 2.0, 5.0 };
    double[] anomalousMeans = new double[] { -2.0, 2.0, -2.0, 2.0, -10.0 };
    double[] variances = new double[] { 1.0, 0.5, 0.25, 1.0, 0.1 };
    List<Example<Event>> trainingData = new ArrayList<>();
    for (int i = 0; i < size; i++) {
        List<Feature> featureList = generateFeatures(rng, featureNames, expectedMeans, variances);
        trainingData.add(new ArrayExample<>(EXPECTED_EVENT, featureList));
    }
    List<Example<Event>> testingData = new ArrayList<>();
    for (int i = 0; i < size; i++) {
        double draw = rng.nextDouble();
        if (draw < fractionAnomalous) {
            List<Feature> featureList = generateFeatures(rng, featureNames, anomalousMeans, variances);
            testingData.add(new ArrayExample<>(ANOMALOUS_EVENT, featureList));
        } else {
            List<Feature> featureList = generateFeatures(rng, featureNames, expectedMeans, variances);
            testingData.add(new ArrayExample<>(EXPECTED_EVENT, featureList));
        }
    }
    SimpleDataSourceProvenance trainingProvenance = new SimpleDataSourceProvenance("Anomaly training data", anomalyFactory);
    MutableDataset<Event> train = new MutableDataset<>(new ListDataSource<>(trainingData, anomalyFactory, trainingProvenance));
    SimpleDataSourceProvenance testingProvenance = new SimpleDataSourceProvenance("Anomaly testing data", anomalyFactory);
    MutableDataset<Event> test = new MutableDataset<>(new ListDataSource<>(testingData, anomalyFactory, testingProvenance));
    return new Pair<>(train, test);
}
Also used : SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) Random(java.util.Random) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Event(org.tribuo.anomaly.Event) MutableDataset(org.tribuo.MutableDataset) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 23 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class LibLinearAnomalyModel method innerGetExcuse.

/**
 * The call to model.getFeatureWeights in the public methods copies the
 * weights array so this inner method exists to save the copy in getExcuses.
 * <p>
 * If it becomes a problem then we could cache the feature weights in the
 * model.
 * @param e The example.
 * @param allFeatureWeights The feature weights.
 * @return An excuse for this example.
 */
@Override
protected Excuse<Event> innerGetExcuse(Example<Event> e, double[][] allFeatureWeights) {
    de.bwaldvogel.liblinear.Model model = models.get(0);
    double[] featureWeights = allFeatureWeights[0];
    Prediction<Event> prediction = predict(e);
    Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
    List<Pair<String, Double>> posScores = new ArrayList<>();
    List<Pair<String, Double>> negScores = new ArrayList<>();
    for (Feature f : e) {
        int id = featureIDMap.getID(f.getName());
        if (id > -1) {
            double score = featureWeights[id] * f.getValue();
            posScores.add(new Pair<>(f.getName(), score));
            negScores.add(new Pair<>(f.getName(), -score));
        }
    }
    posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
    negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
    weightMap.put(Event.EventType.ANOMALOUS.toString(), posScores);
    weightMap.put(Event.EventType.EXPECTED.toString(), negScores);
    return new Excuse<>(e, prediction, weightMap);
}
Also used : FeatureNode(de.bwaldvogel.liblinear.FeatureNode) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) Event(org.tribuo.anomaly.Event) ArrayList(java.util.ArrayList) List(java.util.List) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Excuse(org.tribuo.Excuse)

Example 24 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class KNNModel method innerPredictStreams.

/**
 * Predicts using a FJP and the Streams API.
 * @param examples The examples to predict.
 * @return The predictions.
 */
private List<Prediction<T>> innerPredictStreams(Iterable<Example<T>> examples) {
    List<Prediction<T>> predictions = new ArrayList<>();
    List<Prediction<T>> innerPredictions = null;
    ForkJoinPool fjp = System.getSecurityManager() == null ? new ForkJoinPool(numThreads) : new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
    for (Example<T> example : examples) {
        SGDVector input;
        if (example.size() == featureIDMap.size()) {
            input = DenseVector.createDenseVector(example, featureIDMap, false);
        } else {
            input = SparseVector.createSparseVector(example, featureIDMap, false);
        }
        Function<Pair<SGDVector, T>, OutputDoublePair<T>> distanceFunc = (a) -> new OutputDoublePair<>(a.getB(), DistanceType.getDistance(a.getA(), input, distType));
        Stream<Pair<SGDVector, T>> stream = Stream.of(vectors);
        try {
            innerPredictions = fjp.submit(() -> StreamUtil.boundParallelism(stream.parallel()).map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList())).get();
        } catch (InterruptedException | ExecutionException e) {
            logger.log(Level.SEVERE, "Exception when predicting in KNNModel", e);
        }
        predictions.add(combiner.combine(outputIDInfo, innerPredictions));
    }
    return predictions;
}
Also used : Example(org.tribuo.Example) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) ModelProvenance(org.tribuo.provenance.ModelProvenance) Prediction(org.tribuo.Prediction) PriorityQueue(java.util.PriorityQueue) ImmutableOutputInfo(org.tribuo.ImmutableOutputInfo) Model(org.tribuo.Model) ForkJoinWorkerThread(java.util.concurrent.ForkJoinWorkerThread) EnsembleCombiner(org.tribuo.ensemble.EnsembleCombiner) NeighboursQueryFactory(org.tribuo.math.neighbour.NeighboursQueryFactory) Function(java.util.function.Function) ArrayList(java.util.ArrayList) Level(java.util.logging.Level) Future(java.util.concurrent.Future) NeighboursQuery(org.tribuo.math.neighbour.NeighboursQuery) NeighboursBruteForceFactory(org.tribuo.math.neighbour.bruteforce.NeighboursBruteForceFactory) Output(org.tribuo.Output) SGDVector(org.tribuo.math.la.SGDVector) Map(java.util.Map) Excuse(org.tribuo.Excuse) DistanceType(org.tribuo.math.distance.DistanceType) SparseVector(org.tribuo.math.la.SparseVector) ExecutorService(java.util.concurrent.ExecutorService) IOException(java.io.IOException) Pair(com.oracle.labs.mlrg.olcut.util.Pair) DenseVector(org.tribuo.math.la.DenseVector) PrivilegedAction(java.security.PrivilegedAction) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) Executors(java.util.concurrent.Executors) Objects(java.util.Objects) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) Stream(java.util.stream.Stream) ForkJoinPool(java.util.concurrent.ForkJoinPool) Optional(java.util.Optional) StreamUtil(com.oracle.labs.mlrg.olcut.util.StreamUtil) AccessController(java.security.AccessController) Collections(java.util.Collections) Distance(org.tribuo.common.nearest.KNNTrainer.Distance) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) SGDVector(org.tribuo.math.la.SGDVector) ExecutionException(java.util.concurrent.ExecutionException) ForkJoinPool(java.util.concurrent.ForkJoinPool) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 25 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class KNNTrainer method train.

@Override
public Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
    ImmutableOutputInfo<T> labelIDMap = examples.getOutputIDInfo();
    // generic array creation
    @SuppressWarnings("unchecked") Pair<SGDVector, T>[] vectors = new Pair[examples.size()];
    int i = 0;
    for (Example<T> e : examples) {
        if (e.size() == featureIDMap.size()) {
            vectors[i] = new Pair<>(DenseVector.createDenseVector(e, featureIDMap, false), e.getOutput());
        } else {
            vectors[i] = new Pair<>(SparseVector.createSparseVector(e, featureIDMap, false), e.getOutput());
        }
        i++;
    }
    if (invocationCount != INCREMENT_INVOCATION_COUNT) {
        setInvocationCount(invocationCount);
    }
    trainInvocationCount++;
    ModelProvenance provenance = new ModelProvenance(KNNModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance);
    return new KNNModel<>(k + "nn", provenance, featureIDMap, labelIDMap, false, k, distType, numThreads, combiner, vectors, backend, neighboursQueryFactory);
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)59 ArrayList (java.util.ArrayList)27 List (java.util.List)21 HashMap (java.util.HashMap)18 MutableDataset (org.tribuo.MutableDataset)17 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)16 Label (org.tribuo.classification.Label)14 Feature (org.tribuo.Feature)11 Regressor (org.tribuo.regression.Regressor)11 Prediction (org.tribuo.Prediction)10 DenseVector (org.tribuo.math.la.DenseVector)10 SparseVector (org.tribuo.math.la.SparseVector)10 SGDVector (org.tribuo.math.la.SGDVector)9 Map (java.util.Map)7 Example (org.tribuo.Example)7 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)7 PriorityQueue (java.util.PriorityQueue)6 Excuse (org.tribuo.Excuse)5 Model (org.tribuo.Model)5 LabelFactory (org.tribuo.classification.LabelFactory)5