/*
 * Decompiled with CFR 0.152.
 */
package net.savignano.thirdparty.org.bouncycastle.pqc.crypto.bike;

import java.security.SecureRandom;
import net.savignano.thirdparty.org.bouncycastle.crypto.digests.SHA3Digest;
import net.savignano.thirdparty.org.bouncycastle.crypto.digests.SHAKEDigest;
import net.savignano.thirdparty.org.bouncycastle.pqc.crypto.bike.BIKERing;
import net.savignano.thirdparty.org.bouncycastle.pqc.crypto.bike.BIKEUtils;
import net.savignano.thirdparty.org.bouncycastle.util.Arrays;
import net.savignano.thirdparty.org.bouncycastle.util.Bytes;

class BIKEEngine {
    private int r;
    private int w;
    private int hw;
    private int t;
    private int nbIter;
    private int tau;
    private final BIKERing bikeRing;
    private int L_BYTE;
    private int R_BYTE;
    private int R2_BYTE;

    public BIKEEngine(int r, int w, int t, int l, int nbIter, int tau) {
        this.r = r;
        this.w = w;
        this.t = t;
        this.nbIter = nbIter;
        this.tau = tau;
        this.hw = this.w / 2;
        this.L_BYTE = l / 8;
        this.R_BYTE = r + 7 >>> 3;
        this.R2_BYTE = 2 * r + 7 >>> 3;
        this.bikeRing = new BIKERing(r);
    }

    public int getSessionKeySize() {
        return this.L_BYTE;
    }

    private byte[] functionH(byte[] seed) {
        byte[] res = new byte[2 * this.R_BYTE];
        SHAKEDigest digest = new SHAKEDigest(256);
        digest.update(seed, 0, seed.length);
        BIKEUtils.generateRandomByteArray(res, 2 * this.r, this.t, digest);
        return res;
    }

    private void functionL(byte[] e0, byte[] e1, byte[] result) {
        byte[] hashRes = new byte[48];
        SHA3Digest digest = new SHA3Digest(384);
        digest.update(e0, 0, e0.length);
        digest.update(e1, 0, e1.length);
        digest.doFinal(hashRes, 0);
        System.arraycopy(hashRes, 0, result, 0, this.L_BYTE);
    }

    private void functionK(byte[] m, byte[] c0, byte[] c1, byte[] result) {
        byte[] hashRes = new byte[48];
        SHA3Digest digest = new SHA3Digest(384);
        digest.update(m, 0, m.length);
        digest.update(c0, 0, c0.length);
        digest.update(c1, 0, c1.length);
        digest.doFinal(hashRes, 0);
        System.arraycopy(hashRes, 0, result, 0, this.L_BYTE);
    }

    public void genKeyPair(byte[] h0, byte[] h1, byte[] sigma, byte[] h, SecureRandom random) {
        byte[] seeds = new byte[64];
        random.nextBytes(seeds);
        SHAKEDigest digest = new SHAKEDigest(256);
        digest.update(seeds, 0, this.L_BYTE);
        BIKEUtils.generateRandomByteArray(h0, this.r, this.hw, digest);
        BIKEUtils.generateRandomByteArray(h1, this.r, this.hw, digest);
        long[] h0Element = this.bikeRing.create();
        long[] h1Element = this.bikeRing.create();
        this.bikeRing.decodeBytes(h0, h0Element);
        this.bikeRing.decodeBytes(h1, h1Element);
        long[] t = this.bikeRing.create();
        this.bikeRing.inv(h0Element, t);
        this.bikeRing.multiply(t, h1Element, t);
        this.bikeRing.encodeBytes(t, h);
        System.arraycopy(seeds, this.L_BYTE, sigma, 0, sigma.length);
    }

