#!/usr/bin/env python
# Test cases for problem set 6
import unittest

# You can change this line, and the line ps6.ProblemSet6() below, to
# match the name of your code file (unfortunately, it cannot contain dashes).
import ps6
import nhpn

ps = ps6.ProblemSet6()

# See the code template for what your interface should look like.

class TestPS6(unittest.TestCase):
    def setUp(self):
        self.ps = ps
        self.nodes = self.ps.loader.nodes()
        self.links = self.ps.loader.links()

    def test1(self):
        # The first node in MA whose description contains "cambridge"
        ans = self.ps.node_by_name(self.nodes, 'CAMBRIDGE', 'MA')
        self.assertEqual(ans.state, 'MA')
        self.assertEqual('NORTH CAMBRIDGE', ans.description)

    def test2(self):
        # The distance between two nodes
        ans = self.ps.lat_long_length(self.nodes[0], self.nodes[1])
        self.assertAlmostEqual(105747.58, ans, 2)

    def test3(self):
        # The length of an actual edge
        ans = self.ps.lat_long_length(self.links[0].begin, self.links[0].end)
        self.assertAlmostEqual(2440, ans, 0)

    def verifyPath(self, path, edges, src, dest):
        """Verify that a path is a valid path from src to dest
        (it's valid if it uses only edges in the edge set)."""
        self.assertEqual(src, path[0])
        self.assertEqual(dest, path[-1])
        for i in range(len(path)-1):
            # Check that there's an edge between adjacent pairs.
            # NOTE: This implementation is inefficient!
            for e in edges:
                if e.begin == path[i] and e.end == path[i+1] or \
                       e.begin == path[i+1] and e.end == path[i]:
                    break
            else:
                fail('Adjacent nodes in path have no edge between them')

    def sumPath(self, path, weight):
        """Compute the sum of weights along a path.
        Requires path to be valid (see verifyPath)."""
        sum = 0
        for i in range(len(path)-1):
            sum += weight(path[i], path[i+1])
        return sum

    def testCLRS(self):
        # Run shortest paths on the example graph from CLRS, by constructing
        # a fake data set and weight function.
        V = []
        E = []
        W = []
        for i in range(5):
            V.append(nhpn.Node(i, i, '', ''))
            W.append({})
        W[0][1] = 10
        W[0][2] = 5
        W[1][2] = 2
        W[1][3] = 1
        W[2][1] = 3
        W[2][3] = 9
        W[2][4] = 2
        W[3][4] = 4
        W[4][0] = 7
        W[4][3] = 6
        for i in range(5):
            for j in range(i+1, 5):
                if W[i].has_key(j) or W[j].has_key(i):
                    E.append(nhpn.Link(V[i], V[j], ''))
              
        def eweight(n1, n2):
            x1 = n1.longitude
            x2 = n2.longitude
            if W[x1].has_key(x2):
                return W[x1][x2]
            else:
                return 1000

        ans3 = self.ps.dijkstra_search(V, E, eweight, V[0], V[3])
        self.verifyPath(ans3, E, V[0], V[3])
        self.assertEqual(9, self.sumPath(ans3, eweight))

    def testAL(self):
        # Path between the first two towns in the data set.
        # The shortest path has 1 node in the middle.
        src = self.ps.node_by_name(self.nodes, 'BRIDGEPORT', 'AL')
        dest = self.ps.node_by_name(self.nodes, 'STEVENSON', 'AL')
        ans = self.ps.dijkstra_search(self.nodes,
                                      self.links,
                                      self.ps.lat_long_length,
                                      src, dest)
        self.assertEqual(3, len(ans))
        self.verifyPath(ans, self.links, src, dest)

    def testLonger(self):
        # Path between two arbitrary chosen nodes.
        src = self.nodes[0]
        dest = self.nodes[100]
        ans = self.ps.dijkstra_search(self.nodes,
                                      self.links,
                                      self.ps.lat_long_length,
                                      src, dest)
        self.verifyPath(ans, self.links, src, dest)
        self.assertAlmostEqual(2076299,
                               self.sumPath(ans, self.ps.lat_long_length),
                               0)

if __name__ == '__main__':
    suite = unittest.TestLoader().loadTestsFromTestCase(TestPS6)
    unittest.TextTestRunner(verbosity=2).run(suite)
