
import { onMounted, computed, reactive, toRefs, h, watch, ref, defineComponent } from "vue";
import { useStore } from "@/store/index";

import MatrixView from "./MatrixView/MatrixView.vue";
import Legend from "./Legend/Legend.vue";
import DataGrid from "./DataGrid/DataGrid.vue";
import Circle from './MatrixView/Circle.vue';
import { ArrowUpOutlined, ArrowDownOutlined, ArrowLeftOutlined, ArrowRightOutlined } from "@ant-design/icons-vue";

const text_color_info = [
    {
        label: "query vs. key",
        value: "query_key",
        desc: "token type, query or key"
    },
    {
        label: "position",
        value: "position",
        desc: "token position in sentence (normalized), darker = later in the sentence"
    },
    {
        label: "position mod 5",
        value: "pos_mod_5",
        desc: "token position in sentence modulo 5 (unnormalized)"
    },
    {
        label: "punctuation",
        value: "punctuation",
        desc: "punctuation vs. non-punctuation tokens"
    },
    {
        label: "embedding norm",
        value: "embed_norm",
        desc: "token embedding norm, darker = higher norm"
    },
    {
        label: "token length",
        value: "token_length",
        desc: "number of chars in token, darker = longer token"
    },
    {
        label: "sentence length",
        value: "sent_length",
        desc: "number of tokens in sentence, darker = longer sentence"
    },
    {
        label: "token frequency",
        value: "token_freq",
        desc: "frequency of token in dataset, darker = more frequent"
    }
];

const image_color_info = [
    {
        label: "query vs. key",
        value: "query_key",
        desc: "token type, query or key (outline)"
    },
    {
        label: "query vs. key (fill)",
        value: "qk_map",
        desc: "token type, query or key (fill)"
    },
    {
        label: "patch row",
        value: "row",
        desc: "row of patch in image, darker = later row"
    },
    {
        label: "patch column",
        value: "column",
        desc: "column of patch in image, darker = later column"
    },
    {
        label: "no outline",
        value: "no_outline",
        desc: "original image patch without q/k outline"
    }
];

