import {
  ColumnDef,
  ExpandedState,
  flexRender,
  getCoreRowModel,
  getExpandedRowModel,
  OnChangeFn,
  RowData,
  RowSelectionState,
  RowSelectionTableState,
  useReactTable,
} from '@tanstack/react-table';
import {
  Td,
  Th,
  Tr,
  Tbody,
  Thead,
  Checkbox,
  Flex,
  Box,
  Icon,
} from '@chakra-ui/react';
import { Table, TableContainer } from '@chakra-ui/react';
import React, { useEffect, useMemo, useRef, useState } from 'react';
import { useVirtualizer } from '@tanstack/react-virtual';
import TableLoader from '../Loaders/TableLoader';
import { BsChevronDown, BsChevronRight } from 'react-icons/bs';

declare module '@tanstack/react-table' {
  interface ColumnMeta<TData extends RowData, TValue> {
    sticky?: boolean;
    left?: number;
    border?: boolean;
    align?: 'left' | 'center' | 'right';
  }
}

interface RowGroupedTableProps<
  TData extends { id: number; subRows: TData[] },
  TValue
> {
  columns: ColumnDef<TData, TValue>[];
  data: TData[];
  enableMultiRowSelection: boolean;
  isLoading?: boolean;
  rowHeight?: number;
  maxHeight?: string | number;
  minHeight?: string | number;
  ExpanderComponent?: React.ComponentType<{ data: TData }>;
  onRowSelectionChange: (rows: TData[]) => void;
  expanded: ExpandedState;
  setExpanded: OnChangeFn<ExpandedState>;
  isAllExpanded?: boolean;
  isAllowCheckOnCollapsed?: boolean;
}

export default function RowGroupedTable<
  TData extends { id: number; subRows: TData[] },
  TValue
