package org.infinispan.protostream.impl;

import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.GeneratedMessage;
import com.google.protobuf.GeneratedMessageLite;
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
import org.infinispan.protostream.RawProtobufMarshaller;
import org.infinispan.protostream.SerializationContext;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

/**
 * Marshaller for Google's Protobuf message classes generated by protoc tool.
 *
 * @author anistor@redhat.com
 * @since 1.0
 */
public final class ProtocMessageMarshaller<T extends MessageLite> implements RawProtobufMarshaller<T> {

   // this is a public static field present in all generated classes extending from GeneratedMessage or GeneratedMessageLite
   private static final String PARSER_FIELD_NAME = "PARSER";

   // this is a public static method present only in all generated classes extending from GeneratedMessage
   private static final String GET_DESCRIPTOR_METHOD_NAME = "getDescriptor";

   private final String typeName;

   private final Class<? extends T> clazz;

   private final Parser<T> parser;

   public ProtocMessageMarshaller(Class<? extends T> clazz) {
      this(null, clazz);
   }

   @SuppressWarnings("unchecked")
   public ProtocMessageMarshaller(String typeName, Class<? extends T> clazz) {
      if (!GeneratedMessage.class.isAssignableFrom(clazz) && !GeneratedMessageLite.class.isAssignableFrom(clazz)) {
         throw new IllegalArgumentException("The given class does not appear to be a 'protoc' generated message class.");
      }

      try {
         Field parserField = clazz.getDeclaredField(PARSER_FIELD_NAME);

         parser = (Parser<T>) parserField.get(null);
      } catch (NoSuchFieldException e) {
         throw new IllegalArgumentException("Class " + clazz + " does not appear to be a 'protoc' generated message class (missing 'PARSER' field).");
      } catch (IllegalAccessException e) {
         throw new IllegalArgumentException("Class " + clazz + " does not appear to be a 'protoc' generated message class (missing 'PARSER' field).");
      }

      if (GeneratedMessage.class.isAssignableFrom(clazz)) {
         String guessedTypeName;
         try {
            Method getDescriptorMethod = clazz.getDeclaredMethod(GET_DESCRIPTOR_METHOD_NAME);
            Descriptor descriptor = (Descriptor) getDescriptorMethod.invoke(null);
            guessedTypeName = descriptor.getFullName();
         } catch (NoSuchMethodException e) {
            throw new IllegalArgumentException("Class " + clazz + " does not appear to be a 'protoc' generated message class (missing 'getDescriptor' method).");
         } catch (IllegalAccessException e) {
            throw new IllegalArgumentException("Class " + clazz + " does not appear to be a 'protoc' generated message class (missing 'getDescriptor' method).");
         } catch (InvocationTargetException e) {
            throw new IllegalArgumentException("Class " + clazz + " does not appear to be a 'protoc' generated message class (missing 'getDescriptor' method).");
         }

         if (typeName != null && !guessedTypeName.equals(typeName)) {
            throw new IllegalArgumentException("The specified typeName (' "
                                                     + typeName
                                                     + " ') does not match the one of the message (' "
                                                     + guessedTypeName + "').");
         }
         this.typeName = guessedTypeName;
      } else {
         if (typeName == null) {
            throw new IllegalArgumentException("typeName was not specified and cannot be determined automatically");
         }
         this.typeName = typeName;
      }

      this.clazz = clazz;
   }

   @Override
   public Class<? extends T> getJavaClass() {
      return clazz;
   }

   @Override
   public String getTypeName() {
      return typeName;
   }

   @Override
   public T readFrom(SerializationContext ctx, CodedInputStream in) throws IOException {
      return parser.parseFrom(in);
   }

   @Override
   public void writeTo(SerializationContext ctx, CodedOutputStream out, T o) throws IOException {
      o.writeTo(out);
   }
}
