#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <sys/errno.h>
#include <CL/cl.h>

#define MAX_SOURCE_SIZE (0x100000)
#define STRING_MAX      200

int csvRead(char *fname, char ***olsp) {
	FILE *fp;
	char buf[1024];
	char *p, *ol, *cu;
	char **ols, *o;
	int olsz, olsnum;
	double cubu;

	fp = fopen(fname, "r");
	if (!fp) {
		return -errno;
	}

	olsz = 0;
	olsnum = 0;
	ols = NULL;
	while ((p = fgets(buf, sizeof(buf), fp)) != NULL) {
		int plen;

		plen = strlen(p);
		if (p[plen - 1] == '\n') {
			p[plen - 1] = '\0';
			plen--;
		}

		ol = p;
		cu = NULL;
		if ((cu = (char *) memchr(p, ' ', plen)) != NULL) {
			*cu = '\0';
			cu++;
			if ((p = strchr(cu, ' ')) != NULL) {
				*p = '\0';
			}
		} else if ((cu = (char *) memchr(p, ',', plen)) != NULL) {
			*cu = '\0';
			cu++;
			if ((p = strchr(cu, ',')) != NULL) {
				*p = '\0';
			}
		}

		cubu = 1.0;
		if (cu != NULL) {
			cubu = strtod(cu, &p);
			if (*p != '\0') {
//				fprintf(stderr, "invalid cubundance: %s\n", cu);
				cubu = 1.0;
			}
		}

		o = malloc(strlen(ol) + 1);
		memcpy(o, ol, strlen(ol) + 1);
		if (olsz <= olsnum) {
			olsz += 1000;
			ols = (char **) realloc(ols, sizeof(char *) * olsz);
		}

		ols[olsnum] = o;
		olsnum++;
	}

	if (ferror(fp)) {
		return -errno;
	}

	*olsp = ols;
	return olsnum;
}

