package org.baderlab.pdzsvm.data.manager;


import java.util.*;
import java.util.List;
import org.baderlab.pdzsvm.utils.Constants;
import org.baderlab.brain.*;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.seq.ProteinTools;
import org.biojava.bio.seq.db.HashSequenceDB;
import org.baderlab.pdzsvm.predictor.pwm.PWM;
import org.baderlab.pdzsvm.utils.PDZSVMUtils;
import org.baderlab.pdzsvm.data.DataRepository;
import weka.core.Utils;


/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Manager for the generation of different types of artificial negatives
 * including random, shuffled, randomly selected and PWM negatives.
 */
public class ArtificialNegativesDataManager
{
    public ArtificialNegativesDataManager()
    {
        System.out.println("\n\tInitializing artificial negative data manager...\n");

    }
    public static void main(String[] args)
    {
        DataRepository dr = DataRepository.getInstance();
        ArtificialNegativesDataManager an = new ArtificialNegativesDataManager();
        an.getShuffledNegatives(dr.humanPosPDList,Constants.NUM_RED_PEPTIDES);
        an.getRandomNegatives(dr.humanPosPDList, Constants.NUM_RED_PEPTIDES);
        an.getPWMNegatives(dr.humanPosPDList);
        System.out.println();
    }

    private String shuffleSequence(Sequence seq)
    {
        String shuffledSeqString = "";
        List sequenceIndices = new ArrayList();
        String sequence = seq.seqString();
        for (int ii= 0;ii < sequence.length();ii++)
        {
            sequenceIndices.add(ii);
        }
        Collections.shuffle(sequenceIndices);
        for (int ii = 0; ii < sequenceIndices.size();ii++)
        {
            shuffledSeqString =  shuffledSeqString + sequence.charAt(((Integer)sequenceIndices.get(ii)).intValue());
        }

        return shuffledSeqString;
    }
    private void padSeqList(List seqList, int numToPad)
    {
        for (int i=0; i < numToPad;i++)
        {
            Collections.shuffle(seqList);
            seqList.add(seqList.get(0));
        }
    }
    public List getShuffledNegatives(List trainProfileList, int numRedPeptides)
    {
        System.out.println("\tShuffled negatives...");

        List artNegProfileList = new ArrayList();
        int totSeq = 0;
        for (int i = 0; i < trainProfileList.size();i++)
        {
            HashSequenceDB negSequenceDB = new HashSequenceDB();
            ProteinProfile trainProfile = (ProteinProfile)trainProfileList.get(i);

            String domainName = trainProfile.getName();
            Collection seqCollection = trainProfile.getSequenceMap();
            List seqList = new ArrayList(seqCollection);
            List negSequences = getNegatives(trainProfileList, trainProfile, numRedPeptides);
            int num= negSequences.size();

            int numToPad = num-seqList.size();
            padSeqList(seqList, numToPad);
            
            for (int ii = 0; ii < num;ii++)
            {
                Sequence seq = (Sequence)seqList.get(ii);
                String possibleNegSequence = shuffleSequence(seq);
                try
                {
                    Sequence prot = ProteinTools.createProteinSequence(possibleNegSequence, domainName + ii);
                    negSequenceDB.addSequence(prot);
                    totSeq = totSeq+1;
                }
                catch(Exception e)
                {
                    System.out.println("Exception: " + e);
                } // try
            } // while
            ProteinProfile artNegProfile = PDZSVMUtils.makeProfile( trainProfile, negSequenceDB);
            if (artNegProfile!=null)
                artNegProfileList.add(artNegProfile);
        }// for
        System.out.println("\tTotal number sequence generated: " + totSeq);
        System.out.println("\tShuffled neg stats:");

        getStats(trainProfileList, artNegProfileList);
        return artNegProfileList;
    }

