#include <stdio.h>
#include <stdlib.h>

#define UNDEF -1

#define IN  6
#define OUT 4

#define NUMIN 16
#define NUMOUT 16

/* Type for an s-box */
typedef int sbox[NUMIN];

/* Type for a table (differential or linear) */
typedef int tab[NUMIN][NUMOUT];

struct poswithpenalty {
  int pos;
  double penalty;
};

sbox posch = { 12, 5, 6, 11, 9, 0, 10, 13, 3, 14, 15, 8, 4, 7, 1, 2 };

/* S-BOX HELPER FUNCTIONS */
/* Initializes an s-box as everywhere undefined */
void initsbox(sbox s) {
  int i;

  for(i = 0; i < NUMIN; i++) {
    s[i] = UNDEF;
  }
}

/* Prints an s-box */
void printsbox(sbox s) {
  int i;

  for(i = 0; i < NUMIN; i++) {
    if(s[i] == UNDEF) {
      printf("-");
    } else {
      printf("%1X", s[i]);
    }
    if(i%16==15) {
      printf(" ");
    } 
  }
  printf("\n");
}

/* Copies s1 into s2 */
int sboxcopy(sbox s1, sbox s2) {
  int i;

  for(i = 0; i < NUMIN; i++) {
    s2[i] = s1[i];
  }
}

/* Counts the number of undefined places in an s-box */
int countUNDEF(sbox s) {
  int x;
  int res;

  res = 0;
  for(x = 0; x < NUMIN; x++) {
    if(s[x] == UNDEF) {
      res++;
    }
  }

  return(res);
}

/**************************/
/* TABLE HELPER FUNCTIONS */
/**************************/

/* Initializes t componentwise with initval */
void inittab(tab t, int initval) {
  int i, j;

  for(i = 0; i < NUMIN; i++) {
    for(j = 0; j < NUMOUT; j++) {
      t[i][j] = initval;
    }
  }
}

/* Prints t */
void printtab(tab t) {
  int i, j;

  printf("     ");
  for(j = 0; j < NUMOUT; j++) {
    printf(" %02o ", j);
  }
  printf("\n");
  for(i = 0; i < NUMIN; i++) {
    printf("%02o | ", i);
    for(j = 0; j < NUMOUT; j++) {
      printf("%+03d ", t[i][j]);
    }
    printf("\n");
  }  
}

/* Checks if t1 is componentwise equal to t2 */
int tabeq(tab t1, tab t2) {
  int i, j;

  for(i = 0; i < NUMIN; i++) {
    for(j = 0; j < NUMOUT; j++) {
      if(t1[i][j] != t2[i][j]) {
	return(0);
      }
    }
  }

  return(1);
}

/* Checks if t1 is componentwise lower or equal to t2 */
int tableq(tab t1, tab t2) {
  int i, j;

  for(i = 0; i < NUMIN; i++) {
    for(j = 0; j < NUMOUT; j++) {
      if(t1[i][j] > t2[i][j]) {
	return(0);
      }
    }
  }

  return(1);
}

/* Copies t1 into t2 */
int tabcopy(tab t1, tab t2) {
  int i, j;

  for(i = 0; i < NUMIN; i++) {
    for(j = 0; j < NUMOUT; j++) {
      t2[i][j] = t1[i][j];
    }
  }
}

/* DIFFERENTIAL STUFF */
/* DUMBEST METHOD EVER = DEFINITION */
void difftabDUMB(sbox s, tab res) {
  int dx, dy, x;

  for(dx = 0; dx < NUMIN; dx++) {
    for(dy = 0; dy < NUMOUT; dy++) {
      res[dx][dy] = 0;
      for(x = 0; x < NUMIN; x++) {
        if(s[x] != UNDEF && s[x^dx] != UNDEF) {
          if((s[x]^s[x^dx]) == dy) {
	    res[dx][dy]++;
	  }
	}
      }
    }
  }
}

/* Computes the differential table for s-box s and stores the result in res */
void difftab(sbox s, tab res) {
  int dx, dy, x;

  inittab(res,0);
  for(x = 0; x < NUMIN; x++) {
    for(dx = 0; dx < NUMIN; dx++) {
      if(s[x] != UNDEF && s[x^dx] != UNDEF) {
        res[dx][s[x]^s[x^dx]]++;
      }
    }
  }
}

/* Stores the maximal admissible differential table in res */
void maxdifftab(tab res) {
  int i,j;

  /* diff_S(dx,dy) <= 4 */
  inittab(res, 4);

  /* Trivial */
  res[0][0] = 16;
  for(j = 1; j < NUMOUT; j++) {
    res[0][j] = 0;
  }

  /* diff_S(dx,dy) = 0 for wt(dx) = wt(dy) = 1. */
  for(i = 0; i < NUMIN; i++) {
    for(j = 0; j < NUMOUT; j++) {
      if(hamming(j) == 1) {
	if(hamming(i) == 1) {
	  res[i][j] = 0;
	}
      }
    }
  }
}

/* LINEAR STUFF */
/* Computes the hamming weight of val */
/* Better: http://en.wikipedia.org/wiki/Hamming_weight */
int hamming(int val) {
  int akt;
  int res;

  akt = val;
  res = 0;
  while(akt != 0) {
    res = res + (akt&1);
    akt >>= 1;
  }

  return(res);
}

/* Computes the scalar product of v1 and v2 */
int sprod(int v1, int v2) {
  return(hamming(v1&v2));
}

