import {
  CellIdToResultMapType,
  ColumnIdToColumnHeaderMapType,
  ColumnIdToToolMapType,
  ColumnIdYToCellMapType,
} from "source/redux/matrix";
import _ from "lodash";
import {
  DEFAULT_TOOL_PARAM_OUTPUT_TYPE,
  SUMMARY_COL_NAME,
  SUMMARY_COL_NUM,
  SUMMARY_PROMPT,
} from "source/constants";
import { RowCellState } from "source/api/matrix/types";
import {
  DEFAULT_ROW_RUN_LIMIT,
  RETRIEVE_COL_MAGIC_VALUE,
} from "../../components/matrix/tables/config";
import {
  AnswerToolType,
  ReduxTool,
  ReportToolDependency,
  ReportToolParamType,
  ReportToolType,
} from "../../components/matrix/types/tools.types";
import { ReportTableRow } from "../../components/matrix/types/cells.types";
import { v4 as uuidv4 } from "uuid";
import { ModelType } from "source/constants/llms";

/**
 * Gets set of coordinates that have already run
 * Considered as run (reads rowStateMap) if answered or in loading state (reads cellResults)
 * Omits anything that is in validCoordinates
 */
export const getRowsAlreadyRan = (
  rowStateMap: { [key: number]: RowCellState },
  cellMap: ColumnIdYToCellMapType,
  cellResults: CellIdToResultMapType,
  tabId: string,
  excludedCoordinateSet: Set<number>
) => {
  // Find which rows have already been run on existing columns
  const rows: number[] = [];
  Object.keys(rowStateMap).forEach((row) => {
    if (rowStateMap[row] === RowCellState.ANSWERED) rows.push(parseInt(row));
  });

  // Also find rows that are in a loading state
  Object.values(cellMap?.[tabId] ?? {}).forEach((col) =>
    Object.entries(col).forEach(([row, cell]) => {
      if (cellResults?.[cell.id]?.loading) {
        rows.push(parseInt(row));
      }
    })
  );
  return _.difference(
    Array.from(new Set(rows)),
    Array.from(excludedCoordinateSet)
  );
};

export const getReduxSummaryPrecomputedTool = (requestId?: string): ReduxTool =>
  generateFindTool({
    tool: {
      name: SUMMARY_COL_NAME,
      tool: "precomputed_summary",
      request_id: requestId,
      x: SUMMARY_COL_NUM,
      prompt: SUMMARY_PROMPT.prompt,
      tool_params: {
        output_type: "str",
        model: "gpt-4t-hebbia",
      },
      coordinates: Array.from(Array(DEFAULT_ROW_RUN_LIMIT).keys()),
      dependency_column_ids: [RETRIEVE_COL_MAGIC_VALUE],
    },
  });

export const generateRetrieveTool = (tool: Partial<ReduxTool>): ReduxTool => ({
  tool: "retrieve",
  request_id: "",
  x: 0,
  dependencies: [-1],
  static_column_id: uuidv4(),
  dependency_column_ids: [],
  ...tool,
});

/**
 * Aggregate doc IDs from list of tools
 */
export const getDocIdsFromTools = (tools: ReduxTool[]) => {
  const docIdSet = tools?.reduce((currDocIds, tool) => {
    // Get Retrieve tool docs
    if (tool.tool === "retrieve" && tool.tool_params?.doc_ids)
      tool.tool_params.doc_ids.forEach((docId) => currDocIds.add(docId));
    return currDocIds;
  }, new Set() as Set<string>);
  if (docIdSet === undefined) return;
  return Array.from(docIdSet);
};

export const getNextX = (tools?: ColumnIdToToolMapType): ReduxTool["x"] => {
  const lastX = Object.values(tools ?? []).reduce((acc: number, { x }) => {
    if (x !== undefined && x > acc) {
      return x;
    }
    return acc;
  }, 0);
  return lastX + 1;
};

export const getRetrieveColumnFromToolMap = (
  toolMap?: ColumnIdToToolMapType
): ReduxTool | undefined => {
  const tools = Object.values(toolMap ?? {});
  return tools.find((tool) => tool.tool === "retrieve");
};

export const getRetrieveColumnIdFromToolMap = (
  toolMap?: ColumnIdToToolMapType
): ReduxTool["static_column_id"] | undefined => {
  return getRetrieveColumnFromToolMap(toolMap)?.static_column_id;
};

