/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.deepwater.caffe;

import com.google.protobuf.nano.CodedInputByteBufferNano;
import com.google.protobuf.nano.CodedOutputByteBufferNano;
import deepwater.backends.BackendModel;
import hex.genmodel.algos.deepwater.caffe.nano.Deepwater;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public class DeepwaterCaffeModel
implements BackendModel {
    private int[] _input_shape = new int[0];
    private int[] _sizes = new int[0];
    private String[] _types = new String[0];
    private double[] _dropout_ratios = new double[0];
    private long _seed;
    private boolean _useGPU;
    private String _graph = "";
    private Process _process;
    private static final ThreadLocal<ByteBuffer> _buffer = new ThreadLocal();

    public DeepwaterCaffeModel(int batch_size, int[] sizes, String[] types, double[] dropout_ratios, long seed, boolean useGPU) {
        this._input_shape = new int[]{batch_size, 1, 1, sizes[0]};
        this._sizes = sizes;
        this._types = types;
        this._dropout_ratios = dropout_ratios;
        this._seed = seed;
        this._useGPU = useGPU;
        this.start();
    }

    public DeepwaterCaffeModel(String graph, int[] input_shape, long seed, boolean useGPU) {
        this._graph = graph;
        this._input_shape = input_shape;
        this._seed = seed;
        this._useGPU = useGPU;
        this.start();
    }

    private void start() {
        if (this._process == null) {
            try {
                this.startRegular();
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            Deepwater.Cmd cmd = new Deepwater.Cmd();
            cmd.type = 0;
            cmd.graph = this._graph;
            cmd.inputShape = this._input_shape;
            cmd.solverType = "Adam";
            cmd.sizes = this._sizes;
            cmd.types = this._types;
            cmd.dropoutRatios = this._dropout_ratios;
            cmd.learningRate = 0.01f;
            cmd.momentum = 0.99f;
            cmd.randomSeed = this._seed;
            cmd.useGpu = this._useGPU;
            this.call(cmd);
        }
    }

    public void saveModel(String model_path) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 3;
        cmd.path = model_path;
        this.call(cmd);
    }

    public void saveParam(String param_path) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 4;
        cmd.path = param_path;
        this.call(cmd);
    }

    public void loadParam(String param_path) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 5;
        cmd.path = param_path;
        this.call(cmd);
    }

    private static void copy(float[] data, byte[] buff) {
        if (data.length * 4 != buff.length) {
            throw new RuntimeException();
        }
        ByteBuffer buffer = _buffer.get();
        if (buffer == null || buffer.capacity() < buff.length) {
            buffer = ByteBuffer.allocateDirect(buff.length);
            _buffer.set(buffer);
            buffer.order(ByteOrder.LITTLE_ENDIAN);
        }
        buffer.clear();
        buffer.asFloatBuffer().put(data);
        buffer.get(buff);
    }

    private static void copy(float[][] buffs, Deepwater.Cmd cmd) {
        cmd.data = new byte[buffs.length][];
        for (int i = 0; i < buffs.length; ++i) {
            cmd.data[i] = new byte[buffs[i].length * 4];
            DeepwaterCaffeModel.copy(buffs[i], cmd.data[i]);
        }
    }

    public void train(float[] data, float[] label) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 1;
        cmd.inputShape = this._input_shape;
        int len = this._input_shape[0] * this._input_shape[1] * this._input_shape[2] * this._input_shape[3];
        if (data.length != len) {
            throw new RuntimeException();
        }
        if (label.length != this._input_shape[0]) {
            throw new RuntimeException();
        }
        float[][] buffs = new float[][]{data, label};
        DeepwaterCaffeModel.copy(buffs, cmd);
        this.call(cmd);
    }

    public float[] predict(float[] data) {
        Deepwater.Cmd cmd = new Deepwater.Cmd();
        cmd.type = 2;
        cmd.inputShape = this._input_shape;
        float[][] buffs = new float[][]{data};
        DeepwaterCaffeModel.copy(buffs, cmd);
        cmd = this.call(cmd);
        ByteBuffer buffer = _buffer.get();
        if (buffer == null || buffer.capacity() < cmd.data[0].length) {
            buffer = ByteBuffer.allocateDirect(cmd.data[0].length);
            _buffer.set(buffer);
            buffer.order(ByteOrder.LITTLE_ENDIAN);
        }
        buffer.clear();
        buffer.put(cmd.data[0]);
        float[] res = new float[cmd.data[0].length / 4];
        buffer.flip();
        buffer.asFloatBuffer().get(res);
        return res;
    }

    private void startRegular() throws IOException {
        String pwd = "/opt/caffe-h2o/";
        ProcessBuilder pb = new ProcessBuilder("python3 backend.py".split(" "));
        pb.environment().put("PYTHONPATH", "/opt/caffe/python");
        pb.redirectError(ProcessBuilder.Redirect.INHERIT);
        pb.directory(new File(pwd));
        this._process = pb.start();
    }

    void close() {
        this._process.destroy();
        try {
            this._process.waitFor();
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
    }

    private Deepwater.Cmd call(Deepwater.Cmd cmd) {
        try {
            OutputStream stdin = this._process.getOutputStream();
            int len = cmd.getSerializedSize();
            ByteBuffer buffer = ByteBuffer.allocate(4 + len);
            buffer.putInt(len);
            CodedOutputByteBufferNano ou = CodedOutputByteBufferNano.newInstance(buffer.array(), buffer.position(), buffer.remaining());
            cmd.writeTo(ou);
            buffer.position(buffer.position() + len);
            stdin.write(buffer.array(), 0, buffer.position());
            stdin.flush();
            InputStream stdout = this._process.getInputStream();
            int read = stdout.read(buffer.array(), 0, 4);
            if (read != 4) {
                throw new RuntimeException();
            }
            buffer.position(0);
            buffer.limit(read);
            len = buffer.getInt();
            if (buffer.capacity() < len) {
                buffer = ByteBuffer.allocate(len);
            }
            buffer.position(0);
            buffer.limit(len);
            while (buffer.position() < buffer.limit()) {
                read = stdout.read(buffer.array(), buffer.position(), buffer.limit());
                buffer.position(buffer.position() + read);
            }
            Deepwater.Cmd res = new Deepwater.Cmd();
            CodedInputByteBufferNano in = CodedInputByteBufferNano.newInstance(buffer.array(), 0, buffer.position());
            res.mergeFrom(in);
            return res;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

