/*
 * Decompiled with CFR 0.152.
 */
package org.lenskit.mf.funksvd;

import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import javax.inject.Provider;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.grouplens.lenskit.iterative.TrainingLoopController;
import org.lenskit.data.ratings.RatingMatrix;
import org.lenskit.data.ratings.RatingMatrixEntry;
import org.lenskit.inject.Transient;
import org.lenskit.mf.funksvd.FeatureCount;
import org.lenskit.mf.funksvd.FeatureInfo;
import org.lenskit.mf.funksvd.FunkSVDModel;
import org.lenskit.mf.funksvd.FunkSVDUpdateRule;
import org.lenskit.mf.funksvd.FunkSVDUpdater;
import org.lenskit.mf.funksvd.InitialFeatureValue;
import org.lenskit.mf.funksvd.TrainingEstimator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FunkSVDModelProvider
implements Provider<FunkSVDModel> {
    private static Logger logger = LoggerFactory.getLogger(FunkSVDModelProvider.class);
    protected final int featureCount;
    protected final RatingMatrix snapshot;
    protected final double initialValue;
    protected final FunkSVDUpdateRule rule;

    @Inject
    public FunkSVDModelProvider(@Transient @Nonnull RatingMatrix snapshot, @Transient @Nonnull FunkSVDUpdateRule rule, @FeatureCount int featureCount, @InitialFeatureValue double initVal) {
        this.featureCount = featureCount;
        this.initialValue = initVal;
        this.snapshot = snapshot;
        this.rule = rule;
    }

    public FunkSVDModel get() {
        int userCount = this.snapshot.getUserIds().size();
        RealMatrix userFeatures = MatrixUtils.createRealMatrix((int)userCount, (int)this.featureCount);
        int itemCount = this.snapshot.getItemIds().size();
        RealMatrix itemFeatures = MatrixUtils.createRealMatrix((int)itemCount, (int)this.featureCount);
        logger.debug("Learning rate is {}", (Object)this.rule.getLearningRate());
        logger.debug("Regularization term is {}", (Object)this.rule.getTrainingRegularization());
        logger.info("Building SVD with {} features for {} ratings", (Object)this.featureCount, (Object)this.snapshot.getRatings().size());
        TrainingEstimator estimates = this.rule.makeEstimator(this.snapshot);
        ArrayList<FeatureInfo> featureInfo = new ArrayList<FeatureInfo>(this.featureCount);
        RealVector uvec = MatrixUtils.createRealVector((double[])new double[userCount]);
        RealVector ivec = MatrixUtils.createRealVector((double[])new double[itemCount]);
        for (int f = 0; f < this.featureCount; ++f) {
            logger.debug("Training feature {}", (Object)f);
            StopWatch timer = new StopWatch();
            timer.start();
            uvec.set(this.initialValue);
            ivec.set(this.initialValue);
            FeatureInfo.Builder fib = new FeatureInfo.Builder(f);
            this.trainFeature(f, estimates, uvec, ivec, fib);
            this.summarizeFeature(uvec, ivec, fib);
            featureInfo.add(fib.build());
            estimates.update(uvec, ivec);
            userFeatures.setColumnVector(f, uvec);
            assert (Math.abs(userFeatures.getColumnVector(f).getL1Norm() - uvec.getL1Norm()) < 1.0E-4) : "user column sum matches";
            itemFeatures.setColumnVector(f, ivec);
            assert (Math.abs(itemFeatures.getColumnVector(f).getL1Norm() - ivec.getL1Norm()) < 1.0E-4) : "item column sum matches";
            timer.stop();
            logger.info("Finished feature {} in {}", (Object)f, (Object)timer);
        }
        return new FunkSVDModel(userFeatures, itemFeatures, this.snapshot.userIndex(), this.snapshot.itemIndex(), featureInfo);
    }

    protected void trainFeature(int feature, TrainingEstimator estimates, RealVector userFeatureVector, RealVector itemFeatureVector, FeatureInfo.Builder fib) {
        double rmse = Double.MAX_VALUE;
        double trail = this.initialValue * this.initialValue * (double)(this.featureCount - feature - 1);
        TrainingLoopController controller = this.rule.getTrainingLoopController();
        List ratings = this.snapshot.getRatings();
        while (controller.keepTraining(rmse)) {
            rmse = this.doFeatureIteration(estimates, ratings, userFeatureVector, itemFeatureVector, trail);
            fib.addTrainingRound(rmse);
            logger.trace("iteration {} finished with RMSE {}", (Object)controller.getIterationCount(), (Object)rmse);
        }
    }

    protected double doFeatureIteration(TrainingEstimator estimates, List<RatingMatrixEntry> ratings, RealVector userFeatureVector, RealVector itemFeatureVector, double trail) {
        FunkSVDUpdater updater = this.rule.createUpdater();
        for (RatingMatrixEntry r : ratings) {
            int uidx = r.getUserIndex();
            int iidx = r.getItemIndex();
            updater.prepare(0, r.getValue(), estimates.get(r), userFeatureVector.getEntry(uidx), itemFeatureVector.getEntry(iidx), trail);
            userFeatureVector.addToEntry(uidx, updater.getUserFeatureUpdate());
            itemFeatureVector.addToEntry(iidx, updater.getItemFeatureUpdate());
        }
        return updater.getRMSE();
    }

    protected void summarizeFeature(RealVector ufv, RealVector ifv, FeatureInfo.Builder fib) {
        fib.setUserAverage(this.realVectorSum(ufv) / (double)ufv.getDimension()).setItemAverage(this.realVectorSum(ifv) / (double)ifv.getDimension()).setSingularValue(ufv.getNorm() * ifv.getNorm());
    }

    private double realVectorSum(RealVector rv) {
        double total = 0.0;
        for (double i : rv.toArray()) {
            total += i;
        }
        return total;
    }
}