export const generateFindTool = ({
  tool,
  defaultToolParams,
}: {
  tool: Partial<ReduxTool>;
  defaultToolParams?: ReportToolParamType;
}): ReduxTool => {
  return {
    ...tool,
    tool: tool.tool ?? "find",
    request_id: tool.request_id || uuidv4(),
    x: tool.x,
    static_column_id: tool.static_column_id || uuidv4(),
    tool_params: {
      ...defaultToolParams,
      model: tool.tool_params?.model ?? defaultToolParams?.model,
      output_type:
        tool.tool_params?.output_type ??
        defaultToolParams?.output_type ??
        DEFAULT_TOOL_PARAM_OUTPUT_TYPE,
      tool_spec: tool.tool_params?.tool_spec ?? defaultToolParams?.tool_spec,
      experiment_config:
        tool.tool_params?.experiment_config ??
        defaultToolParams?.experiment_config,
    },
    dependencies: [0],
    dependency_column_ids: tool.dependency_column_ids?.length
      ? tool.dependency_column_ids
      : isAnswerToolType(tool.tool)
        ? [RETRIEVE_COL_MAGIC_VALUE]
        : [],
    drop_cells: false,
  };
};

export const generateStaticColumnIdDataMapping = <
  T extends { static_column_id?: string },
  K,
>(
  items: T[],
  key?: string
): { [static_column_id: string]: K } => {
  return (
    items?.reduce((acc, item) => {
      if (item.static_column_id) {
        acc = {
          ...acc,
          [item.static_column_id]: (key ? item[key] : item) as K,
        };
      }
      return acc;
    }, {}) ?? {}
  );
};

export const transformReportToolDependencyToReduxTool = (
  reportToolDependency: ReportToolDependency,
  reduxTool?: Partial<ReduxTool>
): ReduxTool => {
  const static_column_id = reportToolDependency.static_column_id ?? "";
  return {
    ...reportToolDependency,
    ...reduxTool,
    static_column_id: static_column_id,
    dependency_column_ids: reportToolDependency.dependency_columns ?? [],
  };
};

export const getSummaryColumnIdFromToolMap = (
  toolMap?: ColumnIdToToolMapType
): ReduxTool["static_column_id"] | undefined => {
  return getSummaryColumnFromToolMap(toolMap)?.static_column_id;
};

export const getSummaryColumnFromToolMap = (
  toolMap?: ColumnIdToToolMapType
): ReduxTool | undefined => {
  const tools = Object.values(toolMap ?? {});
  return tools.find((tool) => isSummaryTool(tool.tool));
};

export const getFindColumnIdsFromTools = (
  toolMap?: ColumnIdToToolMapType
): ReduxTool["static_column_id"][] => {
  return Object.values(toolMap ?? {}).reduce(
    (acc, tool) => {
      if (isAnswerToolType(tool.tool)) acc.push(tool.static_column_id);
      return acc;
    },
    [] as ReduxTool["static_column_id"][]
  );
};

export const getToolTypeFromColumnAndData = (
  row: ReportTableRow,
  columnId: ReduxTool["static_column_id"]
) => row[columnId]?.tool;

export const getToolTypeFromColumnAndTools = (
  tools: ColumnIdToToolMapType,
  columnId: ReduxTool["static_column_id"]
) => tools[columnId]?.tool;

// Type guard function for all answer tools
export const isAnswerToolType = (
  type?: ReportToolType,
  includeSummary = true
): type is AnswerToolType => {
  return (
    type === "find" ||
    (includeSummary && type === "precomputed_summary") ||
    type === "precomputed_fast_find"
  );
};

export const isSummaryTool = (type?: ReportToolType) =>
  type === "precomputed_summary";

export const getFallbackModel = (model?: ModelType) => {
  switch (model) {
    case "gpt-4-1106-preview":
      return "gpt-4t-hebbia";
    default:
      return model;
  }
};

export const countAnswerTools = (toolMap: ColumnIdToToolMapType) =>
  Object.values(toolMap).filter(({ tool }) => isAnswerToolType(tool)).length;

export const updateToolInToolMap = (
  toolMap: ColumnIdToToolMapType,
  columnId: string,
  update: Partial<ReduxTool>
): void => {
  // Check if the columnId exists in the toolMap
  if (!toolMap[columnId]) {
    throw new Error(
      `Column with ID ${columnId} does not exist in the tool map.`
    );
  }

  toolMap[columnId] = {
    ...toolMap[columnId],
    ...update,
  } as ReduxTool;
};

export const addToolToToolMap = (
  toolMap: ColumnIdToToolMapType,
  columnId: string,
  tool: ReduxTool
): void => {
  toolMap[columnId] = tool;
};
