Package com.jogamp.opencl.demos.fft

Source Code of com.jogamp.opencl.demos.fft.BlurTest

package com.jogamp.opencl.demos.fft;

import com.jogamp.opencl.CLBuffer;
import com.jogamp.opencl.CLCommandQueue;
import com.jogamp.opencl.CLContext;
import com.jogamp.opencl.CLDevice;
import com.jogamp.opencl.CLKernel;
import com.jogamp.opencl.CLMemory.Mem;
import com.jogamp.opencl.CLPlatform;
import com.jogamp.opencl.CLProgram;
import com.jogamp.opencl.demos.fft.CLFFTPlan.InvalidContextException;
import java.awt.BorderLayout;
import java.awt.Dimension;
import java.awt.Graphics;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.awt.Insets;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.awt.image.DataBufferInt;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.imageio.ImageIO;
import javax.swing.BoxLayout;
import javax.swing.ButtonGroup;
import javax.swing.JButton;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JSlider;
import javax.swing.JToggleButton;
import javax.swing.SwingUtilities;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;

/**
* Perform some user-controllable blur on an image.
* @author notzed
*/
public class BlurTest implements Runnable, ChangeListener, ActionListener {

    public static void main(String[] args) {
        SwingUtilities.invokeLater(new BlurTest());
    }
   
    boolean demo = false;
    // must be power of 2 and width must be multiple of 64
    int width = 512;
    int height = 512;
    BufferedImage src;
    BufferedImage psf;
    BufferedImage dst;
    PaintView left;
    ImageView right;
    //
    JSlider sizex;
    JSlider sizey;
    JSlider angle;
    //
    JToggleButton blurButton;
    JToggleButton drawButton;

    public void run() {
        try {
            initCL();
        } catch (Exception x) {
            System.out.println("failed to init cl");
            x.printStackTrace();
            System.exit(1);
        }

        JFileChooser fc = new JFileChooser();
        BufferedImage img = null;

        while (img == null) {
            try {
                File file = null;

                if (true) {
                    fc.setDialogTitle("Select Image File");
                    fc.setPreferredSize(new Dimension(500, 600));
                    if (fc.showOpenDialog(null) == JFileChooser.APPROVE_OPTION) {
                        file = fc.getSelectedFile();
                    } else {
                        System.exit(0);
                    }

                } else {
                    file = new File("/home/notzed/cat0.jpg");
                }
                img = ImageIO.read(file);
                if (img == null) {
                    JOptionPane.showMessageDialog(null, "Couldn't load file");
                }
            } catch (IOException x) {
                JOptionPane.showMessageDialog(null, "Couldn't load file");
            }
        }

        src = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
        dst = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
        psf = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);

        // Ensure loaded image is in known format and size
        Graphics g = src.createGraphics();
        g.drawImage(img, (width - img.getWidth()) / 2, (height - img.getHeight()) / 2, null);
        g.dispose();

        JFrame win = new JFrame("Blur Demo");
        win.setDefaultCloseOperation(win.EXIT_ON_CLOSE);

        JPanel main = new JPanel();
        main.setLayout(new BorderLayout());

        JPanel controls = new JPanel();
        controls.setLayout(new GridBagLayout());

        GridBagConstraints c0 = new GridBagConstraints();
        c0.gridx = 0;
        c0.anchor = GridBagConstraints.BASELINE_LEADING;
        c0.ipadx = 3;
        c0.insets = new Insets(1, 2, 1, 2);

        controls.add(new JLabel("Width"), c0);
        controls.add(new JLabel("Height"), c0);

        GridBagConstraints c2 = (GridBagConstraints) c0.clone();
        c2.gridx = 2;
        controls.add(new JLabel("Angle"), c2);

        c0 = (GridBagConstraints) c0.clone();
        c0.gridx = 1;
        c0.weightx = 1;
        c0.fill = GridBagConstraints.HORIZONTAL;
        sizex = new JSlider(100, 5000, 1000);
        sizey = new JSlider(100, 5000, 100);
        controls.add(sizex, c0);
        controls.add(sizey, c0);

