package org.baderlab.pdzsvmstruct.predictor.svm;

import libsvm.*;

import java.util.*;
import java.util.List;

import org.baderlab.pdzsvmstruct.data.Data;
import org.baderlab.pdzsvmstruct.evaluation.Prediction;
import org.baderlab.pdzsvmstruct.evaluation.Evaluation;
import org.baderlab.pdzsvmstruct.data.*;
import org.baderlab.pdzsvmstruct.encoding.*;
import org.baderlab.pdzsvmstruct.validation.ValidationParameters;
import org.baderlab.pdzsvmstruct.utils.PDZSVMUtils;
import org.baderlab.pdzsvmstruct.utils.Constants;

/**
 * Copyright (c) 2011 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVMStruct.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVMStruct.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * SVM class which encodes data and loads into format that is compatible with
 * LibSVM and loads formats LivSVM results into format PDZSVM can use.  This
 * class that calls LibSVM functions.
 */
public class SVM
{

    public SVM()
    {
    }

    public static void main(String[]args)
    {

        DataLoader dl = new DataLoader();
        dl.loadSidhuHumanPDBTrain(Constants.SIDHU_HUMAN_G_PDB, Constants.PHAGE_DISPLAY);

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();

        List organismList = new ArrayList();
        organismList.add(Constants.HUMAN);
        DomainFeatureEncoding dEnc = new DomainFeatureEncoding(organismList);
        PeptideFeatureEncoding pEnc = new PeptideFeatureEncoding();

        dl.loadMousePDBTest(Constants.CHEN_MOUSE_PDB);
        List posTestProfileList = dl.getPosTestProfileList();
        List negTestProfileList= dl.getNegTestProfileList();


        Data trainData = new Data();
        trainData.addRawData(posTrainProfileList,Constants.CLASS_YES);
        trainData.addRawData(negTrainProfileList, Constants.CLASS_NO);

        trainData.encodeBindingSiteStructureData(dEnc, pEnc);
        List maxMinList = trainData.getMaxMin();
        trainData.scaleData(0.0,1.0,maxMinList);

        List organismList2 = new ArrayList();
        organismList2.add(Constants.MOUSE);
        Data testData = new Data();
        DomainFeatureEncoding dEnc2 = new DomainFeatureEncoding(organismList2);

        testData.addRawData(posTestProfileList, Constants.CLASS_YES);
        testData.addRawData(negTestProfileList,Constants.CLASS_NO);
        testData.encodeBindingSiteStructureData(dEnc2, pEnc);
        testData.scaleData(0.0,1.0,maxMinList);

        svm_parameter structParams = new svm_parameter();
        structParams.setDefaults();
        double g = 3;
        structParams.gamma = Math.exp(-Math.log(2)-g);

        svm_parameter seqParams = new svm_parameter();
        seqParams.setDefaults();
        g = 3;
        seqParams.gamma = Math.exp(-Math.log(2)-g);
    }

