use of org.tribuo.clustering.ImmutableClusteringInfo in project tribuo by oracle.
the class HdbscanTrainer method train.
@Override
public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
// increment the invocation count.
TrainerProvenance trainerProvenance;
synchronized (this) {
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
SGDVector[] data = new SGDVector[examples.size()];
int n = 0;
for (Example<ClusterID> example : examples) {
if (example.size() == featureMap.size()) {
data[n] = DenseVector.createDenseVector(example, featureMap, false);
} else {
data[n] = SparseVector.createSparseVector(example, featureMap, false);
}
n++;
}
DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
// The levels at which each point becomes noise
double[] pointNoiseLevels = new double[data.length];
// The last label of each point before becoming noise
int[] pointLastClusters = new int[data.length];
// The HDBSCAN* hierarchy
Map<Integer, int[]> hierarchy = new HashMap<>();
List<HdbscanCluster> clusters = computeHierarchyAndClusterTree(emst, minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
propagateTree(clusters);
List<Integer> clusterLabels = findProminentClusters(hierarchy, clusters, data.length);
DenseVector outlierScoresVector = calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = generateClusterAssignments(clusterLabels, outlierScoresVector);
// Use the cluster assignments to establish the clustering info
Map<Integer, MutableLong> counts = new HashMap<>();
for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
counts.put(e.getKey(), new MutableLong(e.getValue().size()));
}
ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
// Compute the cluster exemplars.
List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
// Get the outlier score value for points that are predicted as noise points.
double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
logger.log(Level.INFO, "Hdbscan is done.");
ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distType, noisePointsOutlierScore);
}
use of org.tribuo.clustering.ImmutableClusteringInfo in project tribuo by oracle.
the class KMeansTrainer method train.
@Override
public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) {
// Creates a new local RNG and adds one to the invocation count.
TrainerProvenance trainerProvenance;
SplittableRandom localRNG;
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
int[] oldCentre = new int[examples.size()];
SGDVector[] data = new SGDVector[examples.size()];
double[] weights = new double[examples.size()];
int n = 0;
for (Example<ClusterID> example : examples) {
weights[n] = example.getWeight();
if (example.size() == featureMap.size()) {
data[n] = DenseVector.createDenseVector(example, featureMap, false);
} else {
data[n] = SparseVector.createSparseVector(example, featureMap, false);
}
oldCentre[n] = -1;
n++;
}
DenseVector[] centroidVectors;
switch(initialisationType) {
case RANDOM:
centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
break;
case PLUSPLUS:
centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distType);
break;
default:
throw new IllegalStateException("Unknown initialisation" + initialisationType);
}
Map<Integer, List<Integer>> clusterAssignments = new HashMap<>();
boolean parallel = numThreads > 1;
for (int i = 0; i < centroids; i++) {
clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
}
AtomicInteger changeCounter = new AtomicInteger(0);
Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> {
double minDist = Double.POSITIVE_INFINITY;
int clusterID = -1;
int id = e.idx;
SGDVector vector = e.vector;
for (int j = 0; j < centroids; j++) {
DenseVector cluster = centroidVectors[j];
double distance = DistanceType.getDistance(cluster, vector, distType);
if (distance < minDist) {
minDist = distance;
clusterID = j;
}
}
clusterAssignments.get(clusterID).add(id);
if (oldCentre[id] != clusterID) {
// Changed the centroid of this vector.
oldCentre[id] = clusterID;
changeCounter.incrementAndGet();
}
};
boolean converged = false;
ForkJoinPool fjp = null;
try {
if (parallel) {
if (System.getSecurityManager() == null) {
fjp = new ForkJoinPool(numThreads);
} else {
fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
}
}
for (int i = 0; (i < iterations) && !converged; i++) {
logger.log(Level.FINE, "Beginning iteration " + i);
changeCounter.set(0);
for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
e.getValue().clear();
}
// E step
Stream<SGDVector> vecStream = Arrays.stream(data);
Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
Stream<IntAndVector> zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
if (parallel) {
Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
try {
fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("Parallel execution failed", e);
}
} else {
zipStream.forEach(eStepFunc);
}
logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
mStep(fjp, centroidVectors, clusterAssignments, data, weights);
logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
if (changeCounter.get() == 0) {
converged = true;
logger.log(Level.INFO, "K-Means converged at iteration " + i);
}
}
} finally {
if (fjp != null) {
fjp.shutdown();
}
}
Map<Integer, MutableLong> counts = new HashMap<>();
for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
counts.put(e.getKey(), new MutableLong(e.getValue().size()));
}
ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, distType);
}
Aggregations