Search in sources :

Example 1 with Predictor

use of ai.djl.inference.Predictor in project build-your-own-social-media-analytics-with-apache-kafka by scholzj.

the class TopologyProducer method buildTopology.

@Produces
public Topology buildTopology() {
    final TweetSerde tweetSerde = new TweetSerde();
    try {
        Criteria<String, Classifications> criteria = Criteria.builder().optApplication(Application.NLP.SENTIMENT_ANALYSIS).setTypes(String.class, Classifications.class).build();
        predictor = ModelZoo.loadModel(criteria).newPredictor();
    } catch (IOException | ModelNotFoundException | MalformedModelException e) {
        LOG.error("Failed to load model", e);
        throw new RuntimeException("Failed to load model", e);
    }
    final StreamsBuilder builder = new StreamsBuilder();
    builder.stream(SOURCE_TOPIC, Consumed.with(Serdes.ByteArray(), tweetSerde)).flatMapValues(value -> {
        if (value.getRetweetedStatus() != null) {
            // We ignore retweets => we do not want alert for every retweet
            return List.of();
        } else {
            String tweet = value.getText();
            try {
                Classifications classifications = predictor.predict(tweet);
                String statusUrl = "https://twitter.com/" + value.getUser().getScreenName() + "/status/" + value.getId();
                String alert = String.format("The following tweet was classified as %s with %2.2f%% probability: %s", classifications.best().getClassName().toLowerCase(Locale.ENGLISH), classifications.best().getProbability() * 100, statusUrl);
                // We care nly about strong results where probability is > 50%
                if (classifications.best().getProbability() > 0.50) {
                    LOG.infov("Tweeting: {0}", alert);
                    return List.of(alert);
                } else {
                    LOG.infov("Not tweeting: {0}", alert);
                    return List.of();
                }
            } catch (TranslateException e) {
                LOG.errorv("Failed to classify the tweet {0}", value);
                return List.of();
            }
        }
    }).peek((key, value) -> LOG.infov("{0}", value)).to(TARGET_TOPIC, Produced.with(Serdes.ByteArray(), Serdes.String()));
    return builder.build();
}
Also used : StreamsBuilder(org.apache.kafka.streams.StreamsBuilder) ModelNotFoundException(ai.djl.repository.zoo.ModelNotFoundException) StreamsBuilder(org.apache.kafka.streams.StreamsBuilder) Produces(javax.enterprise.inject.Produces) Produced(org.apache.kafka.streams.kstream.Produced) Consumed(org.apache.kafka.streams.kstream.Consumed) Logger(org.jboss.logging.Logger) TweetSerde(cz.scholz.sentimentanalysis.model.TweetSerde) IOException(java.io.IOException) MalformedModelException(ai.djl.MalformedModelException) Predictor(ai.djl.inference.Predictor) Classifications(ai.djl.modality.Classifications) ModelZoo(ai.djl.repository.zoo.ModelZoo) Application(ai.djl.Application) List(java.util.List) TranslateException(ai.djl.translate.TranslateException) Locale(java.util.Locale) Serdes(org.apache.kafka.common.serialization.Serdes) ApplicationScoped(javax.enterprise.context.ApplicationScoped) Criteria(ai.djl.repository.zoo.Criteria) Topology(org.apache.kafka.streams.Topology) Classifications(ai.djl.modality.Classifications) TranslateException(ai.djl.translate.TranslateException) ModelNotFoundException(ai.djl.repository.zoo.ModelNotFoundException) TweetSerde(cz.scholz.sentimentanalysis.model.TweetSerde) MalformedModelException(ai.djl.MalformedModelException) IOException(java.io.IOException) Produces(javax.enterprise.inject.Produces)

Aggregations

Application (ai.djl.Application)1 MalformedModelException (ai.djl.MalformedModelException)1 Predictor (ai.djl.inference.Predictor)1 Classifications (ai.djl.modality.Classifications)1 Criteria (ai.djl.repository.zoo.Criteria)1 ModelNotFoundException (ai.djl.repository.zoo.ModelNotFoundException)1 ModelZoo (ai.djl.repository.zoo.ModelZoo)1 TranslateException (ai.djl.translate.TranslateException)1 TweetSerde (cz.scholz.sentimentanalysis.model.TweetSerde)1 IOException (java.io.IOException)1 List (java.util.List)1 Locale (java.util.Locale)1 ApplicationScoped (javax.enterprise.context.ApplicationScoped)1 Produces (javax.enterprise.inject.Produces)1 Serdes (org.apache.kafka.common.serialization.Serdes)1 StreamsBuilder (org.apache.kafka.streams.StreamsBuilder)1 Topology (org.apache.kafka.streams.Topology)1 Consumed (org.apache.kafka.streams.kstream.Consumed)1 Produced (org.apache.kafka.streams.kstream.Produced)1 Logger (org.jboss.logging.Logger)1