#include <stdio.h>
#include <vector>

#ifndef _DBGLEVEL
#define _DBGLEVEL 0
#endif
// modexp2 0857bde36f680cda9535784bc4aec5f4344131071b419f732ac9c74d0e61db49dd958c7344236e0279df009c6e66aec6ba574c2820d4aeb0c4d814c8e184c6ea7e6d8aa3e15d1c251c78c5364ea2b3edb3c19e90739afa765506242e78fcdc71a87efdfe2df6ce6039fc62cb3b360cb77cd5574292282df352886cbc3fcfbff2 10001 F765A3A0C9C291D81A56FE73794A746B8DA23DBE155D0D495B49D581B5C6545F449A10FDF1C26A92FBD1F43A0687044927A6A21B69A73999E6083D03ACDAFFA6409F1BC71D810628F6E18F76231ED6E22D54ED2502E66F8A33D0D5F07B3EB605F7418110E2EF9A5EE77B070F4EADFCF3D70C53E870F29C9D4F229F2CB6C25383
//   = 0001FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF003021300906052B0E03021A050004147020B30D42ED7CDF0E849F30E8B2E137941E9691
// see also 
// c:\local\cvsprj\secphone\trunk\machine\echocancel\fir_fxp\nr_class.cpp
// c:\local\cvsprj\secphone\trunk\machine\echocancel\fir_fxp\nr_class.h

typedef long long carry_t;
typedef unsigned long base_t;

typedef std::vector<base_t> Number;
#define BITSINBASE (8*sizeof(base_t))

Number bitmask;


void initmask()
{
    for (size_t i=0 ; i<BITSINBASE ; i++)
    {
        bitmask.push_back(static_cast<base_t>(1<<i));
    }
}
void writenum(const Number& num);

void SetNumber(Number& num, base_t val)
{
    num.clear();
    num.push_back(val);
}
void rightshift(Number& val)
{
    carry_t carry= 0;
    for (Number::reverse_iterator i=val.rbegin() ; i!=val.rend() ; i++)
    {
        carry |= (*i);
        (*i) = static_cast<base_t>(carry>>1);
        carry &= 1;
        carry <<= BITSINBASE;
    }
}
void leftshift(Number& val)
{
    carry_t carry= 0;
    for (Number::iterator i=val.begin() ; i!=val.end() ; i++)
    {
        carry |= (carry_t)(*i) <<1;
        (*i) = static_cast<base_t>(carry);
        carry >>= BITSINBASE;
        carry &= 1;
    }
    if (carry)
        val.push_back(static_cast<base_t>(carry));
}
void big_add(Number& sum, const Number& val)
{
    carry_t carry= 0;
#if _DBGLEVEL > 3
    writenum(sum); printf(" + "); writenum(val);
#endif

    if (sum.size() < val.size())
        sum.resize(val.size());

    Number::const_iterator i_val= val.begin();
    Number::iterator i_sum= sum.begin();
    while (i_sum!=sum.end())
    {
        carry += (carry_t)(*i_sum) + (i_val!=val.end() ? (*i_val):0);
        (*i_sum) = static_cast<base_t>(carry);
        carry >>= BITSINBASE;
        ++i_sum;
        if (i_val!=val.end()) ++i_val;
    }
    if (carry)
        sum.push_back(static_cast<base_t>(carry));
#if _DBGLEVEL > 3
    printf(" = "); writenum(sum); printf("\n");
#endif
}
void big_sub(Number& sum, const Number& val)
{
    carry_t carry= 0;

    if (sum.size() < val.size())
        sum.resize(val.size());
#if _DBGLEVEL > 3
    writenum(sum); printf(" - "); writenum(val);
#endif
    Number::const_iterator i_val= val.begin();
    Number::iterator i_sum= sum.begin();
    while (i_sum!=sum.end())
    {
        //printf("%x+%x-%x", carry, *i_sum, (i_val!=val.end() ? (*i_val):0));
        carry += (carry_t)(*i_sum) -  (i_val!=val.end() ? (*i_val):0);
        (*i_sum) = static_cast<base_t>(carry);
        carry >>= BITSINBASE;
        //printf("=%x:%x\n", carry, *i_sum);
        ++i_sum;
        if (i_val!=val.end()) ++i_val;
    }
    if (carry)
        sum.push_back(static_cast<base_t>(carry));
#if _DBGLEVEL > 3
    printf(" = "); writenum(sum); printf("\n");
#endif
}
void shrink(Number& a)
{
    while (a.size() && a.back()==0)
        a.pop_back();
}
int compare(const Number& a, const Number& b)
{
    if (a.size() < b.size())
        return -1;
    if (a.size() > b.size())
        return 1;

    Number::const_iterator i_b= b.begin();
    Number::const_iterator i_a= a.begin();
    while (i_a!=a.end() && i_b!=b.end())
    {
        if ((*i_a)<(*i_b))
            return -1;
        if ((*i_a)>(*i_b))
            return 1;
        ++i_a;
        ++i_b;
    }
    if (i_a!=a.end())
        return 1;
    if (i_b!=b.end())
        return -1;
    return 0;
}

