Search in sources :

Example 1 with BaseClassifier

use of Classifier.BaseClassifier in project IR_Base by Linda-sunshine.

the class Execution method main.

public static void main(String[] args) throws IOException, ParseException {
    Parameter param = new Parameter(args);
    System.out.println(param.toString());
    String stnModel = (param.m_model.equals("HTMM") || param.m_model.equals("LRHTMM")) ? param.m_stnModel : null;
    String posModel = (param.m_model.equals("HTMM") || param.m_model.equals("LRHTMM")) ? param.m_posModel : null;
    _Corpus corpus;
    Analyzer analyzer;
    /**
     *Load the data from vector file**
     */
    if (param.m_fvFile != null && (new File(param.m_fvFile)).exists()) {
        analyzer = new VctAnalyzer(param.m_classNumber, param.m_lengthThreshold, param.m_featureFile);
        // Load all the documents as the data set.
        analyzer.LoadDoc(param.m_fvFile);
        corpus = analyzer.getCorpus();
    } else {
        /**
         *Load the data from text file**
         */
        analyzer = new DocAnalyzer(param.m_tokenModel, stnModel, posModel, param.m_classNumber, param.m_featureFile, param.m_Ngram, param.m_lengthThreshold);
        ((DocAnalyzer) analyzer).setReleaseContent(!param.m_weightScheme.equals("PR"));
        if (param.m_featureFile == null) {
            /**
             **Pre-process the data.****
             */
            // Feture selection.
            System.out.println("Performing feature selection, wait...");
            param.m_featureFile = String.format("./data/Features/%s_fv.dat", param.m_featureSelection);
            param.m_featureStat = String.format("./data/Features/%s_fv_stat.dat", param.m_featureSelection);
            System.out.println(param.printFeatureSelectionConfiguration());
            ((DocAnalyzer) analyzer).LoadStopwords(param.m_stopwords);
            // Load all the documents as the data set.
            analyzer.LoadDirectory(param.m_folder, param.m_suffix);
            // Select the features.
            analyzer.featureSelection(param.m_featureFile, param.m_featureSelection, param.m_startProb, param.m_endProb, param.m_maxDF, param.m_minDF);
        }
        // Collect vectors for documents.
        System.out.println("Creating feature vectors, wait...");
        // Load all the documents as the data set.
        analyzer.LoadDirectory(param.m_folder, param.m_suffix);
        analyzer.setFeatureValues(param.m_featureValue, param.m_norm);
        corpus = analyzer.returnCorpus(param.m_featureStat);
    }
    if (param.m_weightScheme.equals("PR")) {
        System.out.println("Creating PageRank instance weighting, wait...");
        PageRank myPR = new PageRank(corpus, param.m_C, 100, 50, 1e-6);
        myPR.train(corpus.getCollection());
    }
    // Execute different classifiers.
    if (param.m_style.equals("SUP")) {
        BaseClassifier model = null;
        if (param.m_model.equals("NB")) {
            // Define a new naive bayes with the parameters.
            System.out.println("Start naive bayes, wait...");
            model = new NaiveBayes(corpus);
        } else if (param.m_model.equals("LR")) {
            // Define a new logistics regression with the parameters.
            System.out.println("Start logistic regression, wait...");
            model = new LogisticRegression(corpus, param.m_C);
        } else if (param.m_model.equals("PR-LR")) {
            // Define a new logistics regression with the parameters.
            System.out.println("Start posterior regularized logistic regression, wait...");
            model = new PRLogisticRegression(corpus, param.m_C);
        } else if (param.m_model.equals("SVM")) {
            // corpus.save2File("data/FVs/fvector.dat");
            System.out.println("Start SVM, wait...");
            model = new SVM(corpus, param.m_C);
        } else {
            System.out.println("Classifier has not been developed yet!");
            System.exit(-1);
        }
        model.setDebugOutput(param.m_debugOutput);
        model.crossValidation(param.m_CVFold, corpus);
    } else if (param.m_style.equals("SEMI")) {
        BaseClassifier model = null;
        if (param.m_model.equals("GF")) {
            System.out.println("Start Gaussian Field by matrix inversion, wait...");
            model = new GaussianFields(corpus, param.m_classifier, param.m_C, param.m_sampleRate, param.m_kUL, param.m_kUU);
        } else if (param.m_model.equals("GF-RW")) {
            System.out.println("Start Gaussian Field by random walk, wait...");
            model = new GaussianFieldsByRandomWalk(corpus, param.m_classifier, param.m_C, param.m_sampleRate, param.m_kUL, param.m_kUU, param.m_alpha, param.m_beta, param.m_converge, param.m_eta, param.m_weightedAvg);
        } else if (param.m_model.equals("GF-RW-ML")) {
            System.out.println("Start Gaussian Field with distance metric learning by random walk, wait...");
            model = new LinearSVMMetricLearning(corpus, param.m_classifier, param.m_C, param.m_sampleRate, param.m_kUL, param.m_kUU, param.m_alpha, param.m_beta, param.m_converge, param.m_eta, param.m_weightedAvg, param.m_bound);
        // ((LinearSVMMetricLearning)model).setMetricLearningMethod(false);
        // ((LinearSVMMetricLearning)model).verification(param.m_CVFold, corpus, param.m_debugOutput);
        } else {
            System.out.println("Classifier has not been developed yet!");
            System.exit(-1);
        }
        model.setDebugOutput(param.m_debugOutput);
        model.crossValidation(param.m_CVFold, corpus);
    } else if (param.m_style.equals("TM")) {
        TopicModel model = null;
        if (param.m_model.equals("2topic")) {
            model = new twoTopic(param.m_maxmIterations, param.m_converge, param.m_beta, corpus, param.m_lambda);
        } else if (param.m_model.equals("pLSA")) {
            if (param.m_multithread == false) {
                model = new pLSA(param.m_maxmIterations, param.m_converge, param.m_beta, corpus, param.m_lambda, param.m_numTopics, param.m_alpha);
            } else {
                model = new pLSA_multithread(param.m_maxmIterations, param.m_converge, param.m_beta, corpus, param.m_lambda, param.m_numTopics, param.m_alpha);
            }
            ((pLSA) model).LoadPrior(param.m_priorFile, param.m_gamma);
        } else if (param.m_model.equals("vLDA")) {
            if (param.m_multithread == false) {
                model = new LDA_Variational(param.m_maxmIterations, param.m_converge, param.m_beta, corpus, param.m_lambda, param.m_numTopics, param.m_alpha, param.m_maxVarIterations, param.m_varConverge);
            } else {
                model = new LDA_Variational_multithread(param.m_maxmIterations, param.m_converge, param.m_beta, corpus, param.m_lambda, param.m_numTopics, param.m_alpha, param.m_maxVarIterations, param.m_varConverge);
            }
            ((LDA_Variational) model).LoadPrior(param.m_priorFile, param.m_gamma);
        } else if (param.m_model.equals("gLDA")) {
            model = new LDA_Gibbs(param.m_maxmIterations, param.m_converge, param.m_beta, corpus, param.m_lambda, param.m_numTopics, param.m_alpha, param.m_burnIn, param.m_lag);
            ((LDA_Gibbs) model).LoadPrior(param.m_priorFile, param.m_gamma);
        } else if (param.m_model.equals("HTMM")) {
            model = new HTMM(param.m_maxmIterations, param.m_converge, param.m_beta, corpus, param.m_numTopics, param.m_alpha);
        } else if (param.m_model.equals("LRHTMM")) {
            model = new LRHTMM(param.m_maxmIterations, param.m_converge, param.m_beta, corpus, param.m_numTopics, param.m_alpha, param.m_C);
        } else {
            System.out.println("The specified topic model has not been developed yet!");
            System.exit(-1);
        }
        if (param.m_CVFold <= 1) {
            model.EMonCorpus();
            // fixed: print top 10 words
            model.printTopWords(10);
        } else
            model.crossValidation(param.m_CVFold);
    } else if (param.m_style.equals("FV")) {
        corpus.save2File(param.m_fvFile);
        System.out.format("Vectors saved to %s...\n", param.m_fvFile);
    } else
        System.out.println("Learning paradigm has not developed yet!");
}
Also used : PRLogisticRegression(Classifier.supervised.PRLogisticRegression) GaussianFieldsByRandomWalk(Classifier.semisupervised.GaussianFieldsByRandomWalk) DocAnalyzer(Analyzer.DocAnalyzer) topicmodels.twoTopic(topicmodels.twoTopic) VctAnalyzer(Analyzer.VctAnalyzer) SVM(Classifier.supervised.SVM) DocAnalyzer(Analyzer.DocAnalyzer) VctAnalyzer(Analyzer.VctAnalyzer) Analyzer(Analyzer.Analyzer) TopicModel(topicmodels.TopicModel) LDA_Variational(topicmodels.LDA.LDA_Variational) HTMM(topicmodels.markovmodel.HTMM) LRHTMM(topicmodels.markovmodel.LRHTMM) structures._Corpus(structures._Corpus) NaiveBayes(Classifier.supervised.NaiveBayes) BaseClassifier(Classifier.BaseClassifier) LDA_Variational_multithread(topicmodels.multithreads.LDA.LDA_Variational_multithread) topicmodels.multithreads.pLSA.pLSA_multithread(topicmodels.multithreads.pLSA.pLSA_multithread) PRLogisticRegression(Classifier.supervised.PRLogisticRegression) LogisticRegression(Classifier.supervised.LogisticRegression) topicmodels.pLSA.pLSA(topicmodels.pLSA.pLSA) LRHTMM(topicmodels.markovmodel.LRHTMM) PageRank(influence.PageRank) LinearSVMMetricLearning(Classifier.metricLearning.LinearSVMMetricLearning) LDA_Gibbs(topicmodels.LDA.LDA_Gibbs) Parameter(structures.Parameter) GaussianFields(Classifier.semisupervised.GaussianFields) File(java.io.File)

Aggregations

Analyzer (Analyzer.Analyzer)1 DocAnalyzer (Analyzer.DocAnalyzer)1 VctAnalyzer (Analyzer.VctAnalyzer)1 BaseClassifier (Classifier.BaseClassifier)1 LinearSVMMetricLearning (Classifier.metricLearning.LinearSVMMetricLearning)1 GaussianFields (Classifier.semisupervised.GaussianFields)1 GaussianFieldsByRandomWalk (Classifier.semisupervised.GaussianFieldsByRandomWalk)1 LogisticRegression (Classifier.supervised.LogisticRegression)1 NaiveBayes (Classifier.supervised.NaiveBayes)1 PRLogisticRegression (Classifier.supervised.PRLogisticRegression)1 SVM (Classifier.supervised.SVM)1 PageRank (influence.PageRank)1 File (java.io.File)1 Parameter (structures.Parameter)1 structures._Corpus (structures._Corpus)1 LDA_Gibbs (topicmodels.LDA.LDA_Gibbs)1 LDA_Variational (topicmodels.LDA.LDA_Variational)1 TopicModel (topicmodels.TopicModel)1 HTMM (topicmodels.markovmodel.HTMM)1 LRHTMM (topicmodels.markovmodel.LRHTMM)1