package org.baderlab.pdzsvm.predictor.pwm;

import org.baderlab.brain.*;
import org.biojava.bio.dp.SimpleWeightMatrix;
import org.biojava.bio.dp.WeightMatrix;
import org.biojava.bio.dist.DistributionTools;
import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.UniformDistribution;
import org.biojava.bio.symbol.*;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.seq.ProteinTools;
import org.biojava.utils.ChangeVetoException;

import java.util.*;

import org.baderlab.pdzsvm.utils.PDZSVMUtils;

/**
 * Copyright (c) 2010 University of Toronto
 * Code written by: Shirley Hui
 * Authors: Shirley Hui, Gary Bader
 *
 * This file is part of PDZSVM.
 *
 * PDZSVM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PDZSVM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  The software and
 * documentation provided hereunder is on an "as is" basis, and the
 * University of Toronto has no obligations to provide maintenance,
 * support, updates, enhancements or modifications.  In no event shall
 * the University of Toronto be liable to any party for direct, indirect,
 * special, incidental or consequential damages, including lost profits,
 * arising out of the use of this software and its documentation, even if
 * the University of Toronto has been advised of the possibility of such
 * damage. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with PDZSVM.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * Implementation for a positive weight matrix
 */
public class PWM {

    private static List alphabet = PDZSVMUtils.get20aaAlphabet();
    private SimpleWeightMatrix wm = null;
    private Distribution[] dist = null;
    private double[] maxWeights = null;

    public PWM(Distribution[] dist)
    {
        try
        {
            wm = new SimpleWeightMatrix(dist);
            this.dist = getDistribution();

        }
        catch(Exception e)
        {
            System.out.println("Exception: " + e);
        }

        maxWeights = new double[wm.columns()];
        try {
            for (int i = 0; i < wm.columns(); i++) {
                Distribution d = wm.getColumn(i);
                Iterator l = alphabet.iterator();
                while (l.hasNext()) {
                    Symbol symbol = (Symbol) l.next();
                    double weight = d.getWeight(symbol);
                    if (weight > maxWeights[i])
                    {
                        maxWeights[i] = weight;
                    }
                }
            }
        }
        catch (IllegalSymbolException e) {
            e.printStackTrace();
        }
    }

    public PWM(ProteinProfile proteinProfile)
    {
        wm = proteinProfile.getWeightMatrix();

        dist = getDistribution();

        maxWeights = new double[wm.columns()];
        try {
            for (int i = 0; i < wm.columns(); i++) {
                Distribution d = wm.getColumn(i);
                Iterator l = alphabet.iterator();
                while (l.hasNext()) {
                    Symbol symbol = (Symbol) l.next();
                    double weight = d.getWeight(symbol);
                    if (weight > maxWeights[i])
                    {
                        maxWeights[i] = weight;
                    }
                }
            }
        }
        catch (IllegalSymbolException e) {
            e.printStackTrace();
        }
    }

    public static double distance(PWM pwm1, PWM pwm2)
    {
        double distance = 0.0;
        WeightMatrix wm1 = pwm1.getWeightMatrix();
        WeightMatrix wm2 = pwm2.getWeightMatrix();

        if (wm1.columns() != wm2.columns()) {
            throw new IllegalArgumentException("Profiles to be compared must be the same length.");
        }
        try {
            for (int i = 0; i < wm1.columns(); i++) {
                Distribution d1 = wm1.getColumn(i);
                Distribution d2 = wm2.getColumn(i);
                Iterator l = alphabet.iterator();
                double weightSum = 0.0;
                while (l.hasNext()) {
                    Symbol symbol = (Symbol) l.next();
                    weightSum += Math.pow(d1.getWeight(symbol) - d2.getWeight(symbol), 2);
                }
                weightSum = Math.sqrt(weightSum) / Math.sqrt(2);
                distance += weightSum;
            }
        } catch (IllegalSymbolException e) {
            e.printStackTrace();
        }

        distance = distance / wm1.columns();
        return distance;

    }

    public SimpleWeightMatrix getWeightMatrix()
    {
        return wm;
    }