void modtrunc(Number& val, const Number& mod)
{
#if _DBGLEVEL > 2
    printf("modtrunc "); writenum(val); printf(" : "); writenum(mod); printf("\n");
#endif

    Number::reverse_iterator i_val= val.rbegin();
    while (i_val!=val.rend() && (*i_val)==0)
        ++i_val;
    Number::const_reverse_iterator i_mod= mod.rbegin();
    while (i_mod!=mod.rend() && (*i_mod)==0)
        ++i_mod;
    int diff= distance(i_val, val.rend()) - distance(i_mod, mod.rend());
#if _DBGLEVEL > 2
    printf(" diff=%d ", diff);
#endif
    if (diff<0)
        return;
    if (diff>1) {
        printf("!!!diff = %d\n", diff);
        exit(1);
    }
    if (diff==1) {
#if _DBGLEVEL > 2
        printf("mod trunc %d: %0*X >= 00\n", i_val-val.rbegin(), 2*sizeof(base_t), (*i_val));
#endif
        big_sub(val, mod);
        return;
    }
    while (i_val!=val.rend() && (*i_val)==(*i_mod)) {
        ++i_val;
        ++i_mod;
    }
    if (i_val!=val.rend() && (*i_val)>=(*i_mod)) {
#if _DBGLEVEL > 2
        printf("mod trunc %d: %0*X >= %0*X\n", i_val-val.rbegin(), 2*sizeof(base_t), (*i_val), 2*sizeof(base_t), (*i_mod));
#endif
        big_sub(val, mod);
    }
}

void big_mulmod(Number& val, const Number& mult, const Number& mod)
{
#if _DBGLEVEL > 1
    printf("calc mulmod "); 
    writenum(val); printf(" x ");
    writenum(mult); printf(" mod "); 
    writenum(mod); printf("\n");
#endif

    Number bits= val;
    Number lshifter= mult;
    val.clear();

    for (Number::iterator i_bits= bits.begin() ; i_bits!=bits.end() ; ++i_bits)
    {
        base_t bitval= (*i_bits);
        for (size_t i=0 ; i<BITSINBASE ; i++)
        {
            if (bitval & bitmask[i]) {
                big_add(val, lshifter);
                modtrunc(val, mod);
            }
#if _DBGLEVEL > 1
            //printf("   bits="); writenum(bits); printf(" ");
            printf("   lsh=");  writenum(lshifter); printf(" ");
            printf("   val=");  writenum(val); printf("\n");
#endif
            //XX rightshift(bits, len);
            leftshift(lshifter);
            modtrunc(lshifter, mod);
        }
    }
#if _DBGLEVEL > 1
    printf("mulmodresult="); writenum(val); printf("\n");
#endif
}

