package org.infinispan.interceptors.impl;

import org.infinispan.commands.FlagAffectedCommand;
import org.infinispan.commands.read.AbstractDataCommand;
import org.infinispan.commands.read.GetAllCommand;
import org.infinispan.commands.read.GetCacheEntryCommand;
import org.infinispan.commands.read.GetKeyValueCommand;
import org.infinispan.commands.write.EvictCommand;
import org.infinispan.commands.write.PutKeyValueCommand;
import org.infinispan.commands.write.PutMapCommand;
import org.infinispan.commands.write.RemoveCommand;
import org.infinispan.commands.write.ReplaceCommand;
import org.infinispan.commands.write.WriteCommand;
import org.infinispan.container.DataContainer;
import org.infinispan.context.Flag;
import org.infinispan.context.InvocationContext;
import org.infinispan.factories.annotations.Inject;
import org.infinispan.factories.annotations.Start;
import org.infinispan.jmx.annotations.DisplayType;
import org.infinispan.jmx.annotations.MBean;
import org.infinispan.jmx.annotations.ManagedAttribute;
import org.infinispan.jmx.annotations.ManagedOperation;
import org.infinispan.jmx.annotations.MeasurementType;
import org.infinispan.jmx.annotations.Units;
import org.infinispan.util.TimeService;

import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;

/**
 * Captures cache management statistics
 *
 * @author Jerry Gauthier
 * @since 9.0
 */
@MBean(objectName = "Statistics", description = "General statistics such as timings, hit/miss ratio, etc.")
public class CacheMgmtInterceptor extends JmxStatsCommandInterceptor {
   private final LongAdder hitTimes = new LongAdder();
   private final LongAdder missTimes = new LongAdder();
   private final LongAdder storeTimes = new LongAdder();
   private final LongAdder removeTimes = new LongAdder();
   private final LongAdder hits = new LongAdder();
   private final LongAdder misses = new LongAdder();
   private final LongAdder stores = new LongAdder();
   private final LongAdder evictions = new LongAdder();
   private final AtomicLong startNanoseconds = new AtomicLong(0);
   private final AtomicLong resetNanoseconds = new AtomicLong(0);
   private final LongAdder removeHits = new LongAdder();
   private final LongAdder removeMisses = new LongAdder();

   private DataContainer dataContainer;
   private TimeService timeService;

   @Inject
   @SuppressWarnings("unused")
   public void setDependencies(DataContainer dataContainer, TimeService timeService) {
      this.dataContainer = dataContainer;
      this.timeService = timeService;
   }

   @Start
   public void start() {
      startNanoseconds.set(timeService.time());
      resetNanoseconds.set(startNanoseconds.get());
   }

   @Override
   public CompletableFuture<Void> visitEvictCommand(InvocationContext ctx, EvictCommand command) throws Throwable {
      if (!getStatisticsEnabled(command))
         return ctx.continueInvocation();

      return ctx.onReturn((rCtx, rCommand, rv, throwable) -> {
         evictions.increment();
         return null;
      });
   }

   @Override
   public final CompletableFuture<Void> visitGetKeyValueCommand(InvocationContext ctx, GetKeyValueCommand command) throws Throwable {
      return visitDataReadCommand(ctx, command);
   }

   @Override
   public final CompletableFuture<Void> visitGetCacheEntryCommand(InvocationContext ctx, GetCacheEntryCommand command) throws Throwable {
      return visitDataReadCommand(ctx, command);
   }

   private CompletableFuture<Void> visitDataReadCommand(InvocationContext ctx, AbstractDataCommand command) throws Throwable {
      boolean statisticsEnabled = getStatisticsEnabled(command);
      if (!statisticsEnabled || !ctx.isOriginLocal())
         return ctx.continueInvocation();

      long start = timeService.time();
      return ctx.onReturn((rCtx, rCommand, rv, throwable) -> {
         long intervalMilliseconds = timeService.timeDuration(start, TimeUnit.MILLISECONDS);
         if (rv == null) {
            missTimes.add(intervalMilliseconds);
            misses.increment();
         } else {
            hitTimes.add(intervalMilliseconds);
            hits.increment();
         }
         return null;
      });
   }

   @SuppressWarnings("unchecked")
   @Override
   public CompletableFuture<Void> visitGetAllCommand(InvocationContext ctx, GetAllCommand command) throws Throwable {
      boolean statisticsEnabled = getStatisticsEnabled(command);
      if (!statisticsEnabled || !ctx.isOriginLocal())
         return ctx.continueInvocation();

      long start = timeService.time();
      return ctx.onReturn((rCtx, rCommand, rv, throwable) -> {
         long intervalMilliseconds = timeService.timeDuration(start, TimeUnit.MILLISECONDS);
         int requests = ((GetAllCommand) rCommand).getKeys().size();
         int hitCount = 0;
         for (Entry<Object, Object> entry : ((Map<Object, Object>) rv).entrySet()) {
            if (entry.getValue() != null) {
               hitCount++;
            }
         }

         int missCount = requests - hitCount;
         if (hitCount > 0) {
            hits.add(hitCount);
            hitTimes.add(intervalMilliseconds * hitCount / requests);
         }
         if (missCount > 0) {
            misses.add(missCount);
            missTimes.add(intervalMilliseconds * missCount / requests);
         }
         return null;
      });
   }