int main(int argc, char **argv) 
{
    // Create the inputs
	int i,j,sz;
	int dnum, dsnum, olnum;
	char **ds, **ols;

	dnum = csvRead(argv[1], &ds);
	olnum = csvRead(argv[2], &ols);

	i = dnum / 64;
	if (dnum % 64)
		i++;

	dsnum = i * 64;

	printf("dataset %d reads %d\n", dnum, olnum);
	char *A = (char*)malloc(dsnum*(STRING_MAX + sizeof(int)));
	char *B = (char*)malloc(olnum*(STRING_MAX + sizeof(int)));
	char *results = (char*)malloc(dsnum * sizeof(int)); 
	char temp[STRING_MAX + sizeof(int)];
	double exec_time = 0.0;

	srand(1234);

	memset(A,dsnum*(STRING_MAX + sizeof(int)), 0);
	memset(B,olnum*(STRING_MAX + sizeof(int)), 0);

	// Create the set A, SET_SIZE pascal-type strings up to STRING_MAX length
	for(i = 0; i < dsnum; i++) {
		if (i < dnum)
			sz = strlen(ds[i]); //  1 + rand() % (STRING_MAX-1);
		else
			sz = 0;

		memcpy((char*)(A + i * (STRING_MAX + sizeof(int))), &sz, sizeof(int));
		if (sz != 0)
			memcpy((char *)(A + i * (STRING_MAX + sizeof(int)) + sizeof(int)), ds[i], sz);
	}

	// Create the set B, single pascal-type string
	for(i = 0; i < olnum; i++) {
		sz = strlen(ols[i]);
		memcpy((char*)(B + i * (STRING_MAX + sizeof(int))), &sz, sizeof(int));
		memcpy((char*)(B + i * (STRING_MAX + sizeof(int)) + sizeof(int)), ols[i], sz);
	}

    // Load the kernel 
    FILE *fp;
    char *source_str;
    size_t source_size;
 
    fp = fopen("kernel.cl", "r");
    if (!fp) {
        fprintf(stderr, "Failed to load kernel.\n");
        exit(1);
    }
    source_str = (char*)malloc(MAX_SOURCE_SIZE);
    source_size = fread( source_str, 1, MAX_SOURCE_SIZE, fp);
    fclose( fp );
 
    // OpenCL stuff initialization
    cl_platform_id platform_id = NULL;
    cl_device_id device_id = NULL;   
    cl_uint ret_num_devices;
    cl_uint ret_num_platforms;
    cl_int ret = clGetPlatformIDs(1, &platform_id, &ret_num_platforms);
    ret = clGetDeviceIDs( platform_id, CL_DEVICE_TYPE_DEFAULT, 1, &device_id, &ret_num_devices);
 
    cl_context context = clCreateContext( NULL, 1, &device_id, NULL, NULL, &ret);
    cl_command_queue command_queue = clCreateCommandQueue(context, device_id, 0, &ret);
 
    // Create the buffers and copy to device
    cl_mem set_a = clCreateBuffer(context, CL_MEM_READ_ONLY, dsnum*(STRING_MAX + sizeof(int)), NULL, &ret);
    cl_mem set_b = clCreateBuffer(context, CL_MEM_READ_ONLY, olnum*(STRING_MAX + sizeof(int)), NULL, &ret);
    cl_mem temp_set = clCreateBuffer(context, CL_MEM_READ_WRITE, dsnum * sizeof(int), NULL, &ret);
    cl_mem results_set = clCreateBuffer(context, CL_MEM_WRITE_ONLY, olnum * sizeof(int) * 2, NULL, &ret);
 
    // Copy the lists A and B to their respective memory buffers
    ret = clEnqueueWriteBuffer(command_queue, set_a, CL_TRUE, 0, dsnum*(STRING_MAX + sizeof(int)), A, 0, NULL, NULL);
    ret = clEnqueueWriteBuffer(command_queue, set_b, CL_TRUE, 0, olnum*(STRING_MAX + sizeof(int)), B, 0, NULL, NULL);
 
    // Compile
    cl_program program = clCreateProgramWithSource(context, 1, (const char **)&source_str, (const size_t *)&source_size, &ret);
    ret = clBuildProgram(program, 1, &device_id, NULL, NULL, NULL);
    if (ret != 0)
    {
        size_t log_size;
        clGetProgramBuildInfo(program, device_id, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
        char *log = (char *) malloc(log_size);
        clGetProgramBuildInfo(program, device_id, CL_PROGRAM_BUILD_LOG, log_size, log, NULL);
        printf("%s\n", log);
    }
    cl_kernel kernel_l = clCreateKernel(program, "levenshtein", &ret);
    cl_kernel kernel_r = clCreateKernel(program, "reduce", &ret);
 
    ret = clSetKernelArg(kernel_l, 0, sizeof(cl_mem), (void *)&set_a);
    ret = clSetKernelArg(kernel_l, 1, sizeof(cl_mem), (void *)&set_b);
    ret = clSetKernelArg(kernel_l, 2, sizeof(cl_mem), (void *)&temp_set);

    ret = clSetKernelArg(kernel_r, 0, sizeof(cl_mem), (void *)&temp_set);
    ret = clSetKernelArg(kernel_r, 1, sizeof(cl_mem), (void *)&results_set);
    i = dsnum;
    ret = clSetKernelArg(kernel_r, 3, sizeof(cl_uint), (void *)&i);

    size_t global_item_size = dsnum;
    size_t global_item2_size = 64;
    size_t local_item_size = 64; 

    unsigned int *C = (int*)malloc(sizeof(int)*olnum*2);
    clock_t start = clock();

    for (i=0;i < olnum; i++)
    {
        // Execute
        ret = clSetKernelArg(kernel_l, 3, sizeof(cl_uint), (void *)&i);
        ret = clEnqueueNDRangeKernel(command_queue, kernel_l, 1, NULL, &global_item_size, &local_item_size, 0, NULL, NULL);
        ret = clSetKernelArg(kernel_r, 2, sizeof(cl_uint), (void *)&i);
        ret = clEnqueueNDRangeKernel(command_queue, kernel_r, 1, NULL, &global_item2_size, &local_item_size, 0, NULL, NULL);
        ret = clFinish(command_queue);
    }

    ret = clEnqueueReadBuffer(command_queue, results_set, CL_TRUE, 0, olnum * sizeof(int) * 2, C, 0, NULL, NULL);
    clock_t end = clock();
    exec_time += (double)(end - start) / CLOCKS_PER_SEC;
    printf("execution time = %f seconds\n", exec_time);

    // Print best
    for (i = 0; i < olnum; i++)
        printf("%d = %u (index %u) %s\n", i, C[i*2], C[i*2+1], ols[i]);

    ret = clFlush(command_queue);
    ret = clFinish(command_queue);
    ret = clReleaseKernel(kernel_l);
    ret = clReleaseKernel(kernel_r);
    ret = clReleaseProgram(program);
    ret = clReleaseMemObject(set_a);
    ret = clReleaseMemObject(set_b);
    ret = clReleaseMemObject(results_set);
    ret = clReleaseMemObject(temp_set);
    ret = clReleaseCommandQueue(command_queue);
    ret = clReleaseContext(context);
    free(A);
    free(B);
    free(C);
    return 0;
}

