/* xsc256.c */
#define Main
#include <xsc256.h>

private semiword baseline[2][16];
export int8 *filename;
export semiword ***best;

private state *mkstate(int256 basekey) {
    int128 ke, w;
    int128 *p;
    state *st;
    int16 size;

    p = (int128 *)&basekey;
    ke = *p;
    w = p[1];

    size = sizeof(struct s_state);
    st = (state *)alloc(size);
    assert(p);
    zero($1 p, size);

    p = (int128 *)st->w;
    *p = w;
    st->subkey = gensubkeys(ke);

    return st;
}

private roundkey *nextrk(state *st) {
    roundkey *p;

    if (!st)
        return (roundkey *)0;
    
    p = st->subkey;
    st->subkey = st->subkey->next;

    return p;
}

private int64 P(int16 input) {
    int16 x;
    int64 y;

    y = 0;
    x = input;

    y = $8 x << 48;
    y |= $8 x << 32;
    y |= $8 x << 16;
    y |= x;

    return y;
}

private int192 xbox(int128 input) {
    int8 *x, *y;
    const int8 *iptr;
    int192 output;
    int8 n, tmp, idx;

    x = $1 &input;
    y = $1 &output;
    iptr = Indices;

    for (n=0; n<16; n++) {
        if ((n%2)) {
            idx = *iptr;
            iptr++;
            y[idx] = x[n];
        }
        else {
            //y[*iptr] = y[*++iptr] = x[n];
            tmp = x[n];
            idx = *iptr;
            iptr++;
            y[idx] = tmp;
            idx = *iptr;
            iptr++;
            y[idx] = tmp;
        }
    }

    return output;
}

private void showstate(int8 *ident, state *st) {
    int8 n;

    if (!st)
        return;
    
    printf("(state *)%s = {\n  ", ident);
    show(st->subkey);
    printf("  %s->w = { ", ident);
    for (n=0; n<16; n++)
        printf("%.02hhx ", (char)st->w[n]);
    printf("}\n}\n");

    return;
}

private void showint256(int8 *ident, int256 input) {
    void *mem;
    int32 *p;
    int8 i;

    printf("%s = 0x", ident);
    mem = $v ($v &input+32);
    for (i=0; i<8; i++) {
        mem -= 4;
        p = (int32 *)mem;
        printf("%.08x", $i *p);
    }
    printf("\n");

    return;
}

private void showsemiword(int8 *ident, semiword input) {
    printf("%s = 0x%.01hhx%.01hhx%.01hhx\n", $c ident,
        (char)input.x, (char)input.y, (char)input.z);

    return;
}

private void showint64(int8 *ident, int64 input) {
    printf("%s = 0x%.16llx\n", $c ident, (long long int)input);

    return;
}

private void showint128(int8 *ident, int128 input) {
    void *mem;
    int32 *p;
    int8 i;

    printf("%s = 0x", ident);
    mem = $v ($v &input+16);
    for (i=0; i<4; i++) {
        mem -= 4;
        p = (int32 *)mem;
        printf("%.08x", $i *p);
    }
    printf("\n");

    return;
}

private void showint192(int8 *ident, int192 input) {
    void *mem;
    int32 *p;
    int8 i;

    printf("%s = 0x", ident);
    mem = $v ($v &input+24);
    for (i=0; i<6; i++) {
        mem -= 4;
        p = (int32 *)mem;
        printf("%.08x", $i *p);
    }
    printf("\n");

    return;
}

private int192 sboxes(state *st, int192 input) {
    int8 n;
    int192 output;
    semiword *x, *y;

    output.x = 0;
    x = (semiword *)&input;
    y = (semiword *)&output;

    for (n=0; n<16; n++) {
            st->lastsbox[n] = x[n];
            y[n] = sbox(x[n]);
    }

    return output;
}

private void showroundkey(int8 *id, roundkey *rk) {
    assert(rk);

    printf("%s = {\n", $c id);
    printf("    id=%d\n", rk->id);
    printf("    rc=0x%.16llx\n  ", $8 rk->rc);
    show(rk->subkey);
    printf("}\n");

    return;
}

private semiword sbox(semiword input) {
    return Sbox[input.z][input.y][input.x];
}