    public void encaps(byte[] c0, byte[] c1, byte[] k, byte[] h, SecureRandom random) {
        byte[] m = new byte[this.L_BYTE];
        random.nextBytes(m);
        byte[] eBytes = this.functionH(m);
        byte[] e0Bytes = new byte[this.R_BYTE];
        byte[] e1Bytes = new byte[this.R_BYTE];
        this.splitEBytes(eBytes, e0Bytes, e1Bytes);
        long[] e0Element = this.bikeRing.create();
        long[] e1Element = this.bikeRing.create();
        this.bikeRing.decodeBytes(e0Bytes, e0Element);
        this.bikeRing.decodeBytes(e1Bytes, e1Element);
        long[] t = this.bikeRing.create();
        this.bikeRing.decodeBytes(h, t);
        this.bikeRing.multiply(t, e1Element, t);
        this.bikeRing.add(t, e0Element, t);
        this.bikeRing.encodeBytes(t, c0);
        this.functionL(e0Bytes, e1Bytes, c1);
        Bytes.xorTo(this.L_BYTE, m, c1);
        this.functionK(m, c0, c1, k);
    }

    public void decaps(byte[] k, byte[] h0, byte[] h1, byte[] sigma, byte[] c0, byte[] c1) {
        int[] h0Compact = new int[this.hw];
        int[] h1Compact = new int[this.hw];
        this.convertToCompact(h0Compact, h0);
        this.convertToCompact(h1Compact, h1);
        byte[] syndrome = this.computeSyndrome(c0, h0);
        byte[] ePrimeBits = this.BGFDecoder(syndrome, h0Compact, h1Compact);
        byte[] ePrimeBytes = new byte[2 * this.R_BYTE];
        BIKEUtils.fromBitArrayToByteArray(ePrimeBytes, ePrimeBits, 0, 2 * this.r);
        byte[] e0Bytes = new byte[this.R_BYTE];
        byte[] e1Bytes = new byte[this.R_BYTE];
        this.splitEBytes(ePrimeBytes, e0Bytes, e1Bytes);
        byte[] mPrime = new byte[this.L_BYTE];
        this.functionL(e0Bytes, e1Bytes, mPrime);
        Bytes.xorTo(this.L_BYTE, c1, mPrime);
        byte[] wlist = this.functionH(mPrime);
        if (Arrays.areEqual(ePrimeBytes, 0, this.R2_BYTE, wlist, 0, this.R2_BYTE)) {
            this.functionK(mPrime, c0, c1, k);
        } else {
            this.functionK(sigma, c0, c1, k);
        }
    }

    private byte[] computeSyndrome(byte[] c0, byte[] h0) {
        long[] t = this.bikeRing.create();
        long[] u = this.bikeRing.create();
        this.bikeRing.decodeBytes(c0, t);
        this.bikeRing.decodeBytes(h0, u);
        this.bikeRing.multiply(t, u, t);
        return this.bikeRing.encodeBitsTransposed(t);
    }

    private byte[] BGFDecoder(byte[] s, int[] h0Compact, int[] h1Compact) {
        byte[] e = new byte[2 * this.r];
        int[] h0CompactCol = this.getColumnFromCompactVersion(h0Compact);
        int[] h1CompactCol = this.getColumnFromCompactVersion(h1Compact);
        byte[] black = new byte[2 * this.r];
        byte[] ctrs = new byte[this.r];
        byte[] gray = new byte[2 * this.r];
        int T = this.threshold(BIKEUtils.getHammingWeight(s), this.r);
        this.BFIter(s, e, T, h0Compact, h1Compact, h0CompactCol, h1CompactCol, black, gray, ctrs);
        this.BFMaskedIter(s, e, black, (this.hw + 1) / 2 + 1, h0Compact, h1Compact, h0CompactCol, h1CompactCol);
        this.BFMaskedIter(s, e, gray, (this.hw + 1) / 2 + 1, h0Compact, h1Compact, h0CompactCol, h1CompactCol);
        for (int i = 1; i < this.nbIter; ++i) {
            Arrays.fill(black, (byte)0);
            T = this.threshold(BIKEUtils.getHammingWeight(s), this.r);
            this.BFIter2(s, e, T, h0Compact, h1Compact, h0CompactCol, h1CompactCol, ctrs);
        }
        if (BIKEUtils.getHammingWeight(s) == 0) {
            return e;
        }
        return null;
    }

