Search in sources :

Example 1 with KNN

use of Classifier.supervised.KNN in project IR_Base by Linda-sunshine.

the class VectorReviewMain method main.

public static void main(String[] args) throws IOException, ParseException {
    /**
     ***Set these parameters before run the classifiers.****
     */
    // Define the number of classes in this Naive Bayes.
    int classNumber = 5;
    // Document length threshold
    int lengthThreshold = 5;
    // k fold-cross validation
    int CVFold = 10;
    // Supervised classification models: "NB", "LR", "PR-LR", "SVM"
    // Semi-supervised classification models: "GF", "GF-RW", "GF-RW-ML"
    // Which classifier to use.
    String classifier = "GF-RW-ML";
    // String modelPath = "./data/Model/";
    double C = 1.0;
    // "SUP", "SEMI"
    String style = "SEMI";
    String multipleLearner = "SVM";
    /**
     ***The parameters used in loading files.****
     */
    String featureLocation = "data/Features/fv_2gram_BM25_CHI_small.txt";
    String vctfile = "data/FVs/vct_2gram_BM25_CHI_tablet_small.dat";
    // String featureLocation = "data/Features/fv_fake.txt";
    // String vctfile = "data/Fvs/LinearRegression.dat";
    /**
     ***Parameters in time series analysis.****
     */
    // String debugOutput = String.format("data/debug/%s.sim.pair", classifier);
    String debugOutput = null;
    /**
     **Pre-process the data.****
     */
    // Feture selection.
    System.out.println("Loading vectors from file, wait...");
    VctAnalyzer analyzer = new VctAnalyzer(classNumber, lengthThreshold, featureLocation);
    // Load all the documents as the data set.
    analyzer.LoadDoc(vctfile);
    _Corpus corpus = analyzer.getCorpus();
    // make it binary
    corpus.mapLabels(4);
    /**
     ******Choose different classification methods.********
     */
    if (style.equals("SUP")) {
        if (classifier.equals("NB")) {
            // Define a new naive bayes with the parameters.
            System.out.println("Start naive bayes, wait...");
            NaiveBayes myNB = new NaiveBayes(corpus);
            // Use the movie reviews for testing the codes.
            myNB.crossValidation(CVFold, corpus);
        } else if (classifier.equals("KNN")) {
            // Define a new naive bayes with the parameters.
            System.out.println("Start kNN, wait...");
            KNN myKNN = new KNN(corpus, 10, 1);
            // Use the movie reviews for testing the codes.
            myKNN.crossValidation(CVFold, corpus);
        } else if (classifier.equals("LR")) {
            // Define a new logistics regression with the parameters.
            System.out.println("Start logistic regression, wait...");
            LogisticRegression myLR = new LogisticRegression(corpus, C);
            myLR.setDebugOutput(debugOutput);
            // Use the movie reviews for testing the codes.
            myLR.crossValidation(CVFold, corpus);
        // myLR.saveModel(modelPath + "LR.model");
        } else if (classifier.equals("PRLR")) {
            // Define a new logistics regression with the parameters.
            System.out.println("Start posterior regularized logistic regression, wait...");
            PRLogisticRegression myLR = new PRLogisticRegression(corpus, C);
            myLR.setDebugOutput(debugOutput);
            // Use the movie reviews for testing the codes.
            myLR.crossValidation(CVFold, corpus);
        // myLR.saveModel(modelPath + "LR.model");
        } else if (classifier.equals("SVM")) {
            System.out.println("Start SVM, wait...");
            SVM mySVM = new SVM(corpus, C);
            mySVM.crossValidation(CVFold, corpus);
        } else if (classifier.equals("PR")) {
            System.out.println("Start PageRank, wait...");
            PageRank myPR = new PageRank(corpus, C, 100, 50, 1e-6);
            myPR.train(corpus.getCollection());
        } else
            System.out.println("Classifier has not been developed yet!");
    } else if (style.equals("SEMI")) {
        double learningRatio = 1.0;
        // k nearest labeled, k' nearest unlabeled
        int k = 20, kPrime = 20;
        // labeled data weight, unlabeled data weight
        double tAlpha = 1.0, tBeta = 0.1;
        // convergence of random walk, weight of random walk
        double tDelta = 1e-4, tEta = 0.5;
        boolean simFlag = false;
        double threshold = 0.5;
        // bound for generating rating constraints (must be zero in binary case)
        int bound = 0;
        boolean metricLearning = true;
        if (classifier.equals("GF")) {
            GaussianFields mySemi = new GaussianFields(corpus, multipleLearner, C);
            mySemi.crossValidation(CVFold, corpus);
        } else if (classifier.equals("GF-RW")) {
            GaussianFields mySemi = new GaussianFieldsByRandomWalk(corpus, multipleLearner, C, learningRatio, k, kPrime, tAlpha, tBeta, tDelta, tEta, false);
            mySemi.setDebugOutput(debugOutput);
            mySemi.crossValidation(CVFold, corpus);
        } else if (classifier.equals("GF-RW-ML")) {
            LinearSVMMetricLearning lMetricLearner = new LinearSVMMetricLearning(corpus, multipleLearner, C, learningRatio, k, kPrime, tAlpha, tBeta, tDelta, tEta, false, bound);
            lMetricLearner.setMetricLearningMethod(metricLearning);
            lMetricLearner.setDebugOutput(debugOutput);
            lMetricLearner.crossValidation(CVFold, corpus);
        } else
            System.out.println("Classifier has not been developed yet!");
    } else
        System.out.println("Learning paradigm has not been developed yet!");
}
Also used : PRLogisticRegression(Classifier.supervised.PRLogisticRegression) GaussianFieldsByRandomWalk(Classifier.semisupervised.GaussianFieldsByRandomWalk) PageRank(influence.PageRank) VctAnalyzer(Analyzer.VctAnalyzer) KNN(Classifier.supervised.KNN) SVM(Classifier.supervised.SVM) structures._Corpus(structures._Corpus) NaiveBayes(Classifier.supervised.NaiveBayes) LinearSVMMetricLearning(Classifier.metricLearning.LinearSVMMetricLearning) GaussianFields(Classifier.semisupervised.GaussianFields) PRLogisticRegression(Classifier.supervised.PRLogisticRegression) LogisticRegression(Classifier.supervised.LogisticRegression)

Aggregations

VctAnalyzer (Analyzer.VctAnalyzer)1 LinearSVMMetricLearning (Classifier.metricLearning.LinearSVMMetricLearning)1 GaussianFields (Classifier.semisupervised.GaussianFields)1 GaussianFieldsByRandomWalk (Classifier.semisupervised.GaussianFieldsByRandomWalk)1 KNN (Classifier.supervised.KNN)1 LogisticRegression (Classifier.supervised.LogisticRegression)1 NaiveBayes (Classifier.supervised.NaiveBayes)1 PRLogisticRegression (Classifier.supervised.PRLogisticRegression)1 SVM (Classifier.supervised.SVM)1 PageRank (influence.PageRank)1 structures._Corpus (structures._Corpus)1