/*
 * JOCL - Java bindings for OpenCL
 *
 * Copyright 2009 Marco Hutter - http://www.jocl.org/
 */

package org.jocl.samples;

import static javax.media.opengl.GL.*;
import static org.jocl.CL.*;

import java.awt.*;
import java.awt.event.*;
import java.io.*;
import java.nio.*;

import javax.media.opengl.*;
import javax.media.opengl.glu.GLU;
import javax.swing.*;

import org.jocl.*;

import com.sun.opengl.util.Animator;
import com.sun.opengl.util.j2d.TextRenderer;

/**
 * A small example demonstrating the JOCL/JOGL interoperability,
 * using the "simpleGL.cl" kernel from the NVIDIA "oclSimpleGL"
 * example.
 */
public class JOCLSimpleGL implements GLEventListener
{
    /**
     * Entry point for this sample.
     *
     * @param args not used
     */
    public static void main(String args[])
    {
        SwingUtilities.invokeLater(new Runnable()
        {
            public void run()
            {
                new JOCLSimpleGL();
            }
        });
    }

    /**
     * Compile-time flag which indicates whether the real OpenCL/OpenGL
     * interoperation should be used. If this flag is 'true', then the
     * buffers should be shared between OpenCL and OpenGL. If it is
     * 'false', then the buffer contents will be copied via the host.
     */
    private static final boolean GL_INTEROP = false;

    /**
     * Whether the initialization method of this GLEventListener has
     * already been called
     */
    private boolean initialized = false;

    /**
     * Text renderer for status messages
     */
    private TextRenderer renderer;

    /**
     * The width segments of the mesh to be displayed.
     * Should be a multiple of 8.
     */
    private static final int meshWidth = 8 * 64;

    /**
     * The height segments of the mesh to be displayed
     * Should be a multiple of 8.
     */
    private static final int meshHeight = 8 * 64;

    /**
     * The current animation state of the mesh
     */
    private float animationState = 0.0f;

    /**
     * The animator used to animate the mesh.
     */
    private Animator animator;

    /**
     * The VBO identifier
     */
    private int vertexBufferObject;

    /**
     * The cl_mem that has the contents of the VBO
     */
    private cl_mem vboMem;

    /**
     * The OpenCL context
     */
    private cl_context context;

    /**
     * The OpenCL command queue
     */
    private cl_command_queue commandQueue;

    /**
     * The OpenCL kernel
     */
    private cl_kernel kernel;

    /**
     * Whether the computation should be performed with JOCL or
     * with Java. May be toggled by pressing the 't' key.
     */
    private boolean useJOCL = true;

    /**
     * The translation in X-direction
     */
    private float translationX = 0;

    /**
     * The translation in Y-direction
     */
    private float translationY = 0;

    /**
     * The translation in Z-direction
     */
    private float translationZ = -4;

    /**
     * The rotation about the X-axis, in degrees
     */
    private float rotationX = 40;

    /**
     * The rotation about the Y-axis, in degrees
     */
    private float rotationY = 30;

    /**
     * The System.nanoTime() of the previous rendered frame.
     */
    private long prevFrameNanoTime = 0;

    /**
     * Inner class encapsulating the MouseMotionListener and
     * MouseWheelListener for the interaction
     */
    class MouseControl implements MouseMotionListener, MouseWheelListener
    {
        private Point previousMousePosition = new Point();

        public void mouseDragged(MouseEvent e)
        {
            int dx = e.getX() - previousMousePosition.x;
            int dy = e.getY() - previousMousePosition.y;

            // If the left button is held down, move the object
            if ((e.getModifiersEx() & MouseEvent.BUTTON1_DOWN_MASK) ==
                MouseEvent.BUTTON1_DOWN_MASK)
            {
                translationX += dx / 100.0f;
                translationY -= dy / 100.0f;
            }

            // If the right button is held down, rotate the object
            else if ((e.getModifiersEx() & MouseEvent.BUTTON3_DOWN_MASK) ==
                MouseEvent.BUTTON3_DOWN_MASK)
            {
                rotationX += dy;
                rotationY += dx;
            }
            previousMousePosition = e.getPoint();
        }

        public void mouseMoved(MouseEvent e)
        {
            previousMousePosition = e.getPoint();
        }

        public void mouseWheelMoved(MouseWheelEvent e)
        {
            // Translate along the Z-axis
            translationZ += e.getWheelRotation() * 0.25f;
            previousMousePosition = e.getPoint();
        }
    }

    /**
     * Inner class extending a KeyAdapter for the keyboard
     * interaction
     */
    class KeyboardControl extends KeyAdapter
    {
        public void keyTyped(KeyEvent e)
        {
            char c = e.getKeyChar();
            if (c == 't')
            {
                useJOCL = !useJOCL;
            }
        }
    }

