Search in sources :

Example 1 with PropertyException

use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.

the class DateExtractor method postConfig.

/**
 * Used by the OLCUT configuration system, and should not be called by external code.
 */
@Override
public void postConfig() {
    super.postConfig();
    Locale locale;
    if ((localeLanguage == null) && (localeCountry == null)) {
        locale = Locale.getDefault(Locale.Category.FORMAT);
    } else if (localeLanguage == null) {
        throw new PropertyException("", "localeLanguage", "Must supply both localeLanguage and localeCountry when setting the locale.");
    } else if (localeCountry == null) {
        throw new PropertyException("", "localeCountry", "Must supply both localeLanguage and localeCountry when setting the locale.");
    } else {
        locale = new Locale(localeLanguage, localeCountry);
    }
    if (dateFormat != null) {
        try {
            formatter = DateTimeFormatter.ofPattern(dateFormat, locale);
        } catch (IllegalArgumentException e) {
            throw new PropertyException(e, "", "dateFormat", "dateFormat could not be parsed by DateTimeFormatter");
        }
    } else {
        formatter = DateTimeFormatter.BASIC_ISO_DATE;
    }
}
Also used : Locale(java.util.Locale) PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException)

Example 2 with PropertyException

use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.

the class OffsetDateTimeExtractor method postConfig.

/**
 * Used by the OLCUT configuration system, and should not be called by external code.
 */
@Override
public void postConfig() {
    super.postConfig();
    Locale locale;
    if ((localeLanguage == null) && (localeCountry == null)) {
        locale = Locale.getDefault(Locale.Category.FORMAT);
    } else if (localeLanguage == null) {
        throw new PropertyException("", "localeLanguage", "Must supply both localeLanguage and localeCountry when setting the locale.");
    } else if (localeCountry == null) {
        throw new PropertyException("", "localeCountry", "Must supply both localeLanguage and localeCountry when setting the locale.");
    } else {
        locale = new Locale(localeLanguage, localeCountry);
    }
    if (dateTimeFormat != null) {
        try {
            formatter = DateTimeFormatter.ofPattern(dateTimeFormat, locale);
        } catch (IllegalArgumentException e) {
            throw new PropertyException(e, "", "dateTimeFormat", "dateTimeFormat could not be parsed by DateTimeFormatter");
        }
    } else {
        throw new PropertyException("", "dateTimeFormat", "Invalid Date/Time format string supplied");
    }
}
Also used : Locale(java.util.Locale) PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException)

Example 3 with PropertyException

use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.

the class HdbscanTrainer method postConfig.

/**
 * Used by the OLCUT configuration system, and should not be called by external code.
 */
@Override
public synchronized void postConfig() {
    if (this.distanceType != null) {
        if (this.distType != null) {
            throw new PropertyException("distType", "Both distType and distanceType must not both be set.");
        } else {
            this.distType = this.distanceType.getDistanceType();
            this.distanceType = null;
        }
    }
    if (neighboursQueryFactory == null) {
        int numberThreads = (this.numThreads <= 0) ? 1 : this.numThreads;
        this.neighboursQueryFactory = new NeighboursBruteForceFactory(distType, numberThreads);
    } else {
        if (!this.distType.equals(neighboursQueryFactory.getDistanceType())) {
            throw new PropertyException("neighboursQueryFactory", "distType and its field on the " + "NeighboursQueryFactory must be equal.");
        }
    }
}
Also used : PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException) NeighboursBruteForceFactory(org.tribuo.math.neighbour.bruteforce.NeighboursBruteForceFactory)

Example 4 with PropertyException

use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.

the class GaussianClusterDataSource method postConfig.

/**
 * Used by the OLCUT configuration system, and should not be called by external code.
 */
