/*
 * aes_cbc.c
 *
 * AES Cipher Block Chaining Mode
 *
 * David A. McGrew
 * Cisco Systems, Inc.
 */

/*
 *	
 * Copyright (c) 2001-2006, Cisco Systems, Inc.
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 
 *   Redistributions of source code must retain the above copyright
 *   notice, this list of conditions and the following disclaimer.
 * 
 *   Redistributions in binary form must reproduce the above
 *   copyright notice, this list of conditions and the following
 *   disclaimer in the documentation and/or other materials provided
 *   with the distribution.
 * 
 *   Neither the name of the Cisco Systems, Inc. nor the names of its
 *   contributors may be used to endorse or promote products derived
 *   from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 * OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 */


#include "aes_cbc.h"
#include "alloc.h"

debug_module_t mod_aes_cbc = {
  0,                 /* debugging is off by default */
  "aes cbc"          /* printable module name       */
};



err_status_t
aes_cbc_alloc(cipher_t **c, int key_len) {
  extern cipher_type_t aes_cbc;
  uint8_t *pointer;
  int tmp;

  debug_print(mod_aes_cbc, 
	      "allocating cipher with key length %d", key_len);

  if (key_len != 16)
    return err_status_bad_param;
  
  /* allocate memory a cipher of type aes_icm */
  tmp = (sizeof(aes_cbc_ctx_t) + sizeof(cipher_t));
  pointer = (uint8_t*)crypto_alloc(tmp);
  if (pointer == NULL) 
    return err_status_alloc_fail;

  /* set pointers */
  *c = (cipher_t *)pointer;
  (*c)->type = &aes_cbc;
  (*c)->state = pointer + sizeof(cipher_t);

  /* increment ref_count */
  aes_cbc.ref_count++;

  /* set key size        */
  (*c)->key_len = key_len;

  return err_status_ok;  
}

err_status_t
aes_cbc_dealloc(cipher_t *c) {
  extern cipher_type_t aes_cbc;

  /* zeroize entire state*/
  octet_string_set_to_zero((uint8_t *)c, 
			   sizeof(aes_cbc_ctx_t) + sizeof(cipher_t));

  /* free memory */
  crypto_free(c);

  /* decrement ref_count */
  aes_cbc.ref_count--;
  
  return err_status_ok;  
}

err_status_t
aes_cbc_context_init(aes_cbc_ctx_t *c, const uint8_t *key, 
		     cipher_direction_t dir) {
  v128_t tmp_key;

  /* set tmp_key (for alignment) */
  v128_copy_octet_string(&tmp_key, key);

  debug_print(mod_aes_cbc, 
	      "key:  %s", v128_hex_string(&tmp_key)); 

  /* expand key for the appropriate direction */
  switch (dir) {
  case (direction_encrypt):
    aes_expand_encryption_key(&tmp_key, c->expanded_key);
    break;
  case (direction_decrypt):
    aes_expand_decryption_key(&tmp_key, c->expanded_key);
    break;
  default:
    return err_status_bad_param;
  }


  return err_status_ok;
}


err_status_t
aes_cbc_set_iv(aes_cbc_ctx_t *c, void *iv) {
  int i;
/*   v128_t *input = iv; */
  uint8_t *input = (uint8_t*) iv;
 
  /* set state and 'previous' block to iv */
  for (i=0; i < 16; i++) 
    c->previous.v8[i] = c->state.v8[i] = input[i];

  debug_print(mod_aes_cbc, "setting iv: %s", v128_hex_string(&c->state)); 

  return err_status_ok;
}

err_status_t
aes_cbc_encrypt(aes_cbc_ctx_t *c,
		unsigned char *data, 
		unsigned int *bytes_in_data) {
  int i;
  unsigned char *input  = data;   /* pointer to data being read    */
  unsigned char *output = data;   /* pointer to data being written */
  int bytes_to_encr = *bytes_in_data;

  /*
   * verify that we're 16-octet aligned
   */
  if (*bytes_in_data & 0xf) 
    return err_status_bad_param;

  /*
   * note that we assume that the initialization vector has already
   * been set, e.g. by calling aes_cbc_set_iv()
   */
  debug_print(mod_aes_cbc, "iv: %s", 
	      v128_hex_string(&c->state));
  
  /*
   * loop over plaintext blocks, exoring state into plaintext then
   * encrypting and writing to output
   */
  while (bytes_to_encr > 0) {
    
    /* exor plaintext into state */
    for (i=0; i < 16; i++)
      c->state.v8[i] ^= *input++;

    debug_print(mod_aes_cbc, "inblock:  %s", 
	      v128_hex_string(&c->state));

    aes_encrypt(&c->state, c->expanded_key);

    debug_print(mod_aes_cbc, "outblock: %s", 
	      v128_hex_string(&c->state));

    /* copy ciphertext to output */
    for (i=0; i < 16; i++)
      *output++ = c->state.v8[i];

    bytes_to_encr -= 16;
  }

  return err_status_ok;
}