    /**
     * Creates a new JJOCLSimpleGL.
     */
    public JOCLSimpleGL()
    {
        // Initialize the GL component and the animator
        GLJPanel glComponent = new GLJPanel();
        glComponent.setFocusable(true);
        glComponent.addGLEventListener(this);

        // Initialize the mouse and keyboard controls
        MouseControl mouseControl = new MouseControl();
        glComponent.addMouseMotionListener(mouseControl);
        glComponent.addMouseWheelListener(mouseControl);
        KeyboardControl keyboardControl = new KeyboardControl();
        glComponent.addKeyListener(keyboardControl);

        // Create the main frame
        JFrame frame = new JFrame("JOCL / JOGL interaction sample");
        frame.addWindowListener(new WindowAdapter()
        {
            public void windowClosing(WindowEvent e)
            {
                runExit();
            }
        });
        frame.setLayout(new BorderLayout());
        glComponent.setPreferredSize(new Dimension(800, 800));
        frame.add(glComponent, BorderLayout.CENTER);
        frame.pack();
        frame.setVisible(true);
        glComponent.requestFocus();

        // Create and start the animator
        animator = new Animator(glComponent);
        animator.setRunAsFastAsPossible(true);
        animator.start();
    }

    /**
     * Implementation of GLEventListener: Called to initialize the
     * GLAutoDrawable
     */
    public void init(GLAutoDrawable drawable)
    {
        // Perform the default GL initialization
        GL gl = drawable.getGL();
        gl.setSwapInterval(0);
        gl.glEnable(GL_DEPTH_TEST);
        gl.glClearColor(0.0f, 0.0f, 0.0f, 1.0f);

        // Set up the view matrix
        setupView(drawable);

        // Create a TextRenderer for the status messages
        renderer = new TextRenderer(new Font("SansSerif", Font.PLAIN, 18));

        if (initialized)
        {
            return;
        }

        // Initialize the GL_ARB_vertex_buffer_object extension
        if (!gl.isExtensionAvailable("GL_ARB_vertex_buffer_object"))
        {
            new Thread(new Runnable()
            {
                public void run()
                {
                    JOptionPane.showMessageDialog(null,
                        "GL_ARB_vertex_buffer_object extension not available",
                        "Unavailable extension", JOptionPane.ERROR_MESSAGE);
                    runExit();
                }
            }).start();
        }


        // Obtain the platform IDs and initialize the context properties
        cl_platform_id platforms[] = new cl_platform_id[1];
        clGetPlatformIDs(platforms.length, platforms, null);
        cl_context_properties contextProperties = new cl_context_properties();
        contextProperties.addProperty(CL_CONTEXT_PLATFORM, platforms[0]);

        // Try to create the OpenCL context on a GPU device
        context = clCreateContextFromType(
            contextProperties, CL_DEVICE_TYPE_GPU, null, null, null);
        if (context == null)
        {
            // If no context for a GPU device could be created,
            // try to create one for a CPU device.
            context = clCreateContextFromType(
                contextProperties, CL_DEVICE_TYPE_CPU, null, null, null);

            if (context == null)
            {
                System.out.println("Unable to create a context");
                runExit();
            }
        }

        // Enable exceptions and subsequently omit error checks in this sample
        setExceptionsEnabled(true);
        //setLogLevel(LogLevel.LOG_DEBUGTRACE);

        // Get the list of GPU devices associated with context
        long numBytes[] = new long[1];
        clGetContextInfo(context, CL_CONTEXT_DEVICES, 0, null, numBytes);
        int numDevices = (int) numBytes[0] / Sizeof.cl_device_id;
        cl_device_id devices[] = new cl_device_id[numDevices];
        clGetContextInfo(context, CL_CONTEXT_DEVICES, numBytes[0],
            Pointer.to(devices), null);
        cl_device_id device = devices[0];

        // Create a command-queue
        commandQueue = clCreateCommandQueue(context, device, 0, null);

        // Program Setup
        String source = readFile("simpleGL.cl");

        // Create the program
        cl_program cpProgram = clCreateProgramWithSource(context, 1,
            new String[]{ source }, null, null);

        // Build the program
        clBuildProgram(cpProgram, 0, null, "-cl-mad-enable", null, null);

        // Create the kernel
        kernel = clCreateKernel(cpProgram, "sine_wave", null);

        // Create VBO
        initVBO(drawable.getGL());

        // Set the kernel arguments
        clSetKernelArg(kernel, 0, Sizeof.cl_mem, Pointer.to(vboMem));
        clSetKernelArg(kernel, 1, Sizeof.cl_uint,
            Pointer.to(new int[]{ meshWidth }));
        clSetKernelArg(kernel, 2, Sizeof.cl_uint,
            Pointer.to(new int[]{ meshHeight }));

        initialized = true;
    }

