bzoj4128 Matrix

    xiaoxiao2021-03-25  105

    Description

    给定矩阵A,B和模数p,求最小的x满足

    A^x = B (mod p)

    Input 第一行两个整数n和p,表示矩阵的阶和模数,接下来一个n * n的矩阵A.接下来一个n * n的矩阵B

    Output 输出一个正整数,表示最小的可能的x,数据保证在p内有解

    BSGS求离散对数,这里换成了矩阵乘法,不过是一样的。对矩阵也求逆就可以了。

    #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #include<map> using namespace std; #define LL unsigned long long const int p=131; int n,mod,m; int pow(int base,int p) { int ret=1; while (p) { if (p&1) ret=ret*base%mod; base=base*base%mod; p>>=1; } return ret; } int inverse(int x) { return pow(x,mod-2); } struct mat { int a[80][80]; LL h; void init() { h=0; for (int i=0;i<n;i++) for (int j=0;j<n;j++) h=h*p+a[i][j]; } void I() { for (int i=0;i<n;i++) for (int j=0;j<n;j++) a[i][j]=(i==j); init(); } void rd() { for (int i=0;i<n;i++) for (int j=0;j<n;j++) scanf("%d",&a[i][j]); init(); } mat operator * (const mat &m) const { mat ret; for (int i=0;i<n;i++) for (int j=0;j<n;j++) { ret.a[i][j]=0; for (int k=0;k<n;k++) ret.a[i][j]=(ret.a[i][j]+a[i][k]*m.a[k][j]%mod)%mod; } ret.init(); return ret; } mat pow(int x) { mat ret,base; ret.I(); for (int i=0;i<n;i++) for (int j=0;j<n;j++) base.a[i][j]=a[i][j]; while (x) { if (x&1) ret=ret*base; base=base*base; x>>=1; } ret.init(); return ret; } mat inv() { mat ret; int p,x; ret.I(); for (int i=0;i<n;i++) { p=-1; for (int j=i;j<n;j++) if (a[j][i]) { p=j; break; } if (p==-1) continue; if (p!=i) for (int j=0;j<n;j++) { swap(a[i][j],a[p][j]); swap(ret.a[i][j],ret.a[p][j]); } x=inverse(a[i][i]); for (int j=0;j<n;j++) { a[i][j]=a[i][j]*x%mod; ret.a[i][j]=ret.a[i][j]*x%mod; } for (int j=0;j<n;j++) if (j!=i) { x=a[j][i]; for (int k=0;k<n;k++) { a[j][k]=(a[j][k]-a[i][k]*x%mod+mod)%mod; ret.a[j][k]=(ret.a[j][k]-ret.a[i][k]*x%mod+mod)%mod; } } } ret.init(); return ret; } }a,b,tem,u; map<LL,int> mp; int main() { scanf("%d%d",&n,&mod); a.rd(); b.rd(); m=sqrt(mod); tem.I(); mp[tem.h]=0; for (int i=1;i<m;i++) { tem=tem*a; if (!mp.count(tem.h)) mp[tem.h]=i; } u=(a.pow(m)).inv(); for (int i=0;;i++) { if (mp.count(b.h)) { printf("%d\n",i*m+mp[b.h]); return 0; } b=b*u; } }
    转载请注明原文地址: https://ju.6miu.com/read-15275.html

    最新回复(0)