/*
 * Decompiled with CFR 0.152.
 */
package org.baderlab.brain.correlationlearn;

import java.awt.image.RenderedImage;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import javax.imageio.ImageIO;
import mt.MatrixEntry;
import org.baderlab.brain.PeptideToProfileReader;
import org.baderlab.brain.ProteinProfile;
import org.baderlab.brain.ProteinSequenceLogo;
import org.baderlab.brain.ProteinSequenceUtil;
import org.baderlab.brain.ProteinTerminus;
import org.baderlab.brain.correlationlearn.CorrelationResult;
import org.baderlab.brain.correlationlearn.FeatureUtils;
import org.baderlab.brain.correlationlearn.ResiduePositionPair;
import org.baderlab.brain.correlationlearn.ResidueResidueCorrelationMatrix;
import org.biojava.bio.BioException;
import org.biojava.bio.dist.Count;
import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.DistributionTools;
import org.biojava.bio.dist.IndexedCount;
import org.biojava.bio.seq.ProteinTools;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.seq.SequenceIterator;
import org.biojava.bio.seq.io.SeqIOTools;
import org.biojava.bio.symbol.FiniteAlphabet;
import org.biojava.bio.symbol.IllegalSymbolException;
import org.biojava.bio.symbol.Symbol;
import org.biojava.utils.ChangeVetoException;
import smt.FlexCompRowMatrix;
import smt.SparseVector;

public class CorrelationLearn {
    private ResidueResidueCorrelationMatrix rrcm = null;
    private HashMap profileNameToProfile = null;
    protected HashMap sequenceNameToSequence = null;
    protected int multipleSequenceAlignmentLength = 0;
    protected int totalDomainSequenceLength = -1;
    protected int totalPeptideSequenceLength = -1;
    protected String domainSequenceFilter = null;
    protected String peptideSequenceFilter = null;
    protected SparseVector domainFeatureFrequencyVector = null;
    protected SparseVector peptideFeatureFrequencyVector = null;
    protected int numDomainPositionsPerFeature = 0;
    protected int numPeptidePositionsPerFeature = 0;

    public CorrelationLearn(File multipleSequenceAlignmentFile, File peptideOrProjectFile, int peptideLength, ProteinTerminus terminus, int numDomainPositionsPerFeature, int numPeptidePositionsPerFeature) throws BioException, IOException {
        this.readAlignmentAndDetermineSequenceAlignmentWidth(multipleSequenceAlignmentFile);
        this.numDomainPositionsPerFeature = numDomainPositionsPerFeature;
        this.numPeptidePositionsPerFeature = numPeptidePositionsPerFeature;
        List proteinProfileList = PeptideToProfileReader.readPeptidesAsProfiles(peptideOrProjectFile, peptideLength, terminus, 0.0, null, true, false);
        this.profileNameToProfile = new HashMap();
        this.multipleSequenceAlignmentLength = 0;
        for (int i = 0; i < proteinProfileList.size(); ++i) {
            ProteinProfile proteinProfile = (ProteinProfile)proteinProfileList.get(i);
            if (this.totalPeptideSequenceLength < 0) {
                this.totalPeptideSequenceLength = proteinProfile.getNumColumns();
            } else if (this.totalPeptideSequenceLength != proteinProfile.getNumColumns()) {
                throw new IllegalArgumentException("All peptide sequences must be the same length across all files.Found a profile of length " + proteinProfile.getNumColumns() + " in " + proteinProfile.getName() + " but was expecting length " + this.totalPeptideSequenceLength + " (based on the length of the" + " first sequence seen).");
            }
            this.profileNameToProfile.put(proteinProfile.getName(), proteinProfile);
            this.multipleSequenceAlignmentLength += proteinProfile.getNumSequences();
        }
        this.rrcm = new ResidueResidueCorrelationMatrix(numDomainPositionsPerFeature, numPeptidePositionsPerFeature, this.totalDomainSequenceLength, this.totalPeptideSequenceLength);
    }