    public static svm_problem loadData(Data data, List dataSet)
    {
        svm_problem prob = new svm_problem();
        System.out.println("\n\tLoading data (size: " +dataSet.size() +")...");

        Vector<Double> vy = new Vector<Double>();
        Vector<svm_node[]> vx = new Vector<svm_node[]>();
        int max_index = 0;
        HashMap domainNumToEncMap = data.getDomainNumToEncMap();
        HashMap peptideNumToEncMap = data.getPeptideNumToEncMap();
        for (int i = 0;i < dataSet.size();i++)
        {

            Datum dt = (Datum)dataSet.get(i);

            Features dFeat = (Features)domainNumToEncMap.get(dt.domainNum);
            Features pFeat = (Features)peptideNumToEncMap.get(dt.peptideNum);
            Double classValue = Double.valueOf((double)dt.classToInt());
            vy.addElement(classValue);


            int m = dFeat.numFeatures() + pFeat.numFeatures() ;
            svm_node[] x = new svm_node[m];
            List dFeatList = dFeat.getFeatureValues();
            List pFeatList = pFeat.getFeatureValues();

            int j = 0;

            for (int ii=0; ii < dFeatList.size();ii++)
            {
                x[j] = new svm_node();
                x[j].index = (j+1);
                double theValue = (Double)dFeatList.get(ii);
                x[j].value = theValue;
                j = j+1;
            }
            for (int ii=0; ii < pFeatList.size();ii++)
            {
                x[j] = new svm_node();
                x[j].index = (j+1);
                double theValue = (Double)pFeatList.get(ii);
                x[j].value = theValue;
                j = j+1;
            }
            if(m>0) max_index = Math.max(max_index, x[m-1].index);
            vx.addElement(x);
        }

        prob.l = vy.size();
        prob.x = new svm_node[prob.l][];
        for(int i=0;i<prob.l;i++)
            prob.x[i] = vx.elementAt(i);
        prob.y = new double[prob.l];
        for(int i=0;i<prob.l;i++)
            prob.y[i] = vy.elementAt(i);

        prob.names = new String[prob.l];
        prob.domainSeq = new String[prob.l];
        prob.peptideSeq = new String[prob.l];
        prob.organismShort = new String[prob.l];
        prob.methodShort = new String[prob.l];

        HashMap peptideNumToRawMap = data.getPeptideNumToRawMap();

        for (int i = 0;i < (dataSet.size());i++)
        {
            Datum dt = (Datum)dataSet.get(i);
            prob.names[i] = dt.name;

            Features peptideFeatures = (Features)peptideNumToRawMap.get(dt.peptideNum);
            prob.domainSeq[i] = data.getDomainBindingSiteSeq(dt.domainNum);
            prob.peptideSeq[i] = peptideFeatures.toUndelimitedString();
            prob.organismShort[i] = dt.organism;
            prob.methodShort[i] = dt.expMethod;
        }
        return prob;
    }
    public static HashMap kFoldCrossValidation(Data trainData, svm_parameter svmParams, ValidationParameters validParams)
    {
        System.out.println("\n\t"+validParams.k+" fold cross validation...");
        
        svm_problem trainProb = null;
        if(svmParams.data_encoding==svm_parameter.STRUCT)
        {
            trainProb = loadData(trainData, trainData.getDataList());
        }

        HashMap kFoldResultsMap = new HashMap();

        List cvResults = new ArrayList();
        svm.svm_cross_validation_x(trainProb,svmParams,validParams.k,cvResults);

        for (int k = 0; k < cvResults.size();k++)
        {
            List predictionList = new ArrayList();
            List foldResults = (ArrayList)cvResults.get(k);
            double[] target = new double[foldResults.size()];
            double[] decValues = new double[foldResults.size()];
            int[] actual = new int[foldResults.size()];

            for (int m = 0; m < foldResults.size();m++)
            {

                double[] ret = (double[])foldResults.get(m);
                target[m] = PDZSVMUtils.zeroToOne(ret[1]);
                decValues[m] = ret[0];
                actual[m] = (int) PDZSVMUtils.zeroToOne(ret[2]);
                Prediction pred = new Prediction(target[m],actual[m],decValues[m]);
                predictionList.add(pred);
            }
            kFoldResultsMap.put(k,predictionList);

        }

        return kFoldResultsMap;
    }
    public static HashMap leaveKMOutCrossValidation(ValidationParameters validParams, Data trainData, svm_parameter svmParams, HashMap domainNumToDataListMap, HashMap peptideNumToDataListMap)
    {
        int k = validParams.k;
        int d = validParams.d;
        int p = validParams.p;
        double dPercent = (double)d/100;
        double pPercent = (double)p/100;

        System.out.println("\tLeave ["+d+","+p+"]% "+ ValidationParameters.CV_STRING[ValidationParameters.DOMAIN_PEPTIDE]+ " out cross validation...");
        HashMap kFoldResultsMap = new HashMap();

        Set domainNumKeys = domainNumToDataListMap.keySet();
        List domainNumKeyList = new ArrayList(domainNumKeys);
        Collections.sort(domainNumKeyList);

        Set peptideNumKeys = peptideNumToDataListMap.keySet();
        List peptideNumKeyList = new ArrayList(peptideNumKeys);
        Collections.sort(peptideNumKeyList);

        HashMap domainNumToRawMap = trainData.getDomainNumToRawMap();
        HashMap peptideNumToRawMap = trainData.getPeptideNumToRawMap();
        HashMap domainNumToEncMap = trainData.getDomainNumToEncMap();
        HashMap peptideNumToEncMap = trainData.getPeptideNumToEncMap();

        int numDomains = domainNumToRawMap.size();
        int numInDFold = (int)((double)numDomains*(dPercent));

        int numPeptides = peptideNumToRawMap.size();
        int numInPFold = (int)((double)numPeptides*(pPercent));

        System.out.println("\tNum in d Fold: " + numInDFold);
        System.out.println("\tNum in p Fold: " + numInPFold);

        int numFolds;
        if (d ==0)
        {
            System.out.println("\td cannot be zero...");
            return null;
        }
        else numFolds = k;
        System.out.println("\tNum folds: " + numFolds);

        if (p ==0)
        {
            System.out.println("\tp cannot be zero...");
            return null;
        }
        for (int i =0; i < numFolds;i++)
        {
            List testDFoldDataList = new ArrayList();
            Collections.shuffle(domainNumKeyList);
            List testDomainNumKeyList = domainNumKeyList.subList(0,numInDFold);

            for (int ii = 0; ii < testDomainNumKeyList.size();ii++)
            {
                String testFoldNum = (String)testDomainNumKeyList.get(ii);
                List testDomainDataList = (List) domainNumToDataListMap.get(testFoldNum);
                testDFoldDataList.addAll(testDomainDataList);
            }

            List domainTestNameList=new ArrayList();
            for (int ii=0;ii < testDFoldDataList.size();ii++)
            {
                Datum dt = (Datum)testDFoldDataList.get(ii);
                if (!domainTestNameList.contains(dt.name))
                    domainTestNameList.add(dt.name);
            }

            List testPFoldDataList = new ArrayList();
            Collections.shuffle(peptideNumKeyList);
            List testPeptideNumKeyList = peptideNumKeyList.subList(0,numInPFold);


            for (int jj =0; jj < testPeptideNumKeyList.size();jj++)
            {
                String testPeptideNum = (String)testPeptideNumKeyList.get(jj);

                List testPeptideDataList = (List) peptideNumToDataListMap.get(testPeptideNum);
                List subTestPeptideDataList = new ArrayList();
                for (int kk =0; kk < testPeptideDataList.size();kk++)
                {
                    Datum dt = (Datum)testPeptideDataList.get(kk);
                    if (!testDomainNumKeyList.contains(dt.domainNum))
                    {
                        subTestPeptideDataList.add(dt);
                    }
                }

                testPFoldDataList.addAll(subTestPeptideDataList);

            }

            List peptideTestNameList=new ArrayList();
            for (int jj=0;jj < testPFoldDataList.size();jj++)
            {
                Datum dt = (Datum)testPFoldDataList.get(jj);
                if (!peptideTestNameList.contains(dt.peptideNum))
                    peptideTestNameList.add(dt.peptideNum);
            }

            List testKMFoldDataList = new ArrayList();
            testKMFoldDataList.addAll(testDFoldDataList);
            testKMFoldDataList.addAll(testPFoldDataList);
            Data testFoldData = new Data();
            testFoldData.addRawData(testKMFoldDataList, domainNumToRawMap, peptideNumToRawMap, domainNumToEncMap, peptideNumToEncMap);

            List testDomainNames = testFoldData.getDomainNames();
            List testPeptideNums = testFoldData.getPeptideNums();
            if (testFoldData.getNumPositive() == 0 || testFoldData.getNumNegative()==0)
            {
                System.out.println("\n\t===== Fold " + i +" " +domainTestNameList.toString() + " x " + peptideTestNameList.toString() + " =====");
                System.out.println("\tAll: " + testDomainNames.toString() + " x " + testPeptideNums.toString());
                System.out.println("\tNo positive or negative data...");
                continue;
            }

            List trainKMFoldDataList = new ArrayList();
            for (int jj=0; jj < domainNumKeyList.size();jj++)
            {
                String trainFoldNum = (String)domainNumKeyList.get(jj);
                if (!testDomainNumKeyList.contains(trainFoldNum))
                {
                    List trainDomainDataList = (List) domainNumToDataListMap.get(trainFoldNum);
                    List subTrainDomainDataList = new ArrayList();
                    for (int kk=0; kk < trainDomainDataList.size();kk++)
                    {
                        Datum dt = (Datum) trainDomainDataList.get(kk);
                        if (!testPeptideNumKeyList.contains(dt.peptideNum))
                            subTrainDomainDataList.add(dt);
                    }
                    trainKMFoldDataList.addAll(subTrainDomainDataList);

                }

            }
            // test it!

            System.out.println("\n\t===== Fold " + i +" " +domainTestNameList.toString() + " x " + peptideTestNameList.toString() + " =====");
            System.out.println("\tAll: " + testDomainNames.toString() + " x " + testPeptideNums.toString());
            Data trainFoldData = new Data();
            trainFoldData.addRawData(trainKMFoldDataList, domainNumToRawMap, peptideNumToRawMap,domainNumToEncMap, peptideNumToEncMap);

            svm_model svmModel = SVM.train(trainFoldData, svmParams);
            List predictions = SVM.predict(trainFoldData, testFoldData, svmModel, svmParams);
            Evaluation eval = new Evaluation(predictions);
            System.out.println(eval.toString());

            kFoldResultsMap.put(i,predictions);
        } // for all folds (k)
        return kFoldResultsMap;
    }

