package org.baderlab.pdzsvmstruct.validation;

import org.baderlab.pdzsvmstruct.data.DataLoader;
import org.baderlab.pdzsvmstruct.data.Data;
import org.baderlab.pdzsvmstruct.data.manager.DataFileManager;
import org.baderlab.pdzsvmstruct.utils.PDZSVMUtils;
import org.baderlab.pdzsvmstruct.utils.Constants;

import java.util.*;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.File;

import libsvm.svm_parameter;
import org.baderlab.pdzsvmstruct.predictor.svm.GlobalSVMPredictor;
import org.baderlab.pdzsvmstruct.predictor.Predictor;
import org.baderlab.pdzsvmstruct.evaluation.Evaluation;
import org.baderlab.pdzsvmstruct.encoding.DomainFeatureEncoding;
import weka.core.Instances;

/**
 * Copyright (c) 2011 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVMStruct.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVMStruct.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * CrossValidation calls predictor methods to perform different methods of
 * cross validation. This code generates R code for Figs. 1, 2a, 2b.
 */
public class CrossValidation
{
    private List instList;
    private List aucLabelList;
    private List rocAUCList;
    private List prAUCList;

    private String parentDir = "/CrossValidation";
    private boolean print = true;

    public CrossValidation()
    {
    }
    public void validate(String predictorType, ValidationParameters validParams)
    {
        DataLoader dl = new DataLoader();
        dl.loadMousePDBTrain();
        dl.loadSidhuHumanPDBTrain(Constants.SIDHU_HUMAN_G_PDB, Constants.PHAGE_DISPLAY);

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();
        System.out.println("\n\tTRAIN DOMAINS:");
        System.out.println("\tNum Pos Domains: " + posTrainProfileList.size());
        System.out.println("\tNum Neg Domains: " + negTrainProfileList.size());
        System.out.println();
        System.out.println("\tTRAIN INTERACTIONS");
        System.out.println("\tNum Pos Interactions:" + dl.getNumPosTrainInteractions());
        System.out.println("\tNum Neg Interactions:" + dl.getNumNegTrainInteractions());
        System.out.println();
        
        svm_parameter svmParams = new svm_parameter();
        svmParams.setDefaults();

        double C = 0.0;
        double g = 0.0;

        //svmParams.kernel_type = svm_parameter.LINEAR;
        //svmParams.data_encoding = svm_parameter.PAIRWISE_POTENTIAL;
        //svmParams.data_encoding = svm_parameter.P0;
        //svmParams.kernel_type = svm_parameter.PRECOMPUTED;
        Predictor p;
        if (predictorType.equals("STRUCT"))
        {
            svmParams.data_encoding = svm_parameter.STRUCT;
            C = 4; g=3;            
            svmParams.C = Math.exp(C);
            svmParams.gamma = Math.exp(-Math.log(2)-g);
            p = new GlobalSVMPredictor(posTrainProfileList, negTrainProfileList, svmParams);
        }
        else
        {
            return;
        }
        System.out.println("\tSVM params: [g,C] = ["+g+","+C+"])");

        //public final static int DOMAIN =0;
        //public final static int PEPTIDE =1;
        //public final static int DOMAIN_PEPTIDE =2;
        //public final static int K_FOLD = 3;
        //public final static int LOOV_DOMAIN = 4;
        //public final static int LOOV_PEPTIDE = 5;
        // public final static int LOSIM_DOMAIN = 6;

        instList = new ArrayList();
        aucLabelList = new ArrayList();
        rocAUCList = new ArrayList();
        prAUCList= new ArrayList();

        runValidation(p, validParams);

        String title = ValidationParameters.CV_STRING[validParams.validationType];
        Predictor.plotCurves(instList, rocAUCList, prAUCList, aucLabelList, title);

    }

    private void print(String output, ValidationParameters params, String fileName)
    {
        try
        {
            String outFileName = params.outputDir + "/" + fileName;

            System.out.println("\tWriting to " + outFileName+ "...");

            BufferedWriter bw = new BufferedWriter(new FileWriter(new File(outFileName)));
            bw.write(output);
            bw.close();
        }
        catch(Exception e)
        {
            System.out.println("Exception: " + e);
        }
    }

