/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.prims.mungers;

import ai.h2o.targetencoding.BlendingParams;
import ai.h2o.targetencoding.TargetEncoder;
import java.util.HashMap;
import java.util.Map;
import water.DKV;
import water.fvec.Frame;
import water.rapids.Env;
import water.rapids.ast.AstBuiltin;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstStr;
import water.rapids.ast.params.AstStrList;
import water.rapids.vals.ValFrame;

public class AstTargetEncoderTransform
extends AstBuiltin<AstTargetEncoderTransform> {
    public String[] args() {
        return new String[]{"encodingMapKeys encodingMapFrames frameToTransform teColumns strategy targetColumnName foldColumnName withBlending inflectionPoint smoothing noise seed"};
    }

    public String str() {
        return "target.encoder.transform";
    }

    public int nargs() {
        return 13;
    }

    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        String[] encodingMapKeys = this.getEncodingMapKeys(env, stk, asts);
        Frame[] encodingMapFrames = this.getEncodingMapFrames(env, stk, asts);
        Frame frame = this.getFrameToTransform(env, stk, asts);
        String[] teColumnsToEncode = this.getTEColumns(env, stk, asts);
        TargetEncoder.DataLeakageHandlingStrategy dataLeakageHandlingStrategy = this.getDataLeakageHandlingStrategy(env, stk, asts);
        String targetColumnName = this.getTargetColumnName(env, stk, asts);
        String foldColumnName = this.getFoldColumnName(env, stk, asts);
        boolean withBlending = this.getWithBlending(env, stk, asts);
        BlendingParams blendingParams = withBlending ? this.getBlendingParams(env, stk, asts) : null;
        double noise = this.getNoise(env, stk, asts);
        double seed = this.getSeed(env, stk, asts);
        boolean withImputationForOriginalColumns = true;
        TargetEncoder tec = blendingParams == null ? new TargetEncoder(teColumnsToEncode) : new TargetEncoder(teColumnsToEncode);
        Map<String, Frame> encodingMap = this.reconstructEncodingMap(encodingMapKeys, encodingMapFrames);
        if (noise == -1.0) {
            return new ValFrame(tec.applyTargetEncoding(frame, targetColumnName, encodingMap, dataLeakageHandlingStrategy, foldColumnName, withImputationForOriginalColumns, withImputationForOriginalColumns, blendingParams, (long)seed));
        }
        return new ValFrame(tec.applyTargetEncoding(frame, targetColumnName, encodingMap, dataLeakageHandlingStrategy, foldColumnName, withBlending, noise, withImputationForOriginalColumns, blendingParams, (long)seed));
    }

    private BlendingParams getBlendingParams(Env env, Env.StackHelp stk, AstRoot[] asts) {
        double inflectionPoint = this.getInflectionPoint(env, stk, asts);
        double smoothing = this.getSmoothing(env, stk, asts);
        return new BlendingParams(inflectionPoint, smoothing);
    }

    private Map<String, Frame> reconstructEncodingMap(String[] encodingMapKeys, Frame[] encodingMapFrames) {
        HashMap<String, Frame> encodingMap = new HashMap<String, Frame>();
        assert (encodingMapKeys.length == encodingMapFrames.length) : "EncodingMap elements are inconsistent";
        for (int i = 0; i < encodingMapKeys.length; ++i) {
            encodingMap.put(encodingMapKeys[i], encodingMapFrames[i]);
        }
        return encodingMap;
    }

    private String[] getEncodingMapKeys(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return this.getArgAsStrings(env, stk, asts[1]);
    }

    private Frame[] getEncodingMapFrames(Env env, Env.StackHelp stk, AstRoot[] asts) {
        String[] frameKeys = this.getArgAsStrings(env, stk, asts[2]);
        Frame[] framesWithEncodings = new Frame[frameKeys.length];
        int i = 0;
        for (String key : frameKeys) {
            framesWithEncodings[i++] = (Frame)DKV.getGet((String)key);
        }
        return framesWithEncodings;
    }

    private String[] getArgAsStrings(Env env, Env.StackHelp stk, AstRoot ast) {
        String[] frameKeys;
        if (ast instanceof AstStrList) {
            AstStrList teColumns = (AstStrList)ast;
            frameKeys = teColumns._strs;
        } else if (ast instanceof AstStr) {
            String teColumn = stk.track(ast.exec(env)).getStr();
            frameKeys = new String[]{teColumn};
        } else {
            throw new IllegalStateException("Failed to parse ast parameter: " + ast.toString());
        }
        return frameKeys;
    }

    private Frame getFrameToTransform(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[3].exec(env)).getFrame();
    }

    private String[] getTEColumns(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return this.getArgAsStrings(env, stk, asts[4]);
    }

    private TargetEncoder.DataLeakageHandlingStrategy getDataLeakageHandlingStrategy(Env env, Env.StackHelp stk, AstRoot[] asts) {
        String strategy = stk.track(asts[5].exec(env)).getStr();
        if (strategy.equals("kfold")) {
            return TargetEncoder.DataLeakageHandlingStrategy.KFold;
        }
        if (strategy.equals("loo")) {
            return TargetEncoder.DataLeakageHandlingStrategy.LeaveOneOut;
        }
        if (strategy.equals("loo")) {
            return TargetEncoder.DataLeakageHandlingStrategy.None;
        }
        return TargetEncoder.DataLeakageHandlingStrategy.None;
    }

    private String getTargetColumnName(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[6].exec(env)).getStr();
    }

    private String getFoldColumnName(Env env, Env.StackHelp stk, AstRoot[] asts) {
        try {
            String str = stk.track(asts[7].exec(env)).getStr();
            if (str.equals("")) {
                return null;
            }
            return str;
        }
        catch (IllegalArgumentException ex) {
            return null;
        }
    }

    private boolean getWithBlending(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[8].exec(env)).getNum() == 1.0;
    }

    private double getInflectionPoint(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[9].exec(env)).getNum();
    }

    private double getSmoothing(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[10].exec(env)).getNum();
    }

    private double getNoise(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[11].exec(env)).getNum();
    }

    private double getSeed(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[12].exec(env)).getNum();
    }
}