    public static HashMap leaveKOutCrossValidation(ValidationParameters validParams, Data trainData, svm_parameter svmParams, HashMap numToDataListMap)
    {
        int k = validParams.k;
        int dp ;
        int type = validParams.type;
        if (validParams.type==ValidationParameters.DOMAIN)
            dp = validParams.d;
        else
            dp = validParams.p;
        double percent = (double)dp/100;
        System.out.println("\tLeave "+dp+"% "+ValidationParameters.CV_STRING[type]+" out cross validation...");
        Set numKeys = numToDataListMap.keySet();
        List numKeyList = new ArrayList(numKeys);
        Collections.sort(numKeyList);
        HashMap domainNumToRawMap = trainData.getDomainNumToRawMap();
        HashMap peptideNumToRawMap = trainData.getPeptideNumToRawMap();

        HashMap domainNumToEncMap = trainData.getDomainNumToEncMap();
        HashMap peptideNumToEncMap = trainData.getPeptideNumToEncMap();

        int num;
        if (type==ValidationParameters.DOMAIN)
            num = domainNumToRawMap.size();
        else
            num = peptideNumToRawMap.size();

        int numInFold = (int)((double)num*(percent));

        System.out.println("\tNum in fold: " + numInFold );
        HashMap kFoldResultsMap = new HashMap();

        int numFolds;
        if (dp == 0) numFolds = num;
        else numFolds = k;
        System.out.println("\tNum Folds: " + numFolds);
        for (int i =0; i < numFolds;i++)
        {
            String testName = "";
            List testNameList= new ArrayList();
            List testNumKeyList = new ArrayList();
            List testFoldDataList = new ArrayList();
            if (dp==0)
            {
                String testFoldNum = (String)numKeyList.get(i);
                testNumKeyList.add(testFoldNum);
                List numToDataList = (List) numToDataListMap.get(testFoldNum);
                testFoldDataList.addAll(numToDataList);

                Datum testDt = (Datum)testFoldDataList.get(0);

                Features ft = (Features)peptideNumToRawMap.get(testDt.peptideNum);
                String peptide = ft.toUndelimitedString();
                if (type==ValidationParameters.DOMAIN)
                    testName  = "["+testDt.name +"]";
                else
                    testName = "["+peptide+"]";
            }
            else
            {
                // Randomly pick numInFold from numKeyList
                Collections.shuffle(numKeyList);
                testNumKeyList = numKeyList.subList(0,numInFold);

                for (int j = 0; j < testNumKeyList.size();j++)
                {
                    String testFoldNum = (String)testNumKeyList.get(j);
                    List numToDataList = (List) numToDataListMap.get(testFoldNum);
                    testFoldDataList.addAll(numToDataList);
                }

                for (int ii = 0; ii < testFoldDataList.size();ii++)
                {
                    Datum testDt = (Datum)testFoldDataList.get(ii);
                    Features ft = (Features)peptideNumToRawMap.get(testDt.peptideNum);
                    String peptide = ft.toUndelimitedString();
                    String testNameTemp;
                    if (type==ValidationParameters.DOMAIN)
                        testNameTemp = testDt.name;
                    else
                        testNameTemp = peptide;

                    if (!testNameList.contains(testNameTemp))
                        testNameList.add(testNameTemp);
                }
                testName = testNameList.toString();
            }

            Data testFoldData = new Data();
            testFoldData.addRawData(testFoldDataList, domainNumToRawMap, peptideNumToRawMap, domainNumToEncMap, peptideNumToEncMap);

            if (testFoldData.getNumPositive() == 0 || testFoldData.getNumNegative()==0)
            {
                System.out.println("\n\t===== Fold " + i +" " +testName+ " =====");
                System.out.println("\tNo positive or negative data, skipping...");
                continue;
            }
            List trainFoldDataList = new ArrayList();
            for (int j=0; j < numKeyList.size();j++)
            {
                String trainFoldNum = (String)numKeyList.get(j);
                if (!testNumKeyList.contains(trainFoldNum))
                {
                    List numDataList = (List) numToDataListMap.get(trainFoldNum);
                    trainFoldDataList.addAll(numDataList);
                }
            }
            System.out.println("\n\t===== Fold " + i +" "+testName+" =====");
            Data trainFoldData = new Data();
            trainFoldData.addRawData(trainFoldDataList, domainNumToRawMap, peptideNumToRawMap,domainNumToEncMap, peptideNumToEncMap);

            svm_model svmModel = SVM.train(trainFoldData, svmParams);
            List predictionList = SVM.predict(trainFoldData, testFoldData, svmModel, svmParams);
            if (predictionList.size()==0)
                continue;
            Evaluation eval = new Evaluation(predictionList);
            System.out.println(eval.toString());

            kFoldResultsMap.put(i,predictionList);

        }

        return kFoldResultsMap;
    }