bool big_bittest(const Number& num, size_t bit)
{
    return (num[bit/BITSINBASE] & bitmask[bit&(BITSINBASE-1)])!=0;
}
void modexp(const Number& num, const Number& exp, const Number& modulus, Number& result)
{
#if _DBGLEVEL > 0
    printf("calc modexp "); 
    writenum(num); printf(" ^ ");
    writenum(exp); printf(" mod "); 
    writenum(modulus); printf("\n");
#endif

    Number num2; SetNumber(num2, 1);

    big_mulmod(num2, num, modulus); // num2 = (num2*num) % modulus

    SetNumber(result, 1);

    for (size_t expbit=0 ; expbit<exp.size()*BITSINBASE ; expbit++) {
#if _DBGLEVEL > 0
        printf("%d: num2=", expbit); writenum(num2); printf("\t");
        printf("result="); writenum(result); printf("\n");
#endif
        if (big_bittest(exp, expbit)) {
            big_mulmod(result, num2, modulus);
#if _DBGLEVEL > 0
            printf(" ***");
#endif
        }
#if _DBGLEVEL > 0
        printf("\n");
#endif
        big_mulmod(num2, num2, modulus);  // num2 = (num2*num2) % modulus
    }
#if _DBGLEVEL > 0
    printf("modexpresult="); writenum(result); printf("\n");
#endif
}
int digit2val(char c)
{
    return c<='9' ? c-'0'
            : c<='F' ? c-'A'+10
            : c<='f' ? c-'a'+10
            :0;
}
void hexstr2num(Number& num, const char* hexstr)
{
    num.clear();
    int j=0;
    for (int i=strlen(hexstr)-1 ; i>=0 ; i--, j++) {
        int shift= j & (sizeof(base_t)*2-1);
        if (shift) {
            num.back() |= digit2val(hexstr[i])<<(4*shift);
        }
        else {
            num.push_back( digit2val(hexstr[i]) );
        }
    }
}
void writenum(const Number& num)
{
    Number::const_reverse_iterator i= num.rbegin();

    while (i!=num.rend() && (*i)==0)
        ++i;

    if (i!=num.rend())
        printf("%lX", (*i++));
    while (i!=num.rend())
    {
        printf("%0*lX", 2*sizeof(base_t), (*i));
        ++i;
    }
}

bool is_even(base_t a)
{
    return (a&1)==0;
}

bool is_even(const Number& a)
{
    return (a.size()==0) || is_even(a[0]);
}
bool is_zero(const Number& a)
{
    if (a.size()==0)
        return true;
    for (Number::const_iterator i= a.begin() ; i!=a.end() ; ++i)
        if (*i)
            return false;
    return true;
}
// for x,y - calculates a,b,v such that a*x+b*y=v = gcd(x,y)
void gcd_calc(const Number& x, Number& a, const Number& y, Number& b, Number& v)
{
    // algorithm 14.61

    int nshifts= 0;
    Number xx= x;
    Number yy= y;
    while (is_even(xx) && is_even(yy)) {
        rightshift(xx);
        rightshift(yy);
        nshifts++;
    }
    Number u = xx;
    v = yy;
    Number A; SetNumber(A, 1);
    Number B;
    Number C;
    Number D; SetNumber(D, 1);
    while (!is_zero(u)) {
        while (is_even(u)) {
            rightshift(u);
            if (!(is_even(A) && is_even(B))) {
                big_add(A,yy); 
                big_sub(B,xx);
            }
            rightshift(A);
            rightshift(B);
        }
        while (is_even(v)) {
            rightshift(v);
            if (!(is_even(C) && is_even(D))) {
                big_add(C,yy); 
                big_sub(D,xx);
            }
            rightshift(C);
            rightshift(D);
        }
        if (compare(u,v)>=0) {
            big_sub(u, v);
            big_sub(A, C);
            big_sub(B, D);
        }
        else {
            big_sub(v, u);
            big_sub(C, A);
            big_sub(D, B);
        }
    }
    a= C;
    b= D;
    for (int i=0 ; i<nshifts ; i++)
        leftshift(v);
}
class Montgomery {
public:
    Montgomery(Number& modulus, Number& base)
        : _modulus(modulus), _base(base)
    {
        Number g;
        gcd_calc(modulus, _inversemodulus, base, _inversebase, g);

    }
    void multiply(const Number& x, const Number& y, Number& a)
    {
        SetNumber(a, 0);
        for (size_t i=0 ; i<x.size() ; i++) {
            base_t u = (a[0]+x[i]*y[0]) * _inversemodulus[0];
            
            carry_t carry= 0;
            for (size_t j=0 ; j<a.size() ; j++) {

                carry += (j<a.size()-1 ? a[j+1] : 0) + (j<y.size() ? x[i]*y[j] : 0) + (j<_modulus.size() ? u*_modulus[j] : 0);
                a[j]= static_cast<base_t>(carry);
                carry >>= BITSINBASE;
            }
            if (carry)
                a.push_back(static_cast<base_t>(carry));
        }
        if (compare(a, _modulus)>0)
            big_sub(a, _modulus);
    }