    private void runValidation(Predictor p, ValidationParameters validParams)
    {
        validParams.outputDir = DataFileManager.OUTPUT_ROOT_DIR + parentDir + validParams.dirName;
        validParams.predictorName = p.getPredictorName();

        String validationType = ValidationParameters.CV_STRING[validParams.type];
        List predictionList =  new ArrayList();
        List foldROCAUCList = new ArrayList();
        List foldPRAUCList = new ArrayList();
        List foldPredictionList = new ArrayList();
        List foldDomainNumList = new ArrayList();
        for (int ii=0;ii< validParams.numTimes;ii++)
        {
            HashMap cvResultsMap;
            System.out.println("\t=== Run # " + (ii+1) + " ===");
            if (validParams.type==ValidationParameters.K_FOLD)
                cvResultsMap = p.kFoldCrossValidation(validParams);
            else
                cvResultsMap = p.leaveOutCrossValidation(validParams);

            Set keys = cvResultsMap.keySet();
            List keyList = new ArrayList(keys);
            Collections.sort(keyList);
            for (int i=0; i < keyList.size();i++)
            {
                Integer foldNum = (Integer)keyList.get(i);
                foldDomainNumList.add(foldNum);
                List cvPredictionList = (List)cvResultsMap.get(foldNum);
                predictionList.addAll(cvPredictionList);
                Evaluation eval = new Evaluation(cvPredictionList);
                double rocAUC = eval.getROCAUC();
                double prAUC = eval.getPRAUC();
                foldROCAUCList.add(rocAUC);
                foldPRAUCList.add(prAUC);

                foldPredictionList.add(cvPredictionList);
            }
        }
        Evaluation eval = new Evaluation(predictionList);
        double rocAUC = eval.getROCAUC();
        double prAUC = eval.getPRAUC();

        Instances inst = eval.getCurve(1);
        instList.add(inst);
        aucLabelList.add(p.getPredictorName() +" " + validationType);
        rocAUCList.add(rocAUC);
        prAUCList.add(prAUC);
        
        double[] cirAUC = confidenceInterval(foldROCAUCList);
        System.out.println("\tROC AUC: 95% C.I.: " + cirAUC[0] + "~" + cirAUC[1]);

        double[] ciprAUC = confidenceInterval(foldPRAUCList);
        System.out.println("\tPR  AUC: 95% C.I.: " + ciprAUC[0] + "~" + ciprAUC[1]);

        String cvString = ValidationParameters.CV_STRING[validParams.type];

        if (print)
        {
            String fileName = p.getPredictorName().replace(' ','_') + "_"+cvString;
            if (validParams.d==0 && validParams.type == ValidationParameters.DOMAIN)
                fileName = fileName + "_LODO";
            if (validParams.p==0&& validParams.type == ValidationParameters.PEPTIDE)
                fileName = fileName + "_LOPO";
            if (validParams.sim !=0)
                fileName = fileName + "_"+validParams.sim;
            if (validParams.simRange[0] !=0)
                fileName = fileName + "_"+validParams.simRange[0]+"-"+validParams.simRange[1];
            validParams.predictorName = p.getPredictorName();
            StringBuffer rString = PDZSVMUtils.toRString(foldPredictionList);
            int exFt = p.getExcludedFeature();
            String exFtName = "All";
            if (exFt > -1)
                exFtName = DomainFeatureEncoding.FEATURE_NAMES[exFt];
            print(rString.toString(),validParams, "cv_" + fileName + "_Load-"+exFtName+".r");

            System.out.println();
        }

    }

    private String getRPlotString(List instList, String x, String y)
    {
        double[] avgPrec =null;
        double[] avgRec =null;
        StringBuffer rString =  new StringBuffer();
        for (int i=0; i < instList.size();i++)
        {
            Instances inst = (Instances)instList.get(i);
            int pix = inst.attribute(y).index();
            int rix = inst.attribute(x).index();

            double [] prec = inst.attributeToDoubleArray(pix);
            double [] rec = inst.attributeToDoubleArray(rix);

            avgPrec = new double[prec.length];
            avgRec = new double[rec.length];

            for (int ii=0; ii < prec.length;ii++)
            {
                double precision = prec[ii];
                avgPrec[ii] = avgPrec[ii] + precision;
            }
            for (int ii=0; ii < rec.length;ii++)
            {
                double recall = rec[ii];
                avgRec[ii] = avgRec[ii] + recall;
            }
        }
        for (int ii=0; ii < avgPrec.length;ii++)
        {
            avgPrec[ii] = avgPrec[ii]/instList.size();
        }
        for (int ii=0; ii < avgRec.length;ii++)
        {
            avgRec[ii] = avgRec[ii]/instList.size();
        }

        String var1  = "";
        if (x.equals("Recall"))
            var1 = "rec";
        else
            var1 = "fpr";
        String var2  = "";
        if (y.equals("Precision"))
            var2 = "prec";
        else
            var2 = "tpr";
        rString.append(var1 + " = c("+avgPrec[0]);
        for (int ii=1; ii < avgPrec.length;ii++)
        {
            avgPrec[ii] = avgPrec[ii]/instList.size();
            rString.append(","+avgPrec[ii]);
        }
        rString.append(")\n");
        rString.append(var2+ " = c("+avgRec[0]);

        for (int ii=1; ii < avgRec.length;ii++)
        {
            avgRec[ii] = avgRec[ii]/instList.size();
            rString.append(","+avgRec[ii]);

        }
        rString.append(")\n");
        return rString.toString();
    }
    public double[] confidenceInterval(List statList)
    {
        double confidenceLevel =0.95;
        double c =1.96;
        double n = statList.size();
        double mean = 0.0;
        for (int i=0; i < statList.size();i++)
        {
            double auc = (Double)statList.get(i);
            mean = mean + auc;
        }
        mean = mean/n;

        double sum = 0.0;
        for (int i=0; i < statList.size();i++)
        {
            double auc = (Double)statList.get(i);
            double diff = auc-mean;
            sum = sum + Math.pow(diff,2);
        }
        double var = (1/(n-1)) * sum;
        double s = Math.sqrt(var);

        double ciLow = mean-c*s/Math.sqrt(n);
        double ciHigh = mean+c*s/Math.sqrt(n);

        return new double[]{ciLow, ciHigh};

    }


