/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.core;

import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.VectorEntry;
import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
import org.apache.log4j.Logger;
import org.genemania.engine.core.MatrixUtils;

public class ProfileToNetwork {
    private static Logger logger = Logger.getLogger(ProfileToNetwork.class);
    static boolean isNaNCheckingEnabled = false;

    public static Matrix continuousProfile(Matrix profile, int k) {
        return ProfileToNetwork.continuousProfile(profile, k, 25.0);
    }

    public static Matrix continuousProfile(Matrix profile, int k, double maxMissingPercentage) {
        int numGenes = profile.numColumns();
        int numFeatures = profile.numRows();
        FlexCompColMatrix network = new FlexCompColMatrix(numGenes, numGenes);
        boolean[] goodGenes = MatrixUtils.checkColumnsforMissingDataThreshold(profile, maxMissingPercentage);
        Vector counts = MatrixUtils.columnCountsIgnoreMissingData(profile);
        Vector means = MatrixUtils.columnMeanIgnoreMissingData(profile, counts);
        Vector stdevs = MatrixUtils.columnVarianceIgnoreMissingData(profile, means);
        MatrixUtils.sqrt(stdevs);
        ProfileToNetwork.computeCorrelationTerms(profile, means, stdevs, counts);
        MatrixUtils.replaceMissingData(profile, 0.0);
        MatrixUtils.maskFalseColumns(profile, goodGenes, 0.0);
        ProfileToNetwork.findNaNs(profile);
        for (int i = 0; i < numGenes; ++i) {
            if (!goodGenes[i]) continue;
            Matrix correlations = ProfileToNetwork.computeCorrelations(profile, i);
            ProfileToNetwork.findNaNs(correlations);
            ProfileToNetwork.setTopK((Matrix)network, correlations, i, k + 1);
        }
        MatrixUtils.setDiagonalZero((Matrix)network);
        MatrixUtils.setToMaxTranspose((Matrix)network);
        return network;
    }

    public static Object[] continuousProfileWithNames(Matrix profile, int k, String[] names, Map<String, String> identifierMap, double maxMissingPercentage) {
        int numFeatures = profile.numRows();
        Vector counts = MatrixUtils.columnCountsIgnoreMissingData(profile);
        Vector means = MatrixUtils.columnMeanIgnoreMissingData(profile, counts);
        Vector stdevs = MatrixUtils.columnVarianceIgnoreMissingData(profile, means);
        MatrixUtils.sqrt(stdevs);
        ProfileToNetwork.computeCorrelationTerms(profile, means, stdevs, counts);
        Object[] returnValues = ProfileToNetwork.averageRepeatedIdentifiers(profile, names, identifierMap);
        profile = (Matrix)returnValues[0];
        names = (String[])returnValues[1];
        boolean[] goodGenes = MatrixUtils.checkColumnsforMissingDataThreshold(profile, maxMissingPercentage);
        MatrixUtils.replaceMissingData(profile, 0.0);
        MatrixUtils.maskFalseColumns(profile, goodGenes, 0.0);
        ProfileToNetwork.findNaNs(profile);
        int numGenes = profile.numColumns();
        FlexCompColMatrix network = new FlexCompColMatrix(numGenes, numGenes);
        for (int i = 0; i < numGenes; ++i) {
            if (!goodGenes[i]) continue;
            Matrix correlations = ProfileToNetwork.computeCorrelations(profile, i);
            ProfileToNetwork.findNaNs(correlations);
            ProfileToNetwork.setTopK((Matrix)network, correlations, i, k + 1);
        }
        MatrixUtils.setDiagonalZero((Matrix)network);
        MatrixUtils.setToMaxTranspose((Matrix)network);
        Object[] newReturnValues = new Object[]{network, names};
        return newReturnValues;
    }

    public static Matrix ThresholdOnly(Matrix profile, double low, double hi) {
        int numGenes = profile.numColumns();
        int numFeatures = profile.numRows();
        FlexCompColMatrix network = new FlexCompColMatrix(numGenes, numGenes);
        Vector counts = MatrixUtils.columnCountsIgnoreMissingData(profile);
        Vector means = MatrixUtils.columnMean(profile);
        Vector stdevs = MatrixUtils.columnVariance(profile, means);
        MatrixUtils.sqrt(stdevs);
        ProfileToNetwork.computeCorrelationTerms(profile, means, stdevs, counts);
        for (int i = 0; i < numGenes; ++i) {
            Matrix correlations = ProfileToNetwork.computeCorrelationsDiag(profile, i);
            ProfileToNetwork.findNaNs(correlations);
        }
        return network;
    }

