package hws.hw6;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Font;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Set;

import javax.swing.BorderFactory;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JMenu;
import javax.swing.JMenuBar;
import javax.swing.JMenuItem;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTextArea;
import javax.swing.SwingUtilities;
import javax.swing.filechooser.FileNameExtensionFilter;
import java.io.File;

/**
 * LanguageModelGUI
 *
 * @author CS159 Instructors and Claude Sonnet 4.5
 * @version Fall 2025
 */
public class LanguageModelGUI extends JFrame {
    // Model
    private final SmallLanguageModel model;

    // Visualizer controls
    private JComboBox<String> allWordsCombo;
    private JPanel nextBarsPanel;

    // Generator controls
    private JComboBox<String> startingWordsCombo;
    private JButton generateButton;
    private JTextArea logArea;

    public LanguageModelGUI(SmallLanguageModel model) {
        this.model = model;

        setTitle("LanguageModelGUI");
        setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        setSize(980, 720);
        setLayout(new BorderLayout(10, 10));

        // Menu
        setJMenuBar(createMenuBar());

        // Build UI
        add(buildVisualizerPanel(), BorderLayout.CENTER);
        add(buildGeneratorPanel(), BorderLayout.SOUTH);

        setLocationRelativeTo(null);
    }

    // ---------- UI builders ----------

    private JPanel buildVisualizerPanel() {
        JPanel root = new JPanel(new BorderLayout(8, 8));
        root.setBorder(BorderFactory.createTitledBorder(
                BorderFactory.createLineBorder(Color.GRAY),
                "Visualizer: Next-Word Distribution",
                0, 0, new Font("Arial", Font.BOLD, 13)));

        // Controls row
        JPanel controls = new JPanel(new FlowLayout(FlowLayout.LEFT, 10, 8));
        JLabel pickLbl = new JLabel("Word:");
        pickLbl.setFont(new Font("Arial", Font.BOLD, 13));
        controls.add(pickLbl);

        allWordsCombo = new JComboBox<>(getAllTokensSorted());
        allWordsCombo.setPreferredSize(new Dimension(260, 28));
        allWordsCombo.addActionListener(e -> onAllWordSelected());
        controls.add(allWordsCombo);
        root.add(controls, BorderLayout.NORTH);

        // Bars area
        nextBarsPanel = new JPanel();
        nextBarsPanel.setLayout(new BoxLayout(nextBarsPanel, BoxLayout.Y_AXIS));
        nextBarsPanel.setBackground(Color.WHITE);
        JScrollPane sp = new JScrollPane(nextBarsPanel);
        sp.setVerticalScrollBarPolicy(JScrollPane.VERTICAL_SCROLLBAR_AS_NEEDED);
        sp.setHorizontalScrollBarPolicy(JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);
        root.add(sp, BorderLayout.CENTER);

        // Initial selection
        if (allWordsCombo.getItemCount() > 0) {
            allWordsCombo.setSelectedIndex(0);
            onAllWordSelected();
        } else {
            showEmptyVisualizerMessage();
        }

        return root;
    }

    private JPanel buildGeneratorPanel() {
        JPanel root = new JPanel(new BorderLayout(8, 8));
        root.setBorder(BorderFactory.createTitledBorder(
                BorderFactory.createLineBorder(Color.GRAY),
                "Generator: Create Sentences from Starting Words",
                0, 0, new Font("Arial", Font.BOLD, 13)));

        // Controls row
        JPanel controls = new JPanel(new FlowLayout(FlowLayout.LEFT, 10, 8));
        JLabel startLbl = new JLabel("Start:");
        startLbl.setFont(new Font("Arial", Font.BOLD, 13));
        controls.add(startLbl);

        startingWordsCombo = new JComboBox<>(getStartingTokensSorted());
        startingWordsCombo.setPreferredSize(new Dimension(260, 28));
        controls.add(startingWordsCombo);

        generateButton = new JButton("Generate Sentence");
        generateButton.setFont(new Font("Arial", Font.BOLD, 13));
        generateButton.setPreferredSize(new Dimension(180, 30));
        generateButton.addActionListener(e -> onGenerate());
        controls.add(generateButton);

        root.add(controls, BorderLayout.NORTH);

        // Log area
        logArea = new JTextArea(6, 50);
        logArea.setLineWrap(true);
        logArea.setWrapStyleWord(true);
        logArea.setEditable(false);
        logArea.setFont(new Font("Monospaced", Font.PLAIN, 13));
        logArea.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
        JScrollPane logScroll = new JScrollPane(logArea);
        root.add(logScroll, BorderLayout.CENTER);

        return root;
    }