    public static HashMap leaveSimOutCrossValidation(Data trainData, svm_parameter svmParams, ValidationParameters validParams,HashMap domainNumToDataListMap)
    {
        System.out.println("\n\tLeave out cross validation ("+validParams.sim+")...");
        Sidhu10FeatureEncoding enc = new Sidhu10FeatureEncoding();
        System.out.println("\tLeave "+validParams.sim+" domains out cross validation...");

        HashMap domainNumToRawMap = trainData.getDomainNumToRawMap();
        HashMap peptideNumToRawMap = trainData.getPeptideNumToRawMap();
        HashMap domainNumToEncMap = trainData.getDomainNumToEncMap();
        HashMap peptideNumToEncMap = trainData.getPeptideNumToEncMap();

        HashMap kFoldResultsMap = new HashMap();

        Set key = domainNumToDataListMap.keySet();
        List domainNumList = new ArrayList(key);
        Collections.sort(domainNumList);
        HashMap kFoldInfoMap = new HashMap();
        for (int i=0; i < domainNumList.size();i++)
        {
            String domainKeyi = (String) domainNumList.get(i);
            List numToDataListi = (List) domainNumToDataListMap.get(domainKeyi);
            Datum dti = (Datum)numToDataListi.get(0);
            String domainName = dti.name;
            String organismi = dti.organism;
            int domainNumi = dti.domainNum;

            List trainFoldDataList = new ArrayList();

            System.out.println("\n\t===== Fold " + i + " " +domainName+ " " + organismi + " =====");
            String bindingSiteSeq = enc.getFeatures(domainName, organismi);
            double nnSim = 0.0;
            for (int j=0; j < domainNumList.size();j++)
            {
                String domainKeyj = (String) domainNumList.get(j);
                List numToDataListj = (List) domainNumToDataListMap.get(domainKeyj);
                Datum dtj = (Datum)numToDataListj.get(0);
                String otherDomainName = dtj.name;
                String organismj = dtj.organism;

                if (i == j)
                    continue;
                String otherBindingSiteSeq = enc.getFeatures(otherDomainName, organismj);
                double sim = PDZSVMUtils.identity(bindingSiteSeq, otherBindingSiteSeq);

                if (sim >= nnSim)
                {
                    nnSim = sim;
                }
                if (sim <= validParams.sim)
                {
                    //System.out.println(sim +"," + validParams.sim+ ": " + otherDomainName);
                    List numToTrainDataList = (List) domainNumToDataListMap.get(domainKeyj);
                    trainFoldDataList.addAll(numToTrainDataList);
                }

            }

            List testFoldDataList = new ArrayList();

            List numToTestDataList = (List) domainNumToDataListMap.get(domainKeyi);
            testFoldDataList.addAll(numToTestDataList);
            Data testFoldData = new Data();
            testFoldData.addRawData(testFoldDataList, domainNumToRawMap, peptideNumToRawMap, domainNumToEncMap, peptideNumToEncMap);

            if (testFoldData.getNumPositive() ==0 || testFoldData.getNumNegative()==0)
            {
                //System.out.println("\n\t===== Fold " + (i+1) +" " +domainName+ " =====");
                System.out.println("\tNo positive or negative data, skipping...");
                continue;
            }

            Data trainFoldData = new Data();
            trainFoldData.addRawData(trainFoldDataList, domainNumToRawMap, peptideNumToRawMap, domainNumToEncMap, peptideNumToEncMap);

            svm_model svmModel;
            List predictionList;
            svmModel= SVM.train(trainFoldData, svmParams);
            predictionList = SVM.predict(trainFoldData, testFoldData, svmModel, svmParams);

            Evaluation eval = new Evaluation(predictionList);
            System.out.println("\n\t===== Fold " + i +" " +domainName+ " " +organismi +" =====");
            System.out.println(eval.toString());

            kFoldResultsMap.put(i,predictionList);

            List foldInfoList = new ArrayList();

            foldInfoList.add(domainName);
            foldInfoList.add(organismi);
            foldInfoList.add(nnSim);
            foldInfoList.add(eval.getROCAUC());
            foldInfoList.add(eval.getPRAUC());
            foldInfoList.add(eval.getMCC());
            foldInfoList.add(domainNumi);
            foldInfoList.add(testFoldData.getNumPositive());
            foldInfoList.add(testFoldData.getNumNegative());
            kFoldInfoMap.put(i,foldInfoList);
        }
        Set keys = kFoldResultsMap.keySet();
        List keyNumList = new ArrayList(keys);
        Collections.sort(keyNumList);
        for (int i=0 ; i < keyNumList.size();i++)
        {
            int keyNum = (Integer)keyNumList.get(i);
            List foldInfoList = (List) kFoldInfoMap.get(keyNum);
            if (foldInfoList == null)
                continue;
            String domainName = (String)foldInfoList.get(0);
            String organism = (String)foldInfoList.get(1);
            Double nnSim = (Double)foldInfoList.get(2);
            Double rAUC = (Double)foldInfoList.get(3);
            Double prAUC = (Double)foldInfoList.get(4);
            Double mcc = (Double)foldInfoList.get(5);
            Integer numPos = (Integer)foldInfoList.get(7);
            Integer numNeg = (Integer)foldInfoList.get(8);

            System.out.println(domainName+"\t" +organism+"\t"+ nnSim + "\t"+numPos+"\t"+numNeg+"\t"+ rAUC + "\t" + prAUC + "\t" + mcc);
        }

        return kFoldResultsMap;

    }

