Search in sources :

Example 1 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class MockOutputInfo method observe.

@Override
public void observe(MockOutput output) {
    if (output == MockOutputFactory.UNKNOWN_TEST_OUTPUT) {
        unknownCount++;
    } else {
        String label = output.label;
        MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
        labels.computeIfAbsent(label, MockOutput::new);
        value.increment();
        if (!labelIDMap.containsKey(label)) {
            labelIDMap.put(label, labelCounter);
            idLabelMap.put(labelCounter, label);
            labelCounter++;
        }
    }
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Example 2 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class HdbscanTrainer method train.

@Override
public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
    // increment the invocation count.
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    SGDVector[] data = new SGDVector[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
        }
        n++;
    }
    DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
    ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
    // The levels at which each point becomes noise
    double[] pointNoiseLevels = new double[data.length];
    // The last label of each point before becoming noise
    int[] pointLastClusters = new int[data.length];
    // The HDBSCAN* hierarchy
    Map<Integer, int[]> hierarchy = new HashMap<>();
    List<HdbscanCluster> clusters = computeHierarchyAndClusterTree(emst, minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
    propagateTree(clusters);
    List<Integer> clusterLabels = findProminentClusters(hierarchy, clusters, data.length);
    DenseVector outlierScoresVector = calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
    Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = generateClusterAssignments(clusterLabels, outlierScoresVector);
    // Use the cluster assignments to establish the clustering info
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    }
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    // Compute the cluster exemplars.
    List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
    // Get the outlier score value for points that are predicted as noise points.
    double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
    logger.log(Level.INFO, "Hdbscan is done.");
    ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distType, noisePointsOutlierScore);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) ArrayList(java.util.ArrayList) List(java.util.List) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) DenseVector(org.tribuo.math.la.DenseVector) ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Example 3 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class PairDistribution method constructFromMap.

/**
 * Constructs a joint distribution from the counts.
 * @param jointCount The joint count.
 * @param aCount The first marginal count.
 * @param bCount The second marginal count.
 * @param <T1> The type of the first variable.
 * @param <T2> The type of the second variable.
 * @return A pair distribution.
 */
public static <T1, T2> PairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, MutableLong> jointCount, Map<T1, MutableLong> aCount, Map<T2, MutableLong> bCount) {
    long count = 0L;
    for (Entry<CachedPair<T1, T2>, MutableLong> e : jointCount.entrySet()) {
        CachedPair<T1, T2> pair = e.getKey();
        long curCount = e.getValue().longValue();
        T1 a = pair.getA();
        T2 b = pair.getB();
        MutableLong curACount = aCount.computeIfAbsent(a, k -> new MutableLong());
        curACount.increment(curCount);
        MutableLong curBCount = bCount.computeIfAbsent(b, k -> new MutableLong());
        curBCount.increment(curCount);
        count += curCount;
    }
    return new PairDistribution<>(count, jointCount, aCount, bCount);
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Example 4 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class TripleDistribution method constructFromMap.

/**
 * Constructs a TripleDistribution by marginalising the supplied joint distribution.
 * <p>
 * Sizes are used to preallocate the HashMaps.
 * @param jointCount The joint distribution.
 * @param abCount An empty hashmap for AB.
 * @param acCount An empty hashmap for AC.
 * @param bcCount An empty hashmap for BC.
 * @param aCount An empty hashmap for A.
 * @param bCount An empty hashmap for B.
 * @param cCount An empty hashmap for C.
 * @param <T1> The type of A.
 * @param <T2> The type of B.
 * @param <T3> The type of C.
 * @return A TripleDistribution.
 */
public static <T1, T2, T3> TripleDistribution<T1, T2, T3> constructFromMap(Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount, Map<CachedPair<T1, T2>, MutableLong> abCount, Map<CachedPair<T1, T3>, MutableLong> acCount, Map<CachedPair<T2, T3>, MutableLong> bcCount, Map<T1, MutableLong> aCount, Map<T2, MutableLong> bCount, Map<T3, MutableLong> cCount) {
    long count = 0L;
    for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
        CachedTriple<T1, T2, T3> abc = e.getKey();
        long curCount = e.getValue().longValue();
        CachedPair<T1, T2> ab = abc.getAB();
        CachedPair<T1, T3> ac = abc.getAC();
        CachedPair<T2, T3> bc = abc.getBC();
        T1 a = abc.getA();
        T2 b = abc.getB();
        T3 c = abc.getC();
        count += curCount;
        MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong());
        abCurCount.increment(curCount);
        MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong());
        acCurCount.increment(curCount);
        MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong());
        bcCurCount.increment(curCount);
        MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong());
        aCurCount.increment(curCount);
        MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong());
        bCurCount.increment(curCount);
        MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong());
        cCurCount.increment(curCount);
    }
    return new TripleDistribution<>(count, jointCount, abCount, acCount, bcCount, aCount, bCount, cCount);
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Example 5 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class MutableRegressionInfo method observe.

@Override
public void observe(Regressor output) {
    if (output == RegressionFactory.UNKNOWN_REGRESSOR) {
        unknownCount++;
    } else {
        if (overallCount != 0) {
            // Validate that the dimensions in this regressor are the same as the ones already observed.
            String[] names = output.getNames();
            if (names.length != countMap.size()) {
                throw new IllegalArgumentException("Expected this Regressor to contain " + countMap.size() + " dimensions, found " + names.length);
            }
            for (String name : names) {
                if (!countMap.containsKey(name)) {
                    throw new IllegalArgumentException("Regressor contains unexpected dimension named '" + name + "'");
                }
            }
        }
        for (Regressor.DimensionTuple r : output) {
            String name = r.getName();
            double value = r.getValue();
            // Update max and min
            minMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() < b.doubleValue() ? a : b);
            maxMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() > b.doubleValue() ? a : b);
            // Update count
            MutableLong countValue = countMap.computeIfAbsent(name, k -> new MutableLong());
            countValue.increment();
            // Update mean
            MutableDouble meanValue = meanMap.computeIfAbsent(name, k -> new MutableDouble());
            double delta = value - meanValue.doubleValue();
            meanValue.increment(delta / countValue.longValue());
            // Update running sum of squares
            double delta2 = value - meanValue.doubleValue();
            MutableDouble sumSquaresValue = sumSquaresMap.computeIfAbsent(name, k -> new MutableDouble());
            sumSquaresValue.increment(delta * delta2);
        }
        overallCount++;
    }
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) MutableDouble(com.oracle.labs.mlrg.olcut.util.MutableDouble)

Aggregations

MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)15 HashMap (java.util.HashMap)6 ArrayList (java.util.ArrayList)3 LinkedHashMap (java.util.LinkedHashMap)3 List (java.util.List)3 Map (java.util.Map)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 ClusterID (org.tribuo.clustering.ClusterID)2 ImmutableClusteringInfo (org.tribuo.clustering.ImmutableClusteringInfo)2 DenseVector (org.tribuo.math.la.DenseVector)2 SGDVector (org.tribuo.math.la.SGDVector)2 ModelProvenance (org.tribuo.provenance.ModelProvenance)2 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)2 MutableDouble (com.oracle.labs.mlrg.olcut.util.MutableDouble)1 LinkedHashSet (java.util.LinkedHashSet)1 LinkedList (java.util.LinkedList)1 Queue (java.util.Queue)1 SplittableRandom (java.util.SplittableRandom)1 ExecutionException (java.util.concurrent.ExecutionException)1 ForkJoinPool (java.util.concurrent.ForkJoinPool)1