/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.core.integration;

import java.util.ArrayList;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.AttributeData;
import org.genemania.engine.core.data.AttributeGroups;
import org.genemania.engine.core.data.CoAnnotationSet;
import org.genemania.engine.core.integration.AbstractAttributeSelector;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.integration.FeatureLoader;
import org.genemania.engine.matricks.Matrix;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;

public class CorrelatedAttributeSelector
extends AbstractAttributeSelector {
    private static Logger logger = Logger.getLogger(CorrelatedAttributeSelector.class);
    DataCache cache;
    String goBranch;
    private int maxAttributes;
    private String namespace = "CORE";

    public CorrelatedAttributeSelector(DataCache cache, String goBranch, int maxAttributes) {
        this.cache = cache;
        this.goBranch = goBranch;
        this.maxAttributes = maxAttributes;
    }

    @Override
    public FeatureList selectAttributes(long organismId, long attributeGroupId) throws ApplicationException {
        AttributeData attributeSet = this.cache.getAttributeData(this.namespace, organismId, attributeGroupId);
        Matrix data = attributeSet.getData();
        int numAttributes = data.numCols();
        int numGenes = data.numRows();
        DenseVector sums = new DenseVector(numAttributes);
        data.columnSums(sums.getData());
        boolean oldMode = false;
        boolean scaled = false;
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, organismId, oldMode, scaled);
        CoAnnotationSet annoSet = this.cache.getCoAnnotationSet(organismId, this.goBranch);
        DenseVector scores = new DenseVector(numAttributes);
        FeatureList candidateFeatures = this.getCandidateFeatures(organismId, attributeGroupId);
        double n = numGenes;
        double yhatSum = MatrixUtils.sum((Vector)annoSet.GetBHalf());
        double correctionFactor = yhatSum + n * annoSet.GetConstant();
        double diagSum = 0.0;
        for (int i = 0; i < numGenes; ++i) {
            diagSum += annoSet.GetCoAnnotationMatrix().get(i, i);
        }
        correctionFactor += diagSum;
        correctionFactor /= n * n;
        double annoSetSum = annoSet.GetCoAnnotationMatrix().elementSum();
        int i = 0;
        DenseVector tempVec = new DenseVector(numGenes);
        for (Feature feature : candidateFeatures) {
            double score = 0.0;
            double s = sums.get(i);
            double mean = s * (s - 1.0) / (n * n);
            double var = mean * (1.0 - mean);
            SymMatrix network = featureLoader.load(feature);
            double attributeCoannoProd = network.elementMultiplySum(annoSet.GetCoAnnotationMatrix());
            if (attributeCoannoProd == 0.0) {
                score = Double.MIN_VALUE;
            } else {
                double networkSum = s * (s - 1.0);
                network.mult(annoSet.GetBHalf().getData(), tempVec.getData());
                double tempVecSum = MatrixUtils.sum((Vector)tempVec);
                double term1 = attributeCoannoProd + tempVecSum + networkSum * annoSet.GetConstant();
                double term2 = annoSetSum + n * yhatSum + n * n * annoSet.GetConstant();
                score = (term1 - term2) / (Math.sqrt(var) * n * n);
                double correction = -mean / Math.sqrt(var) * correctionFactor;
                score -= correction;
            }
            scores.set(i, score);
            if (++i % 1000 != 0) continue;
            logger.debug((Object)("i: " + i));
        }
        FeatureList featureList = this.selectTopFeatures(organismId, attributeGroupId, candidateFeatures, scores, sums);
        return featureList;
    }

    private FeatureList getCandidateFeatures(long organismId, long attributeGroupId) throws ApplicationException {
        AttributeGroups groups = this.cache.getAttributeGroups(this.namespace, organismId);
        FeatureList list = new FeatureList();
        ArrayList<Long> attributeIds = groups.getAttributesForGroup(attributeGroupId);
        for (long attributeId : attributeIds) {
            Feature feature = new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, attributeGroupId, attributeId);
            list.add(feature);
        }
        return list;
    }

    private FeatureList selectTopFeatures(long organismId, long attributeGroupId, FeatureList candidateFeatures, DenseVector correlations, DenseVector columnSums) throws ApplicationException {
        AttributeGroups groups = this.cache.getAttributeGroups(this.namespace, organismId);
        FeatureList list = new FeatureList();
        DenseVector ranks = correlations.copy();
        ranks.scale(-1.0);
        MatrixUtils.add((Vector)ranks, 1.0);
        MatrixUtils.rank((Vector)ranks);
        DenseVector ordered = new DenseVector(correlations.size());
        int[] unrank = new int[correlations.size()];
        int i = 0;
        while (i < ranks.size()) {
            int p = (int)Math.round(ranks.get(i)) - 1;
            ordered.set(p, correlations.get(i));
            unrank[p] = i++;
        }
        int max = ranks.size();
        if (this.maxAttributes > 0) {
            max = Math.min(max, this.maxAttributes);
        }
        for (int i2 = 0; i2 < max; ++i2) {
            int attributeIndex = unrank[i2];
            double count = columnSums.get(attributeIndex);
            if (!(count >= 1.0)) continue;
            long attributeId = groups.getAttributeIdForIndex(attributeGroupId, attributeIndex);
            Feature feature = new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, attributeGroupId, attributeId);
            list.add(feature);
        }
        logger.debug((Object)String.format("selected %d attributes", list.size()));
        return list;
    }
}

