use of com.amplifyframework.predictions.PredictionsException in project amplify-android by aws-amplify.
the class AWSPredictionsPlugin method configure.
@Override
public void configure(JSONObject pluginConfiguration, @NonNull Context context) throws PredictionsException {
this.configuration = AWSPredictionsPluginConfiguration.fromJson(pluginConfiguration);
AWSCredentialsProvider credentialsProvider;
if (credentialsProviderOverride != null) {
credentialsProvider = credentialsProviderOverride;
} else {
try {
credentialsProvider = (AWSMobileClient) Amplify.Auth.getPlugin(AUTH_DEPENDENCY_PLUGIN_KEY).getEscapeHatch();
} catch (IllegalStateException exception) {
throw new PredictionsException("AWSPredictionsPlugin depends on AWSCognitoAuthPlugin but it is currently missing", exception, "Before configuring Amplify, be sure to add AWSPredictionsPlugin same as you added " + "AWSPinpointAnalyticsPlugin.");
}
}
this.predictionsService = new AWSPredictionsService(configuration, credentialsProvider);
}
use of com.amplifyframework.predictions.PredictionsException in project amplify-android by aws-amplify.
the class AWSPredictionsPluginConfiguration method fromJson.
/**
* Constructs an instance of {@link AWSPredictionsPluginConfiguration} from
* the plugin configuration JSON object.
* @param configurationJson the plugin configuration
* @return the configuration object for AWS Predictions Plugin
* @throws PredictionsException if configuration is missing or malformed
*/
@NonNull
static AWSPredictionsPluginConfiguration fromJson(JSONObject configurationJson) throws PredictionsException {
if (configurationJson == null) {
throw new PredictionsException("Could not locate predictions configuration for AWS Predictions Plugin.", "Verify that amplifyconfiguration.json contains a section for \"awsPredictionsPlugin\".");
}
final Region defaultRegion;
final SpeechGeneratorConfiguration speechGeneratorConfiguration;
final TranslateTextConfiguration translateTextConfiguration;
final IdentifyLabelsConfiguration identifyLabelsConfiguration;
final IdentifyEntitiesConfiguration identifyEntitiesConfiguration;
final IdentifyTextConfiguration identifyTextConfiguration;
final InterpretTextConfiguration interpretConfiguration;
try {
// Get default region
String regionString = configurationJson.getString(ConfigKey.DEFAULT_REGION.key());
defaultRegion = Region.getRegion(regionString);
if (configurationJson.has(ConfigKey.CONVERT.key())) {
JSONObject convertJson = configurationJson.getJSONObject(ConfigKey.CONVERT.key());
speechGeneratorConfiguration = SpeechGeneratorConfiguration.fromJson(convertJson);
translateTextConfiguration = TranslateTextConfiguration.fromJson(convertJson);
} else {
speechGeneratorConfiguration = null;
translateTextConfiguration = null;
}
if (configurationJson.has(ConfigKey.IDENTIFY.key())) {
JSONObject identifyJson = configurationJson.getJSONObject(ConfigKey.IDENTIFY.key());
identifyLabelsConfiguration = IdentifyLabelsConfiguration.fromJson(identifyJson);
identifyEntitiesConfiguration = IdentifyEntitiesConfiguration.fromJson(identifyJson);
identifyTextConfiguration = IdentifyTextConfiguration.fromJson(identifyJson);
} else {
identifyLabelsConfiguration = null;
identifyEntitiesConfiguration = null;
identifyTextConfiguration = null;
}
if (configurationJson.has(ConfigKey.INTERPRET.key())) {
JSONObject interpretJson = configurationJson.getJSONObject(ConfigKey.INTERPRET.key());
interpretConfiguration = InterpretTextConfiguration.fromJson(interpretJson);
} else {
interpretConfiguration = null;
}
} catch (JSONException | IllegalArgumentException exception) {
throw new PredictionsException("Issue encountered while parsing configuration JSON", exception, "Check the attached exception for more details.");
}
return new AWSPredictionsPluginConfiguration(defaultRegion, speechGeneratorConfiguration, translateTextConfiguration, identifyLabelsConfiguration, identifyEntitiesConfiguration, identifyTextConfiguration, interpretConfiguration);
}
use of com.amplifyframework.predictions.PredictionsException in project amplify-android by aws-amplify.
the class TensorFlowTextClassificationService method classify.
/**
* Classifies text to analyze associated sentiments.
* @param text the text to classify
* @param onSuccess notified when classification succeeds
* @param onError notified when classification fails
*/
void classify(@NonNull String text, @NonNull Consumer<InterpretResult> onSuccess, @NonNull Consumer<PredictionsException> onError) {
// Escape early if the initialization failed
if (loadingError != null) {
onError.accept(loadingError);
return;
}
// Wait for initialization to complete
// TODO: encapsulate blocking logic elsewhere
boolean didLoad = false;
try {
didLoad = loaded.await(LOAD_TIMEOUT_MS, TimeUnit.MILLISECONDS);
} catch (InterruptedException exception) {
onError.accept(new PredictionsException("Text classification service initialization was interrupted.", "Please wait for the required assets to be fully loaded."));
return;
}
if (!didLoad) {
onError.accept(new PredictionsException("Text classification service timed out while awaiting load.", "Your classification data may be too resource intensive?"));
}
try {
final Sentiment sentiment = fetchSentiment(text);
onSuccess.accept(InterpretResult.builder().sentiment(sentiment).build());
} catch (PredictionsException exception) {
onError.accept(exception);
}
}
use of com.amplifyframework.predictions.PredictionsException in project amplify-android by aws-amplify.
the class TensorFlowTextClassificationService method fetchSentiment.
@VisibleForTesting
Sentiment fetchSentiment(String text) throws PredictionsException {
float[][] input;
float[][] output;
try {
// Pre-process input text
input = dictionary.tokenizeInputText(text);
output = new float[1][labels.size()];
// Run inference.
interpreter.run(input, output);
} catch (IllegalArgumentException exception) {
throw new PredictionsException("TensorFlow Lite failed to make an inference.", exception, "Verify that the label size matches the output size of the model.");
}
// Find the predominant sentiment
Sentiment sentiment = null;
for (int i = 0; i < labels.size(); i++) {
SentimentType sentimentType = SentimentTypeAdapter.fromTensorFlow(labels.get(i));
float confidenceScore = output[0][i] * PERCENT;
if (sentiment == null || sentiment.getConfidence() < confidenceScore) {
sentiment = Sentiment.builder().value(sentimentType).confidence(confidenceScore).build();
}
}
return sentiment;
}
use of com.amplifyframework.predictions.PredictionsException in project amplify-android by aws-amplify.
the class TextClassificationModel method load.
/**
* Loads the pre-trained text classification model into
* TensorFlow Lite interpreter.
*/
@WorkerThread
@Override
public synchronized void load() {
// No-op if loaded already
if (loaded) {
return;
}
try {
ByteBuffer buffer = loadModelFile();
interpreter = new Interpreter(buffer);
if (onLoaded != null) {
onLoaded.accept(interpreter);
}
loaded = true;
} catch (PredictionsException exception) {
if (onLoadError != null) {
onLoadError.accept(exception);
}
}
}
Aggregations