        c2 = (GridBagConstraints) c0.clone();
        c2.gridx = 3;
        angle = new JSlider(0, (int) (Math.PI * 1000));
        controls.add(angle, c2);

        sizex.addChangeListener(this);
        sizey.addChangeListener(this);
        angle.addChangeListener(this);

        JPanel buttons = new JPanel();
        controls.add(buttons, c2);
        JButton b;
        b = new JButton("Clear");
        buttons.add(b);
        b.addActionListener(new ActionListener() {

            public void actionPerformed(ActionEvent e) {
                doclear();
            }
        });
        ButtonGroup opt = new ButtonGroup();
        JToggleButton tb;
        blurButton = new JToggleButton("Blur");
        opt.add(blurButton);
        buttons.add(blurButton);
        blurButton.addActionListener(this);
        drawButton = new JToggleButton("Draw");
        opt.add(drawButton);
        buttons.add(drawButton);
        drawButton.addActionListener(this);

        JPanel imgs = new JPanel();
        imgs.setLayout(new BoxLayout(imgs, BoxLayout.X_AXIS));
        left = new PaintView(this, psf);
        right = new ImageView(dst);
        imgs.add(left);
        imgs.add(right);

        main.add(controls, BorderLayout.NORTH);
        main.add(imgs, BorderLayout.CENTER);
        win.getContentPane().add(main);

        win.pack();
        win.setVisible(true);

        // pre-load and transform src, since that wont change
        loadSource(src);