    protected static void computeCorrelationTerms(Matrix profile, Vector means, Vector stdevs, Vector counts) {
        Vector sqrtCounts = counts.copy();
        MatrixUtils.sqrt(sqrtCounts);
        MatrixUtils.elementMult(stdevs, sqrtCounts);
        for (MatrixEntry e : profile) {
            double x = (e.get() - means.get(e.column())) / (stdevs.get(e.column()) + 1.0E-13);
            e.set(x);
        }
    }

    private static Matrix computeCorrelations(Matrix terms, int i) {
        ProfileToNetwork.findNaNs(terms);
        int numGenes = terms.numColumns();
        int numFeatures = terms.numRows();
        int[] all = Matrices.index((int)0, (int)numFeatures);
        Matrix iTerms = Matrices.getSubMatrix((Matrix)terms, (int[])all, (int[])new int[]{i}).copy();
        DenseMatrix correlations = new DenseMatrix(1, numGenes);
        iTerms.transAmult(terms, (Matrix)correlations);
        return correlations;
    }

    private static Matrix computeCorrelationsDiag(Matrix terms, int i) {
        ProfileToNetwork.findNaNs(terms);
        int numGenes = terms.numColumns();
        int numFeatures = terms.numRows();
        int[] all = Matrices.index((int)0, (int)numFeatures);
        int[] submatrixIndices = Matrices.index((int)0, (int)i);
        Matrix iTerms = Matrices.getSubMatrix((Matrix)terms, (int[])all, (int[])new int[]{i});
        DenseMatrix correlations = new DenseMatrix(i + 1, 1);
        for (int j = 0; j <= i; ++j) {
            double correlation = ProfileToNetwork.sillyColumnMult(terms, i, j);
            correlations.set(j, 0, correlation);
        }
        Matrix bTerms = Matrices.getSubMatrix((Matrix)terms, (int[])all, (int[])submatrixIndices).copy();
        DenseMatrix correlations2 = new DenseMatrix(i, 1);
        bTerms.transAmult(iTerms, (Matrix)correlations2);
        correlations.scale(1.0 / (double)numFeatures);
        return correlations;
    }

    public static double sillyColumnMult(Matrix terms, int i, int j) {
        DenseMatrix dterms = (DenseMatrix)terms;
        double prod = 0.0;
        int numRows = terms.numRows();
        int starti = i * numRows;
        int startj = j * numRows;
        double[] data = dterms.getData();
        for (int k = 0; k < terms.numRows(); ++k) {
            prod += data[starti + k] * data[startj + k];
        }
        return prod;
    }

    private static void setTopK(Matrix network, Matrix correlations, int i, int k) {
        Vector v = MatrixUtils.extractRowToVector(correlations, 0);
        int[] indices = MatrixUtils.getIndicesForSortedValues(v);
        for (int j = 0; j < indices.length && j < k; ++j) {
            if (v.get(indices[j]) == 0.0) continue;
            network.set(i, indices[j], v.get(indices[j]));
        }
    }

    public static void setHiLow(Matrix network, Matrix correlations, int i, double hi, double low) {
        Vector v = MatrixUtils.extractRowToVector(correlations, 0);
        for (int j = 0; j < correlations.numColumns(); ++j) {
            double val = correlations.get(0, j);
            if (!(val <= low) && !(val >= hi)) continue;
            network.set(i, j, val);
        }
    }

    public static Matrix binaryProfile(Matrix profile, int k) {
        ProfileToNetwork.transformBinaryProfile(profile);
        return ProfileToNetwork.continuousProfile(profile, k);
    }

