package com.ediagnosis.cdr.dao;


import com.mybatisflex.core.datasource.DataSourceKey;
import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.core.row.Db;
import com.mybatisflex.core.row.Row;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.parser.CCJSqlParser;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.parser.ParseException;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.Statements;
import net.sf.jsqlparser.statement.UnsupportedStatement;
import net.sf.jsqlparser.statement.select.*;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.util.Assert;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

@SpringBootTest
class HiveQueryExecutorTest {


    @Test
    public void selectBySql() {
        String sql = "SELECT * FROM ods.xjd_patient_info ORDER BY id LIMIT 10";
        try {
            DataSourceKey.use("ds-hive");
            List<Row> rows = Db.selectListBySql(sql);
            System.out.println(rows);
        } finally {
            DataSourceKey.clear();
        }
    }

    @Test
    public void parseSql() throws JSQLParserException {
        String sqlStr = """
                 select * from (select * from
                     (
                         select t1.order_id as orderId,
                                t2.patient_name as patientName,
                                DATE_FORMAT(t2.create_ts,'%Y-%m') as 'ym',
                                t1.diagnosis,
                                t1.is_refill,
                                t2.state,
                                t4.patient_leave_hospital_time as endTime,
                                t2.to_hospital_time as startTime,
                                t4.transfer_code,
                                0 as validPass,
                                0 as pass,
                                t2.to_hospital_way as toTospitalWay,
                                ifnull(t1.is_n2s,0) as isN2s,
                                t1.treat_method
                         from diagnosis t1
                                  left join clinic_order t2 on t1.order_id = t2.id and t2.state not in (6, 10, 40)
                                  left join treat_cp_rongshuan t3 on t2.id = t3.order_id
                                  left join archive_info t4 on t2.id = t4.order_id
                
                                  inner join clinic_order_report AS t8 on t8.order_id = t2.id
                
                         where
                             t2.product_id = 1
                           and t2.state >= '1'
                           and t8.report_state in (1,2)
                           and t2.hospital_code = '1232456'
                           and t2.is_enabled = 1
                           and t1.is_enabled = 1
                           and t2.to_hospital_time is not null
                           and t2.create_ts  between '2021-10-1' and '2022-10-2'
                         order by t1.diag_time desc, t1.id desc
                         limit 10
                     ) tmp group by orderId ) t9 
                     where t9.diagnosis in ('1') 
                        and t9.treat_method is not null 
                        and (t9.treat_method like '%17%' 
                                    or t9.treat_method like '%26%' 
                                    or t9.treat_method like '%50%' 
                             )
                """;


        Statement statement = CCJSqlParserUtil.parse(sqlStr);

        if (statement instanceof PlainSelect plainSelect) {
            Limit limit = plainSelect.getLimit();
            if (limit == null) {
                System.out.println("判断不存在limit从句，添加从句");
                Select select = plainSelect.withLimit(
                        new Limit()
                                .withRowCount(new LongValue(10))
                                .withOffset(new LongValue(0))
                );
                System.out.println(select.toString());
            } else {
                System.out.println("判断存在limit从句，不添加");
                System.out.println(plainSelect.toString());
            }
        }

    }


    @Test
    public void parseSqlFromItem() throws JSQLParserException {
        String sqlStr1 = """
                 select * from (select * from
                     (
                         select t1.order_id as orderId,
                                t2.patient_name as patientName,
                                DATE_FORMAT(t2.create_ts,'%Y-%m') as 'ym',
                                t1.diagnosis,
                                t1.is_refill,
                                t2.state,
                                t4.patient_leave_hospital_time as endTime,
                                t2.to_hospital_time as startTime,
                                t4.transfer_code,
                                0 as validPass,
                                0 as pass,
                                t2.to_hospital_way as toTospitalWay,
                                ifnull(t1.is_n2s,0) as isN2s,
                                t1.treat_method
                         from diagnosis t1
                                  left join clinic_order t2 on t1.order_id = t2.id and t2.state not in (6, 10, 40)
                                  left join treat_cp_rongshuan t3 on t2.id = t3.order_id
                                  left join archive_info t4 on t2.id = t4.order_id
                
                                  inner join clinic_order_report AS t8 on t8.order_id = t2.id
                
                         where
                             t2.product_id = 1
                           and t2.state >= '1'
                           and t8.report_state in (1,2)
                           and t2.hospital_code = '1232456'
                           and t2.is_enabled = 1
                           and t1.is_enabled = 1
                           and t2.to_hospital_time is not null
                           and t2.create_ts  between '2021-10-1' and '2022-10-2'
                         order by t1.diag_time desc, t1.id desc
                         limit 10
                     ) tmp group by orderId ) t9 
                     where t9.diagnosis in ('1') 
                        and t9.treat_method is not null 
                        and (t9.treat_method like '%17%' 
                                    or t9.treat_method like '%26%' 
                                    or t9.treat_method like '%50%' 
                             )
                """;
        String sqlStr2 = """
                select p.patient_id,p.patient_name,r.visit_no from ods.xjd_patient_info p, ods.xjd_emergency_record r
                where p.patient_id=r.patient_id  and p.id='3';
                """;


        Statement statement1 = CCJSqlParserUtil.parse(sqlStr1);
        Statement statement2 = CCJSqlParserUtil.parse(sqlStr2);
        changeToAllColumns(statement1);
        changeToAllColumns(statement2);

    }