    public static HashMap leaveOutCrossValidation(Data trainData, svm_parameter svmParams, ValidationParameters validParams)
    {
        List dataList = trainData.getDataList();

        System.out.println("\n\tLeave out cross validation...");
        HashMap domainNumToDataListMap = new HashMap();
        HashMap peptideNumToDataListMap = new HashMap();

        for (int i=0; i < dataList.size();i++)
        {
            Datum dt = (Datum)dataList.get(i);
            int peptideNum = dt.peptideNum;
            String organism = dt.organism;
            String domainName = dt.name;
            List domainDataList = (List)domainNumToDataListMap.get(domainName+"-"+organism);
            if (domainDataList ==  null)
                domainDataList = new ArrayList();
            domainDataList.add(dt);
            domainNumToDataListMap.put(domainName+"-"+organism,domainDataList);

            List peptideDataList = (List)peptideNumToDataListMap.get(peptideNum+"-"+organism);
            if (peptideDataList ==  null)
                peptideDataList = new ArrayList();
            peptideDataList.add(dt);
            peptideNumToDataListMap.put(peptideNum+"-"+organism,peptideDataList);
        }
        HashMap kFoldResultsMap;
        if (validParams.type== ValidationParameters.DOMAIN)
        {
            kFoldResultsMap = leaveKOutCrossValidation(validParams,  trainData, svmParams, domainNumToDataListMap);
        }
        else if (validParams.type == ValidationParameters.PEPTIDE)
        {
            kFoldResultsMap = leaveKOutCrossValidation(validParams, trainData, svmParams, peptideNumToDataListMap);
        }
        else if (validParams.type == ValidationParameters.DOMAIN_PEPTIDE)
        {
            kFoldResultsMap = leaveKMOutCrossValidation(validParams, trainData, svmParams, domainNumToDataListMap, peptideNumToDataListMap);
        }
        else if (validParams.type == ValidationParameters.LOSIM_DOMAIN)
        {
            kFoldResultsMap = leaveSimOutCrossValidation(trainData, svmParams, validParams, domainNumToDataListMap);
        }
        else
            kFoldResultsMap = new HashMap();
        return kFoldResultsMap;
    }

