强网杯2022-强网前锋-polydiv

2022-7-30 Misc强网杯2022

# polydiv

  • task.py
import socketserver
import os, sys, signal
import string, random
from hashlib import sha256

from secret import flag
from poly2 import *

pad = lambda s:s + bytes([(len(s)-1)%16+1]*((len(s)-1)%16+1))
testCases = 40

class Task(socketserver.BaseRequestHandler):
    def _recvall(self):
        BUFF_SIZE = 2048
        data = b''
        while True:
            part = self.request.recv(BUFF_SIZE)
            data += part
            if len(part) < BUFF_SIZE:
                break
        return data.strip()

    def send(self, msg, newline=True):
        try:
            if newline:
                msg += b'\n'
            self.request.sendall(msg)
        except:
            pass

    def recv(self, prompt=b'> '):
        self.send(prompt, newline=False)
        return self._recvall()

    def close(self):
        self.send(b"Bye~")
        self.request.close()

    def proof_of_work(self):
        random.seed(os.urandom(8))
        proof = ''.join([random.choice(string.ascii_letters+string.digits) for _ in range(20)])
        _hexdigest = sha256(proof.encode()).hexdigest()
        self.send(f"sha256(XXXX+{proof[4:]}) == {_hexdigest}".encode())
        x = self.recv(prompt=b'Give me XXXX: ')
        if len(x) != 4 or sha256(x+proof[4:].encode()).hexdigest() != _hexdigest:
            return False
        return True

    def guess(self):
        from Crypto.Util.number import getPrime
        a,b,c = [getPrime(i) for i in [256,256,128]]
        pa,pb,pc = [PP(bin(i)[2:]) for i in [a,b,c]]
        r = pa*pb+pc
        self.send(b'r(x) = '+str(r).encode())
        self.send(b'a(x) = '+str(pa).encode())
        self.send(b'c(x) = '+str(pc).encode())
        self.send(b'Please give me the b(x) which satisfy a(x)*b(x)+c(x)=r(x)')
        #self.send(b'b(x) = '+str(pb).encode())
        
        return self.recv(prompt=b'> b(x) = ').decode() == str(pb)


    def handle(self):
        #signal.alarm(1200)

        if not self.proof_of_work():
            return

        for turn in range(testCases):
            if not self.guess():
                self.send(b"What a pity, work harder.")
                return
            self.send(b"Success!")
        else:
            self.send(b'Congratulations, this is you reward.')
            self.send(flag)
        
        

class ThreadedServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
    pass

#class ForkedServer(socketserver.ForkingMixIn, socketserver.TCPServer):
class ForkedServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
    pass

if __name__ == "__main__":
    
    HOST, PORT = '0.0.0.0', 10000
    server = ForkedServer((HOST, PORT), Task)
    server.allow_reuse_address = True
    server.serve_forever()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  • poly2.py

