/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.xgboost;

import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.MojoReaderBackendFactory;
import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URL;
import org.junit.Assert;
import org.junit.Test;

public class XGBoostJavaMojoModelTest {
    @Test
    public void testObjFunction() {
        for (XGBoostMojoModel.ObjectiveType type : XGBoostMojoModel.ObjectiveType.values()) {
            Assert.assertNotNull((Object)type.getId());
            Assert.assertFalse((boolean)type.getId().isEmpty());
            Assert.assertNotNull((Object)XGBoostJavaMojoModel.getObjFunction((String)type.getId()));
        }
    }

    @Test
    public void testPredictContributionsSerialization() throws Exception {
        MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend((URL)XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"), (MojoReaderBackendFactory.CachingStrategy)MojoReaderBackendFactory.CachingStrategy.MEMORY);
        XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel)MojoModel.load((MojoReaderBackend)readerBackend);
        PredictContributions pc = mojo.makeContributionsPredictor();
        Assert.assertNotNull((Object)pc);
        Assert.assertTrue((boolean)(XGBoostJavaMojoModelTest.deserialize(XGBoostJavaMojoModelTest.serialize(pc)) instanceof PredictContributions));
    }

    @Test
    public void testLeafNodeAssignments() throws Exception {
        MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend((URL)this.getClass().getResource("xgboost_java.zip"), (MojoReaderBackendFactory.CachingStrategy)MojoReaderBackendFactory.CachingStrategy.MEMORY);
        XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel)MojoModel.load((MojoReaderBackend)readerBackend);
        double[] doubles = new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
        SharedTreeMojoModel.LeafNodeAssignments res = mojo.getLeafNodeAssignments(doubles);
        Assert.assertNotNull((Object)res._nodeIds);
        Assert.assertNotNull((Object)res._paths);
        Object[] paths = mojo.getDecisionPath(doubles);
        Assert.assertArrayEquals((Object[])paths, (Object[])res._paths);
        RowData data = new RowData();
        for (int i = 0; i < doubles.length; ++i) {
            data.put((Object)mojo._names[i], (Object)doubles[i]);
        }
        EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel((GenModel)mojo).setEnableLeafAssignment(true));
        RegressionModelPrediction res2 = (RegressionModelPrediction)wrapper.predict(data);
        Assert.assertNotNull((Object)res2.leafNodeAssignmentIds);
        Assert.assertNotNull((Object)res2.leafNodeAssignments);
        Assert.assertArrayEquals((int[])res._nodeIds, (int[])res2.leafNodeAssignmentIds);
        Assert.assertArrayEquals((Object[])res._paths, (Object[])res2.leafNodeAssignments);
    }

    private static byte[] serialize(Object o) throws Exception {
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        try (ObjectOutputStream out = new ObjectOutputStream(bos);){
            out.writeObject(o);
        }
        return bos.toByteArray();
    }

    private static Object deserialize(byte[] bs) throws Exception {
        try (ByteArrayInputStream bis = new ByteArrayInputStream(bs);){
            ObjectInputStream in = new ObjectInputStream(bis);
            Object object = in.readObject();
            return object;
        }
    }
}