@Override
public void postConfig() {
    if (numSamples < 1) {
        throw new PropertyException("", "numSamples", "numSamples must be positive, found " + numSamples);
    }
    if (mixingDistribution.length != 5) {
        throw new PropertyException("", "mixingDistribution", "mixingDistribution must have 5 elements, found " + mixingDistribution.length);
    }
    if (Math.abs(Util.sum(mixingDistribution) - 1.0) > 1e-10) {
        throw new PropertyException("", "mixingDistribution", "mixingDistribution must sum to 1.0, found " + Util.sum(mixingDistribution));
    }
    if ((firstMean.length > allFeatureNames.length) || (firstMean.length == 0)) {
        throw new PropertyException("", "firstMean", "Must have 1-4 features, found " + firstMean.length);
    }
    int covarianceSize = firstMean.length * firstMean.length;
    if (firstVariance.length != (covarianceSize)) {
        throw new PropertyException("", "firstVariance", "Invalid first covariance matrix, expected " + covarianceSize + " elements, found " + firstVariance.length);
    }
    if (secondMean.length != firstMean.length) {
        throw new PropertyException("", "secondMean", "All Gaussians must have the same number of dimensions, expected " + firstMean.length + ", found " + secondMean.length);
    }
    if (secondVariance.length != firstVariance.length) {
        throw new PropertyException("", "secondVariance", "secondVariance is invalid, expected " + covarianceSize + ", found " + secondVariance.length);
    }
    if (thirdMean.length != firstMean.length) {
        throw new PropertyException("", "thirdMean", "All Gaussians must have the same number of dimensions, expected " + firstMean.length + ", found " + thirdMean.length);
    }
    if (thirdVariance.length != firstVariance.length) {
        throw new PropertyException("", "thirdVariance", "thirdVariance is invalid, expected " + covarianceSize + ", found " + thirdVariance.length);
    }
    if (fourthMean.length != firstMean.length) {
        throw new PropertyException("", "fourthMean", "All Gaussians must have the same number of dimensions, expected " + firstMean.length + ", found " + fourthMean.length);
    }
    if (fourthVariance.length != firstVariance.length) {
        throw new PropertyException("", "fourthVariance", "fourthVariance is invalid, expected " + covarianceSize + ", found " + fourthVariance.length);
    }
    if (fifthMean.length != firstMean.length) {
        throw new PropertyException("", "fifthMean", "All Gaussians must have the same number of dimensions, expected " + firstMean.length + ", found " + fifthMean.length);
    }
    if (fifthVariance.length != firstVariance.length) {
        throw new PropertyException("", "fifthVariance", "fifthVariance is invalid, expected " + covarianceSize + ", found " + fifthVariance.length);
    }
    for (int i = 0; i < mixingDistribution.length; i++) {
        if (mixingDistribution[i] < 0) {
            throw new PropertyException("", "mixingDistribution", "Probability values in the mixing distribution must be non-negative, found " + Arrays.toString(mixingDistribution));
        }
    }
    double[] mixingCDF = Util.generateCDF(mixingDistribution);
    String[] featureNames = Arrays.copyOf(allFeatureNames, firstMean.length);
    Random rng = new Random(seed);
    MultivariateNormalDistribution first = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), firstMean, reshapeAndValidate(firstVariance, "firstVariance"));
    MultivariateNormalDistribution second = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), secondMean, reshapeAndValidate(secondVariance, "secondVariance"));
    MultivariateNormalDistribution third = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), thirdMean, reshapeAndValidate(thirdVariance, "thirdVariance"));
    MultivariateNormalDistribution fourth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), fourthMean, reshapeAndValidate(fourthVariance, "fourthVariance"));
    MultivariateNormalDistribution fifth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), fifthMean, reshapeAndValidate(fifthVariance, "fifthVariance"));
    MultivariateNormalDistribution[] Gaussians = new MultivariateNormalDistribution[] { first, second, third, fourth, fifth };
    List<Example<ClusterID>> examples = new ArrayList<>(numSamples);
    for (int i = 0; i < numSamples; i++) {
        int centroid = Util.sampleFromCDF(mixingCDF, rng);
        double[] sample = Gaussians[centroid].sample();
        examples.add(new ArrayExample<>(new ClusterID(centroid), featureNames, sample));
    }
    this.examples = Collections.unmodifiableList(examples);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException) ArrayList(java.util.ArrayList) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) Random(java.util.Random) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) JDKRandomGenerator(org.apache.commons.math3.random.JDKRandomGenerator)

