16#ifndef CRYPT_ML_KEM_LOCAL_H
17#define CRYPT_ML_KEM_LOCAL_H
18#include "crypt_mlkem.h"
19#include "sal_atomic.h"
20#include "crypt_local_types.h"
23#define MLKEM_N_HALF 128
24#define MLKEM_CIPHER_LEN 384
26#define MLKEM_SEED_LEN 32
27#define MLKEM_SHARED_KEY_LEN 32
28#define MLKEM_PRF_BLOCKSIZE 64
29#define MLKEM_ENCODE_BLOCKSIZE 32
32#define MLKEM_Q_INV_BETA (-3327)
33#define MLKEM_Q_HALF ((MLKEM_Q + 1) / 2)
34#define MLKEM_BITS_OF_Q 12
35#define MLKEM_INVN 3303
40#define MLKEM_PLANTARD_L 16
41#define MLKEM_PLANTARD_ALPHA 3
45#define MLKEM_LAST_ROUND_ZETA 2131356556
46#define MLKEM_HALF_DEGREE_INVERSE_MOD_Q (-33544352)
48typedef int32_t (*MlKemHashFunc)(uint32_t id,
const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t *outLen);
51static inline int16_t BarrettReduction(int32_t a)
53 const int32_t v = ((1 << 27) + MLKEM_Q / 2) / MLKEM_Q;
54 int32_t t = ((int64_t)v * a + (1 << 26)) >> 27;
56 return (int16_t)(a - t);
59static inline int16_t PlantardReduction(int32_t a)
62 tmp >>= MLKEM_PLANTARD_L;
63 tmp = (tmp + (1 << MLKEM_PLANTARD_ALPHA)) * MLKEM_Q;
64 tmp >>= MLKEM_PLANTARD_L;
70 int16_t *matrix[MLKEM_K_MAX][MLKEM_K_MAX];
71 int16_t *vectorS[MLKEM_K_MAX];
72 int16_t *vectorE[MLKEM_K_MAX];
73 int16_t *vectorT[MLKEM_K_MAX];
84 uint32_t encapsKeyLen;
85 uint32_t decapsKeyLen;
101 CRYPT_ALGO_MLKEM_DK_FORMAT_TYPE dkFormat;
103 uint8_t seed[MLKEM_SEED_LEN * 2];
105int32_t MLKEM_DecodeDk(CRYPT_ML_KEM_Ctx *ctx,
const uint8_t *dk, uint32_t dkLen);
106int32_t MLKEM_DecodeEk(CRYPT_ML_KEM_Ctx *ctx,
const uint8_t *ek, uint32_t ekLen);
107void MLKEM_ComputNTT(int16_t *a,
const int32_t *psi);
108void MLKEM_ComputINTT(int16_t *a,
const int32_t *psi);
109void MLKEM_SamplePolyCBD(int16_t *polyF, uint8_t *buf, uint8_t eta);
110void MLKEM_TransposeMatrixMulAdd(uint8_t k, int16_t **matrix, int16_t **polyVec, int16_t **polyVecOut,
111 const int16_t mulCache[MLKEM_K_MAX][MLKEM_N_HALF]);
112void MLKEM_MatrixMulAdd(uint8_t k, int16_t **matrix, int16_t **polyVec, int16_t **polyVecOut,
113 const int16_t mulCache[MLKEM_K_MAX][MLKEM_N_HALF]);
114void MLKEM_VectorInnerProductAdd(uint8_t k, int16_t **polyVec1, int16_t **polyVec2, int16_t *polyOut,
115 const int32_t *factor);
116void MLKEM_VectorInnerProductAddUseCache(uint8_t k, int16_t **polyVec1, int16_t **polyVec2, int16_t *polyOut,
117 const int16_t mulCache[MLKEM_K_MAX][MLKEM_N_HALF]);
119void MLKEM_ComputeMulCache(uint8_t k, int16_t **input, int16_t output[MLKEM_K_MAX][MLKEM_N_HALF],
120 const int32_t *factor);
122int32_t MLKEM_KeyGenInternal(CRYPT_ML_KEM_Ctx *ctx, uint8_t *d, uint8_t *z);
124int32_t MLKEM_EncapsInternal(CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t *ctLen, uint8_t *sk, uint32_t *skLen,
127int32_t MLKEM_DecapsInternal(CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t ctLen, uint8_t *sk, uint32_t *skLen);