export default defineComponent({
    components: { MatrixView, Legend, DataGrid, Circle, ArrowUpOutlined, ArrowDownOutlined, ArrowLeftOutlined, ArrowRightOutlined },
    setup() {
        const store = useStore();

        const matrixView = ref(null);
        const legend = ref(null);
        const dataGrid = ref(null);

        const state = reactive({
            mode: computed(() => store.state.mode),
            renderState: computed(() => store.state.renderState),
            searchToken: "",
            view: computed(() => store.state.view),
            showAll: computed({
                get: () => store.state.showAll,
                set: (v) => store.commit("setShowAll", v)
            }),
            sizeByNorm: computed({
                get: () => store.state.sizeByNorm,
                set: (v) => store.commit("setSizeByNorm", v)
            }),
            showAttention: computed({
                get: () => store.state.showAttention,
                set: (v) => store.commit("setShowAttention", v)
            }),

            attnLoading: computed(() => store.state.attentionLoading),
            layer: computed(() => store.state.layer),
            head: computed(() => store.state.head),

            dimension: computed({
                get: () => store.state.dimension,
                set: (v) => store.commit("setDimension", v)
            }),

            placeholder: "",
            num_message: "",

            colorBy: computed({
                get: () => store.state.colorBy,
                set: (v) => store.commit("setColorBy", v),
            }),
            colorByOptions: [] as any,
            colorByDict: {} as any,

            projectionMethod: computed({
                get: () => store.state.projectionMethod,
                set: (v) => store.commit("setProjectionMethod", v),
            }),
            projectionMethods: ["tsne", "umap", "pca"].map((x) => ({ value: x, label: x })),

            modelType: computed({
                get: () => store.state.modelType,
                set: (v) => store.dispatch("switchModel", v)
            }),
            // modelOptions: ["vit-16", "vit-32", "bert", "gpt-2"].map((x) => (
            modelOptions: ["vit-nat", "bert", "gpt-2"].map((x) => (
                // modelOptions: ["vit-nat", "vit-syn", "bert", "gpt-2"].map((x) => (
                { value: x, label: x }
            )),

            tokenData: computed(() => store.state.tokenData),
            attnMsg: "click a plot to zoom in",
            curLayer: computed(() => store.state.layer),
            curHead: computed(() => store.state.head),
            showAttn: computed(() => store.state.showAttn),
            transitionInProgress: computed(() => store.state.transitionInProgress),
            clearSelection: computed(() => store.state.clearSelection),
        });

        onMounted(() => {
            switchViewMsg();
            switchPlaceholder();
            switchColorOptions();
        })

        const onClickReset = () => {
            (matrixView.value as any).resetZoom();
        }

        const resetToMatrix = () => {
            (matrixView.value as any).reset(true);
        }

        const clearSearch = () => {
            store.commit("updateTransitionInProgress", true);
            state.searchToken = "";
            // actually need to clear search results from scatterplot
            store.commit("setHighlightedTokenIndices", []);
            store.commit("setView", "none");
            store.commit("updateTransitionInProgress", false);
        }

        const onSearch = (str: string) => {
            if (str == "") { // don't search empty string
                return;
            }
            if (state.view != "search") {
                store.commit("setView", "search");
            }
            let num_results = (matrixView.value as any).onSearch(str);
            state.searchToken = str + " (" + num_results + " results)";
        }

        const logViewport = () => {
            setTimeout(() => {
                (matrixView.value as any).printViewport();
            }, 100)
        }

        // zoom to single plot
        const zoomToPlot = (layer: number, head: number) => {
            (matrixView.value as any).zoomToPlot(layer, head, true, false);
        }

        // move 1 plot up/left/right/down
        const moveToPlot = (direction: string) => {
            switch (direction) {
                case "up":
                    zoomToPlot((state.layer as number) - 1, (state.head as number));
                    break;
                case "left":
                    zoomToPlot((state.layer as number), (state.head as number) - 1);
                    break;
                case "right":
                    zoomToPlot((state.layer as number), (state.head as number) + 1);
                    break;
                case "down":
                    zoomToPlot((state.layer as number) + 1, (state.head as number));
                    break;
                default:
                    throw Error("Invalid direction!");
            }
        }

        // switch placeholder text
        const switchPlaceholder = () => {
            if (state.modelType.includes('vit')) {
                state.placeholder = "e.g., person, background";
            } else {
                state.placeholder = "e.g., cat, april";
            }
        }

        // switch data text
        const switchDataMsg = () => {
            const numTokens = state.tokenData.length;
            const numInstances = (matrixView.value as any).getUnique().length;
            const messageStart = numTokens + " tokens (";

            if (state.modelType.includes("vit")) {
                state.num_message = messageStart + numInstances + " images)";
            } else {
                state.num_message = messageStart + numInstances + " sentences)";
            }
        }

        // clear search/attention
        const clearSelection = () => {
            console.log('clearing selection');
            if (state.view == 'attn') {
                store.commit("updateTransitionInProgress", true);
                store.commit("setShowAttn", false);
                store.commit("updateTransitionInProgress", false);
            } else { // state.view == 'search'
                clearSearch();
            }
        }


        // show image data
        const showImages = () => {
            if (state.modelType.includes('vit')) {
                const images = (matrixView.value as any).getUnique();
                (dataGrid.value as any).drawGrid(images);
            }
        }

        const switchColorOptions = () => {
            // reset color options depending on model selected
            const curColorBy = state.colorBy;
            let color_map = state.modelType.includes("vit") ? image_color_info : text_color_info;

            state.colorByOptions = color_map.map((x) => ({ value: x.value, label: x.label }));
            state.colorByDict = Object.assign({}, ...color_map.map((x) => ({ [x.value]: { label: x.label, desc: x.desc } })));

            if (!(curColorBy in state.colorByDict)) {
                if (state.modelType.includes('vit')) {
                    store.commit("setColorBy", "no_outline"); // make no outline the default for vit
                } else {
                    store.commit("setColorBy", "query_key"); // query vs. key default for bert/gpt
                }
            }
        }

        // switch color msg
        const getColorMsg = () => {
            const colorMsg = state.colorByDict[state.colorBy].desc;
            (legend.value as any).setColorMsg(colorMsg);
        }

        const switchViewMsg = () => {
            state.attnMsg = state.mode == "single"
                ? "click a point to explore its attention"
                : "click a plot to zoom in";
        }

        watch(() => state.view,
            () => {
                if (state.view == "attn") {
                    state.searchToken = "";
                }
            })

        watch(
            () => [state.mode],
            () => {
                switchViewMsg();
            }
        )

        watch(() => state.modelType,
            () => { // clear highlighted tokens for simplicity
                store.commit("setHighlightedTokenIndices", []);
                if (state.searchToken.length > 0) {
                    state.searchToken = "";
                    store.commit("setView", "none");
                }
                if (state.attnLoading) {
                    store.commit("updateAttentionLoading", false);
                }

                switchPlaceholder();
                switchColorOptions();
                onClickReset();
            })

        watch(() => state.renderState, () => {
            if (!state.renderState) {
                showImages();
                switchDataMsg();
                getColorMsg();
            }
        })

        // change color msg
        watch([() => state.colorBy],
            () => {
                getColorMsg();
            })

        // clear search from matrix view if user unselects
        watch([() => state.clearSelection],
            () => {
                if (state.clearSelection) {
                    clearSelection();
                    store.commit("setClearSelection", false);
                }
            })

        return {
            ...toRefs(state),
            legend,
            dataGrid,
            matrixView,
            onClickReset,
            resetToMatrix,
            clearSearch,
            onSearch,
            logViewport,
            zoomToPlot,
            moveToPlot,
            switchPlaceholder,
            clearSelection,
            getColorMsg
        };
    }
});