    // ---------- Actions ----------

    private void onAllWordSelected() {
        String token = (String) allWordsCombo.getSelectedItem();
        nextBarsPanel.removeAll();

        if (token == null || token.isEmpty()) {
            showEmptyVisualizerMessage();
            return;
        }

        WordNode node = model.getWordNodes().get(token);
        if (node == null) {
            showEmptyVisualizerMessage();
            return;
        }

        int total = Math.max(0, node.getAllNextNodes().size());
        Set<WordNode> nexts = node.getNextNodes();
        if (nexts == null || nexts.isEmpty() || total == 0) {
            JLabel none = new JLabel("No next words for '" + token + "'.");
            none.setFont(new Font("Arial", Font.ITALIC, 12));
            none.setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10));
            nextBarsPanel.add(none);
        } else {
            List<WordNode> sorted = new ArrayList<>(nexts);
            sorted.sort(new Comparator<WordNode>() {
                @Override
                public int compare(WordNode a, WordNode b) {
                    int ca = node.getNextCount(a);
                    int cb = node.getNextCount(b);
                    int cmp = Integer.compare(cb, ca); // desc by count
                    if (cmp != 0) return cmp;
                    return a.getToken().compareToIgnoreCase(b.getToken());
                }
            });

            int max = 0;
            for (WordNode n : sorted) {
                max = Math.max(max, node.getNextCount(n));
            }
            max = Math.max(1, max);

            for (WordNode n : sorted) {
                int c = node.getNextCount(n);
                JPanel bar = createBarRow(n.getToken(), c, max, total);
                nextBarsPanel.add(bar);
            }
        }

        nextBarsPanel.revalidate();
        nextBarsPanel.repaint();
    }

    private void onGenerate() {
        String startToken = (String) startingWordsCombo.getSelectedItem();
        if (startToken == null || startToken.isEmpty()) {
            logArea.append("[warn] Please choose a starting word.\n");
            return;
        }

        WordNode start = model.getWordNodes().get(startToken);
        if (start == null) {
            logArea.append("[warn] Starting token missing in model: " + startToken + "\n");
            return;
        }

        String sentence = model.generateSentence(start);
        logArea.append(sentence.trim());
        logArea.append("\n\n");
        logArea.setCaretPosition(logArea.getDocument().getLength());
    }

    // ---------- Helpers ----------

    private String[] getAllTokensSorted() {
        List<String> tokens = new ArrayList<>(model.getWordNodes().keySet());
        tokens.sort(String.CASE_INSENSITIVE_ORDER);
        return tokens.toArray(new String[0]);
    }

    private String[] getStartingTokensSorted() {
        List<String> tokens = new ArrayList<>();
        for (WordNode n : model.getStartingWords()) {
            tokens.add(n.getToken());
        }
        tokens.sort(String.CASE_INSENSITIVE_ORDER);
        return tokens.toArray(new String[0]);
    }

    private void showEmptyVisualizerMessage() {
        nextBarsPanel.removeAll();
        JLabel msg = new JLabel("No words to visualize. Train the model first.");
        msg.setFont(new Font("Arial", Font.ITALIC, 12));
        msg.setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10));
        nextBarsPanel.add(msg);
        nextBarsPanel.revalidate();
        nextBarsPanel.repaint();
    }

    private JPanel createBarRow(String word, int count, int maxCount, int totalCount) {
        JPanel row = new JPanel(new BorderLayout(6, 4));
        row.setBorder(BorderFactory.createEmptyBorder(5, 10, 5, 10));
        row.setBackground(Color.WHITE);

        JLabel wordLbl = new JLabel(word);
        wordLbl.setPreferredSize(new Dimension(120, 20));
        row.add(wordLbl, BorderLayout.WEST);

        int width = (int) Math.round(450.0 * count / maxCount);
        JPanel bar = new JPanel();
        bar.setPreferredSize(new Dimension(Math.max(0, width), 20));
        bar.setBackground(new Color(90, 155, 250));
        bar.setBorder(BorderFactory.createLineBorder(new Color(60, 120, 220)));

        JPanel barWrap = new JPanel(new BorderLayout());
        barWrap.setBackground(Color.WHITE);
        barWrap.add(bar, BorderLayout.WEST);
        row.add(barWrap, BorderLayout.CENTER);

        double pct = totalCount > 0 ? (count * 100.0 / totalCount) : 0.0;
        JLabel num = new JLabel(String.format(" %d (%.1f%%)", count, pct));
        num.setPreferredSize(new Dimension(110, 20));
        num.setFont(new Font("Monospaced", Font.PLAIN, 11));
        row.add(num, BorderLayout.EAST);

        return row;
    }

    // ---------- Menu ----------

    private JMenuBar createMenuBar() {
        JMenuBar bar = new JMenuBar();

        JMenu file = new JMenu("File");

        JMenuItem newModel = new JMenuItem("New Model");
        newModel.addActionListener(e -> onNewModel());
        file.add(newModel);

        JMenuItem train = new JMenuItem("Train on File...");
        train.addActionListener(e -> onTrainOnFile());
        file.add(train);

        file.addSeparator();

        JMenuItem exit = new JMenuItem("Exit");
        exit.addActionListener(e -> System.exit(0));
        file.add(exit);

        bar.add(file);
        return bar;
    }

    private void onNewModel() {
        int choice = JOptionPane.showConfirmDialog(this,
                "Clear current model and start new?",
                "New Model",
                JOptionPane.OK_CANCEL_OPTION,
                JOptionPane.WARNING_MESSAGE);
        if (choice != JOptionPane.OK_OPTION) return;

        // Clear the existing model (field is final)
        model.getWordNodes().clear();
        model.getStartingWords().clear();

        refreshCombosAndViews();
        logArea.setText("");
    }

    private void onTrainOnFile() {
        JFileChooser chooser = new JFileChooser();
        chooser.setDialogTitle("Select text file(s) to train");
        chooser.setFileFilter(new FileNameExtensionFilter("Text Files (*.txt)", "txt"));
        chooser.setCurrentDirectory(new File("src/hws/hw6"));
        chooser.setMultiSelectionEnabled(true);
        int result = chooser.showOpenDialog(this);
        if (result == JFileChooser.APPROVE_OPTION) {
            File[] files = chooser.getSelectedFiles();
            try {
                for (File f : files) {
                    model.trainOnFile(f.getAbsolutePath());
                }
                refreshCombosAndViews();
            } catch (Exception ex) {
                JOptionPane.showMessageDialog(this,
                        "Failed to train on file:\n" + ex.getMessage(),
                        "Error",
                        JOptionPane.ERROR_MESSAGE);
            }
        }
    }

    private void refreshCombosAndViews() {
        // Update both combo boxes
        String[] allTokens = getAllTokensSorted();
        allWordsCombo.removeAllItems();
        for (String t : allTokens) allWordsCombo.addItem(t);

        String[] startTokens = getStartingTokensSorted();
        startingWordsCombo.removeAllItems();
        for (String t : startTokens) startingWordsCombo.addItem(t);

        // Refresh visualizer panel
        if (allWordsCombo.getItemCount() > 0) {
            allWordsCombo.setSelectedIndex(0);
            onAllWordSelected();
        } else {
            showEmptyVisualizerMessage();
        }
    }

    // ---------- Main ----------

    public static void main(String[] args) {
        SwingUtilities.invokeLater(() -> {
            SmallLanguageModel model = new SmallLanguageModel();
            LanguageModelGUI gui = new LanguageModelGUI(model);
            gui.setVisible(true);
        });
    }
}
