import Box from '@mui/material/Box';
import IconButton from '@mui/material/IconButton';
import composeClasses from '@mui/utils/composeClasses';
import type { GridRenderCellParams } from '@mui/x-data-grid-pro';
import {
  getDataGridUtilityClass,
  gridFilteredDescendantCountLookupSelector,
  useGridApiContext,
  useGridRootProps,
  useGridSelector,
} from '@mui/x-data-grid-pro';
import { isNavigationKey } from '@mui/x-data-grid-pro/internals';
import type { DataGridProProcessedProps } from '@mui/x-data-grid-pro/models/dataGridProProps';
import * as React from 'react';
import { Item } from './constants';

interface OwnerState {
  classes: DataGridProProcessedProps['classes'];
}

const useUtilityClasses = (ownerState: OwnerState) => {
  const { classes } = ownerState;

  const slots = {
    root: ['treeDataGroupingCell'],
    toggle: ['treeDataGroupingCellToggle'],
  };

  return composeClasses(slots, getDataGridUtilityClass, classes);
};

interface CustomGridTreeDataGroupingCellProps
  extends GridRenderCellParams<string, Item> {
  hideDescendantCount?: boolean;
  children: React.ReactNode;
}

export function CustomGridTreeDataGroupingCell(
  props: CustomGridTreeDataGroupingCellProps
): JSX.Element {
  const { id, field, rowNode, children, hideDescendantCount } = props;

  const rootProps = useGridRootProps();
  const apiRef = useGridApiContext();
  const ownerState: OwnerState = { classes: rootProps.classes };
  const classes = useUtilityClasses(ownerState);
  const filteredDescendantCountLookup = useGridSelector(
    apiRef,
    gridFilteredDescendantCountLookupSelector
  );

  const filteredDescendantCount =
    filteredDescendantCountLookup[rowNode.id] ?? 0;

  const Icon =
    rowNode.childrenExpanded ?? false
      ? rootProps.components.TreeDataCollapseIcon
      : rootProps.components.TreeDataExpandIcon;

  const handleKeyDown = (event: React.KeyboardEvent<HTMLButtonElement>) => {
    if (event.key === ' ') {
      event.stopPropagation();
    }

    if (isNavigationKey(event.key) && !event.shiftKey) {
      apiRef.current.publishEvent('cellNavigationKeyDown', props, event);
    }
  };

  const handleClick = (event: React.MouseEvent<HTMLButtonElement>) => {
    apiRef.current.setRowChildrenExpansion(
      id,
      !(rowNode.childrenExpanded ?? false)
    );
    apiRef.current.setCellFocus(id, field);
    event.stopPropagation(); // TODO remove event.stopPropagation
  };

  return (
    <Box className={classes.root} sx={{ ml: rowNode.depth * 2 }}>
      <div className={classes.toggle}>
        {filteredDescendantCount > 0 && (
          <IconButton
            size="small"
            onClick={handleClick}
            onKeyDown={handleKeyDown}
            tabIndex={-1}
            aria-label={
              rowNode.childrenExpanded ?? false
                ? apiRef.current.getLocaleText('treeDataCollapse')
                : apiRef.current.getLocaleText('treeDataExpand')
            }
          >
            <Icon fontSize="inherit" />
          </IconButton>
        )}
      </div>
      <span>
        {children}
        {!(hideDescendantCount ?? false) && filteredDescendantCount > 0
          ? ` (${filteredDescendantCount})`
          : ''}
      </span>
    </Box>
  );
}

export default CustomGridTreeDataGroupingCell;