    @Test
    public void changeToCountSql() throws JSQLParserException {
        String sqlStr1 = """
                 select * from (select * from
                     (
                         select t1.order_id as orderId,
                                t2.patient_name as patientName,
                                DATE_FORMAT(t2.create_ts,'%Y-%m') as 'ym',
                                t1.diagnosis,
                                t1.is_refill,
                                t2.state,
                                t4.patient_leave_hospital_time as endTime,
                                t2.to_hospital_time as startTime,
                                t4.transfer_code,
                                0 as validPass,
                                0 as pass,
                                t2.to_hospital_way as toTospitalWay,
                                ifnull(t1.is_n2s,0) as isN2s,
                                t1.treat_method
                         from diagnosis t1
                                  left join clinic_order t2 on t1.order_id = t2.id and t2.state not in (6, 10, 40)
                                  left join treat_cp_rongshuan t3 on t2.id = t3.order_id
                                  left join archive_info t4 on t2.id = t4.order_id
                
                                  inner join clinic_order_report AS t8 on t8.order_id = t2.id
                
                         where
                             t2.product_id = 1
                           and t2.state >= '1'
                           and t8.report_state in (1,2)
                           and t2.hospital_code = '1232456'
                           and t2.is_enabled = 1
                           and t1.is_enabled = 1
                           and t2.to_hospital_time is not null
                           and t2.create_ts  between '2021-10-1' and '2022-10-2'
                         order by t1.diag_time desc, t1.id desc
                         limit 10
                     ) tmp group by orderId ) t9 
                     where t9.diagnosis in ('1') 
                        and t9.treat_method is not null 
                        and (t9.treat_method like '%17%' 
                                    or t9.treat_method like '%26%' 
                                    or t9.treat_method like '%50%' 
                             )
                """;
        String sqlStr2 = """
                select p.patient_id,p.patient_name,r.visit_no from ods.xjd_patient_info p, ods.xjd_emergency_record r
                where p.patient_id=r.patient_id  and p.id='3';
                """;

        String sqlStr3 = """
                select count(*)
                from ods.xjd_patient_info p, ods.xjd_emergency_record r
                where p.patient_id=r.patient_id  and p.id='3';
                """;

        Statement statement3 = CCJSqlParserUtil.parse(sqlStr3);
        Statement statement2 = CCJSqlParserUtil.parse(sqlStr2);
        Statement statement1 = CCJSqlParserUtil.parse(sqlStr1);

        changeToCuntColumns(statement1);
        changeToCuntColumns(statement2);
        changeToCuntColumns(statement3);

    }

    private void changeToCuntColumns(Statement statement) {
        List<SelectItem<?>> selectItems = null;
        if (statement instanceof PlainSelect plainSelect) {
            selectItems = plainSelect.getSelectItems();
        }
        Assert.notNull(selectItems, "");
        selectItems.clear();
        selectItems.add(new SelectItem<>(
                        new Function("count", new AllColumns())
                )
        );
        System.out.println(statement);

    }


    private void changeToAllColumns(Statement statement) {


        List<SelectItem<?>> selectItems = null;

        if (statement instanceof PlainSelect plainSelect) {
            selectItems = plainSelect.getSelectItems();
        }
        Assert.notNull(selectItems, "");

        if (selectItems.size() > 1) {
            selectItems.clear();
            selectItems.add(new SelectItem<>(new AllColumns()));
        } else if (selectItems.size() == 1) {
            SelectItem<?> selectItem = selectItems.getFirst();
            Expression expression = selectItem.getExpression();
            if (!(expression instanceof AllColumns allColumns)) {
                expression = new AllColumns();
            }
        }
        System.out.println(statement.toString());
    }


    @Test
    public void parseSqlErrorRecoveryTest() throws ParseException {
        CCJSqlParser parser = new CCJSqlParser(
                "select * from mytable; select from; select * from mytable2");
        Statements statements = parser.withErrorRecovery().Statements();

// 3 statements, the failing one set to NULL
        assertEquals(3, statements.size());
        assertNull(statements.get(1));

// errors are recorded
        assertEquals(1, parser.getParseErrors().size());
    }

    @Test
    public void parseSqlUnsupportedStatementTest() throws ParseException, JSQLParserException {
        CCJSqlParser parser = new CCJSqlParser(
                "select * from mytable; select from; select * from mytable2; select 4;");
        parser.withUnsupportedStatements();
        Statements statements = parser.Statements();

// 4 statements with one Unsupported Statement holding the content
        assertEquals(4, statements.size());
        assertInstanceOf(UnsupportedStatement.class, statements.get(1));
        assertEquals("select from", statements.get(1).toString());

// no errors records, because a statement has been returned
        assertEquals(0, parser.getParseErrors().size());
    }


    @Test
    public void selectPage() {
        try {
            DataSourceKey.use("ds-hive");
            QueryWrapper query = QueryWrapper.create()
                    .where("id > ?", 18);
            Page<Row> rowPage = Db.paginate("ods.xjd_patient_info", 1, 10, query);
            System.out.println(rowPage);
        } finally {
            DataSourceKey.clear();
        }

    }


}