diff --git a/Makefile b/Makefile index aabe28a92..24318d692 100644 --- a/Makefile +++ b/Makefile @@ -417,6 +417,7 @@ lint: flake8 --config=scripts/flake8.cfg test/inhfd/*.py flake8 --config=scripts/flake8.cfg test/others/rpc/config_file.py flake8 --config=scripts/flake8.cfg lib/py/images/pb2dict.py + flake8 --config=scripts/flake8.cfg lib/py/images/images.py flake8 --config=scripts/flake8.cfg scripts/criu-ns flake8 --config=scripts/flake8.cfg scripts/crit-setup.py flake8 --config=scripts/flake8.cfg coredump/ diff --git a/lib/py/images/images.py b/lib/py/images/images.py index df4f92ac9..a1d76e7cf 100644 --- a/lib/py/images/images.py +++ b/lib/py/images/images.py @@ -69,6 +69,22 @@ class MagicException(Exception): self.magic = magic +def decode_base64_data(data): + """A helper function to decode base64 data.""" + if (sys.version_info > (3, 0)): + return base64.decodebytes(str.encode(data)) + else: + return base64.decodebytes(data) + + +def write_base64_data(f, data): + """A helper function to write base64 encoded data to a file.""" + if (sys.version_info > (3, 0)): + f.write(base64.decodebytes(str.encode(data))) + else: + f.write(base64.decodebytes(data)) + + # Generic class to handle loading/dumping criu images entries from/to bin # format to/from dict(json). class entry_handler: @@ -285,15 +301,9 @@ class ghost_file_handler: size = len(pb_str) f.write(struct.pack('i', size)) f.write(pb_str) - if (sys.version_info > (3, 0)): - f.write(base64.decodebytes(str.encode(item['extra']))) - else: - f.write(base64.decodebytes(item['extra'])) + write_base64_data(f, item['extra']) else: - if (sys.version_info > (3, 0)): - f.write(base64.decodebytes(str.encode(item['extra']))) - else: - f.write(base64.decodebytes(item['extra'])) + write_base64_data(f, item['extra']) def dumps(self, entries): f = io.BytesIO('') @@ -314,10 +324,7 @@ class pipes_data_extra_handler: return base64.encodebytes(data).decode('utf-8') def dump(self, extra, f, pload): - if (sys.version_info > (3, 0)): - data = base64.decodebytes(str.encode(extra)) - else: - data = base64.decodebytes(extra) + data = decode_base64_data(extra) f.write(data) def skip(self, f, pload): @@ -332,10 +339,7 @@ class sk_queues_extra_handler: return base64.encodebytes(data).decode('utf-8') def dump(self, extra, f, _unused): - if (sys.version_info > (3, 0)): - data = base64.decodebytes(str.encode(extra)) - else: - data = base64.decodebytes(extra) + data = decode_base64_data(extra) f.write(data) def skip(self, f, pload): @@ -356,12 +360,8 @@ class tcp_stream_extra_handler: return d def dump(self, extra, f, _unused): - if (sys.version_info > (3, 0)): - inq = base64.decodebytes(str.encode(extra['inq'])) - outq = base64.decodebytes(str.encode(extra['outq'])) - else: - inq = base64.decodebytes(extra['inq']) - outq = base64.decodebytes(extra['outq']) + inq = decode_base64_data(extra['inq']) + outq = decode_base64_data(extra['outq']) f.write(inq) f.write(outq) @@ -370,6 +370,7 @@ class tcp_stream_extra_handler: f.seek(0, os.SEEK_END) return pbuff.inq_len + pbuff.outq_len + class bpfmap_data_extra_handler: def load(self, f, pload): size = pload.keys_bytes + pload.values_bytes @@ -384,14 +385,13 @@ class bpfmap_data_extra_handler: f.seek(pload.bytes, os.SEEK_CUR) return pload.bytes + class ipc_sem_set_handler: def load(self, f, pbuff): entry = pb2dict.pb2dict(pbuff) size = sizeof_u16 * entry['nsems'] rounded = round_up(size, sizeof_u64) - s = array.array('H') - if s.itemsize != sizeof_u16: - raise Exception("Array size mismatch") + s = self._get_sem_array() s.frombytes(f.read(size)) f.seek(rounded - size, 1) return s.tolist() @@ -400,9 +400,7 @@ class ipc_sem_set_handler: entry = pb2dict.pb2dict(pbuff) size = sizeof_u16 * entry['nsems'] rounded = round_up(size, sizeof_u64) - s = array.array('H') - if s.itemsize != sizeof_u16: - raise Exception("Array size mismatch") + s = self._get_sem_array() s.fromlist(extra) if len(s) != entry['nsems']: raise Exception("Number of semaphores mismatch") @@ -415,23 +413,16 @@ class ipc_sem_set_handler: f.seek(round_up(size, sizeof_u64), os.SEEK_CUR) return size + def _get_sem_array(self): + s = array.array('H') + if s.itemsize != sizeof_u16: + raise Exception("Array size mismatch") + return s + class ipc_msg_queue_handler: def load(self, f, pbuff): - entry = pb2dict.pb2dict(pbuff) - messages = [] - for x in range(0, entry['qnum']): - buf = f.read(4) - if len(buf) == 0: - break - size, = struct.unpack('i', buf) - msg = pb.ipc_msg() - msg.ParseFromString(f.read(size)) - rounded = round_up(msg.msize, sizeof_u64) - data = f.read(msg.msize) - f.seek(rounded - msg.msize, 1) - messages.append(pb2dict.pb2dict(msg)) - messages.append(base64.encodebytes(data).decode('utf-8')) + messages, _ = self._read_messages(f, pbuff) return messages def dump(self, extra, f, pbuff): @@ -443,15 +434,17 @@ class ipc_msg_queue_handler: f.write(struct.pack('i', size)) f.write(msg_str) rounded = round_up(msg.msize, sizeof_u64) - if (sys.version_info > (3, 0)): - data = base64.decodebytes(str.encode(extra[i + 1])) - else: - data = base64.decodebytes(extra[i + 1]) + data = decode_base64_data(extra[i + 1]) f.write(data[:msg.msize]) f.write(b'\0' * (rounded - msg.msize)) def skip(self, f, pbuff): + _, pl_len = self._read_messages(f, pbuff, skip_data=True) + return pl_len + + def _read_messages(self, f, pbuff, skip_data=False): entry = pb2dict.pb2dict(pbuff) + messages = [] pl_len = 0 for x in range(0, entry['qnum']): buf = f.read(4) @@ -461,10 +454,17 @@ class ipc_msg_queue_handler: msg = pb.ipc_msg() msg.ParseFromString(f.read(size)) rounded = round_up(msg.msize, sizeof_u64) - f.seek(rounded, os.SEEK_CUR) pl_len += size + msg.msize - return pl_len + if skip_data: + f.seek(rounded, os.SEEK_CUR) + else: + data = f.read(msg.msize) + f.seek(rounded - msg.msize, 1) + messages.append(pb2dict.pb2dict(msg)) + messages.append(base64.encodebytes(data).decode('utf-8')) + + return messages, pl_len class ipc_shm_handler: @@ -560,7 +560,7 @@ handlers = { 'MEMFD_INODE': entry_handler(pb.memfd_inode_entry), 'BPFMAP_FILE': entry_handler(pb.bpfmap_file_entry), 'BPFMAP_DATA': entry_handler(pb.bpfmap_data_entry, - bpfmap_data_extra_handler()), + bpfmap_data_extra_handler()), 'APPARMOR': entry_handler(pb.apparmor_entry), } @@ -574,12 +574,12 @@ def __rhandler(f): try: m = magic.by_val[img_magic] - except: + except Exception: raise MagicException(img_magic) try: handler = handlers[m] - except: + except Exception: raise Exception("No handler found for image with magic " + m) return m, handler @@ -641,7 +641,7 @@ def dump(img, f): try: handler = handlers[m] - except: + except Exception: raise Exception("No handler found for image with such magic") handler.dump(img['entries'], f)