/*
* JBoss, Home of Professional Open Source.
* Copyright 2006, Red Hat Middleware LLC, and individual contributors
* as indicated by the @author tags. See the copyright.txt file in the
* distribution for a full listing of individual contributors. 
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/ 
package org.jboss.reflect.plugins.bytecode.accessor.generated;

import java.lang.reflect.InvocationTargetException;
import java.security.ProtectionDomain;

import javassist.CtClass;
import javassist.Modifier;
import javassist.bytecode.Bytecode;
import javassist.bytecode.ClassFile;
import javassist.bytecode.ClassFileWriter;
import javassist.bytecode.Descriptor;
import javassist.bytecode.MethodInfo;
import javassist.bytecode.Opcode;
import javassist.bytecode.ClassFileWriter.ConstPoolWriter;
import javassist.bytecode.ClassFileWriter.MethodWriter;

/**
 * Wrapper around the {@link ClassFileWriter} with some utility methods
 * 
 * @author <a href="kabir.khan@jboss.com">Kabir Khan</a>
 * @version $Revision: 1.1 $
 */
class ClassFileWriterContext<T>
{
   private static final java.lang.reflect.Method defineClass1, defineClass2;

   static {
       try {
           Class<?> cl = Class.forName("java.lang.ClassLoader");
           defineClass1 = SecurityActions.getDeclaredMethod(
                       cl,
                       "defineClass",
                       new Class[] { String.class, byte[].class,
                                     int.class, int.class });

           defineClass2 = SecurityActions.getDeclaredMethod(
                       cl,
                       "defineClass",
                       new Class[] { String.class, byte[].class,
                             int.class, int.class, ProtectionDomain.class });
       }
       catch (Exception e) {
           throw new RuntimeException("cannot initialize");
       }
       
       SecurityActions.setAccessible(defineClass1);
       SecurityActions.setAccessible(defineClass2);
   }
   
   /** The class of the interface we are implementing */
   final Class<T> type;
   
   /** The name of the class we are creating */
   final String name;
   
   /** The underlying class file writer */
   final ClassFileWriter fileWriter;
   
   /** The underlying class pool writer */
   final ConstPoolWriter poolWriter;
   
   /** This class's name index in the const pool */
   final int thisClass;

   /** This class's superclass name index in the const pool */
   final int superClass;
   
   /** The interfaces */
   final int[] interfaces;
   
   /** The method writer for the methods */
   final MethodWriter mw;
   
   /** The created bytes */
   byte[] bytes;
   
   int stackDepth;
   
   int maxStackDepth;
   
   ClassFileWriterContext(String name, String superClassName, Class<T> type, String[] interfaceNames)
   {
      this.name = name;
      this.type = type;
      
      //FIXME Once we get rid of the old ClassFile stuff we should make the real names look like this 
      //to start with
      name = ClassFileWriterContext.jvmClassName(name);
      superClassName = ClassFileWriterContext.jvmClassName(superClassName);
      for (int i = 0 ; i < interfaceNames.length ; i++)
         interfaceNames[i] = ClassFileWriterContext.jvmClassName(interfaceNames[i]);
      
      fileWriter = new ClassFileWriter(ClassFile.JAVA_4, 0);
      poolWriter = fileWriter.getConstPool();
      thisClass = poolWriter.addClassInfo(name);
      superClass = poolWriter.addClassInfo(superClassName);
      interfaces = poolWriter.addClassInfo(interfaceNames);
      
      //Add default constructor
      mw = fileWriter.getMethodWriter();
      mw.begin(Modifier.PUBLIC, MethodInfo.nameInit, "()V",  null, null);
      mw.add(Opcode.ALOAD_0);
      mw.add(Opcode.INVOKESPECIAL);
      int signature = poolWriter.addNameAndTypeInfo(MethodInfo.nameInit, "()V");
      mw.add16(poolWriter.addMethodrefInfo(superClass, signature));
      mw.add(Opcode.RETURN);
      mw.codeEnd(1, 1);
      mw.end(null, null);
   }

   String getSimpleType()
   {
      return type.getSimpleName();
   }
   
   String getName()
   {
      return name;
   }
   
   void beginMethod(int accessFlags, String name, String descriptor, String[] exceptions) 
   {
      mw.begin(Modifier.PUBLIC, name, descriptor, exceptions, null);
   }
   
   void endMethod(int maxLocals)
   {
      mw.codeEnd(maxStackDepth, maxLocals);
      mw.end(null, null);
   }
   
   void addInvokeStatic(String targetClass, String methodName, String descriptor)
   {
      mw.addInvoke(Opcode.INVOKESTATIC, targetClass, methodName, descriptor);
      
      //Stolen from Bytecode.addInvokestatic()
      growStack(Descriptor.dataSize(descriptor));   
   }

