import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import java.util.Scanner;
import java.util.TreeMap;

/**
 * Comprehensive benchmark for Dictionary implementations.
 *
 * <p>This benchmark performs realistic operations on a text index:
 * <ul>
 *   <li>Indexing: Build the index from a text file</li>
 *   <li>Successful lookups: Search for words that exist</li>
 *   <li>Unsuccessful lookups: Search for words that don't exist</li>
 *   <li>Removals: Delete existing entries</li>
 *   <li>Re-insertions: Add entries back after removal</li>
 * </ul>
 *
 * @author CS240 Instructors and Claude Sonnet 4.5
 * @version 11/2025
 */
public class DictionaryBenchmark {

  private static final int WARMUP_ITERATIONS = 3;
  private static final int TRIAL_ITERATIONS = 10;
  private static final int LOOKUP_COUNT = 10000;
  private static final int REMOVAL_COUNT = 1000;
  private static final Random random = new Random(42); // Fixed seed for reproducibility

  /**
   * Main entry point for benchmarking.
   *
   * <p>Usage: java hw.hashing.DictionaryBenchmark &lt;filename&gt;
   *
   * @param args command line arguments
   */
  public static void main(String[] args) {
    if (args.length != 1) {
      System.err.println("Usage: java hw.hashing.DictionaryBenchmark <filename>");
      return;
    }

    String filename = args[0];

    try {
      // Load all words from file
      ArrayList<String> allWords = loadWords(filename);
      ArrayList<String> uniqueWords = getUniqueWords(allWords);

      System.out.println("======================================================================");
      System.out.println("DICTIONARY IMPLEMENTATION BENCHMARK");
      System.out.println("======================================================================");
      System.out.println("File: " + filename);
      System.out.println("Total words: " + allWords.size());
      System.out.println("Unique words: " + uniqueWords.size());
      System.out.println();
      System.out.println("Running " + WARMUP_ITERATIONS + " warm-up iterations...");
      System.out.println("Then " + TRIAL_ITERATIONS + " timed trials (reporting median)");
      System.out.println();

      // Benchmark HashTableDone
      System.out.println("----------------------------------------------------------------------");
      System.out.println("CS 240 HashDictionary");
      System.out.println("----------------------------------------------------------------------");
      HashDictionary<String, ArrayList<Integer>> index = new HashDictionary<>();
      benchmarkDictionary(index, allWords, uniqueWords);
      System.out.println(String.format("\nFinal load factor: %.2f", index.loadFactor()));

      System.out.println();

      // Benchmark Java HashMap
      System.out.println("----------------------------------------------------------------------");
      System.out.println("Java HashMap (java.util.HashMap)");
      System.out.println("----------------------------------------------------------------------");
      benchmarkDictionary(new JavaHashMapWrapper<>(), allWords, uniqueWords);
      System.out.println();

      // Benchmark Java TreeMap
      System.out.println("----------------------------------------------------------------------");
      System.out.println("Java TreeMap (java.util.TreeMap - Red-Black Tree)");
      System.out.println("----------------------------------------------------------------------");
      benchmarkDictionary(new JavaTreeMapWrapper<>(), allWords, uniqueWords);
      System.out.println();

      // Benchmark BST Dictionary
      System.out.println("----------------------------------------------------------------------");
      System.out.println("CS 240 BSTDictionary (Non-Self-Balancing BST)");
      System.out.println("----------------------------------------------------------------------");
      benchmarkDictionary(new BSTDictionary<>(), allWords, uniqueWords);

    } catch (FileNotFoundException e) {
      System.err.println("ERROR: File not found: " + e.getMessage());
    }
  }

