import static org.junit.jupiter.api.Assertions.*;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.NoSuchElementException;
import org.junit.jupiter.api.Test;

/**
 * Unit tests for the Counter Multiset class
 *
 * @author CS 240 Instructors
 * @version 07/2025
 *
 */
public class CounterMultisetTest {

  private <T> Multiset<T> makeMultiset(int capacity) {
    return new CounterMultiset<T>(capacity);
  }

  // --------------------------------------------
  // TESTS FOR SIZE
  // --------------------------------------------
  @Test
  public void testEmptySizeZero() {
    Multiset<String> set = makeMultiset(100);
    assertEquals(0, set.size());
  }

  @Test
  public void testSizeDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add("b");
    assertEquals(3, set.size());
  }


  @Test
  public void testSizeNulls() {
    Multiset<String> set = makeMultiset(100);
    set.add(null);
    set.add("a");
    set.add("b");
    assertEquals(3, set.size());
  }

  @Test
  public void testSizeNoDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("b");
    set.add("c");
    assertEquals(3, set.size());
  }

  @Test
  public void testSizeRemoveDuplicate() {
    Multiset<String> set = makeMultiset(100);

    set.add("a");
    set.add("a");
    set.add("b");
    set.remove("a");
    assertEquals(2, set.size());
  }

  @Test
  public void testSizeRemoveNull() {
    Multiset<String> set = makeMultiset(100);

    set.add(null);
    set.add(null);
    set.add("b");
    set.remove(null);
    assertEquals(2, set.size());
  }

  @Test
  public void testSizeRemoveNonDuplicate() {
    Multiset<String> set = makeMultiset(100);

    set.add("a");
    set.add("b");
    set.add("c");
    set.remove("a");
    assertEquals(2, set.size());
  }

  // --------------------------------------------
  // TESTS FOR MULTIPLE ADD
  // --------------------------------------------
  @Test
  public void testAddZeroDoesntChangeSize() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a", 0);
    assertEquals(1, set.size());
  }


  @Test
  public void testAddMultipleNulls() {
    Multiset<String> set = makeMultiset(100);
    set.add("a", 4);
    set.add(null, 3);
    assertEquals(4, set.getCount("a"));
    assertEquals(3, set.getCount(null));
  }

  @Test
  public void testAddMultipleAddsCorrectNumber() {
    Multiset<String> set = makeMultiset(100);
    set.add("a", 4);
    set.add("b", 3);
    assertEquals(4, set.getCount("a"));
    assertEquals(3, set.getCount("b"));
    set.add("b", 10);
    assertEquals(4, set.getCount("a"));
    assertEquals(13, set.getCount("b"));
    assertEquals(17, set.size());
  }

  @Test
  public void testAddMultipleCorrectReturnValue() {
    Multiset<String> set = makeMultiset(3);
    assertTrue(set.add("a", 20));
    assertTrue(set.add("b", 20));
    assertTrue(set.add("c", 20));
    assertTrue(set.add("a", 20));
    assertFalse(set.add("d", 20));
  }

  @Test
  public void testNegativeAdd() {
    Multiset<String> set = makeMultiset(100);
    assertThrows(IllegalArgumentException.class, () -> set.add("A", -1));
  }

  // --------------------------------------------
  // TESTS FOR CONTAINS
  // --------------------------------------------

  @Test
  public void testContainsSetEmpty() {
    Multiset<String> set = makeMultiset(100);

    assertFalse(set.contains("a"));
  }

  @Test
  public void testContainsHandlesNulls() {
    Multiset<String> set = makeMultiset(100);

    assertFalse(set.contains(null));
    set.add(null);
    assertTrue(set.contains(null));
  }

  @Test
  public void testContainsTrueOneElement() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    assertTrue(set.contains("a"));
  }

  @Test
  public void testContainsFalseOneElement() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    assertFalse(set.contains("tree"));
    assertFalse(set.contains("house"));
  }

  @Test
  public void testContainsMultipleElementsNoDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("b");
    set.add("c");
    set.add("d");

    assertTrue(set.contains("a"));
    assertTrue(set.contains("b"));
    assertTrue(set.contains("c"));
    assertTrue(set.contains("d"));
    assertFalse(set.contains("q"));
  }

  @Test
  public void testContainsMultipleElementsWithDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add("b");
    set.add("b");
    set.add("c");
    set.add("c");

    assertTrue(set.contains("a"));
    assertTrue(set.contains("b"));
    assertTrue(set.contains("c"));
    assertFalse(set.contains("q"));
  }

  @Test
  public void testContainsAfterRemovalNoDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("b");
    set.add("c");
    set.add("d");
    set.add("e");
    set.remove("a");
    set.remove("c");
    set.remove("e");

    assertFalse(set.contains("a"));
    assertFalse(set.contains("c"));
    assertFalse(set.contains("e"));
    assertTrue(set.contains("b"));
    assertTrue(set.contains("d"));
  }

  @Test
  public void testContainsAfterRemovalWithDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add("c");
    set.add("c");
    set.add("e");
    set.add("e");

    set.remove("a");
    set.remove("c");
    set.remove("e");
    assertTrue(set.contains("a"));
    assertTrue(set.contains("c"));
    assertTrue(set.contains("e"));

    set.remove("a");
    set.remove("c");
    set.remove("e");
    assertFalse(set.contains("a"));
    assertFalse(set.contains("c"));
    assertFalse(set.contains("e"));

  }

  @SuppressWarnings("unlikely-arg-type")
  @Test
  public void testContainsObjectArgument() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    Integer integer = Integer.valueOf(3);
    assertFalse(set.contains(integer));
  }

  // --------------------------------------------
  // TESTS FOR COUNT
  // --------------------------------------------
  @Test
  public void testCountSetEmpty() {
    Multiset<String> set = makeMultiset(100);

    assertEquals(0, set.getCount("A"));
  }

  @Test
  public void testCountZeroNonEmpty() {
    Multiset<String> set = makeMultiset(100);

    set.add("B");
    assertEquals(0, set.getCount("A"));
  }

  @Test
  public void testCountOneOneElement() {
    Multiset<String> set = makeMultiset(100);

    set.add("C");
    assertEquals(1, set.getCount("C"));
  }

  @Test
  public void testCountMultipleElementsNoDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("b");
    set.add("c");
    set.add("d");

    assertEquals(1, set.getCount("a"));
    assertEquals(1, set.getCount("b"));
    assertEquals(1, set.getCount("c"));
    assertEquals(1, set.getCount("d"));
    assertEquals(0, set.getCount("q"));
  }

  @Test
  public void testCountMultipleElementsWithDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add("b");
    set.add("b");
    set.add("c");
    set.add("c");

    assertEquals(2, set.getCount("a"));
    assertEquals(2, set.getCount("b"));
    assertEquals(2, set.getCount("c"));
    assertEquals(0, set.getCount("q"));
  }

  @Test
  public void testCountWithNulls() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add(null);
    set.add(null);
    set.add("c");
    set.add("c");

    assertEquals(2, set.getCount("a"));
    assertEquals(2, set.getCount(null));
    assertEquals(2, set.getCount("c"));
    assertEquals(0, set.getCount("q"));
  }

  @Test
  public void testCountAfterRemovalNoDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("b");
    set.add("c");
    set.add("d");
    set.add("e");
    set.remove("a");
    set.remove("c");
    set.remove("e");

    assertEquals(0, set.getCount("a"));
    assertEquals(1, set.getCount("b"));
    assertEquals(0, set.getCount("c"));
    assertEquals(1, set.getCount("d"));
    assertEquals(0, set.getCount("e"));
  }

  @Test
  public void testCountAfterRemovalWithDuplicates() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add("c");
    set.add("c");
    set.add("e");
    set.add("e");

    set.remove("a");
    set.remove("c");
    set.remove("e");

    assertEquals(1, set.getCount("a"));
    assertEquals(1, set.getCount("c"));
    assertEquals(1, set.getCount("e"));
  }

  @Test
  public void testCountAfterRemovalWithDuplicatesAndNulls() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add(null);
    set.add(null);
    set.add("c");
    set.add("c");

    set.remove("a");
    set.remove(null);
    set.remove("c");

    assertEquals(1, set.getCount("a"));
    assertEquals(1, set.getCount(null));
    assertEquals(1, set.getCount("c"));
  }

  @Test
  public void testCountObjectArgument() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    Integer integer = Integer.valueOf(3);
    assertEquals(0, set.getCount(integer));

  }

  // --------------------------------------------
  // TESTS FOR EQUALS
  // --------------------------------------------

  @Test
  public void testEqualsNonMultiset() {
    Multiset<String> set = makeMultiset(100);

    assertFalse(set.equals("HELLO"));
  }

  @Test
  public void testEqualsSameMultiset() {
    Multiset<String> set = makeMultiset(100);
    set.add("Z");
    assertTrue(set.equals(set));
  }

  @Test
  public void testEqualsEmptyMultisetsAfterRemoval() {
    Multiset<String> set1 = makeMultiset(100);
    Multiset<String> set2 = makeMultiset(100);
    set1.add("a");
    set1.remove("a");

    assertTrue(set1.equals(set2));
    assertTrue(set2.equals(set1));
  }

  @Test
  public void testEqualsNoRepeats() {
    Multiset<String> set1 = makeMultiset(100);
    Multiset<String> set2 = makeMultiset(100);
    Multiset<String> set3 = makeMultiset(100);

    set1.add("a");
    set1.add("b");
    set1.add("c");
    set1.add("d");

    set2.add("d");
    set2.add("b");
    set2.add("c");
    set2.add("a");

    set3.add("b");
    set3.add("c");
    set3.add("a");

    assertTrue(set1.equals(set2));
    assertTrue(set2.equals(set1));
    assertFalse(set3.equals(set1));
  }

  @Test
  public void testEqualsWithRepeats() {
    Multiset<String> set1 = makeMultiset(100);
    Multiset<String> set2 = makeMultiset(100);
    Multiset<String> set3 = makeMultiset(100);
    set1.add("a");
    set1.add("b");
    set1.add("a");
    set1.add("b");
    set1.add("c");
    set1.add("d");

    set2.add("d");
    set2.add("b");
    set2.add("b");
    set2.add("c");
    set2.add("a");
    set2.add("a");

    set3.add("d");
    set3.add("b");
    set3.add("d");
    set3.add("c");
    set3.add("a");
    set3.add("a");

    assertTrue(set1.equals(set2));
    assertTrue(set2.equals(set1));
    assertFalse(set3.equals(set1));
    assertFalse(set1.equals(set3));
  }

  @Test
  public void testEqualsWithRepeatsAndNulls() {
    Multiset<String> set1 = makeMultiset(100);
    Multiset<String> set2 = makeMultiset(100);
    Multiset<String> set3 = makeMultiset(100);
    set1.add(null);
    set1.add("a");
    set1.add("b");
    set1.add("a");
    set1.add("b");
    set1.add("c");
    set1.add("d");

    set2.add(null);
    set2.add("d");
    set2.add("b");
    set2.add("b");
    set2.add("c");
    set2.add("a");
    set2.add("a");

    set3.add("d");
    set3.add("b");
    set3.add("d");
    set3.add("c");
    set3.add("a");
    set3.add("a");

    assertTrue(set1.equals(set2));
    assertTrue(set2.equals(set1));
    assertFalse(set3.equals(set1));
    assertFalse(set1.equals(set3));
  }

  // --------------------------------------------
  // TESTS FOR ITERATOR
  // --------------------------------------------

  @Test
  public void testIteratorEmptyHasNext() {
    Multiset<String> set1 = makeMultiset(100);
    Iterator<String> it = set1.iterator();

    assertFalse(it.hasNext());
  }

  @Test
  public void testIteratorEmptyNextException() {
    Multiset<String> set1 = makeMultiset(100);
    Iterator<String> it = set1.iterator();
    assertThrows(NoSuchElementException.class, () -> it.next());
  }

  @Test
  public void testIteratorNonEmptyNextException() {
    Multiset<String> set1 = makeMultiset(100);
    set1.add("a");
    Iterator<String> it = set1.iterator();
    it.next();
    assertThrows(NoSuchElementException.class, () -> it.next());
  }

  private void testIteration(ArrayList<String> items) {

    // Create a hashmap with item counts
    HashMap<String, Integer> map = new HashMap<>();
    for (String item : items) {
      if (map.containsKey(item)) {
        map.put(item, map.get(item) + 1);
      } else {
        map.put(item, 1);
      }
    }

    // Add all of the items to the multiset
    Multiset<String> set = makeMultiset(100);
    for (String item : items) {
      set.add(item);
    }

    // Iterate through the multiset, decrementing counts
    for (String item : set) {
      if (map.containsKey(item)) {
        map.put(item, map.get(item) - 1);
      } else {
        fail();
      }
    }

    // all counts should be 0.
    for (String key : map.keySet()) {
      assertEquals(0, (int) map.get(key));
    }

  }

  @Test
  public void testIteratorNoRepeats() {
    ArrayList<String> items = new ArrayList<>();
    items.add("a");
    items.add("b");
    items.add("c");
    items.add("d");
    items.add("e");
    items.add("f");
    testIteration(items);
  }

  @Test
  public void testIteratorWithRepeatsAndRemovals() {
    ArrayList<String> items = new ArrayList<>();
    items.add("a");
    items.add("b");
    items.add("c");
    items.add("d");
    items.add("e");
    items.add("f");

    items.add("a");
    items.add("b");
    items.add("c");
    items.add("d");
    items.add("e");

    items.add("a");
    items.add("b");

    items.remove("a");

    testIteration(items);
  }

  // --------------------------------------------
  // TESTS FOR ITERATOR REMOVE
  // --------------------------------------------
  @Test
  public void testIteratorRemoveToCount1() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add("b");
    set.add("b");
    set.add("c");
    set.add("c");

    Iterator<String> it = set.iterator();

    while (it.hasNext()) { // remove one "b"
      if (it.next().equals("b")) {
        it.remove();
        break;
      }
    }
    assertEquals(2, set.getCount("a"));
    assertEquals(1, set.getCount("b"));
    assertEquals(2, set.getCount("c"));
    assertEquals(5, set.size());

    it = set.iterator();
    while (it.hasNext()) { // remove one "a"
      if (it.next().equals("a")) {
        it.remove();
        break;
      }
    }
    assertEquals(1, set.getCount("a"));
    assertEquals(1, set.getCount("b"));
    assertEquals(2, set.getCount("c"));
    assertEquals(4, set.size());

    it = set.iterator();
    while (it.hasNext()) { // remove one "c"
      if (it.next().equals("c")) {
        it.remove();
        break;
      }
    }
    assertEquals(1, set.getCount("a"));
    assertEquals(1, set.getCount("b"));
    assertEquals(1, set.getCount("c"));
    assertEquals(3, set.size());
  }

  @Test
  public void testIteratorRemoveToCount0() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    set.add("b");
    set.add("b");
    set.add("c");
    set.add("c");

    Iterator<String> it = set.iterator();

    while (it.hasNext()) { // remove all instances of "b"
      if (it.next().equals("b")) {
        it.remove();
      }
    }
    assertEquals(2, set.getCount("a"));
    assertEquals(0, set.getCount("b"));
    assertEquals(2, set.getCount("c"));
    assertEquals(4, set.size());

    it = set.iterator();
    while (it.hasNext()) { // remove all instances of "a"
      if (it.next().equals("a")) {
        it.remove();
      }
    }
    assertEquals(0, set.getCount("a"));
    assertEquals(0, set.getCount("b"));
    assertEquals(2, set.getCount("c"));
    assertEquals(2, set.size());

    it = set.iterator();
    while (it.hasNext()) { // remove all instances of "c"
      if (it.next().equals("c")) {
        it.remove();
      }
    }
    assertEquals(0, set.getCount("a"));
    assertEquals(0, set.getCount("b"));
    assertEquals(0, set.getCount("c"));
    assertEquals(0, set.size());
  }

  @Test
  public void testIteratorDoubleRemoveThrowsIllegalStateException() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    Iterator<String> it = set.iterator();
    it.next();
    it.remove();

    // Test that calling remove twice in a row throws IllegalStateException
    assertThrows(IllegalStateException.class, () -> it.remove());
  }

  @Test
  public void testIteratorRemoveBeforeNextThrowsIllegalStateException() {

    Multiset<String> set = makeMultiset(100);
    set.add("a");
    set.add("a");
    Iterator<String> it = set.iterator();
    assertThrows(IllegalStateException.class, () -> it.remove());
  }

  @Test
  public void testIteratorRemoveDuringIteration() {

    Multiset<String> a = makeMultiset(10);
    a.add("A", 3);
    a.add("B", 2);

    Iterator<String> it = a.iterator();
    it.next();
    it.remove();
    it.next();

    HashMap<String, Integer> map = new HashMap<>();
    map.put("A", 0);
    map.put("B", 0);
    while (it.hasNext()) {
      String cur = it.next();
      map.put(cur, map.get(cur) + 1);
    }
    assertEquals(1, map.get("A"));
    assertEquals(2, map.get("B"));

    it = a.iterator();
    map = new HashMap<>();
    map.put("A", 0);
    map.put("B", 0);
    while (it.hasNext()) {
      String cur = it.next();
      map.put(cur, map.get(cur) + 1);
    }
    assertEquals(2, map.get("A"));
    assertEquals(2, map.get("B"));
  }

  @Test
  public void testIteratorRemoveItemFullyDuringIteration() {

    Multiset<String> a = makeMultiset(10);
    a.add("A", 1);
    a.add("B", 2);

    Iterator<String> it = a.iterator();
    it.next();
    it.remove();

    HashMap<String, Integer> map = new HashMap<>();
    map.put("A", 0);
    map.put("B", 0);
    while (it.hasNext()) {
      String cur = it.next();
      map.put(cur, map.get(cur) + 1);
    }
    assertEquals(0, map.get("A"));
    assertEquals(2, map.get("B"));

    it = a.iterator();
    map = new HashMap<>();
    map.put("A", 0);
    map.put("B", 0);
    while (it.hasNext()) {
      String cur = it.next();
      map.put(cur, map.get(cur) + 1);
    }
    assertEquals(0, map.get("A"));
    assertEquals(2, map.get("B"));
  }

  // --------------------------------------------
  // TESTS FOR CLEAR
  // --------------------------------------------

  @Test
  public void testClear() {
    Multiset<String> set = makeMultiset(100);

    set.add("a");
    set.add("a");
    set.add("c");

    set.clear();

    assertEquals(0, set.size());
    assertEquals(0, set.getCount("a"));
    assertEquals(0, set.getCount("c"));
  }

  // --------------------------------------------
  // TESTS FOR REMOVE
  // (Note that remove is implicitly covered by lots of other tests. These
  // are mostly intended to test the return value.)
  // --------------------------------------------

  @Test
  public void testRemoveFromEmptyReturnsFalse() {
    Multiset<String> set = makeMultiset(100);
    assertFalse(set.remove("A"));
  }

  @Test
  public void testRemoveSingletonReturnsTrue() {
    Multiset<String> set = makeMultiset(100);
    set.add("A");
    assertTrue(set.remove("A"));
  }

  @Test
  public void testRemoveDuplicateReturnsTrue() {
    Multiset<String> set = makeMultiset(100);
    set.add("A");
    set.add("A");
    set.remove("A");
    assertTrue(set.remove("A"));
  }

  @SuppressWarnings("unlikely-arg-type")
  @Test
  public void testRemoveObjectArgument() {
    Multiset<String> set = makeMultiset(100);
    set.add("a");
    Integer integer = Integer.valueOf(3);
    assertFalse(set.remove(integer));
  }


  // --------------------------------------------
  // TESTS FOR toString()
  // --------------------------------------------

  @Test
  public void testToStringEmpty() {
    Multiset<String> set = makeMultiset(100);

    assertEquals("[]", set.toString());
  }

  @Test
  public void testToStringSingleElement() {
    Multiset<String> set = makeMultiset(100);
    set.add("C");

    assertEquals("[C]", set.toString());
  }

  @Test
  public void testToStringDuplicateElement() {
    Multiset<String> set = makeMultiset(100);
    set.add("C");
    set.add("C");

    assertEquals("[C, C]", set.toString());
  }

  @Test
  public void testToStringMultipleElements() {
    Multiset<String> set = makeMultiset(100);
    set.add("B");
    set.add("C");
    set.add("C");

    boolean ok = "[B, C, C]".equals(set.toString()) || "[C, B, C]".equals(set.toString())
        || "[C, C, B]".equals(set.toString());
    assertTrue(ok);
  }

  @Test
  public void testToStringAfterRemoval() {
    Multiset<String> set = makeMultiset(100);
    set.add("B");
    set.add("C");
    set.add("C");
    set.remove("B");
    assertEquals("[C, C]", set.toString());
  }

}