err_status_t
aes_cbc_decrypt(aes_cbc_ctx_t *c,
		unsigned char *data, 
		unsigned int *bytes_in_data) {
  int i;
  v128_t state, previous;
  unsigned char *input  = data;   /* pointer to data being read    */
  unsigned char *output = data;   /* pointer to data being written */
  int bytes_to_encr = *bytes_in_data;
  uint8_t tmp;

  /*
   * verify that we're 16-octet aligned
   */
  if (*bytes_in_data & 0x0f)
    return err_status_bad_param;    

  /* set 'previous' block to iv*/
  for (i=0; i < 16; i++) {
    previous.v8[i] = c->previous.v8[i];
  }

  debug_print(mod_aes_cbc, "iv: %s", 
	      v128_hex_string(&previous));
  
  /*
   * loop over ciphertext blocks, decrypting then exoring with state
   * then writing plaintext to output
   */
  while (bytes_to_encr > 0) {
    
    /* set state to ciphertext input block */
    for (i=0; i < 16; i++) {
     state.v8[i] = *input++;
    }

    debug_print(mod_aes_cbc, "inblock:  %s", 
	      v128_hex_string(&state));
    
    /* decrypt state */
    aes_decrypt(&state, c->expanded_key);

    debug_print(mod_aes_cbc, "outblock: %s", 
	      v128_hex_string(&state));

    /* 
     * exor previous ciphertext block out of plaintext, and write new
     * plaintext block to output, while copying old ciphertext block
     * to the 'previous' block
     */
    for (i=0; i < 16; i++) {
      tmp = *output;
      *output++ = state.v8[i] ^ previous.v8[i];
      previous.v8[i] = tmp;
    }

    bytes_to_encr -= 16;
  }

  return err_status_ok;
}


err_status_t
aes_cbc_nist_encrypt(aes_cbc_ctx_t *c,
		     unsigned char *data, 
		     unsigned int *bytes_in_data) {
  int i;
  unsigned char *pad_start; 
  int num_pad_bytes;
  err_status_t status;

  /* 
   * determine the number of padding bytes that we need to add - 
   * this value is always between 1 and 16, inclusive.
   */
  num_pad_bytes = 16 - (*bytes_in_data & 0xf);
  pad_start = data;
  pad_start += *bytes_in_data;
  *pad_start++ = 0xa0;
  for (i=0; i < num_pad_bytes; i++) 
    *pad_start++ = 0x00;
   
  /* 
   * increment the data size 
   */
  *bytes_in_data += num_pad_bytes;  

  /*
   * now cbc encrypt the padded data 
   */
  status = aes_cbc_encrypt(c, data, bytes_in_data);
  if (status) 
    return status;

  return err_status_ok;
}


err_status_t
aes_cbc_nist_decrypt(aes_cbc_ctx_t *c,
		     unsigned char *data, 
		     unsigned int *bytes_in_data) {
  unsigned char *pad_end;
  int num_pad_bytes;
  err_status_t status;

  /*
   * cbc decrypt the padded data 
   */
  status = aes_cbc_decrypt(c, data, bytes_in_data);
  if (status) 
    return status;

  /*
   * determine the number of padding bytes in the decrypted plaintext
   * - this value is always between 1 and 16, inclusive.
   */
  num_pad_bytes = 1;
  pad_end = data + (*bytes_in_data - 1);
  while (*pad_end != 0xa0) {   /* note: should check padding correctness */
    pad_end--;
    num_pad_bytes++;
  }
  
  /* decrement data size */
  *bytes_in_data -= num_pad_bytes;  

  return err_status_ok;
}


char 
aes_cbc_description[] = "aes cipher block chaining (cbc) mode";

/*
 * Test case 0 is derived from FIPS 197 Appendix A; it uses an
 * all-zero IV, so that the first block encryption matches the test
 * case in that appendix.  This property provides a check of the base
 * AES encryption and decryption algorithms; if CBC fails on some
 * particular platform, then you should print out AES intermediate
 * data and compare with the detailed info provided in that appendix.
 *
 */


uint8_t aes_cbc_test_case_0_key[16] = {
  0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 
  0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
};

uint8_t aes_cbc_test_case_0_plaintext[64] =  {
  0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
  0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff 
};