    public List getMixedUpNegatives(List profileList)
    {
        System.out.println("\tGetting mixed up negatives...");
        List artNegProfileList = new ArrayList();

        SequencePoolManager sm = new SequencePoolManager(profileList);
        List sequencePoolList = sm.getSequencePool();
        System.out.println("\tNumber of sequences: " + sequencePoolList.size());

        Collections.shuffle(sequencePoolList);
        // randomly pair them up
        int ix = 0;
        for (int i = 0; i < profileList.size();i++)
        {
            ProteinProfile trainProfile = (ProteinProfile)profileList.get(i);
            String domainName = trainProfile.getName();
            HashSequenceDB negSequenceDB = new HashSequenceDB();
            Collection seqCollection = trainProfile.getSequenceMap();

            Iterator it = seqCollection.iterator();
            List trainSeqList = new ArrayList();
            while(it.hasNext())
            {
                Sequence seq = (Sequence)it.next();
                trainSeqList.add(seq.seqString());
            }
            int numAdded = 0;
            while( numAdded < seqCollection.size())
            {
                String reassignedSequence = (String)sequencePoolList.get(ix);

                while (trainSeqList.contains(reassignedSequence))
                {
                    ix= ix+1;
                    if (ix==sequencePoolList.size())
                        ix= 0;
                    reassignedSequence = (String)sequencePoolList.get(ix);

                }
                try
                {
                    Sequence prot = ProteinTools.createProteinSequence(reassignedSequence, domainName + numAdded);
                    negSequenceDB.addSequence(prot);
                    numAdded = numAdded+1;
                }
                catch(Exception e)
                {
                    System.out.println("Exception: " + e);
                } // try
                ix =ix+1;
                if (ix == sequencePoolList.size())
                    ix = 0;
            } // for
            ProteinProfile artNegProfile = PDZSVMUtils.makeProfile( trainProfile, negSequenceDB);
            if (artNegProfile!=null)
                artNegProfileList.add(artNegProfile);
        }// for
        getStats(profileList, artNegProfileList);
        System.out.println("\t#Pos profiles: " + profileList.size());
        System.out.println("\t#Neg profiles: " + artNegProfileList.size());
        System.out.println();

        return artNegProfileList;

    }
    public List getRandomSelectionNegatives(List trainProfileList, int numRedPeptides)
    {
        System.out.println("\tRandom selection negatives...");

        List artNegProfileList = new ArrayList();
        SequencePoolManager sm = new SequencePoolManager(trainProfileList);
        List sequencePoolList = sm.getSequencePool();
        int totSeq=0;
        for (int i = 0; i < trainProfileList.size();i++)
        {
            ProteinProfile trainProfile = (ProteinProfile)trainProfileList.get(i);
            List negSequences = getNegatives(trainProfileList, trainProfile, numRedPeptides);
            int num= negSequences.size();

            Collections.shuffle(sequencePoolList);
            List shuffledSequenceList = sequencePoolList.subList(0,num);
            HashSequenceDB seqDB = new HashSequenceDB();
            for (int ii=0; ii < shuffledSequenceList.size();ii++)
            {
                String seq = (String)sequencePoolList.get(ii);
                try
                {
                    Sequence prot = ProteinTools.createProteinSequence(seq, trainProfile.getName() + ii);
                    seqDB.addSequence(prot);
                    totSeq = totSeq +1;
                }
                catch(Exception e)
                {
                    System.out.println("Exception: " + e);
                }

            }// for
            ProteinProfile artNegProfile = PDZSVMUtils.makeProfile( trainProfile, seqDB);
            if (artNegProfile!=null)
                artNegProfileList.add(artNegProfile);
        } // for
        System.out.println("\tTotal number sequence generated: " + totSeq);

        System.out.println("\tRandom selection neg stats:");

        getStats(trainProfileList, artNegProfileList);

        return artNegProfileList;
    }
    public List getRandomNegatives(List trainProfileList,int numRedPeptides)
    {
        System.out.println("\tRandom negatives...");

        List artNegProfileList = new ArrayList();
        ProteinProfile profile = (ProteinProfile)trainProfileList.get(0);
        int peptideLength = getPeptideLength(profile);
        PWM randomPWM = new PWM(peptideLength);
        int totSeq=0;
        for (int i = 0; i < trainProfileList.size();i++)
        {

            HashSequenceDB negSequenceDB = new HashSequenceDB();

            ProteinProfile trainProfile = (ProteinProfile)trainProfileList.get(i);
            List negSequences = getNegatives(trainProfileList, trainProfile, numRedPeptides);
            int num= negSequences.size();

            String domainName = trainProfile.getName();

            for (int ii=0; ii < num;ii++)
            {
                String possibleNegSequence = randomPWM.generateSequence(randomPWM.getDistribution());
                try
                {
                    Sequence prot = ProteinTools.createProteinSequence(possibleNegSequence, domainName + ii);
                    negSequenceDB.addSequence(prot);
                    totSeq = totSeq+1;
                }
                catch(Exception e)
                {
                    System.out.println("Exception: " + e);
                    e.printStackTrace();
                }
            } // while

            ProteinProfile artNegProfile = PDZSVMUtils.makeProfile( trainProfile, negSequenceDB);
            if (artNegProfile!=null)
                artNegProfileList.add(artNegProfile);
        } // for
        System.out.println("\tTotal number sequence generated: " + totSeq);

        System.out.println("\tRandom neg stats:");

        getStats(trainProfileList, artNegProfileList);
        return artNegProfileList;
    }
    private int getPeptideLength(ProteinProfile profile)
    {
        Collection seqCollection = profile.getSequenceMap();
        List seqList = new ArrayList(seqCollection);
        Sequence seq = (Sequence)seqList.get(0);
        int peptideLength = seq.length();
        return peptideLength;
    }
    private List getNegatives(List refProfileList, ProteinProfile profile, int numPeptideSim)
    {
        SequencePoolManager sm = new SequencePoolManager(refProfileList);
        List sequencePoolList = sm.getSequencePool();

        //System.out.println("\tSequence pool size: "+ sequencePoolList.size());
        PWM pwm = new PWM(profile);
        List sortedSequencePoolList = sm.sortSequencePool(SequencePoolManager.DESC, pwm);

        // Find out the cutoff wrt to the profile passed in
        Collection seqCollection = profile.getSequenceMap();
        List seqList= new ArrayList(seqCollection);
        double cutoff = Double.MAX_VALUE;
        for (int ii =0; ii < seqList.size();ii++)
        {
            Sequence seq = (Sequence)seqList.get(ii);
            String seqString = seq.seqString();
            double score = pwm.score(seqString);
            if (score < cutoff) cutoff= score;
        }
        // Scan all peptides in the sequence pool for those < cutoff and not too similar to the ones already found as defined by isLike
        List reducedSequencePoolList  = new ArrayList();
        for (int ii=0; ii < sortedSequencePoolList.size();ii++)
        {
            String seq = (String) sortedSequencePoolList.get(ii);
            if (!SequencePoolManager.isLike(reducedSequencePoolList, seq, numPeptideSim))
            {
                double score = pwm.score(seq);
                if (score < cutoff)  reducedSequencePoolList.add(seq);
            }
        }

        System.out.println("\t=== " +profile.getName() +" (" +reducedSequencePoolList.size() +" of "+sequencePoolList.size()+ ") ===");

        return reducedSequencePoolList;
    }