   void addInvokeVirtual(String targetClass, String methodName, String descriptor)
   {
      mw.addInvoke(Opcode.INVOKEVIRTUAL, targetClass, methodName, descriptor);

      //Stolen from Bytecode.addInvokevirtual()
      growStack(Descriptor.dataSize(descriptor) - 1);
   }

   void addInvokeInterface(String targetClass, String methodName, String descriptor, int count)
   {
      mw.addInvoke(Opcode.INVOKEINTERFACE, targetClass, methodName, descriptor);
      mw.add(count);
      mw.add(0);

      //Stolen from Bytecode.addInvokeinterface()
      growStack(Descriptor.dataSize(descriptor) - 1);
   }
   
   void addInvokeSpecial(String targetClass, String methodName, String descriptor)
   {
      mw.addInvoke(Opcode.INVOKESPECIAL, targetClass, methodName, descriptor);         

      //Stolen from Bytecode.addInvokespecial()
      growStack(Descriptor.dataSize(descriptor) - 1);
   }

   void addGetField(String className, String fieldName, String type)
   {
      mw.add(Opcode.GETFIELD);
      addFieldRefInfo(className, fieldName, type);
      
      //Stolen from Bytecode.addGetfield()
      growStack(Descriptor.dataSize(type) - 1);
   }
   
   void addGetStatic(String className, String fieldName, String type)
   {
      mw.add(Opcode.GETSTATIC);
      addFieldRefInfo(className, fieldName, type);

      //Stolen from Bytecode.addGetstatic()
      growStack(Descriptor.dataSize(type));   
   }
   
   void addPutField(String className, String fieldName, String type)
   {
      mw.add(Opcode.PUTFIELD);
      addFieldRefInfo(className, fieldName, type);
      
      //Stolen from Bytecode.addPutfield()
      growStack(1 - Descriptor.dataSize(type));
   }
   
   void addPutStatic(String className, String fieldName, String type)
   {
      mw.add(Opcode.PUTSTATIC);
      addFieldRefInfo(className, fieldName, type);

      //Stolen from Bytecode.addPutStatic()
      growStack(-Descriptor.dataSize(type));
   }
   
   void addAReturn()
   {
      mw.add(Opcode.ARETURN);
      
      //From Opcode.STACK_GROW[] (see test in main())
      growStack(-1);
   }
   
   void addAConstNull()
   {
      mw.add(Opcode.ACONST_NULL);

      //From Opcode.STACK_GROW[] (see test in main())
      growStack(1);
   }
   
   void addAALoad()
   {
      mw.add(Opcode.AALOAD);

      //From Opcode.STACK_GROW[] (see test in main())
      growStack(-1);
   }
   
   /**
    * Adds the right bytecode to call ALOAD depending on the 
    * number of the parameter
    * 
    * @param i the number of the parameter
    * @see Bytecode#addAload(int)
    */
   void addAload(int i)
   {
      if (i < 4)
         mw.add(Opcode.ALOAD_0 + i);
      else if (i < 0x100) 
      {
         mw.add(Opcode.ALOAD);           // aload
         mw.add(i);
      }
      else 
      {
         mw.add(Opcode.WIDE);
         mw.add(Opcode.ALOAD);
         addIndex(i);
     }
      //From Opcode.STACK_GROW[] (see test in main())
      growStack(1);
   }
   
   /**
    * Adds the right bytecode to load a constant depending on the
    * size of the constant
    * 
    * @param i the number
    * @see Bytecode#addIconst(int);
    */
   void addIconst(int i)
   {
      if (i < 6 && -2 < i)
         mw.add(Opcode.ICONST_0 + i);           // iconst_<i>   -1..5
      else if (i <= 127 && -128 <= i) {
         mw.add(Opcode.BIPUSH);              // bipush
         mw.add(i);
      }
      else if (i <= 32767 && -32768 <= i) 
      {
         mw.add(Opcode.SIPUSH);              // sipush
            mw.add(i >> 8);
            mw.add(i);
      }
      else
      {
         int ref = poolWriter.addIntegerInfo(i);

         if (i > 0xFF)
         {
            mw.add(Opcode.LDC_W);
            mw.add(i >> 8);
            mw.add(i);
         }
         else
         {
            mw.add(Opcode.LDC);
            mw.add(ref);
         }
      }
      //From Opcode.STACK_GROW[] (see test in main())
      growStack(1);
   }
      
   void addNew(String className)
   {
      mw.add(Opcode.NEW);
      addIndex(addClassInfo(className));
      
      //From Opcode.STACK_GROW[] (see test in main())
      growStack(1);
   }

   void addDup()
   {
      mw.add(Opcode.DUP);
      
      //From Opcode.STACK_GROW[] (see test in main())
      growStack(1);
   }