        blurButton.doClick();
    }

    public void stateChanged(ChangeEvent e) {
        if (drawButton.isSelected()) {
            recalc();
        } else {
            double w = sizex.getValue() / 100.0;
            double h = sizey.getValue() / 100.0;
            double a = angle.getValue() / 1000.0;

            Graphics g = psf.createGraphics();

            g.clearRect(0, 0, width, height);
            g.dispose();

            left.drawDot(w, h, a);
        }
    }

    public void actionPerformed(ActionEvent e) {
        stateChanged(null);
    }

    private void doclear() {
        Graphics g = psf.createGraphics();

        g.clearRect(0, 0, width, height);
        g.dispose();
        left.repaint();
        recalc();
    }

    private void dorecalc() {
        loadPSF(psf);

        // convolve each plane in freq domain
        convolve(aCBuffer, psfBuffer, aGBuffer);
        convolve(rCBuffer, psfBuffer, rGBuffer);
        convolve(gCBuffer, psfBuffer, gGBuffer);
        convolve(bCBuffer, psfBuffer, bGBuffer);

        // convert back to spatial domain
        fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, aGBuffer, aBuffer, null, null);
        fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, rGBuffer, rBuffer, null, null);
        fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, gGBuffer, gBuffer, null, null);
        fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Inverse, bGBuffer, bBuffer, null, null);

        // while gpu is running, calculate energy of psf
        float scale;

        long total = 0;
        DataBufferByte pd = (DataBufferByte) psf.getRaster().getDataBuffer();
        byte[] data = pd.getData();
        for (int i = 0; i < data.length; i++) {
            total += data[i] & 0xff;
        }
        scale = 255.0f / total / width / height;

        getDestination(argbBuffer, aBuffer, rBuffer, gBuffer, bBuffer, scale);

        // drop back to java, slow-crappy-method
        q.putReadBuffer(argbBuffer, true);
        DataBufferInt db = (DataBufferInt) dst.getRaster().getDataBuffer();
        argbBuffer.getBuffer().position(0);
        argbBuffer.getBuffer().get(db.getData());
        argbBuffer.getBuffer().position(0);
        right.repaint();
    }
    Runnable later;

    void recalc() {
        if (later == null) {
            later = new Runnable() {

                public void run() {
                    later = null;
                    dorecalc();
                }
            };
            SwingUtilities.invokeLater(later);
        }
    }
    CLContext cl;
    CLCommandQueue q;
    CLProgram prog;
    CLKernel kImg2Planes;
    CLKernel kPlanes2Img;
    CLKernel kGrey2Plane;
    CLKernel kConvolve;
    CLKernel kDeconvolve;
    CLFFTPlan fft;
    CLBuffer<IntBuffer> argbBuffer;
    CLBuffer<ByteBuffer> greyBuffer;
    CLBuffer<FloatBuffer> aBuffer;
    CLBuffer<FloatBuffer> rBuffer;
    CLBuffer<FloatBuffer> gBuffer;
    CLBuffer<FloatBuffer> bBuffer;
    CLBuffer<FloatBuffer> aCBuffer;
    CLBuffer<FloatBuffer> rCBuffer;
    CLBuffer<FloatBuffer> gCBuffer;
    CLBuffer<FloatBuffer> bCBuffer;
    CLBuffer<FloatBuffer> aGBuffer;
    CLBuffer<FloatBuffer> rGBuffer;
    CLBuffer<FloatBuffer> gGBuffer;
    CLBuffer<FloatBuffer> bGBuffer;
    CLBuffer<FloatBuffer> psfBuffer;
    CLBuffer<FloatBuffer> tmpBuffer;
    //
    CLKernel fft512;

    void initCL() throws InvalidContextException {
       
        // search a platform with a GPU
        CLPlatform[] platforms = CLPlatform.listCLPlatforms();
        CLDevice gpu = null;
        for (CLPlatform platform : platforms) {
            gpu = platform.getMaxFlopsDevice(CLDevice.Type.GPU);
            if(gpu != null) {
                break;
            }
        }

        cl = CLContext.create(gpu);

        q = cl.getDevices()[0].createCommandQueue();

        prog = cl.createProgram(img2Planes + planes2Img + convolve + grey2Plane + deconvolve);
        prog.build("-cl-mad-enable");

        kImg2Planes = prog.createCLKernel("img2planes");
        kPlanes2Img = prog.createCLKernel("planes2img");
        kGrey2Plane = prog.createCLKernel("grey2plane");
        kConvolve = prog.createCLKernel("convolve");
        kDeconvolve = prog.createCLKernel("deconvolve");

        argbBuffer = cl.createIntBuffer(width * height, Mem.READ_WRITE);
        greyBuffer = cl.createByteBuffer(width * height, Mem.READ_WRITE);
        aBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        rBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        gBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        bBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        psfBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        tmpBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);

        aCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        rCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        gCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        bCBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);

        aGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        rGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        gGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        bGBuffer = cl.createFloatBuffer(width * height * 2, Mem.READ_WRITE);
        if (false) {
            try {
                CLProgram p = cl.createProgram(new FileInputStream("/home/notzed/cl/fft-512.cl"));
                p.build();
                fft512 = p.createCLKernel("fft0");
            } catch (IOException ex) {
                Logger.getLogger(BlurTest.class.getName()).log(Level.SEVERE, null, ex);
            }
        } else {
            fft = new CLFFTPlan(cl, new int[]{width, height}, CLFFTPlan.CLFFTDataFormat.InterleavedComplexFormat);
        }
        //fft.dumpPlan(null);
    }

    void loadSource(BufferedImage src) {
        DataBufferInt sb = (DataBufferInt) src.getRaster().getDataBuffer();

        argbBuffer.getBuffer().position(0);
        argbBuffer.getBuffer().put(sb.getData());
        argbBuffer.getBuffer().position(0);
        q.putWriteBuffer(argbBuffer, false);

        kImg2Planes.setArg(0, argbBuffer);
        kImg2Planes.setArg(1, 0);
        kImg2Planes.setArg(2, width);
        kImg2Planes.setArg(3, aBuffer);
        kImg2Planes.setArg(4, rBuffer);
        kImg2Planes.setArg(5, gBuffer);
        kImg2Planes.setArg(6, bBuffer);
        kImg2Planes.setArg(7, 0);
        kImg2Planes.setArg(8, width);
        q.put2DRangeKernel(kImg2Planes, 0, 0, width, height, 64, 1);
        q.finish();

        fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, aBuffer, aCBuffer, null, null);
        fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, rBuffer, rCBuffer, null, null);
        fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, gBuffer, gCBuffer, null, null);
        fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, bBuffer, bCBuffer, null, null);
    }

    void loadPSF(BufferedImage psf) {
        assert (psf.getType() == BufferedImage.TYPE_BYTE_GRAY);
        DataBufferByte pb = (DataBufferByte) psf.getRaster().getDataBuffer();

        greyBuffer.getBuffer().position(0);
        greyBuffer.getBuffer().put(pb.getData());
        greyBuffer.getBuffer().position(0);
        q.putWriteBuffer(greyBuffer, false);

        kGrey2Plane.setArg(0, greyBuffer);
        kGrey2Plane.setArg(1, 0);
        kGrey2Plane.setArg(2, width);
        kGrey2Plane.setArg(3, tmpBuffer);
        kGrey2Plane.setArg(4, 0);
        kGrey2Plane.setArg(5, width);
        q.put2DRangeKernel(kGrey2Plane, 0, 0, width, height, 64, 1);

        if (true) {
            fft.executeInterleaved(q, 1, CLFFTPlan.CLFFTDirection.Forward, tmpBuffer, psfBuffer, null, null);
        } else if (true) {
            fft512.setArg(0, tmpBuffer);
            fft512.setArg(1, psfBuffer);
            fft512.setArg(2, -1);
            fft512.setArg(3, height);
            //q.put1DRangeKernel(fft512, 0,height*64, 64);
            q.put2DRangeKernel(fft512, 0, 0, height * 64, 1, 64, 1);
            System.out.println("running kernel " + 64 * height + ", " + 64);
        }
    }

    // g = f x h
    void convolve(CLBuffer<FloatBuffer> h, CLBuffer<FloatBuffer> f, CLBuffer<FloatBuffer> g) {
        kConvolve.setArg(0, h);
        kConvolve.setArg(1, f);
        kConvolve.setArg(2, g);
        kConvolve.setArg(3, width);
        q.put2DRangeKernel(kConvolve, 0, 0, width, height, 64, 1);
    }

    // g = h*conj(f) / (abs(f)^2 + k)
    void deconvolve(CLBuffer<FloatBuffer> h, CLBuffer<FloatBuffer> f, CLBuffer<FloatBuffer> g, float k) {
        kDeconvolve.setArg(0, h);
        kDeconvolve.setArg(1, f);
        kDeconvolve.setArg(2, g);
        kDeconvolve.setArg(3, width);
        kDeconvolve.setArg(4, k);
        q.put2DRangeKernel(kDeconvolve, 0, 0, width, height, 64, 1);
    }

    void getDestination(CLBuffer<IntBuffer> dst, CLBuffer<FloatBuffer> a, CLBuffer<FloatBuffer> r, CLBuffer<FloatBuffer> g, CLBuffer<FloatBuffer> b, float scale) {
        kPlanes2Img.setArg(0, dst);
        kPlanes2Img.setArg(1, 0);
        kPlanes2Img.setArg(2, width);
        kPlanes2Img.setArg(3, a);
        kPlanes2Img.setArg(4, r);
        kPlanes2Img.setArg(5, g);
        kPlanes2Img.setArg(6, b);
        kPlanes2Img.setArg(7, 0);
        kPlanes2Img.setArg(8, width);
        kPlanes2Img.setArg(9, scale);
        q.put2DRangeKernel(kPlanes2Img, 0, 0, width, height, 64, 1);
    }
    // Convert packed ARGB byte image to planes of complex floats
    final String img2Planes =
              "kernel void img2planes(global const uchar4 *argb, int soff, int sstride,"
            + "  global float2 *a, global float2 *r, global float2 *g, global float2 *b, int doff, int dstride) {"
            + " int gx = get_global_id(0);"
            + " int gy = get_global_id(1);"
            + " uchar4 v = argb[soff+sstride*gy+gx];"
            + " float4 ff = convert_float4(v) * (float4)(1.0f/255);"
            + " doff += (dstride * gy + gx);"
            + " b[doff] = (float2){ ff.s0, 0 };\n"
            + " g[doff] = (float2){ ff.s1, 0 };"
            + " r[doff] = (float2){ ff.s2, 0 };"
            + " a[doff] = (float2){ ff.s3, 0 };\n"
            + "}\n\n";
    // not the best implementation
    // this also performs an 'fftshift'
    final String grey2Plane =
              "kernel void grey2plane(global const uchar *src, int soff, int sstride,"
            + "  global float2 *dst, int doff, int dstride) {"
            + " int gx = get_global_id(0);"
            + " int gy = get_global_id(1);"
            + " uchar v = src[soff+sstride*gy+gx];"
            + " float ff = convert_float(v) * (1.0f/255);"
            // fftshift
            + " gx ^= get_global_size(0)>>1;"
            + " gy ^= get_global_size(1)>>1;"
            + " doff += (dstride * gy + gx);"
            + " dst[doff] = (float2) { ff, 0 };"
            + "}\n\n";
    // This also does the 'fftscale' after the inverse fft.
    final String planes2Img =
              "kernel void planes2img(global uchar4 *argb, int soff, int sstride, const global float2 *a, const global float2 *r, const global float2 *g, const global float2 *b, int doff, int dstride, float scale) {"
            + " int gx = get_global_id(0);"
            + " int gy = get_global_id(1);"
            + " float4 fr, fi, fa;"
            + " float2 t;"
            + " doff += (dstride * gy + gx);"
            + " float2 s = (float2)scale;"
            + " t = b[doff]*s; fr.s0 = t.s0; fi.s0 = t.s1;"
            + " t = g[doff]*s; fr.s1 = t.s0; fi.s1 = t.s1;"
            + " t = r[doff]*s; fr.s2 = t.s0; fi.s2 = t.s1;"
            + " t = a[doff]*s; fr.s3 = t.s0; fi.s3 = t.s1;"
            + " fa = sqrt(fr*fr + fi*fi) * 255.0f;"
            + " fa = clamp(fa, 0.0f, 255.0f);"
            + " argb[soff +sstride*gy+gx] = convert_uchar4(fa);"
            + "}\n\n";
    final String convolve =
              "kernel void convolve(global const float2 *h, global const float2 *ff, global float2 *g, int stride) {"
            + " int gx = get_global_id(0);"
            + " int gy = get_global_id(1);"
            + " int off = stride * gy + gx;"
            + " float2 a = h[off];"
            + " float2 b = ff[off];"
            + " g[off] = (float2) { a.s0 * b.s0 - a.s1 * b.s1, a.s0 * b.s1 + a.s1 * b.s0 };"
            + "}\n\n";
    final String deconvolve =
              "kernel void deconvolve(global const float2 *h, global const float2 *ff, global float2 *g, int stride, float k) {"
            + " int gx = get_global_id(0);"
            + " int gy = get_global_id(1);"
            + " int off = stride * gy + gx;"
            + " float2 a = h[off];"
            + " float2 b = ff[off];"
            + " float d = b.s0 * b.s0 + b.s1 * b.s1 + k;"
            + " b.s0 /= d;"
            + " b.s1 /= -d;"
            + " g[off] = (float2) { a.s0 * b.s0 - a.s1 * b.s1, a.s0 * b.s1 + a.s1 * b.s0 };"
            + "}\n\n";
}
TOP

Related Classes of com.jogamp.opencl.demos.fft.BlurTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.