/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.core.integration.attribute;

import java.util.ArrayList;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.AttributeData;
import org.genemania.engine.core.data.AttributeGroups;
import org.genemania.engine.core.data.CoAnnotationSet;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.integration.FeatureLoader;
import org.genemania.engine.core.integration.attribute.IAttributeScorer;
import org.genemania.engine.core.utils.ObjectSelector;
import org.genemania.engine.matricks.Matrix;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;

public class TargetCorrelatedAttributeScorer
implements IAttributeScorer {
    private static Logger logger = Logger.getLogger(TargetCorrelatedAttributeScorer.class);
    DataCache cache;
    String goBranch;

    public TargetCorrelatedAttributeScorer(DataCache cache, String goBranch) {
        this.cache = cache;
        this.goBranch = goBranch;
    }

    @Override
    public ObjectSelector<Feature> scoreAttributes(String namespace, long organismId, long attributeGroupId) throws ApplicationException {
        AttributeData attributeSet = this.cache.getAttributeData(namespace, organismId, attributeGroupId);
        Matrix data = attributeSet.getData();
        int numAttributes = data.numCols();
        int numGenes = data.numRows();
        DenseVector sums = new DenseVector(numAttributes);
        data.columnSums(sums.getData());
        boolean oldMode = false;
        boolean scaled = false;
        FeatureLoader featureLoader = new FeatureLoader(this.cache, namespace, organismId, oldMode, scaled);
        CoAnnotationSet annoSet = this.cache.getCoAnnotationSet(organismId, this.goBranch);
        DenseVector scores = new DenseVector(numAttributes);
        FeatureList candidateFeatures = this.getCandidateFeatures(namespace, organismId, attributeGroupId);
        double n = numGenes;
        double yhatSum = MatrixUtils.sum((Vector)annoSet.GetBHalf());
        double correctionFactor = yhatSum + n * annoSet.GetConstant();
        double diagSum = 0.0;
        for (int i = 0; i < numGenes; ++i) {
            diagSum += annoSet.GetCoAnnotationMatrix().get(i, i);
        }
        correctionFactor += diagSum;
        correctionFactor /= n * n;
        double annoSetSum = annoSet.GetCoAnnotationMatrix().elementSum();
        int i = 0;
        DenseVector tempVec = new DenseVector(numGenes);
        for (Feature feature : candidateFeatures) {
            double score = 0.0;
            double s = sums.get(i);
            double mean = s * (s - 1.0) / (n * n);
            double var = mean * (1.0 - mean);
            SymMatrix network = featureLoader.load(feature);
            double attributeCoannoProd = network.elementMultiplySum(annoSet.GetCoAnnotationMatrix());
            if (attributeCoannoProd == 0.0) {
                score = Double.MIN_VALUE;
            } else {
                double networkSum = s * (s - 1.0);
                network.mult(annoSet.GetBHalf().getData(), tempVec.getData());
                double tempVecSum = MatrixUtils.sum((Vector)tempVec);
                double term1 = attributeCoannoProd + tempVecSum + networkSum * annoSet.GetConstant();
                double term2 = annoSetSum + n * yhatSum + n * n * annoSet.GetConstant();
                score = (term1 - term2) / (Math.sqrt(var) * n * n);
                double correction = -mean / Math.sqrt(var) * correctionFactor;
                score -= correction;
            }
            scores.set(i, score);
            if (++i % 1000 != 0) continue;
            logger.debug((Object)("i: " + i));
        }
        ObjectSelector<Feature> list = this.buildList(namespace, organismId, attributeGroupId, candidateFeatures, scores, sums);
        return list;
    }

    private FeatureList getCandidateFeatures(String namespace, long organismId, long attributeGroupId) throws ApplicationException {
        AttributeGroups groups = this.cache.getAttributeGroups(namespace, organismId);
        FeatureList list = new FeatureList();
        ArrayList<Long> attributeIds = groups.getAttributesForGroup(attributeGroupId);
        for (long attributeId : attributeIds) {
            Feature feature = new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, attributeGroupId, attributeId);
            list.add(feature);
        }
        return list;
    }

    private ObjectSelector<Feature> buildList(String namespace, long organismId, long attributeGroupId, FeatureList candidateFeatures, DenseVector correlations, DenseVector columnSums) throws ApplicationException {
        ObjectSelector<Feature> list = new ObjectSelector<Feature>();
        AttributeGroups groups = this.cache.getAttributeGroups(namespace, organismId);
        correlations = correlations.copy();
        correlations.scale(-1.0);
        MatrixUtils.add((Vector)correlations, 1.0);
        ArrayList<Long> attributeIds = groups.getAttributeGroups().get(attributeGroupId);
        for (int i = 0; i < attributeIds.size(); ++i) {
            long attributeId = attributeIds.get(i);
            double count = columnSums.get(i);
            if (!(count >= 1.0)) continue;
            Feature feature = new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, attributeGroupId, attributeId);
            list.add(feature, correlations.get(i));
        }
        logger.debug((Object)String.format("selected %d attributes", list.size()));
        return list;
    }
}