class Polynomial2():
    '''
    模二多项式环,定义方式有三种
    一是从高到低给出每一项的系数
        >>> Polynomial2([1,1,0,1])
        x^3 + x^2 + 1

    二是写成01字符串形式
        >>> Polynomial2('1101')
        x^3 + x^2 + 1

    三是直接给出系数为1的项的阶
        >>> Poly([3,1,4])
        x^4 + x^3 + x
        >>> Poly([]) # 加法元
        0
        >>> Poly(0) # 乘法元
        1
        >>> Poly(1,2) * Poly(2,3)
        x^5 + x^3
    '''
    def __init__(self,ll):
        
        if type(ll) ==  str:
            ll = list(map(int,ll))

        self.param = ll[::-1]
        self.ones = [i for i in range(len(self.param)) if self.param[i] == 1] # 系数为1的项的阶数列表
        self.Latex = self.latex()
        self.b = ''.join([str(i) for i in ll]) # 01串形式打印系数
        
        self.order = 0 # 最高阶
        try:self.order = max(self.ones)
        except:pass
        
    def format(self,reverse = True):
        '''
            格式化打印字符串
            默认高位在左
            reverse = False时,低位在左
            但是注意定义多项式时只能高位在右
        '''
        r = ''
        if len(self.ones) == 0:
            return '0'
        if reverse:
            return ((' + '.join(f'x^{i}' for i in self.ones[::-1])+' ').replace('x^0','1').replace('x^1 ','x ')).strip()
        return ((' + '.join(f'x^{i}' for i in self.ones)+' ').replace('x^0','1').replace('x^1 ','x ')).strip()

    def __call__(self,x):
        '''
            懒得写了,用不到
        '''
        print(f'call({x})')

    def __add__(self,other):
        '''
            多项式加法
        '''
        a,b = self.param[::-1],other.param[::-1]
        if len(a) < len(b):a,b = b,a
        for i in range(len(a)):
            try:a[-1-i] = (b[-1-i] + a[-1-i]) % 2
            except:break
        return Polynomial2(a)

    def __mul__(self,other):
        '''
            多项式乘法
        '''

        a,b = self.param[::-1],other.param[::-1]
        r = [0 for i in range(len(a) + len(b) - 1)]
        for i in range(len(b)):
            if b[-i-1] == 1:
                if i != 0:sa = a+[0]*i
                else:sa = a
                sa = [0] * (len(r)-len(sa)) + sa
                #r += np.array(sa)
                #r %= 2
                r = [(r[t] + sa[t])%2 for t in range(len(r))]
        return Polynomial2(r)

    def __sub__(self,oo):
        # 模二多项式环,加减相同
        return self + oo


    def __repr__(self) -> str:
        return self.format()
    
    def __str__(self) -> str:
        return self.format()

    def __pow__(self,a):
        # 没有大数阶乘的需求,就没写快速幂
        t = Polynomial2([1])
        for i in range(a):
            t *= self
        return t
    
    def latex(self,reverse=True):
        '''
            Latex格式打印...其实就是给两位及以上的数字加个括号{}
        '''
        def latex_pow(x):
            if len(str(x)) <= 1:
                return str(x)
            return '{'+str(x)+'}'
        
        r = ''
        if len(self.ones) == 0:
            return '0'
        if reverse:
            return (' + '.join(f'x^{latex_pow(i)}' for i in self.ones[::-1])+' ').replace('x^0','1').replace(' x^1 ',' x ').strip()
        return (' + '.join(f'x^{latex_pow(i)}' for i in self.ones)+' ').replace('x^0','1').replace(' x^1 ',' x ').strip()

    def __eq__(self,other):
        return self.ones == other.ones

    def __lt__(self,other):
        return max(self.ones) < max(other.ones)

    def __le__(self,other):
        return max(self.ones) <= max(other.ones)

def Poly(*args):
    '''
        另一种定义方式
        Poly([3,1,4]) 或 Poly(3,1,4)
    '''
    if len(args) == 1 and type(args[0]) in [list,tuple]:
        args = args[0]
        
    if len(args) == 0:
        return Polynomial2('0')
    
    ll = [0 for i in range(max(args)+1)]
    for i in args:
        ll[i] = 1
    return Polynomial2(ll[::-1])

    
PP = Polynomial2
P = Poly
# 简化名称,按长度区分 P 和 PP
if __name__ == '__main__':
    p = Polynomial2('10011')
    p3 = Polynomial2('11111')
    Q = p*p3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

# writeup

登录验证

计算 中的 ,使用题目的poly2.py中的类和函数处理 ,根据等式可知 的最高次幂的项为 ,通过爆破可以得到 的值,而testCases = 40,所以循环40次即可

from pwn import *
import string
from hashlib import *
from poly2 import *

io = remote('123.56.86.227', 16303)
context.log_level='debug'
io.recvuntil('XXXX+')
proof = io.recv(16).decode()
io.recvuntil(') == ')
_hexdigest = io.recv(64).decode()
io.recvuntil('Give me XXXX: ')

tables = string.ascii_letters+string.digits
def bp():
    for t1 in tables:
        for t2 in tables:
            for t3 in tables:
                for t4 in tables:
                    t = t1+t2+t3+t4
                    if sha256((t+proof).encode()).hexdigest() == _hexdigest:
                        return t

io.sendline(bp())

def str2PP(_str):
    lst = _str.split(' + ')
    re = []
    for i in lst:
        if 'x^' in i:
            re.append(int(i[2:]))
        elif 'x' == i:
            re.append(1)
        elif '1' == i:
            re.append(0)
    return P(re)

def guess():
    for i in range(1, 256):
        pb = PP(bin(i)[2:])
        if pa*pb+pc == r:
            return str(pb)

testCases = 40
for i in range(testCases):
    rx = io.recvline()[7:].decode().replace('\n', '')
    ax = io.recvline()[7:].decode().replace('\n', '')
    cx = io.recvline()[7:].decode().replace('\n', '')
    r = str2PP(rx)
    pa = str2PP(ax)
    pc = str2PP(cx)
    io.recv()

    pb = str(guess()).encode()
    io.sendline(pb)
    io.recvuntil('Success!\n')

io.interactive()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

flag{dfd4eebe-0974-47a6-a4a2-4b58e5923ae7}
1