Search in sources :

Example 1 with ImmutableClusteringInfo

use of org.tribuo.clustering.ImmutableClusteringInfo in project tribuo by oracle.

the class HdbscanTrainer method train.

@Override
public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
    // increment the invocation count.
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    SGDVector[] data = new SGDVector[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
        }
        n++;
    }
    DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
    ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
    // The levels at which each point becomes noise
    double[] pointNoiseLevels = new double[data.length];
    // The last label of each point before becoming noise
    int[] pointLastClusters = new int[data.length];
    // The HDBSCAN* hierarchy
    Map<Integer, int[]> hierarchy = new HashMap<>();
    List<HdbscanCluster> clusters = computeHierarchyAndClusterTree(emst, minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
    propagateTree(clusters);
    List<Integer> clusterLabels = findProminentClusters(hierarchy, clusters, data.length);
    DenseVector outlierScoresVector = calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
    Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = generateClusterAssignments(clusterLabels, outlierScoresVector);
    // Use the cluster assignments to establish the clustering info
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    }
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    // Compute the cluster exemplars.
    List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
    // Get the outlier score value for points that are predicted as noise points.
    double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
    logger.log(Level.INFO, "Hdbscan is done.");
    ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distType, noisePointsOutlierScore);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) ArrayList(java.util.ArrayList) List(java.util.List) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) DenseVector(org.tribuo.math.la.DenseVector) ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Example 2 with ImmutableClusteringInfo

use of org.tribuo.clustering.ImmutableClusteringInfo in project tribuo by oracle.

the class KMeansTrainer method train.

@Override
public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    // Creates a new local RNG and adds one to the invocation count.
    TrainerProvenance trainerProvenance;
    SplittableRandom localRNG;
    synchronized (this) {
        if (invocationCount != INCREMENT_INVOCATION_COUNT) {
            setInvocationCount(invocationCount);
        }
        localRNG = rng.split();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    int[] oldCentre = new int[examples.size()];
    SGDVector[] data = new SGDVector[examples.size()];
    double[] weights = new double[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        weights[n] = example.getWeight();
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
        }
        oldCentre[n] = -1;
        n++;
    }
    DenseVector[] centroidVectors;
    switch(initialisationType) {
        case RANDOM:
            centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
            break;
        case PLUSPLUS:
            centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distType);
            break;
        default:
            throw new IllegalStateException("Unknown initialisation" + initialisationType);
    }
    Map<Integer, List<Integer>> clusterAssignments = new HashMap<>();
    boolean parallel = numThreads > 1;
    for (int i = 0; i < centroids; i++) {
        clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
    }
    AtomicInteger changeCounter = new AtomicInteger(0);
    Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> {
        double minDist = Double.POSITIVE_INFINITY;
        int clusterID = -1;
        int id = e.idx;
        SGDVector vector = e.vector;
        for (int j = 0; j < centroids; j++) {
            DenseVector cluster = centroidVectors[j];
            double distance = DistanceType.getDistance(cluster, vector, distType);
            if (distance < minDist) {
                minDist = distance;
                clusterID = j;
            }
        }
        clusterAssignments.get(clusterID).add(id);
        if (oldCentre[id] != clusterID) {
            // Changed the centroid of this vector.
            oldCentre[id] = clusterID;
            changeCounter.incrementAndGet();
        }
    };
    boolean converged = false;
    ForkJoinPool fjp = null;
    try {
        if (parallel) {
            if (System.getSecurityManager() == null) {
                fjp = new ForkJoinPool(numThreads);
            } else {
                fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
            }
        }
        for (int i = 0; (i < iterations) && !converged; i++) {
            logger.log(Level.FINE, "Beginning iteration " + i);
            changeCounter.set(0);
            for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
                e.getValue().clear();
            }
            // E step
            Stream<SGDVector> vecStream = Arrays.stream(data);
            Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
            Stream<IntAndVector> zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
            if (parallel) {
                Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
                try {
                    fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException("Parallel execution failed", e);
                }
            } else {
                zipStream.forEach(eStepFunc);
            }
            logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
            mStep(fjp, centroidVectors, clusterAssignments, data, weights);
            logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
            if (changeCounter.get() == 0) {
                converged = true;
                logger.log(Level.INFO, "K-Means converged at iteration " + i);
            }
        }
    } finally {
        if (fjp != null) {
            fjp.shutdown();
        }
    }
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    }
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, distType);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) ArrayList(java.util.ArrayList) List(java.util.List) ExecutionException(java.util.concurrent.ExecutionException) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SplittableRandom(java.util.SplittableRandom) DenseVector(org.tribuo.math.la.DenseVector) ForkJoinPool(java.util.concurrent.ForkJoinPool) ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Aggregations

MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)2 ArrayList (java.util.ArrayList)2 HashMap (java.util.HashMap)2 List (java.util.List)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 ClusterID (org.tribuo.clustering.ClusterID)2 ImmutableClusteringInfo (org.tribuo.clustering.ImmutableClusteringInfo)2 DenseVector (org.tribuo.math.la.DenseVector)2 SGDVector (org.tribuo.math.la.SGDVector)2 ModelProvenance (org.tribuo.provenance.ModelProvenance)2 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)2 SplittableRandom (java.util.SplittableRandom)1 ExecutionException (java.util.concurrent.ExecutionException)1 ForkJoinPool (java.util.concurrent.ForkJoinPool)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1