/* Computes the linear table for s-box s and stores the result in res */
void lintab(sbox s, tab res) {
  int alpha, beta, x;
 
  for(alpha = 0; alpha < NUMIN; alpha++) {
    for(beta = 0; beta < NUMOUT; beta++) {
      res[alpha][beta] = 0;
      for(x = 0; x < NUMIN; x++) {
        if(s[x] != UNDEF) {
	  if((sprod(s[x],beta)&1) == ((sprod(x,alpha))&1)) {
	    res[alpha][beta]++;
	  } else {
	    res[alpha][beta]--;
	  }
	}
      }
    }
  }
}

/* Stores the maximal admissible bias table in res */
void maxlintab(tab res) {
  int i,j;

  /* C2 */
  inittab(res, 8);

  for(i = 0; i < NUMIN; i++) {
    for(j = 0; j < NUMOUT; j++) {
      if(hamming(i) == 1 && hamming(j) == 1) {
	res[i][j] = 4;
      }
    }
  }

  /* trivial */
  res[0][0] = 16;

  for(i = 1; i < NUMIN; i++) {
    res[i][0] = 0;
  }
}

/* Computes a heuristical value that describes the quality of a linear table ltab
   with respect to a maximal one */
double computepenalty(tab ltab, tab maxltab) {
  int alpha, beta;
  int tmp;

  double res;
 
  res = 0;
  for(alpha = 0; alpha < NUMIN; alpha++) {
    for(beta = 0; beta < NUMOUT; beta++) {
      tmp = abs(ltab[alpha][beta])-maxltab[alpha][beta];
      if(tmp >= 0) {
	res += (1<<tmp);
      } else {
	res += 1.0/(1<<(-tmp));
      }

      res--;
    }
  }
 
  return(res);
}

/* Array helper */

/* Comparison function of two poswithpenalty elements wrt. the double stored in there */
int comparePoswithpenalty(const void *p1, const void *p2) {
  struct poswithpenalty* s1;
  struct poswithpenalty* s2;

  s1 = (struct poswithpenalty*) p1;
  s2 = (struct poswithpenalty*) p2;

  return((s1->penalty > s2->penalty) - (s1->penalty < s2->penalty));
}

/* Prints nicely a poswithpenalty struct */
void printPoswithpenalty(struct poswithpenalty p1) {
  printf("(%d, %f)", p1.pos, p1.penalty);
}


/**********************************************/
/**********************************************/
/**********************************************/
/******* THIS IS WHERE YOUR WORK STARTS *******/
/**********************************************/
/**********************************************/
/**********************************************/

/* Computes the differential table for s-box s with additional value s[x0] = y0 
   given the differential table res for the s-box s and stores the result in res */
void difftabAcc(sbox s, int x0, int y0, tab dtab) {
  /* Go over all possibilities and incrementally update dtab. */
}

/* Computes the linear table for s-box s with additional value s[x0] = y0 
   given the linear table res for the s-box s and stores the result in res */
void lintabAcc(sbox s, int x0, int y0, tab ltab) {
  /* Go over all possibilities and incrementally update ltab. */
}

/* Computes the list and number of admissible options. Done for you :) */
void computeoptions(sbox s, tab maxdtab, tab dtab, int x0, int* res, int* num) {
  int y0;
  tab dtabbak;

  *num = 0;

  for(y0 = 0; y0 < NUMOUT; y0++) {
    tabcopy(dtab, dtabbak);
    difftabAcc(s, x0, y0, dtab);

    if(tableq(dtab, maxdtab)) {
      res[(*num)++] = y0;
    }

    tabcopy(dtabbak, dtab);
  }
}


/* backtracking procedure. Gets an s-box s and a new assignment s[x0] = y0 as well as all
   necessary tables to find an s-box that fulfills the requirements stated by maxdtab and 
   maxltab. */
int backtrack(int x0, int y0, tab maxdtab, tab maxltab, tab actdtab, tab actltab, sbox s) {
  /* 
     If y0 is not undefined, update s, actdtab and actltab

     If you foud a fully defined s-box exit

     compute x1 all options y1 you have. Select one of the x1's that have a minimal number
     possible of y1's. If there are more with minimal number, select one uniformly at random.

     For all options for the selected x1 compute the penalities for the possible y1 and
     backtrack through those options in increasing order (corresponding to the penalities)
  */

  return(0);
}

void find_sbox() {
  tab maxdtab, maxltab, actltab,actdtab;
  sbox s;
  int x0;

  printf("**********************************************\n");
  printf("*** Hello! Starting search for an s-box... ***\n");
  printf("**********************************************\n\n");

  printf("Computing the Coppersmith/Poschmann bounds...\n");
  maxdifftab(maxdtab);
  maxlintab(maxltab);
  printf("Done.\n");

  printf("Generating a partially empty s-box...\n");
  initsbox(s);
  inittab(actltab, 0);
  inittab(actdtab, 0);

  for(x0 = 3; x0 < NUMIN; x0++) {
    s[x0] = posch[x0];
    difftabAcc(s, x0, posch[x0], actdtab);
    lintabAcc(s, x0,  posch[x0], actltab);
  }
  printf("Done.\n");

  printf("STARTING SEARCH...\n");
  backtrack(0, UNDEF, maxdtab, maxltab, actdtab, actltab, s);
}

int main() {
  find_sbox();
}