    private void BFIter(byte[] s, byte[] e, int T, int[] h0Compact, int[] h1Compact, int[] h0CompactCol, int[] h1CompactCol, byte[] black, byte[] gray, byte[] ctrs) {
        int ctrBit2;
        int ctrBit1;
        int count;
        int j;
        this.ctrAll(h0CompactCol, s, ctrs);
        int count2 = ctrs[0] & 0xFF;
        int ctrBit12 = (count2 - T >> 31) + 1;
        int ctrBit22 = (count2 - (T - this.tau) >> 31) + 1;
        e[0] = (byte)(e[0] ^ (byte)ctrBit12);
        black[0] = (byte)ctrBit12;
        gray[0] = (byte)ctrBit22;
        for (j = 1; j < this.r; ++j) {
            count = ctrs[j] & 0xFF;
            ctrBit1 = (count - T >> 31) + 1;
            ctrBit2 = (count - (T - this.tau) >> 31) + 1;
            int n = this.r - j;
            e[n] = (byte)(e[n] ^ (byte)ctrBit1);
            black[j] = (byte)ctrBit1;
            gray[j] = (byte)ctrBit2;
        }
        this.ctrAll(h1CompactCol, s, ctrs);
        count2 = ctrs[0] & 0xFF;
        ctrBit12 = (count2 - T >> 31) + 1;
        ctrBit22 = (count2 - (T - this.tau) >> 31) + 1;
        int n = this.r;
        e[n] = (byte)(e[n] ^ (byte)ctrBit12);
        black[this.r] = (byte)ctrBit12;
        gray[this.r] = (byte)ctrBit22;
        for (j = 1; j < this.r; ++j) {
            count = ctrs[j] & 0xFF;
            ctrBit1 = (count - T >> 31) + 1;
            ctrBit2 = (count - (T - this.tau) >> 31) + 1;
            int n2 = this.r + this.r - j;
            e[n2] = (byte)(e[n2] ^ (byte)ctrBit1);
            black[this.r + j] = (byte)ctrBit1;
            gray[this.r + j] = (byte)ctrBit2;
        }
        for (int i = 0; i < 2 * this.r; ++i) {
            this.recomputeSyndrome(s, i, h0Compact, h1Compact, black[i] != 0);
        }
    }

    private void BFIter2(byte[] s, byte[] e, int T, int[] h0Compact, int[] h1Compact, int[] h0CompactCol, int[] h1CompactCol, byte[] ctrs) {
        int ctrBit1;
        int count;
        int j;
        int[] updatedIndices = new int[2 * this.r];
        this.ctrAll(h0CompactCol, s, ctrs);
        int count2 = ctrs[0] & 0xFF;
        int ctrBit12 = (count2 - T >> 31) + 1;
        e[0] = (byte)(e[0] ^ (byte)ctrBit12);
        updatedIndices[0] = ctrBit12;
        for (j = 1; j < this.r; ++j) {
            count = ctrs[j] & 0xFF;
            ctrBit1 = (count - T >> 31) + 1;
            int n = this.r - j;
            e[n] = (byte)(e[n] ^ (byte)ctrBit1);
            updatedIndices[j] = ctrBit1;
        }
        this.ctrAll(h1CompactCol, s, ctrs);
        count2 = ctrs[0] & 0xFF;
        ctrBit12 = (count2 - T >> 31) + 1;
        int n = this.r;
        e[n] = (byte)(e[n] ^ (byte)ctrBit12);
        updatedIndices[this.r] = ctrBit12;
        for (j = 1; j < this.r; ++j) {
            count = ctrs[j] & 0xFF;
            ctrBit1 = (count - T >> 31) + 1;
            int n2 = this.r + this.r - j;
            e[n2] = (byte)(e[n2] ^ (byte)ctrBit1);
            updatedIndices[this.r + j] = ctrBit1;
        }
        for (int i = 0; i < 2 * this.r; ++i) {
            this.recomputeSyndrome(s, i, h0Compact, h1Compact, updatedIndices[i] == 1);
        }
    }