    public List getPWMNegatives(List profileList)
    {
        return getPWMNegatives(profileList, Constants.NUM_RED_PEPTIDES);
    }

    public List getPWMNegatives(List profileList, int numRedPeptides)
    {
        System.out.println("\tGetting PWM negatives...");
        System.out.println("\tNum similar peptides: " + numRedPeptides);

        List artNegProfileList = new ArrayList();
        int totSeq =0;
        for (int i = 0; i < profileList.size();i++)
        {
            ProteinProfile trainProfile = (ProteinProfile)profileList.get(i);

            List artNegSequences = getNegatives(profileList, trainProfile, numRedPeptides);

            HashSequenceDB seqDB = new HashSequenceDB();

            String domainName = trainProfile.getName();
            for (int j = 0; j < artNegSequences.size();j++)
            {
                String seq = (String) artNegSequences.get(j);
              
                try
                {
                    Sequence prot = ProteinTools.createProteinSequence(seq, domainName + "-" + j);
                    seqDB.addSequence(prot);
                    totSeq = totSeq +1;
                    //System.out.println(" Added: " + score +","+seq);
                }
                catch(Exception e)
                {
                    System.out.println("Exception: " + e);
                }

            }
            ProteinProfile artNegProfile = PDZSVMUtils.makeProfile( trainProfile, seqDB);
            if (artNegProfile!=null)
                artNegProfileList.add(artNegProfile);
        }
        System.out.println("\tTotal number sequence generated: " + totSeq);

        System.out.println("\tPWM neg stats:");
        getStats(profileList, artNegProfileList);
        System.out.println("\t#Pos profiles: " + profileList.size());
        System.out.println("\t#Neg profiles: " + artNegProfileList.size());
        System.out.println();
        return artNegProfileList;
    }
    private void getStats(List trainProfileList, List artNegProfileList)
    {
        HashMap trainProfileMap = PDZSVMUtils.profileListToHashMap(trainProfileList);
        double mean = 0.0;
        double min = Double.MAX_VALUE;
        double max = Double.MIN_VALUE;
        int total = 0;
        int ix = 0;
        for (int i = 0; i < artNegProfileList.size();i++)
        {
            ProteinProfile profile = (ProteinProfile)artNegProfileList.get(i);
            ProteinProfile trainProfile = (ProteinProfile)trainProfileMap.get(profile.getName());
            if (trainProfile == null)
                continue;
            PWM pwm = new PWM(trainProfile);
            Collection seqCollection = profile.getSequenceMap();
            Iterator it = seqCollection.iterator();
            while(it.hasNext())
            {
                Sequence sequence = (Sequence)it.next();
                String seq = sequence.seqString();
                double score = pwm.score(seq);
                mean = mean + score;
                if(score < min)
                    min = score;
                if (score > max)
                    max = score;

            }
            total = total + seqCollection.size();
            ix = ix +1;
        }
        System.out.println("\tMin, max, mean: [" +Utils.doubleToString(min,7,3) + "," +Utils.doubleToString(max,7,3) + "," + Utils.doubleToString(mean/total,7,3) + "]");

    }

} // end class
