/*
 * Decompiled with CFR 0.152.
 */
package bham.leakiest;

import bham.leakiest.State;
import bham.leakiest.TestInfoLeak;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class ProbDist {
    private State[] sts;
    private double[] pmf;
    private int size;
    private int numJoint = 0;
    private HashMap<State, Double> dist;
    private ProbDist[] allMarginals;
    private boolean forbid_overwrite;
    private final int NOT_FOUND = -1;
    private final int NAN = -1;
    private static int verbose = TestInfoLeak.verbose;
    protected static final double ERROR = -1.0;
    protected static final double accuracy = 1.0E-10;

    public ProbDist(int numStates) {
        this.size = numStates;
        this.forbid_overwrite = false;
        this.sts = new State[numStates];
        this.pmf = new double[numStates];
        this.dist = new HashMap();
        if (verbose >= 5) {
            System.out.println("A probability distribution is created.");
        }
    }

    private ProbDist(State[] sts_in, double[] pmf_in, boolean lock, boolean copy) {
        if (sts_in.length == pmf_in.length) {
            this.size = sts_in.length;
            this.forbid_overwrite = lock;
            this.sts = new State[this.size];
            this.pmf = new double[this.size];
            this.dist = new HashMap();
            if (copy) {
                System.arraycopy(sts_in, 0, this.sts, 0, sts_in.length);
                System.arraycopy(pmf_in, 0, this.pmf, 0, pmf_in.length);
                int i = 0;
                while (i < this.size) {
                    this.dist.put(sts_in[i], pmf_in[i]);
                    ++i;
                }
            } else {
                this.sts = sts_in;
                this.pmf = pmf_in;
            }
            if (verbose >= 5) {
                System.out.println("A probability distribution is created.");
            }
        } else {
            System.out.println("Two arrays (input to ProbDist) have different lengths.");
        }
    }

    public ProbDist(State[] sts_in, double[] pmf_in) {
        this(sts_in, pmf_in, true, true);
    }

    public ProbDist(State[] sts_in, double[] pmf_in, boolean lock) {
        this(sts_in, pmf_in, lock, true);
    }

    public ProbDist(String[] stateNames, double[] pmf) {
        this(stateNames, pmf, true);
    }

    public ProbDist(String[] stateNames, double[] pmf, boolean lock) {
        State st;
        int numStates = stateNames.length;
        HashMap<State, Double> dist = new HashMap<State, Double>();
        int i = 0;
        while (i < numStates) {
            st = new State();
            st.updateValue("input", stateNames[i]);
            double prob = pmf[i];
            dist.put(st, prob);
            ++i;
        }
        this.dist = dist;
        this.sts = new State[dist.size()];
        this.pmf = new double[dist.size()];
        i = 0;
        Iterator<State> iterator = this.dist.keySet().iterator();
        while (iterator.hasNext()) {
            this.sts[i] = st = iterator.next();
            this.pmf[i] = (Double)dist.get(st);
            ++i;
        }
    }

    public ProbDist(HashMap<State, Double> dist_in, boolean lock) {
        this.dist = dist_in;
        this.sts = new State[this.dist.size()];
        this.pmf = new double[this.dist.size()];
        int i = 0;
        Iterator<State> iterator = this.dist.keySet().iterator();
        while (iterator.hasNext()) {
            State st;
            this.sts[i] = st = iterator.next();
            this.pmf[i] = this.dist.get(st);
            ++i;
        }
    }

    public static double[] uniformProbArray(int noOfInputs) {
        double[] pmf = new double[noOfInputs];
        int i = 0;
        while (i < pmf.length) {
            pmf[i] = 1.0 / (double)noOfInputs;
            ++i;
        }
        return pmf;
    }

    public static ProbDist uniformProbDist(String[] stateNames, boolean lock) {
        int numStates = stateNames.length;
        HashMap<State, Double> dist = new HashMap<State, Double>();
        int i = 0;
        while (i < numStates) {
            State st = new State();
            st.updateValue("input", stateNames[i]);
            double prob = 1.0 / (double)numStates;
            dist.put(st, prob);
            ++i;
        }
        ProbDist pd = new ProbDist(dist, lock);
        return pd;
    }

    public static double[] cumulativeProbArray(double[] pmf) {
        double[] cdf = new double[pmf.length];
        cdf[0] = pmf[0];
        int i = 1;
        while (i < pmf.length) {
            cdf[i] = cdf[i - 1] + pmf[i];
            ++i;
        }
        return cdf;
    }

    public ProbDist cumulativeProbDist() {
        State[] sts = this.getStatesArray();
        double[] cdf = ProbDist.cumulativeProbArray(this.getPMFArray());
        ProbDist cpd = new ProbDist(sts, cdf);
        return cpd;
    }

    public ProbDist sharedProbDist(int numJoint, boolean lock) {
        HashMap<State, Double> jdist = new HashMap<State, Double>();
        for (State st : this.dist.keySet()) {
            try {
                String str = st.getValue("input");
                String jstr = "(";
                int i = 0;
                while (i < numJoint) {
                    jstr = String.valueOf(jstr) + str;
                    if (i != numJoint - 1) {
                        jstr = String.valueOf(jstr) + ", ";
                    }
                    ++i;
                }
                jstr = String.valueOf(jstr) + ")";
                double jprob = this.dist.get(st);
                State jst = new State();
                jst.updateValue("input", jstr);
                jdist.put(jst, jprob);
            }
            catch (Exception ex0) {
                System.out.println("Error in calculating the shared input distribution: " + ex0);
                System.exit(1);
            }
        }
        ProbDist pd = new ProbDist(jdist, lock);
        if (verbose >= 5) {
            pd.printProbDist();
        }
        return pd;
    }

    public double getProb(String str) {
        double result = 0.0;
        for (State st : this.dist.keySet()) {
            if (!st.getValue("input").equals(str)) continue;
            result = this.dist.get(st);
            break;
        }
        return result;
    }

    public double getProb(State st) {
        return this.dist.get(st);
    }

    public void updateProb(State st, double prob) {
        if (this.forbid_overwrite) {
            System.out.println("Error in updateProb: Cannot overwrite the probability distribution.");
        } else if (prob >= 0.0 && prob <= 1.0) {
            this.dist.put(st, prob);
            if (verbose >= 7) {
                System.out.println("Added Pr" + st.stringState() + " = " + prob);
            }
        } else {
            System.out.println("Error in updateProb: " + prob + " is not a probability.");
        }
    }

    public void removeProb(State st) {
        if (this.forbid_overwrite) {
            System.out.println("Error in updateProb: Cannot overwrite the probability distribution.");
            return;
        }
        this.dist.put(st, 0.0);
    }

    public boolean isWellDefined(double error) {
        boolean wellDefined = true;
        double sum = 0.0;
        for (Double prob : this.dist.values()) {
            sum += prob.doubleValue();
        }
        if (sum > 1.0 + error || sum < 1.0 - error) {
            wellDefined = false;
            System.out.println("Sum of probabilities is " + sum + ".");
        }
        return wellDefined;
    }

    public void checkWellDefined(double error) {
        if (this.isWellDefined(error)) {
            System.out.println("The state is well-defined.");
        } else {
            System.out.println("Error: The state is not well-defined.");
        }
    }

    public int sizeSampleSpace() {
        return this.dist.size();
    }

    public Collection<State> getStatesCollection() {
        return this.dist.keySet();
    }

    public State[] getStatesArray() {
        Set<State> set = this.dist.keySet();
        this.sts = set.toArray(new State[set.size()]);
        return this.sts;
    }

    public Collection<Double> getPMFCollection() {
        return this.dist.values();
    }

    public double[] getPMFArray() {
        Collection<Double> set = this.dist.values();
        double[] pmf = new double[set.size()];
        int i = 0;
        Iterator<Double> iterator = set.iterator();
        while (iterator.hasNext()) {
            double d;
            pmf[i] = d = iterator.next().doubleValue();
            ++i;
        }
        return pmf;
    }

    public State[] probDistToStatesArray(String[] inputNames) {
        if (inputNames.length != this.sts.length) {
            System.out.println("Error: The size of the (prior) input domain does not match the channel matrix.");
            System.out.println("  the input domain size: " + this.sts.length);
            System.out.println("  the number of rows in the channel matrix: " + inputNames.length);
            System.out.println("Failed to produce an array of states from the probability distribution.");
            System.exit(1);
        }
        State[] result = new State[inputNames.length];
        int i = 0;
        while (i < inputNames.length) {
            boolean found = false;
            for (State st : this.dist.keySet()) {
                if (!inputNames[i].equals(st.getValue("input"))) continue;
                result[i] = st;
                found = true;
                break;
            }
            if (!found) {
                System.out.println("Error: A label of the (prior) input domain is duplicated or missing in the channel matrix.");
                System.out.println("  Input domain: ");
                for (State st : this.dist.keySet()) {
                    System.out.print("   ");
                    st.printState();
                }
                System.out.println("  Labels of the channel matrix: ");
                System.out.print("    {");
                String[] stringArray = inputNames;
                int n = inputNames.length;
                int n2 = 0;
                while (n2 < n) {
                    String s = stringArray[n2];
                    System.out.print(" " + s);
                    ++n2;
                }
                System.out.println(" }");
                return null;
            }
            ++i;
        }
        return result;
    }

    public double[] probDistToPMFArray(String[] inputNames) {
        if (inputNames.length < this.dist.size()) {
            System.out.println("Error: The size of the (prior) input domain is larger than the channel matrix.");
            System.out.println("  the input domain size: " + this.dist.size());
            System.out.println("  the number of rows in the channel matrix: " + inputNames.length);
            System.out.println("Failed to produce an array of probabilities from the probability distribution.");
            System.out.println("  See ProbDist.probDistToPMFArray(String[] inputNames) for debug.");
            System.exit(1);
        }
        double[] result = new double[inputNames.length];
        int i = 0;
        while (i < inputNames.length) {
            boolean found = false;
            for (State st : this.dist.keySet()) {
                if (!inputNames[i].equals(st.getValue("input"))) continue;
                result[i] = this.dist.get(st);
                found = true;
                break;
            }
            if (!found) {
                result[i] = 0.0;
            }
            ++i;
        }
        double sum1 = 0.0;
        double sum2 = 0.0;
        double[] dArray = result;
        int n = result.length;
        int n2 = 0;
        while (n2 < n) {
            double d = dArray[n2];
            sum1 += d;
            ++n2;
        }
        for (State st : this.dist.keySet()) {
            sum2 += this.dist.get(st).doubleValue();
        }
        if (Math.abs(sum1 - sum2) > 1.0E-10) {
            System.out.println("Error: The states of the given prior and channel are different.");
            System.out.println("  Some label of the (prior) input domain is missing in the channel matrix.");
            System.out.println("  Input domain (size = " + (int)sum1 + "): ");
            for (State st : this.dist.keySet()) {
                System.out.print("   ");
                st.printState();
            }
            System.out.println("  Labels of the channel matrix (size = " + (int)sum2 + "): ");
            System.out.print("    {");
            int i2 = 0;
            while (i2 < inputNames.length) {
                System.out.print(" " + inputNames[i2]);
                if (i2 != inputNames.length - 1) {
                    System.out.print(", ");
                }
                ++i2;
            }
            System.out.println(" }");
            System.out.println("  See ProbDist.probDistToPMFArray(String[] inputNames) for debug.");
            System.exit(1);
        }
        return result;
    }

    private ArrayList<HashSet<String>> getSampleSpace() {
        ProbDist[] marginals = this.getAllMarginals();
        int numJoint = marginals.length;
        ArrayList<HashSet<String>> sampleSpace = new ArrayList<HashSet<String>>();
        int i = 0;
        while (i < numJoint) {
            Collection<State> sts = marginals[i].getStatesCollection();
            HashSet<String> hsst = new HashSet<String>();
            for (State st : sts) {
                hsst.add(st.getValue("input"));
            }
            sampleSpace.add(hsst);
            ++i;
        }
        return sampleSpace;
    }

    public int getNumJoint() {
        if (this.numJoint > 1) {
            return this.numJoint;
        }
        try {
            Iterator<State> iterator = this.dist.keySet().iterator();
            if (iterator.hasNext()) {
                State st = iterator.next();
                String lineInput = st.getValue("input");
                String[] input = lineInput.split(",", 0);
                this.numJoint = input.length;
            }
        }
        catch (Exception ex0) {
            System.out.println("Error in reading elements of an input." + ex0);
            System.out.println("  The file does not follow a prior file (-prior) format.");
            System.exit(1);
        }
        return this.numJoint;
    }

    public boolean consistentChannelsAndPrior(int numChannels) {
        if (numChannels != this.getNumJoint()) {
            System.out.println("Error: The number of channels (" + numChannels + ") does not match with the size of the (prior) input distribution (" + this.getNumJoint() + ").");
            System.exit(1);
        }
        return true;
    }

    private int stringToID(int num, String outcome, ArrayList<HashSet<String>> sampleSpace) {
        int index = 0;
        for (String str : sampleSpace.get(num)) {
            if (outcome.equals(str)) break;
            ++index;
        }
        return index;
    }

    private String IDToString(int num, int id, ArrayList<HashSet<String>> sampleSpace) {
        int index = 0;
        for (String str : sampleSpace.get(num)) {
            if (index == id) {
                return str;
            }
            ++index;
        }
        return "";
    }

    public boolean isJointlySupported() {
        if (this.allMarginals == null) {
            this.getAllMarginals();
        }
        ArrayList<HashSet<String>> sampleSpace = this.getSampleSpace();
        int sizeSampleSpace = sampleSpace.size();
        int jointNum = this.getNumJoint();
        int[] sizeSpace = new int[sizeSampleSpace];
        int i = 0;
        while (i < sizeSampleSpace) {
            sizeSpace[i] = sampleSpace.get(i).size();
            ++i;
        }
        int[] bases = new int[sizeSampleSpace + 1];
        bases[0] = 1;
        int i2 = 1;
        while (i2 <= sizeSampleSpace) {
            bases[i2] = bases[i2 - 1] * sizeSpace[i2 - 1];
            ++i2;
        }
        HashMap<Integer, Double> idmap = new HashMap<Integer, Double>();
        for (State st : this.dist.keySet()) {
            try {
                int id = 0;
                int i3 = 0;
                while (i3 < sizeSampleSpace) {
                    String outcome = st.getValue("input" + i3);
                    id += bases[i3] * this.stringToID(i3, outcome, sampleSpace);
                    ++i3;
                }
                double prob = this.dist.get(st);
                idmap.put(id, prob);
            }
            catch (Exception ex) {
                System.out.println("Error in reading elements of an input to calculate a projection of a state: " + ex);
                ex.printStackTrace();
                System.out.println("  The file does not follow a prior file (-prior) format.");
                System.exit(1);
            }
        }
        int maxID = bases[sizeSampleSpace];
        int id = 0;
        while (id < maxID) {
            if (idmap.containsKey(id) && (Double)idmap.get(id) != 0.0) {
                if (verbose > 5) {
                    System.out.println("  The sample with id = " + id + " exists in the channel matrix.");
                }
            } else {
                int[] marginalsID = new int[sizeSampleSpace];
                String[] outcomes = new String[sizeSampleSpace];
                int remainder = id;
                int i4 = jointNum - 1;
                while (i4 >= 0) {
                    marginalsID[i4] = remainder / bases[i4];
                    outcomes[i4] = this.IDToString(i4, marginalsID[i4], sampleSpace);
                    remainder %= bases[i4];
                    --i4;
                }
                boolean result = false;
                int i5 = 0;
                while (i5 < sizeSampleSpace) {
                    if (this.allMarginals[i5].getProb(outcomes[i5]) == 0.0) {
                        result = true;
                        break;
                    }
                    ++i5;
                }
                if (!result) {
                    if (idmap.containsKey(id)) {
                        System.out.println("The joint input distribution is: ");
                        this.printProbDist();
                        System.out.println("while the marginals are:");
                    } else {
                        System.out.print("There is a sample with the zero joint probability: [");
                        String[] stringArray = outcomes;
                        int n = outcomes.length;
                        int n2 = 0;
                        while (n2 < n) {
                            String outcome = stringArray[n2];
                            System.out.print(" " + outcome);
                            ++n2;
                        }
                        System.out.println(" ] with id = " + id + ", while the marginals are:");
                    }
                    i = 0;
                    while (i < sizeSampleSpace) {
                        System.out.println("  marginal_" + i + "(" + outcomes[i] + ") = " + this.allMarginals[i].getProb(outcomes[i]));
                        ++i;
                    }
                    System.out.println("Hence the joint input distribution is not jointly supported.");
                    return result;
                }
            }
            ++id;
        }
        return true;
    }

    public ProbDist getMarginal(int num) {
        int jointNum = this.getNumJoint();
        if (num >= jointNum) {
            System.out.println("Error: " + num + "-th marginal probability is not defined.");
            System.out.println("  jointNum = " + jointNum);
            System.exit(1);
        } else if (jointNum == 1) {
            for (State jst : this.dist.keySet()) {
                String str = jst.getValue("input").replaceAll("\\(", " ").replaceAll("\\)", "").trim();
                jst.updateValue("input0", str);
            }
            return this;
        }
        if (this.allMarginals != null && this.allMarginals[num] != null) {
            return this.allMarginals[num];
        }
        HashMap<State, Double> marginalProb = new HashMap<State, Double>();
        for (State jst : this.dist.keySet()) {
            String outcomeTuple;
            Pattern pattern = Pattern.compile("\\((.+?)\\)");
            Matcher matcher = pattern.matcher(outcomeTuple = jst.getValue("input"));
            if (matcher.find()) {
                String[] outcomes = matcher.group(1).split(",", 0);
                if (num < outcomes.length) {
                    State st = new State();
                    boolean found = false;
                    for (State key : marginalProb.keySet()) {
                        if (key.getValue("input") == null || !key.getValue("input").equals(outcomes[num].trim())) continue;
                        found = true;
                        st = key;
                        break;
                    }
                    if (found) {
                        double prob = marginalProb.get(st) + this.dist.get(jst);
                        marginalProb.put(st, prob);
                    } else {
                        st.updateValue("input", outcomes[num].trim());
                        marginalProb.put(st, this.dist.get(jst));
                    }
                    jst.updateValue("input" + num, outcomes[num].trim());
                    continue;
                }
                System.out.println("Error in reading elements of an input to calculate a marginal.");
                System.out.println("  The file does not follow a prior file (-prior) format.");
                System.exit(1);
                continue;
            }
            System.out.println("Error in reading elements of an input to calculate a marginal.");
            System.out.println("  The file does not follow a prior file (-prior) format.");
            System.exit(1);
        }
        return new ProbDist(marginalProb, true);
    }

    public String getProjectedState(State jst, int numElements) {
        String outcomeTuple = jst.getValue("input");
        if (this.getNumJoint() == 1) {
            return outcomeTuple.trim();
        }
        Pattern pattern = Pattern.compile("\\((.+?)\\)");
        Matcher matcher = pattern.matcher(outcomeTuple);
        try {
            String[] outcomes;
            if (matcher.find() && numElements < (outcomes = matcher.group(1).split(",", 0)).length) {
                return outcomes[numElements].trim();
            }
        }
        catch (Exception ex) {
            System.out.println("Error in reading elements of an input to calculate a projection of a state.");
            System.out.println("  The file does not follow a prior file (-prior) format.");
            System.exit(1);
        }
        System.out.println("Error in reading elements of an input to calculate a projection of a state.");
        System.out.println("  The file does not follow a prior file (-prior) format.");
        System.exit(1);
        return "";
    }

    public ProbDist[] getAllMarginals() {
        int jointNum = this.getNumJoint();
        if (this.allMarginals != null) {
            return this.allMarginals;
        }
        ProbDist[] pds = new ProbDist[jointNum];
        int i = 0;
        while (i < jointNum) {
            pds[i] = this.getMarginal(i);
            if (verbose > 5) {
                System.out.println("Marginal input distribution (" + i + ")");
                pds[i].printProbDist();
            }
            ++i;
        }
        this.allMarginals = pds;
        return pds;
    }

    public void printProb(State st) {
        System.out.println("Pr" + st.stringState() + " = " + this.getProb(st));
    }

    public void printProbDist() {
        System.out.println("{");
        for (State st : this.dist.keySet()) {
            if (st == null) continue;
            System.out.println("  " + st.stringState() + " = " + this.dist.get(st));
        }
        System.out.println("}");
        if (verbose >= 5) {
            double sum = 0.0;
            for (State st : this.dist.keySet()) {
                if (st == null) continue;
                sum += this.getProb(st);
            }
            System.out.println("Sum of probabilities: " + sum);
        }
    }

    public void printProbDist(File file) {
        try {
            PrintWriter pw = new PrintWriter(new BufferedWriter(new FileWriter(file)));
            for (State st : this.dist.keySet()) {
                if (st == null) continue;
                String str = st.getValue("input");
                pw.println("(" + this.dist.get(st) + ", " + str + ")");
            }
            pw.close();
        }
        catch (FileNotFoundException ex) {
            System.out.println(" file not found " + ex);
        }
        catch (Exception ex) {
            System.out.println(" error " + ex);
        }
        if (verbose >= 5) {
            double sum = 0.0;
            for (State st : this.dist.keySet()) {
                if (st == null) continue;
                sum += this.getProb(st);
            }
            System.out.println("Sum of probabilities: " + sum);
        }
    }
}

