""" Simple unit tests for particle_demo.py

Author: Nathan Sprague
Version: 10/2020
"""

import unittest
import particle_demo
from particle_demo import Particle


class TestFilter(unittest.TestCase):

    def test_normalize_one_particle_correct_weight(self):
        particles = [Particle(10.0, 3)]

        particle_demo.normalize_particles(particles)
        self.assertAlmostEqual(1.0, particles[0].weight)
        self.assertEqual(3, particles[0].x)

        particles[0].weight = .1
        particle_demo.normalize_particles(particles)
        self.assertAlmostEqual(1.0, particles[0].weight)

    def test_normalize_correct_weights(self):
        particles = [Particle(2.5, 3), Particle(7, 3), Particle(.5, 3), ]

        particle_demo.normalize_particles(particles)
        self.assertAlmostEqual(.25, particles[0].weight)
        self.assertAlmostEqual(.7, particles[1].weight)
        self.assertAlmostEqual(.05, particles[2].weight)


    def test_calc_probability(self):
        particles = [Particle(.25, 3), Particle(.7, 3), Particle(.05, 1), ]

        self.assertAlmostEqual(.95,
                               particle_demo.calc_probability(particles, 3))
        self.assertAlmostEqual(.05,
                               particle_demo.calc_probability(particles, 1))
        self.assertEqual(0, particle_demo.calc_probability(particles, 2))

    def test_resample_single_particle_creates_copy_not_alias(self):
        particle = Particle(1., 3)
        particles = [particle]
        particles = particle_demo.resample(particles)
        self.assertFalse(particle is particles[0])

    def test_resample_ignores_weight_zero_particles(self):
        particles = [Particle(0, 3), Particle(0, 3), Particle(1, 1), ]
        particles = particle_demo.resample(particles)
        for p in particles:
            self.assertTrue(p.x == 1)
            self.assertAlmostEqual(1./3, p.weight)

    def test_resampled_particles_are_normalized(self):
        particles = [Particle(.25, 3), Particle(.7, 3), Particle(.05, 1), ]
        particles = particle_demo.resample(particles)

        total = 0
        for p in particles:
            total += p.weight

        self.assertAlmostEqual(1, total)


    def test_resampling_is_proportional(self):

        # make 10 particles covering three different states.
        particles = [Particle(.1/8, 0) for i in range(8)] #10% 0s
        particles.append(Particle(.2, 1))                 #20% 1's
        particles.append(Particle(.7, 2))                 #80% 2's

        counters = {}
        counters[0] = 0
        counters[1] = 0
        counters[2] = 0
        trials = 1000
        
        for _ in range(trials):
            new_particles = particle_demo.resample(particles)
            for p in new_particles:
                counters[p.x] += 1

        self.assertTrue(abs(counters[0] / 10. / trials - .1) < .05)
        self.assertTrue(abs(counters[1] / 10. / trials - .2) < .05)
        self.assertTrue(abs(counters[2] / 10. / trials - .7) < .05)

if __name__ == '__main__':
    unittest.main()
