/*
 * Decompiled with CFR 0.152.
 */
package org.genemania.engine.apps;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.VectorEntry;
import org.apache.log4j.Appender;
import org.apache.log4j.FileAppender;
import org.apache.log4j.Layout;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;
import org.genemania.domain.AttributeGroup;
import org.genemania.domain.InteractionNetwork;
import org.genemania.domain.InteractionNetworkGroup;
import org.genemania.domain.NetworkMetadata;
import org.genemania.domain.Organism;
import org.genemania.engine.Constants;
import org.genemania.engine.apps.AbstractEngineApp;
import org.genemania.engine.apps.support.LabelWriter;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.cache.FileSerializedObjectCache;
import org.genemania.engine.cache.MemObjectCache;
import org.genemania.engine.cache.SynchronizedObjectCache;
import org.genemania.engine.config.Config;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.GoAnnotations;
import org.genemania.engine.core.data.GoIds;
import org.genemania.engine.core.data.NodeIds;
import org.genemania.engine.core.mania.CoreMania;
import org.genemania.engine.matricks.Matrix;
import org.genemania.engine.utils.FileUtils;
import org.genemania.engine.validation.AucPr;
import org.genemania.engine.validation.AucRoc;
import org.genemania.engine.validation.EvaluationMeasure;
import org.genemania.engine.validation.PrecisionFixedRecall;
import org.genemania.exception.ApplicationException;
import org.genemania.exception.DataStoreException;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;

