lib/py: reduce code duplication

Refactor lib/py/images/images.py to reduce code duplication
by extracting repetitive code into helper functions and
private methods. This improves code readability and maintainability,
as well as reducing the risk of bugs caused by duplicated code.
Additionally, in Makefile, lib/py/images/images.py is added to the
list of files to run by flake8 during CI.

Fixes: #340

Signed-off-by: Kouame Behouba Manasse <behouba@gmail.com>
This commit is contained in:
Kouame Behouba Manasse 2023-02-24 05:32:44 +03:00 committed by Andrei Vagin
parent 85b5c1e451
commit a0cc95c03e
2 changed files with 53 additions and 52 deletions

View file

@ -417,6 +417,7 @@ lint:
flake8 --config=scripts/flake8.cfg test/inhfd/*.py
flake8 --config=scripts/flake8.cfg test/others/rpc/config_file.py
flake8 --config=scripts/flake8.cfg lib/py/images/pb2dict.py
flake8 --config=scripts/flake8.cfg lib/py/images/images.py
flake8 --config=scripts/flake8.cfg scripts/criu-ns
flake8 --config=scripts/flake8.cfg scripts/crit-setup.py
flake8 --config=scripts/flake8.cfg coredump/

View file

@ -69,6 +69,22 @@ class MagicException(Exception):
self.magic = magic
def decode_base64_data(data):
"""A helper function to decode base64 data."""
if (sys.version_info > (3, 0)):
return base64.decodebytes(str.encode(data))
else:
return base64.decodebytes(data)
def write_base64_data(f, data):
"""A helper function to write base64 encoded data to a file."""
if (sys.version_info > (3, 0)):
f.write(base64.decodebytes(str.encode(data)))
else:
f.write(base64.decodebytes(data))
# Generic class to handle loading/dumping criu images entries from/to bin
# format to/from dict(json).
class entry_handler:
@ -285,15 +301,9 @@ class ghost_file_handler:
size = len(pb_str)
f.write(struct.pack('i', size))
f.write(pb_str)
if (sys.version_info > (3, 0)):
f.write(base64.decodebytes(str.encode(item['extra'])))
else:
f.write(base64.decodebytes(item['extra']))
write_base64_data(f, item['extra'])
else:
if (sys.version_info > (3, 0)):
f.write(base64.decodebytes(str.encode(item['extra'])))
else:
f.write(base64.decodebytes(item['extra']))
write_base64_data(f, item['extra'])
def dumps(self, entries):
f = io.BytesIO('')
@ -314,10 +324,7 @@ class pipes_data_extra_handler:
return base64.encodebytes(data).decode('utf-8')
def dump(self, extra, f, pload):
if (sys.version_info > (3, 0)):
data = base64.decodebytes(str.encode(extra))
else:
data = base64.decodebytes(extra)
data = decode_base64_data(extra)
f.write(data)
def skip(self, f, pload):
@ -332,10 +339,7 @@ class sk_queues_extra_handler:
return base64.encodebytes(data).decode('utf-8')
def dump(self, extra, f, _unused):
if (sys.version_info > (3, 0)):
data = base64.decodebytes(str.encode(extra))
else:
data = base64.decodebytes(extra)
data = decode_base64_data(extra)
f.write(data)
def skip(self, f, pload):
@ -356,12 +360,8 @@ class tcp_stream_extra_handler:
return d
def dump(self, extra, f, _unused):
if (sys.version_info > (3, 0)):
inq = base64.decodebytes(str.encode(extra['inq']))
outq = base64.decodebytes(str.encode(extra['outq']))
else:
inq = base64.decodebytes(extra['inq'])
outq = base64.decodebytes(extra['outq'])
inq = decode_base64_data(extra['inq'])
outq = decode_base64_data(extra['outq'])
f.write(inq)
f.write(outq)
@ -370,6 +370,7 @@ class tcp_stream_extra_handler:
f.seek(0, os.SEEK_END)
return pbuff.inq_len + pbuff.outq_len
class bpfmap_data_extra_handler:
def load(self, f, pload):
size = pload.keys_bytes + pload.values_bytes
@ -384,14 +385,13 @@ class bpfmap_data_extra_handler:
f.seek(pload.bytes, os.SEEK_CUR)
return pload.bytes
class ipc_sem_set_handler:
def load(self, f, pbuff):
entry = pb2dict.pb2dict(pbuff)
size = sizeof_u16 * entry['nsems']
rounded = round_up(size, sizeof_u64)
s = array.array('H')
if s.itemsize != sizeof_u16:
raise Exception("Array size mismatch")
s = self._get_sem_array()
s.frombytes(f.read(size))
f.seek(rounded - size, 1)
return s.tolist()
@ -400,9 +400,7 @@ class ipc_sem_set_handler:
entry = pb2dict.pb2dict(pbuff)
size = sizeof_u16 * entry['nsems']
rounded = round_up(size, sizeof_u64)
s = array.array('H')
if s.itemsize != sizeof_u16:
raise Exception("Array size mismatch")
s = self._get_sem_array()
s.fromlist(extra)
if len(s) != entry['nsems']:
raise Exception("Number of semaphores mismatch")
@ -415,23 +413,16 @@ class ipc_sem_set_handler:
f.seek(round_up(size, sizeof_u64), os.SEEK_CUR)
return size
def _get_sem_array(self):
s = array.array('H')
if s.itemsize != sizeof_u16:
raise Exception("Array size mismatch")
return s
class ipc_msg_queue_handler:
def load(self, f, pbuff):
entry = pb2dict.pb2dict(pbuff)
messages = []
for x in range(0, entry['qnum']):
buf = f.read(4)
if len(buf) == 0:
break
size, = struct.unpack('i', buf)
msg = pb.ipc_msg()
msg.ParseFromString(f.read(size))
rounded = round_up(msg.msize, sizeof_u64)
data = f.read(msg.msize)
f.seek(rounded - msg.msize, 1)
messages.append(pb2dict.pb2dict(msg))
messages.append(base64.encodebytes(data).decode('utf-8'))
messages, _ = self._read_messages(f, pbuff)
return messages
def dump(self, extra, f, pbuff):
@ -443,15 +434,17 @@ class ipc_msg_queue_handler:
f.write(struct.pack('i', size))
f.write(msg_str)
rounded = round_up(msg.msize, sizeof_u64)
if (sys.version_info > (3, 0)):
data = base64.decodebytes(str.encode(extra[i + 1]))
else:
data = base64.decodebytes(extra[i + 1])
data = decode_base64_data(extra[i + 1])
f.write(data[:msg.msize])
f.write(b'\0' * (rounded - msg.msize))
def skip(self, f, pbuff):
_, pl_len = self._read_messages(f, pbuff, skip_data=True)
return pl_len
def _read_messages(self, f, pbuff, skip_data=False):
entry = pb2dict.pb2dict(pbuff)
messages = []
pl_len = 0
for x in range(0, entry['qnum']):
buf = f.read(4)
@ -461,10 +454,17 @@ class ipc_msg_queue_handler:
msg = pb.ipc_msg()
msg.ParseFromString(f.read(size))
rounded = round_up(msg.msize, sizeof_u64)
f.seek(rounded, os.SEEK_CUR)
pl_len += size + msg.msize
return pl_len
if skip_data:
f.seek(rounded, os.SEEK_CUR)
else:
data = f.read(msg.msize)
f.seek(rounded - msg.msize, 1)
messages.append(pb2dict.pb2dict(msg))
messages.append(base64.encodebytes(data).decode('utf-8'))
return messages, pl_len
class ipc_shm_handler:
@ -560,7 +560,7 @@ handlers = {
'MEMFD_INODE': entry_handler(pb.memfd_inode_entry),
'BPFMAP_FILE': entry_handler(pb.bpfmap_file_entry),
'BPFMAP_DATA': entry_handler(pb.bpfmap_data_entry,
bpfmap_data_extra_handler()),
bpfmap_data_extra_handler()),
'APPARMOR': entry_handler(pb.apparmor_entry),
}
@ -574,12 +574,12 @@ def __rhandler(f):
try:
m = magic.by_val[img_magic]
except:
except Exception:
raise MagicException(img_magic)
try:
handler = handlers[m]
except:
except Exception:
raise Exception("No handler found for image with magic " + m)
return m, handler
@ -641,7 +641,7 @@ def dump(img, f):
try:
handler = handlers[m]
except:
except Exception:
raise Exception("No handler found for image with such magic")
handler.dump(img['entries'], f)