/* adapted from rijndael-alg-ref.c   v2.0   August '99 
 * Rijndael Reference ANSI C code
 * authors: Paulo Barreto
 *          Vincent Rijmen
 *
 * Demonstrates a state-recovery given an internal collision in Pelican-MAC
 *
 * This files requires a 64-bit machine to run. It allocates and uses 36GB of RAM
 * It stores 80GB of data to a hard drive. 
 *
 *
 * C. Bouillaguet, P. Derbez, P.-A. Fouque, February 2011
 */



#include <sys/mman.h> // mmap

#include <sys/types.h> // open and friends
#include <sys/stat.h>
#include <fcntl.h>

#include <unistd.h> // ftruncate
#include <stdio.h>
#include <string.h> //memcpy
#include <stdlib.h>
#include <assert.h>
#include <stdint.h>

#include "rijndael.h"

#define SC	((BC - 4))

#include "boxes-ref.h"

int n_sols[256][256];
word8 * sols_x[256][256];
word8 * sols_y[256][256];

//const word8 mc_matrix[4][4] = { {} {} {} {} };
const word8 imc_matrix[4][4] = { 
  {0x0e, 0x0b, 0x0d, 0x09},
  {0x09, 0x0e, 0x0b, 0x0d},
  {0x0d, 0x09, 0x0e, 0x0b},
  {0x0b, 0x0d, 0x09, 0x0e}
};


// [charles] that's crazy inefficient
static inline void ShiftRow(word8 a[4][MAXBC], word8 d, word8 BC) {
	/* Row 0 remains unchanged
	 * The other three rows are shifted a variable amount
	 */
	word8 tmp[MAXBC];
	int i, j;
	
	for(i = 1; i < 4; i++) {
		for(j = 0; j < BC; j++) tmp[j] = a[i][(j + shifts[SC][i][d]) % BC];
		for(j = 0; j < BC; j++) a[i][j] = tmp[j];
	}
}

static inline void Substitution(word8 a[4][MAXBC], const word8 box[256], word8 BC) {
	/* Replace every byte of the input by the byte at that place
	 * in the nonlinear S-box
	 */
	int i, j;
	
	for(i = 0; i < 4; i++)
		for(j = 0; j < BC; j++) a[i][j] = box[a[i][j]] ;
}
   
static inline void MixColumn(word8 a[4][MAXBC], word8 BC) {
        /* Mix the four bytes of every column in a linear way
	 */
	word8 b[4][MAXBC];
	int i, j;
		
	for(j = 0; j < BC; j++)
		for(i = 0; i < 4; i++)
			b[i][j] = mul(2,a[i][j])
				^ mul(3,a[(i + 1) % 4][j])
				^ a[(i + 2) % 4][j]
				^ a[(i + 3) % 4][j];
	for(i = 0; i < 4; i++)
		for(j = 0; j < BC; j++) a[i][j] = b[i][j];
}

static inline void InvMixColumn(word8 a[4][MAXBC], word8 BC) {
        /* Mix the four bytes of every column in a linear way
	 * This is the opposite operation of Mixcolumn
	 */
	word8 b[4][MAXBC];
	int i, j;
	
	for(j = 0; j < BC; j++)
	for(i = 0; i < 4; i++)             
		b[i][j] = mul(0xe,a[i][j])
			^ mul(0xb,a[(i + 1) % 4][j])                 
			^ mul(0xd,a[(i + 2) % 4][j])
			^ mul(0x9,a[(i + 3) % 4][j]);                        
	for(i = 0; i < 4; i++)
		for(j = 0; j < BC; j++) a[i][j] = b[i][j];
}


void print_state(word8 a[4][MAXBC], int BC) {

  int i,j;
  for(i=0; i<4; i++) {
    for(j=0; j<BC; j++)
      printf("%02x ", a[i][j]);
    printf("\n");
  }
}

void init_diff_tables() {
  int i,j;

  // count solutions
  for(i=0; i<0x100; i++) 
    for(j=0; j<0x100; j++) 
      n_sols[i][j] = 0;

  for(i=0; i<0x100; i++) 
    for(j=0; j<0x100; j++) 
      n_sols[i^j][ S[i] ^ S[j] ]++;

  // allocate space for solutions    
  for(i=0; i<0x100; i++) 
    for(j=0; j<0x100; j++) {
      sols_x[i][j] = (word8 *) malloc(n_sols[i][j]);
      sols_y[i][j] = (word8 *) malloc(n_sols[i][j]);
      n_sols[i][j] = 0;
    }

  // fill solution tables    
  for(i=0; i<0x100; i++)
    for(j=0; j<0x100; j++) {
      int delta_i = i^j;
      int delta_o = S[i] ^ S[j];

      sols_x[ delta_i ][ delta_o ][ n_sols[ delta_i ][ delta_o ] ] = i;
      sols_y[ delta_i ][ delta_o ][ n_sols[ delta_i ][ delta_o ] ] = j;
      n_sols[ delta_i ][ delta_o ]++;
    }
}



