
import java.awt.*;
import java.io.*;
import java.awt.*;
import java.awt.geom.*;
import java.awt.image.*;
import java.awt.color.*;
import java.util.*;
import javax.swing.*;
import com.sun.media.jai.codec.*;


public class BandedRGBImageFactory implements ImageObserver{

   public static void  main(String[] args){
      BandedRGBImageFactory ifact = new BandedRGBImageFactory();
      if(args.length <1) {
         System.out.println("Enter a valid image file name");
         System.exit(0);
      }
      ifact.createRGBBufferedImage(args[0]);
   }

   public void createRGBBufferedImage(String filename) {
      Image awtImage = readImage(filename);
      int imageWidth = awtImage.getWidth(this);
      int imageHeight = awtImage.getHeight(this);
       //System.out.println( "len wid ht = "+imageWidth*imageHeight);
      int[] pix = fetchPixels(awtImage, imageWidth, imageHeight);
      byte[][] data = extractDataBanded(pix,4);
      BufferedImage image = createBandedRGBImage(imageWidth, imageHeight, 8, data,4);
      try {
         //saveAsBMP(image,"taj.bmp"); 
         saveAsPNGRGB(image,"lena1.png"); 
         displayImage(image);
      } catch(Exception e){e.printStackTrace();}
   }

   public void createRGBBufferedImageShort(String filename) {
      Image awtImage = readImage(filename);
      int imageWidth = awtImage.getWidth(this);
      int imageHeight = awtImage.getHeight(this);
       //System.out.println( "len wid ht = "+imageWidth*imageHeight);
      int[] pix = fetchPixels(awtImage, imageWidth, imageHeight);
      short[][] data = extractDataBandedShort(pix,3);
      BufferedImage image = createBandedRGBImage(imageWidth, imageHeight, 8, data,3);
      try {
         //saveAsBMP(image,"taj.bmp"); 
         saveAsPNGRGB(image,"lena.png"); 
         displayImage(image);
      } catch(Exception e){e.printStackTrace();}
   }


   public Image readImage(String imageName){
       Image image = Toolkit.getDefaultToolkit().getImage(imageName);
       MediaTracker imageTracker = new MediaTracker(new JPanel());
       imageTracker.addImage(image, 0);
       try{
        imageTracker.waitForID(0);
       }catch(InterruptedException e){ return null;}
       return image;
    }

   public boolean imageUpdate(Image img,
                             int infoflags,
                             int x,
                             int y,
                             int width,
                             int height){

       if((infoflags & ImageObserver.ERROR) != 0){
          System.out.println("ERROR in image loading or drawing");
          return false;
       }

       if((infoflags & (ImageObserver.FRAMEBITS | ImageObserver.ALLBITS))!= 0) {
           return false;
       }
       return true;
   }


   public int[] fetchPixels(Image image, int width, int height){
      int pixMap[] = new int[width*height];
     System.out.println("pix map len = "+ width*height);
      PixelGrabber pg = new PixelGrabber(image, 0,0,width,height, pixMap, 0, width);
      try {
	   pg.grabPixels();
      } catch (InterruptedException e){return null;}
      if((pg.status()  & ImageObserver.ABORT)!=0){
          return null;
      }
      return pixMap;
   }

   public byte[] extractData(int[] pixmap, int numbands) {
      byte data[] = new byte[pixmap.length*numbands];
      

      for(int i=0;i<pixmap.length;i++){
          int pixel = pixmap[i];
          byte a = (byte)((pixel >> 24) & 0xff);
          byte r  = (byte)((pixel >> 16) & 0xff);
	    byte g = (byte)((pixel >>  8) & 0xff);
	    byte b = (byte)((pixel      ) & 0xff);

          if(numbands == 4){
            data[i*numbands+0] = r;
            data[i*numbands+1] = g;
            data[i*numbands+2]= b;
            data[i*numbands+3] = a;
          } else {
            data[i*numbands+0] = r;
            data[i*numbands+1] = g;
            data[i*numbands+2]= b;
          }
      }
      return data;
   }

   public byte[][] extractDataBanded(int[] pixmap, int numbands) {
      byte data[][] = new byte[numbands][pixmap.length];
      System.out.println("len = "+pixmap.length);
      for(int i=0;i<pixmap.length;i++){
          int pixel = pixmap[i];
          byte a = (byte)((pixel >> 24) & 0xff);
          byte r  = (byte)((pixel >> 16) & 0xff);
	    byte g = (byte)((pixel >>  8) & 0xff);
	    byte b = (byte)((pixel      ) & 0xff);

          if(numbands == 4){
            data[0][i] = r;
            data[1][i] = g;
            data[2][i] = b;
            data[3][i] = a;
          } else {
            data[0][i] = r;
            data[1][i] = g;
            data[2][i] = b;
          }
      }
      return data;
   }

   public short[][] extractDataBandedShort(int[] pixmap, int numbands) {
      short data[][] = new short[numbands][pixmap.length];
      System.out.println("len = "+pixmap.length);
      for(int i=0;i<pixmap.length;i++){
          int pixel = pixmap[i];
          short a = (short)((pixel >> 24) & 0xff);
          short r  = (short)((pixel >> 16) & 0xff);
	    short g = (short)((pixel >>  8) & 0xff);
	    short b = (short)((pixel      ) & 0xff);

          if(numbands == 4){
            data[0][i] = r;
            data[1][i] = g;
            data[2][i] = b;
            data[3][i] = a;
          } else {
            data[0][i] = r;
            data[1][i] = g;
            data[2][i] = b;
          }
      }
      return data;
   }


