/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.expression.function;

import java.lang.reflect.Method;
import java.util.List;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexImpTable;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.function.Parameter;
import org.apache.calcite.linq4j.function.Strict;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.impl.ScalarFunctionImpl;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.BuiltInMethod;
import org.opensearch.sql.calcite.type.ExprSqlType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.planner.physical.collector.Rounding;

public class SpanFunctionImpl
extends ImplementorUDF {
    protected SpanFunctionImpl() {
        super(new SpanImplementor(), NullPolicy.ARG0);
    }

    @Override
    public SqlReturnTypeInference getReturnTypeInference() {
        return ReturnTypes.ARG0;
    }

    @Strict
    public static Object evalDate(@Parameter(name="value") String value, @Parameter(name="interval") int interval, @Parameter(name="unit") String unit) {
        ExprValue exprInterval = ExprValueUtils.fromObjectValue(interval, ExprCoreType.INTEGER);
        ExprValue exprValue = ExprValueUtils.fromObjectValue(value, ExprCoreType.DATE);
        Rounding.DateRounding rounding = new Rounding.DateRounding(exprInterval, unit);
        return ((Rounding)rounding).round(exprValue).valueForCalcite();
    }

    @Strict
    public static Object evalTime(@Parameter(name="value") String value, @Parameter(name="interval") int interval, @Parameter(name="unit") String unit) {
        ExprValue exprInterval = ExprValueUtils.fromObjectValue(interval, ExprCoreType.INTEGER);
        ExprValue exprValue = ExprValueUtils.fromObjectValue(value, ExprCoreType.TIME);
        Rounding.TimeRounding rounding = new Rounding.TimeRounding(exprInterval, unit);
        return ((Rounding)rounding).round(exprValue).valueForCalcite();
    }

    @Strict
    public static Object evalTimestamp(@Parameter(name="value") String value, @Parameter(name="interval") int interval, @Parameter(name="unit") String unit) {
        ExprValue exprInterval = ExprValueUtils.fromObjectValue(interval, ExprCoreType.INTEGER);
        ExprValue exprValue = ExprValueUtils.fromObjectValue(value, ExprCoreType.TIMESTAMP);
        Rounding.TimestampRounding rounding = new Rounding.TimestampRounding(exprInterval, unit);
        return ((Rounding)rounding).round(exprValue).valueForCalcite();
    }

    public static class SpanImplementor
    implements NotNullImplementor {
        public Expression implement(RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
            assert (call.getOperands().size() == 3) : "SPAN should have 3 arguments";
            assert (translatedOperands.size() == 3) : "SPAN should have 3 arguments";
            Expression field = translatedOperands.get(0);
            Expression interval = translatedOperands.get(1);
            RelDataType fieldType = ((RexNode)call.getOperands().get(0)).getType();
            RelDataType unitType = ((RexNode)call.getOperands().get(2)).getType();
            if (SqlTypeUtil.isNull((RelDataType)unitType)) {
                return switch (call.getType().getSqlTypeName()) {
                    case SqlTypeName.BIGINT, SqlTypeName.INTEGER, SqlTypeName.SMALLINT, SqlTypeName.TINYINT -> Expressions.multiply((Expression)Expressions.divide((Expression)field, (Expression)interval), (Expression)interval);
                    default -> Expressions.multiply((Expression)Expressions.call((Method)BuiltInMethod.FLOOR.method, (Expression[])new Expression[]{Expressions.divide((Expression)field, (Expression)interval)}), (Expression)interval);
                };
            }
            if (fieldType instanceof ExprSqlType) {
                ExprSqlType exprSqlType = (ExprSqlType)fieldType;
                String methodName = switch (exprSqlType.getUdt()) {
                    case OpenSearchTypeFactory.ExprUDT.EXPR_DATE -> "evalDate";
                    case OpenSearchTypeFactory.ExprUDT.EXPR_TIME -> "evalTime";
                    case OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP -> "evalTimestamp";
                    default -> throw new IllegalArgumentException(String.format("Unsupported expr type: %s", exprSqlType.getExprType()));
                };
                ScalarFunctionImpl function = (ScalarFunctionImpl)ScalarFunctionImpl.create((Method)Types.lookupMethod(SpanFunctionImpl.class, (String)methodName, (Class[])new Class[]{String.class, Integer.TYPE, String.class}));
                return function.getImplementor().implement(translator, call, RexImpTable.NullAs.NULL);
            }
            throw new IllegalArgumentException(String.format("Unsupported expr type: %s", OpenSearchTypeFactory.convertRelDataTypeToExprType(fieldType)));
        }
    }
}