Example 5 with PropertyException

use of com.oracle.labs.mlrg.olcut.config.PropertyException in project tribuo by oracle.

the class NonlinearGaussianDataSource method postConfig.

/**
 * Used by the OLCUT configuration system, and should not be called by external code.
 */
@Override
public void postConfig() {
    // We use java.util.Random here because SplittableRandom doesn't have nextGaussian yet.
    Random rng = new Random(seed);
    if (weights.length != 4) {
        throw new PropertyException("", "weights", "Must supply 4 weights, found " + weights.length);
    }
    if (xZeroMax <= xZeroMin) {
        throw new PropertyException("", "xZeroMax", "xZeroMax must be greater than xZeroMin, found xZeroMax = " + xZeroMax + ", xZeroMin = " + xZeroMin);
    }
    if (xOneMax <= xOneMin) {
        throw new PropertyException("", "xOneMax", "xOneMax must be greater than xOneMin, found xOneMax = " + xOneMax + ", xOneMin = " + xOneMin);
    }
    if (variance <= 0.0) {
        throw new PropertyException("", "variance", "Variance must be positive, found variance = " + variance);
    }
    List<Example<Regressor>> examples = new ArrayList<>(numSamples);
    double zeroRange = xZeroMax - xZeroMin;
    double oneRange = xOneMax - xOneMin;
    for (int i = 0; i < numSamples; i++) {
        double xZero = (rng.nextDouble() * zeroRange) + xZeroMin;
        double xOne = (rng.nextDouble() * oneRange) + xOneMin;
        // N(w_0*x_0 + w_1*x_1 + w_2*x_1*x_0 + w_3*x_1*x_1*x_1 + intercept,variance).
        double outputValue = (weights[0] * xZero) + (weights[1] * xOne) + (weights[2] * xZero * xOne) + (weights[3] * Math.pow(xOne, 3)) + intercept;
        Regressor output = new Regressor("Y", (rng.nextGaussian() * variance) + outputValue);
        ArrayExample<Regressor> e = new ArrayExample<>(output, featureNames, new double[] { xZero, xOne });
        examples.add(e);
    }
    this.examples = Collections.unmodifiableList(examples);
}
Also used : ArrayExample(org.tribuo.impl.ArrayExample) Random(java.util.Random) PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ArrayList(java.util.ArrayList) Regressor(org.tribuo.regression.Regressor)

Aggregations

PropertyException (com.oracle.labs.mlrg.olcut.config.PropertyException)18 Random (java.util.Random)6 ArrayList (java.util.ArrayList)5 Example (org.tribuo.Example)5 ArrayExample (org.tribuo.impl.ArrayExample)5 Locale (java.util.Locale)2 NeighboursBruteForceFactory (org.tribuo.math.neighbour.bruteforce.NeighboursBruteForceFactory)2 Regressor (org.tribuo.regression.Regressor)2 NodeInfo (ai.onnxruntime.NodeInfo)1 OrtException (ai.onnxruntime.OrtException)1 OrtSession (ai.onnxruntime.OrtSession)1 TensorInfo (ai.onnxruntime.TensorInfo)1 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 FileNotFoundException (java.io.FileNotFoundException)1 IOException (java.io.IOException)1 MalformedURLException (java.net.MalformedURLException)1 MessageDigest (java.security.MessageDigest)1 HashSet (java.util.HashSet)1 SplittableRandom (java.util.SplittableRandom)1 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)1