    private void readAlignmentAndDetermineSequenceAlignmentWidth(File multipleSequenceAlignmentFile) throws IOException, BioException {
        this.sequenceNameToSequence = new HashMap();
        BufferedReader brMSA = new BufferedReader(new FileReader(multipleSequenceAlignmentFile));
        SequenceIterator sequenceAlignment = (SequenceIterator)SeqIOTools.fileToBiojava((String)"fasta", (String)"PROTEIN", (BufferedReader)brMSA);
        while (sequenceAlignment.hasNext()) {
            Sequence seq = sequenceAlignment.nextSequence();
            if (this.totalDomainSequenceLength < 0) {
                this.totalDomainSequenceLength = seq.length();
            } else if (this.totalDomainSequenceLength != seq.length()) {
                throw new IllegalArgumentException("All domain sequences must be the same length.Found a domain sequence of length " + seq.length() + " called " + seq.getName() + " but was expecting length " + this.totalDomainSequenceLength + " (based on the length of the" + " first sequence seen).");
            }
            this.sequenceNameToSequence.put(seq.getName(), seq);
        }
        brMSA.close();
    }

    public void setLearnParams(int numDomainPositionsPerFeature, int numPeptidePositionsPerFeature) {
        this.numDomainPositionsPerFeature = numDomainPositionsPerFeature;
        this.numPeptidePositionsPerFeature = numPeptidePositionsPerFeature;
        this.rrcm = new ResidueResidueCorrelationMatrix(numDomainPositionsPerFeature, numPeptidePositionsPerFeature, this.totalDomainSequenceLength, this.totalPeptideSequenceLength);
    }

    public long learn() throws BioException {
        long numberCorrelationCounts = 0L;
        if (this.domainSequenceFilter != null) {
            this.totalDomainSequenceLength = ProteinSequenceUtil.countLengthOfFilteredStringResult(this.domainSequenceFilter);
        }
        if (this.peptideSequenceFilter != null) {
            this.totalPeptideSequenceLength = ProteinSequenceUtil.countLengthOfFilteredStringResult(this.peptideSequenceFilter);
        }
        long start = System.currentTimeMillis();
        Collection profiles = this.profileNameToProfile.values();
        for (ProteinProfile proteinProfile : profiles) {
            if (this.sequenceNameToSequence.containsKey(proteinProfile.getName())) {
                numberCorrelationCounts += (long)this.learnSequenceToPeptides((Sequence)this.sequenceNameToSequence.get(proteinProfile.getName()), proteinProfile, this.numDomainPositionsPerFeature, this.numPeptidePositionsPerFeature);
                continue;
            }
            System.out.println("No aligned sequence was found for profile: " + proteinProfile.getName());
        }
        long end = System.currentTimeMillis();
        System.out.println(numberCorrelationCounts + " correlations counted in " + (end - start) / 1000L + " seconds.");
        this.calculateDomainMSAFrequencies(this.numDomainPositionsPerFeature);
        this.calculatePeptideMSAFrequencies(this.numPeptidePositionsPerFeature);
        return numberCorrelationCounts;
    }

    public void setDomainSequenceFilter(String domainSequenceFilter) {
        this.domainSequenceFilter = domainSequenceFilter;
    }

    public void setPeptideSequenceFilter(String peptideSequenceFilter) {
        this.peptideSequenceFilter = peptideSequenceFilter;
    }

    protected int learnSequenceToPeptides(Sequence domainSequence, ProteinProfile bindingPeptides, int numDomainPositionsPerFeature, int numPeptidePositionsPerFeature) {
        int correlationCounts = 0;
        String domainSequenceString = null;
        domainSequenceString = this.domainSequenceFilter == null ? domainSequence.seqString() : ProteinSequenceUtil.filterSequenceByColumns(domainSequence, this.domainSequenceFilter);
        Collection peptides = bindingPeptides.getSequenceMap();
        String peptideSequenceString = null;
        for (Sequence peptideSequence : peptides) {
            peptideSequenceString = this.peptideSequenceFilter == null ? peptideSequence.seqString() : ProteinSequenceUtil.filterSequenceByColumns(peptideSequence, this.peptideSequenceFilter);
            correlationCounts += this.learnSequenceToPeptide(domainSequenceString, peptideSequenceString, numDomainPositionsPerFeature, numPeptidePositionsPerFeature);
        }
        return correlationCounts;
    }