    public static void main(String[] args)
    {
        CrossValidation cv = new CrossValidation();
        String predictorType = args[0];
        if (predictorType.equals("C"))
            System.out.println("\tPredictor type: ContactMapSVMPredictor");
        else if (predictorType.equals("NN"))
            System.out.println("\tPredictor type: NNPredictor");
        else if (predictorType.equals("NNSTRUCT"))
            System.out.println("\tPredictor type: NNPredictor-Struct");

        else
            System.out.println("\tPredictor type: GlobalSVMPredictor");
        int validationType = Integer.parseInt(args[1]);

        System.out.println("\tValidation type: " + ValidationParameters.CV_STRING[validationType]);

        ValidationParameters validParams = new ValidationParameters();
        validParams.validationType=validationType;
        if (validationType==ValidationParameters.K_FOLD)
        {
            validParams.type = ValidationParameters.K_FOLD; validParams.dirName = "/KFold";   validParams.k = 10;   validParams.numTimes = 10;
        }
        else if (validationType==ValidationParameters.DOMAIN)
        {
            validParams.type = ValidationParameters.DOMAIN;  validParams.dirName = "/Domain";
            validParams.d = 12; validParams.k = 10;  validParams.numTimes = 10;
        }
        else if (validationType==ValidationParameters.LOOV_DOMAIN)
        {
            validParams.type = ValidationParameters.DOMAIN;  validParams.dirName = "/Domain";
            validParams.d = 0; validParams.p = 0; validParams.k = 1; validParams.numTimes = 1;
        }
        else if (validationType==ValidationParameters.PEPTIDE)
        {
            validParams.type = ValidationParameters.PEPTIDE; validParams.dirName = "/Peptide";
            validParams.p = 8; validParams.k = 10;  validParams.numTimes = 10;
        }
        else if (validationType == ValidationParameters.LOOV_PEPTIDE)
        {
            validParams.type = ValidationParameters.PEPTIDE; validParams.dirName = "/Peptide";
            validParams.d = 0; validParams.p = 0; validParams.k = 1; validParams.numTimes = 1;
        }
        else if (validationType == ValidationParameters.DOMAIN_PEPTIDE)
        {
            validParams.type = ValidationParameters.DOMAIN_PEPTIDE; validParams.dirName = "/DomainPeptide";
            validParams.d = 12;validParams.p = 8;  validParams.k = 10;  validParams.numTimes = 10;
        }
        else if (validationType == ValidationParameters.LOSIM_DOMAIN)
        {
            validParams.type = ValidationParameters.LOSIM_DOMAIN; validParams.dirName = "/SimDomain";
            validParams.sim = 1.0;   validParams.numTimes = 1;
        }
        else if (validationType == ValidationParameters.ALLSIM_DOMAIN)
        {
            double low = Double.parseDouble(args[2]);
            double high = Double.parseDouble(args[3]);

            validParams.type = ValidationParameters.ALLSIM_DOMAIN; validParams.dirName = "/SimRangeDomain";
            validParams.simRange = new double[]{low,high};   validParams.numTimes = 10;  validParams.k = 10;
        }
        else
        {
            System.out.println("\tUnknown validation type...exiting.");
            return;
        }



        cv.validate(predictorType, validParams);

    }
}
