package org.baderlab.pdzsvm.predictor.pwm;

import org.baderlab.pdzsvm.predictor.Predictor;
import org.baderlab.pdzsvm.predictor.nn.NN;
import org.baderlab.pdzsvm.data.DataLoader;
import org.baderlab.pdzsvm.data.Data;
import org.baderlab.pdzsvm.utils.PDZSVMUtils;
import org.baderlab.pdzsvm.utils.Constants;
import org.baderlab.pdzsvm.evaluation.Evaluation;
import org.baderlab.pdzsvm.evaluation.Prediction;
import org.baderlab.pdzsvm.validation.ValidationParameters;
import org.baderlab.brain.ProteinProfile;
import org.biojava.bio.seq.Sequence;

import java.util.*;
import weka.core.Instances;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Optimized PWM predictor that learns an optimized cut off score using positive
 * and negative training data.
 * NOTE: NOT USED IN THE PAPER
 */
public class PWMOptPredictor extends Predictor {

    private HashMap predictorParamsMap = new HashMap();
    public PWMOptPredictor(List posTrainProfileList,  List negTrainProfileList)
    {
        super(posTrainProfileList, negTrainProfileList);
        predictorName = "PWM";
    }

    public static void main(String[] args)
    {
        DataLoader dl = new DataLoader();
        dl.loadMouseChenTrain();
        List posTrainProfileList = dl.getPosTrainProfileList();
        List negTrainProfileList = dl.getNegTrainProfileList();

        System.out.println("\tNum pos interactions: " + dl.getNumPosTrainInteractions());

        String testName = "";
        String dirName = "";
        //dl.loadWormTest(Constants.PROTEIN_MICROARRAY); testName = "PM WORM G"; dirName = "PMWormG";
        //dl.loadMouseTest("ORPHAN"); testName = "MOUSE ORPHAN PM";  dirName = "PMMouseOrphanG";
        dl.loadFlyTest(); testName = "FLY PM"; dirName = "PMFlyG";

        // Don't balance if using only positive profiles only
        // Balance the training data (only pos/neg profile pairs can be used for training)
        List[] balancedLists = PDZSVMUtils.balanceLists(posTrainProfileList,negTrainProfileList);
        posTrainProfileList = balancedLists[0];
        negTrainProfileList = balancedLists[1];

        List posTestProfileList = dl.getPosTestProfileList();
        List negTestProfileList = dl.getNegTestProfileList();

        PWMOptPredictor pwm = new PWMOptPredictor(posTrainProfileList,negTrainProfileList);

        String predictorName = pwm.getPredictorName();
        pwm.train();
        List predictionList = pwm.predict(posTestProfileList, negTestProfileList);

        Evaluation eval = new Evaluation(predictionList);
        List rocAUCList = new ArrayList();
        rocAUCList.add(eval.getROCAUC());
        List prAUCList = new ArrayList();
        prAUCList.add(eval.getPRAUC());
        List aucLabelList = new ArrayList();
        aucLabelList.add(predictorName);

        Instances inst = eval.getCurve(1);
        List instList = new ArrayList();
        instList.add(inst);


        System.out.println("\t=== Summary " +testName+ " ("+predictorName+") ===");
        System.out.println(eval.toString());

        plotCurves(instList,rocAUCList, prAUCList, aucLabelList, predictorName + " (" +testName+")");
    }
    private class Parameters
    {
        List predictList = new ArrayList();
        double cutoff = 0;
        String method = "";
        String artmethod = "";
    }
    public void train()
    {
        HashMap trainNegHashMap = PDZSVMUtils.profileListToHashMap(negTrainProfileList);
        for (int i = 0; i< posTrainProfileList.size();i++)
        {
            ProteinProfile predictorProfile = (ProteinProfile)posTrainProfileList.get(i);
            String organismLong = predictorProfile.getOrganism();
            String organism = PDZSVMUtils.organismLongToShortForm(organismLong);
            ProteinProfile trainPosProfile = NN.getNNBindingSiteSeqProfile(predictorProfile.getDomainSequence(), organism, posTrainProfileList);

            System.out.println("\t===== Training " + predictorProfile.getName() + " ( by " + trainPosProfile.getName() + " ) =====");

            ProteinProfile trainNegProfile = (ProteinProfile)trainNegHashMap.get(trainPosProfile.getName());
            PWM predictorPWM = new PWM(predictorProfile);

            double[] pwmScoreCutoffs = new double[]{0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1};

            int numPos = 0;
            int numNeg = 0;
            double maxPWMScore = 0.0;
            // Get the best set of predictions
            double maxMeasure = Double.MIN_VALUE;
            List maxPredictionList = new ArrayList();
            for (int j =0;j < pwmScoreCutoffs.length;j++)
            {

                List predictionList = new ArrayList();
                if (trainPosProfile!= null)
                {
                    Collection seqPosCollection = trainPosProfile.getSequenceMap();
                    Iterator it = seqPosCollection.iterator();
                    while(it.hasNext())
                    {
                        Prediction prediction = null;
                        Sequence seq = (Sequence)it.next();
                        String seqString = seq.seqString();

                        double pwmScore = predictorPWM.score(seqString);
                        double score = pwmScore - pwmScoreCutoffs[j];
                        if (score > 0)
                        {
                            // true positive
                            prediction = new Prediction(1.0, 1, score);
                        }
                        else
                        {
                            prediction = new Prediction(0.0, 1, score);
                        }

                        if (prediction !=null)
                            predictionList.add(prediction);

                        numPos = numPos+1;
                    }
                }

                if (trainNegProfile!= null)
                {
                    Collection seqNegCollection = trainNegProfile.getSequenceMap();

                    Iterator it = seqNegCollection.iterator();
                    while(it.hasNext())
                    {
                        Prediction prediction = null;
                        Sequence seq = (Sequence)it.next();
                        String seqString = seq.seqString();

                        double pwmScore = predictorPWM.score(seqString);

                        double score = pwmScore - pwmScoreCutoffs[j];

                        if (score <= 0)
                        {
                            prediction = new Prediction(0.0, 0, score);

                        }
                        else
                        {
                            prediction = new Prediction(1.0, 0, score);
                        }

                        if (prediction !=null)
                            predictionList.add(prediction);
                    }
                    numNeg = numNeg+1;

                }
                Evaluation eval = new Evaluation(predictionList);
                double tpr = eval.truePositiveRate(0);
                double tnr = eval.truePositiveRate(1);

                // Maximize tpr and tnr
                double tnpr = tpr * tnr;
                if (tnpr >= maxMeasure)
                {
                    maxMeasure = tnpr;
                    maxPWMScore = pwmScoreCutoffs[j];
                    maxPredictionList = predictionList;
                }
            } // for over all cutoffs

            Parameters param = new Parameters();
            param.predictList = maxPredictionList;
            param.cutoff = maxPWMScore;
            predictorParamsMap.put(predictorProfile.getName(), param);
        }
    }