    // returns value * base^-1 ( mod modulus )
    void reduce(const Number& val, Number& a)
    {
        a.clear();
        for (size_t i=0 ; i<val.size() ; i++) {
            base_t u= val[i]*_inversemodulus[0];

            carry_t carry= 0;
            for (size_t j=i ; j<a.size() ; j++) {
                carry += (j<_modulus.size() ? u*_modulus[j] : 0) + a[j];
                a[j]= static_cast<base_t>(carry);
                carry >>= BITSINBASE;
            }
            if (carry)
                a.push_back(static_cast<base_t>(carry));
        }
    }
    void modexp(const Number& num, const Number& exp, Number& result)
    {
        Number a= _base;
        Number b= num;
        big_mulmod(b, _base, _modulus);

        SetNumber(result, 1);
        int byte;
        for (byte=exp.size()-1 ; exp[byte]==0 && byte>=0 ; byte--)
            ;
        if (byte<0)
            return;
        int bit;
        for (bit=BITSINBASE-1 ; (exp[byte]&bitmask[bit])==0 && bit>=0 ; bit--)
            ;
        while(byte>=0)
        {
            Number a2;
            multiply(a, a, a2);
            a= a2;
            if (exp[byte]&bitmask[bit])
            {
                Number ab;
                multiply(a, b, ab);
                a= ab;
            }
            bit--;
            if (bit<0) {
                bit=BITSINBASE-1;
                byte--;
            }
        }
        Number one; SetNumber(one, 1);
        multiply(a, one, result);
    }

private:
    Number _modulus;
    Number _base;
    Number _inversebase;
    Number _inversemodulus;
};


#include <windows.h>
#include <stdlib.h>
#include <string.h>