private int128 cbox(int192 input) {
    int128 output;
    int8 *x, *y;
    int8 n, idx, tmp;

    x = $1 &input;
    y = $1 &output;

    for (idx=n=0; n<8; n++) {
        tmp = x[(idx+1)];
         y[idx] = x[idx] & tmp;
         idx++;
        tmp = x[(idx+1)];
         y[idx] = x[idx] | tmp;
         idx++;
    }

    return output;
}

private int128 f(state *st, int128 input) {
    int128 output;

    output = cbox(sboxes(st,xbox(input)));
    return output;
}

private roundkey *mkroundkey(int8 id, int128 subkey, int64 rc) {
    int16 size;
    roundkey *p;

    size = sizeof(struct s_roundkey);
    p = alloc(size);
    assert(p);
    zero($1 p, size);

    p->id = id;
    p->subkey = subkey;
    p->rc = rc;
    p->next = (roundkey *)0;

    return p;
}

private roundkey *gensubkeys(int128 key) {
    roundkey *p, *last, *first;
    // int64 rc;
    int8 x;

    x = 0;
    first = mkroundkey(x, key, RCs[x]);
    x++;

    for (last = first; x<16; x++) {
        // rc = P(RCs[x]);
        // p = mkroundkey(x, g(last->subkey,rc), rc);
        p = mkroundkey(x, 0, 1);
        last->next = p;
        last = p;
    }
    last->next = first;

    return first;
}

private void zero(int8* x, int16 size) {
    int16 n;
    int8 *p;

    for (p=x, n=size; n; n--, p++)
        *p = 0;
    
    return;
}

private void mksbox() {
    // int8 x,y,z;

    // for (x=0; x<=0x0f; x++)
    //     for (y=0; y<=0x0f; y++)
    //         for (z=0; z<=0x0f; z++)
    //             Sbox[z][y][x] = (semiword){.x=z, .y=y, .z=x};
    //             // getrandom($c &Sbox[((z*16*16)+(y*16)+x)], 8, 0);
    
    // Sbox[0][0][0] = (semiword){5,5,5};
    Sbox = (semiword (*)[0x10][0x10])&sbox1d;

    return;
}

#define _rotl(b)    private int ## b rotl ## b (int ## b,int8); \
private int ## b rotl ## b (int ## b input, int8 n) { \
    int ## b x=input; for (int8 i=0; i<n; i++) x=rotl1((b),x); return x; \
}
_rotl(8);
_rotl(16);
_rotl(32);
_rotl(64);
_rotl(128);

#define _rotr(b) private int ## b rotr ## b (int ## b,int8); \
private int ## b rotr ## b (int ## b input, int8 n) { \
    int ## b x=input; for (int8 i=0; i<n; i++) x=rotr1((b),x); return x; \
}
_rotr(8);
_rotr(16);
_rotr(32);
_rotr(64);
_rotr(128);

private int128 g(int128 x, int64 rc) {
    int128 y;
    semiword *p;
    void *mem;

    y = rotl(x, 7);
    y = y * rc;
    
    y = rotr(y, 4);
    mem = $v &y;
    mem += 14;
    p = (semiword *)mem;
    *p = sbox(*p);
    y = rotl(y, 4);

    return y;
}

export state *xsc256init(int256 basekey) {
    state *st;

    st = mkstate(basekey);
    //rnds(st, 62);

    return st;
}

export unsigned char xsc256byte(state *st) {
    int8 idx;
    unsigned char byte;

    rnds(st, 2);
    idx = grabidx(st->w[0]);
    byte = (unsigned char)st->w[idx];

    return byte;
}

export unsigned char *xsc256encrypt(state *st, unsigned char *input,
    unsigned short int len) {
        int16 size, n;
        int8 *buf, *yptr;
        unsigned char *xptr, *ret;

        assert(st && input && len);
        size = $2 len;
        buf = $1 alloc(size);
        assert(buf);
        zero($1 buf, size);

        xptr = input;
        yptr = buf;

        for (n=size; n; n--, xptr++, yptr++)
            *yptr = *xptr ^ xsc256byte(st);
        ret = (unsigned char *)buf;

        return ret;
}

