/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify.evaluate;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.Trial;
import cc.mallet.classify.evaluate.Graph2;
import cc.mallet.classify.evaluate.GraphItem;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.PrintUtilities;
import java.awt.Color;
import java.awt.Container;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Toolkit;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Vector;
import java.util.logging.Logger;
import javax.swing.JButton;
import javax.swing.JFrame;

public class AccuracyCoverage
implements ActionListener {
    private static Logger logger = MalletLogger.getLogger(AccuracyCoverage.class.getName());
    static final int DEFAULT_NUM_BUCKETS = 20;
    static final int DEFAULT_MAX_X = 100;
    private ArrayList classifications;
    private double[] accuracyValues;
    private int numBuckets;
    private double step;
    private Graph2 graph;
    private JFrame frame;

    public AccuracyCoverage(Trial t, int numBuckets, String title, String dataName) {
        this.classifications = t;
        this.numBuckets = numBuckets;
        this.step = 100.0 / (double)numBuckets;
        this.accuracyValues = new double[numBuckets];
        this.frame = null;
        logger.info("Constructing AccCov with " + this.classifications.size());
        this.sortClassifications();
        this.createAccuracyArray();
        this.graph = new Graph2(title, 0, 100, "Coverage", "Accuracy");
        this.addDataToGraph(this.accuracyValues, numBuckets, dataName);
    }

    public AccuracyCoverage(Trial t, String title, String name) {
        this(t, 20, title, name);
    }

    public AccuracyCoverage(Trial t, String title) {
        this(t, 20, title, "unnamed");
    }

    public AccuracyCoverage(Classifier C, InstanceList ilist, String title) {
        this(new Trial(C, ilist), 20, title, "unnamed");
    }

    public AccuracyCoverage(Classifier C, InstanceList ilist, int numBuckets, String title) {
        this(new Trial(C, ilist), numBuckets, title, "unnamed");
    }

    public double cumulativeAccuracy() {
        double area = 0.0;
        for (int i = 1; i < 100; ++i) {
            double leftAccuracy = this.accuracyAtCoverage((double)i / 100.0);
            double rightAccuracy = this.accuracyAtCoverage((double)(i + 1) / 100.0);
            area += 0.5 * (leftAccuracy + rightAccuracy);
        }
        return area;
    }

    public void createAccuracyArray() {
        for (int i = 0; i < this.numBuckets; ++i) {
            this.accuracyValues[i] = this.accuracyAtCoverage(this.step * (double)(i + 1) / 100.0);
        }
    }

    public double accuracyAtCoverage(double cov) {
        assert (cov <= 1.0 && cov > 0.0);
        int numTrials = (int)Math.round((double)this.classifications.size() * cov);
        int numCorrect = 0;
        for (int i = this.classifications.size() - 1; i >= this.classifications.size() - numTrials; --i) {
            Classification temp = (Classification)this.classifications.get(i);
            if (!temp.bestLabelIsCorrect()) continue;
            ++numCorrect;
        }
        return (double)numCorrect / (double)numTrials;
    }

    public void sortClassifications() {
        Collections.sort(this.classifications, new ClassificationComparator());
    }

    public void addDataToGraph(double[] accValues, int nBuckets, String name) {
        Vector<GraphItem> values = new Vector<GraphItem>(nBuckets);
        for (int i = 0; i < nBuckets; ++i) {
            GraphItem temp = new GraphItem("", (int)(accValues[i] * 100.0), Color.black);
            values.add(temp);
        }
        logger.info("Sending " + values.size() + " elements to graph");
        this.graph.addItemVector(values, name);
    }

    public void displayGraph() {
        Vector values = new Vector(this.numBuckets);
        JButton printButton = new JButton("Print");
        this.frame = new JFrame("Graph");
        DecimalFormat df = new DecimalFormat();
        printButton.addActionListener(this);
        this.frame.addWindowListener(new WindowAdapter(){

            @Override
            public void windowClosing(WindowEvent e) {
                System.exit(0);
            }
        });
        Container pane = this.frame.getContentPane();
        pane.setLayout(new FlowLayout());
        assert (this.graph != null);
        pane.add(this.graph);
        pane.add(printButton);
        this.frame.pack();
        Toolkit toolkit = Toolkit.getDefaultToolkit();
        Dimension scrnsize = toolkit.getScreenSize();
        Dimension framesize = this.frame.getSize();
        this.frame.setLocation((int)(scrnsize.getWidth() - (double)this.frame.getWidth()) / 2, (int)(scrnsize.getHeight() - (double)this.frame.getHeight()) / 2);
        this.frame.setVisible(true);
    }

    @Override
    public void actionPerformed(ActionEvent event) {
        PrintUtilities.printComponent(this.graph);
    }

    public void addTrial(Trial t, String name) {
        this.addTrial(t, 20, name);
    }

    public void addTrial(Trial t, int nBuckets, String name) {
        AccuracyCoverage newData = new AccuracyCoverage(t, nBuckets, "untitled", name);
        double[] accValues = newData.accuracyValues();
        this.addDataToGraph(accValues, nBuckets, name);
    }

    public double[] accuracyValues() {
        return this.accuracyValues;
    }

    public class ClassificationComparator
    implements Comparator {
        public final int compare(Object a, Object b) {
            LabelVector x = ((Classification)a).getLabelVector();
            LabelVector y = ((Classification)b).getLabelVector();
            double difference = x.getBestValue() - y.getBestValue();
            int toReturn = 0;
            if (difference > 0.0) {
                toReturn = 1;
            } else if (difference < 0.0) {
                toReturn = -1;
            }
            return toReturn;
        }
    }
}

