/*
 * JET - JOCL Event Tracer
 * 
 * Copyright 2010 Marco Hutter - http://www.jocl.org/
 */

package org.jocl.jet.samples;

import static org.jocl.CL.*;

import java.nio.*;

import org.jocl.*;
import org.jocl.jet.JET;

/**
 * A simple example of how to use the JET, the JOCL Event Tracer.<br />
 * <br />
 * This example will add some large arrays using an OpenCL kernel. 
 * It will...
 * <ul>
 * <li>allocate three large memory objects for float vectors A, B and C</li>
 * <li>enqueue commands to write host data into these vectors</li> 
 * <li>enqueue kernels which compute the sums A+B and B+C </li>
 * <li>enqueue commands to read the sums back to the host</li>
 * </ul>
 * All events in this example will be sent to the JET, which
 * will display the events graphically.<br />
 * <br />
 * <b>NOTE:</b> The program will allocate <b>large</b> arrays,
 * so it should be started with the <i>-Xmx256m</i> parameter.  
 */
public class JETSimpleSample
{
    /**
     * The source code of the kernel that will be executed.
     * Will add two large arrays and write the result into
     * a third array. Mainly intended to take some time...
     */
    private static String programSource =
        "__kernel void test(" +
        "     __global const float *a,"+
        "     __global const float *b, " +
        "     __global float *c,"+
        "     int stride)"+
        "{"+
        "    int gid = get_global_id(0);"+
        "    for (int i=0; i<stride; i++)"+
        "    {"+
        "        c[gid*stride+i]=a[gid*stride+i]+b[gid*stride+i];"+
        "    }"+
        "}";
    