   @Override
   public CompletableFuture<Void> visitPutMapCommand(InvocationContext ctx, PutMapCommand command) throws Throwable {
      boolean statisticsEnabled = getStatisticsEnabled(command);
      if (!statisticsEnabled || !ctx.isOriginLocal())
         return ctx.continueInvocation();

      long start = timeService.time();
      return ctx.onReturn((rCtx, rCommand, rv, throwable) -> {
         final long intervalMilliseconds = timeService.timeDuration(start, TimeUnit.MILLISECONDS);
         final Map<Object, Object> data = ((PutMapCommand) rCommand).getMap();
         if (data != null && !data.isEmpty()) {
            storeTimes.add(intervalMilliseconds);
            stores.add(data.size());
         }
         return null;
      });
   }

   @Override
   //Map.put(key,value) :: oldValue
   public CompletableFuture<Void> visitPutKeyValueCommand(InvocationContext ctx, PutKeyValueCommand command) throws Throwable {
      return updateStoreStatistics(ctx, command);
   }

   @Override
   public CompletableFuture<Void> visitReplaceCommand(InvocationContext ctx, ReplaceCommand command) throws Throwable {
      return updateStoreStatistics(ctx, command);
   }

   private CompletableFuture<Void> updateStoreStatistics(InvocationContext ctx, WriteCommand command) throws Throwable {
      boolean statisticsEnabled = getStatisticsEnabled(command);
      if (!statisticsEnabled || !ctx.isOriginLocal())
         return ctx.continueInvocation();

      long start = timeService.time();
      return ctx.onReturn((rCtx, rCommand, rv, throwable) -> {
         if (command.isSuccessful()) {
            long intervalMilliseconds = timeService.timeDuration(start, TimeUnit.MILLISECONDS);
            storeTimes.add(intervalMilliseconds);
            stores.increment();
         }
         return null;
      });
   }

   @Override
   public CompletableFuture<Void> visitRemoveCommand(InvocationContext ctx, RemoveCommand command) throws Throwable {
      boolean statisticsEnabled = getStatisticsEnabled(command);
      if (!statisticsEnabled || !ctx.isOriginLocal())
         return ctx.continueInvocation();

      long start = timeService.time();
      return ctx.onReturn((rCtx, rCommand, rv, throwable) -> {
         RemoveCommand removeCommand = (RemoveCommand) rCommand;
         if (removeCommand.isConditional()) {
            if (removeCommand.isSuccessful())
               increaseRemoveHits(start);
            else
               increaseRemoveMisses();
         } else {
            if (rv == null)
               increaseRemoveMisses();
            else
               increaseRemoveHits(start);
         }
         return null;
      });
   }

   private void increaseRemoveHits(long start) {
      long intervalMilliseconds = timeService.timeDuration(start, TimeUnit.MILLISECONDS);
      removeTimes.add(intervalMilliseconds);
      removeHits.increment();
   }

   private void increaseRemoveMisses() {
      removeMisses.increment();
   }

   @ManagedAttribute(
         description = "Number of cache attribute hits",
         displayName = "Number of cache hits",
         measurementType = MeasurementType.TRENDSUP,
         displayType = DisplayType.SUMMARY)
   public long getHits() {
      return hits.sum();
   }

   @ManagedAttribute(
         description = "Number of cache attribute misses",
         displayName = "Number of cache misses",
         measurementType = MeasurementType.TRENDSUP,
         displayType = DisplayType.SUMMARY
   )
   public long getMisses() {
      return misses.sum();
   }

   @ManagedAttribute(
         description = "Number of cache removal hits",
         displayName = "Number of cache removal hits",
         measurementType = MeasurementType.TRENDSUP,
         displayType = DisplayType.SUMMARY
   )
   public long getRemoveHits() {
      return removeHits.sum();
   }

   @ManagedAttribute(
         description = "Number of cache removals where keys were not found",
         displayName = "Number of cache removal misses",
         measurementType = MeasurementType.TRENDSUP,
         displayType = DisplayType.SUMMARY
   )
   public long getRemoveMisses() {
      return removeMisses.sum();
   }

   @ManagedAttribute(
         description = "number of cache attribute put operations",
         displayName = "Number of cache puts" ,
         measurementType = MeasurementType.TRENDSUP,
         displayType = DisplayType.SUMMARY
   )
   public long getStores() {
      return stores.sum();
   }

