# writeUP for ByteCTF 2021

## EASYXOR

 1 2 3 4  def shift(m, k, c): if k < 0: return m ^ (m >> (-k)) & c return m ^ ((m << k) & c) 

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57   def invshift_opt(c,k,mask): c = bin(c)[2:].rjust(64,'0') cip=[int(i) for i in c] mask = bin(mask)[2:].rjust(64,'0') mask=[int(i) for i in mask] ans={} idx = 63 for i in range(k): ans[idx]=cip.pop() idx-=1 for i in range(63-k,-1,-1): tmp = cip[i]^(ans[i+k]&mask[i]) ans[i]=tmp flag ='' for i in range(64): flag += str(ans[i]) ans = int(flag,2) return ans def invshift_ngt(c,k,mask): k=-k c = bin(c)[2:].rjust(64,'0') cip=[int(i) for i in c] mask = bin(mask)[2:].rjust(64,'0') mask=[int(i) for i in mask] ans={} for i in range(k): ans[i]=cip[i] for i in range(k,64): tmp = cip[i]^(ans[i-k]&mask[i]) ans[i]=tmp flag ='' for i in range(64): flag += str(ans[i]) # ans=[str(ans[i]) for i in range(64)] # ans = "".join(ans) ans = int(flag,2) return ans def invconvert(m, key): c_list = [0x37386180af9ae39e, 0xaf754e29895ee11a, 0x85e1a429a2b7030c, 0x964c5a89f6d3ae8c] for t in range(3,-1,-1): if(key[t]>0): m = invshift_opt(m, key[t], c_list[t]) else: m = invshift_ngt(m, key[t], c_list[t]) return m 

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155   import struct import sys def bytes_to_long(s): """Convert a byte string to a long integer (big endian). In Python 3.2+, use the native method instead:: >>> int.from_bytes(s, 'big') For instance:: >>> int.from_bytes(b'\x00P', 'big') 80 This is (essentially) the inverse of :func:long_to_bytes. """ acc = 0 unpack = struct.unpack # Up to Python 2.7.4, struct.unpack can't work with bytearrays nor # memoryviews if sys.version_info[0:3] < (2, 7, 4): if isinstance(s, bytearray): s = bytes(s) elif isinstance(s, memoryview): s = s.tobytes() length = len(s) if length % 4: extra = (4 - length % 4) s = b'\x00' * extra + s length = length + extra for i in range(0, length, 4): acc = (acc << 32) + unpack('>I', s[i:i+4])[0] return acc def long_to_bytes(n, blocksize=0): """Convert an integer to a byte string. In Python 3.2+, use the native method instead:: >>> n.to_bytes(blocksize, 'big') For instance:: >>> n = 80 >>> n.to_bytes(2, 'big') b'\x00P' If the optional :data:blocksize is provided and greater than zero, the byte string is padded with binary zeros (on the front) so that the total length of the output is a multiple of blocksize. If :data:blocksize is zero or not provided, the byte string will be of minimal length. """ # after much testing, this algorithm was deemed to be the fastest s = b'' n = int(n) pack = struct.pack while n > 0: s = pack('>I', n & 0xffffffff) + s n = n >> 32 # strip off leading zeros for i in range(len(s)): if s[i] != b'\x00'[0]: break else: # only happens when n == 0 s = b'\x00' i = 0 s = s[i:] # add back some pad bytes. this could be done more efficiently w.r.t. the # de-padding being done above, but sigh... if blocksize > 0 and len(s) % blocksize: s = (blocksize - len(s) % blocksize) * b'\x00' + s return s def check(s): for i in s: if(i>32 and i<127): continue else: return False return True def shift(m, k, c): if k < 0: return m ^ (m >> (-k)) & c return m ^ ((m << k) & c) def convert(m, key): c_list = [0x37386180af9ae39e, 0xaf754e29895ee11a, 0x85e1a429a2b7030c, 0x964c5a89f6d3ae8c] for t in range(4): m = shift(m, key[t], c_list[t]) return m def invshift_opt(c,k,mask): c = bin(c)[2:].rjust(64,'0') cip=[int(i) for i in c] mask = bin(mask)[2:].rjust(64,'0') mask=[int(i) for i in mask] ans={} idx = 63 for i in range(k): ans[idx]=cip.pop() idx-=1 for i in range(63-k,-1,-1): tmp = cip[i]^(ans[i+k]&mask[i]) ans[i]=tmp flag ='' for i in range(64): flag += str(ans[i]) ans = int(flag,2) return ans def invshift_ngt(c,k,mask): k=-k c = bin(c)[2:].rjust(64,'0') cip=[int(i) for i in c] mask = bin(mask)[2:].rjust(64,'0') mask=[int(i) for i in mask] ans={} for i in range(k): ans[i]=cip[i] for i in range(k,64): tmp = cip[i]^(ans[i-k]&mask[i]) ans[i]=tmp flag ='' for i in range(64): flag += str(ans[i]) ans = int(flag,2) return ans def invconvert(m, key): c_list = [0x37386180af9ae39e, 0xaf754e29895ee11a, 0x85e1a429a2b7030c, 0x964c5a89f6d3ae8c] for t in range(3,-1,-1): if(key[t]>0): m = invshift_opt(m, key[t], c_list[t]) else: m = invshift_ngt(m, key[t], c_list[t]) return m 