export void xsc256uninit(state *st) {
    int8 n;
    roundkey *rk;

    if (!st)
        return;
    
    for (n=0, rk=nextrk(st); (n<16) && rk; rk=nextrk(st), n++)
        destroy(rk);
    free(st);

    return;
}

export int256 xsc256kdf(unsigned char *input) {
    int256 *ke;
    int8 hash[256];
    SHA2_CTX ctx;
    int16 inputlength;
    int8 n;

    assert(input != 0);
    inputlength = $2 strlen($c input);
    assert(inputlength <= 256);
    zero($1 &hash, 256);
    strncpy($c hash, $c input, $i inputlength);

    for (n=0; n<64; n++) {
        zero($1 &ctx, sizeof(SHA2_CTX));
        SHA256Init(&ctx);
        SHA256Update(&ctx, hash, $i inputlength);
        SHA256Final(hash, &ctx);
    }
    ke = (int256 *)&hash;

    return *ke;
}

export void encryptfile_(int256 ke) {
    int8 n;
    int8 *p;
    int8 buf[1024];
    signed int ret;
    state *st;

    st = xsc256init(ke);
    do {
        zero($1 &buf, 1024);
        ret = read(0, $c buf, 1023);
        if (ret < 1)
            break;
        else
            n = (int8)ret;
        
        p = xsc256encrypt(st, buf, n);
        if (!p)
            break;
        write(1, $c p, $i n);
        destroy(p);
    } while(true);
    xsc256uninit(st);

    return;
}

private void rnd(state *st) {
    int128 w;
    int128 *wptr;

    assert(st);
    wptr = (int128 *)&st->w;
    w = *wptr;

    w = rotr(w, 7);
    w = f(st,w);
    w ^= nextrk(st)
        -> subkey;
    *wptr = w;
    // show(w);

    return;
}

// void main1() {
//     int8 n;
//     int8 *p;
//     int128 x;

//     x = 0;
//     for (n=0, p=$1 &x; n<16; n++, p++)
//         *p = n;

//     for (n=0; n<64; n++)
//         x = f(x);
//     show(x);

//     return;
// }

// void main2() {
//     int128 x;
//     roundkey *rk;
//     int8 n;

//     x = 3;
//     rk = gensubkeys(x);
//     for (n=0; n<16; n++, rk = rk->next)
//         show(rk);

//     return;
// }

// void main3() {
//     state *st;
//     int256 x;
//     int8 *p;
//     int8 n;
//     unsigned char byte;
//     // roundkey *rk;

//     x = 0;
//     for (n=0, p=$1 &x; n<32; n++, p++)
//         *p = n;

//     // show(x);
//     st = xsc256init(x);
//     // show(st);
//     // for (n=0, rk=nextrk(st); n<17; n++, rk = nextrk(st))  
//         // show(rk);

//     for (n=0; n<12; n++) {
//         byte = xsc256byte(st);
//         //show(st);
//         printf("%.02hhx\n", (char)byte);
//     }
//     xsc256uninit(st);

//     return;

// }

// void main4(unsigned char *key) {
//     int256 x;

//     x = 3423948239482;
//     encryptfile(x);

//     return;
// }

// int main(int argc, char *argv[]) {
//     if (argc < 2) {
//         fprintf(stderr, "Usage: cat input | %s <key> >> output.txt\n",
//             *argv);
//         return -1;
//     }
//     //mksbox();
//     main4((unsigned char *)argv[1]);

//     return 0;
// }

private int32 genseed() {
    int8 buf[4];
    int32 *p;
    signed int ret;

    ret = getrandom($v &buf, 4, GRND_RANDOM|GRND_NONBLOCK);
    if (ret < 0) {
        fprintf(stderr, "No random bytes available. Try again later.\n");
        exit(-1);
    }
    p = (int32 *)&buf;

    return *p;
}

private tuple *semirand() {
    tuple *p;
    int32 rnd_;
    semiword *sm;
    int16 size;

    rnd_ = (rand() % 0x00ffffff);
    sm = (semiword *)&rnd_;
    size = sizeof(struct s_tuple);
    p = (tuple *)alloc(size);
    assert(p);
    zero($1 p, size);

    p->x = *sm;
    p->y = sm[1];

    return p;
}