    private void BFMaskedIter(byte[] s, byte[] e, byte[] mask, int T, int[] h0Compact, int[] h1Compact, int[] h0CompactCol, int[] h1CompactCol) {
        boolean isCtrGtEqT;
        int j;
        int[] updatedIndices = new int[2 * this.r];
        for (j = 0; j < this.r; ++j) {
            if (mask[j] != 1) continue;
            isCtrGtEqT = this.ctr(h0CompactCol, s, j) >= T;
            this.updateNewErrorIndex(e, j, isCtrGtEqT);
            updatedIndices[j] = isCtrGtEqT ? 1 : 0;
        }
        for (j = 0; j < this.r; ++j) {
            if (mask[this.r + j] != 1) continue;
            isCtrGtEqT = this.ctr(h1CompactCol, s, j) >= T;
            this.updateNewErrorIndex(e, this.r + j, isCtrGtEqT);
            updatedIndices[this.r + j] = isCtrGtEqT ? 1 : 0;
        }
        for (int i = 0; i < 2 * this.r; ++i) {
            this.recomputeSyndrome(s, i, h0Compact, h1Compact, updatedIndices[i] == 1);
        }
    }

    private int threshold(int hammingWeight, int r) {
        switch (r) {
            case 12323: {
                return BIKEEngine.thresholdFromParameters(hammingWeight, 0.0069722, 13.53, 36);
            }
            case 24659: {
                return BIKEEngine.thresholdFromParameters(hammingWeight, 0.005265, 15.2588, 52);
            }
            case 40973: {
                return BIKEEngine.thresholdFromParameters(hammingWeight, 0.00402312, 17.8785, 69);
            }
        }
        throw new IllegalArgumentException();
    }

    private static int thresholdFromParameters(int hammingWeight, double dm, double da, int min) {
        return Math.max(min, (int)Math.floor(dm * (double)hammingWeight + da));
    }

    private int ctr(int[] hCompactCol, byte[] s, int j) {
        int i;
        int count = 0;
        int limit = this.hw - 4;
        for (i = 0; i <= limit; i += 4) {
            int sPos0 = hCompactCol[i + 0] + j - this.r;
            int sPos1 = hCompactCol[i + 1] + j - this.r;
            int sPos2 = hCompactCol[i + 2] + j - this.r;
            int sPos3 = hCompactCol[i + 3] + j - this.r;
            sPos0 += sPos0 >> 31 & this.r;
            sPos1 += sPos1 >> 31 & this.r;
            sPos2 += sPos2 >> 31 & this.r;
            sPos3 += sPos3 >> 31 & this.r;
            count += s[sPos0] & 0xFF;
            count += s[sPos1] & 0xFF;
            count += s[sPos2] & 0xFF;
            count += s[sPos3] & 0xFF;
        }
        while (i < this.hw) {
            int sPos = hCompactCol[i] + j - this.r;
            sPos += sPos >> 31 & this.r;
            count += s[sPos] & 0xFF;
            ++i;
        }
        return count;
    }

    private void ctrAll(int[] hCompactCol, byte[] s, byte[] ctrs) {
        int col = hCompactCol[0];
        int neg = this.r - col;
        System.arraycopy(s, col, ctrs, 0, neg);
        System.arraycopy(s, 0, ctrs, neg, col);
        for (int i = 1; i < this.hw; ++i) {
            int k;
            int j;
            int col2 = hCompactCol[i];
            int neg2 = this.r - col2;
            int jLimit = neg2 - 4;
            for (j = 0; j <= jLimit; j += 4) {
                int n = j + 0;
                ctrs[n] = (byte)(ctrs[n] + (s[col2 + j + 0] & 0xFF));
                int n2 = j + 1;
                ctrs[n2] = (byte)(ctrs[n2] + (s[col2 + j + 1] & 0xFF));
                int n3 = j + 2;
                ctrs[n3] = (byte)(ctrs[n3] + (s[col2 + j + 2] & 0xFF));
                int n4 = j + 3;
                ctrs[n4] = (byte)(ctrs[n4] + (s[col2 + j + 3] & 0xFF));
            }
            while (j < neg2) {
                int n = j;
                ctrs[n] = (byte)(ctrs[n] + (s[col2 + j] & 0xFF));
                ++j;
            }
            int kLimit = this.r - 4;
            for (k = neg2; k <= kLimit; k += 4) {
                int n = k + 0;
                ctrs[n] = (byte)(ctrs[n] + (s[k + 0 - neg2] & 0xFF));
                int n5 = k + 1;
                ctrs[n5] = (byte)(ctrs[n5] + (s[k + 1 - neg2] & 0xFF));
                int n6 = k + 2;
                ctrs[n6] = (byte)(ctrs[n6] + (s[k + 2 - neg2] & 0xFF));
                int n7 = k + 3;
                ctrs[n7] = (byte)(ctrs[n7] + (s[k + 3 - neg2] & 0xFF));
            }
            while (k < this.r) {
                int n = k;
                ctrs[n] = (byte)(ctrs[n] + (s[k - neg2] & 0xFF));
                ++k;
            }
        }
    }

