import React, { useEffect, useState, useMemo } from 'react';
import { Button, Dialog, DialogTitle, DialogActions, DialogContent, Box, Typography, Alert } from '@mui/material';
import { DataGrid, gridClasses } from '@mui/x-data-grid';
import { useAuth0 } from '@auth0/auth0-react';
import { datadogRum } from '@datadog/browser-rum';
import { chunkArray } from './api';
import { identifyOverlappingResistanceGenes } from '../AMRPredictions/resistanceGeneUtils';

const partialCoverageMethods = [ 'PARTIALX', 'PARTIALP', 'PARTIAL_CONTIG_ENDX', 'PARTIAL_CONTIG_ENDP', 'INTERNAL_STOP']
const geneHasPartialCoverage = (gene) => {
  // If overlapping hit and multiple methods are used, we consider it partial coverage if at least
  // one of the methods is a partial coverage method.
  return gene.method.split(', ')
    .some(method => partialCoverageMethods.includes(method))
};

const ResistanceGenesDialog = ({ identifiedSpecies, open, onClose, setAllAvailableResistanceGenesLoading }) => {
  const { getAccessTokenSilently, isAuthenticated } = useAuth0();
  const [allAvailableResistanceGenes, setAllAvailableResistanceGenes] = useState(null);

  useEffect(() => {
    const fetchAllAvailableResistanceGenes = async () => {
      if (!identifiedSpecies?.length || !isAuthenticated) return;
      try {
        const accessToken = await getAccessTokenSilently();
        const headers = { Authorization: `Bearer ${accessToken}` };
  
        const speciesChunks = chunkArray(identifiedSpecies, 5);
        let allResults = [];

        setAllAvailableResistanceGenesLoading(true);
        for (const chunk of speciesChunks) {
          const results = await Promise.all(
            chunk.map(async ({ id, attributes }) => {
              const response = await fetch(
                `${process.env.REACT_APP_KEYNOME_API_URL_BASE}/v1/identified_species/${id}/resistance_genes`,
                { method: 'GET', headers }
              );
              const data = await response.json();

              const filteredGenes = data?.data.filter(
                gene => gene.attributes?.element_type?.toUpperCase() === 'AMR'
              );

              const uniqueGenes = filteredGenes.reduce((acc, gene) => {
                const { id, attributes } = gene;
                // TODO: determine what the cases are that the same gene symbol would appear twice
                // same gene in multiple locations or multiple contigs? Same gene with different 
                // drug classes?
                // We might aggregate data (for example join with comma) if collisions have data we
                // might want to retain and display.
                acc[attributes.gene_symbol] = {
                  id,
                  contig_id: attributes.contig_id,
                  strand: attributes.strand,
                  start: attributes.start,
                  stop: attributes.stop,
                  gene_symbol: attributes.gene_symbol,
                  drug_class: attributes.drug_class || 'Unknown',
                  method: attributes.method,
                  gene_family: attributes.gene_family
                };
                return acc;
              }, {});

              const genesById = Object.values(uniqueGenes).reduce((acc, gene) => {
                acc[gene.id] = gene;
                return acc;
              }, {});

              const geneClusters = identifyOverlappingResistanceGenes(Object.values(genesById));

              const uniqueGenesWithOverlaps = geneClusters.map(cluster => {
                if (cluster.length === 1) {
                  const gene = genesById[cluster[0]];
                  return {
                    gene_symbol: gene.gene_symbol,
                    drug_class: gene.drug_class,
                    method: gene.method
                  };
                }
            
                const geneFamilies = new Set(cluster.map(id => genesById[id].gene_family || genesById[id].gene_symbol));
                const uniqueDrugClasses = new Set(cluster.flatMap(id => genesById[id].drug_class.split('/')));
                const uniqueMethods = new Set(cluster.map(id => genesById[id].method))

                return {
                  gene_symbol: geneFamilies.size > 1 ? `Gene similar to ${Array.from(geneFamilies).join(' or ')}` : geneFamilies[0],
                  drug_class: Array.from(uniqueDrugClasses).join('/'),
                  method: Array.from(uniqueMethods).join(', '), 
                };
              });

              return { species_name: attributes.species, genes: uniqueGenesWithOverlaps };
            })
          );

          allResults = [...allResults, ...results];
        }

        const resistanceGenes = allResults.reduce((acc, { species_name, genes }) => {
          if (genes.length) acc[species_name] = genes;
          return acc;
        }, {});
        setAllAvailableResistanceGenesLoading(false);
        setAllAvailableResistanceGenes(resistanceGenes);
      } catch (error) {
        datadogRum.addError(error);
        console.error('Error fetching resistance genes:', error);
      }
    };

  fetchAllAvailableResistanceGenes();
}, [identifiedSpecies, getAccessTokenSilently, isAuthenticated, setAllAvailableResistanceGenesLoading]);

  const columns = useMemo(() => [
    { field: 'organism', headerName: 'Organism', flex: 1 },
    { field: 'drugClass', headerName: 'Drug Class', flex: 1 },
    { 
      field: 'gene', 
      headerName: 'Resistance Marker', 
      flex: 1, 
      renderCell: (params) => <b style={{ fontStyle: 'italic' }}>{params.value}</b>
    },
    {
      field: 'coverage',
      headerName: 'Gene Coverage',
      flex: 1
    }
  ], []);

  const rows = useMemo(() => {
    return Object.entries(allAvailableResistanceGenes || {})
    .sort(([speciesNameA], [speciesNameB]) => speciesNameA.localeCompare(speciesNameB))
    .flatMap(
      ([speciesName, genes]) =>
        genes
        .sort((a, b) => {
          const drugClassComparison = a.drug_class.localeCompare(b.drug_class);
          return drugClassComparison !== 0 ? drugClassComparison : a.gene_symbol.localeCompare(b.gene_symbol);
        })
        .map((gene, index) => ({
          id: `${speciesName}-${index}`,
          organism: speciesName,
          gene: gene.gene_symbol,
          drugClass: gene.drug_class,
          coverage: geneHasPartialCoverage(gene) ? 'Partial' : 'High'
        }))
    );
  }, [allAvailableResistanceGenes]);

  return (
    <Dialog open={open} onClose={onClose} fullWidth maxWidth="lg" scroll="paper" >
      <DialogTitle sx={{ backgroundColor: "#eee" }}>Complete Resistance Gene Profile</DialogTitle>
      <DialogContent>
        <Box sx={{ padding: 2, '& .even': { backgroundColor: '#fff' }, '& .odd': { backgroundColor: '#f5f5f5' } }}>
        {rows.length > 0 ? (
          <>
            <Typography paragraph>
              The following table represents all resistance genes identified in the analyzed sample.
            </Typography>
            <Alert 
              severity="warning"
              sx={{ marginBottom: 2 }}
            >
              Resistance gene presence or absence is often not indicative of antibiotic resistance or susceptibility. Refer to resistance genes reported alongside AMR predictions for highest predictive value.
            </Alert>
            <DataGrid
              rows={rows}
              columns={columns}
              disableColumnSorting
              getRowHeight={() => 'auto'}
              getRowClassName={(params) => (params.indexRelativeToCurrentPage % 2 === 0 ? 'even' : 'odd')}
              sx={{ maxHeight: '500px', [`& .${gridClasses.cell}`]: { py: 0.75 } }}
            />
          </>
          ) : (<p>No resistance genes found.</p>)}
        </Box>
      </DialogContent>
      <DialogActions sx={{ backgroundColor: "#eee" }}>
        <Button color="primary" sx={{ marginRight: 2 }} onClick={onClose}>Close</Button>
      </DialogActions>
    </Dialog>
  );
};

export default ResistanceGenesDialog;