private bool savesbox(int8 *file, semiword ***sb) {
    int32 size, fd, n;
    int8 *p;
    signed int ret;

    assert(file && sb);
    ret = open($c file, O_WRONLY|O_CREAT, 00644);
    if (ret > 2)
        fd = $4 ret;
    else
        return false;
    
    size = (16*16*16*2);
    p = $1 sb;
    for (n=size; n; n--, p++) {
        ret = write($i fd, $c p, 1);
        if (ret != 1) {
            close($i fd);
            return false;
        }
    }
    close(fd);

    return true;
}

export int main(int argc, char *argv[]) {
    int8 branch, hiscore;
    int32 size;

    filename = $1 0;
    best = (semiword ***)0;
    if (argc < 2) {
        fprintf(stderr, "Usage: %s <filename>\n", *argv);
        return(-1);
    }
    else
        filename = $1 argv[1];

    // signal(SIGTERM, SIG_IGN);
    // signal(SIGINT, handler);

    size = (16*16*16*2);
    best = (semiword ***)alloc(size);
    assert(best);

    branch = 0;
    srand(genseed());
    mksbox();

    sbox2csv();
    return 0;
    

    branch = runtrials(0, false);
    hiscore = 0;
    // return 0;

    do {
        branch = runtrials(0, false);
        if (branch > hiscore) {
            printf("*** New hiscore!\n");
            hiscore = branch;
            memcpy($c best, $c Sbox, size);
        }
        printf("branch = Δ%hhd, hiscore = Δ%hhd \n",
            (char)branch, (char)hiscore);
    } while(true);

    return 0;
}

private bool getbit(int8 *data, int16 num) {
    int8 byte;
    int16 idx;
    int8 bit;

    idx = (num / 8);
    bit = (num % 8);
    byte = data[idx];

    return (getbit8(byte, bit)) ?
            true :
        false;
}

private void printbits(void *data, int16 length, int8 sep) {
    int16 n, x;
    int8 *p;

    p = $1 data;
    assert(data && length);
    
    x=0;
    n=--length;
    do {
        if (x && !(x%sep))
            printf(" ");
         if (getbit(p, n))
            printf("1");
        else
            printf("0");

       x++;
    } while (n--);
    printf("\n");

    return;
}

int8 runtrials(int8 input, bool verify) {
    int8 bit, branch, branch_, l, m;
    int16 n;
    int32 bign;
    int16 max;
    tuple *rnd_;
    semiword sw;

    max = (verify) ?
            0xffff :
        2;
    l = (verify) ?
            150 :
        128;
    m = (verify) ?
            167 :
        140;

    branch_ = 0, branch = (input)?input:100;
    zero($1 &baseline, (sizeof(struct s_semiword)*16*2));
    // (volatile void)trial(0,0,0);
    rnd_ = (tuple *)0;

    if (!verify)
        for (bign=0; bign<0x00ffffff; bign++) {
            if (rnd_)
                free(rnd_);
            rnd_ = semirand();
            sw = swap(rnd_->x, rnd_->y);
            if (sw.x) {
                free(rnd_);
                rnd_ = semirand();
                sw = swap(rnd_->x, rnd_->y);
                // free(rnd_);
                if (sw.x)
                    // return input;
                    continue;
            }
    }
    if (rnd_)
        free(rnd_);

    for (n=1; n<max; n++)
        for (bit=l; bit<m; bit++) {
            (volatile void)trial(0,0,0);
            branch_ = trial(n, bit, 1);
            // printf(" %d ", branch_);
            if (branch_ < branch)
                branch = branch_;
        }

    // if (input && (branch < input)) {
    //     swap(rnd_->x, rnd_->y);
    //     free(rnd_);

    //     return input;
    // }
    // free(rnd_);

    return branch;
}