    private void convertToCompact(int[] compactVersion, byte[] h) {
        int count = 0;
        for (int i = 0; i < this.R_BYTE; ++i) {
            for (int j = 0; j < 8 && i * 8 + j != this.r; ++j) {
                int mask = h[i] >> j & 1;
                compactVersion[count] = i * 8 + j & -mask | compactVersion[count] & ~(-mask);
                count = (count + mask) % this.hw;
            }
        }
    }

    private int[] getColumnFromCompactVersion(int[] hCompact) {
        int[] hCompactColumn = new int[this.hw];
        if (hCompact[0] == 0) {
            hCompactColumn[0] = 0;
            for (int i = 1; i < this.hw; ++i) {
                hCompactColumn[i] = this.r - hCompact[this.hw - i];
            }
        } else {
            for (int i = 0; i < this.hw; ++i) {
                hCompactColumn[i] = this.r - hCompact[this.hw - 1 - i];
            }
        }
        return hCompactColumn;
    }

    private void recomputeSyndrome(byte[] syndrome, int index, int[] h0Compact, int[] h1Compact, boolean isOne) {
        byte bit;
        byte by = bit = isOne ? (byte)1 : 0;
        if (index < this.r) {
            for (int i = 0; i < this.hw; ++i) {
                if (h0Compact[i] <= index) {
                    int n = index - h0Compact[i];
                    syndrome[n] = (byte)(syndrome[n] ^ bit);
                    continue;
                }
                int n = this.r + index - h0Compact[i];
                syndrome[n] = (byte)(syndrome[n] ^ bit);
            }
        } else {
            for (int i = 0; i < this.hw; ++i) {
                if (h1Compact[i] <= index - this.r) {
                    int n = index - this.r - h1Compact[i];
                    syndrome[n] = (byte)(syndrome[n] ^ bit);
                    continue;
                }
                int n = this.r - h1Compact[i] + (index - this.r);
                syndrome[n] = (byte)(syndrome[n] ^ bit);
            }
        }
    }

    private void splitEBytes(byte[] e, byte[] e0, byte[] e1) {
        int partial = this.r & 7;
        System.arraycopy(e, 0, e0, 0, this.R_BYTE - 1);
        byte split = e[this.R_BYTE - 1];
        byte mask = (byte)(-1 << partial);
        e0[this.R_BYTE - 1] = (byte)(split & ~mask);
        byte c = (byte)(split & mask);
        for (int i = 0; i < this.R_BYTE; ++i) {
            byte next = e[this.R_BYTE + i];
            e1[i] = (byte)(next << 8 - partial | (c & 0xFF) >>> partial);
            c = next;
        }
    }

    private void updateNewErrorIndex(byte[] e, int index, boolean isOne) {
        int newIndex = index;
        if (index != 0 && index != this.r) {
            newIndex = index > this.r ? 2 * this.r - index + this.r : this.r - index;
        }
        int n = newIndex;
        e[n] = (byte)(e[n] ^ (isOne ? (byte)1 : 0));
    }
}

