import React, { memo, useState, useEffect } from "react";
import styled from "styled-components";
import OpenInNewIcon from "@material-ui/icons/OpenInNew";
import {
  Grid, Typography, TableRow, TableHead, TableCell, TableBody, Table,
} from "@material-ui/core";
import Highcharts from "highcharts";
import { faInfoCircle } from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";

// local components

import Collapse from "../../common/Collapse";
import SolutionContainer from "../../common/SolutionContainerWrapper";
import {
  Paragraph,
  StyledButton,
  DemoContainer,
  LinkContainer,
  HighchartContainer,
  List,
  ListItem,
} from "../../../styles/common";
import PresignedS3Link from "../../common/PresignedS3Link";

const UHGStrokePrediction = () => {
  const [sample, setSample] = useState(false);
  const [notification, setNotification] = useState("");
  const [cnnModelDataRes, setCNNModelDataRes] = useState("");
  const [cnnPredictionRes, setCNNPredictionRes] = useState("");

  useEffect(() => {
    window.scrollTo(0, 0);
  }, []);

  const formHighChart = () => {
    Highcharts.chart('container_1', {
      chart: {
        type: 'column',
      },

      title: {
        text: 'Comparison of Scores Before and After Shapley',
      },
      legend: {
        align: 'right',
        verticalAlign: 'middle',
        layout: 'vertical',
      },
      xAxis: {
        categories: ['Accuracy', 'AUC', 'Precision', 'F1 '],
        labels: {
          x: -10,
        },
        title: {
          text: 'Values',
        },
      },
      yAxis: {
        allowDecimals: false,
        title: {
          text: 'Metrics',
        },
      },
      series: [{
        name: 'Before Data Shapley ',
        data: [65.47, 65.52, 65.52, 65.59],
      }, {
        name: 'After Data Shapley ',
        data: [67.45, 67.45, 61.67, 67.96],
      }],

      responsive: {
        rules: [{
          condition: {
            maxWidth: 500,
          },
          chartOptions: {
            legend: {
              align: 'center',
              verticalAlign: 'bottom',
              layout: 'horizontal',
            },
            yAxis: {
              labels: {
                align: 'left',
                x: 0,
                y: -5,
              },
              title: {
                text: null,
              },
            },
            subtitle: {
              text: null,
            },
            credits: {
              enabled: false,
            },
          },
        }],
      },
    });
  };

  const formSampleHighChart = () => {
    Highcharts.chart('container', {
      chart: {
        plotBackgroundColor: null,
        plotBorderWidth: null,
        plotShadow: false,
        type: 'pie',
      },
      title: {
        text: 'Distribution of Synthetic and Real Images',
      },
      tooltip: {
        pointFormat: '{series.name}: <b>{point.percentage:.0f}%</b>',
      },
      accessibility: {
        point: {
          valueSuffix: '%',
        },
      },
      plotOptions: {
        pie: {
          allowPointSelect: true,
          cursor: 'pointer',
          dataLabels: {
            enabled: true,
            format: '<b>{point.name}</b>: {point.percentage:.0f} %',
          },
        },
      },
      series: [{
        name: 'Data Count',
        colorByPoint: true,
        data: [{
          name: 'Real Samples',
          y: 1406,

        }, {
          name: 'Synthetic Samples',
          y: 1372,
          sliced: true,

        }],
      }],
    });
  };

  const getcnnPrediction = () => {
    setTimeout(() => {
      formHighChart();
    }, 0);
    setCNNPredictionRes(true);
  };

  const getCnnModelData = () => {
    setCNNPredictionRes("");
    setCNNModelDataRes(true);
  };

  const getSampleData1 = () => {
    setCNNPredictionRes("");
    setCNNModelDataRes("");

    setSample(true);
    setTimeout(() => {
      formSampleHighChart();
    }, 0);
  };
  return (
    <SolutionContainer snackbar={notification}>
      <Collapse text="Description">
        <Paragraph>
          <p>
            According to the World Health Organization (WHO), cerebrovascular accidents (stroke) are the second leading cause of death globally, responsible for approximately 11% of total deaths. Stroke is a leading cause of death in the United States. About 795,000 people in the United States are reported to have a stroke each year.
            There are several risk factors that increase the likelihood of an individual to experience stroke cases. Age is one of the major factors that result in a stroke case. Older adults are at a higher risk of getting a stroke. However, according to research two or more factors in addition to age are usually responsible for stroke cases. Early identification of possible stroke cases can aid in timely care and, to an extent, prevent stroke. Blood pressure levels, BMI, and blood glucose levels, as well as lifestyle factors and family health history, are important factors to consider while monitoring patients to identify their susceptibility to stroke.
          </p>
          <p>A great deal of research goes into using machine learning to aid in medical diagnosis, but the lack of availability of training data becomes a limiting factor. A small, well-balanced subset of the stroke dataset provided by the National Cardiovascular Disease Surveillance System is used for this application. As the size of the dataset is pretty small, to increase its size, synthetic data points are generated from the original dataset using CTGAN and combined with the original dataset. This larger dataset is then used to train a Logistic Regression model to trained to predict whether a patient is likely to experience a stroke or not by analyzing various health indicators and demographic factors.</p>
          <p>In addition to this, a principled framework to address data valuation in the context of supervised machine learning is showcased. Monte Carlo and Gradient-based methods are implemented to efficiently estimate data Shapley values for each of the data points in the dataset to identify and eliminate noisy data. The logistic regression model is trained again on the modified dataset to produce better results.</p>
          <ResponsiveTypography variant="h6" gutterBottom>
            <FontAwesomeIcon icon={faInfoCircle} />
            {' '}
            Dataset
          </ResponsiveTypography>
          <p>The dataset used for this application is provided by the National Cardiovascular Disease Surveillance System. It consists of 43400 rows of patient demographic data and health indicators, and include CVDs (e.g., heart failure) and risk factors (e.g., hypertension). A count of ~640 data points from each class label was extracted to form a smaller subset to work with.</p>
          <ResponsiveTypography variant="h6" gutterBottom>
            <FontAwesomeIcon icon={faInfoCircle} />
            {' '}
            CTGAN
          </ResponsiveTypography>
          <p>The small size of the dataset makes it unsuitable for training an efficient machine learning model that will be capable of making accurate predictions. A technique called Synthetic Data Generation is used to generate a set of synthetic tabular data from the existing dataset to overcome this challenge. This is achieved using CTGAN, a collection of Deep Learning-based Synthetic Data Generators for single-table data that are able to learn from real data and generate synthetic clones with high fidelity.</p>
          <p>
            Tabular data usually contains a mix of discrete and continuous columns. Continuous columns may have multiple modes whereas discrete columns are sometimes imbalanced making the modelling difficult. The sdv.tabular.CTGAN model is based on the GAN-based Deep Learning data synthesizer which was presented at the NeurIPS 2020 conference by the paper titled Modeling Tabular data using Conditional GAN.
            The newly generated tabular data is combined with the original dataset, is preprocessed and made model-ready.
          </p>
          <ResponsiveTypography variant="h6" gutterBottom>
            <FontAwesomeIcon icon={faInfoCircle} />
            {' '}
            Logistic Regression
          </ResponsiveTypography>
          <p>A logistic regression model is built and trained on the newly modified dataset. The model is trained to predict whether a patient is likely to experience a stroke or not by analyzing various health indicators and demographic factors.</p>
          <ResponsiveTypography variant="h6" gutterBottom>
            <FontAwesomeIcon icon={faInfoCircle} />
            {' '}
            Data Shapley
          </ResponsiveTypography>
          <p>Although the logistic regression model gives us fairly good performance results, it is undeniable that the dataset used for training may have noisy data. Quantifying the value of each data point towards the final performance of an AI/ML model can help us identify those data points that contribute more positively and those that bring down the model performance. This can be achieved by using Data Shapley. Data Shapley makes use of a principled framework to address the quantification of data in the context of machine learning. Truncated Monte Carlo (TMC) and Gradient-based methods are implemented to efficiently estimate the Shapley values for each data point in the dataset.</p>
          <ResponsiveTypography variant="h6" gutterBottom>
            <FontAwesomeIcon icon={faInfoCircle} />
            {' '}
            Why use Data Shapley?
          </ResponsiveTypography>
          <List>
            <ListItem>It is more powerful and effective than generic methods like leave-one-out or leverage-score sampling algorithms</ListItem>
            <ListItem>Data Shapley calculates and assigns a “Shapley value” to each data point in the dataset</ListItem>
            <ListItem>Data points with relatively low Shapley values can be considered as outliers or corrupted data. Those with higher Shapley values can be used to retrain an AI/ML model to obtain increased performance</ListItem>
          </List>
        </Paragraph>
        <LinkContainer>
          <Grid container spacing={2}>
            <Grid item>
              <StyledButton
                variant="outlined"
                color="primary"
                size="large"
                startIcon={<OpenInNewIcon />}
              >
                <PresignedS3Link
                  href="UHG_StrokePrediction/Stroke_Prediction.html"
                  target="_blank"
                  rel="noopener noreferrer"
                >
                  Notebook
                </PresignedS3Link>
              </StyledButton>
            </Grid>
            <Grid item>
              <StyledButton
                variant="outlined"
                color="primary"
                size="large"
                startIcon={<OpenInNewIcon />}
              >
                <a
                  href="https://www.kaggle.com/fedesoriano/stroke-prediction-dataset"
                  target="_blank"
                  rel="noopener noreferrer"
                >
                  Dataset
                </a>
              </StyledButton>
            </Grid>
            <Grid item>
              <StyledButton
                variant="outlined"
                color="primary"
                size="large"
                startIcon={<OpenInNewIcon />}
              >
                <a
                  href="https://arxiv.org/abs/1904.02868"
                  target="_blank"
                  rel="noopener noreferrer"
                >
                  Stanford AI : Citations
                </a>
              </StyledButton>
            </Grid>
          </Grid>
        </LinkContainer>
      </Collapse>
      <Collapse text="Demo">
        <DemoContainer>
          <section>
            <Grid
              container
              xs={12}
              direction="row"
              justify="center"
              alignItems="center"
            >
              <Grid xs={12} alignItems="center">
                <StyledButton
                  variant="contained"
                  color="primary"
                  onClick={getSampleData1}
                >
                  Sample Tabular Data
                </StyledButton>
              </Grid>
            </Grid>
          </section>
          {sample && (
            <section>
              <Grid container xs={12} spacing={2} direction="row" justify="center" alignItems="center">
                <Grid item xs={12} sm={12} md={12}>
                  <PaperTable>
                    <Table>
                      <TableHead>
                        <TableRow>
                          <StyledTableHead align="left">gender</StyledTableHead>
                          <StyledTableHead align="left">age</StyledTableHead>
                          <StyledTableHead align="left">hypertension</StyledTableHead>
                          <StyledTableHead align="left">heart_disease</StyledTableHead>
                          <StyledTableHead align="left">ever_married</StyledTableHead>
                          <StyledTableHead align="left">work_type</StyledTableHead>
                          <StyledTableHead align="left">residence_type</StyledTableHead>
                          <StyledTableHead align="left">avg_glucose_level</StyledTableHead>
                          <StyledTableHead align="left">bmi</StyledTableHead>
                          <StyledTableHead align="left">smoking_status</StyledTableHead>
                          <StyledTableHead align="left">stroke</StyledTableHead>
                        </TableRow>
                      </TableHead>
                      <TableBody>
                        <TableRow>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">58.0</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">2</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">87.96</TableCell>
                          <TableCell align="left">39.2</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">0</TableCell>
                        </TableRow>
                        <TableRow>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">70.0</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">2</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">69.04</TableCell>
                          <TableCell align="left">35.9</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">0</TableCell>
                        </TableRow>
                        <TableRow>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">52.0</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">2</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">77.59</TableCell>
                          <TableCell align="left">17.7</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">0</TableCell>
                        </TableRow>
                        <TableRow>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">75.0</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">3</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">243.53</TableCell>
                          <TableCell align="left">27.0</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">0</TableCell>
                        </TableRow>
                        <TableRow>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">32.0</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">1</TableCell>
                          <TableCell align="left">2</TableCell>
                          <TableCell align="left">0</TableCell>
                          <TableCell align="left">77.67</TableCell>
                          <TableCell align="left">32.3</TableCell>
                          <TableCell align="left">2</TableCell>
                          <TableCell align="left">0</TableCell>
                        </TableRow>
                      </TableBody>
                    </Table>
                  </PaperTable>
                </Grid>
                <Grid item xs={12} sm={12} md={10}>
                  <HighchartContainer
                    id="container"
                    display={sample}
                  />
                </Grid>
                <Grid item xs={12} sm={12} md={10}>
                  <PaperTable>
                    <Table>
                      <TableHead>
                        <TableRow>
                          <StyledTableHead>Class</StyledTableHead>
                          <StyledTableHead>Original DataCount</StyledTableHead>
                          <StyledTableHead>Synthetic DataCount</StyledTableHead>
                          <StyledTableHead>Total</StyledTableHead>
                        </TableRow>
                      </TableHead>
                      <TableBody>
                        <TableRow>
                          <TableCell align="left">Class 0 - No Stroke</TableCell>
                          <TableCell align="left">640</TableCell>
                          <TableCell align="left">766</TableCell>
                          <TableCell align="left">1406</TableCell>
                        </TableRow>
                        <TableRow>
                          <TableCell align="left">Class 1 - Stroke</TableCell>
                          <TableCell align="left">638</TableCell>
                          <TableCell align="left">734</TableCell>
                          <TableCell align="left">1372</TableCell>
                        </TableRow>
                      </TableBody>
                    </Table>
                  </PaperTable>
                </Grid>
              </Grid>
              <MarginButton
                variant="contained"
                color="primary"
                onClick={getCnnModelData}
              >
                SHOW RESULTS
              </MarginButton>
              {cnnModelDataRes && (
                <section>
                  <Grid
                    container
                    xs={12}
                    spacing={2}
                    direction="row"
                    justify="center"
                    alignItems="center"
                  >
                    <Grid item xs={12} sm={8} md={4}>
                      <h4>Metrics Before Data Shapley </h4>
                      <Table>
                        <TableHead>
                          <TableRow>
                            <StyledTableHead>Metrics</StyledTableHead>
                            <StyledTableHead>Values</StyledTableHead>
                          </TableRow>
                        </TableHead>
                        <TableBody>
                          <TableRow key="0">
                            <TableCell align="left">Accuracy</TableCell>
                            <TableCell align="left">65.47</TableCell>
                          </TableRow>
                          <TableRow key="1">
                            <TableCell align="left">F1 Score</TableCell>
                            <TableCell align="left">65.59</TableCell>
                          </TableRow>
                          <TableRow key="2">
                            <TableCell align="left">Precision</TableCell>
                            <TableCell align="left">65.52</TableCell>
                          </TableRow>
                          <TableRow key="3">
                            <TableCell align="left">AUC</TableCell>
                            <TableCell align="left">65.52</TableCell>
                          </TableRow>
                        </TableBody>
                      </Table>
                    </Grid>
                    <Grid item xs={12} sm={8} md={4}>
                      <MarginButton
                        variant="contained"
                        color="primary"
                        onClick={getcnnPrediction}
                      >
                        RETRAIN WITH SHAPLEY
                      </MarginButton>
                    </Grid>
                    {cnnPredictionRes && (
                      <Grid xs={12} sm={12} md={4}>
                        <Grid container direction="row" justify="center">
                          <Grid item xs={12} sm={8} md={12}>
                            <h4>
                              <p>Metrics After Data Shapley </p>
                            </h4>
                            <Table>
                              <TableHead>
                                <TableRow>
                                  <StyledTableHead>Metrics</StyledTableHead>
                                  <StyledTableHead>Values</StyledTableHead>
                                </TableRow>
                              </TableHead>
                              <TableBody>
                                <TableRow key="0">
                                  <TableCell align="left">Accuracy</TableCell>
                                  <TableCell align="left">67.45</TableCell>
                                </TableRow>
                                <TableRow key="1">
                                  <TableCell align="left">F1 Score</TableCell>
                                  <TableCell align="left">67.96</TableCell>
                                </TableRow>
                                <TableRow key="2">
                                  <TableCell align="left">Precision</TableCell>
                                  <TableCell align="left">61.67</TableCell>
                                </TableRow>
                                <TableRow key="3">
                                  <TableCell align="left">AUC</TableCell>
                                  <TableCell align="left">67.45</TableCell>
                                </TableRow>
                              </TableBody>
                            </Table>
                          </Grid>
                        </Grid>
                      </Grid>
                    )}
                    <Grid xs={12} md={12}>
                      <Grid container direction="row" justify="center">
                        <Grid item xs={12} sm={12}>
                          <HighchartContainer
                            id="container_1"
                            display={cnnPredictionRes}
                          />
                        </Grid>
                      </Grid>
                    </Grid>
                  </Grid>
                </section>
              )}
            </section>
          )}
        </DemoContainer>
      </Collapse>

    </SolutionContainer>
  );
};
export default memo(UHGStrokePrediction);

const MarginButton = styled(StyledButton)`
  margin-right: 10px;
  margin-left: 10px;
`;
const PaperTable = styled.div`
overflow: auto;
`;
const StyledTableHead = styled(TableCell)`
  background-color: #3c40af;
  color: white;
  @media only screen and (max-width: 900px) {
    font-size: 14px;
  }

  @media only screen and (max-width: 450px) {
    font-size: 12px;
  }
`;

const ResponsiveTypography = styled(Typography)`
font-weight: bold;
  @media only screen and (max-width: 900px) {
    font-size: 18px;
  }
  @media only screen and (max-width: 450px) {
    font-size: 15px;
  }
`;