  /**
   * Benchmark a dictionary implementation with various operations.
   *
   * @param prototype prototype dictionary (will be cleared and reused)
   * @param allWords all words from the text file
   * @param uniqueWords unique words from the text file
   */
  private static void benchmarkDictionary(
      Dictionary<String, ArrayList<Integer>> prototype,
      ArrayList<String> allWords,
      ArrayList<String> uniqueWords) {

    long[] indexTimes = new long[TRIAL_ITERATIONS];
    long[] successLookupTimes = new long[TRIAL_ITERATIONS];
    long[] failLookupTimes = new long[TRIAL_ITERATIONS];
    long[] removalTimes = new long[TRIAL_ITERATIONS];
    long[] reinsertTimes = new long[TRIAL_ITERATIONS];

    // Generate words that are NOT in the dictionary for unsuccessful lookups
    ArrayList<String> absentWords = generateAbsentWords(uniqueWords, LOOKUP_COUNT);

    // Warm-up iterations
    for (int i = 0; i < WARMUP_ITERATIONS; i++) {
      runBenchmarkIteration(prototype, allWords, uniqueWords, absentWords,
          null, null, null, null, null);
    }

    // Timed trials
    for (int i = 0; i < TRIAL_ITERATIONS; i++) {
      runBenchmarkIteration(prototype, allWords, uniqueWords, absentWords,
          indexTimes, successLookupTimes, failLookupTimes, removalTimes, reinsertTimes);

      if (i < TRIAL_ITERATIONS - 1) {
        System.gc(); // Suggest garbage collection between trials
        try {
          Thread.sleep(100); // Brief pause
        } catch (InterruptedException e) {
          // Ignore
        }
      }
    }

    // Report median times
    System.out.println("Indexing time (median):            " + formatTime(median(indexTimes)));
    System.out.println("Successful lookups (median):       "
        + formatTime(median(successLookupTimes)) + " (" + LOOKUP_COUNT + " lookups)");
    System.out.println("Unsuccessful lookups (median):     "
        + formatTime(median(failLookupTimes)) + " (" + LOOKUP_COUNT + " lookups)");
    System.out.println("Removals (median):                 " + formatTime(median(removalTimes))
        + " (" + REMOVAL_COUNT + " removals)");
    System.out.println("Re-insertions (median):            " + formatTime(median(reinsertTimes))
        + " (" + REMOVAL_COUNT + " insertions)");

    // Calculate per-operation times
    double avgSuccessLookup = median(successLookupTimes) / (double) LOOKUP_COUNT;
    double avgFailLookup = median(failLookupTimes) / (double) LOOKUP_COUNT;
    double avgRemoval = median(removalTimes) / (double) REMOVAL_COUNT;
    double avgReinsertion = median(reinsertTimes) / (double) REMOVAL_COUNT;

    System.out.println();
    System.out.println("Per-operation times:");
    System.out.println("  Successful lookup:   " + formatNanos(avgSuccessLookup));
    System.out.println("  Unsuccessful lookup: " + formatNanos(avgFailLookup));
    System.out.println("  Removal:             " + formatNanos(avgRemoval));
    System.out.println("  Insertion:           " + formatNanos(avgReinsertion));
  }

  /**
   * Run a single benchmark iteration.
   */
  private static void runBenchmarkIteration(
      Dictionary<String, ArrayList<Integer>> dict,
      ArrayList<String> allWords,
      ArrayList<String> uniqueWords,
      ArrayList<String> absentWords,
      long[] indexTimes,
      long[] successLookupTimes,
      long[] failLookupTimes,
      long[] removalTimes,
      long[] reinsertTimes) {

    int trial = (indexTimes == null) ? -1 : findNextIndex(indexTimes);

    // Clear dictionary for fresh start
    dict.clear();

    // 1. Indexing phase
    long startTime = System.nanoTime();
    int position = 0;
    for (String word : allWords) {
      ArrayList<Integer> positions = dict.get(word);
      if (positions == null) {
        positions = new ArrayList<>();
        dict.put(word, positions);
      }
      positions.add(position++);
    }
    long indexTime = System.nanoTime() - startTime;
    if (indexTimes != null) {
      indexTimes[trial] = indexTime;
    }

    // 2. Successful lookups (words that exist)
    ArrayList<String> lookupWords = selectRandomWords(uniqueWords, LOOKUP_COUNT);
    startTime = System.nanoTime();
    for (String word : lookupWords) {
      dict.get(word);
    }
    long successTime = System.nanoTime() - startTime;
    if (successLookupTimes != null) {
      successLookupTimes[trial] = successTime;
    }

    // 3. Unsuccessful lookups (words that don't exist)
    startTime = System.nanoTime();
    for (String word : absentWords) {
      dict.get(word);
    }
    long failTime = System.nanoTime() - startTime;
    if (failLookupTimes != null) {
      failLookupTimes[trial] = failTime;
    }

    // 4. Removals
    ArrayList<String> wordsToRemove = selectRandomWords(uniqueWords, REMOVAL_COUNT);
    startTime = System.nanoTime();
    for (String word : wordsToRemove) {
      dict.remove(word);
    }
    long removalTime = System.nanoTime() - startTime;
    if (removalTimes != null) {
      removalTimes[trial] = removalTime;
    }

    // 5. Re-insertions (put the removed words back)
    startTime = System.nanoTime();
    for (String word : wordsToRemove) {
      dict.put(word, new ArrayList<>());
    }
    long reinsertTime = System.nanoTime() - startTime;
    if (reinsertTimes != null) {
      reinsertTimes[trial] = reinsertTime;
    }
  }

  /**
   * Load all words from a text file.
   */
  private static ArrayList<String> loadWords(String filename) throws FileNotFoundException {
    ArrayList<String> words = new ArrayList<>();
    Scanner scanner = new Scanner(new File(filename));

    while (scanner.hasNext()) {
      String word = scanner.next();
      word = word.replaceAll("[^a-zA-Z]", "").toLowerCase();
      if (!word.isEmpty()) {
        words.add(word);
      }
    }
    scanner.close();
    return words;
  }

