package org.baderlab.pdzsvm.predictor.pwm;

import org.baderlab.pdzsvm.data.*;

import java.util.*;
import java.util.List;
import org.baderlab.brain.ProteinProfile;
import org.biojava.bio.seq.Sequence;
import org.baderlab.pdzsvm.evaluation.Prediction;
import org.baderlab.pdzsvm.evaluation.Evaluation;
import org.baderlab.pdzsvm.predictor.Predictor;
import org.baderlab.pdzsvm.predictor.nn.NN;
import weka.core.Instances;
import weka.core.Utils;
import org.baderlab.pdzsvm.validation.ValidationParameters;
import org.baderlab.pdzsvm.utils.PDZSVMUtils;
import org.baderlab.pdzsvm.utils.Constants;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Conventional PWM predictor which returns by default the top 10% of interactions
 * ranked by PWM score. In the paper the we return the top 1% which is changed by
 * calling setPercent method.
 */
public class PWMPredictor extends Predictor {

    private double percent = 0.1; // default is 10 percent
    public PWMPredictor(List posTrainProfileList)
    {
        super(posTrainProfileList, new ArrayList());
        System.out.println("\tReturn top 10%");
        predictorName = "PWM";
    }
    public void setPercent(double percent)
    {
        System.out.println("\tReturn top " +percent+"%");
        
        this.percent = percent/100.00;

    }
    public static void main(String[] args)
    {
        DataLoader dl = new DataLoader();
        dl.loadMouseChenTrain();
        dl.loadSidhuHumanTrain();

        List posTrainProfileList = dl.getPosTrainProfileList();

        System.out.println("\tNum pos interactions: " + dl.getNumPosTrainInteractions());

        String testName = "";
        String dirName = "";
        //dl.loadWormTest(Constants.PROTEIN_MICROARRAY); testName = "PM WORM G"; dirName = "PMWormG";
        //dl.loadMouseTest(Constants.MOUSE_ORPHAN); testName = "MOUSE ORPHAN PM";  dirName = "PMMouseOrphanG";
        dl.loadFlyTest(); testName = "FLY PM"; dirName = "PMFlyG";

        List posTestProfileList = dl.getPosTestProfileList();
        List negTestProfileList = dl.getNegTestProfileList();

        PWMPredictor pwm = new PWMPredictor(posTrainProfileList);

        String predictorName = pwm.getPredictorName();
        List predictionList = pwm.predict(posTestProfileList, negTestProfileList);

        Evaluation eval = new Evaluation(predictionList);
        List rocAUCList = new ArrayList();
        rocAUCList.add(eval.getROCAUC());
        List prAUCList = new ArrayList();
        prAUCList.add(eval.getPRAUC());
        List aucLabelList = new ArrayList();
        aucLabelList.add(predictorName);

        Instances inst = eval.getCurve(1);
        List instList = new ArrayList();
        instList.add(inst);


        System.out.println("\t=== Summary " +testName+ " ("+predictorName+") ===");
        System.out.println(eval.toString());

        plotCurves(instList,rocAUCList, prAUCList, aucLabelList, predictorName + " (" +testName+")");
    }