>({
  columns,
  data,
  enableMultiRowSelection,
  isLoading = false,
  rowHeight = 50,
  maxHeight = '100%',
  minHeight = '100%',
  expanded,
  setExpanded,
  onRowSelectionChange,
  isAllExpanded,
  ExpanderComponent,
  isAllowCheckOnCollapsed,
}: RowGroupedTableProps<TData, TValue>) {
  const tableContainerRef = useRef<HTMLDivElement>(null);
  const { rowSelection, setRowSelection } = useRowSelection<TData>(
    onRowSelectionChange,
    data
  );
  const count = useMemo(() => {
    return data.reduce((acc, row) => {
      acc += 1;
      if (row.subRows) {
        acc += row.subRows.length;
      }
      return acc;
    }, 0);
  }, [data]);
  const table = useReactTable({
    columns,
    data,
    enableMultiRowSelection,
    getCoreRowModel: getCoreRowModel(),
    getRowId: (originalRow) => String(originalRow.id),
    state: {
      rowSelection,
      expanded,
    },
    getSubRows: (row) => row.subRows,
    onRowSelectionChange: setRowSelection,
    getExpandedRowModel: getExpandedRowModel(),
    onExpandedChange: setExpanded,
  });
  const tableRows = useMemo(
    () => table.getRowModel(),
    [table, columns, rowSelection, data, expanded]
  );
  const rowVirtualizer = useVirtualizer({
    count: count,
    estimateSize: () => 200,
    getScrollElement: () => tableContainerRef.current,
    overscan: 20,
    measureElement:
      typeof window !== 'undefined' &&
      navigator.userAgent.indexOf('Firefox') === -1
        ? (element) => element?.getBoundingClientRect().height
        : undefined,
  });

  useEffect(() => {
    if (isAllExpanded) {
      table.toggleAllRowsExpanded();
    }
  }, [isAllExpanded]);

  return (
    <>
      {isLoading ? (
        <TableLoader
          header={
            columns.map(
              (column) => column?.header ?? ''
            ) as Array<React.ReactNode>
          }
        />
      ) : (
        <TableContainer
          boxSizing="border-box"
          border="1px solid"
          borderColor="default.white.400"
          sx={{
            overflow: 'auto',
            position: 'relative',
            height: '100%',
            borderRadius: 'md',
            borderBottomRadius: '0px',
          }}
          maxH={maxHeight}
          minH={minHeight}
          ref={tableContainerRef}
        >
          <Table
            sx={{
              display: 'grid',
              width: '100%',
            }}
          >
            <Thead
              sx={{
                position: 'sticky',
                zIndex: 30,
                top: 0,
                backgroundColor: 'default.white.600',
              }}
            >
              {table.getHeaderGroups().map((headerGroup) => (
                <Tr
                  key={headerGroup.id}
                  sx={{
                    display: 'flex',
                    width: '100%',
                    position: 'relative',
                  }}
                >
                  {headerGroup.headers.map((header) => {
                    const meta = header.column.columnDef.meta;
                    return (
                      <Th
                        key={header?.id}
                        sx={{
                          display: 'flex',
                          alignItems: 'center',
                          width: header.getSize(),
                          position: meta?.sticky ? 'sticky' : 'static',
                          left: meta?.sticky ? meta?.left : 'unset',
                          justifyContent: meta?.align || 'flex-start',
                          '&::after': {
                            content: '""',
                            position: 'absolute',
                            top: 0,
                            right: 0,
                            background: meta?.border
                              ? '#E7E9ED'
                              : 'transparent',
                            width: '1px',
                            height: '100%',
                            zIndex: 20,
                          },
                        }}
                        bg="default.white.600"
                        color="default.gray.600"
                        p="10px 20px"
                      >
                        {flexRender(
                          header.column.columnDef.header,
                          header.getContext()
                        )}
                      </Th>
                    );
                  })}
                </Tr>
              ))}
            </Thead>

            <Tbody
              boxSizing="border-box"
              background="default.white.100"
              borderBottom="1px solid"
              borderColor="default.white.400"
              sx={{
                display: 'grid',
                position: 'relative',
                height: `${rowVirtualizer.getTotalSize()}px`,
              }}
            >
              {rowVirtualizer.getVirtualItems().map((virtualRow) => {
                const row = tableRows.rows[virtualRow.index];

                if (!row.parentId) {
                  return (
                    <Tr
                      data-index={virtualRow.index}
                      ref={(node) => rowVirtualizer.measureElement(node)}
                      key={row?.id}
                      borderBottom="1px solid #EEEEEE"
                      sx={{
                        display: 'flex',
                        alignItems: 'center',
                        position: 'absolute',
                        top: 0,
                        left: 0,
                        width: '100%',
                        transform: `translateY(${virtualRow.start}px)`,
                        willChange: 'transform',
                      }}
                      bg="#FAF5FF"
                      _hover={{
                        bg: '#f8f9fa',
                      }}
                    >
                      {row
                        ?.getVisibleCells?.()
                        ?.filter((cell, index) => !index)
                        ?.map((cell) => {
                          return (
                            <Td
                              key={cell.id}
                              sx={{
                                width: '100%',
                                zIndex: 1,
                              }}
                              h={'36px'}
                              p="0px 20px"
                              borderBottom={0}
                              colSpan={columns.length}
                            >
                              <Flex
                                alignItems="center"
                                gap="25px"
                                sx={{
                                  color: 'primary.600',
                                  fontSize: '12px',
                                  fontWeight: 700,
                                  textTransform: 'uppercase',
                                  cursor: 'pointer',
                                  zIndex: 15,
                                  w: '100%',
                                }}
                              >
                                <Checkbox
                                  isDisabled={
                                    isAllowCheckOnCollapsed
                                      ? false
                                      : !row.getIsExpanded()
                                  }
                                  isIndeterminate={row.getIsSomeSelected()}
                                  isChecked={row.getIsAllSubRowsSelected()}
                                  onChange={row.getToggleSelectedHandler()}
                                  aria-label="Select row"
                                  id={row.id}
                                  name={row.id}
                                />
                                <Flex
                                  gap="10px"
                                  onClick={row?.getToggleExpandedHandler?.()}
                                  py="8px"
                                >
                                  <Box width="16px">
                                    {row?.getIsExpanded?.() ? (
                                      <Icon as={BsChevronDown} />
                                    ) : (
                                      <Icon as={BsChevronRight} />
                                    )}
                                  </Box>
                                  {ExpanderComponent && (
                                    <ExpanderComponent data={row.original} />
                                  )}
                                </Flex>
                              </Flex>
                            </Td>
                          );
                        })}
                    </Tr>
                  );
                }
                return (
                  <Tr
                    data-index={virtualRow.index}
                    ref={(node) => rowVirtualizer.measureElement(node)}
                    key={row?.id}
                    borderBottom="1px solid #EEEEEE"
                    sx={{
                      display: 'flex',
                      alignItems: 'center',
                      position: 'absolute',
                      top: 0,
                      left: 0,
                      width: '100%',
                      transform: `translateY(${virtualRow.start}px)`,
                      willChange: 'transform',
                    }}
                    _hover={{
                      '& td': {
                        bg: '#f8f9fa',
                      },
                    }}
                  >
                    {row?.getVisibleCells?.()?.map((cell) => {
                      const meta = cell.column.columnDef.meta;
                      return (
                        <Td
                          key={cell.id}
                          sx={{
                            display: 'flex',
                            alignItems: 'center',
                            width: cell.column.getSize(),
                            zIndex: meta?.sticky ? 5 : 1,
                            position: meta?.sticky ? 'sticky' : 'static',
                            left: meta?.sticky ? meta?.left : 'unset',
                            justifyContent: meta?.align || 'flex-start',
                            '&::after': {
                              content: '""',
                              position: 'absolute',
                              top: 0,
                              right: 0,
                              background: meta?.border
                                ? '#E7E9ED'
                                : 'transparent',
                              width: '1px',
                              height: '100%',
                              zIndex: 20,
                            },
                          }}
                          h={`${rowHeight}px`}
                          p={'10px 20px'}
                          borderBottom={0}
                        >
                          {flexRender(
                            cell.column.columnDef.cell,
                            cell.getContext()
                          )}
                        </Td>
                      );
                    })}
                  </Tr>
                );
              })}
            </Tbody>
          </Table>
        </TableContainer>
      )}
    </>
  );
}

function useRowSelection<T extends { id: number; subRows: T[] }>(
  callbackFn: (rows: T[]) => void,
  rows: T[]
) {
  const mappedIdRows = useMemo(() => {
    return rows.reduce((acc, row) => {
      if (row.subRows) {
        row.subRows.forEach((subRow) => {
          acc[subRow.id] = subRow;
        });
      }
      return acc;
    }, {} as Record<number, T>);
  }, [rows]);
  const [rowSelection, setRowSelection] = useState({});

  useEffect(() => {
    const rowIds = Object.keys(rowSelection).map(Number);
    const selectedRows = rowIds
      .map((rowId) => mappedIdRows[rowId])
      .filter(Boolean);
    callbackFn(selectedRows);
  }, [callbackFn, rowSelection, rows, mappedIdRows]);

  return { rowSelection, setRowSelection };
}

export type { ExpandedState, RowSelectionState, RowSelectionTableState };
