#include <stdio.h>
#include <stdlib.h>

#include <CL/cl.h>


int main(int argc, char ** argv)
{
   const char * kernel_name = argv[1];
   int arg0, arg1, expected, result;
   cl_int error;
   cl_device_id device_id;

   cl_context context;

   cl_command_queue command_queue;

   cl_kernel kernel;

   cl_mem out_buffer;
   int out_value = 0;
   size_t global_work_size = 1;

   /* Parse command line args */
   arg0 = atoi(argv[2]);
   arg1 = atoi(argv[3]);
   expected = atoi(argv[4]);

   if (!clSimpleInitGpuDevice(&device_id)) {
      return EXIT_FAILURE;
   }

   context = clCreateContext(NULL, /* Properties */
                           1, /* Number of devices */
                           &device_id, /* Device pointer */
                           NULL, /* Callback for reporting errors */
                           NULL, /* User data to pass to error callback */
                           &error); /* Error code */

   if (error != CL_SUCCESS) {
      fprintf(stderr, "clCreateContext() failed: %s\n", clUtilErrorString(error));
      return EXIT_FAILURE;
   }

   fprintf(stderr, "clCreateContext() succeeded.\n");

   command_queue = clCreateCommandQueue(context,
                                        device_id,
                                        0, /* Command queue properties */
                                        &error); /* Error code */

   if (error != CL_SUCCESS) {
      fprintf(stderr, "clCreateCommandQueue() failed: %s\n",
                      clUtilErrorString(error));
      return EXIT_FAILURE;
   }

   fprintf(stderr, "clCreateCommandQueue() succeeded.\n");

   if (!clSimpleCreateKernel(context, device_id, &kernel, kernel_name)) {
      return EXIT_FAILURE;
   }

   out_buffer = clCreateBuffer(context,
                               CL_MEM_WRITE_ONLY, /* Flags */
                               sizeof(int), /* Size of buffer */
                               NULL, /* Pointer to the data */
                               &error); /* error code */

   if (error != CL_SUCCESS) {
      fprintf(stderr, "clCreateBuffer() failed: %s\n", clUtilErrorString(error));
      return EXIT_FAILURE;
   }

   fprintf(stderr, "clCreateBuffer() succeeded.\n");

  if (   !clSimpleKernelSetArg(kernel, 0, sizeof(cl_mem), &out_buffer)
      || !clSimpleKernelSetArg(kernel, 1, sizeof(int), &arg0)
      || !clSimpleKernelSetArg(kernel, 2, sizeof(int), &arg1)) {
      return EXIT_FAILURE;
   }

   error = clEnqueueNDRangeKernel(command_queue,
                                  kernel,
                                  1, /* Number of dimensions */
                                  NULL, /* Global work offset */
                                  &global_work_size,
                                  &global_work_size, /* local work size */
                                  0, /* Events in wait list */
                                  NULL, /* Wait list */
                                  NULL); /* Event object for this event */

   if (error != CL_SUCCESS) {
      fprintf(stderr, "clEnqueueNDRangeKernel() failed: %s\n",
                      clUtilErrorString(error));
      return EXIT_FAILURE;
   }

   fprintf(stderr, "clEnqueueNDRangeKernel() suceeded.\n");

   error = clEnqueueReadBuffer(command_queue,
                                out_buffer,
                                CL_TRUE, /* TRUE means it is a blocking read. */
                                0, /* Buffer offset to read from. */
                                sizeof(int), /* Bytes to read */
                                &out_value, /* Pointer to store the data */
                                0, /* Events in wait list */
                                NULL, /* Wait list */
                                NULL); /* Event object */


   if (error != CL_SUCCESS) {
      fprintf(stderr, "clEnqueueReadBuffer() failed: %s\n",
                      clUtilErrorString(error));
      return EXIT_FAILURE;
   }

   fprintf(stderr, "clEnqueueReadBuffer() suceeded.\n");

   if (out_value == expected) {
      fprintf(stderr, "Pass\n");
      return EXIT_SUCCESS;
   } else {
      fprintf(stderr, "Expected %d, but got %d\n", expected, out_value);
      return EXIT_FAILURE;
   }
}