use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.
the class ConvergenceChecker method isConverged.
/**
* Checks convergency on dataset.
*
* @param envBuilder Learning environment builder.
* @param currMdl Current model.
* @return True if GDB is converged.
*/
public boolean isConverged(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(preprocessor);
try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(preprocessor), environment)) {
return isConverged(dataset, currMdl);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.
the class Pipeline method fit.
/**
* Fits the pipeline to the input dataset builder.
*/
public PipelineMdl<K, V> fit(DatasetBuilder datasetBuilder) {
if (finalStage == null)
throw new IllegalStateException("The Pipeline should be finished with the Training Stage.");
// Reload for new fit
finalPreprocessor = vectorizer;
preprocessingTrainers.forEach(e -> {
finalPreprocessor = e.fit(envBuilder, datasetBuilder, finalPreprocessor);
});
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.initDeployingContext(finalPreprocessor);
IgniteModel<Vector, Double> internalMdl = finalStage.fit(datasetBuilder, finalPreprocessor, env);
return new PipelineMdl<K, V>().withPreprocessor(finalPreprocessor).withInternalMdl(internalMdl);
}
use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.
the class PreprocessingTrainer method learningEnvironment.
/**
* Returns local learning environment with initialized deploying context by base preprocessor.
*
* @param basePreprocessor Preprocessor.
* @return Learning environment.
*/
public default LearningEnvironment learningEnvironment(Preprocessor<K, V> basePreprocessor) {
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.initDeployingContext(basePreprocessor);
return env;
}
Aggregations