int8 trial(int16 val, int8 bit, int8 initbranch) {
    int8 n, branch;
    int256 key;
    state *st;

    key = (!bit) ?
            0 :
         ($32 1 << bit);
    
    st = xsc256init(key);
    assert(st);

    if (!bit) {
        rnd(st);
        for (n=0; n<16; n++) {
            // show(st->lastsbox[n]);
            baseline[0][n] = st->lastsbox[n];
        }
        rnd(st);
        for (n=0; n<16; n++)
            baseline[1][n] = st->lastsbox[n];
        xsc256uninit(st);

        return 0;
    }

    branch = initbranch;
    rnd(st);
    for (n=0; n<16; n++) {
        // printf("n=%d\n", (int8)n);
        // show(baseline[0][n]);
        // show(st->lastsbox[n]);
        // if (baseline[0][n] != st->lastsbox[n])
        if (
            (baseline[0][n].x == st->lastsbox[n].x) &&
            (baseline[0][n].y == st->lastsbox[n].y) &&
            (baseline[0][n].z == st->lastsbox[n].z)
        );
        else {
            // printf("  rnd=0, sbox[%hhd] differs:\n", (char)n);
            // show(baseline[0][n]);
            // show(st->lastsbox[n]);
            // printf("\n");
            // show(baseline[0][n]);
            // show(st->lastsbox[n]);
            // branch++;
            break;
        }
    }

    rnd(st);
    for (n=0; n<16; n++) {
        if (
            (baseline[1][n].x == st->lastsbox[n].x) &&
            (baseline[1][n].y == st->lastsbox[n].y) &&
            (baseline[1][n].z == st->lastsbox[n].z)
        );
        else {
            // printf("  rnd=1, sbox[%hhd] differs:\n", (char)n);
            // show(baseline[0][n]);
            // show(st->lastsbox[n]);
            // printf("\n");
            // show(baseline[1][n]);
            // show(st->lastsbox[n]);
            branch++;
        }
    }
    xsc256uninit(st);

    return branch;
}

semiword swap(semiword idxx, semiword idxy) {
    semiword valx, valy;
    semiword outx, outy;
    int32 x;

    // printf("swap(\n  ");
    // show(idxx);
    // printf("  ");
    // show(idxy);
    // printf(");\n\n");
    // fflush(stdout);

    if (
        (idxx.x == idxy.x) &&
        (idxx.y == idxy.y) &&
        (idxx.z == idxy.z)
    )
        // return swap((semiword){++idxx.z, idxx.y, idxx.x},
        //     (semiword){idxy.z, idxy.y, ++idxy.x});
        return $s {0,0,1};

    // valx = Sbox[idxx.z][idxx.y][idxx.x];
    // valy = Sbox[idxy.z][idxy.y][idxy.x];
    // x = ((idxx.z*16*16)+(idxx.y*16)+idxx.x);
    // printf("x=0x%.08x (%d)\n", $i x, $i x);
    // x = ((idxy.z*16*16)+(idxy.y*16)+idxy.x);
    // printf("y=0x%.08x (%d)\n", $i x, $i x);

    valx = Sbox[idxx.z][idxx.y][idxx.x];
    valy = Sbox[idxy.z][idxy.y][idxy.x];

    outx = valy;
    outy = valx;

    if (
        (idxx.x == outx.x) &&
        (idxx.y == outx.y) &&
        (idxx.z == outx.z)
    )
        return $s {0,0,1};

    if (
        (idxy.x == outy.x) &&
        (idxy.y == outy.y) &&
        (idxy.z == outy.z)
    )
        return $s {0,0,1};

    // Sbox[idxx.z][idxx.y][idxx.x] = outx;
    // Sbox[idxy.z][idxy.y][idxy.x] = outy;
    Sbox[idxx.z][idxx.y][idxx.x] = outx;
    Sbox[idxy.z][idxy.y][idxy.x] = outy;

    return (semiword){0,0,0};
}

export void handler(int _) {
    int8 branchno;

    assert(filename);
    printf("\n\nSaving file...");
    fflush(stdout);
    savesbox(filename, best);
    printf("%s done\n", $1 filename);
    printf("Verifying branch number...");
    fflush(stdout);

    branchno = runtrials(0, true);
    printf(" Δ%hhd\n", (char)branchno);
    exit(0);

    return;
}

private void sbox2csv() {
    int16 x, y;
    int16 *ptr;
    semiword *px, *py;
    semiword sy;

    printf("x,s(x)\n");
    for (x=0, px=(semiword *)&x; x<0x1000; x++) {
        sy = sbox(*px);
        py = (semiword *)&sy;
        ptr = (int16 *)py;
        y = *ptr;
        printf("%d,%d\n", x, y);

    }

    return;
}