    protected int learnSequenceToPeptide(String domainSequenceString, String peptideSequenceString, int numDomainPositionsPerFeature, int numPeptidePositionsPerFeature) {
        int numberOfDomainFeatures = FeatureUtils.getCombinations(this.totalDomainSequenceLength, numDomainPositionsPerFeature);
        int[] domainPositionArray = new int[numDomainPositionsPerFeature];
        boolean initializeDomain = true;
        ResiduePositionPair[] domainFeature = FeatureUtils.allocateFeature(numDomainPositionsPerFeature);
        int numberOfPeptideFeatures = FeatureUtils.getCombinations(this.totalPeptideSequenceLength, numPeptidePositionsPerFeature);
        int[] peptidePositionArray = new int[numPeptidePositionsPerFeature];
        boolean initializePeptide = true;
        ResiduePositionPair[] peptideFeature = FeatureUtils.allocateFeature(numPeptidePositionsPerFeature);
        int correlationCounts = 0;
        for (int i = 0; i < numberOfDomainFeatures; ++i) {
            domainPositionArray = FeatureUtils.generateFeaturePositions(domainPositionArray, domainSequenceString.length(), initializeDomain);
            if (initializeDomain) {
                initializeDomain = false;
            }
            if (!FeatureUtils.featureValid(domainFeature = FeatureUtils.createFeature(domainFeature, domainPositionArray, domainSequenceString))) continue;
            for (int j = 0; j < numberOfPeptideFeatures; ++j) {
                peptidePositionArray = FeatureUtils.generateFeaturePositions(peptidePositionArray, peptideSequenceString.length(), initializePeptide);
                if (initializePeptide) {
                    initializePeptide = false;
                }
                if (!FeatureUtils.featureValid(peptideFeature = FeatureUtils.createFeature(peptideFeature, peptidePositionArray, peptideSequenceString))) continue;
                this.rrcm.addCorrelationCount(domainFeature, peptideFeature);
                ++correlationCounts;
            }
            initializePeptide = true;
        }
        return correlationCounts;
    }

    public void calculateDomainMSAFrequencies(int numDomainPositionsPerFeature) {
        int numberOfDomainFeatures = FeatureUtils.getCombinations(this.totalDomainSequenceLength, numDomainPositionsPerFeature);
        int[] domainPositionArray = new int[numDomainPositionsPerFeature];
        boolean initializeDomain = true;
        ResiduePositionPair[] domainFeature = FeatureUtils.allocateFeature(numDomainPositionsPerFeature);
        this.domainFeatureFrequencyVector = new SparseVector(FeatureUtils.getMaxIndex(numDomainPositionsPerFeature, this.totalDomainSequenceLength));
        Collection alignedSequences = this.sequenceNameToSequence.values();
        for (Sequence domainSequence : alignedSequences) {
            if (!this.profileNameToProfile.containsKey(domainSequence.getName())) continue;
            String domainSequenceString = null;
            domainSequenceString = this.domainSequenceFilter == null ? domainSequence.seqString() : ProteinSequenceUtil.filterSequenceByColumns(domainSequence, this.domainSequenceFilter);
            initializeDomain = true;
            ProteinProfile domainProfile = (ProteinProfile)this.profileNameToProfile.get(domainSequence.getName());
            for (int i = 0; i < numberOfDomainFeatures; ++i) {
                domainPositionArray = FeatureUtils.generateFeaturePositions(domainPositionArray, domainSequenceString.length(), initializeDomain);
                if (initializeDomain) {
                    initializeDomain = false;
                }
                if (!FeatureUtils.featureValid(domainFeature = FeatureUtils.createFeature(domainFeature, domainPositionArray, domainSequenceString))) continue;
                this.addFrequencyCount(domainFeature, this.domainFeatureFrequencyVector, domainProfile.getNumSequences());
            }
        }
    }