  /**
   * Get unique words from a list (preserving first occurrence order).
   */
  private static ArrayList<String> getUniqueWords(ArrayList<String> words) {
    HashDictionary<String, Boolean> seen = new HashDictionary<>();
    ArrayList<String> unique = new ArrayList<>();

    for (String word : words) {
      if (seen.get(word) == null) {
        seen.put(word, true);
        unique.add(word);
      }
    }
    return unique;
  }

  /**
   * Generate words that are guaranteed NOT to be in the dictionary.
   */
  private static ArrayList<String> generateAbsentWords(ArrayList<String> uniqueWords, int count) {
    HashDictionary<String, Boolean> existing = new HashDictionary<>();
    for (String word : uniqueWords) {
      existing.put(word, true);
    }

    ArrayList<String> absent = new ArrayList<>();
    int attempts = 0;
    while (absent.size() < count && attempts < count * 10) {
      String candidate = generateRandomWord();
      if (existing.get(candidate) == null && !absent.contains(candidate)) {
        absent.add(candidate);
      }
      attempts++;
    }
    return absent;
  }

  /**
   * Generate a random word-like string.
   */
  private static String generateRandomWord() {
    int length = 3 + random.nextInt(12); // 3-14 characters
    StringBuilder sb = new StringBuilder();
    for (int i = 0; i < length; i++) {
      sb.append((char) ('a' + random.nextInt(26)));
    }
    return sb.toString();
  }

  /**
   * Select random words from a list (with replacement).
   */
  private static ArrayList<String> selectRandomWords(ArrayList<String> words, int count) {
    ArrayList<String> selected = new ArrayList<>(count);
    for (int i = 0; i < count; i++) {
      selected.add(words.get(random.nextInt(words.size())));
    }
    return selected;
  }

  /**
   * Find the next available index in the array (first one that's 0).
   */
  private static int findNextIndex(long[] array) {
    for (int i = 0; i < array.length; i++) {
      if (array[i] == 0) {
        return i;
      }
    }
    return 0;
  }

  /**
   * Calculate median of an array.
   */
  private static long median(long[] values) {
    long[] sorted = values.clone();
    java.util.Arrays.sort(sorted);
    return sorted[sorted.length / 2];
  }

  /**
   * Format time in milliseconds.
   */
  private static String formatTime(long nanos) {
    return String.format("%.2f ms", nanos / 1_000_000.0);
  }

  /**
   * Format nanoseconds with appropriate unit.
   */
  private static String formatNanos(double nanos) {
    if (nanos < 1000) {
      return String.format("%.2f ns", nanos);
    } else if (nanos < 1_000_000) {
      return String.format("%.2f μs", nanos / 1000);
    } else {
      return String.format("%.2f ms", nanos / 1_000_000);
    }
  }

  // ============================================================================
  // WRAPPER CLASSES FOR JAVA'S BUILT-IN MAPS
  // ============================================================================

  /**
   * Wrapper class that adapts Java's HashMap to implement the Dictionary interface.
   *
   * <p>This allows Java's highly optimized HashMap to be benchmarked alongside
   * custom Dictionary implementations like HashTableDone.
   *
   * @param <K> the type of keys
   * @param <V> the type of values
   */
  static class JavaHashMapWrapper<K, V> implements Dictionary<K, V> {

    private HashMap<K, V> map;

    /**
     * Create a new wrapper around an empty HashMap.
     */
    public JavaHashMapWrapper() {
      map = new HashMap<>();
    }

    @Override
    public void put(K key, V value) {
      map.put(key, value);
    }

    @Override
    public V get(K key) {
      return map.get(key);
    }

    @Override
    public V remove(K key) {
      return map.remove(key);
    }

    @Override
    public void clear() {
      map.clear();
    }

    @Override
    public int size() {
      return map.size();
    }

    @Override
    public Iterator<K> iterator() {
      return map.keySet().iterator();
    }
  }

  /**
   * Wrapper class that adapts Java's TreeMap to implement the Dictionary interface.
   *
   * <p>This allows Java's TreeMap (a red-black tree implementation) to be benchmarked
   * alongside custom Dictionary implementations. TreeMap provides O(log n) operations
   * and maintains keys in sorted order.
   *
   * @param <K> the type of keys (must be Comparable)
   * @param <V> the type of values
   */
  static class JavaTreeMapWrapper<K, V> implements Dictionary<K, V> {

    private TreeMap<K, V> map;

    /**
     * Create a new wrapper around an empty TreeMap.
     */
    public JavaTreeMapWrapper() {
      map = new TreeMap<>();
    }

    @Override
    public void put(K key, V value) {
      map.put(key, value);
    }

    @Override
    public V get(K key) {
      return map.get(key);
    }

    @Override
    public V remove(K key) {
      return map.remove(key);
    }

    @Override
    public void clear() {
      map.clear();
    }

    @Override
    public int size() {
      return map.size();
    }

    @Override
    public Iterator<K> iterator() {
      return map.keySet().iterator();
    }
  }
}
