/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.neural.rnn;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Sentence;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Map;
import java.util.PriorityQueue;

public class TopNGramRecord {
    private final int ngramCount;
    private final int numClasses;
    private final int maximumLength;
    Map<Integer, Map<Integer, PriorityQueue<Tree>>> classToNGrams = Generics.newHashMap();

    public TopNGramRecord(int numClasses, int ngramCount, int maximumLength) {
        this.numClasses = numClasses;
        this.ngramCount = ngramCount;
        this.maximumLength = maximumLength;
        for (int i = 0; i < numClasses; ++i) {
            Map innerMap = Generics.newHashMap();
            this.classToNGrams.put(i, innerMap);
        }
    }

    public void countTree(Tree tree) {
        Tree simplified = this.simplifyTree(tree);
        for (int i = 0; i < this.numClasses; ++i) {
            this.countTreeHelper(simplified, i, this.classToNGrams.get(i));
        }
    }

    private Tree simplifyTree(Tree tree) {
        CoreLabel newLabel = new CoreLabel();
        newLabel.set(RNNCoreAnnotations.Predictions.class, RNNCoreAnnotations.getPredictions(tree));
        newLabel.setValue(tree.label().value());
        if (tree.isLeaf()) {
            return tree.treeFactory().newLeaf(newLabel);
        }
        ArrayList<Tree> children = Generics.newArrayList(tree.children().length);
        for (int i = 0; i < tree.children().length; ++i) {
            children.add(this.simplifyTree(tree.children()[i]));
        }
        return tree.treeFactory().newTreeNode(newLabel, children);
    }

    private int countTreeHelper(Tree tree, int prediction, Map<Integer, PriorityQueue<Tree>> ngrams) {
        if (tree.isLeaf()) {
            return 1;
        }
        int treeSize = 0;
        for (Tree child : tree.children()) {
            treeSize += this.countTreeHelper(child, prediction, ngrams);
        }
        if (this.maximumLength > 0 && treeSize > this.maximumLength) {
            return treeSize;
        }
        PriorityQueue<Tree> queue = this.getPriorityQueue(treeSize, prediction, ngrams);
        if (!queue.contains(tree)) {
            queue.add(tree);
        }
        if (queue.size() > this.ngramCount) {
            queue.poll();
        }
        return treeSize;
    }

    private PriorityQueue<Tree> getPriorityQueue(int size, int prediction, Map<Integer, PriorityQueue<Tree>> ngrams) {
        PriorityQueue<Tree> queue = ngrams.get(size);
        if (queue != null) {
            return queue;
        }
        queue = new PriorityQueue<Tree>(this.ngramCount + 1, this.scoreComparator(prediction));
        ngrams.put(size, queue);
        return queue;
    }

    private Comparator<Tree> scoreComparator(final int prediction) {
        return new Comparator<Tree>(){

            @Override
            public int compare(Tree tree1, Tree tree2) {
                double score2;
                double score1 = RNNCoreAnnotations.getPredictions(tree1).get(prediction);
                if (score1 < (score2 = RNNCoreAnnotations.getPredictions(tree2).get(prediction))) {
                    return -1;
                }
                if (score1 > score2) {
                    return 1;
                }
                return 0;
            }
        };
    }

    public String toString() {
        StringBuilder result = new StringBuilder();
        for (int prediction = 0; prediction < this.numClasses; ++prediction) {
            result.append("Best scores for class " + prediction + "\n");
            Map<Integer, PriorityQueue<Tree>> ngrams = this.classToNGrams.get(prediction);
            for (Map.Entry<Integer, PriorityQueue<Tree>> entry : ngrams.entrySet()) {
                ArrayList trees = Generics.newArrayList((Collection)entry.getValue());
                Collections.sort(trees, this.scoreComparator(prediction));
                result.append("  Len " + entry.getKey() + "\n");
                for (int i = trees.size() - 1; i >= 0; --i) {
                    Tree tree = (Tree)trees.get(i);
                    result.append("    " + Sentence.listToString(tree.yield()) + "  [" + RNNCoreAnnotations.getPredictions(tree).get(prediction) + "]\n");
                }
            }
        }
        return result.toString();
    }
}