   void addCheckcast(String clazz)
   {
      mw.add(Opcode.CHECKCAST);
      int i = poolWriter.addClassInfo(clazz);
      addIndex(i);

      //From Opcode.STACK_GROW[] (see test in main())
      //No change to stack
   }

   byte[] getBytes()
   {
      if (bytes == null)
         bytes = fileWriter.end(Modifier.PUBLIC, thisClass, superClass, interfaces, null);
      return bytes;
   }

   Class<T> toClass(ClassLoader loader, ProtectionDomain domain) throws InvocationTargetException, IllegalAccessException
   {
      byte[] bytes = getBytes();
      if (domain == null)
         return (Class<T>) SecurityActions.invoke(defineClass1, loader, name, bytes, Integer.valueOf(0), Integer.valueOf(bytes.length));
      else
         return (Class<T>) SecurityActions.invoke(defineClass2, loader, name, bytes, Integer.valueOf(0), Integer.valueOf(bytes.length), domain);
   }
   
   private void addIndex(int i)
   {
      mw.add(i >> 8);
      mw.add(i);
   }

   private void addFieldRefInfo(String className, String fieldName, String type)
   {
      addIndex(poolWriter.addFieldrefInfo(poolWriter.addClassInfo(className), poolWriter.addNameAndTypeInfo(fieldName, type)));
   }

   private int addClassInfo(String className)
   {
      return poolWriter.addClassInfo(className.replace('.', '/'));
   }
  
   private void growStack(int i)
   {
      stackDepth += i;
      if (stackDepth > maxStackDepth)
         maxStackDepth = stackDepth;
   }
   
   static String jvmClassName(CtClass clazz)
   {
      return ClassFileWriterContext.jvmClassName(clazz.getName());
   }

   static String jvmClassName(String name)
   {
      return name.replace('.', '/');
   }

   public static void main (String[] args)
   {
      System.out.println("  Opcode.INVOKESTATIC   \t" + Opcode.STACK_GROW[ Opcode.INVOKESTATIC ]);
      System.out.println("  Opcode.INVOKEVIRTUAL  \t" + Opcode.STACK_GROW[ Opcode.INVOKEVIRTUAL ]);
      System.out.println("  Opcode.INVOKEINTERFACE\t" + Opcode.STACK_GROW[ Opcode.INVOKEINTERFACE ]);
      System.out.println("  Opcode.INVOKESPECIAL  \t" + Opcode.STACK_GROW[ Opcode.INVOKESPECIAL ]);
      System.out.println("  Opcode.GETFIELD       \t" + Opcode.STACK_GROW[ Opcode.GETFIELD ]);
      System.out.println("  Opcode.GETSTATIC      \t" + Opcode.STACK_GROW[ Opcode.GETSTATIC ]);
      System.out.println("  Opcode.PUTFIELD       \t" + Opcode.STACK_GROW[ Opcode.PUTFIELD ]);
      System.out.println("  Opcode.PUTSTATIC      \t" + Opcode.STACK_GROW[ Opcode.PUTSTATIC ]);
      System.out.println("  Opcode.ARETURN        \t" + Opcode.STACK_GROW[ Opcode.ARETURN ]);
      System.out.println("  Opcode.ACONST_NULL    \t" + Opcode.STACK_GROW[ Opcode.ACONST_NULL ]);
      System.out.println("  Opcode.AALOAD         \t" + Opcode.STACK_GROW[ Opcode.AALOAD ]);
      System.out.println("  Opcode.ALOAD_1        \t" + Opcode.STACK_GROW[ Opcode.ALOAD_1 ]);
      System.out.println("  Opcode.ALOAD          \t" + Opcode.STACK_GROW[ Opcode.ALOAD ]);
      System.out.println("  Opcode.WIDE           \t" + Opcode.STACK_GROW[ Opcode.WIDE ]);
      System.out.println("  Opcode.NEW            \t" + Opcode.STACK_GROW[ Opcode.NEW ]);
      System.out.println("  Opcode.DUP            \t" + Opcode.STACK_GROW[ Opcode.DUP ]);
      System.out.println("  Opcode.CHECKCAST      \t" + Opcode.STACK_GROW[ Opcode.CHECKCAST ]);     
      System.out.println("  Opcode.ICONST_1)      \t" + Opcode.STACK_GROW[ Opcode.ICONST_1 ]);
      System.out.println("  Opcode.BIPUSH)        \t" + Opcode.STACK_GROW[ Opcode.BIPUSH ]);
      System.out.println("  Opcode.SIPUSH)        \t" + Opcode.STACK_GROW[ Opcode.SIPUSH ]);
      System.out.println("  Opcode.LDC_W)         \t" + Opcode.STACK_GROW[ Opcode.LDC_W ]);
      System.out.println("  Opcode.LDC)           \t" + Opcode.STACK_GROW[ Opcode.LDC ]);
   }
 }