    public void calculatePeptideMSAFrequencies(int numPeptidePositionsPerFeature) {
        int numberOfPeptideFeatures = FeatureUtils.getCombinations(this.totalPeptideSequenceLength, numPeptidePositionsPerFeature);
        int[] peptidePositionArray = new int[numPeptidePositionsPerFeature];
        boolean initializePeptide = true;
        ResiduePositionPair[] peptideFeature = FeatureUtils.allocateFeature(numPeptidePositionsPerFeature);
        this.peptideFeatureFrequencyVector = new SparseVector(FeatureUtils.getMaxIndex(numPeptidePositionsPerFeature, this.totalPeptideSequenceLength));
        Collection proteinProfileList = this.profileNameToProfile.values();
        for (ProteinProfile proteinProfile : proteinProfileList) {
            Collection peptides = proteinProfile.getSequenceMap();
            for (Sequence peptideSequence : peptides) {
                String peptideSequenceString = null;
                peptideSequenceString = this.peptideSequenceFilter == null ? peptideSequence.seqString() : ProteinSequenceUtil.filterSequenceByColumns(peptideSequence, this.peptideSequenceFilter);
                for (int j = 0; j < numberOfPeptideFeatures; ++j) {
                    peptidePositionArray = FeatureUtils.generateFeaturePositions(peptidePositionArray, peptideSequenceString.length(), initializePeptide);
                    if (initializePeptide) {
                        initializePeptide = false;
                    }
                    if (!FeatureUtils.featureValid(peptideFeature = FeatureUtils.createFeature(peptideFeature, peptidePositionArray, peptideSequenceString))) continue;
                    this.addFrequencyCount(peptideFeature, this.peptideFeatureFrequencyVector, 1.0);
                }
                initializePeptide = true;
            }
        }
    }

    private void addFrequencyCount(ResiduePositionPair[] feature, SparseVector frequencyVector, double count) {
        frequencyVector.add(FeatureUtils.getFeatureIndex(feature), count);
    }

    protected double getFrequencyCount(ResiduePositionPair[] feature, SparseVector frequencyVector) {
        return frequencyVector.get(FeatureUtils.getFeatureIndex(feature));
    }

    private TreeMap getMostInformativeFeatures(double scoreThreshold) {
        TreeMap sortedResultMap = new TreeMap();
        FlexCompRowMatrix correlationMatrix = this.rrcm.getCorrelationMatrix();
        if (correlationMatrix == null) {
            return null;
        }
        Iterator iterator = correlationMatrix.iterator();
        ResiduePositionPair[] domainFeature = FeatureUtils.allocateFeature(this.numDomainPositionsPerFeature);
        ResiduePositionPair[] peptideFeature = FeatureUtils.allocateFeature(this.numPeptidePositionsPerFeature);
        while (iterator.hasNext()) {
            double peptideFrequency;
            MatrixEntry matrixEntry = (MatrixEntry)iterator.next();
            domainFeature = this.getDomainFeatureFromSparseMatrixEntry(matrixEntry, domainFeature);
            peptideFeature = this.getPeptideFeatureFromSparseMatrixEntry(matrixEntry, peptideFeature);
            double domainFrequency = this.getFrequencyCount(domainFeature, this.domainFeatureFrequencyVector);
            double score = FeatureUtils.scoreFeature(matrixEntry, domainFrequency, peptideFrequency = this.getFrequencyCount(peptideFeature, this.peptideFeatureFrequencyVector), this.multipleSequenceAlignmentLength);
            if (!(score < scoreThreshold)) continue;
            this.addResultToSortedResultMap(score, matrixEntry, sortedResultMap);
        }
        return sortedResultMap;
    }

    public void printMostInformativeFeatures(String chosenSequenceName, double scoreThreshold) {
        ResiduePositionPair[] domainFeature = FeatureUtils.allocateFeature(this.numDomainPositionsPerFeature);
        ResiduePositionPair[] peptideFeature = FeatureUtils.allocateFeature(this.numPeptidePositionsPerFeature);
        TreeMap sortedResultMap = this.getMostInformativeFeatures(scoreThreshold);
        System.out.println("Features better than score " + scoreThreshold + " (lower is better) shown with " + chosenSequenceName);
        System.out.println("Total domain length: " + this.totalDomainSequenceLength);
        System.out.println("Total peptide length: " + this.totalPeptideSequenceLength);
        Set scores = sortedResultMap.keySet();
        for (Double score : scores) {
            ArrayList resultList = (ArrayList)sortedResultMap.get(score);
            for (int i = 0; i < resultList.size(); ++i) {
                CorrelationResult correlationResult = (CorrelationResult)resultList.get(i);
                this.printResult(correlationResult, chosenSequenceName, domainFeature, peptideFeature, score);
            }
        }
    }