    public List[] getSortedLists(List seqList, PWM pwm)
    {

        double[] pwmScores = new double[seqList.size()];
        for (int j=0;j<seqList.size();j++)
        {
            Sequence seq = (Sequence) seqList.get(j);
            String seqString = seq.seqString();
            double score = pwm.score(seqString);
            pwmScores[j] = score;

        }

        int[] lowToHigh  = Utils.sort(pwmScores);
        int[] highToLo = new int[lowToHigh.length];
        int end = lowToHigh.length-1;

        for (int ii=0; ii<lowToHigh.length;ii++)
        {
            int ixi = lowToHigh[end-ii];
            highToLo[ii] = ixi;
        }
        int[] ix;

        ix = highToLo;
        List sortedSequenceList  = new ArrayList();
        List sortedPWMScoreList  = new ArrayList();
        // Get top percent
        //System.out.println("\tTotal seq: " + ix.length + ", top " + percent + " :" + numTopPercent);
        if (ix.length==1)
        {
            Sequence seq = (Sequence)seqList.get(0);
            String seqString = seq.seqString();
            sortedSequenceList.add(seqString);
            sortedPWMScoreList.add(pwmScores[0]);
        }
        else
        {
            for (int ii=0; ii < ix.length;ii++)
            {
                int ixi = ix[ii];
                Sequence seq = (Sequence)seqList.get(ixi);
                String seqString = seq.seqString();
             
                sortedSequenceList.add(seqString);
                sortedPWMScoreList.add(pwmScores[ixi]);
            }


        }
        List[] sortedLists = new ArrayList[2];
        sortedLists[0] = sortedSequenceList;
        sortedLists[1] = sortedPWMScoreList;
        return sortedLists;
    }
    public void train()
    {
        // Nothing to do here
    }
    public List predict(List posTestProfileList, List negTestProfileList)
    {
        // Clear prediction list!
        predictionList = new ArrayList();
        Data testData = new Data();
        testData.addRawData(posTestProfileList, Constants.CLASS_YES);
        List balPosProfileList = new ArrayList();
        List balNegProfileList = new ArrayList();
        HashMap testPosProfileHashMap = PDZSVMUtils.profileListToHashMap(posTestProfileList);

        HashMap testNegProfileHashMap = new HashMap();
        if (negTestProfileList !=  null && !negTestProfileList.isEmpty())
        {
            testData.addRawData(negTestProfileList, Constants.CLASS_NO);
            testNegProfileHashMap = PDZSVMUtils.profileListToHashMap(negTestProfileList);

        }
        for (int i =0; i < negTestProfileList.size();i++)
        {
            ProteinProfile negProfile = (ProteinProfile)negTestProfileList.get(i);
            ProteinProfile posProfile = (ProteinProfile)testPosProfileHashMap.get(negProfile.getName());
            if (posProfile == null)
            {
                balNegProfileList.add(negProfile);
                balPosProfileList.add(null);
            }
            else
            {
                balNegProfileList.add(negProfile);
                balPosProfileList.add(posProfile);
            }
        }
        for (int i =0; i < posTestProfileList.size();i++)
        {
            ProteinProfile posProfile = (ProteinProfile)posTestProfileList.get(i);
            ProteinProfile negProfile = (ProteinProfile)testNegProfileHashMap.get(posProfile.getName());
            if (negProfile == null)
            {
                balNegProfileList.add(null);
                balPosProfileList.add(posProfile);
            }

        }
        ProteinProfile profile;
        for (int i =0;i < balPosProfileList.size();i++)
        {
            ProteinProfile testPosProfile = (ProteinProfile)balPosProfileList.get(i);
            ProteinProfile testNegProfile = (ProteinProfile)balNegProfileList.get(i);
            if (testPosProfile !=null)
                profile = testPosProfile;
            else
                profile = testNegProfile;
            String name = profile.getName();
            String organismLong = profile.getOrganism();
            String organism = PDZSVMUtils.organismLongToShortForm(organismLong);
            String methodLong = profile.getExperimentalMethod();
            String method = PDZSVMUtils.methodLongToShortForm(methodLong);
            String domainSeqFull = profile.getDomainSequence();

            ProteinProfile predictorProfile = NN.getNNBindingSiteSeqProfile(domainSeqFull, organism, posTrainProfileList);
            PWM pwm = new PWM(predictorProfile);

            if (testPosProfile!=null)
            {
                //System.out.println("\tTesting profile: " + testPosProfile.getName() + " using " + predictorProfile.getName());

                Collection testSeqCollection = testPosProfile.getSequenceMap();

                List testSeqList = new ArrayList(testSeqCollection);

                List[] sortedLists = getSortedLists(testSeqList, pwm);
                List sortedSeqList = sortedLists[0];
                List sortedPWMScoreList = sortedLists[1];
                int numTopPercent = (int)Math.rint(sortedPWMScoreList.size()*percent);
                if (numTopPercent == 0)
                    numTopPercent =1;
                //System.out.println("Num top %: " + numTopPercent);
                
                List predPosSeqList = sortedSeqList.subList(0,numTopPercent);
                List predNegSeqList = sortedSeqList.subList(numTopPercent, sortedSeqList.size());

                List posPWMScoreList = sortedPWMScoreList.subList(0,numTopPercent);
                List negPWMScoreList = sortedPWMScoreList.subList(numTopPercent, sortedPWMScoreList.size());

                double cutoff =0;
                if (negPWMScoreList.size() > 0)
                {
                    cutoff = (Double)negPWMScoreList.get(0);
                }
                //System.out.println("\tCutoff: " + cutoff);
                //System.out.println(numTopPercent +"\t"+predPosSeqList.size());

                for (int ii = 0; ii < predPosSeqList.size();ii++)
                {
                    String seqString = (String)predPosSeqList.get(ii);
                    double pwmScore = (Double)posPWMScoreList.get(ii);
                    double score = pwmScore-cutoff;
                    Prediction prediction = new Prediction(1.0, 1, score, name,domainSeqFull, seqString, organism, method );

                    predictionList.add(prediction);
                }
                for (int ii = 0; ii < predNegSeqList.size();ii++)
                {
                    String seqString = (String)predNegSeqList.get(ii);

                    double pwmScore = (Double)negPWMScoreList.get(ii);
                    double score = pwmScore-cutoff;
                    Prediction prediction = new Prediction(0.0, 1, score, name,domainSeqFull, seqString, organism, method );

                    predictionList.add(prediction);
                }
            }
            if (testNegProfile!=null)
            {
                Collection testSeqCollection = testNegProfile.getSequenceMap();
                List testSeqList = new ArrayList(testSeqCollection);

                List[] sortedLists = getSortedLists(testSeqList, pwm);
                List sortedSeqList = sortedLists[0];
                List sortedPWMScoreList = sortedLists[1];
                int numTopPercent = (int)Math.rint(sortedPWMScoreList.size()*percent);
                if (numTopPercent == 0)
                    numTopPercent =1;
                List predPosSeqList = sortedSeqList.subList(0,numTopPercent);
                List predNegSeqList = sortedSeqList.subList(numTopPercent, sortedSeqList.size());

                List posPWMScoreList = sortedPWMScoreList.subList(0,numTopPercent);
                List negPWMScoreList = sortedPWMScoreList.subList(numTopPercent, sortedPWMScoreList.size());

                double cutoff =  0;
                if (negPWMScoreList.size() > 0)
                    cutoff = (Double)negPWMScoreList.get(0);

                //System.out.println("\tCutoff: " + cutoff);
                for (int ii = 0; ii < predPosSeqList.size();ii++)
                {
                    String seqString = (String)predPosSeqList.get(ii);

                    double pwmScore = (Double)posPWMScoreList.get(ii);
                    double score = pwmScore-cutoff;
                    Prediction prediction = new Prediction(1.0, 0, score, name,domainSeqFull, seqString, organism, method );

                    predictionList.add(prediction);
                }
                for (int ii = 0; ii < predNegSeqList.size();ii++)
                {
                    String seqString = (String)predNegSeqList.get(ii);

                    double pwmScore = (Double)negPWMScoreList.get(ii);
                    double score = pwmScore-cutoff;
                    Prediction prediction = new Prediction(0.0, 0, score, name,domainSeqFull, seqString, organism, method );

                    predictionList.add(prediction);
                }

            }

        }
        if (negTestProfileList!=null && !negTestProfileList.isEmpty())
        {
            Evaluation eval = new Evaluation(predictionList);
            double rocAUC = eval.getROCAUC();
            if (rocAUC<0.5)
            {
                System.out.println("\tFlipping signs...");

                for (int ii = 0; ii < predictionList.size();ii++)
                {
                    Prediction prediction = (Prediction)predictionList.get(ii);
                    double decValue = prediction.getDecValue();
                    prediction.setDecValue(-decValue);
                }
            }
        }
        System.out.println(predictionList.size());
        return predictionList;
    }

    public HashMap kFoldCrossValidation(ValidationParameters validParams)
    {
        return new HashMap();
    }

    public HashMap leaveOutCrossValidation(ValidationParameters validParams)
    {
        return new HashMap();
        
    }

}