    public PWM (int numCols)
    {
        dist = new Distribution[numCols];
        
        for (int i = 0; i < numCols;i++)
        {
            FiniteAlphabet a = ProteinTools.getAlphabet();
            Distribution d = new UniformDistribution(a);
            dist[i] = d;
        }
        try
        {
            wm = new SimpleWeightMatrix(dist);
        }
        catch(Exception e)
        {
            System.out.println("Exception: " + e);
            e.printStackTrace();
        }
    }


    public void print()
    {
        Distribution d0 = dist[0];
        FiniteAlphabet a = (FiniteAlphabet)d0.getAlphabet();
        String output ="";

        Iterator it = a.iterator();
        while(it.hasNext())
        {
            Symbol sym = (Symbol) it.next();
            double weight = -1.0;
            try
            {
                weight = d0.getWeight(sym);
                output = output + sym.getName() + "\t" + weight;
                for (int i = 1; i< dist.length;i++)
                {
                    Distribution d = dist[i];
                    weight = d.getWeight(sym);
                    output = output + "\t" + weight;
                }
            }
            catch(Exception e)
            {
                System.out.println("Exception: " + e + sym.getName());
                e.printStackTrace();
                
                weight= -1.0;
                output = output + "\t" + weight;

            }
            output = output + "\n";
        }
        System.out.println(output);
    }
    public Distribution[] getDistribution()
    {

        int numCols = wm.columns();
        Distribution[] dist = new Distribution[numCols];
        AtomicSymbol sec = ProteinTools.sec();
        AtomicSymbol ter = ProteinTools.ter();
        //AtomicSymbol pyl = ProteinTools.pyl();
        for (int i = 0; i < numCols;i++)
        {
            Distribution d = wm.getColumn(i);
            FiniteAlphabet a = (FiniteAlphabet)d.getAlphabet();

            //System.out.println("Alphabet size: " + a.size());
            Symbol gap = a.getGapSymbol();
            //System.out.println("gap symbol:" + gap.getName());
            try
            {
                d.setWeight(sec,0.0);
                d.setWeight(ter,0.0);
                //d.setWeight(pyl,0.0);
                d.setWeight(gap,0.0);
            }
            catch(ChangeVetoException e1)
            {

            }
            catch(Exception e)
            {
                System.out.println("Exception: " + e);
                e.printStackTrace();
                
            }
            dist[i] = d;
        }
        return dist;
    }
    public static String generateSequence(Distribution[] dist)
    {
        String sequence = "";
        for (int i = 0; i < dist.length;i++)
        {
            String residue = "-";
            while( residue.equals("-"))
            {
                Sequence probResidue = DistributionTools.generateSequence("Residue", dist[i], 1);
                residue = probResidue.seqString();
                if (!residue.equals("-"))
                    break;
            }
            sequence = sequence + residue;

        }
        //System.out.println(sequence);
        return sequence;
    }

    public List generateSequences(Distribution[] dist, double threshold, int numSequences)
    {
        List sequenceList = new ArrayList();
        int numGen = 0;
        while (sequenceList.size() != numSequences)
        {
            String seq = generateSequence(dist);
            double proba = score(seq);
            //System.out.println("0\t"+proba);
            if (proba <= threshold)
            {
                sequenceList.add(seq);
                numGen = numGen +1;
                System.out.println(numGen);
            }
        }
        return sequenceList;
    }
    public double maxScore()
    {
        double maxScore = 0.0;
        for (int i = 0 ; i < maxWeights.length;i++)
        {
            maxScore = maxScore + maxWeights[i];
        }
        return maxScore;
    }

    public double score(String sequence)
    {
        double score = 0.0;
        for (int i = 0; i < sequence.length();i++)
        {
            char c = sequence.charAt(i);
            String residue = String.valueOf(c);
            //System.out.println(c);

            AtomicSymbol s = PDZSVMUtils.get20aaSymbol(residue);
            Distribution d = wm.getColumn(i);

            try
            {
                // if s == null then this symbol is not a 20 aa symbol (i.e. SEC, PYL, TER)
                if (s ==null)
                    return -1.0;
                if (s.getName().equals("SEC") || s.getName().equals("PYL") || s.getName().equals("TER"))
                    return -1.0;
                double weight = d.getWeight(s);

                score = score+weight;
            }
            catch(Exception e)
            {
                System.out.println("Exception: "+residue +", " + e);
                e.printStackTrace();
                return -1.0;

            }

        }
        // return the normalized score
        return score/maxScore();
    }

}