/*
 * Decompiled with CFR 0.152.
 */
package hex.word2vec;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.word2vec.HBWTree;
import hex.word2vec.Word2Vec;
import hex.word2vec.Word2VecMojoWriter;
import hex.word2vec.WordCountTask;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.MemoryManager;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.IcedHashMap;
import water.util.IcedHashMapGeneric;
import water.util.IcedLong;
import water.util.RandomBase;
import water.util.RandomUtils;

public class Word2VecModel
extends Model<Word2VecModel, Word2VecParameters, Word2VecOutput> {
    public Word2VecModel(Key<Word2VecModel> selfKey, Word2VecParameters params, Word2VecOutput output) {
        super(selfKey, params, output);
        assert (Arrays.equals(this._key._kb, selfKey._kb));
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        throw H2O.unimpl("No Model Metrics for Word2Vec.");
    }

    @Override
    public double[] score0(Chunk[] cs, int foo, double[] data, double[] preds) {
        throw H2O.unimpl();
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        throw H2O.unimpl();
    }

    @Override
    public Word2VecMojoWriter getMojo() {
        return new Word2VecMojoWriter(this);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Frame toFrame() {
        Keyed zeroVec = null;
        try {
            zeroVec = Vec.makeZero(((Word2VecOutput)this._output)._words.length);
            byte[] types = new byte[1 + ((Word2VecOutput)this._output)._vecSize];
            Arrays.fill(types, (byte)3);
            types[0] = 2;
            String[] colNames = new String[types.length];
            colNames[0] = "Word";
            for (int i = 1; i < colNames.length; ++i) {
                colNames[i] = "V" + i;
            }
            Frame frame = ((ConvertToFrameTask)new ConvertToFrameTask(this).doAll(types, new Vec[]{zeroVec})).outputFrame(colNames, null);
            return frame;
        }
        finally {
            if (zeroVec != null) {
                zeroVec.remove();
            }
        }
    }

    public float[] transform(String target) {
        return this.transform(new BufferedString(target));
    }

    private float[] transform(BufferedString word) {
        if (!((Word2VecOutput)this._output)._vocab.containsKey(word)) {
            return null;
        }
        int wordIdx = ((Word2VecOutput)this._output)._vocab.get(word);
        return Arrays.copyOfRange(((Word2VecOutput)this._output)._vecs, wordIdx * ((Word2VecOutput)this._output)._vecSize, (wordIdx + 1) * ((Word2VecOutput)this._output)._vecSize);
    }

    public Frame transform(Vec wordVec, AggregateMethod aggregateMethod) {
        if (wordVec.get_type() != 2) {
            throw new IllegalArgumentException("Expected a string vector, got " + wordVec.get_type_str() + " vector.");
        }
        byte[] types = new byte[((Word2VecOutput)this._output)._vecSize];
        Arrays.fill(types, (byte)3);
        MRTask transformTask = aggregateMethod == AggregateMethod.AVERAGE ? new Word2VecAggregateTask(this) : new Word2VecTransformTask(this);
        return ((MRTask)transformTask.doAll(types, wordVec)).outputFrame(Key.make(), null, null);
    }

    public Map<String, Float> findSynonyms(String target, int cnt) {
        float[] vec = this.transform(target);
        if (vec == null || cnt == 0) {
            return Collections.emptyMap();
        }
        int[] synonyms = new int[cnt];
        float[] scores = new float[cnt];
        int min = 0;
        for (int i = 0; i < cnt; ++i) {
            synonyms[i] = i;
            scores[i] = this.cosineSimilarity(vec, i * vec.length, ((Word2VecOutput)this._output)._vecs);
            if (!(scores[i] < scores[min])) continue;
            min = i;
        }
        int vocabSize = ((Word2VecOutput)this._output)._vocab.size();
        for (int i = cnt; i < vocabSize; ++i) {
            float score = this.cosineSimilarity(vec, i * vec.length, ((Word2VecOutput)this._output)._vecs);
            if (score <= scores[min] || (double)score >= 0.999999) continue;
            synonyms[min] = i;
            scores[min] = score;
            min = 0;
            for (int j = 1; j < cnt; ++j) {
                if (!(scores[j] < scores[min])) continue;
                min = j;
            }
        }
        HashMap<String, Float> result = new HashMap<String, Float>(cnt);
        for (int i = 0; i < cnt; ++i) {
            result.put(((Word2VecOutput)this._output)._words[synonyms[i]].toString(), Float.valueOf(scores[i]));
        }
        return result;
    }

    private float cosineSimilarity(float[] target, int pos, float[] vecs) {
        float dotProd = 0.0f;
        float tsqr = 0.0f;
        float csqr = 0.0f;
        for (int i = 0; i < target.length; ++i) {
            dotProd += target[i] * vecs[pos + i];
            tsqr = (float)((double)tsqr + Math.pow(target[i], 2.0));
            csqr = (float)((double)csqr + Math.pow(vecs[pos + i], 2.0));
        }
        return (float)((double)dotProd / (Math.sqrt(tsqr) * Math.sqrt(csqr)));
    }

    void buildModelOutput(Word2VecModelInfo modelInfo) {
        IcedHashMapGeneric<BufferedString, Integer> vocab = ((Vocabulary)DKV.getGet(modelInfo._vocabKey))._data;
        BufferedString[] words = new BufferedString[vocab.size()];
        Iterator<BufferedString> iterator = vocab.keySet().iterator();
        while (iterator.hasNext()) {
            BufferedString str;
            words[vocab.get((Object)str).intValue()] = str = iterator.next();
        }
        ((Word2VecOutput)this._output)._vecSize = ((Word2VecParameters)this._parms)._vec_size;
        ((Word2VecOutput)this._output)._vecs = modelInfo._syn0;
        ((Word2VecOutput)this._output)._words = words;
        ((Word2VecOutput)this._output)._vocab = vocab;
    }

    void buildModelOutput(BufferedString[] words, float[] syn0) {
        IcedHashMapGeneric<BufferedString, Integer> vocab = new IcedHashMapGeneric<BufferedString, Integer>();
        for (int i = 0; i < words.length; ++i) {
            vocab.put(words[i], i);
        }
        ((Word2VecOutput)this._output)._vecSize = ((Word2VecParameters)this._parms)._vec_size;
        ((Word2VecOutput)this._output)._vecs = syn0;
        ((Word2VecOutput)this._output)._words = words;
        ((Word2VecOutput)this._output)._vocab = vocab;
    }

    public static class WordCounts
    extends Keyed<WordCounts> {
        IcedHashMap<BufferedString, IcedLong> _data;

        WordCounts(IcedHashMap<BufferedString, IcedLong> data) {
            super(Key.make());
            this._data = data;
        }
    }

    public static class Vocabulary
    extends Keyed<Vocabulary> {
        IcedHashMapGeneric<BufferedString, Integer> _data;

        Vocabulary(IcedHashMapGeneric<BufferedString, Integer> data) {
            super(Key.make());
            this._data = data;
        }
    }

    public static class Word2VecModelInfo
    extends Iced {
        long _vocabWordCount;
        long _totalProcessedWords = 0L;
        float[] _syn0;
        float[] _syn1;
        Key<HBWTree> _treeKey;
        Key<Vocabulary> _vocabKey;
        Key<WordCounts> _wordCountsKey;
        private Word2VecParameters _parameters;

        public final Word2VecParameters getParams() {
            return this._parameters;
        }

        public Word2VecModelInfo() {
        }

        private Word2VecModelInfo(Word2VecParameters params, WordCounts wordCounts) {
            this._parameters = params;
            long vocabWordCount = 0L;
            ArrayList wordCountList = new ArrayList(wordCounts._data.size());
            for (Map.Entry wc : wordCounts._data.entrySet()) {
                if (((IcedLong)wc.getValue())._val < (long)this._parameters._min_word_freq) continue;
                wordCountList.add(wc);
                vocabWordCount += ((IcedLong)wc.getValue())._val;
            }
            Collections.sort(wordCountList, new Comparator<Map.Entry<BufferedString, IcedLong>>(){

                @Override
                public int compare(Map.Entry<BufferedString, IcedLong> o1, Map.Entry<BufferedString, IcedLong> o2) {
                    long x = o1.getValue()._val;
                    long y = o2.getValue()._val;
                    return x < y ? -1 : (x == y ? 0 : 1);
                }
            });
            int vocabSize = wordCountList.size();
            long[] countAry = new long[vocabSize];
            Vocabulary vocab = new Vocabulary(new IcedHashMapGeneric<BufferedString, Integer>());
            int idx = 0;
            for (Map.Entry entry : wordCountList) {
                countAry[idx] = ((IcedLong)entry.getValue())._val;
                vocab._data.put((BufferedString)entry.getKey(), idx++);
            }
            HBWTree t = HBWTree.buildHuffmanBinaryWordTree(countAry);
            this._vocabWordCount = vocabWordCount;
            this._treeKey = Word2VecModelInfo.publish(t);
            this._vocabKey = Word2VecModelInfo.publish(vocab);
            this._wordCountsKey = Word2VecModelInfo.publish(wordCounts);
            RandomBase randomBase = RandomUtils.getRNG(912559L, 55930L);
            this._syn1 = MemoryManager.malloc4f(this._parameters._vec_size * vocabSize);
            this._syn0 = MemoryManager.malloc4f(this._parameters._vec_size * vocabSize);
            for (int i = 0; i < this._parameters._vec_size * vocabSize; ++i) {
                this._syn0[i] = (randomBase.nextFloat() - 0.5f) / (float)this._parameters._vec_size;
            }
        }

        public static Word2VecModelInfo createInitialModelInfo(Word2VecParameters params) {
            Vec v = params.trainVec();
            WordCounts wordCounts = new WordCounts(((WordCountTask)new WordCountTask().doAll((Vec[])new Vec[]{v}))._counts);
            return new Word2VecModelInfo(params, wordCounts);
        }

        private static <T extends Keyed<T>> Key<T> publish(T keyed) {
            Scope.track_generic(keyed);
            DKV.put(keyed);
            return keyed._key;
        }
    }

    public static class Word2VecOutput
    extends Model.Output {
        public int _vecSize;
        public int _epochs;
        public BufferedString[] _words;
        public float[] _vecs;
        public IcedHashMapGeneric<BufferedString, Integer> _vocab;

        public Word2VecOutput(Word2Vec b) {
            super(b);
        }

        @Override
        public ModelCategory getModelCategory() {
            return ModelCategory.WordEmbedding;
        }
    }

    public static class Word2VecParameters
    extends Model.Parameters {
        static final int MAX_VEC_SIZE = 10000;
        public Word2Vec.WordModel _word_model = Word2Vec.WordModel.SkipGram;
        public Word2Vec.NormModel _norm_model = Word2Vec.NormModel.HSM;
        public int _min_word_freq = 5;
        public int _vec_size = 100;
        public int _window_size = 5;
        public int _epochs = 5;
        public float _init_learning_rate = 0.025f;
        public float _sent_sample_rate = 0.001f;
        public Key<Frame> _pre_trained;

        @Override
        public String algoName() {
            return "Word2Vec";
        }

        @Override
        public String fullName() {
            return "Word2Vec";
        }

        @Override
        public String javaName() {
            return Word2VecModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return this.isPreTrained() ? (long)this._pre_trained.get().anyVec().nChunks() : (long)(this.train().vec(0).nChunks() * this._epochs);
        }

        boolean isPreTrained() {
            return this._pre_trained != null;
        }

        Vec trainVec() {
            return this.train().vec(0);
        }
    }

    private static class Word2VecAggregateTask
    extends MRTask<Word2VecAggregateTask> {
        private Word2VecModel _model;

        public Word2VecAggregateTask(Word2VecModel model) {
            this._model = model;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            assert (cs.length == 1);
            Chunk chk = cs[0];
            int offset = 0;
            if (chk.cidx() > 0) {
                int naPos = this.findNA(chk);
                if (naPos < 0) {
                    return;
                }
                offset = naPos + 1;
            }
            float[] aggregated = new float[ncs.length];
            int seqLength = 0;
            boolean seqOpen = false;
            BufferedString tmp = new BufferedString();
            block0: do {
                for (int i = offset; i < chk._len; ++i) {
                    if (chk.isNA(i)) {
                        this.writeAggregate(seqLength, aggregated, ncs);
                        Arrays.fill(aggregated, 0.0f);
                        seqLength = 0;
                        seqOpen = false;
                        if (chk == cs[0]) continue;
                        break block0;
                    }
                    BufferedString word = chk.atStr(tmp, i);
                    float[] vs = this._model.transform(word);
                    if (vs != null) {
                        for (int j = 0; j < ncs.length; ++j) {
                            int n = j;
                            aggregated[n] = aggregated[n] + vs[j];
                        }
                        ++seqLength;
                    }
                    seqOpen = true;
                }
                offset = 0;
            } while ((chk = chk.nextChunk()) != null);
            if (seqOpen) {
                this.writeAggregate(seqLength, aggregated, ncs);
            }
        }

        private void writeAggregate(int seqLength, float[] aggregated, NewChunk[] ncs) {
            if (seqLength == 0) {
                for (NewChunk nc : ncs) {
                    nc.addNA();
                }
            } else {
                for (int j = 0; j < ncs.length; ++j) {
                    ncs[j].addNum(aggregated[j] / (float)seqLength);
                }
            }
        }

        private int findNA(Chunk chk) {
            for (int i = 0; i < chk._len; ++i) {
                if (!chk.isNA(i)) continue;
                return i;
            }
            return -1;
        }
    }

    private static class Word2VecTransformTask
    extends MRTask<Word2VecTransformTask> {
        private Word2VecModel _model;

        public Word2VecTransformTask(Word2VecModel model) {
            this._model = model;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            assert (cs.length == 1);
            Chunk chk = cs[0];
            BufferedString tmp = new BufferedString();
            for (int i = 0; i < chk._len; ++i) {
                if (chk.isNA(i)) {
                    for (NewChunk nc : ncs) {
                        nc.addNA();
                    }
                    continue;
                }
                BufferedString word = chk.atStr(tmp, i);
                float[] vs = this._model.transform(word);
                if (vs == null) {
                    for (NewChunk nc : ncs) {
                        nc.addNA();
                    }
                    continue;
                }
                for (int j = 0; j < ncs.length; ++j) {
                    ncs[j].addNum(vs[j]);
                }
            }
        }
    }

    public static enum AggregateMethod {
        NONE,
        AVERAGE;

    }

    private static class ConvertToFrameTask
    extends MRTask<ConvertToFrameTask> {
        private Key<Word2VecModel> _modelKey;
        private transient Word2VecModel _model;

        public ConvertToFrameTask(Word2VecModel model) {
            this._modelKey = model._key;
        }

        @Override
        protected void setupLocal() {
            this._model = (Word2VecModel)DKV.getGet(this._modelKey);
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            assert (cs.length == 1);
            assert (ncs.length == ((Word2VecOutput)this._model._output)._vecSize + 1);
            Chunk chk = cs[0];
            int wordOffset = (int)chk.start();
            int vecPos = ((Word2VecOutput)this._model._output)._vecSize * wordOffset;
            for (int i = 0; i < chk._len; ++i) {
                ncs[0].addStr(((Word2VecOutput)this._model._output)._words[wordOffset + i]);
                for (int j = 1; j < ncs.length; ++j) {
                    ncs[j].addNum(((Word2VecOutput)this._model._output)._vecs[vecPos++]);
                }
            }
        }
    }
}