   @ManagedAttribute(
         description = "Number of cache eviction operations",
         displayName = "Number of cache evictions",
         measurementType = MeasurementType.TRENDSUP,
         displayType = DisplayType.SUMMARY
   )
   public long getEvictions() {
      return evictions.sum();
   }

   @ManagedAttribute(
         description = "Percentage hit/(hit+miss) ratio for the cache",
         displayName = "Hit ratio",
         units = Units.PERCENTAGE,
         displayType = DisplayType.SUMMARY
   )
   @SuppressWarnings("unused")
   public double getHitRatio() {
      long hitsL = hits.sum();
      double total = hitsL + misses.sum();
      // The reason for <= is that equality checks
      // should be avoided for floating point numbers.
      if (total <= 0)
         return 0;
      return (hitsL / total);
   }

   @ManagedAttribute(
         description = "read/writes ratio for the cache",
         displayName = "Read/write ratio",
         units = Units.PERCENTAGE,
         displayType = DisplayType.SUMMARY
   )
   @SuppressWarnings("unused")
   public double getReadWriteRatio() {
      long sum = stores.sum();
      if (sum == 0)
         return 0;
      return (((double) (hits.sum() + misses.sum()) / (double) sum));
   }

   @ManagedAttribute(
         description = "Average number of milliseconds for a read operation on the cache",
         displayName = "Average read time",
         units = Units.MILLISECONDS,
         displayType = DisplayType.SUMMARY
   )
   @SuppressWarnings("unused")
   public long getAverageReadTime() {
      long total = hits.sum() + misses.sum();
      if (total == 0)
         return 0;
      return (hitTimes.sum() + missTimes.sum()) / total;
   }

   @ManagedAttribute(
         description = "Average number of milliseconds for a write operation in the cache",
         displayName = "Average write time",
         units = Units.MILLISECONDS,
         displayType = DisplayType.SUMMARY
   )
   @SuppressWarnings("unused")
   public long getAverageWriteTime() {
      long sum = stores.sum();
      if (sum == 0)
         return 0;
      return (storeTimes.sum()) / sum;
   }

   @ManagedAttribute(
         description = "Average number of milliseconds for a remove operation in the cache",
         displayName = "Average remove time",
         units = Units.MILLISECONDS,
         displayType = DisplayType.SUMMARY
   )
   @SuppressWarnings("unused")
   public long getAverageRemoveTime() {
      long removes = getRemoveHits();
      if (removes == 0)
         return 0;
      return (removeTimes.sum()) / removes;
   }

   @ManagedAttribute(
         description = "Number of entries currently in memory including expired entries",
         displayName = "Number of current cache entries",
         displayType = DisplayType.SUMMARY
   )
   public int getNumberOfEntries() {
      return dataContainer.sizeIncludingExpired();
   }

   @ManagedAttribute(
         description = "Number of seconds since cache started",
         displayName = "Seconds since cache started",
         units = Units.SECONDS,
         measurementType = MeasurementType.TRENDSUP,
         displayType = DisplayType.SUMMARY
   )
   public long getTimeSinceStart() {
      return timeService.timeDuration(startNanoseconds.get(), TimeUnit.SECONDS);
   }

   /**
    * Returns number of seconds since cache started
    *
    * @deprecated use {@link #getTimeSinceStart()} instead.
    * @return number of seconds since cache started
    */
   @ManagedAttribute(
         description = "Number of seconds since cache started",
         displayName = "Seconds since cache started",
         units = Units.SECONDS,
         measurementType = MeasurementType.TRENDSUP,
         displayType = DisplayType.SUMMARY
   )
   @Deprecated
   public long getElapsedTime() {
      // backward compatibility as we renamed ElapsedTime to TimeSinceStart
      return getTimeSinceStart();
   }

   @ManagedAttribute(
         description = "Number of seconds since the cache statistics were last reset",
         displayName = "Seconds since cache statistics were reset",
         units = Units.SECONDS,
         displayType = DisplayType.SUMMARY
   )
   @SuppressWarnings("unused")
   public long getTimeSinceReset() {
      return timeService.timeDuration(resetNanoseconds.get(), TimeUnit.SECONDS);
   }

   @Override
   @ManagedOperation(
         description = "Resets statistics gathered by this component",
         displayName = "Reset Statistics (Statistics)"
   )
   public void resetStatistics() {
      hits.reset();
      misses.reset();
      stores.reset();
      evictions.reset();
      hitTimes.reset();
      missTimes.reset();
      storeTimes.reset();
      removeHits.reset();
      removeTimes.reset();
      removeMisses.reset();
      resetNanoseconds.set(timeService.time());
   }

   private boolean getStatisticsEnabled(FlagAffectedCommand cmd) {
      return super.getStatisticsEnabled() && !cmd.hasFlag(Flag.SKIP_STATISTICS);
   }

   public void addEvictions(long numEvictions) {
      evictions.add(numEvictions);
   }
}