    /**
     * Helper function which reads the file with the given name and returns
     * the contents of this file as a String. Will exit the application
     * if the file can not be read.
     *
     * @param fileName The name of the file to read.
     * @return The contents of the file
     */
    private String readFile(String fileName)
    {
        try
        {
            BufferedReader br = new BufferedReader(
                new InputStreamReader(new FileInputStream(fileName)));
            StringBuffer sb = new StringBuffer();
            String line = null;
            while (true)
            {
                line = br.readLine();
                if (line == null)
                {
                    break;
                }
                sb.append(line).append("\n");
            }
            return sb.toString();
        }
        catch (IOException e)
        {
            e.printStackTrace();
            runExit();
            return null;
        }

    }

    /**
     * Create the vertex buffer object (VBO) that stores the
     * vertex positions.
     *
     * @param gl The GL context
     */
    private void initVBO(GL gl)
    {
        if (vertexBufferObject != 0)
        {
            gl.glDeleteBuffers(1, new int[]{ vertexBufferObject }, 0);
            vertexBufferObject = 0;
        }

        // Create the vertex buffer object
        int buffer[] = new int[1];
        gl.glGenBuffers(1, IntBuffer.wrap(buffer));
        vertexBufferObject = buffer[0];

        // Initialize the vertex buffer object
        gl.glBindBuffer(GL_ARRAY_BUFFER, vertexBufferObject);
        int size = meshWidth * meshHeight * 4 * Sizeof.cl_float;
        gl.glBufferData(GL_ARRAY_BUFFER, size, (Buffer) null,
            GL_DYNAMIC_DRAW);
        gl.glBindBuffer(GL_ARRAY_BUFFER, 0);

        if (GL_INTEROP)
        {
            // Create OpenCL buffer from GL VBO
            vboMem = clCreateFromGLBuffer(context, CL_MEM_WRITE_ONLY,
                vertexBufferObject, null);
        }
        else
        {
            vboMem = clCreateBuffer(context, CL_MEM_WRITE_ONLY, size,
                null, null);
        }
    }

    /**
     * Set up a default view for the given GLAutoDrawable
     *
     * @param drawable The GLAutoDrawable to set the view for
     */
    private void setupView(GLAutoDrawable drawable)
    {
        GL gl = drawable.getGL();

        gl.glViewport(0, 0, drawable.getWidth(), drawable.getHeight());

        gl.glMatrixMode(GL_PROJECTION);
        gl.glLoadIdentity();
        GLU glu = new GLU();
        float aspect = (float) drawable.getWidth() / drawable.getHeight();
        glu.gluPerspective(50.0, aspect, 0.1, 100.0);
    }

    /**
     * Implementation of GLEventListener: Called when the given GLAutoDrawable
     * is to be displayed.
     */
    public void display(GLAutoDrawable drawable)
    {
        if (!initialized)
        {
            return;
        }
        GL gl = drawable.getGL();

        if (useJOCL)
        {
            // Run the JOCL kernel to generate new vertex positions.
            runJOCL(gl);
        }
        else
        {
            // Run the Java method to generate new vertex positions.
            runJava(gl);
        }

        gl.glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

        // Set up the modelview matrix according to the current
        // translation and rotation.
        gl.glMatrixMode(GL_MODELVIEW);
        gl.glLoadIdentity();
        gl.glTranslatef(translationX, translationY, translationZ);
        gl.glRotatef(rotationX, 1.0f, 0.0f, 0.0f);
        gl.glRotatef(rotationY, 0.0f, 1.0f, 0.0f);

        // Render the VBO
        gl.glBindBuffer(GL_ARRAY_BUFFER, vertexBufferObject);
        gl.glVertexPointer(4, GL_FLOAT, 0, 0);
        gl.glEnableClientState(GL_VERTEX_ARRAY);
        gl.glColor3f(1.0f, 0.0f, 0.0f);
        gl.glDrawArrays(GL_POINTS, 0, meshWidth * meshHeight);
        gl.glBindBuffer(GL_ARRAY_BUFFER, 0);
        gl.glDisableClientState(GL_VERTEX_ARRAY);

        // Compute FPS
        long nanoTime = System.nanoTime();
        double frameTimeMs = (nanoTime - prevFrameNanoTime) / 1000000.0;
        prevFrameNanoTime = nanoTime;
        double fps = 1000.0 / frameTimeMs;
        String fpsString = String.format("%.2f", fps);

        // Print status message
        renderer.beginRendering(drawable.getWidth(), drawable.getHeight());
        renderer.setColor(1.0f, 1.0f, 1.0f, 0.5f);
        renderer.draw("[t] Toggle JOCL/Java mode", 20, 30);
        if (useJOCL)
        {
            renderer.draw("Current mode: JOCL, " + fpsString + " fps", 20, 10);
        }
        else
        {
            renderer.draw("Current mode: Java, " + fpsString + " fps", 20, 10);
        }
        renderer.endRendering();

        animationState += 0.01;
    }