void run_tests()
{

    // test left+right shift
    Number x; SetNumber(x, 1);
    for (size_t i=0 ; i<BITSINBASE*16 ; i++)
        leftshift(x);
    for (size_t i=0 ; i<BITSINBASE*16 ; i++)
        rightshift(x);
    shrink(x);
    if (x.size()!=1 || x[0]!=1) {
        printf("shift test failed\n");
        printf(" >> %d =", BITSINBASE*16); writenum(x); printf("\n");
        exit(1);
    }
    rightshift(x);
    shrink(x);
    if (x.size()!=0) {
        printf("shift test failed\n");
        printf(" >> 1 = "); writenum(x); printf("\n");
        exit(1);
    }

    // test add/sub
    Number y;
    x.clear();
    srand(0);
    for (int i=0 ; i<8 ; i++)
        x.push_back(static_cast<base_t>(rand()));
    for (int i=0 ; i<8 ; i++)
        y.push_back(static_cast<base_t>(rand()));
    //writenum(x); printf(" + "); writenum(y);
    big_add(x, y);
    //printf(" = "); writenum(x); printf("\n");
    big_sub(x, y);
    //printf("- = "); writenum(x); printf("\n");

    Number z;
    for (int i=0 ; i<8 ; i++)
        z.push_back(static_cast<base_t>(rand()));

//                writenum(x);
//  printf("* "); writenum(y);
//  printf("mod "); writenum(z);
//  big_mulmod(x, y, z);
//  printf(" = "); writenum(x); printf("\n");

//                writenum(x);
//  printf("^ "); writenum(y);
//  printf("mod "); writenum(z);
//  Number q;
//  modexp(x, y, z, q);
//  printf(" = "); writenum(q); printf("\n");


    hexstr2num(x, "0857BDE36F680CDA9535784BC4AEC5F4344131071B419F732AC9C74D0E61DB49DD958C7344236E0279DF009C6E66AEC6BA574C2820D4AEB0C4D814C8E184C6EA7E6D8AA3E15D1C251C78C5364EA2B3EDB3C19E90739AFA765506242E78FCDC71A87EFDFE2DF6CE6039FC62CB3B360CB77CD5574292282DF352886CBC3FCFBFF2");
    hexstr2num(y, "10001");
    hexstr2num(z, "F765A3A0C9C291D81A56FE73794A746B8DA23DBE155D0D495B49D581B5C6545F449A10FDF1C26A92FBD1F43A0687044927A6A21B69A73999E6083D03ACDAFFA6409F1BC71D810628F6E18F76231ED6E22D54ED2502E66F8A33D0D5F07B3EB605F7418110E2EF9A5EE77B070F4EADFCF3D70C53E870F29C9D4F229F2CB6C25383");
    Number q;
    DWORD t0= GetTickCount();
    modexp(x, y, z, q);
    printf("%lu ticks\n", GetTickCount()-t0);

    Number r;
    hexstr2num(r, "1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF003021300906052B0E03021A050004147020B30D42ED7CDF0E849F30E8B2E137941E9691");
    shrink(r);
    shrink(q);
    if (compare(r,q)!=0) {
        printf("got:      "); writenum(q); printf("\n");
        printf("expected: "); writenum(r); printf("\n");
    }
    Number a;
    Number b;
    Number g;
    SetNumber(x, 693);
    SetNumber(y, 609);
    gcd_calc(x, a, y, b, g);
    printf("a="); writenum(a); printf("\n");
    printf("b="); writenum(b); printf("\n");
    printf("g="); writenum(g); printf("\n");
}
int main(int argc, char **argv)
{
    initmask();
    run_tests();
    if (argc!=4) {
        printf("Usage: modexp num exp mod\n");
        return 1;
    }
    char *numstr= argv[1];
    char *expstr= argv[2];
    char *modstr= argv[3];

    // strip leading zeros
    while (*numstr=='0') numstr++;
    while (*expstr=='0') expstr++;
    while (*modstr=='0') modstr++;

    Number num; hexstr2num(num, numstr);
    Number exp; hexstr2num(exp, expstr);
    Number mod; hexstr2num(mod, modstr);

    if (num.size() > mod.size()) {
        printf("number cannot be larger than mod\n");
        return 1;
    }

    //printf("num="); writenum(num); printf("\n");
    //printf("mod="); writenum(mod); printf("\n");
    //printf("exp="); writenum(exp); printf("\n");

    Number result;

    try {
        modexp(num, exp, mod, result);
    } catch(char *msg) {
        printf("EXCEPTION: %s\n", msg);
    }

    writenum(result);
    printf("\n");

    printf(" now trying montgomery\n");

    Number base; SetNumber(base, 0x10000);
    Montgomery m(mod, base);
    m.modexp(num, exp, result);
    writenum(result);

    printf(".\n");
}