    public static svm_model train(Data trainData, svm_parameter svmParams)
    {
        System.out.println("\tTraining...\n");
        System.out.println("\t=== TRAINING DATA ===");
        trainData.printSummary();
        System.out.println();

        svm_problem trainProb = null;

        if(svmParams.data_encoding==svm_parameter.STRUCT)
        {
            System.out.println("\tLoading STRUCT training data");
            trainProb = loadData(trainData, trainData.getDataList());
        }

        svm_model svmModel = null;
        try
        {
            svmParams.print();
            System.out.println("\tTraining model...");
            svmModel = svm.svm_train( trainProb,  svmParams);
        }
        catch(Exception e)
        {
            e.printStackTrace();
        }

        System.out.println("\tFinished training...\n");
        return svmModel;
    }


    public static List predict(Data trainData, Data testData, svm_model svmModel, svm_parameter svmParams)
    {
        System.out.println("\tPredicting...\n");
        System.out.println("\t=== TESTING DATA ===");
        testData.printSummary();
        System.out.println();
        svm_problem testProb = null;
        if(svmParams.data_encoding==svm_parameter.STRUCT)
        {
            System.out.println("\tLoading STRUCT test data");
            testProb = loadData(testData, testData.getDataList());
        }

        // Iterate over all testProb instances
        List predictions = new ArrayList();
        for (int i = 0; i < testProb.l;i++)
        {
            svm_node[] node = testProb.x[i];

            double[] ret = svm.svm_predict_x(svmModel,node);
            double decValue = ret[0];
            double predValue = PDZSVMUtils.zeroToOne(ret[1]);
            int actualValue = (int)PDZSVMUtils.zeroToOne(testProb.y[i]);
            String name = testProb.names[i];
            String domainSeq = testProb.domainSeq[i];
            String peptideSeq = testProb.peptideSeq[i];
            String organism = testProb.organismShort[i];
            String method = testProb.methodShort[i];
            Prediction prediction = new Prediction(predValue, actualValue, decValue, name, domainSeq, peptideSeq, organism, method);
            predictions.add(prediction);
        } // for


        System.out.println("\tFinished predicting...");

        // Don't compute AUC scores if we don't have any negatives or positives
        if (testData.getNumPositive() >0 && testData.getNumNegative() > 0)
        {
            Evaluation eval = new Evaluation(predictions);
            double rocAUC = eval.getROCAUC();
            // Flip the decvalue signs
            if (rocAUC<0.5)
            {
                System.out.println("\tFlipping signs...");

                for (int i = 0; i < predictions.size();i++)
                {
                    Prediction prediction = (Prediction)predictions.get(i);
                    double decValue = prediction.getDecValue();
                    prediction.setDecValue(-decValue);
                }
            }
            System.out.println();
        }
        return predictions;

    }

}