//////// ***************************** Tables generation **********************************

// with this template, I am sure that the compiler will correctly optimise the two "switches" below
template<int column>
void StateRecovery_gentables(const word8 delta_i, const word8 delta_o[4][4], const char *filename) {

  int i,j;
  word8 delta_op[4][4];
  uint32_t * data;     // actual solutions
  word8 * n_entries; // counting solutions
  uint32_t * indices;   // where do they start

  const int64_t N = 0x100000000;
  int ctr0,ctr1;
  int ctr2,ctr3;

  uint64_t n=0;

  word8 W_2[4];
  word8 X3[4];
  word8 Xp3[4];

  word8 Wp_2[4];
  word8 delta[4]; // differences on the diagonal(s) of X_2
  
  printf("\n============================================\n");
  printf("Creating table for column %d\n", column);
  printf("============================================\n\n");

  for(i=0; i<4; i++)
    for(j=0; j<4; j++)
      delta_op[i][j] = delta_o[i][j];

  // compute the difference just after the last SubByte
  InvMixColumn(delta_op, 4);
  ShiftRow(delta_op,1,4);                

  // first, try to allocate the tables
  n_entries = (word8 *) calloc(N, sizeof(word8));
  if (n_entries == NULL) {
    printf("n_entries allocation failed !\n");
    exit(1);
  }

  indices = (uint32_t *) calloc(N, sizeof(uint32_t));
  if (indices == NULL) {
    printf("indices allocation failed !\n");
    exit(1);
  }

  data = (uint32_t *) calloc(N, sizeof(uint32_t));
  if (data == NULL) {
    printf("data allocation failed !\n");
    exit(1);
  }

  printf("tables allocated\n");




  ///************************* first pass : counting
  
  // STEP 1-a : guess first column of X_3
  for(ctr0=0; ctr0<0x100; ctr0++)
  for(ctr1=0; ctr1<0x100; ctr1++) 
  {
    n++;
    printf("\r%.1f", (100.0*n)/0x10000);
    fflush(stdout);

    X3[0] = ctr0;
    X3[1] = ctr1;

    for(ctr2=0; ctr2<0x100; ctr2++)
    for(ctr3=0; ctr3<0x100; ctr3++) 
    {

      X3[2] = ctr2;
      X3[3] = ctr3;

    for(i=0; i<4; i++)
      Xp3[i] = Si[ S[ X3[i] ] ^ delta_op[i][column] ]; 
  
    // STEP 1-b : computes bytes of X_2 and X'_2    
    // inverse MixColumn to get first column of W_2 and W'_2
    for(i=0; i<4; i++) {
      Wp_2[i] = W_2[i] = 0;    
      for(j=0; j<4; j++) {
	W_2[i]  ^= mul(imc_matrix[i][j], X3[j]);
	Wp_2[i] ^= mul(imc_matrix[i][j], Xp3[j]);
      }
    }
    // inverse SubBytes to get diagonal(s) of X_2
    for(i=0; i<4; i++)
      delta[i] = Si[ W_2[i] ] ^ Si[ Wp_2[i] ];
    
    // now try to put delta in the big_table
    uint32_t key;  // the key is (alpha,beta,gamma,delta). Obtaining it depends on the column
    switch(column) {
    case 0: key = mul(0x8d,delta[0]) ^ (delta[1] << 8) ^ (mul(0x8d,delta[2]) << 16) ^ (delta[3] << 24); break;
    case 1: key = mul(0xf6,delta[3]) ^ (delta[0] << 8) ^ (mul(0xf6,delta[1]) << 16) ^ (delta[2] << 24); break;
    case 2: key = delta[2] ^ (mul(0x8d,delta[3]) << 8) ^ (delta[0] << 16) ^ (mul(0x8d,delta[1]) << 24); break;
    case 3: key = delta[1] ^ (mul(0xf6,delta[2]) << 8) ^ (delta[3] << 16) ^ (mul(0xf6,delta[0]) << 24); break;
    }

    /*    printf("\n my diagonal : %02x %02x %02x %02x\n", delta[0], delta[1], delta[2], delta[3]);
    printf("\n hitting the table with %08x\n", key);
    return;
    */

    //#pragma omp atomic
    n_entries[key] += 1;
  } // loop for X_3  
  }

  printf("\n\ndone counting solutions. Now: prefix-sum\n");


  int n_max = 0;
  indices[0] = 0;
  for(n=1; n<N; n++) {
    if (n_entries[n] > n_max) n_max = n_entries[n];
    indices[n] = indices[n-1] + n_entries[n-1];
  }

  printf("last one (should be 2^32): %x\n", indices[N-1] + n_entries[N-1]);
  printf("max. number of sols : %d\n", n_max);
  printf("Now filling the data structure\n\n");

  ///***************************** filling. Very same as before...


  n=0;
  
  //#pragma omp parallel for shared(indices,data,n)		\
  //  private(i,j,W_2, Wp_2, delta)				\
  //  schedule(dynamic, 1)
  
  for(ctr0=0; ctr0<0x100; ctr0++)
  for(ctr1=0; ctr1<0x100; ctr1++) {
    n++;
    printf("\r%.1f", (100.0*n)/0x10000);
    fflush(stdout);

    X3[0] = ctr0;
    X3[1] = ctr1;


    for(ctr2=0; ctr2<0x100; ctr2++)
    for(ctr3=0; ctr3<0x100; ctr3++) {

    X3[2] = ctr2;
    X3[3] = ctr3;

    for(i=0; i<4; i++)
      Xp3[i] = Si[ S[ X3[i] ] ^ delta_op[i][column] ]; 
  
    for(i=0; i<4; i++) {
      Wp_2[i] = W_2[i] = 0;    
      for(j=0; j<4; j++) {
	W_2[i]  ^= mul(imc_matrix[i][j], X3[j] );
	Wp_2[i] ^= mul(imc_matrix[i][j], Xp3[j]);
      }
    }
    for(i=0; i<4; i++)
      delta[i] = Si[ W_2[i] ] ^ Si[ Wp_2[i] ];
    
    uint32_t key;  // the key is (alpha,beta,gamma,delta). Obtaining it depends on the column
    switch(column) {
    case 0: key = mul(0x8d,delta[0]) ^ (delta[1] << 8) ^ (mul(0x8d,delta[2]) << 16) ^ (delta[3] << 24); break;
    case 1: key = mul(0xf6,delta[3]) ^ (delta[0] << 8) ^ (mul(0xf6,delta[1]) << 16) ^ (delta[2] << 24); break;
    case 2: key = delta[2] ^ (mul(0x8d,delta[3]) << 8) ^ (delta[0] << 16) ^ (mul(0x8d,delta[1]) << 24); break;
    case 3: key = delta[1] ^ (mul(0xf6,delta[2]) << 8) ^ (delta[3] << 16) ^ (mul(0xf6,delta[0]) << 24); break;
    }

    uint32_t value = X3[0] ^ (X3[1] << 8) ^ (X3[2] << 16) ^ (X3[3] << 24) ;


    //#pragma omp critical
    data[ indices[key]++ ] = value;
    } // loop for X_3  
  }

  free(indices); // we don't need those anymore.

  // *************** now, save n_entries/data into a file.

  printf("\n\nSaving tables to %s\n", filename);

  // first, try to create the file 
  int fd = open(filename, O_CREAT | O_TRUNC | O_RDWR, 0644);
  if (fd == -1) {
    perror("failure to create the file\n");
    exit(1);
  }

  int64_t file_size = N*(sizeof(word8) + sizeof(uint32_t));

  // set the right file size
  if (ftruncate(fd, file_size) != 0) {
    perror("failure to resize the file\n");
    exit(1);
  }

  // project the file into memory
  void *mmaped_zone = mmap(NULL, file_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
  if (mmaped_zone == MAP_FAILED) {
    perror("mmap failed : ");
    exit(1);
  }

  // be nice with the VM subsystem
  if (madvise(mmaped_zone, file_size, MADV_SEQUENTIAL) != 0) perror("madvise() error : ");  

  word8 * n_entries_target = (word8 *) mmaped_zone;
  uint32_t * data_target  = (uint32_t *) (n_entries_target + N);

  // save to file
  memcpy(n_entries_target, n_entries, N*sizeof(word8));
  memcpy(data_target, data, N*sizeof(uint32_t));

  // write the file to the disk
  munmap(mmaped_zone, file_size); // save the tables in the file !
  close(fd);

  // release the memory
  free(data);
  free(n_entries);
}



//////// ***************************** Tables matching **********************************

void StateRecovery_final(const word8 delta_i, word8 delta_o[4][4], const char *filenames[4]) {

  // open and map the table files into memory
  int fd[4];
  void *mmaped_zone[4];
  word8 * n_entries[4];
  uint32_t * data[4];

  int i,j;
  int64_t k;

  int64_t N = 0x100000000;
  int64_t file_size = N*(sizeof(word8) + sizeof(uint32_t));

  word8 delta_op[4][4];
  int dev_null;

  // alpha,beta,gamma,delta = 1e,5e,1e,93

  // compute the difference just after the last SubByte
  for(i=0; i<4; i++)
    for(j=0; j<4; j++)
      delta_op[i][j] = delta_o[i][j];

  InvMixColumn(delta_op, 4);
  ShiftRow(delta_op,1,4);                
  
  // preparate the 4 tables
  for(i=0; i<4; i++) {

    // first, open the files 
    fd[i] = open(filenames[i], O_RDONLY, 0644);
    if (fd[i] == -1) {
      perror("failure to open the file\n");
      exit(1);
    }

    // project the file into memory
    mmaped_zone[i] = mmap(NULL, file_size, PROT_READ, MAP_SHARED, fd[i], 0);
    if (mmaped_zone == MAP_FAILED) {
      perror("mmap failed : ");
      exit(1);
    }
    
    // be nice with the VM subsystem
    if (madvise(mmaped_zone[i], file_size, MADV_SEQUENTIAL) != 0) perror("madvise() error : ");  

    n_entries[i] = (word8 *) mmaped_zone[i];
    data[i]  = (uint32_t *) (n_entries[i] + N);

    printf("mapped %s. Caching index...\n", filenames[i]);
    fflush(stdout);
    //    for(j=0; j<N; j++)
    //      dev_null += n_entries[i][j];
  }

  


  int64_t idx[4] = { 0,0,0,0 }; // idx[i] denotes the offset in file #i where the next n_entries[i][k] entries are.
  int l,m;


  for(i=0; i<0x10000; i++) 
    {
    
    printf("\r%.1f -- %08llx", i*100./0x10000,k);
    fflush(stdout);

    for(j=0; j<0x10000; j++) 
      {

  k = (i << 16) + j;
  /*  for(i=0;i<k;i++)
    for(j=0;j<4;j++)
    idx[j] += n_entries[j][i];
*/

      int ctr[4];
      for(ctr[0]=0; ctr[0]<n_entries[0][k]; ctr[0]++)
      for(ctr[1]=0; ctr[1]<n_entries[1][k]; ctr[1]++)
      for(ctr[2]=0; ctr[2]<n_entries[2][k]; ctr[2]++)
      for(ctr[3]=0; ctr[3]<n_entries[3][k]; ctr[3]++) {

        word8 X[4][4];
        word8 Xp[4][4];
        uint32_t z[4];

        //assemble candidate X_3

        // load data
        for(l=0; l<4; l++)
          z[l] = data[l][idx[l] + ctr[l]];

        // convert to lousy format
        for(l=0; l<4; l++) 
          for(m=0; m<4; m++) 
            X[m][l] = (z[l] >> (8*m)) & 0xff;

        // deduce X'_3
        for(l=0; l<4; l++) 
          for(m=0; m<4; m++) 
            Xp[l][m] = Si[ S[ X[l][m] ] ^ delta_op[l][m] ]; 

  
        // deduce X_2 and X'_2, by usual decyption
        InvMixColumn(X,4);            InvMixColumn(Xp,4);
        ShiftRow(X,1,4);              ShiftRow(Xp,1,4);
        Substitution(X,Si,4);         Substitution(Xp,Si,4);

        // deduce X_1 and X'_1
        InvMixColumn(X,4);            InvMixColumn(Xp,4);
        ShiftRow(X,1,4);              ShiftRow(Xp,1,4);
        Substitution(X,Si,4);         Substitution(Xp,Si,4);

        // deduce X_0 and X'_0
        InvMixColumn(X,4);            InvMixColumn(Xp,4);
        ShiftRow(X,1,4);              ShiftRow(Xp,1,4);
        Substitution(X,Si,4);         Substitution(Xp,Si,4);

        // now check that the difference pattern is right
        if ((X[0][0] ^ Xp[0][0]) == delta_i) {
        if (X[0][1] == Xp[0][1]) {
        if (X[0][2] == Xp[0][2]) {
        if (X[0][3] == Xp[0][3]) {
        if (X[1][0] == Xp[1][0]) {
        if (X[1][1] == Xp[1][1]) {
        if (X[1][2] == Xp[1][2]) {
        if (X[1][3] == Xp[1][3]) {
        if (X[2][0] == Xp[2][0]) {
        if (X[2][1] == Xp[2][1]) {
        if (X[2][2] == Xp[2][2]) {
        if (X[2][3] == Xp[2][3]) {
        if (X[3][0] == Xp[3][0]) {
        if (X[3][1] == Xp[3][1]) {
        if (X[3][2] == Xp[3][2]) {
        if (X[3][3] == Xp[3][3]) {
        
        // Seems correct ! Re-build and show X_3
        for(l=0; l<4; l++) 
          for(m=0; m<4; m++) 
            X[m][l] = (z[l] >> (8*m)) & 0xff;

        printf("actual X_3 likely found :\n");
        print_state(X,4);
        printf("\n\n");
	fflush(stdout);
	return;

	}}}}}}}}}}}}}}}}
      }

      // advance to the next quadruplet (alpha,beta,gamma,delta)
      for(l=0; l<4; l++)
        idx[l] += n_entries[l][k];
        

    }
  }
}

//////// ***************************** Main **********************************

int main() {

  word8 P_1[4][4];
  word8 P_2[4][4];

  word8 C_1[4][4];
  word8 C_2[4][4];

  int i,j,r;
  init_diff_tables();

  printf("\nWARNING : this program will try to allocate something like 36GB of RAM and put stuff in it.\n");
  printf("Then, it will write about 80GB of junk to your hard drive.\n");

  printf("\nRun it at your own risk. You've been warned. And don't forget to remove the junk afterwards...\n\n");


  // initialize random internal state of the MAC
  for(i=0; i<4; i++)
    for(j=0; j<4; j++)
      P_1[i][j] = P_2[i][j] = rand() & 0xff;
      
  // inject known difference in input byte.
  P_2[0][0] = P_1[0][0] ^ 0x01;

  for(i=0; i<4; i++)
    for(j=0; j<4; j++) {
      C_1[i][j] = P_1[i][j];
      C_2[i][j] = P_2[i][j];
    }

  word8 X_3[4][4];
  word8 delta_X_2[4][4];
  // 4 full rounds, no key addition
  for(r=0; r<4; r++) {

    if (r==2) {
      printf("\nX_2 : \n");
      print_state(C_1,4);
    }
    
    if (r == 2) // store delta_X_2 for cheating
      for(i=0; i<4; i++)
	for(j=0; j<4; j++) 
	  delta_X_2[i][j] = C_1[i][j] ^ C_2[i][j];


    if (r == 3) {// store X_3 for cheating
      for(i=0; i<4; i++)
	for(j=0; j<4; j++) 
	  X_3[i][j] = C_1[i][j];

      printf("X_3 : \n");
      print_state(C_1, 4);
      
      printf("\nX'_3 : \n");
      print_state(C_2, 4);
    }

    Substitution(C_1,S,4);          Substitution(C_2,S,4);
    ShiftRow(C_1,0,4);              ShiftRow(C_2,0,4);
    MixColumn(C_1,4);               MixColumn(C_2,4);
  }

  word8 delta_o[4][4];
  for(i=0; i<4; i++)
    for(j=0; j<4; j++)
      delta_o[i][j] = C_1[i][j] ^ C_2[i][j];

  printf("\ndelta X_2 : \n");
  print_state(delta_X_2, 4);

  printf("\n consequence: (alpha,beta,gamma,delta) = (%02x,%02x,%02x,%02x)\n\n", delta_X_2[1][0], delta_X_2[0][1], delta_X_2[0][2], delta_X_2[2][3]);


  // data initialized ; run the attack !

  //  printf("will be running with %d threads\n", omp_get_max_threads());
  const char *filenames[4] = { "/aspilia/T_0.dat", "/aspilia/T_1.dat", "/aspilia/T_2.dat", "/aspilia/T_3.dat" };


  StateRecovery_gentables<0>(0x01, delta_o, filenames[0]);
  StateRecovery_gentables<1>(0x01, delta_o, filenames[1]);
  StateRecovery_gentables<2>(0x01, delta_o, filenames[2]);
  StateRecovery_gentables<3>(0x01, delta_o, filenames[3]);

  
  printf("\n\n");

  StateRecovery_final(0x01, delta_o, filenames);

  printf("\n");
}
