/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.calcite;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.TimeString;
import org.apache.calcite.util.TimestampString;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.And;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.expression.subquery.ScalarSubquery;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.CalciteRelNodeVisitor;
import org.opensearch.sql.calcite.ExtendedRexBuilder;
import org.opensearch.sql.calcite.type.ExprSqlType;
import org.opensearch.sql.calcite.udf.datetimeUDF.PostprocessDateToStringFunction;
import org.opensearch.sql.calcite.utils.BuiltinFunctionUtils;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.CalciteUnsupportedException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.PPLFuncImpTable;

public class CalciteRexNodeVisitor
extends AbstractNodeVisitor<RexNode, CalcitePlanContext> {
    private final CalciteRelNodeVisitor planVisitor;

    public RexNode analyze(UnresolvedExpression unresolved, CalcitePlanContext context) {
        return unresolved.accept(this, context);
    }

    public RexNode analyzeJoinCondition(UnresolvedExpression unresolved, CalcitePlanContext context) {
        return context.resolveJoinCondition(unresolved, this::analyze);
    }

    @Override
    public RexNode visitLiteral(Literal node, CalcitePlanContext context) {
        ExtendedRexBuilder rexBuilder = context.rexBuilder;
        RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
        Object value = node.getValue();
        if (value == null) {
            RelDataType type = typeFactory.createSqlType(SqlTypeName.NULL);
            return rexBuilder.makeNullLiteral(type);
        }
        switch (node.getType()) {
            case NULL: {
                return rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.NULL));
            }
            case STRING: {
                return rexBuilder.makeLiteral(value.toString());
            }
            case INTEGER: {
                return rexBuilder.makeExactLiteral(new BigDecimal((Integer)value));
            }
            case LONG: {
                return rexBuilder.makeBigintLiteral(new BigDecimal((Long)value));
            }
            case SHORT: {
                return rexBuilder.makeExactLiteral(new BigDecimal(((Short)value).shortValue()), typeFactory.createSqlType(SqlTypeName.SMALLINT));
            }
            case FLOAT: {
                return rexBuilder.makeApproxLiteral(new BigDecimal(Float.toString(((Float)value).floatValue())), typeFactory.createSqlType(SqlTypeName.FLOAT));
            }
            case DOUBLE: {
                return rexBuilder.makeApproxLiteral(new BigDecimal(Double.toString((Double)value)), typeFactory.createSqlType(SqlTypeName.DOUBLE));
            }
            case BOOLEAN: {
                return rexBuilder.makeLiteral((Boolean)value);
            }
            case DATE: {
                return rexBuilder.makeDateLiteral(new DateString(value.toString()));
            }
            case TIME: {
                return rexBuilder.makeTimeLiteral(new TimeString(value.toString()), -1);
            }
            case TIMESTAMP: {
                return rexBuilder.makeTimestampLiteral(new TimestampString(value.toString()), -1);
            }
        }
        throw new UnsupportedOperationException("Unsupported literal type: " + String.valueOf((Object)node.getType()));
    }

    @Override
    public RexNode visitInterval(Interval node, CalcitePlanContext context) {
        RexNode value = this.analyze(node.getValue(), context);
        SqlIntervalQualifier intervalQualifier = context.rexBuilder.createIntervalUntil(PlanUtils.intervalUnitToSpanUnit(node.getUnit()));
        return context.rexBuilder.makeIntervalLiteral(new BigDecimal(value.toString()), intervalQualifier);
    }

    @Override
    public RexNode visitAnd(And node, CalcitePlanContext context) {
        RelDataType booleanType = context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN);
        RexNode left = this.analyze(node.getLeft(), context);
        RexNode right = this.analyze(node.getRight(), context);
        return context.rexBuilder.makeCall(booleanType, (SqlOperator)SqlStdOperatorTable.AND, List.of(left, right));
    }

    @Override
    public RexNode visitOr(Or node, CalcitePlanContext context) {
        RexNode left = this.analyze(node.getLeft(), context);
        RexNode right = this.analyze(node.getRight(), context);
        return context.relBuilder.or(new RexNode[]{left, right});
    }

    @Override
    public RexNode visitXor(Xor node, CalcitePlanContext context) {
        RexNode left = this.analyze(node.getLeft(), context);
        RexNode right = this.analyze(node.getRight(), context);
        return context.relBuilder.notEquals(left, right);
    }

    @Override
    public RexNode visitNot(Not node, CalcitePlanContext context) {
        RexNode expr = this.analyze(node.getExpression(), context);
        return context.relBuilder.not(expr);
    }

    @Override
    public RexNode visitIn(In node, CalcitePlanContext context) {
        RexNode field = this.analyze(node.getField(), context);
        List<RexNode> valueList = node.getValueList().stream().map(value -> this.analyze((UnresolvedExpression)value, context)).toList();
        ArrayList<RelDataType> dataTypes = new ArrayList<RelDataType>(valueList.stream().map(RexNode::getType).toList());
        dataTypes.add(field.getType());
        RelDataType commonType = context.rexBuilder.getTypeFactory().leastRestrictive(dataTypes);
        if (commonType != null) {
            List<RexNode> newValueList = valueList.stream().map(value -> context.rexBuilder.makeCast(commonType, (RexNode)value)).toList();
            return context.rexBuilder.makeIn(field, newValueList);
        }
        List<ExprType> exprTypes = dataTypes.stream().map(OpenSearchTypeFactory::convertRelDataTypeToExprType).toList();
        throw new SemanticCheckException(StringUtils.format("In expression types are incompatible: fields type %s, values type %s", exprTypes.getLast(), exprTypes.subList(0, exprTypes.size() - 1)));
    }

    @Override
    public RexNode visitCompare(Compare node, CalcitePlanContext context) {
        RexNode leftCandidate = this.analyze(node.getLeft(), context);
        RexNode rightCandidate = this.analyze(node.getRight(), context);
        Boolean whetherCompareByTime = leftCandidate.getType() instanceof ExprSqlType || rightCandidate.getType() instanceof ExprSqlType;
        RexNode left = this.transferCompareForDateRelated(leftCandidate, context, whetherCompareByTime);
        RexNode right = this.transferCompareForDateRelated(rightCandidate, context, whetherCompareByTime);
        return PPLFuncImpTable.INSTANCE.resolve((RexBuilder)context.rexBuilder, node.getOperator(), left, right);
    }

    private RexNode transferCompareForDateRelated(RexNode candidate, CalcitePlanContext context, boolean whetherCompareByTime) {
        if (whetherCompareByTime) {
            SqlOperator postToStringNode = UserDefinedFunctionUtils.TransferUserDefinedFunction(PostprocessDateToStringFunction.class, "PostprocessDateToString", BuiltinFunctionUtils.VARCHAR_FORCE_NULLABLE);
            RexNode transferredStringNode = context.rexBuilder.makeCall(postToStringNode, List.of(candidate, context.rexBuilder.makeLiteral(context.functionProperties.getQueryStartClock().instant().toString())));
            return transferredStringNode;
        }
        return candidate;
    }

    @Override
    public RexNode visitBetween(Between node, CalcitePlanContext context) {
        RexNode value = this.analyze(node.getValue(), context);
        RexNode lowerBound = this.analyze(node.getLowerBound(), context);
        RexNode upperBound = this.analyze(node.getUpperBound(), context);
        RelDataType commonType = context.rexBuilder.commonType(value, lowerBound, upperBound);
        if (commonType == null) {
            throw new SemanticCheckException(StringUtils.format("BETWEEN expression types are incompatible: [%s, %s, %s]", OpenSearchTypeFactory.convertRelDataTypeToExprType(value.getType()), OpenSearchTypeFactory.convertRelDataTypeToExprType(lowerBound.getType()), OpenSearchTypeFactory.convertRelDataTypeToExprType(upperBound.getType())));
        }
        lowerBound = context.rexBuilder.makeCast(commonType, lowerBound);
        upperBound = context.rexBuilder.makeCast(commonType, upperBound);
        return context.relBuilder.between(value, lowerBound, upperBound);
    }

    @Override
    public RexNode visitEqualTo(EqualTo node, CalcitePlanContext context) {
        RexNode left = this.analyze(node.getLeft(), context);
        RexNode right = this.analyze(node.getRight(), context);
        return context.rexBuilder.equals(left, right);
    }

    @Override
    public RexNode visitQualifiedName(QualifiedName node, CalcitePlanContext context) {
        if (context.isResolvingJoinCondition()) {
            List<String> parts = node.getParts();
            if (parts.size() == 1) {
                try {
                    return context.relBuilder.field(2, 0, parts.getFirst());
                }
                catch (IllegalArgumentException ee) {
                    return context.relBuilder.field(2, 1, parts.getFirst());
                }
            }
            if (parts.size() == 2) {
                return context.relBuilder.field(2, parts.get(0), parts.get(1));
            }
            if (parts.size() == 3) {
                throw new UnsupportedOperationException("Unsupported qualified name: " + String.valueOf(node));
            }
        }
        String qualifiedName = node.toString();
        List currentFields = context.relBuilder.peek().getRowType().getFieldNames();
        if (currentFields.contains(qualifiedName)) {
            return context.relBuilder.field(qualifiedName);
        }
        if (node.getParts().size() == 2) {
            List<String> parts = node.getParts();
            try {
                return context.relBuilder.field(1, parts.get(0), parts.get(1));
            }
            catch (IllegalArgumentException e) {
                return context.peekCorrelVar().map(correlVar -> context.relBuilder.field((RexNode)correlVar, (String)parts.get(1))).orElseThrow(() -> e);
            }
        }
        if (currentFields.stream().noneMatch(f -> f.startsWith(qualifiedName))) {
            return context.peekCorrelVar().map(correlVar -> context.relBuilder.field((RexNode)correlVar, qualifiedName)).orElseGet(() -> context.relBuilder.field(qualifiedName));
        }
        throw new IllegalArgumentException(String.format("field [%s] not found; input fields are: %s", qualifiedName, currentFields));
    }

    @Override
    public RexNode visitAlias(Alias node, CalcitePlanContext context) {
        RexNode expr = this.analyze(node.getDelegated(), context);
        return context.relBuilder.alias(expr, Strings.isEmpty((CharSequence)node.getAlias()) ? node.getName() : node.getAlias());
    }

    @Override
    public RexNode visitSpan(Span node, CalcitePlanContext context) {
        RexNode field = this.analyze(node.getField(), context);
        RexNode value = this.analyze(node.getValue(), context);
        SpanUnit unit = node.getUnit();
        RexBuilder rexBuilder = context.relBuilder.getRexBuilder();
        RexLiteral unitNode = this.isTimeBased(unit) ? rexBuilder.makeLiteral(unit.getName()) : rexBuilder.constantNull();
        return PPLFuncImpTable.INSTANCE.resolve((RexBuilder)context.rexBuilder, BuiltinFunctionName.SPAN, new RexNode[]{field, value, unitNode});
    }

    private boolean isTimeBased(SpanUnit unit) {
        return unit != SpanUnit.NONE && unit != SpanUnit.UNKNOWN;
    }

    @Override
    public RexNode visitLet(Let node, CalcitePlanContext context) {
        RexNode expr = this.analyze(node.getExpression(), context);
        return context.relBuilder.alias(expr, node.getVar().getField().toString());
    }

    @Override
    public RexNode visitFunction(Function node, CalcitePlanContext context) {
        List<RexNode> arguments = node.getFuncArgs().stream().map(arg -> this.analyze((UnresolvedExpression)arg, context)).collect(Collectors.toList());
        RexNode resolvedNode = PPLFuncImpTable.INSTANCE.resolveSafe(context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0]));
        if (resolvedNode != null) {
            return resolvedNode;
        }
        SqlOperator operator = BuiltinFunctionUtils.translate(node.getFuncName());
        List<RexNode> translatedArguments = BuiltinFunctionUtils.translateArgument(node.getFuncName(), arguments, context, context.functionProperties.getQueryStartClock().instant().toString());
        RelDataType returnType = BuiltinFunctionUtils.deriveReturnType(node.getFuncName(), context.rexBuilder, operator, translatedArguments);
        return context.rexBuilder.makeCall(returnType, operator, translatedArguments);
    }

    @Override
    public RexNode visitInSubquery(InSubquery node, CalcitePlanContext context) {
        List<RexNode> nodes = node.getChild().stream().map(child -> this.analyze((UnresolvedExpression)child, context)).toList();
        UnresolvedPlan subquery = node.getQuery();
        RelNode subqueryRel = this.resolveSubqueryPlan(subquery, context);
        try {
            return context.relBuilder.in(subqueryRel, nodes);
        }
        catch (AssertionError e) {
            throw new SemanticCheckException("The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery");
        }
    }

    @Override
    public RexNode visitScalarSubquery(ScalarSubquery node, CalcitePlanContext context) {
        return context.relBuilder.scalarQuery(b -> {
            UnresolvedPlan subquery = node.getQuery();
            return this.resolveSubqueryPlan(subquery, context);
        });
    }

    @Override
    public RexNode visitExistsSubquery(ExistsSubquery node, CalcitePlanContext context) {
        return context.relBuilder.exists(b -> {
            UnresolvedPlan subquery = node.getQuery();
            return this.resolveSubqueryPlan(subquery, context);
        });
    }

    private RelNode resolveSubqueryPlan(UnresolvedPlan subquery, CalcitePlanContext context) {
        boolean isResolvingJoinConditionOuter = context.isResolvingJoinCondition();
        if (isResolvingJoinConditionOuter) {
            context.setResolvingJoinCondition(false);
        }
        RelNode subqueryRel = subquery.accept(this.planVisitor, context);
        context.relBuilder.build();
        if (isResolvingJoinConditionOuter) {
            context.setResolvingJoinCondition(true);
        }
        return subqueryRel;
    }

    @Override
    public RexNode visitCast(Cast node, CalcitePlanContext context) {
        RexNode expr = this.analyze(node.getExpression(), context);
        RelDataType type = OpenSearchTypeFactory.convertExprTypeToRelDataType(node.getDataType().getCoreType());
        RelDataType nullableType = context.rexBuilder.getTypeFactory().createTypeWithNullability(type, true);
        return context.rexBuilder.makeCast(nullableType, expr, true, true);
    }

    @Override
    public RexNode visitWhen(When node, CalcitePlanContext context) {
        throw new CalciteUnsupportedException("CastWhen function is unsupported in Calcite");
    }

    @Override
    public RexNode visitRelevanceFieldList(RelevanceFieldList node, CalcitePlanContext context) {
        throw new CalciteUnsupportedException("Relevance fields expression is unsupported in Calcite");
    }

    @Generated
    public CalciteRexNodeVisitor(CalciteRelNodeVisitor planVisitor) {
        this.planVisitor = planVisitor;
    }
}

