/*
 * Created on Oct 7, 2003
 *
 */
package edu.mit.six825.bn.bayesnet;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

import edu.mit.six825.bn.functiontable.Assignment;
import edu.mit.six825.bn.functiontable.ComparableBoolean;
import edu.mit.six825.bn.functiontable.Compute;
import edu.mit.six825.bn.functiontable.DomainIndex;
import edu.mit.six825.bn.functiontable.Function;
import edu.mit.six825.bn.functiontable.FunctionVariable;
import edu.mit.six825.bn.functiontable.FunctionVariableSet;

/**
 * An implementation of the most basic solver for the BayesNets. The algorithm
 * has been implemented as per the pseudo-code provided in
 * AIMA, 2nd Ed., pg. 506
 *
 * @author vineet, drayside
 */
public class EnumerationSolver extends Solver {

	public EnumerationSolver(BayesNet _bn) {
		super(_bn);
	}
	public EnumerationSolver() {
		super();
	}

	private Function enumerationAsk(final BayesNetNode node) {
		final double[] retVal = new double[node.var.domain.size()];
		double sum = 0;

		// iterate over the domain of the variable, trying each value in the domain
		for (final Iterator i = node.var.domain.iterator(); i.hasNext(); ) {
			final DomainIndex index = (DomainIndex)i.next();
			final Comparable value = index.getValue();
			final Assignment valEvidence = new Assignment(_evidence, node.var, value);
			System.out.println(node + "=" + value);
			retVal[index.i] =
				enumerateAll(
					"  ",
					_bn.nodes.getNodesWithTopologicalOrdering(),
					valEvidence);
		}

		return Compute.normalize(new Function(node.var, retVal));
	}

	private double enumerateAll(
		final String h,
		final List varList,
		final Assignment currEvidence) {
		if (varList.isEmpty()) {
			return 1.0;
		}
		final BayesNetNode Y = (BayesNetNode) varList.get(0);
		final List restVars = (List) ((LinkedList)varList).clone();
		restVars.remove(Y);
		/*
		for (final Iterator i = varList.iterator(); i.hasNext(); ) {
			restVars.add(i.next());
		}
		*/


		final Comparable evidenceValueForY = currEvidence.getAssignedValue(Y.var);
		//Integer evidenceValueY = (Integer) currEvidence.get(Y);

		if (evidenceValueForY != null) {
			System.out.println(
				h
					+ "P("
					+ Y
					+ "="
					+ evidenceValueForY //Y.getGivenValue(((Integer) evidenceValueY).intValue())
					+ ")");
			return Y.cpt.evaluate(currEvidence) //Y.getProb(evidenceValueY.intValue(), currEvidence)
				* enumerateAll(h + "  ", restVars, currEvidence);
		} else {
			double sum = 0;
			System.out.println(h + "E [" + Y + "]");
			// iterate through all of the values in the domain of Y
			for (final Iterator i = Y.var.domain.iterator(); i.hasNext(); ){
				final DomainIndex index = (DomainIndex)i.next();
				final Comparable value = index.getValue();
				final Assignment valEvidence = new Assignment(currEvidence, Y.var, value);
				System.out.println(h + "|-" + value);
				sum += Y.cpt.evaluate(valEvidence) //getProb(value, currEvidence)
					* enumerateAll(h + "  ", restVars, valEvidence);

			}
			/*
			for (int value = 0; value < Y.getNumValues(); value++) {
				final Assignment valEvidence = new Assignment(currEvidence, Y.var, );
				final valEvidence.put(Y, new Integer(value));
				//System.out.println(h+"P(" + Y + "=" + Y.getGivenValue(value) + ")");
				System.out.println(h + "|-" + Y.getGivenValue(value));
				sum += Y.getProb(value, currEvidence)
					* enumerateAll(h + "  ", restVars, valEvidence);
			}
			*/
			return sum;
		}
	}

	public String toString() {
		return "EnumerationSolver";
	}

	public static void main(String[] args) {
		System.out.println("Prob(Burglary|JohnCalls=true, MaryCalls=true)");
		System.out.println("Burglary=TRUE AIMA: " + 0.284);

		final BayesNet bn = edu.mit.six825.bn.inputs.Nets.getBurglary();
		final Solver solver = new EnumerationSolver();
		solver.setBayesNet(bn);
		//Solver solver = new EnumerationSolver(bn);
		// ...GibbsSamplerSolver(bn);
		// ...LikelihoodWeightingSolver(bn);
		// ...VariableEliminationSolver(bn);

		final FunctionVariable[] vars = new FunctionVariable[2];
		vars[0] = new FunctionVariable("JohnCalls");
		vars[1] = new FunctionVariable("MaryCalls");
		final Comparable[] vals = new Comparable[2];
		vals[0] = ComparableBoolean.TRUE;
		vals[1] = ComparableBoolean.TRUE;
		final Assignment evidence = new Assignment(new FunctionVariableSet(vars), vals);
		solver.setEvidence(evidence);

		final BayesNetNode burgVar = bn.nodes.getNode("Burglary");
		final Function burgProb = solver.query(burgVar);
		System.out.println("Burglary=TRUE Calc.: " + burgProb);
	}

	public Function query(BayesNetNode variable) {
		return enumerationAsk(variable);
	}

}