   public static void displayImage(BufferedImage img){
       JFrame fr = new JFrame();
       ImagePanel pan = new ImagePanel(img);
       pan.setSize(256,256);
       fr.getContentPane().add(pan);
       fr.pack();
       fr.setSize(256,256);
       fr.show();
   }

   public static BufferedImage createBandedRGBImage(int imageWidth,
                                                 int imageHeight,
                                                 int imageDepth,
                                                 byte data[][], int numbands){

      int depth[] = new int[numbands];
      int bands[] = new int[numbands];
      int offsets[] = new int[numbands];
      for(int i=0; i<numbands;i++){
          depth[i] = imageDepth;
          bands[i]= i;
          offsets[i] = 0;
      }
      boolean hasAlpha = false;
      int transparency = Transparency.OPAQUE;
      if(numbands == 4) {
         hasAlpha = true;
         transparency = Transparency.TRANSLUCENT;
         System.out.println("transparecy = "+ transparency );
      }
      ComponentColorModel ccm = new ComponentColorModel(
                          ColorSpace.getInstance(ColorSpace.CS_sRGB),
                          depth, hasAlpha, false,
                          transparency,
                          DataBuffer.TYPE_BYTE);
      BandedSampleModel csm = new BandedSampleModel(
                                    DataBuffer.TYPE_BYTE,
                                    imageWidth, imageHeight, imageWidth,
                                    bands,offsets);
      DataBuffer dataBuf = new DataBufferByte(data, imageWidth*imageHeight);
      WritableRaster wr = Raster.createWritableRaster(csm, dataBuf, new Point(0,0));
      //WritableRaster wr = Raster.createBandedRaster(dataBuf,
      //                                              imageWidth, imageHeight, imageWidth,
      //                                              bands,offsets, new Point(0,0));

      Hashtable ht = new Hashtable();
      ht.put("owner", "Lawrence Rodrigues");
      return  new BufferedImage(ccm, wr, false, ht);
   }

   public static BufferedImage createBandedRGBImage(int imageWidth,
                                                 int imageHeight,
                                                 int imageDepth,
                                                 short data[][], int numbands){

      int depth[] = new int[numbands];
      int bands[] = new int[numbands];
      int offsets[] = new int[numbands];
      for(int i=0; i<numbands;i++){
          depth[i] = imageDepth;
          bands[i]= i;
          offsets[i] = 0;
      }
      boolean hasAlpha = false;
      int transparency = Transparency.OPAQUE;
      if(numbands == 4) {
         hasAlpha = true;
         transparency = Transparency.TRANSLUCENT;
      }
      ComponentColorModel ccm = new ComponentColorModel(
                          ColorSpace.getInstance(ColorSpace.CS_sRGB),
                          depth, hasAlpha, false,
                          transparency,
                          DataBuffer.TYPE_USHORT);
      BandedSampleModel csm = new BandedSampleModel(
                                    DataBuffer.TYPE_USHORT,
                                    imageWidth, imageHeight, imageWidth,
                                    bands,offsets);
      DataBuffer dataBuf = new DataBufferUShort(data, imageWidth*imageHeight);
      WritableRaster wr = Raster.createWritableRaster(csm, dataBuf, new Point(0,0));
      Hashtable ht = new Hashtable();
      ht.put("owner", "Lawrence Rodrigues");
      return  new BufferedImage(ccm, wr, false, ht);
   }



   static class ImagePanel extends JComponent {
      protected BufferedImage image;
      public ImagePanel(){}
      public ImagePanel(BufferedImage img){ image = img;}

      public void setImage(BufferedImage img){ image = img; }

      public void paintComponent(Graphics g){
         Rectangle rect = this.getBounds();
         if(image != null) {
            ((Graphics2D)g).drawImage(image,new AffineTransform(), null);
            //g.drawImage(image,0,0,rect.width, rect.height, this);
         }
      }
   }

   public static void saveAsPNGRGB(RenderedImage image, String file)
      throws java.io.IOException {
      String filename = file;
      if(!filename.endsWith(".png"))filename = new String(file+".png");
      OutputStream out = new FileOutputStream(filename);
      PNGEncodeParam.RGB param = new PNGEncodeParam.RGB();
      ImageEncoder encoder = ImageCodec.createImageEncoder("PNG", out, param);
      encoder.encode(image);
      out.close();
   }


   public static void saveAsJPEG(RenderedImage image, String file)
           throws java.io.IOException {
      String filename = file;
      if(!filename.endsWith(".jpg"))filename = new String(file+".jpg");
      OutputStream out = new FileOutputStream(filename);
      JPEGEncodeParam param = new JPEGEncodeParam();
      ImageEncoder encoder = ImageCodec.createImageEncoder("JPEG", out, param);
      encoder.encode(image);
      out.close();
   }

   public static void saveAsBMP(RenderedImage image, String file)
      throws java.io.IOException {
      String filename = file;
      if(!filename.endsWith(".bmp"))filename = new String(file+".bmp");
      OutputStream out = new FileOutputStream(filename);
      BMPEncodeParam param = new BMPEncodeParam();
      ImageEncoder encoder = ImageCodec.createImageEncoder("BMP", out, param);
      encoder.encode(image);
      out.close();
   }

}