/*
This was written in a few minutes.
It's purely experimental, probably buggy,
not verified, not heavily tested, not reviewed, not constant-time.
*/

#include <immintrin.h>
#include <string.h>
#include "djbsort.h"
#include "radixwrapper.h"

#define CUTOFF 8192 /* use djbsort for n <= CUTOFF */
#define MAXSPLIT 1024 /* number of radix buckets allocated (could also use VLA) */
#define TARGET 50 /* preferred size for recursion */

#define int32 int32_t
#define uint32 uint32_t
typedef __m256i int32x8;
#define int32x8_load(z) _mm256_loadu_si256((__m256i *) (z))
#define int32x8_min _mm256_min_epi32
#define int32x8_max _mm256_max_epi32
#define int32x8_constextract_eachside(v,p0,p1,p2,p3) _mm256_shuffle_epi32(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define int32x8_10325476(a) int32x8_constextract_eachside(a,1,0,3,2)
#define int32x8_23016745(a) int32x8_constextract_eachside(a,2,3,0,1)
#define int32x8_45670123(a) _mm256_permute4x64_epi64(a,0x4e)
#define int32x8_0(a) _mm256_extract_epi32(a,0)

static void horizontal_minmax_int32_atleast8(int32 *xmin,int32 *xmax,const int32 *x,long long n)
{
  int32x8 low = int32x8_load(x+n-8);
  int32x8 high = low;
  for (long long i = 0;i+8 <= n;i += 8) {
    int32x8 xi = int32x8_load(x+i);
    low = int32x8_min(low,xi);
    high = int32x8_max(high,xi);
  }
  low = int32x8_min(low,int32x8_10325476(low));
  high = int32x8_max(high,int32x8_10325476(high));
  low = int32x8_min(low,int32x8_23016745(low));
  high = int32x8_max(high,int32x8_23016745(high));
  low = int32x8_min(low,int32x8_45670123(low));
  high = int32x8_max(high,int32x8_45670123(high));
  *xmin = int32x8_0(low);
  *xmax = int32x8_0(high);
}

void radixwrapper_int32(int32 *x,int32 *y,long long n)
{
  if (n <= CUTOFF) { djbsort_int32(x,n); return; }
  int32 xmin,xmax;
  long long i,t,c[MAXSPLIT];
  long long split = n/TARGET;
  while (split > MAXSPLIT) split >>= 1;
  horizontal_minmax_int32_atleast8(&xmin,&xmax,x,n);
  uint32 range = xmax-xmin;
  if (range == 0) return;
  long long shift = 0;
  while (range >= split) { shift += 1; range >>= 1; }
  for (i = 0;i < split;++i) c[i] = 0;
  for (i = 0;i < n;++i) { int32 xi = x[i]; uint32 u = xi-xmin; u >>= shift; ++c[u]; }
  t = 0;
  for (i = 0;i < split;++i) { long long ci = c[i]; c[i] = t; t += ci; }
  for (i = 0;i < n;++i) { int32 xi = x[i]; uint32 u = xi-xmin; u >>= shift; y[c[u]++] = xi; }
  t = 0;
  for (i = 0;i < split;++i) {
    long long ci = c[i];
    memcpy(x+t,y+t,4*(ci-t));
    radixwrapper_int32(x+t,y+t,ci-t);
    t = ci;
  }
}
