diff --git a/include/net/tls.h b/include/net/tls.h index da616db48413..754b130672f0 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -202,6 +202,7 @@ struct cipher_context { char *iv; u16 rec_seq_size; char *rec_seq; + u16 aad_size; }; union tls_crypto_context { diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 9326c06c2ffe..7b6386f4c685 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -185,7 +185,7 @@ static int tls_do_decryption(struct sock *sk, int ret; aead_request_set_tfm(aead_req, ctx->aead_recv); - aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_ad(aead_req, tls_ctx->rx.aad_size); aead_request_set_crypt(aead_req, sgin, sgout, data_len + tls_ctx->rx.tag_size, (u8 *)iv_recv); @@ -289,12 +289,12 @@ static struct tls_rec *tls_get_rec(struct sock *sk) sg_init_table(rec->sg_aead_in, 2); sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, - sizeof(rec->aad_space)); + tls_ctx->tx.aad_size); sg_unmark_end(&rec->sg_aead_in[1]); sg_init_table(rec->sg_aead_out, 2); sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, - sizeof(rec->aad_space)); + tls_ctx->tx.aad_size); sg_unmark_end(&rec->sg_aead_out[1]); return rec; @@ -455,7 +455,7 @@ static int tls_do_encryption(struct sock *sk, msg_en->sg.curr = start; aead_request_set_tfm(aead_req, ctx->aead_send); - aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_ad(aead_req, tls_ctx->tx.aad_size); aead_request_set_crypt(aead_req, rec->sg_aead_in, rec->sg_aead_out, data_len, rec->iv_data); @@ -1317,7 +1317,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); mem_size = aead_size + (nsg * sizeof(struct scatterlist)); - mem_size = mem_size + TLS_AAD_SPACE_SIZE; + mem_size = mem_size + tls_ctx->rx.aad_size; mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); /* Allocate a single block of memory which contains @@ -1333,7 +1333,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, sgin = (struct scatterlist *)(mem + aead_size); sgout = sgin + n_sgin; aad = (u8 *)(sgout + n_sgout); - iv = aad + TLS_AAD_SPACE_SIZE; + iv = aad + tls_ctx->rx.aad_size; /* Prepare IV */ err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, @@ -1352,7 +1352,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, /* Prepare sgin */ sg_init_table(sgin, n_sgin); - sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(&sgin[0], aad, tls_ctx->rx.aad_size); err = skb_to_sgvec(skb, &sgin[1], rxm->offset + tls_ctx->rx.prepend_size, rxm->full_len - tls_ctx->rx.prepend_size); @@ -1364,7 +1364,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, if (n_sgout) { if (out_iov) { sg_init_table(sgout, n_sgout); - sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(&sgout[0], aad, tls_ctx->rx.aad_size); *chunk = 0; err = tls_setup_from_iter(sk, out_iov, data_len, @@ -2100,6 +2100,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) goto free_priv; } + cctx->aad_size = TLS_AAD_SPACE_SIZE; cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; cctx->tag_size = tag_size; cctx->overhead_size = cctx->prepend_size + cctx->tag_size;