    public List predict(List posTestProfileList, List negTestProfileList)
    {
        System.out.println();
        // Clear prediction list!
        predictionList = new ArrayList();
        Data testData = new Data();
        testData.addRawData(posTestProfileList, Constants.CLASS_YES);
        List balPosProfileList = new ArrayList();
        List balNegProfileList = new ArrayList();
        HashMap testPosProfileHashMap = PDZSVMUtils.profileListToHashMap(posTestProfileList);

        HashMap testNegProfileHashMap = new HashMap();
        if (negTestProfileList !=  null && !negTestProfileList.isEmpty())
        {
            testData.addRawData(negTestProfileList, Constants.CLASS_NO);
            testNegProfileHashMap = PDZSVMUtils.profileListToHashMap(negTestProfileList);

        }

        for (int i =0; i < negTestProfileList.size();i++)
        {
            ProteinProfile negProfile = (ProteinProfile)negTestProfileList.get(i);
            ProteinProfile posProfile = (ProteinProfile)testPosProfileHashMap.get(negProfile.getName());
            if (posProfile == null)
            {
                balNegProfileList.add(negProfile);
                balPosProfileList.add(null);
            }
            else
            {
                balNegProfileList.add(negProfile);
                balPosProfileList.add(posProfile);
            }
        }
        for (int i =0; i < posTestProfileList.size();i++)
        {
            ProteinProfile posProfile = (ProteinProfile)posTestProfileList.get(i);
            ProteinProfile negProfile = (ProteinProfile)testNegProfileHashMap.get(posProfile.getName());
            if (negProfile == null)
            {
                balNegProfileList.add(null);
                balPosProfileList.add(posProfile);
            }

        }
        ProteinProfile profile;
        for (int i =0;i < balPosProfileList.size();i++)
        {
            ProteinProfile testPosProfile = (ProteinProfile)balPosProfileList.get(i);
            ProteinProfile testNegProfile = (ProteinProfile)balNegProfileList.get(i);
            if (testPosProfile !=null)
                profile = testPosProfile;
            else
                profile = testNegProfile;
            String name = profile.getName();
            String organismLong = profile.getOrganism();
            String organism = PDZSVMUtils.organismLongToShortForm(organismLong);
            String methodLong = profile.getExperimentalMethod();
            String method = PDZSVMUtils.methodLongToShortForm(methodLong);
            String domainSeqFull = profile.getDomainSequence();

            ProteinProfile predictorProfile = NN.getNNBindingSiteSeqProfile(domainSeqFull, organism, posTrainProfileList);
            PWM pwm = new PWM(predictorProfile);

            Parameters params= (Parameters)predictorParamsMap.get(predictorProfile.getName());
            double cutoff = params.cutoff;
            System.out.println("\t===== Predicting " +testPosProfile.getName() + " with " + predictorProfile.getName() +" using cutoff \t"+cutoff +" =====");
            if (testPosProfile!=null)
            {
                Collection testSeqCollection = testPosProfile.getSequenceMap();

                Iterator it = testSeqCollection.iterator();
                while(it.hasNext())
                {
                    Prediction prediction = null;
                    Sequence seq = (Sequence)it.next();
                    String seqString = seq.seqString();

                    double pwmScore = pwm.score(seqString);
                    double score = pwmScore-cutoff;
                    if (score > 0)
                    {
                        prediction = new Prediction(1.0, 1, score, name,domainSeqFull, seqString, organism, method );
                    }
                    else
                    {
                        prediction = new Prediction(0.0, 1, score, name,domainSeqFull, seqString, organism, method );


                    }
                    predictionList.add(prediction);

                }
            }
            if (testNegProfile!=null)
            {
                Collection testSeqCollection = testNegProfile.getSequenceMap();
                Iterator it = testSeqCollection.iterator();
                while(it.hasNext())
                {
                    Prediction prediction = null;
                    Sequence seq = (Sequence)it.next();
                    String seqString = seq.seqString();

                    double pwmScore = pwm.score(seqString);
                    double score = pwmScore-cutoff;

                    if (score <= 0)
                    {
                        prediction = new Prediction(0.0, 0, score, name,domainSeqFull, seqString, organism, method );
                    }
                    else
                    {
                        prediction = new Prediction(1.0, 0, score, name,domainSeqFull, seqString, organism, method );
                    }
                    predictionList.add(prediction);

                }
            }

        }
        if (negTestProfileList!=null && !negTestProfileList.isEmpty())
        {
            Evaluation eval = new Evaluation(predictionList);
            double rocAUC = eval.getROCAUC();
            if (rocAUC<0.5)
            {
                System.out.println("\tFlipping signs...");

                for (int ii = 0; ii < predictionList.size();ii++)
                {
                    Prediction prediction = (Prediction)predictionList.get(ii);
                    double decValue = prediction.getDecValue();
                    prediction.setDecValue(-decValue);
                }
            }
        }
        return predictionList;
    }

    public HashMap kFoldCrossValidation(ValidationParameters validParams)
    {
        return new HashMap();
    }

    public HashMap leaveOutCrossValidation(ValidationParameters validParams)
    {
        return new HashMap();        
    }

}