    /**
     * Run the JOCL computation to create new vertex positions
     * inside the vertexBufferObject.
     *
     * @param gl The current GL
     */
    private void runJOCL(GL gl)
    {

        if (GL_INTEROP)
        {
            // Map OpenGL buffer object for writing from OpenCL
            gl.glFinish();
            clEnqueueAcquireGLObjects(commandQueue, 1,
                new cl_mem[]{ vboMem }, 0, null, null);
        }

        // Set work size and execute the kernel
        long globalWorkSize[] = new long[2];
        globalWorkSize[0] = meshWidth;
        globalWorkSize[1] = meshHeight;

        clSetKernelArg(kernel, 3, Sizeof.cl_float,
            Pointer.to(new float[]{ animationState }));
        clEnqueueNDRangeKernel(commandQueue, kernel, 2, null,
            globalWorkSize, null, 0, null, null);

        if (GL_INTEROP)
        {
            // Unmap buffer object
            clEnqueueReleaseGLObjects(commandQueue, 1,
                new cl_mem[]{ vboMem }, 0, null, null);
            clFinish(commandQueue);
        }
        else
        {
            // Explicit Copy:
            // Map the PBO to copy data from the CL buffer via host
            gl.glBindBuffer(GL_ARRAY_BUFFER, vertexBufferObject);

            // Map the buffer object into client's memory
            ByteBuffer pointer = gl.glMapBuffer(GL_ARRAY_BUFFER,
                GL_WRITE_ONLY);

            clEnqueueReadBuffer(commandQueue, vboMem, CL_TRUE, 0,
                Sizeof.cl_float * 4 * meshHeight * meshWidth,
                Pointer.to(pointer), 0, null, null);

            gl.glUnmapBuffer(GL_ARRAY_BUFFER);
        }
    }

    /**
     * Run the Java computation to create new vertex positions
     * inside the vertexBufferObject.
     *
     * @param gl The current GL.
     */
    private void runJava(GL gl)
    {
        gl.glBindBuffer(GL_ARRAY_BUFFER, vertexBufferObject);
        ByteBuffer byteBuffer = gl.glMapBuffer(GL_ARRAY_BUFFER,
            GL_READ_WRITE);
        if (byteBuffer == null)
        {
            throw new RuntimeException("Unable to map buffer");
        }
        FloatBuffer vertices = byteBuffer.order(
            ByteOrder.nativeOrder()).asFloatBuffer();
        for (int x = 0; x < meshWidth; x++)
        {
            for (int y = 0; y < meshHeight; y++)
            {
                // Calculate u/v coordinates
                float u = x / (float) meshWidth;
                float v = y / (float) meshHeight;

                u = u * 2.0f - 1.0f;
                v = v * 2.0f - 1.0f;

                // Calculate simple sine wave pattern
                float freq = 4.0f;
                float w = (float) Math.sin(u * freq + animationState) *
                          (float) Math.cos(v * freq + animationState) * 0.5f;

                // Write output vertex
                int index = 4 * (y * meshWidth + x);
                vertices.put(index + 0, u);
                vertices.put(index + 1, w);
                vertices.put(index + 2, v);
                vertices.put(index + 3, 1);
            }
        }
        gl.glUnmapBuffer(GL_ARRAY_BUFFER);
        gl.glBindBuffer(GL_ARRAY_BUFFER, 0);
    }

    /**
     * Implementation of GLEventListener: Called then the
     * GLAutoDrawable was reshaped
     */
    public void reshape(GLAutoDrawable drawable, int x, int y,
                    int width, int height)
    {
        setupView(drawable);
    }

    /**
     * Implementation of GLEventListener - not used
     */
    public void displayChanged(GLAutoDrawable drawable, boolean modeChanged,
                    boolean deviceChanged)
    {}

    /**
     * Stops the animator and calls System.exit() in a new Thread.
     * (System.exit() may not be called synchronously inside one
     * of the JOGL callbacks)
     */
    private void runExit()
    {
        new Thread(new Runnable()
        {
            public void run()
            {
                animator.stop();
                System.exit(0);
            }
        }).start();
    }
}