    private static void transformBinaryProfile(Matrix profile) {
        Vector means = MatrixUtils.columnMean(profile.transpose());
        profile.transpose();
        Vector f1 = means.copy();
        MatrixUtils.log(f1);
        f1.scale(-1.0);
        Vector f0 = means.copy();
        f0.scale(-1.0);
        MatrixUtils.add(f0, 1.0);
        MatrixUtils.log(f0);
        for (MatrixEntry e : profile) {
            if (e.get() == 0.0) {
                e.set(f0.get(e.row()));
                continue;
            }
            if (e.get() == 1.0) {
                e.set(f1.get(e.row()));
                continue;
            }
            System.out.println("not a binary profile");
        }
    }

    public static Object[] averageRepeatedIdentifiers(Matrix profile, String[] names, Map<String, String> identifierMap) {
        LinkedHashMap<String, java.util.Vector<Integer>> occurances = ProfileToNetwork.findRepeats(names, identifierMap);
        int numFeatures = profile.numRows();
        int numGenes = occurances.size();
        DenseMatrix averagedProfile = new DenseMatrix(numFeatures, numGenes);
        String[] newNames = new String[numGenes];
        HashMap<String, String> preferredNames = ProfileToNetwork.getPreferredNames(identifierMap);
        int[] allRows = Matrices.index((int)0, (int)numFeatures);
        int col = 0;
        for (String uid : occurances.keySet()) {
            java.util.Vector<Integer> colList = occurances.get(uid);
            if (colList.size() == 1) {
                MatrixUtils.setColumn((Matrix)averagedProfile, col, profile, colList.get(0));
            } else {
                int[] cols = new int[colList.size()];
                for (int i = 0; i < colList.size(); ++i) {
                    cols[i] = colList.get(i);
                }
                Matrix subMatrix = Matrices.getSubMatrix((Matrix)profile, (int[])allRows, (int[])cols);
                Vector ave = MatrixUtils.rowMeanIgnoreMissingData(subMatrix);
                MatrixUtils.setColumn((Matrix)averagedProfile, col, ave);
            }
            newNames[col] = preferredNames.get(uid);
            ++col;
        }
        Object[] returnValues = new Object[]{averagedProfile, newNames};
        return returnValues;
    }

    private static HashMap<String, String> getPreferredNames(Map<String, String> identifierMap) {
        HashMap<String, String> preferredNames = new HashMap<String, String>();
        for (String id : identifierMap.keySet()) {
            String uid = identifierMap.get(id);
            if (preferredNames.containsKey(uid)) continue;
            preferredNames.put(uid, id);
        }
        return preferredNames;
    }

    private static LinkedHashMap<String, java.util.Vector<Integer>> findRepeats(String[] names, Map<String, String> identifierMap) {
        LinkedHashMap<String, java.util.Vector<Integer>> occurances = new LinkedHashMap<String, java.util.Vector<Integer>>();
        for (int i = 0; i < names.length; ++i) {
            java.util.Vector<Object> colList;
            String key = names[i];
            String uid = identifierMap.get(key);
            if (uid == null) continue;
            if (occurances.containsKey(uid)) {
                colList = occurances.get(uid);
                colList.add(i);
                continue;
            }
            colList = new java.util.Vector();
            colList.add(i);
            occurances.put(uid, colList);
        }
        return occurances;
    }

    public static void findNaNs(Matrix m) {
        if (!isNaNCheckingEnabled) {
            return;
        }
        for (MatrixEntry e : m) {
            if (!Double.isNaN(e.get())) continue;
            logger.info((Object)("found NaN at " + e.row() + " " + e.column()));
        }
    }

    public static void findNaNs(Vector v) {
        if (!isNaNCheckingEnabled) {
            return;
        }
        for (VectorEntry e : v) {
            if (!Double.isNaN(e.get())) continue;
            logger.info((Object)("found NaN at " + e.index()));
        }
    }

    public static void convertProfileToRanks(Matrix profile) {
        int numGenes = profile.numRows();
        int numFeatures = profile.numColumns();
        for (int i = 0; i < numGenes; ++i) {
            int j;
            DenseVector v = new DenseVector(numFeatures);
            for (j = 0; j < numFeatures; ++j) {
                v.set(j, profile.get(i, j));
            }
            MatrixUtils.tiedRank((Vector)v);
            for (j = 0; j < numFeatures; ++j) {
                profile.set(i, j, v.get(j));
            }
        }
    }
}

