/*
 * Copyright 2016 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.kie.dmn.feel.lang.ast;

import org.antlr.v4.runtime.ParserRuleContext;
import org.kie.dmn.feel.lang.EvaluationContext;
import org.kie.dmn.feel.runtime.Range;
import org.kie.dmn.feel.util.EvalHelper;

import java.math.BigDecimal;
import java.math.MathContext;
import java.time.*;
import java.time.temporal.Temporal;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.function.BiPredicate;
import java.util.function.BinaryOperator;

public class InfixOpNode
        extends BaseNode {

    public static enum InfixOperator {
        ADD( "+" ),
        SUB( "-" ),
        MULT( "*" ),
        DIV( "/" ),
        POW( "**" ),
        LTE( "<=" ),
        LT( "<" ),
        GT( ">" ),
        GTE( ">=" ),
        EQ( "=" ),
        NE( "!=" ),
        AND( "and" ),
        OR( "or" );

        public final String symbol;

        InfixOperator(String symbol) {
            this.symbol = symbol;
        }

        public static InfixOperator determineOperator(String symbol) {
            for ( InfixOperator op : InfixOperator.values() ) {
                if ( op.symbol.equals( symbol ) ) {
                    return op;
                }
            }
            throw new IllegalArgumentException( "No operator found for symbol '" + symbol + "'" );
        }
    }

    private InfixOperator operator;
    private BaseNode      left;
    private BaseNode      right;

    public InfixOpNode(ParserRuleContext ctx, BaseNode left, String op, BaseNode right) {
        super( ctx );
        this.left = left;
        this.operator = InfixOperator.determineOperator( op );
        this.right = right;
    }

    public InfixOperator getOperator() {
        return operator;
    }

    public void setOperator(InfixOperator operator) {
        this.operator = operator;
    }

    public BaseNode getLeft() {
        return left;
    }

    public void setLeft(BaseNode left) {
        this.left = left;
    }

    public BaseNode getRight() {
        return right;
    }

    public void setRight(BaseNode right) {
        this.right = right;
    }

    @Override
    public Object evaluate(EvaluationContext ctx) {
        Object left = this.left.evaluate( ctx );
        Object right = this.right.evaluate( ctx );
        switch ( operator ) {
            case ADD:
                return add( left, right, ctx );
            case SUB:
                return sub( left, right, ctx);
            case MULT:
                return math( left, right, ctx, (l, r) -> l.multiply( r, MathContext.DECIMAL128 ) );
            case DIV:
                return math( left, right, ctx, (l, r) -> l.divide( r, MathContext.DECIMAL128 ) );
            case POW:
                return math( left, right, ctx, (l, r) -> l.pow( r.intValue(), MathContext.DECIMAL128 ) );
            case AND:
                return and( left, right, ctx );
            case OR:
                return or( left, right, ctx );
            case LTE:
                return comparison( left, right, ctx, (l, r) -> l.compareTo( r ) <= 0 );
            case LT:
                return comparison( left, right, ctx, (l, r) -> l.compareTo( r ) < 0 );
            case GT:
                return comparison( left, right, ctx, (l, r) -> l.compareTo( r ) > 0 );
            case GTE:
                return comparison( left, right, ctx, (l, r) -> l.compareTo( r ) >= 0 );
            case EQ:
                return equality( left, right, ctx, (l, r) -> l.compareTo( r ) == 0 );
            case NE:
                return equality( left, right, ctx, (l, r) -> l.compareTo( r ) != 0 );
            default:
                return null;
        }
    }

    private Object add(Object left, Object right, EvaluationContext ctx) {
        if ( left == null || right == null ) {
            return null;
        } else if ( left instanceof String && right instanceof String ) {
            return ((String) left) + ((String) right);
        } else if ( left instanceof Period && right instanceof Period ) {
            return ((Period) left).plus( (Period) right);
        } else if ( left instanceof Duration && right instanceof Duration ) {
            return ((Duration) left).plus( (Duration) right);
        } else if ( left instanceof ZonedDateTime && right instanceof Period ) {
            return ((ZonedDateTime) left).plus( (Period) right);
        } else if ( left instanceof OffsetDateTime && right instanceof Period ) {
            return ((OffsetDateTime) left).plus( (Period) right);
        } else if ( left instanceof LocalDateTime && right instanceof Period ) {
            return ((LocalDateTime) left).plus( (Period) right);
        } else if ( left instanceof ZonedDateTime && right instanceof Duration ) {
            return ((ZonedDateTime) left).plus( (Duration) right);
        } else if ( left instanceof OffsetDateTime && right instanceof Duration ) {
            return ((OffsetDateTime) left).plus( (Duration) right);
        } else if ( left instanceof LocalDateTime && right instanceof Duration ) {
            return ((LocalDateTime) left).plus( (Duration) right);
        } else if ( left instanceof Period && right instanceof ZonedDateTime ) {
            return ((ZonedDateTime) right).plus( (Period) left);
        } else if ( left instanceof Period && right instanceof OffsetDateTime ) {
            return ((OffsetDateTime) right).plus( (Period) left);
        } else if ( left instanceof Period && right instanceof LocalDateTime ) {
            return ((LocalDateTime) right).plus( (Period) left);
        } else if ( left instanceof Duration && right instanceof ZonedDateTime ) {
            return ((ZonedDateTime) right).plus( (Duration) left);
        } else if ( left instanceof Duration && right instanceof OffsetDateTime ) {
            return ((OffsetDateTime) right).plus( (Duration) left);
        } else if ( left instanceof Duration && right instanceof LocalDateTime ) {
            return ((LocalDateTime) right).plus( (Duration) left);
        } else if ( left instanceof LocalTime && right instanceof Duration ) {
            return ((LocalDateTime) left).plus( (Duration) left);
        } else if ( left instanceof Duration && right instanceof LocalTime ) {
            return ((LocalDateTime) right).plus( (Duration) left);
        } else if ( left instanceof OffsetTime && right instanceof Duration ) {
            return ((OffsetTime) left).plus( (Duration) left);
        } else if ( left instanceof Duration && right instanceof OffsetTime ) {
            return ((OffsetTime) right).plus( (Duration) left);
        } else {
            return math( left, right, ctx, (l, r) -> l.add( r, MathContext.DECIMAL128 ) );
        }
    }

    private Object sub(Object left, Object right, EvaluationContext ctx) {
        if ( left == null || right == null ) {
            return null;
        } else if ( left instanceof ZonedDateTime && right instanceof ZonedDateTime ) {
            return Duration.between( (ZonedDateTime)left, (ZonedDateTime) right);
        } else if ( left instanceof OffsetDateTime && right instanceof OffsetDateTime ) {
            return Duration.between( (OffsetDateTime)left, (OffsetDateTime) right);
        } else if ( left instanceof LocalDateTime && right instanceof LocalDateTime ) {
            return Duration.between( (LocalDateTime)left, (LocalDateTime) right);
        } else if ( left instanceof LocalTime && right instanceof LocalTime ) {
            return Duration.between( (LocalTime)left, (LocalTime) right);
        } else if ( left instanceof OffsetTime && right instanceof OffsetTime ) {
            return Duration.between( (OffsetTime)left, (OffsetTime) right);
        } else if ( left instanceof Period && right instanceof Period ) {
            return ((Period) left).minus( (Period) right);
        } else if ( left instanceof Duration && right instanceof Duration ) {
            return ((Duration) left).minus( (Duration) right);
        } else if ( left instanceof ZonedDateTime && right instanceof Period ) {
            return ((ZonedDateTime) left).minus( (Period) right);
        } else if ( left instanceof OffsetDateTime && right instanceof Period ) {
            return ((OffsetDateTime) left).minus( (Period) right);
        } else if ( left instanceof LocalDateTime && right instanceof Period ) {
            return ((LocalDateTime) left).minus( (Period) right);
        } else if ( left instanceof ZonedDateTime && right instanceof Duration ) {
            return ((ZonedDateTime) left).minus( (Duration) right);
        } else if ( left instanceof OffsetDateTime && right instanceof Duration ) {
            return ((OffsetDateTime) left).minus( (Duration) right);
        } else if ( left instanceof LocalDateTime && right instanceof Duration ) {
            return ((LocalDateTime) left).minus( (Duration) right);
        } else if ( left instanceof LocalTime && right instanceof Duration ) {
            return ((LocalDateTime) left).minus( (Duration) left);
        } else if ( left instanceof OffsetTime && right instanceof Duration ) {
            return ((OffsetTime) left).minus( (Duration) left);
        } else {
            return math( left, right, ctx, (l, r) -> l.subtract( r, MathContext.DECIMAL128 )  );
        }
    }

    private Object math(Object left, Object right, EvaluationContext ctx, BinaryOperator<BigDecimal> op) {
        BigDecimal l = EvalHelper.getBigDecimalOrNull( left );
        BigDecimal r = EvalHelper.getBigDecimalOrNull( right );
        if ( l == null || r == null ) {
            return null;
        }
        try {
            return op.apply( l, r );
        } catch ( ArithmeticException e ) {
            // happens in cases like division by 0
            return null;
        }
    }

    /**
     * Implements the ternary logic AND operation
     */
    private Object and(Object left, Object right, EvaluationContext ctx) {
        Boolean l = EvalHelper.getBooleanOrNull( left );
        Boolean r = EvalHelper.getBooleanOrNull( right );
        // have to check for all nulls first to avoid NPE
        if ( (l == null && r == null) || (l == null && r == true) || (r == null && l == true) ) {
            return null;
        } else if ( l == null || r == null ) {
            return false;
        }
        return l && r;
    }

    /**
     * Implements the ternary logic OR operation
     */
    private Object or(Object left, Object right, EvaluationContext ctx) {
        Boolean l = EvalHelper.getBooleanOrNull( left );
        Boolean r = EvalHelper.getBooleanOrNull( right );
        // have to check for all nulls first to avoid NPE
        if ( (l == null && r == null) || (l == null && r == false) || (r == null && l == false) ) {
            return null;
        } else if ( l == null || r == null ) {
            return true;
        }
        return l || r;
    }

    private Object comparison(Object left, Object right, EvaluationContext ctx, BiPredicate<Comparable, Comparable> op) {
        if ( left == null || right == null ) {
            return null;
        } else if ( (left instanceof String && right instanceof String) ||
                    (left instanceof Number && right instanceof Number) ||
                    (left instanceof Boolean && right instanceof Boolean) ||
                    (left instanceof Comparable && left.getClass().isAssignableFrom( right.getClass() )) ) {
            Comparable l = (Comparable) left;
            Comparable r = (Comparable) right;
            return op.test( l, r );
        }
        return null;
    }


    private Object equality(Object left, Object right, EvaluationContext ctx, BiPredicate<Comparable, Comparable> op) {
        if ( left == null && right == null ) {
            return operator == InfixOperator.EQ;
        } else if ( left == null || right == null ) {
            return operator == InfixOperator.NE;
        } else if( left instanceof Range && right instanceof Range ) {
            return operator == InfixOperator.NE ^ isEqual( (Range)left, (Range) right );
        } else if( left instanceof Iterable && right instanceof Iterable ) {
            return operator == InfixOperator.NE ^ isEqual( (Iterable)left, (Iterable) right );
        } else if( left instanceof Map && right instanceof Map ) {
            return operator == InfixOperator.NE ^ isEqual( (Map)left, (Map) right );
        }
        return comparison( left, right, ctx, op );
    }

    private Boolean isEqual(Range left, Range right) {
        return left.equals( right );
    }

    private Boolean isEqual(Iterable left, Iterable right) {
        Iterator li = left.iterator();
        Iterator ri = right.iterator();
        while( li.hasNext() && ri.hasNext() ) {
            Object l = li.next();
            Object r = ri.next();
            if ( !isEqual( l, r ) ) return false;
        }
        return li.hasNext() == ri.hasNext();
    }

    private Boolean isEqual(Map<?,?> left, Map<?,?> right) {
        if( left.size() != right.size() ) {
            return false;
        }
        for( Map.Entry le : left.entrySet() ) {
            Object l = le.getValue();
            Object r = right.get( le.getKey() );
            if ( !isEqual( l, r ) ) return false;
        }
        return true;
    }

    private boolean isEqual(Object l, Object r) {
        if( l instanceof Iterable && r instanceof Iterable && !isEqual( (Iterable) l, (Iterable) r ) ) {
            return false;
        } else if( l instanceof Map && r instanceof Map && !isEqual( (Map) l, (Map) r ) ) {
            return false;
        } else if( l != null && r != null && !l.equals( r ) ) {
            return false;
        } else if( ( l == null || r == null ) && l != r ) {
            return false;
        }
        return true;
    }
}
