package org.baderlab.pdzsvm.optimize;

import libsvm.svm_parameter;
import org.baderlab.pdzsvm.predictor.svm.ContactMapSVMPredictor;
import org.baderlab.pdzsvm.predictor.Predictor;

import java.util.*;

import org.baderlab.pdzsvm.data.DataLoader;
import org.baderlab.pdzsvm.validation.ValidationParameters;
import org.baderlab.pdzsvm.utils.Constants;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Grid search program for the contact map svm predictor only
 * TODO: make into one program for regular global svm predictor
 */
public class OptimizeContactMapPredictor
{
    public OptimizeContactMapPredictor()
    {

    }

    public GridResult optimize(String data, String artNegMethod, String type,  int numPeptideSim)
    {
        System.out.println("\t" + data +"," +artNegMethod + "," + type + "," +numPeptideSim);

        DataLoader dl = new DataLoader();
        if (data.equals(Constants.CHEN_MOUSE))
        {
            System.out.println("\tLoading only Chen Mouse data...");
            dl.loadMouseChenTrain();
        }
        else if (data.equals(Constants.SIDHU_HUMAN))
        {
            System.out.println("\tLoading only Sidhu Human data...");

            dl.loadSidhuHumanTrain(Constants.PWM,  numPeptideSim);
        }
        else if (data.equals(Constants.CHEN_MOUSE+Constants.SIDHU_HUMAN))
        {
            if (type.equals(Constants.NONE))
            {
                System.out.println("\tLoading only Chen Mouse + Sidhu Human data...");
                dl.loadMouseChenTrain();
                dl.loadSidhuHumanTrain(artNegMethod,  numPeptideSim);
            }
            else
            {
                System.out.println("\tLoading only Chen Mouse + Sidhu Human "+type+" data...");
                dl.loadMouseChenTrain();
                dl.loadHumanTrain(artNegMethod, type, numPeptideSim);
            }
        }

        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();

        System.out.println("\tTRAIN DOMAINS");
        System.out.println("\tNum Pos Domains: " + posTrainProfileList.size());
        System.out.println("\tNum Neg Domains: " + negTrainProfileList.size());

        System.out.println("\tTRAIN INTERACTIONS");
        System.out.println("\tNum Pos Interactions:" + dl.getNumPosTrainInteractions());
        System.out.println("\tNum Neg Interactions:" + dl.getNumNegTrainInteractions());


        svm_parameter svmParams = new svm_parameter();
        svmParams.setDefaults();

        svmParams.data_encoding = svm_parameter.CONTACTMAP2020;


        double[] lnG = new double[]{3,4,5};
        double[] lnC = new double[]{2,3,4,5,6};

        //double[] lnC = new double[]{2,4,6,8,10};
        //double[] lnG = new double[]{2,4,6,8,10};


        ContactMapSVMPredictor cp =new ContactMapSVMPredictor(posTrainProfileList,
                negTrainProfileList,
                svmParams);

        List gridResultList;
        if (svmParams.kernel_type == svm_parameter.LINEAR)
            gridResultList = optimizeLINEAR(cp, lnC);
        else
            gridResultList = optimizeRBF(cp,lnG, lnC);

        for (int i=0; i < gridResultList.size();i++)
        {
            GridResult result = (GridResult)gridResultList.get(i);
            if (i==0)
                System.out.println(result.headerString());
            System.out.println(result.toString());

        }

        GridResult topResult = (GridResult) gridResultList.get(0);
        System.out.println("\tOptimal grid result: "+ topResult.g + ", " + topResult.C);
        return topResult;
    }