    /**
     * Entry point of this sample
     * 
     * @param args Not used
     */
    public static void main(String args[])
    {
        // Obtain the platform IDs and initialize the context properties
        System.out.println("Creating context...");
        cl_platform_id platforms[] = new cl_platform_id[1];
        clGetPlatformIDs(platforms.length, platforms, null);
        cl_context_properties contextProperties = new cl_context_properties();
        contextProperties.addProperty(CL_CONTEXT_PLATFORM, platforms[0]);

        // Try to create an OpenCL context on a GPU device
        cl_context context = clCreateContextFromType(
            contextProperties, CL_DEVICE_TYPE_GPU, null, null, null);
        if (context == null)
        {
            // If no context for a GPU device could be created,
            // try to create one for a CPU device.
            System.out.println("Unable to create a GPU context, using CPU...");
            context = clCreateContextFromType(
                contextProperties, CL_DEVICE_TYPE_CPU, null, null, null);

            if (context == null)
            {
                System.out.println("Unable to create a context");
                System.exit(1);
                return;
            }
        }
        
        CL.setExceptionsEnabled(true);
        
        // Get the list of GPU devices associated with context
        System.out.println("Initializing device...");
        long numBytes[] = new long[1];
        clGetContextInfo(context, CL_CONTEXT_DEVICES, 0, null, numBytes);
        int numDevices = (int) numBytes[0] / Sizeof.cl_device_id; 
        cl_device_id devices[] = new cl_device_id[numDevices];
        clGetContextInfo(context, CL_CONTEXT_DEVICES, numBytes[0], 
            Pointer.to(devices), null);

        // Create a command-queue
        System.out.println("Creating command queue...");
        long properties = 0;
        properties |= CL_QUEUE_PROFILING_ENABLE;
        properties |= CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE;
        cl_command_queue commandQueue = 
            clCreateCommandQueue(context, devices[0], properties, null);
        
        // Create the program
        System.out.println("Creating program...");
        cl_program program = clCreateProgramWithSource(context, 
            1, new String[]{ programSource }, null, null);
        
        // Build the program
        System.out.println("Building program...");
        clBuildProgram(program, 0, null, null, null, null);
        
        // Create the kernel
        System.out.println("Creating kernel...");
        cl_kernel kernel = clCreateKernel(program, "test", null);


        // Initialize the input data
        System.out.println("Initializing input data...");
        int globalSize = 1000;
        int stride = 10000;
        int n = globalSize * stride;
        
        FloatBuffer srcBufferA = createFloatBuffer(n);
        FloatBuffer srcBufferB = createFloatBuffer(n);
        FloatBuffer srcBufferC = createFloatBuffer(n);
        FloatBuffer dstBufferAB = createFloatBuffer(n);
        FloatBuffer dstBufferBC = createFloatBuffer(n);
        for (int i=0; i<n; i++)
        {
            srcBufferA.put(1);
            srcBufferB.put(2);
            srcBufferC.put(3);
        }
        Pointer srcA = Pointer.to(srcBufferA);
        Pointer srcB = Pointer.to(srcBufferB);
        Pointer srcC = Pointer.to(srcBufferC);
        Pointer dstAB = Pointer.to(dstBufferAB);
        Pointer dstBC = Pointer.to(dstBufferBC);
        
        // Allocate the buffer memory objects
        System.out.println("Initializing buffers...");
        cl_mem srcMemA = clCreateBuffer(context, 
            CL_MEM_READ_WRITE, Sizeof.cl_float * n, null, null);

        cl_mem srcMemB = clCreateBuffer(context, 
            CL_MEM_READ_WRITE, Sizeof.cl_float * n, null, null);

        cl_mem srcMemC = clCreateBuffer(context, 
            CL_MEM_READ_WRITE, Sizeof.cl_float * n, null, null);

        cl_mem dstMemAB = clCreateBuffer(context, 
            CL_MEM_READ_WRITE, Sizeof.cl_float * n, null, null);

        cl_mem dstMemBC = clCreateBuffer(context, 
            CL_MEM_READ_WRITE, Sizeof.cl_float * n, null, null);

        
        
        
        // Initialize the JET. This will avoid delays on the calling 
        // thread which might otherwise occur when the first event 
        // is traced. 
        JET.init(devices[0], commandQueue);
        
        

        // Write input data, and trace all events
        System.out.println("Enqueueing input write...");
        
        cl_event writeEventA = new cl_event();
        clEnqueueWriteBuffer(commandQueue, srcMemA, false, 0, 
            n * Sizeof.cl_float, srcA, 0, null, writeEventA);
        JET.traceEvent(writeEventA, "writeA");

        cl_event writeEventB = new cl_event();
        clEnqueueWriteBuffer(commandQueue, srcMemB, false, 0, 
            n * Sizeof.cl_float, srcB, 0, null, writeEventB);
        JET.traceEvent(writeEventB, "writeB");

        cl_event writeEventC = new cl_event();
        clEnqueueWriteBuffer(commandQueue, srcMemC, false, 0, 
            n * Sizeof.cl_float, srcC, 0, null, writeEventC);
        JET.traceEvent(writeEventC, "writeC");
        
        
        
        
        // Enqueue the kernels which will compute the sums A+B and 
        // B+C, and trace all events. The kernel executions will
        // wait for the events that indicate that the necessary
        // input writes are finished 
        System.out.println("Enqueueing kernel...");

        long globalWorkSize[] = new long[]{globalSize};
        long localWorkSize[] = new long[]{1};
        
        cl_event kernelEventAB = new cl_event();
        clSetKernelArg(kernel, 0, Sizeof.cl_mem, Pointer.to(srcMemA));
        clSetKernelArg(kernel, 1, Sizeof.cl_mem, Pointer.to(srcMemB));
        clSetKernelArg(kernel, 2, Sizeof.cl_mem, Pointer.to(dstMemAB));
        clSetKernelArg(kernel, 3, Sizeof.cl_int, Pointer.to(new int[]{stride}));
        
        cl_event waitListAB[] = new cl_event[] { writeEventA, writeEventB };
        clEnqueueNDRangeKernel(commandQueue, kernel, 1, null,
            globalWorkSize, localWorkSize, waitListAB.length, 
            waitListAB, kernelEventAB);

        JET.traceEvent(kernelEventAB, "kernelAB", waitListAB);

        
        cl_event kernelEventBC = new cl_event();
        clSetKernelArg(kernel, 0, Sizeof.cl_mem, Pointer.to(srcMemB));
        clSetKernelArg(kernel, 1, Sizeof.cl_mem, Pointer.to(srcMemC));
        clSetKernelArg(kernel, 2, Sizeof.cl_mem, Pointer.to(dstMemBC));
        clSetKernelArg(kernel, 3, Sizeof.cl_int, Pointer.to(new int[]{stride}));
        
        cl_event waitListBC[] = new cl_event[] { writeEventB, writeEventC };
        clEnqueueNDRangeKernel(commandQueue, kernel, 1, null,
            globalWorkSize, localWorkSize, waitListBC.length, 
            waitListBC, kernelEventBC);

        JET.traceEvent(kernelEventBC, "kernelBC", waitListBC);
        
        

        
        // Read output data. These reads will wait for the respective
        // kernels to be finished.
        System.out.println("Enqueueing output read...");

        cl_event readWaitListAB[] = new cl_event[] { kernelEventAB };
        cl_event readEventAB = new cl_event();
        clEnqueueReadBuffer(commandQueue, dstMemAB, CL_FALSE, 0,
            n * Sizeof.cl_float, dstAB, readWaitListAB.length, 
            readWaitListAB, readEventAB);
        JET.traceEvent(readEventAB, "readAB", readWaitListAB);

        cl_event readWaitListBC[] = new cl_event[] { kernelEventBC };
        cl_event readEventBC = new cl_event();
        clEnqueueReadBuffer(commandQueue, dstMemBC, CL_FALSE, 0,
            n * Sizeof.cl_float, dstBC, readWaitListBC.length, 
            readWaitListBC, readEventBC);
        JET.traceEvent(readEventBC, "readBC", readWaitListBC);

        clFinish(commandQueue);
        
        // Clean up
        clReleaseMemObject(srcMemA);
        clReleaseMemObject(srcMemB);
        clReleaseMemObject(srcMemC);
        clReleaseMemObject(dstMemAB);
        clReleaseMemObject(dstMemBC);
        clReleaseKernel(kernel);
        clReleaseProgram(program);
        clReleaseCommandQueue(commandQueue);
        clReleaseContext(context);
        
        // Print the result
        System.out.println("Result: ");
        System.out.println(stringFor(dstBufferAB, 10));
        System.out.println(stringFor(dstBufferBC, 10));
    }
    
    
    /**
     * Returns a string containing the first entries of the
     * given buffer. 
     * 
     * @param b The buffer
     * @param max The maximum number of elements to show 
     * @return The string
     */
    private static String stringFor(FloatBuffer b, int max)
    {
        StringBuilder sb = new StringBuilder();
        max = Math.min(max, b.capacity());
        for (int i=0; i<max; i++)
        {
            sb.append(b.get(i));
            if (i < max-1)
            {
                sb.append(", ");
            }
            else if (b.capacity() > max)
            {
                sb.append(" ...");
            }
        }
        return sb.toString();
    }

    /**
     * Create a direct FloatBuffer with the given size and the
     * native byte order
     * 
     * @param n The size
     * @return The FloatBuffer
     */
    private static FloatBuffer createFloatBuffer(int n)
    {
        ByteBuffer byteBuffer = ByteBuffer.allocateDirect(n * Sizeof.cl_float);
        return byteBuffer.order(ByteOrder.nativeOrder()).asFloatBuffer();        
    }
    
}