public class VectorCrossValidator
extends AbstractEngineApp {
    public static final double[] SKIPPED = new double[0];
    private static Logger logger = Logger.getLogger(VectorCrossValidator.class);
    @Option(name="-method", usage="network combination method, should be one of 'equal' or 'smart'")
    private String combiningMethodName;
    @Option(name="-organism", usage="organism name, eg 'Home Sapiens', if -orgid not given")
    private String organismName;
    @Option(name="-orgid", usage="organism id")
    private long organismId = -1L;
    @Option(name="-qfile", usage="file containing gene queries")
    private String queryFileName;
    @Option(name="-out", usage="name of output file to contain validation results")
    private String outFilename;
    @Option(name="-numfolds", usage="number of folds to use for each query, defaults to 5")
    private int numFolds = 5;
    @Option(name="-recallpoint", usage="computes precision at the given recall. E.g. -recallPoint 10 computes precision at 10% recall")
    private String recallPoint;
    @Option(name="-biasing_method", usage="biasing method, defaults to average")
    private String biasingMethod = "average";
    @Option(name="-all_neg_cross_validation", usage="use all genes less positives as negatives")
    private boolean allNegCrossVal;
    @Option(name="-useCachedGoAnnos", usage="use only a go category id 'GO:001234' from each line, and lookup the +ve annotations from engine cache files. this forces all other genes to be negatives (allnegcrossval)")
    private boolean useCachedGoAnnos;
    @Option(name="-netids", usage="comma delim list of network ids to use eg '3,4,19', or 'all', or 'default', or 'preferred' for our selection heuristic.")
    private String networkIdsList;
    @Option(name="-attrIds", usage="comma delim list of attribute group ids.")
    private String attrIds;
    @Option(name="-dump", usage="optional, one of 'organisms' or 'networks'. causes program to dump ids & names and exit without executing any queries")
    private String dumpType;
    @Option(name="-seed", usage="optional, random seed to use when generating cross-validation folds, 0 (default) will select a seed based on system time")
    private long seed;
    @Option(name="-threads", usage="optional, total threads to use for parallel prediction, defaults to 1")
    private int totalThreads = 1;
    @Option(name="-label", usage="optional, output labels to file. defaults to false")
    private boolean writeLabels = false;
    private static final long INTERACTION_COUNT_THRESHOLD = 1000L;
    private static String[] preferredGroupCodes = new String[]{"coexp", "pi", "gi"};
    private Organism organism;
    private List<EvaluationMeasure> measures;
    private PrintWriter writer = null;
    private Collection<String[]> queries;
    private NodeIds nodeIds;
    private Collection<Collection<Long>> idList;
    private Collection<Long> attributeGroupIds;
    private DataCache cache;
    private Map<String, Integer> symbolToIndexCache = Collections.synchronizedMap(new HashMap());
    private int queryCounter = 0;
    private final Object outputMutex = new Object();
    private String namespace;
    private static final int SYMBOL_NOT_FOUND = -1;
    private Integer minimumGeneSetSize;
    private Integer maximumGeneSetSize;

    public long getOrganismId() {
        return this.organismId;
    }

    public void setOrganismId(long organismId) {
        this.organismId = organismId;
    }

    public String getNetworkIdsList() {
        return this.networkIdsList;
    }

    public void setNetworkIdsList(String networkIdsList) {
        this.networkIdsList = networkIdsList;
    }

    public String getAttrIdsList() {
        return this.attrIds;
    }

    public void setAttrIdsList(String attrIds) {
        this.attrIds = attrIds;
    }

    public Collection<Long> getAttrIds() {
        return this.attributeGroupIds;
    }

    public void setAttrIds(Collection<Long> ids) {
        this.attributeGroupIds = ids;
    }

    public void setNetworkIds(Collection<Collection<Long>> ids) {
        this.idList = ids;
    }

    public Collection<Collection<Long>> getNetworkIds() {
        return this.idList;
    }

    public String getBiasingMethod() {
        return this.biasingMethod;
    }

    public void setBiasingMethod(String biasingMethod) {
        this.biasingMethod = biasingMethod;
    }

    public String getCombiningMethodName() {
        return this.combiningMethodName;
    }

    public void setCombiningMethodName(String combiningMethodName) {
        this.combiningMethodName = combiningMethodName;
    }

    @Override
    public String getLogFilename() {
        return this.logFilename;
    }

    @Override
    public void setLogFilename(String logFilename) {
        this.logFilename = logFilename;
    }

    public String getOutFilename() {
        return this.outFilename;
    }

    public void setOutFilename(String outFilename) {
        this.outFilename = outFilename;
    }

    public String getQueryFileName() {
        return this.queryFileName;
    }

    public void setQueryFileName(String queryFileName) {
        this.queryFileName = queryFileName;
    }

    public String getRecallPoint() {
        return this.recallPoint;
    }

    public void setRecallPoint(String recallPoint) {
        this.recallPoint = recallPoint;
    }

    @Override
    public DataCache getCache() {
        return this.cache;
    }

    @Override
    public void setCache(DataCache cache) {
        this.cache = cache;
    }

    public int getQueryCounter() {
        return this.queryCounter;
    }

    public void setQueryCounter(int queryCounter) {
        this.queryCounter = queryCounter;
    }

    public int getNumFolds() {
        return this.numFolds;
    }

    public void setNumFolds(int numFolds) {
        this.numFolds = numFolds;
    }

    public long getSeed() {
        return this.seed;
    }

    public void setSeed(long seed) {
        this.seed = seed;
    }

    public void setAllNegCrossVal(boolean value) {
        this.allNegCrossVal = value;
    }

    public void setThreads(int threads) {
        this.totalThreads = threads;
    }

    public void setUseCachedGoAnnotations(boolean value) {
        this.useCachedGoAnnos = value;
    }

    public void initValidation() throws Exception {
        this.openOutput();
        logger.info((Object)"initializing...");
        this.cache = new DataCache(new SynchronizedObjectCache(new MemObjectCache(new FileSerializedObjectCache(this.getCacheDir()))));
        this.organism = this.getOrganism();
        if (this.idList == null) {
            this.idList = this.getNetworkIdList(this.networkIdsList);
        }
        if (this.attributeGroupIds == null) {
            this.attributeGroupIds = this.getAttributeGroupIdsList(this.attrIds);
            logger.debug((Object)("intialized attributeGroupIds to " + this.attributeGroupIds));
        }
        logger.debug((Object)String.format("regularization enabled: %s, constant: %f", Config.instance().isRegularizationEnabled(), Config.instance().getRegularizationConstant()));
        logger.debug((Object)String.format("attribute pre-selection limit: %d", Config.instance().getAttributeEnrichmentMaxSize()));
        this.measures = new ArrayList<EvaluationMeasure>();
        this.measures.add(new AucRoc("AUC-ROC"));
        this.measures.add(new AucPr("AUC-PR"));
        double recall = 10.0;
        if (this.recallPoint != null) {
            try {
                recall = Double.parseDouble(this.recallPoint);
            }
            catch (NumberFormatException e) {
                logger.warn((Object)e.getMessage());
                logger.info((Object)"setting recall point to 10%");
                recall = 10.0;
            }
        } else {
            this.recallPoint = "10";
        }
        this.measures.add(new PrecisionFixedRecall("PR-" + this.recallPoint, recall));
        this.writeHeader();
        this.loadQueries();
        this.nodeIds = this.cache.getNodeIds(this.organism.getId());
        if (this.seed == 0L) {
            this.seed = System.currentTimeMillis();
        }
        logger.info((Object)("setting random seed to: " + this.seed));
    }

    private void writeHeader() {
        String[] names;
        StringBuilder header = new StringBuilder("queryIdentifier\tfold #");
        for (String name : names = this.getMeasureNames()) {
            header.append("\t" + name);
        }
        header.append("\t#t(+)\t#t(-)\t#v(+)\t#v(-)");
        logger.info((Object)header.toString());
        this.writeOutput(header.toString());
    }

    public void openOutput() throws IOException {
        if (this.outFilename != null) {
            logger.info((Object)("writing network to " + this.outFilename));
            this.writer = new PrintWriter(new File(this.outFilename));
        }
    }

    private void writeOutput(String msg) {
        if (this.writer != null) {
            this.writer.println(msg);
            this.writer.flush();
        }
    }

    public void writeResult(String queryIdentifier, int fold, int numPosT, int numNegT, int numPosV, int numNegV, double[] measures) {
        StringBuilder msg = new StringBuilder(String.format("%s\t%s", queryIdentifier, fold + 1));
        if (measures == SKIPPED) {
            msg.append("\tskipped");
        } else {
            for (double e : measures) {
                msg.append("\t" + e);
            }
            msg.append(String.format("\t%s\t%s\t%s\t%s", numPosT, numNegT, numPosV, numNegV));
        }
        logger.info((Object)msg.toString());
        this.writeOutput(msg.toString());
    }

    public void writeResult(String queryIdentifier, double[] measures) {
        StringBuilder msg = new StringBuilder(String.format("%s\t-", queryIdentifier));
        if (measures == SKIPPED) {
            msg.append("\tskipped");
        } else {
            for (double e : measures) {
                msg.append("\t" + e);
            }
        }
        logger.info((Object)msg.toString());
        this.writeOutput(msg.toString());
    }

    public String[] getMeasureNames() {
        String[] names = new String[this.measures.size()];
        for (int i = 0; i < this.measures.size(); ++i) {
            names[i] = this.measures.get(i).getName();
        }
        return names;
    }

    private Collection<Collection<Long>> getNetworkIdList(String ids) throws ApplicationException, DataStoreException {
        if (ids.equalsIgnoreCase("all")) {
            return this.getAllNetworks();
        }
        if (ids.equalsIgnoreCase("preferred")) {
            return this.getPreferredNetworks();
        }
        if (ids.equalsIgnoreCase("default")) {
            return this.getDefaultNetworks();
        }
        return this.getNetworksById(ids);
    }

    private Collection<Long> getAttributeGroupIdsList(String ids) throws ApplicationException {
        return this.getAttributeGroupsById(ids);
    }

    private Collection<Collection<Long>> getAllNetworks() {
        Collection groups = this.organism.getInteractionNetworkGroups();
        int numFound = 0;
        ArrayList<Collection<Long>> ids = new ArrayList<Collection<Long>>();
        for (InteractionNetworkGroup group : groups) {
            Collection networks = group.getInteractionNetworks();
            ArrayList<Long> list = new ArrayList<Long>();
            for (InteractionNetwork n : networks) {
                list.add(n.getId());
                ++numFound;
            }
            if (list.size() <= 0) continue;
            ids.add(list);
        }
        logger.info((Object)String.format("total %d networks selected", numFound));
        return ids;
    }

    private Collection<Collection<Long>> getNetworksById(String idsArg) throws ApplicationException, DataStoreException {
        String[] parts = idsArg.split(",");
        HashSet<String> wantedIds = new HashSet<String>();
        wantedIds.addAll(Arrays.asList(parts));
        int numFound = 0;
        Collection groups = this.organism.getInteractionNetworkGroups();
        ArrayList<Collection<Long>> ids = new ArrayList<Collection<Long>>();
        for (InteractionNetworkGroup group : groups) {
            Collection networks = group.getInteractionNetworks();
            ArrayList<Long> list = new ArrayList<Long>();
            for (InteractionNetwork n : networks) {
                NetworkMetadata metadata = n.getMetadata();
                String key = "" + n.getId();
                if (!wantedIds.contains(key)) continue;
                logger.info((Object)String.format("using network %d containing %d interactions from group %s: %s", n.getId(), metadata.getInteractionCount(), group.getName(), n.getName()));
                list.add(n.getId());
                ++numFound;
            }
            if (list.size() <= 0) continue;
            ids.add(list);
        }
        if (numFound != parts.length) {
            throw new ApplicationException("some of the specified networks could not be found");
        }
        logger.info((Object)String.format("total %d networks selected", numFound));
        return ids;
    }

    private Collection<Long> getAttributeGroupsById(String idsArg) throws ApplicationException {
        ArrayList<Long> ids = new ArrayList<Long>();
        if (idsArg != null) {
            String[] parts;
            for (String part : parts = idsArg.split(",")) {
                long attributeGroupId = Long.parseLong(part);
                AttributeGroup attributeGroup = this.getAttributeMediator().findAttributeGroup(this.organismId, attributeGroupId);
                if (attributeGroup == null) {
                    throw new ApplicationException("unrecognized attribute group id: " + attributeGroupId);
                }
                ids.add(attributeGroupId);
            }
        }
        return ids;
    }

    private Collection<Collection<Long>> getPreferredNetworks() throws ApplicationException, DataStoreException {
        HashSet<String> preferredGroupSet = new HashSet<String>();
        preferredGroupSet.addAll(Arrays.asList(preferredGroupCodes));
        int numFound = 0;
        Collection groups = this.organism.getInteractionNetworkGroups();
        ArrayList<Collection<Long>> ids = new ArrayList<Collection<Long>>();
        for (InteractionNetworkGroup group : groups) {
            if (!preferredGroupSet.contains(group.getCode())) {
                logger.debug((Object)("skipping all networks in group since not preferred: " + group.getName() + " " + group.getCode()));
                continue;
            }
            Collection networks = group.getInteractionNetworks();
            ArrayList<Long> list = new ArrayList<Long>();
            for (InteractionNetwork n : networks) {
                NetworkMetadata metadata = n.getMetadata();
                if (metadata.getInteractionCount() <= 1000L) continue;
                logger.info((Object)String.format("using network %d containing %d interactions from group %s: %s", n.getId(), metadata.getInteractionCount(), group.getName(), n.getName()));
                list.add(n.getId());
                ++numFound;
            }
            if (list.size() <= 0) continue;
            ids.add(list);
        }
        if (ids.size() == 0) {
            throw new ApplicationException("no preferred networks found!");
        }
        logger.info((Object)String.format("total %d networks selected", numFound));
        return ids;
    }

    private Collection<Collection<Long>> getDefaultNetworks() throws ApplicationException, DataStoreException {
        int numFound = 0;
        Collection groups = this.organism.getInteractionNetworkGroups();
        ArrayList<Collection<Long>> ids = new ArrayList<Collection<Long>>();
        for (InteractionNetworkGroup group : groups) {
            Collection networks = group.getInteractionNetworks();
            ArrayList<Long> list = new ArrayList<Long>();
            for (InteractionNetwork n : networks) {
                if (!n.isDefaultSelected()) continue;
                NetworkMetadata metadata = n.getMetadata();
                logger.info((Object)String.format("using default network %d containing %d interactions from group %s: %s", n.getId(), metadata.getInteractionCount(), group.getName(), n.getName()));
                list.add(n.getId());
                ++numFound;
            }
            if (list.size() <= 0) continue;
            ids.add(list);
        }
        if (ids.size() == 0) {
            throw new ApplicationException("no default networks found!");
        }
        logger.info((Object)String.format("total %d networks selected", numFound));
        return ids;
    }

    private void dump(String option) throws ApplicationException, DataStoreException {
        if (option.equalsIgnoreCase("organisms")) {
            this.dumpOrganisms();
        } else if (option.equalsIgnoreCase("networks")) {
            this.dumpNetworks();
        } else {
            throw new ApplicationException("unknown dump option: " + option);
        }
    }

    private void dumpNetworks() throws ApplicationException, DataStoreException {
        Organism organism = this.getOrganism();
        Collection groups = organism.getInteractionNetworkGroups();
        for (InteractionNetworkGroup group : groups) {
            Collection networks = group.getInteractionNetworks();
            for (InteractionNetwork n : networks) {
                NetworkMetadata metadata = n.getMetadata();
                System.out.println(String.format("network %d contains %d interactions from group %s: %s", n.getId(), metadata.getInteractionCount(), group.getName(), n.getName()));
            }
        }
    }

    private void dumpOrganisms() throws DataStoreException {
        List organisms = this.organismMediator.getAllOrganisms();
        for (Organism o : organisms) {
            System.out.println(String.format("%d: %s", o.getId(), o.getName()));
        }
    }

    @Override
    public void setupLogging() throws Exception {
        if (this.logFilename == null) {
            return;
        }
        PatternLayout layout = new PatternLayout("%d{HH:mm:ss} %-5p: %m%n");
        FileAppender appender = new FileAppender((Layout)layout, this.logFilename, false);
        Logger.getLogger((String)"org.genemania").setLevel(Level.DEBUG);
        Logger.getRootLogger().addAppender((Appender)appender);
    }

    @Override
    public boolean getCommandLineArgs(String[] args) {
        CmdLineParser parser = new CmdLineParser((Object)this);
        try {
            parser.parseArgument(args);
        }
        catch (CmdLineException e) {
            System.err.println(e.getMessage());
            System.err.println("java -jar myprogram.jar [options...] arguments...");
            parser.printUsage((OutputStream)System.err);
            return false;
        }
        return true;
    }

    private Organism getOrganism() throws ApplicationException, DataStoreException {
        Organism organism;
        if (this.organismId != -1L) {
            organism = this.getOrganismById(this.organismId);
        } else if (this.organismName != null) {
            organism = this.getOrganismByName(this.organismName);
        } else {
            throw new ApplicationException("organism not specified");
        }
        logger.info((Object)("quering organism: " + organism.getName()));
        return organism;
    }

    private Organism getOrganismByName(String name) throws ApplicationException, DataStoreException {
        List organisms = this.organismMediator.getAllOrganisms();
        Organism organism = null;
        for (Organism o : organisms) {
            if (!o.getName().equalsIgnoreCase(this.organismName)) continue;
            organism = o;
        }
        if (organism == null) {
            throw new ApplicationException("Failed to find organism " + this.organismName);
        }
        return organism;
    }

    private Organism getOrganismById(long organismId) throws ApplicationException, DataStoreException {
        Organism organism = this.organismMediator.getOrganism(organismId);
        return organism;
    }

    public void loadQueries() throws IOException {
        BufferedReader reader = new BufferedReader(new FileReader(this.queryFileName));
        this.queries = FileUtils.loadRecords(reader, '\t', '\t');
        ((Reader)reader).close();
    }

    private Constants.CombiningMethod getCombiningMethod(String[] queryRecord) throws ApplicationException {
        Constants.CombiningMethod combiningMethod = null;
        if ("auto_detect_category".equalsIgnoreCase(this.combiningMethodName)) {
            throw new ApplicationException("auto_detect_category not implemented");
        }
        combiningMethod = Constants.getCombiningMethod(this.combiningMethodName);
        return combiningMethod;
    }

    private double[] getMeasureResults(Vector initialLabel, Vector discriminant, Collection<Integer> excludedRowIndices) {
        int n = excludedRowIndices.size();
        double[] scores = new double[n];
        boolean[] classes = new boolean[n];
        int i = 0;
        for (Integer index : excludedRowIndices) {
            scores[i] = discriminant.get(index.intValue());
            classes[i] = initialLabel.get(index.intValue()) == 1.0;
            ++i;
        }
        return this.calculateMeasureResults(this.measures, scores, classes);
    }

    private double[] calculateMeasureResults(Collection<EvaluationMeasure> measures, double[] scores, boolean[] classes) {
        double[] results = new double[measures.size()];
        int i = 0;
        for (EvaluationMeasure measure : measures) {
            results[i] = measure.computeResult(classes, scores);
            ++i;
        }
        return results;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void crossValidateVector(CoreMania coreMania, Vector initialLabel, Constants.CombiningMethod method, String goCategory, double[] averageMeasures, int k, int[] allPerm) throws ApplicationException {
        int i;
        DenseVector label = new DenseVector(initialLabel);
        ArrayList<Integer> excludedRowIndices = new ArrayList<Integer>();
        ArrayList<Integer> includedRowIndices = new ArrayList<Integer>();
        double foldSize = (double)allPerm.length * 1.0 / (double)this.numFolds;
        int firstIndex = (int)Math.ceil((double)k * foldSize);
        int lastIndex = (int)Math.ceil((double)(k + 1) * foldSize) - 1;
        for (i = 0; i < firstIndex; ++i) {
            includedRowIndices.add(allPerm[i]);
        }
        for (i = lastIndex + 1; i < allPerm.length; ++i) {
            includedRowIndices.add(allPerm[i]);
        }
        int numPosIncluded = 0;
        int numNegIncluded = 0;
        Iterator i$ = includedRowIndices.iterator();
        while (i$.hasNext()) {
            int index = (Integer)i$.next();
            double value = initialLabel.get(index);
            if (value == 1.0) {
                ++numPosIncluded;
                continue;
            }
            if (value != -1.0) continue;
            ++numNegIncluded;
        }
        int numPos = 0;
        int numNeg = 0;
        for (int j = firstIndex; j <= lastIndex; ++j) {
            excludedRowIndices.add(allPerm[j]);
            label.set(allPerm[j], -2.0);
            double value = initialLabel.get(allPerm[j]);
            if (value == 1.0) {
                ++numPos;
                continue;
            }
            if (value != -1.0) continue;
            ++numNeg;
        }
        logger.info((Object)(MatrixUtils.countMatches((Vector)label, -2.0) + " unknowns in label"));
        this.checkLabels((Vector)label);
        if (!method.isQuerySpecific()) {
            CoreMania j = coreMania;
            synchronized (j) {
                if (coreMania.getCombinedKernel(this.organism.getId(), this.namespace) == null) {
                    logger.info((Object)"computing weights since none saved");
                    coreMania.computeWeights(this.namespace, this.organism.getId(), (Vector)label, method, this.idList, this.attributeGroupIds, Config.instance().getAttributeEnrichmentMaxSize());
                } else {
                    logger.info((Object)"reusing weights");
                }
            }
        } else {
            coreMania = new CoreMania(this.cache);
            coreMania.computeWeights(this.namespace, this.organism.getId(), (Vector)label, method, this.idList, this.attributeGroupIds, Config.instance().getAttributeEnrichmentMaxSize());
        }
        coreMania.computeDiscriminant(this.namespace, this.organism.getId(), (Vector)label, goCategory, this.biasingMethod);
        double[] measures = this.getMeasureResults(initialLabel, coreMania.getDiscriminant(), excludedRowIndices);
        for (int i2 = 0; i2 < measures.length; ++i2) {
            int n = i2;
            averageMeasures[n] = averageMeasures[n] + measures[i2];
        }
        this.writeResult(goCategory, k, numPosIncluded, numNegIncluded, numPos, numNeg, measures);
        if (this.writeLabels) {
            LabelWriter writer = new LabelWriter(this.outFilename, this.nodeMediator, this.organismId);
            writer.write(goCategory, k, initialLabel, coreMania.getDiscriminant(), excludedRowIndices, this.nodeIds);
        }
        ++this.queryCounter;
    }

    private Vector loadAnnosFromCache(String goCategory) throws ApplicationException {
        logger.info((Object)("loading annotations for " + goCategory + " from cache"));
        long organismId = this.organism.getId();
        GoIds goIds = this.cache.getGoIds(organismId, "ALL");
        int goIndex = goIds.getIndexForId(goCategory);
        GoAnnotations annotations = this.cache.getGoAnnotations(organismId, "ALL");
        Matrix data = annotations.getData();
        DenseVector label = new DenseVector(data.numRows());
        for (int i = 0; i < label.size(); ++i) {
            if (data.get(i, goIndex) == 1.0) {
                label.set(i, 1.0);
                continue;
            }
            label.set(i, -1.0);
        }
        return label;
    }

    public Map<String, double[]> crossValidate() throws Exception {
        ArrayList<ValidationTask> tasks = new ArrayList<ValidationTask>();
        for (String[] queryRecord : this.queries) {
            tasks.add(new ValidationTask(queryRecord, this.seed));
        }
        final Iterator jobQueue = tasks.iterator();
        final Object jobMutex = new Object();
        final int[] jobCount = new int[1];
        ArrayList<Thread> threads = new ArrayList<Thread>();
        for (int threadIndex = 0; threadIndex < this.totalThreads; ++threadIndex) {
            final int threadId = threadIndex + 1;
            Thread thread = new Thread(new Runnable(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public void run() {
                    while (true) {
                        try {
                            while (true) {
                                ValidationTask task;
                                Object object = jobMutex;
                                synchronized (object) {
                                    if (!jobQueue.hasNext()) {
                                        return;
                                    }
                                    task = (ValidationTask)jobQueue.next();
                                }
                                jobCount[0] = jobCount[0] + 1;
                                logger.info((Object)String.format("[Thread %d] %d/%d %s", threadId, jobCount[0], VectorCrossValidator.this.queries.size(), task.queryRecord[0]));
                                task.run();
                            }
                        }
                        catch (Throwable t) {
                            logger.error((Object)"Unexpected error", t);
                            continue;
                        }
                        break;
                    }
                }
            });
            threads.add(thread);
            thread.start();
        }
        for (Thread thread : threads) {
            thread.join();
        }
        HashMap<String, double[]> result = new HashMap<String, double[]>();
        for (ValidationTask task : tasks) {
            double[] results = task.getMeasures();
            result.put(task.getQueryId(), results);
        }
        return result;
    }

    /*
     * Enabled aggressive block sorting
     */
    private int populateLabel(Vector initialLabel, String[] queryRecord, boolean autoComputeNegatives) {
        ProcessMode mode = ProcessMode.Scan;
        int totalPositive = 0;
        int totalNegative = 0;
        block5: for (String item : queryRecord) {
            switch (mode) {
                case Scan: {
                    if ("+".equals(item)) {
                        mode = ProcessMode.Positive;
                        break;
                    }
                    if (!"-".equals(item)) break;
                    mode = ProcessMode.Negative;
                    break;
                }
                case Positive: {
                    if ("-".equals(item)) {
                        mode = ProcessMode.Negative;
                        break;
                    }
                    Integer index = this.lookupSymbol(this.organism, item);
                    if (index == null || initialLabel.get(index.intValue()) != 0.0) break;
                    ++totalPositive;
                    initialLabel.set(index.intValue(), 1.0);
                    break;
                }
                case Negative: {
                    if ("%".equals(item)) break block5;
                    Integer index = this.lookupSymbol(this.organism, item);
                    if (index == null || initialLabel.get(index.intValue()) != 0.0) break;
                    ++totalNegative;
                    initialLabel.set(index.intValue(), -1.0);
                }
            }
        }
        if (!autoComputeNegatives && totalNegative == 0) {
            autoComputeNegatives = true;
            logger.warn((Object)String.format("Query %s has no negative examples.  Forcing automatic computation of negatives.", queryRecord[0]));
        }
        if (autoComputeNegatives) {
            for (int i = 0; i < initialLabel.size(); ++i) {
                if (initialLabel.get(i) != 0.0) continue;
                initialLabel.set(i, -1.0);
                ++totalNegative;
            }
        }
        return totalPositive + totalNegative;
    }

    private void checkLabels(Vector labels) {
        int size = labels.size();
        int ones = MatrixUtils.countMatches(labels, 1.0);
        int minus_ones = MatrixUtils.countMatches(labels, -1.0);
        int zeros = MatrixUtils.countMatches(labels, 0.0);
        int excluded = MatrixUtils.countMatches(labels, -2.0);
        int unaccounted = size - ones - minus_ones - zeros - excluded;
        logger.info((Object)String.format("label vector: size %d, +1 %d, -1 %d, 0: %d, excl: %d, unaccounted %d", size, ones, minus_ones, zeros, excluded, unaccounted));
    }

    private Integer lookupSymbol(Organism organism, String symbol) {
        Integer index = this.symbolToIndexCache.get(symbol);
        if (index == null) {
            Long nodeId = this.geneMediator.getNodeId(organism.getId(), symbol);
            if (nodeId == null) {
                logger.info((Object)("symbol not in db: " + symbol));
                this.symbolToIndexCache.put(symbol, -1);
            } else {
                try {
                    index = this.nodeIds.getIndexForId(nodeId);
                    this.symbolToIndexCache.put(symbol, index);
                }
                catch (ApplicationException e) {
                    logger.warn((Object)("gene not in mappings for " + symbol));
                    this.symbolToIndexCache.put(symbol, -1);
                    index = null;
                }
            }
        } else if (index == -1) {
            index = null;
        }
        return index;
    }

    private static void logEngineVersion() {
        String version = Config.instance().getVersion();
        logger.info((Object)("Version: " + version));
    }

    @Override
    public void process() throws Exception {
        if (this.dumpType != null) {
            this.dump(this.dumpType);
        } else {
            this.initValidation();
            this.crossValidate();
        }
    }

    @Override
    public void init() throws Exception {
        super.init();
        VectorCrossValidator.logEngineVersion();
    }

    public static void main(String[] args) throws Exception {
        VectorCrossValidator validator = new VectorCrossValidator();
        if (!validator.getCommandLineArgs(args)) {
            System.exit(1);
        }
        try {
            validator.init();
            validator.process();
            validator.cleanup();
        }
        catch (Exception e) {
            logger.error((Object)"Fatal error", (Throwable)e);
            System.exit(1);
        }
    }

    private int countPositive(Vector label) {
        int totalPositive = 0;
        for (VectorEntry entry : label) {
            if (entry.get() != 1.0) continue;
            ++totalPositive;
        }
        return totalPositive;
    }

    public void setMinimumGeneSetSize(Integer minimumGeneSetSize) {
        this.minimumGeneSetSize = minimumGeneSetSize;
    }

    public void setMaxmimumGeneSetSize(Integer maximumGeneSetSize) {
        this.maximumGeneSetSize = maximumGeneSetSize;
    }

    public void setCacheNamespace(String namespace) {
        this.namespace = namespace;
    }

    public void setWriteLabels(boolean writeLabels) {
        this.writeLabels = writeLabels;
    }

    private class ValidationTask {
        private String[] queryRecord;
        private long randomSeed;
        private double[] averageMeasures;

        ValidationTask(String[] queryRecord, long seed) {
            this.queryRecord = queryRecord;
            this.randomSeed = seed;
        }

        public String getQueryId() {
            return this.queryRecord[0];
        }

        public double[] getMeasures() {
            return this.averageMeasures;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        void run() throws ApplicationException {
            DenseVector initialLabel;
            int totalNodes;
            String goCategory = this.queryRecord[0];
            Constants.CombiningMethod combiningMethod = VectorCrossValidator.this.getCombiningMethod(this.queryRecord);
            if (VectorCrossValidator.this.useCachedGoAnnos) {
                totalNodes = VectorCrossValidator.this.nodeIds.getNodeIds().length;
                initialLabel = VectorCrossValidator.this.loadAnnosFromCache(goCategory);
            } else {
                initialLabel = new DenseVector(VectorCrossValidator.this.nodeIds.getNodeIds().length);
                totalNodes = VectorCrossValidator.this.populateLabel((Vector)initialLabel, this.queryRecord, VectorCrossValidator.this.allNegCrossVal);
            }
            if (VectorCrossValidator.this.minimumGeneSetSize != null || VectorCrossValidator.this.maximumGeneSetSize != null) {
                int totalPositive = VectorCrossValidator.this.countPositive((Vector)initialLabel);
                if (VectorCrossValidator.this.minimumGeneSetSize != null && totalPositive < VectorCrossValidator.this.minimumGeneSetSize) {
                    this.averageMeasures = SKIPPED;
                    return;
                }
                if (VectorCrossValidator.this.maximumGeneSetSize != null && totalPositive > VectorCrossValidator.this.maximumGeneSetSize) {
                    this.averageMeasures = SKIPPED;
                    return;
                }
            }
            int[] allPerm = this.computePermutation(totalNodes, (Vector)initialLabel);
            this.averageMeasures = new double[VectorCrossValidator.this.measures.size()];
            CoreMania coreMania = new CoreMania(VectorCrossValidator.this.cache);
            for (int k = 0; k < VectorCrossValidator.this.numFolds; ++k) {
                logger.debug((Object)String.format("executing fold %d of %d", k + 1, VectorCrossValidator.this.numFolds));
                VectorCrossValidator.this.crossValidateVector(coreMania, (Vector)initialLabel, combiningMethod, goCategory, this.averageMeasures, k, allPerm);
            }
            int i = 0;
            while (i < this.averageMeasures.length) {
                int n = i++;
                this.averageMeasures[n] = this.averageMeasures[n] / (double)VectorCrossValidator.this.numFolds;
            }
            Object object = VectorCrossValidator.this.outputMutex;
            synchronized (object) {
                VectorCrossValidator.this.writeResult(goCategory, this.averageMeasures);
            }
        }

        private int[] computePermutation(int totalNodes, Vector initialLabel) {
            int labelSize = VectorCrossValidator.this.nodeIds.getNodeIds().length;
            int[] permutations = MatrixUtils.permutation(labelSize, new Random(this.randomSeed));
            if (totalNodes == labelSize) {
                return permutations;
            }
            int[] nodeMap = new int[totalNodes];
            int index = 0;
            for (int i = 0; i < permutations.length; ++i) {
                int permutatedIndex = permutations[i];
                if (initialLabel.get(permutatedIndex) == 0.0) continue;
                nodeMap[index] = permutatedIndex;
                ++index;
            }
            return nodeMap;
        }
    }

    static enum ProcessMode {
        Scan,
        Positive,
        Negative;

    }
}