## document for JustDecrypt

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150  #!/usr/bin/env python3.9 # -*- coding: utf-8 -*- import string import random import socketserver import signal import codecs from os import urandom from hashlib import sha256 from Crypto.Cipher import AES from flag import FLAG BANNER = rb""" ___ _ ______ _ |_ | | | | _ \ | | | |_ _ ___| |_ | | | |___ ___ _ __ _ _ _ __ | |_ | | | | / __| __| | | | / _ \/ __| '__| | | | '_ \| __| /\__/ / |_| \__ \ |_ | |/ / __/ (__| | | |_| | |_) | |_ \____/ \__,_|___/\__| |___/ \___|\___|_| \__, | .__/ \__| __/ | | |___/|_| """ BLOCK_SIZE = 16 class AES_CFB(object): def __init__(self): self.key = urandom(BLOCK_SIZE) self.iv = urandom(16) self.aes_encrypt = AES.new(self.key, AES.MODE_CFB, self.iv) self.aes_decrypt = AES.new(self.key, AES.MODE_CFB, self.iv) def encrypt(self, plain): return self.aes_encrypt.encrypt(self.pad(plain)) def decrypt(self, cipher): return self.unpad(self.aes_decrypt.decrypt(cipher)) @staticmethod def pad(s): num = BLOCK_SIZE - (len(s) % BLOCK_SIZE) return s + bytes([num] * num) @staticmethod def unpad(s): return s[:-s[-1]] class Task(socketserver.BaseRequestHandler): def _recvall(self): BUFF_SIZE = 1024 data = b'' while True: part = self.request.recv(BUFF_SIZE) data += part if len(part) < BUFF_SIZE: break return data.strip() def send(self, msg, newline=True): try: if newline: msg += b'\n' self.request.sendall(msg) except: pass def recv(self, prompt=b'> '): self.send(prompt, newline=False) return self._recvall() def proof_of_work(self): random.seed(urandom(32)) alphabet = string.ascii_letters + string.digits proof = ''.join(random.choices(alphabet, k=32)) hash_value = sha256(proof.encode()).hexdigest() self.send(f'sha256(XXXX+{proof[4:]}) == {hash_value}'.encode()) nonce = self.recv(prompt=b'Give me XXXX > ') if len(nonce) != 4 or sha256(nonce + proof[4:].encode()).hexdigest() != hash_value: return False return True def timeout_handler(self, signum, frame): raise TimeoutError def handle(self): try: signal.signal(signal.SIGALRM, self.timeout_handler) signal.alarm(60) self.send(BANNER) # if not self.proof_of_work(): # self.send(b'\nWrong!') # self.request.close() # return self.send(b"It's just a decryption system. And I heard that only the Bytedancer can get secret.") aes = AES_CFB() # signal.alarm(300) for i in range(52): cipher_hex = self.recv(prompt=b'Please enter your cipher in hex > ') if len(cipher_hex) > 2048: self.send(b"It's too long!") continue try: cipher = codecs.decode(cipher_hex, 'hex') except: self.send(b'Not hex data!') continue if len(cipher) == 0 or len(cipher) % BLOCK_SIZE != 0: self.send(f'Cipher length must be a multiple of {BLOCK_SIZE}!'.encode()) continue plaintext = aes.decrypt(cipher) plaintext_hex = codecs.encode(plaintext, 'hex') self.send(b'Your plaintext in hex: \n%s\n' % plaintext_hex) if plaintext == b"Hello, I'm a Bytedancer. Please give me the flag!": self.send(b'OK! Here is your flag: ') self.send(FLAG.encode()) break self.send(b'Bye!\n') except TimeoutError: self.send(b'\nTimeout!') except Exception as err: self.send(b'Something Wrong!') finally: self.request.close() class ForkedServer(socketserver.ForkingMixIn, socketserver.TCPServer): pass if __name__ == "__main__": HOST, PORT = '0.0.0.0', 30002 print(HOST, PORT) server = ForkedServer((HOST, PORT), Task) server.allow_reuse_address = True server.serve_forever() 

• CFB模式机制
  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67  from Crypto.Util.number import * from pwn import * from tqdm import tqdm def main(): r = remote('0.0.0.0', '30002') plaintext = b"Hello, I'm a Bytedancer. Please give me the flag!"+b"\x0f"*15 def my_XOR(a, b): assert len(a) == len(b) return b''.join([long_to_bytes(a[i]^b[i]) for i in range(len(a))]) def proof_of_work(): rev = r.recvuntil(b"sha256(XXXX+") suffix = r.recv(28).decode() rev = r.recvuntil(b" == ") tar = r.recv(64).decode() def f(x): hashresult = hashlib.sha256(x.encode()+suffix.encode()).hexdigest() return hashresult == tar prefix = util.iters.mbruteforce(f, string.digits + string.ascii_letters, 4, 'upto') r.recvuntil(b'Give me XXXX > ') r.sendline(prefix.encode()) def decrypt(msg): newmsg = msg + b'\x00'*(256+64-len(msg)) r.recvuntil(b'Please enter your cipher in hex > ') r.sendline(newmsg.hex().encode()) r.recvline() result = r.recvline().decode().strip() return bytes.fromhex(result) def decrypt_(msg): newmsg = msg + b'\x00'*(256-len(msg)) r.recvuntil(b'Please enter your cipher in hex > ') r.sendline(newmsg.hex().encode()) r.recvline() result = r.recvline().decode().strip() return bytes.fromhex(result) # proof_of_work() msg = b'\x00'*16 decrypt(msg) c = b"" for i in range(50): t = decrypt(c)[i] c += long_to_bytes(t^plaintext[i]) decc = decrypt_(c) print(decc) res = r.recvline()+r.recvline() if b"Here is your flag" in res: print(r.recvline()) print(r.recvline()) r.close() return (True, len(decc)) r.close() return (False, len(decc)) ll = [] while True: ss = main() ll.append(ss[1]) if ss[0]: break print(len(ll), ll)