/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.core.integration.gram;

import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.CoAnnotationSet;
import org.genemania.engine.core.data.DatasetInfo;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.integration.FeatureLoader;
import org.genemania.engine.exception.CancellationException;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

public class BasicGramBuilder {
    private static Logger logger = Logger.getLogger(BasicGramBuilder.class);
    DataCache cache;
    String namespace;
    long organismId;
    ProgressReporter progress;

    public BasicGramBuilder(DataCache cache, String namespace, long organismId, ProgressReporter progress) {
        this.cache = cache;
        this.namespace = namespace;
        this.organismId = organismId;
        this.progress = progress;
    }

    public DenseMatrix buildBasicKtK(FeatureList featureList, ProgressReporter reporter) throws ApplicationException {
        BasicGramBuilder.checkFeatureList(featureList, true);
        int size = featureList.size();
        int numGenes = this.cache.getNodeIds(this.organismId).getNodeIds().length;
        DenseMatrix KtK2 = new DenseMatrix(size, size);
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId);
        KtK2.set(0, 0, (double)(numGenes * numGenes));
        for (int i = 1; i < size; ++i) {
            Feature iFeature = (Feature)featureList.get(i);
            SymMatrix network_i = featureLoader.load(iFeature);
            double networkSum = network_i.elementSum();
            KtK2.set(i, 0, networkSum);
            KtK2.set(0, i, networkSum);
            for (int j = 1; j <= i; ++j) {
                if (this.progress.isCanceled()) {
                    throw new CancellationException();
                }
                Feature jFeature = (Feature)featureList.get(j);
                SymMatrix network_j = featureLoader.load(jFeature);
                double prodSum = network_j.elementMultiplySum(network_i);
                KtK2.set(i, j, prodSum);
                KtK2.set(j, i, prodSum);
            }
        }
        return KtK2;
    }

    public DenseMatrix buildKtT(FeatureList featureList, CoAnnotationSet annoSet, ProgressReporter reporter) throws ApplicationException {
        BasicGramBuilder.checkFeatureList(featureList, true);
        DatasetInfo info = this.cache.getDatasetInfo(this.organismId);
        int goBranchNum = Constants.getIndexForGoBranch(annoSet.getGoBranch());
        int numberOfCategories = info.getNumCategories()[goBranchNum];
        int numberOfGenes = info.getNumGenes();
        SymMatrix CoAnnotationMatrix = annoSet.GetCoAnnotationMatrix();
        DenseVector BHalf = annoSet.GetBHalf();
        double constant = annoSet.GetConstant();
        CoAnnotationMatrix.setDiag(0.0);
        int size = featureList.size();
        logger.debug((Object)("Number of Genes " + numberOfGenes + ", Number of Categories " + numberOfCategories + ", Number of networks: " + (size - 1)));
        double biasValue = (double)numberOfGenes * (double)numberOfGenes * (double)numberOfCategories;
        logger.debug((Object)("biasValue: " + biasValue));
        DenseMatrix Ktt = new DenseMatrix(size, 1);
        Ktt.set(0, 0, MatrixUtils.sum((Vector)BHalf) * (double)numberOfGenes + CoAnnotationMatrix.elementSum() + constant * (double)numberOfGenes * (double)numberOfGenes);
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId);
        logger.debug((Object)("Ktt bias value is " + Ktt.get(0, 0)));
        for (int i = 1; i < size; ++i) {
            if (this.progress.isCanceled()) {
                throw new CancellationException();
            }
            Feature iFeature = (Feature)featureList.get(i);
            SymMatrix network_i = featureLoader.load(iFeature);
            double val = BasicGramBuilder.computeKttElement(numberOfGenes, network_i, CoAnnotationMatrix, BHalf, constant);
            Ktt.set(i, 0, val);
        }
        return Ktt;
    }

    public static double computeKttElement(int numberOfGenes, SymMatrix network, SymMatrix CoAnnotationMatrix, DenseVector BHalf, double constant) {
        double result = 0.0;
        double networkSum = network.elementSum();
        DenseVector tempVec = new DenseVector(numberOfGenes);
        network.mult(BHalf.getData(), tempVec.getData());
        double tempVecSum = MatrixUtils.sum((Vector)tempVec);
        result = network.elementMultiplySum(CoAnnotationMatrix) + tempVecSum + networkSum * constant;
        return result;
    }

    public DenseMatrix updateBasicKtK(DenseMatrix KtK2, FeatureList featureList, FeatureList featuresToAdd, ProgressReporter reporter) throws ApplicationException {
        int i;
        BasicGramBuilder.checkFeatureList(featureList, true);
        BasicGramBuilder.checkFeatureList(featuresToAdd, false);
        int oldSize = featureList.size();
        int numFeaturesToAdd = featuresToAdd.size();
        int newSize = oldSize + numFeaturesToAdd;
        logger.debug((Object)"allocating new KtK and copying data over");
        DenseMatrix KtKNew = new DenseMatrix(newSize, newSize);
        for (int i2 = 0; i2 < oldSize; ++i2) {
            for (int j = 0; j < oldSize; ++j) {
                KtKNew.set(i2, j, KtK2.get(i2, j));
            }
        }
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId);
        logger.debug((Object)"preloading new features");
        SymMatrix[] newFeatures = new SymMatrix[numFeaturesToAdd];
        for (int j = 0; j < numFeaturesToAdd; ++j) {
            SymMatrix network_j;
            Feature jFeature = (Feature)featuresToAdd.get(j);
            newFeatures[j] = network_j = featureLoader.load(jFeature);
        }
        logger.debug((Object)String.format("computing products between %d new and %d old features", numFeaturesToAdd, oldSize));
        for (i = 1; i < oldSize; ++i) {
            Feature iFeature = (Feature)featureList.get(i);
            SymMatrix network_i = featureLoader.load(iFeature);
            for (int j = 0; j < numFeaturesToAdd; ++j) {
                if (this.progress.isCanceled()) {
                    throw new CancellationException();
                }
                SymMatrix network_j = newFeatures[j];
                double prodSum = network_i.elementMultiplySum(network_j);
                KtKNew.set(i, j + oldSize, prodSum);
                KtKNew.set(j + oldSize, i, prodSum);
            }
        }
        logger.debug((Object)String.format("computing products between %d new features, and their biases", numFeaturesToAdd));
        for (i = 0; i < numFeaturesToAdd; ++i) {
            SymMatrix network_i = newFeatures[i];
            double networkSum = network_i.elementSum();
            KtKNew.set(i + oldSize, 0, networkSum);
            KtKNew.set(0, i + oldSize, networkSum);
            for (int j = 0; j < numFeaturesToAdd; ++j) {
                if (this.progress.isCanceled()) {
                    throw new CancellationException();
                }
                Feature jFeature = (Feature)featuresToAdd.get(j);
                SymMatrix network_j = featureLoader.load(jFeature);
                double prodSum = network_i.elementMultiplySum(network_j);
                KtKNew.set(i + oldSize, j + oldSize, prodSum);
                KtKNew.set(j + oldSize, i + oldSize, prodSum);
            }
        }
        return KtKNew;
    }

    public DenseMatrix updateKtT(DenseMatrix Ktt, FeatureList featureList, FeatureList featuresToAdd, CoAnnotationSet annoSet, ProgressReporter reporter) throws ApplicationException {
        BasicGramBuilder.checkFeatureList(featureList, true);
        BasicGramBuilder.checkFeatureList(featuresToAdd, false);
        DatasetInfo info = this.cache.getDatasetInfo(this.organismId);
        int numberOfGenes = info.getNumGenes();
        SymMatrix CoAnnotationMatrix = annoSet.GetCoAnnotationMatrix();
        DenseVector BHalf = annoSet.GetBHalf();
        double constant = annoSet.GetConstant();
        int oldSize = featureList.size();
        int numFeaturesToAdd = featuresToAdd.size();
        int newSize = oldSize + numFeaturesToAdd;
        DenseMatrix KttNew = new DenseMatrix(newSize, 1);
        for (int i = 0; i < oldSize; ++i) {
            KttNew.set(i, 0, Ktt.get(i, 0));
        }
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId);
        for (int i = 0; i < numFeaturesToAdd; ++i) {
            if (this.progress.isCanceled()) {
                throw new CancellationException();
            }
            Feature iFeature = (Feature)featuresToAdd.get(i);
            SymMatrix network_i = featureLoader.load(iFeature);
            double val = BasicGramBuilder.computeKttElement(numberOfGenes, network_i, CoAnnotationMatrix, BHalf, constant);
            KttNew.set(i + oldSize, 0, val);
        }
        return KttNew;
    }

    public static void checkFeatureList(FeatureList featureList, boolean hasBias) throws ApplicationException {
        if (hasBias) {
            if (((Feature)featureList.get(0)).getType() != Constants.NetworkType.BIAS) {
                throw new ApplicationException("must include bias in first row/col");
            }
        } else if (((Feature)featureList.get(0)).getType() == Constants.NetworkType.BIAS) {
            throw new ApplicationException("must not include bias in first row/col");
        }
    }
}