uint8_t aes_cbc_test_case_0_ciphertext[80] = {
  0x69, 0xc4, 0xe0, 0xd8, 0x6a, 0x7b, 0x04, 0x30, 
  0xd8, 0xcd, 0xb7, 0x80, 0x70, 0xb4, 0xc5, 0x5a,
  0x03, 0x35, 0xed, 0x27, 0x67, 0xf2, 0x6d, 0xf1, 
  0x64, 0x83, 0x2e, 0x23, 0x44, 0x38, 0x70, 0x8b

};

uint8_t aes_cbc_test_case_0_iv[16] = {
  0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
};


cipher_test_case_t aes_cbc_test_case_0 = {
  16,                                    /* octets in key            */
  aes_cbc_test_case_0_key,               /* key                      */
  aes_cbc_test_case_0_iv,                /* initialization vector    */
  16,                                    /* octets in plaintext      */
  aes_cbc_test_case_0_plaintext,         /* plaintext                */
  32,                                    /* octets in ciphertext     */
  aes_cbc_test_case_0_ciphertext,        /* ciphertext               */
  NULL                                   /* pointer to next testcase */
};


/*
 * this test case is taken directly from Appendix F.2 of NIST Special
 * Publication SP 800-38A
 */

uint8_t aes_cbc_test_case_1_key[16] = {
  0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
  0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c,
};

uint8_t aes_cbc_test_case_1_plaintext[64] =  {
  0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 
  0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a,
  0xae, 0x2d, 0x8a, 0x57, 0x1e, 0x03, 0xac, 0x9c, 
  0x9e, 0xb7, 0x6f, 0xac, 0x45, 0xaf, 0x8e, 0x51,
  0x30, 0xc8, 0x1c, 0x46, 0xa3, 0x5c, 0xe4, 0x11,
  0xe5, 0xfb, 0xc1, 0x19, 0x1a, 0x0a, 0x52, 0xef,
  0xf6, 0x9f, 0x24, 0x45, 0xdf, 0x4f, 0x9b, 0x17, 
  0xad, 0x2b, 0x41, 0x7b, 0xe6, 0x6c, 0x37, 0x10
};

uint8_t aes_cbc_test_case_1_ciphertext[80] = {
  0x76, 0x49, 0xab, 0xac, 0x81, 0x19, 0xb2, 0x46,
  0xce, 0xe9, 0x8e, 0x9b, 0x12, 0xe9, 0x19, 0x7d,
  0x50, 0x86, 0xcb, 0x9b, 0x50, 0x72, 0x19, 0xee,
  0x95, 0xdb, 0x11, 0x3a, 0x91, 0x76, 0x78, 0xb2,
  0x73, 0xbe, 0xd6, 0xb8, 0xe3, 0xc1, 0x74, 0x3b,
  0x71, 0x16, 0xe6, 0x9e, 0x22, 0x22, 0x95, 0x16, 
  0x3f, 0xf1, 0xca, 0xa1, 0x68, 0x1f, 0xac, 0x09, 
  0x12, 0x0e, 0xca, 0x30, 0x75, 0x86, 0xe1, 0xa7,
  0x39, 0x34, 0x07, 0x03, 0x36, 0xd0, 0x77, 0x99, 
  0xe0, 0xc4, 0x2f, 0xdd, 0xa8, 0xdf, 0x4c, 0xa3
};

uint8_t aes_cbc_test_case_1_iv[16] = {
  0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 
  0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
};

cipher_test_case_t aes_cbc_test_case_1 = {
  16,                                    /* octets in key            */
  aes_cbc_test_case_1_key,               /* key                      */
  aes_cbc_test_case_1_iv,                /* initialization vector    */
  64,                                    /* octets in plaintext      */
  aes_cbc_test_case_1_plaintext,         /* plaintext                */
  80,                                    /* octets in ciphertext     */
  aes_cbc_test_case_1_ciphertext,        /* ciphertext               */
  &aes_cbc_test_case_0                    /* pointer to next testcase */
};

cipher_type_t aes_cbc = {
  (cipher_alloc_func_t)          aes_cbc_alloc,
  (cipher_dealloc_func_t)        aes_cbc_dealloc,  
  (cipher_init_func_t)           aes_cbc_context_init,
  (cipher_encrypt_func_t)        aes_cbc_nist_encrypt,
  (cipher_decrypt_func_t)        aes_cbc_nist_decrypt,
  (cipher_set_iv_func_t)         aes_cbc_set_iv,
  (char *)                       aes_cbc_description,
  (int)                          0,   /* instance count */
  (cipher_test_case_t *)        &aes_cbc_test_case_0,
  (debug_module_t *)            &mod_aes_cbc
};