    private void printResult(CorrelationResult correlationResult, String chosenSequenceName, ResiduePositionPair[] domainFeature, ResiduePositionPair[] peptideFeature, Double score) {
        domainFeature = FeatureUtils.indexToFeature(correlationResult.domainFeature, domainFeature, this.totalDomainSequenceLength);
        peptideFeature = FeatureUtils.indexToFeature(correlationResult.peptideFeature, peptideFeature, this.totalPeptideSequenceLength);
        int domainFrequency = (int)this.getFrequencyCount(domainFeature, this.domainFeatureFrequencyVector);
        int peptideFrequency = (int)this.getFrequencyCount(peptideFeature, this.peptideFeatureFrequencyVector);
        System.out.println(FeatureUtils.featureToString(domainFeature) + " " + FeatureUtils.featureToString(peptideFeature) + " " + score + " (" + correlationResult.correlationCount + "," + domainFrequency + "," + peptideFrequency + ")");
        Sequence chosenSequence = (Sequence)this.sequenceNameToSequence.get(chosenSequenceName);
        String sequenceString = chosenSequence.seqString();
        StringBuffer sb = new StringBuffer(sequenceString);
        for (int i = domainFeature.length - 1; i >= 0; --i) {
            ResiduePositionPair residuePositionPair = domainFeature[i];
            sb.replace(residuePositionPair.position, residuePositionPair.position + 1, "[" + sequenceString.charAt(residuePositionPair.position) + "]");
        }
        System.out.println(sb);
        boolean printedAResidue = false;
        for (int i = 0; i < this.totalPeptideSequenceLength; ++i) {
            for (int j = 0; j < peptideFeature.length; ++j) {
                ResiduePositionPair residuePositionPair = peptideFeature[j];
                if (residuePositionPair.position != i) continue;
                System.out.print(residuePositionPair.residue);
                printedAResidue = true;
                break;
            }
            if (!printedAResidue) {
                System.out.print("-");
                continue;
            }
            printedAResidue = false;
        }
        System.out.print("\n");
    }

    private void addResultToSortedResultMap(double score, MatrixEntry matrixEntry, TreeMap sortedResultMap) {
        Double scoreDouble = new Double(score);
        ArrayList resultList = null;
        resultList = !sortedResultMap.containsKey(scoreDouble) ? new ArrayList() : (ArrayList)sortedResultMap.get(scoreDouble);
        CorrelationResult result = null;
        result = this.rrcm.domainPositionsAsRows() ? new CorrelationResult(matrixEntry.row(), matrixEntry.column(), (int)matrixEntry.get()) : new CorrelationResult(matrixEntry.column(), matrixEntry.row(), (int)matrixEntry.get());
        resultList.add(result);
        sortedResultMap.put(scoreDouble, resultList);
    }

    protected ResiduePositionPair[] getDomainFeatureFromSparseMatrixEntry(MatrixEntry matrixEntry, ResiduePositionPair[] preAllocatedDomainFeature) {
        int columnIndex = 0;
        int rowIndex = 0;
        ResiduePositionPair[] domainFeature = null;
        columnIndex = matrixEntry.column();
        rowIndex = matrixEntry.row();
        domainFeature = this.rrcm.domainPositionsAsRows() ? FeatureUtils.indexToFeature(rowIndex, preAllocatedDomainFeature, this.totalDomainSequenceLength) : FeatureUtils.indexToFeature(columnIndex, preAllocatedDomainFeature, this.totalDomainSequenceLength);
        return domainFeature;
    }

    protected ResiduePositionPair[] getPeptideFeatureFromSparseMatrixEntry(MatrixEntry matrixEntry, ResiduePositionPair[] preAllocatedPeptideFeature) {
        int columnIndex = 0;
        int rowIndex = 0;
        ResiduePositionPair[] peptideFeature = null;
        columnIndex = matrixEntry.column();
        rowIndex = matrixEntry.row();
        peptideFeature = this.rrcm.domainPositionsAsRows() ? FeatureUtils.indexToFeature(columnIndex, preAllocatedPeptideFeature, this.totalPeptideSequenceLength) : FeatureUtils.indexToFeature(rowIndex, preAllocatedPeptideFeature, this.totalPeptideSequenceLength);
        return peptideFeature;
    }

