use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class MutableMockMultiOutputInfo method observe.
/**
* Throws IllegalStateException if the MockMultiOutput contains a Label which has a "," in it.
*
* Such labels are disallowed. There should be an exception thrown when one is constructed
* too.
* @param output The observed output.
*/
@Override
public void observe(MockMultiOutput output) {
if (output == MockMultiOutputFactory.UNKNOWN_MULTILABEL) {
unknownCount++;
} else {
for (String label : output.getNameSet()) {
if (label.contains(",")) {
throw new IllegalStateException("MockMultiOutput cannot use a Label which contains ','. The supplied label was " + label + ".");
}
MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
labels.computeIfAbsent(label, MockMultiOutput::new);
value.increment();
}
totalCount++;
}
}
use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class Dataset method createTransformers.
/**
* Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by
* observing all the values in this dataset.
* <p>
* Does not mutate the dataset, if you wish to apply the TransformerMap, use
* {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
* <p>
* TransformerMaps operate on feature values which are present, sparse values
* are ignored and not transformed. If the zeros should be transformed, call
* {@link MutableDataset#densify} on the datasets before applying a transformer.
* See {@link org.tribuo.transform} for a more detailed discussion of densify and includeImplicitZeroFeatures.
* <p>
* Throws {@link IllegalArgumentException} if the TransformationMap object has
* regexes which apply to multiple features.
* @param transformations The transformations to fit.
* @param includeImplicitZeroFeatures Use the implicit zero feature values to construct the transformations.
* @return A TransformerMap which can apply the transformations to a dataset.
*/
public TransformerMap createTransformers(TransformationMap transformations, boolean includeImplicitZeroFeatures) {
ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());
// Validate map by checking no regex applies to multiple features.
logger.fine(String.format("Processing %d feature specific transforms", transformations.getFeatureTransformations().size()));
Map<String, List<Transformation>> featureTransformations = new HashMap<>();
for (Map.Entry<String, List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) {
// Compile the regex.
Pattern pattern = Pattern.compile(entry.getKey());
// Check all the feature names
for (String name : featureNames) {
// If the regex matches
if (pattern.matcher(name).matches()) {
List<Transformation> oldTransformations = featureTransformations.put(name, entry.getValue());
// See if there is already a transformation list for that name.
if (oldTransformations != null) {
throw new IllegalArgumentException("Feature name '" + name + "' matches multiple regexes, at least one of which was '" + entry.getKey() + "'.");
}
}
}
}
// Populate the feature transforms map.
Map<String, Queue<TransformStatistics>> featureStats = new HashMap<>();
// sparseCount tracks how many times a feature was not observed
Map<String, MutableLong> sparseCount = new HashMap<>();
for (Map.Entry<String, List<Transformation>> entry : featureTransformations.entrySet()) {
// Create the queue of feature transformations for this feature
Queue<TransformStatistics> l = new LinkedList<>();
for (Transformation t : entry.getValue()) {
l.add(t.createStats());
}
// Add the queue to the map for that feature
featureStats.put(entry.getKey(), l);
sparseCount.put(entry.getKey(), new MutableLong(data.size()));
}
if (!transformations.getGlobalTransformations().isEmpty()) {
// Append all the global transformations
int ntransform = featureNames.size();
logger.fine(String.format("Starting %,d global transformations", ntransform));
int ndone = 0;
for (String v : featureNames) {
// Create the queue of feature transformations for this feature
Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>());
for (Transformation t : transformations.getGlobalTransformations()) {
l.add(t.createStats());
}
// Add the queue to the map for that feature
featureStats.put(v, l);
// Generate the sparse count initialised to the number of features.
sparseCount.putIfAbsent(v, new MutableLong(data.size()));
ndone++;
if (logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
}
}
}
Map<String, List<Transformer>> output = new LinkedHashMap<>();
Set<String> removeSet = new LinkedHashSet<>();
boolean initialisedSparseCounts = false;
// Iterate through the dataset max(transformations.length) times.
while (!featureStats.isEmpty()) {
for (Example<T> example : data) {
for (Feature f : example) {
if (featureStats.containsKey(f.getName())) {
if (!initialisedSparseCounts) {
sparseCount.get(f.getName()).decrement();
}
List<Transformer> curTransformers = output.get(f.getName());
// Apply all current transformations
double fValue = TransformerMap.applyTransformerList(f.getValue(), curTransformers);
// Observe the transformed value
featureStats.get(f.getName()).peek().observeValue(fValue);
}
}
}
// Sparse counts are updated (this could be protected by an if statement)
initialisedSparseCounts = true;
removeSet.clear();
// Emit the new transformers onto the end of the list in the output map.
for (Map.Entry<String, Queue<TransformStatistics>> entry : featureStats.entrySet()) {
TransformStatistics currentStats = entry.getValue().poll();
if (includeImplicitZeroFeatures) {
// Observe all the sparse feature values
int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
currentStats.observeSparse(unobservedFeatures);
}
// Get the transformer list for that feature (if absent)
List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>());
// Generate the transformer and add it to the appropriate list.
l.add(currentStats.generateTransformer());
// If the queue is empty, remove that feature, ensuring that featureStats is eventually empty.
if (entry.getValue().isEmpty()) {
removeSet.add(entry.getKey());
}
}
// Remove the features with empty queues.
for (String s : removeSet) {
featureStats.remove(s);
}
}
return new TransformerMap(output, getProvenance(), transformations.getProvenance());
}
use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class KMeansTrainer method train.
@Override
public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) {
// Creates a new local RNG and adds one to the invocation count.
TrainerProvenance trainerProvenance;
SplittableRandom localRNG;
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
int[] oldCentre = new int[examples.size()];
SGDVector[] data = new SGDVector[examples.size()];
double[] weights = new double[examples.size()];
int n = 0;
for (Example<ClusterID> example : examples) {
weights[n] = example.getWeight();
if (example.size() == featureMap.size()) {
data[n] = DenseVector.createDenseVector(example, featureMap, false);
} else {
data[n] = SparseVector.createSparseVector(example, featureMap, false);
}
oldCentre[n] = -1;
n++;
}
DenseVector[] centroidVectors;
switch(initialisationType) {
case RANDOM:
centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
break;
case PLUSPLUS:
centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distType);
break;
default:
throw new IllegalStateException("Unknown initialisation" + initialisationType);
}
Map<Integer, List<Integer>> clusterAssignments = new HashMap<>();
boolean parallel = numThreads > 1;
for (int i = 0; i < centroids; i++) {
clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
}
AtomicInteger changeCounter = new AtomicInteger(0);
Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> {
double minDist = Double.POSITIVE_INFINITY;
int clusterID = -1;
int id = e.idx;
SGDVector vector = e.vector;
for (int j = 0; j < centroids; j++) {
DenseVector cluster = centroidVectors[j];
double distance = DistanceType.getDistance(cluster, vector, distType);
if (distance < minDist) {
minDist = distance;
clusterID = j;
}
}
clusterAssignments.get(clusterID).add(id);
if (oldCentre[id] != clusterID) {
// Changed the centroid of this vector.
oldCentre[id] = clusterID;
changeCounter.incrementAndGet();
}
};
boolean converged = false;
ForkJoinPool fjp = null;
try {
if (parallel) {
if (System.getSecurityManager() == null) {
fjp = new ForkJoinPool(numThreads);
} else {
fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
}
}
for (int i = 0; (i < iterations) && !converged; i++) {
logger.log(Level.FINE, "Beginning iteration " + i);
changeCounter.set(0);
for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
e.getValue().clear();
}
// E step
Stream<SGDVector> vecStream = Arrays.stream(data);
Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
Stream<IntAndVector> zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
if (parallel) {
Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
try {
fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("Parallel execution failed", e);
}
} else {
zipStream.forEach(eStepFunc);
}
logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
mStep(fjp, centroidVectors, clusterAssignments, data, weights);
logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
if (changeCounter.get() == 0) {
converged = true;
logger.log(Level.INFO, "K-Means converged at iteration " + i);
}
}
} finally {
if (fjp != null) {
fjp.shutdown();
}
}
Map<Integer, MutableLong> counts = new HashMap<>();
for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
counts.put(e.getKey(), new MutableLong(e.getValue().size()));
}
ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, distType);
}
use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class ClusteringFactory method constructInfoForExternalModel.
/**
* Unlike the other info types, clustering directly uses the integer IDs as the stored value,
* so this mapping discards the cluster IDs and just uses the supplied integers.
* @param mapping The mapping to use.
* @return An {@link ImmutableOutputInfo} for the clustering.
*/
@Override
public ImmutableOutputInfo<ClusterID> constructInfoForExternalModel(Map<ClusterID, Integer> mapping) {
// Validate inputs are dense
OutputFactory.validateMapping(mapping);
Map<Integer, MutableLong> countsMap = new HashMap<>();
for (Map.Entry<ClusterID, Integer> e : mapping.entrySet()) {
countsMap.put(e.getValue(), new MutableLong(1));
}
return new ImmutableClusteringInfo(countsMap);
}
use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class MutableClusteringInfo method observe.
@Override
public void observe(ClusterID output) {
if (output == ClusteringFactory.UNASSIGNED_CLUSTER_ID) {
unknownCount++;
} else {
int id = output.getID();
MutableLong value = clusterCounts.computeIfAbsent(id, k -> new MutableLong());
value.increment();
}
}
Aggregations