    public static void main(String[] args)
    {
        OptimizeContactMapPredictor opt = new OptimizeContactMapPredictor();
        String type =Constants.NONE;
        String artNegMethod = Constants.NONE;
        int numPeptideSim =3;

        if (args.length != 4 )
        {
            System.out.println("\tNeed at least four parameters, exiting...");   
            return;
        }
        String data = args[0]; // chen mouse, sidhu human, chen mouse + sidhu human

        // chen mouse and sidhu human artnegmethod will be NONE and PWM by default

        artNegMethod = args[1]; // none, pwm, random, shuffled, random sel
        type = args[2];  // genomic or non genomic
        numPeptideSim = Integer.parseInt(args[3]);

        System.out.println("\tTraining method, type, numPeptideSim: " +artNegMethod+", "+type + ", " + numPeptideSim);


        opt.optimize(data, artNegMethod, type, numPeptideSim);

    }
    public List optimizeLINEAR(Predictor predictor, double[] lnC)
    {
        ContactMapSVMPredictor cmSVMPredictor = (ContactMapSVMPredictor)predictor;

        int numFolds = 10;
        int totInt = cmSVMPredictor.getNumTrainPositive() + cmSVMPredictor.getNumTrainNegative();
        if (totInt <20 && totInt >= 10)
        {
            numFolds = 5;
        }

        svm_parameter svmParams = cmSVMPredictor.getSVMParams();

        double[] C = GridSearchUtils.getC(lnC);
        List gridResultList= new ArrayList();
        ValidationParameters validParams = new ValidationParameters();
        validParams.k = 10;
        validParams.d = 12;
        validParams.type = ValidationParameters.DOMAIN;

        for (int j = 0; j < C.length;j++)
        {
            svmParams.C = C[j];
            svmParams.print();

            //HashMap cvResultsMap = cmSVMPredictor.kFoldCrossValidation(validParams);
            HashMap cvResultsMap = cmSVMPredictor.leaveOutCrossValidation(validParams);

            GridResult result = GridSearchUtils.computeAvgGridResult(cvResultsMap);

            System.out.println(result.toString());

            result.C = GridSearchUtils.toLnC(C[j]);
            gridResultList.add(result);

        }
        Collections.sort(gridResultList, new MyComparator());

        return gridResultList;
    }
    public List optimizeRBF(Predictor predictor, double[] lnG,double[] lnC)
    {
        System.out.println("\t=== Optimizing RBF Kernel ===\n");
        ContactMapSVMPredictor cmSVPredictor = (ContactMapSVMPredictor)predictor;

        svm_parameter svmParams = cmSVPredictor.getSVMParams();

        double[] C = GridSearchUtils.getC(lnC);
        double[] g = GridSearchUtils.getG(lnG);
        ValidationParameters validParams = new ValidationParameters();
        validParams.k = 10;
        validParams.type = ValidationParameters.K_FOLD;
        //validParams.type = ValidationParameters.DOMAIN;
        //validParams.numTimes = 1;
        //validParams.d = 12;
        //validParams.p = 8;

        List gridResultList= new ArrayList();
        for (int i=0; i < g.length;i++)
        {
            for (int j = 0; j < C.length;j++)
            {
                System.out.println("\t==== Grid [" + lnG[i] + "," + lnC[j] + "] ===\n");

                svmParams.C = C[j];
                svmParams.gamma = g[i];
                svmParams.print();
                HashMap cvResultsMap;
                if (validParams.type == ValidationParameters.K_FOLD)
                    cvResultsMap= cmSVPredictor.kFoldCrossValidation(validParams);
                else
                    cvResultsMap = cmSVPredictor.leaveOutCrossValidation(validParams);

                GridResult result = GridSearchUtils.computeAvgGridResult(cvResultsMap);

                result.C = GridSearchUtils.toLnC(C[j]);
                result.g = GridSearchUtils.toLnG(g[i]);

                System.out.println(result.toString());

                gridResultList.add(result);

            }
        }

        Collections.sort(gridResultList, new MyComparator());

        return gridResultList;
    }
    public class MyComparator implements Comparator
    {
        public int compare(Object anotherResult2, Object anotherResult) throws ClassCastException
        {
            if (!(anotherResult instanceof GridResult))
                throw new ClassCastException("A GridResult object expected.");
            double auc = ((GridResult) anotherResult).rocAUC;
            double auc2 = ((GridResult) anotherResult2).rocAUC;
            if (auc > auc2) return 1;
            else if (auc < auc2) return -1;
            else return 0;
        }
    }
}