    public ProteinProfile predictProfileFromSequence(Sequence alignedDomainSequence) {
        ProteinProfile proteinProfile = null;
        String seqString = alignedDomainSequence.seqString();
        double domainFrequency = 0.0;
        double peptideFrequency = 0.0;
        double conditionalEntropy = 0.0;
        Distribution currentDist = null;
        HashMap alphabetMap = ProteinSequenceUtil.get20aaAlphabet();
        Symbol residue = null;
        Distribution[] weightMatrixColumns = null;
        Object weightMatrix = null;
        weightMatrixColumns = new Distribution[this.totalPeptideSequenceLength];
        IndexedCount c = new IndexedCount(ProteinTools.getAlphabet());
        for (int i = 0; i < weightMatrixColumns.length; ++i) {
            weightMatrixColumns[i] = DistributionTools.countToDistribution((Count)c);
            FiniteAlphabet fa = ProteinTools.getAlphabet();
            for (Symbol symbol : fa) {
                try {
                    weightMatrixColumns[i].setWeight(symbol, 1.0E-10);
                }
                catch (IllegalSymbolException e) {
                    e.printStackTrace();
                }
                catch (ChangeVetoException e) {
                    e.printStackTrace();
                }
            }
        }
        FlexCompRowMatrix correlationMatrix = this.rrcm.getCorrelationMatrix();
        if (correlationMatrix == null) {
            return null;
        }
        Iterator iterator = correlationMatrix.iterator();
        ResiduePositionPair[] domainFeature = FeatureUtils.allocateFeature(this.numDomainPositionsPerFeature);
        ResiduePositionPair[] peptideFeature = FeatureUtils.allocateFeature(this.numPeptidePositionsPerFeature);
        while (iterator.hasNext()) {
            MatrixEntry matrixEntry = (MatrixEntry)iterator.next();
            domainFeature = this.getDomainFeatureFromSparseMatrixEntry(matrixEntry, domainFeature);
            peptideFeature = this.getPeptideFeatureFromSparseMatrixEntry(matrixEntry, peptideFeature);
            if (!FeatureUtils.isFeatureInSequence(domainFeature, seqString) || !((conditionalEntropy = FeatureUtils.scoreFeature(matrixEntry, domainFrequency = this.getFrequencyCount(domainFeature, this.domainFeatureFrequencyVector), peptideFrequency = this.getFrequencyCount(peptideFeature, this.peptideFeatureFrequencyVector), this.multipleSequenceAlignmentLength)) < -0.03)) continue;
            for (int i = 0; i < peptideFeature.length; ++i) {
                ResiduePositionPair residuePositionPair = peptideFeature[i];
                residue = (Symbol)alphabetMap.get(String.valueOf(residuePositionPair.residue));
                currentDist = weightMatrixColumns[residuePositionPair.position];
                try {
                    currentDist.setWeight(residue, currentDist.getWeight(residue) + Math.abs(conditionalEntropy));
                    continue;
                }
                catch (IllegalSymbolException e) {
                    e.printStackTrace();
                    continue;
                }
                catch (ChangeVetoException e) {
                    e.printStackTrace();
                }
            }
        }
        return proteinProfile;
    }

    public void outputAllLogos() {
        String outputDirectory = "D:\\Gbader\\Code\\PDZ\\data\\PDZ\\SpecificityPrediction\\Logos";
        Collection alignedSequences = this.sequenceNameToSequence.values();
        for (Sequence sequence : alignedSequences) {
            ProteinProfile proteinProfile = this.predictProfileFromSequence(sequence);
            String outFileName = new String(outputDirectory + File.separator + proteinProfile.getName() + ".png");
            ProteinSequenceLogo logo = new ProteinSequenceLogo(proteinProfile, 240);
            try {
                logo.sequenceLogoSetStartIndex(-9);
                ImageIO.write((RenderedImage)logo.drawSequenceLogo(), "png", new File(outFileName));
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
}

