Search in sources :

Example 1 with MutableLong

use of in project tribuo by oracle.

the class MockOutputInfo method observe.

public void observe(MockOutput output) {
    if (output == MockOutputFactory.UNKNOWN_TEST_OUTPUT) {
    } else {
        String label = output.label;
        MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
        labels.computeIfAbsent(label, MockOutput::new);
        if (!labelIDMap.containsKey(label)) {
            labelIDMap.put(label, labelCounter);
            idLabelMap.put(labelCounter, label);
Also used : MutableLong(

Example 2 with MutableLong

use of in project tribuo by oracle.

the class HdbscanTrainer method train.

public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
    // increment the invocation count.
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        trainerProvenance = getProvenance();
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    SGDVector[] data = new SGDVector[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
    DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
    ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
    // The levels at which each point becomes noise
    double[] pointNoiseLevels = new double[data.length];
    // The last label of each point before becoming noise
    int[] pointLastClusters = new int[data.length];
    // The HDBSCAN* hierarchy
    Map<Integer, int[]> hierarchy = new HashMap<>();
    List<HdbscanCluster> clusters = computeHierarchyAndClusterTree(emst, minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
    List<Integer> clusterLabels = findProminentClusters(hierarchy, clusters, data.length);
    DenseVector outlierScoresVector = calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
    Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = generateClusterAssignments(clusterLabels, outlierScoresVector);
    // Use the cluster assignments to establish the clustering info
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    // Compute the cluster exemplars.
    List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
    // Get the outlier score value for points that are predicted as noise points.
    double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
    logger.log(Level.INFO, "Hdbscan is done.");
    ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(),, examples.getProvenance(), trainerProvenance, runProvenance);
    return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distType, noisePointsOutlierScore);
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector( ArrayList(java.util.ArrayList) List(java.util.List) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) MutableLong( DenseVector( ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Example 3 with MutableLong

use of in project tribuo by oracle.

the class PairDistribution method constructFromMap.

 * Constructs a joint distribution from the counts.
 * @param jointCount The joint count.
 * @param aCount The first marginal count.
 * @param bCount The second marginal count.
 * @param <T1> The type of the first variable.
 * @param <T2> The type of the second variable.
 * @return A pair distribution.
public static <T1, T2> PairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, MutableLong> jointCount, Map<T1, MutableLong> aCount, Map<T2, MutableLong> bCount) {
    long count = 0L;
    for (Entry<CachedPair<T1, T2>, MutableLong> e : jointCount.entrySet()) {
        CachedPair<T1, T2> pair = e.getKey();
        long curCount = e.getValue().longValue();
        T1 a = pair.getA();
        T2 b = pair.getB();
        MutableLong curACount = aCount.computeIfAbsent(a, k -> new MutableLong());
        MutableLong curBCount = bCount.computeIfAbsent(b, k -> new MutableLong());
        count += curCount;
    return new PairDistribution<>(count, jointCount, aCount, bCount);
Also used : MutableLong(

Example 4 with MutableLong

use of in project tribuo by oracle.

the class TripleDistribution method constructFromMap.

 * Constructs a TripleDistribution by marginalising the supplied joint distribution.
 * <p>
 * Sizes are used to preallocate the HashMaps.
 * @param jointCount The joint distribution.
 * @param abCount An empty hashmap for AB.
 * @param acCount An empty hashmap for AC.
 * @param bcCount An empty hashmap for BC.
 * @param aCount An empty hashmap for A.
 * @param bCount An empty hashmap for B.
 * @param cCount An empty hashmap for C.
 * @param <T1> The type of A.
 * @param <T2> The type of B.
 * @param <T3> The type of C.
 * @return A TripleDistribution.
public static <T1, T2, T3> TripleDistribution<T1, T2, T3> constructFromMap(Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount, Map<CachedPair<T1, T2>, MutableLong> abCount, Map<CachedPair<T1, T3>, MutableLong> acCount, Map<CachedPair<T2, T3>, MutableLong> bcCount, Map<T1, MutableLong> aCount, Map<T2, MutableLong> bCount, Map<T3, MutableLong> cCount) {
    long count = 0L;
    for (Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
        CachedTriple<T1, T2, T3> abc = e.getKey();
        long curCount = e.getValue().longValue();
        CachedPair<T1, T2> ab = abc.getAB();
        CachedPair<T1, T3> ac = abc.getAC();
        CachedPair<T2, T3> bc = abc.getBC();
        T1 a = abc.getA();
        T2 b = abc.getB();
        T3 c = abc.getC();
        count += curCount;
        MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong());
        MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong());
        MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong());
        MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong());
        MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong());
        MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong());
    return new TripleDistribution<>(count, jointCount, abCount, acCount, bcCount, aCount, bCount, cCount);
Also used : MutableLong(

Example 5 with MutableLong

use of in project tribuo by oracle.

the class MutableRegressionInfo method observe.

public void observe(Regressor output) {
    if (output == RegressionFactory.UNKNOWN_REGRESSOR) {
    } else {
        if (overallCount != 0) {
            // Validate that the dimensions in this regressor are the same as the ones already observed.
            String[] names = output.getNames();
            if (names.length != countMap.size()) {
                throw new IllegalArgumentException("Expected this Regressor to contain " + countMap.size() + " dimensions, found " + names.length);
            for (String name : names) {
                if (!countMap.containsKey(name)) {
                    throw new IllegalArgumentException("Regressor contains unexpected dimension named '" + name + "'");
        for (Regressor.DimensionTuple r : output) {
            String name = r.getName();
            double value = r.getValue();
            // Update max and min
            minMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() < b.doubleValue() ? a : b);
            maxMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() > b.doubleValue() ? a : b);
            // Update count
            MutableLong countValue = countMap.computeIfAbsent(name, k -> new MutableLong());
            // Update mean
            MutableDouble meanValue = meanMap.computeIfAbsent(name, k -> new MutableDouble());
            double delta = value - meanValue.doubleValue();
            meanValue.increment(delta / countValue.longValue());
            // Update running sum of squares
            double delta2 = value - meanValue.doubleValue();
            MutableDouble sumSquaresValue = sumSquaresMap.computeIfAbsent(name, k -> new MutableDouble());
            sumSquaresValue.increment(delta * delta2);
Also used : MutableLong( MutableDouble(


MutableLong ( HashMap (java.util.HashMap)6 ArrayList (java.util.ArrayList)3 LinkedHashMap (java.util.LinkedHashMap)3 List (java.util.List)3 Map (java.util.Map)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 ClusterID (org.tribuo.clustering.ClusterID)2 ImmutableClusteringInfo (org.tribuo.clustering.ImmutableClusteringInfo)2 DenseVector ( SGDVector ( ModelProvenance (org.tribuo.provenance.ModelProvenance)2 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)2 MutableDouble ( LinkedHashSet (java.util.LinkedHashSet)1 LinkedList (java.util.LinkedList)1 Queue (java.util.Queue)1 SplittableRandom (java.util.SplittableRandom)1 ExecutionException (java.util.concurrent.ExecutionException)1 ForkJoinPool (java.util.concurrent.ForkJoinPool)1