mirror of
https://github.com/checkpoint-restore/criu.git
synced 2026-01-23 02:14:37 +00:00
lib/py: reduce code duplication
Refactor lib/py/images/images.py to reduce code duplication by extracting repetitive code into helper functions and private methods. This improves code readability and maintainability, as well as reducing the risk of bugs caused by duplicated code. Additionally, in Makefile, lib/py/images/images.py is added to the list of files to run by flake8 during CI. Fixes: #340 Signed-off-by: Kouame Behouba Manasse <behouba@gmail.com>
This commit is contained in:
parent
85b5c1e451
commit
a0cc95c03e
2 changed files with 53 additions and 52 deletions
1
Makefile
1
Makefile
|
|
@ -417,6 +417,7 @@ lint:
|
|||
flake8 --config=scripts/flake8.cfg test/inhfd/*.py
|
||||
flake8 --config=scripts/flake8.cfg test/others/rpc/config_file.py
|
||||
flake8 --config=scripts/flake8.cfg lib/py/images/pb2dict.py
|
||||
flake8 --config=scripts/flake8.cfg lib/py/images/images.py
|
||||
flake8 --config=scripts/flake8.cfg scripts/criu-ns
|
||||
flake8 --config=scripts/flake8.cfg scripts/crit-setup.py
|
||||
flake8 --config=scripts/flake8.cfg coredump/
|
||||
|
|
|
|||
|
|
@ -69,6 +69,22 @@ class MagicException(Exception):
|
|||
self.magic = magic
|
||||
|
||||
|
||||
def decode_base64_data(data):
|
||||
"""A helper function to decode base64 data."""
|
||||
if (sys.version_info > (3, 0)):
|
||||
return base64.decodebytes(str.encode(data))
|
||||
else:
|
||||
return base64.decodebytes(data)
|
||||
|
||||
|
||||
def write_base64_data(f, data):
|
||||
"""A helper function to write base64 encoded data to a file."""
|
||||
if (sys.version_info > (3, 0)):
|
||||
f.write(base64.decodebytes(str.encode(data)))
|
||||
else:
|
||||
f.write(base64.decodebytes(data))
|
||||
|
||||
|
||||
# Generic class to handle loading/dumping criu images entries from/to bin
|
||||
# format to/from dict(json).
|
||||
class entry_handler:
|
||||
|
|
@ -285,15 +301,9 @@ class ghost_file_handler:
|
|||
size = len(pb_str)
|
||||
f.write(struct.pack('i', size))
|
||||
f.write(pb_str)
|
||||
if (sys.version_info > (3, 0)):
|
||||
f.write(base64.decodebytes(str.encode(item['extra'])))
|
||||
else:
|
||||
f.write(base64.decodebytes(item['extra']))
|
||||
write_base64_data(f, item['extra'])
|
||||
else:
|
||||
if (sys.version_info > (3, 0)):
|
||||
f.write(base64.decodebytes(str.encode(item['extra'])))
|
||||
else:
|
||||
f.write(base64.decodebytes(item['extra']))
|
||||
write_base64_data(f, item['extra'])
|
||||
|
||||
def dumps(self, entries):
|
||||
f = io.BytesIO('')
|
||||
|
|
@ -314,10 +324,7 @@ class pipes_data_extra_handler:
|
|||
return base64.encodebytes(data).decode('utf-8')
|
||||
|
||||
def dump(self, extra, f, pload):
|
||||
if (sys.version_info > (3, 0)):
|
||||
data = base64.decodebytes(str.encode(extra))
|
||||
else:
|
||||
data = base64.decodebytes(extra)
|
||||
data = decode_base64_data(extra)
|
||||
f.write(data)
|
||||
|
||||
def skip(self, f, pload):
|
||||
|
|
@ -332,10 +339,7 @@ class sk_queues_extra_handler:
|
|||
return base64.encodebytes(data).decode('utf-8')
|
||||
|
||||
def dump(self, extra, f, _unused):
|
||||
if (sys.version_info > (3, 0)):
|
||||
data = base64.decodebytes(str.encode(extra))
|
||||
else:
|
||||
data = base64.decodebytes(extra)
|
||||
data = decode_base64_data(extra)
|
||||
f.write(data)
|
||||
|
||||
def skip(self, f, pload):
|
||||
|
|
@ -356,12 +360,8 @@ class tcp_stream_extra_handler:
|
|||
return d
|
||||
|
||||
def dump(self, extra, f, _unused):
|
||||
if (sys.version_info > (3, 0)):
|
||||
inq = base64.decodebytes(str.encode(extra['inq']))
|
||||
outq = base64.decodebytes(str.encode(extra['outq']))
|
||||
else:
|
||||
inq = base64.decodebytes(extra['inq'])
|
||||
outq = base64.decodebytes(extra['outq'])
|
||||
inq = decode_base64_data(extra['inq'])
|
||||
outq = decode_base64_data(extra['outq'])
|
||||
|
||||
f.write(inq)
|
||||
f.write(outq)
|
||||
|
|
@ -370,6 +370,7 @@ class tcp_stream_extra_handler:
|
|||
f.seek(0, os.SEEK_END)
|
||||
return pbuff.inq_len + pbuff.outq_len
|
||||
|
||||
|
||||
class bpfmap_data_extra_handler:
|
||||
def load(self, f, pload):
|
||||
size = pload.keys_bytes + pload.values_bytes
|
||||
|
|
@ -384,14 +385,13 @@ class bpfmap_data_extra_handler:
|
|||
f.seek(pload.bytes, os.SEEK_CUR)
|
||||
return pload.bytes
|
||||
|
||||
|
||||
class ipc_sem_set_handler:
|
||||
def load(self, f, pbuff):
|
||||
entry = pb2dict.pb2dict(pbuff)
|
||||
size = sizeof_u16 * entry['nsems']
|
||||
rounded = round_up(size, sizeof_u64)
|
||||
s = array.array('H')
|
||||
if s.itemsize != sizeof_u16:
|
||||
raise Exception("Array size mismatch")
|
||||
s = self._get_sem_array()
|
||||
s.frombytes(f.read(size))
|
||||
f.seek(rounded - size, 1)
|
||||
return s.tolist()
|
||||
|
|
@ -400,9 +400,7 @@ class ipc_sem_set_handler:
|
|||
entry = pb2dict.pb2dict(pbuff)
|
||||
size = sizeof_u16 * entry['nsems']
|
||||
rounded = round_up(size, sizeof_u64)
|
||||
s = array.array('H')
|
||||
if s.itemsize != sizeof_u16:
|
||||
raise Exception("Array size mismatch")
|
||||
s = self._get_sem_array()
|
||||
s.fromlist(extra)
|
||||
if len(s) != entry['nsems']:
|
||||
raise Exception("Number of semaphores mismatch")
|
||||
|
|
@ -415,23 +413,16 @@ class ipc_sem_set_handler:
|
|||
f.seek(round_up(size, sizeof_u64), os.SEEK_CUR)
|
||||
return size
|
||||
|
||||
def _get_sem_array(self):
|
||||
s = array.array('H')
|
||||
if s.itemsize != sizeof_u16:
|
||||
raise Exception("Array size mismatch")
|
||||
return s
|
||||
|
||||
|
||||
class ipc_msg_queue_handler:
|
||||
def load(self, f, pbuff):
|
||||
entry = pb2dict.pb2dict(pbuff)
|
||||
messages = []
|
||||
for x in range(0, entry['qnum']):
|
||||
buf = f.read(4)
|
||||
if len(buf) == 0:
|
||||
break
|
||||
size, = struct.unpack('i', buf)
|
||||
msg = pb.ipc_msg()
|
||||
msg.ParseFromString(f.read(size))
|
||||
rounded = round_up(msg.msize, sizeof_u64)
|
||||
data = f.read(msg.msize)
|
||||
f.seek(rounded - msg.msize, 1)
|
||||
messages.append(pb2dict.pb2dict(msg))
|
||||
messages.append(base64.encodebytes(data).decode('utf-8'))
|
||||
messages, _ = self._read_messages(f, pbuff)
|
||||
return messages
|
||||
|
||||
def dump(self, extra, f, pbuff):
|
||||
|
|
@ -443,15 +434,17 @@ class ipc_msg_queue_handler:
|
|||
f.write(struct.pack('i', size))
|
||||
f.write(msg_str)
|
||||
rounded = round_up(msg.msize, sizeof_u64)
|
||||
if (sys.version_info > (3, 0)):
|
||||
data = base64.decodebytes(str.encode(extra[i + 1]))
|
||||
else:
|
||||
data = base64.decodebytes(extra[i + 1])
|
||||
data = decode_base64_data(extra[i + 1])
|
||||
f.write(data[:msg.msize])
|
||||
f.write(b'\0' * (rounded - msg.msize))
|
||||
|
||||
def skip(self, f, pbuff):
|
||||
_, pl_len = self._read_messages(f, pbuff, skip_data=True)
|
||||
return pl_len
|
||||
|
||||
def _read_messages(self, f, pbuff, skip_data=False):
|
||||
entry = pb2dict.pb2dict(pbuff)
|
||||
messages = []
|
||||
pl_len = 0
|
||||
for x in range(0, entry['qnum']):
|
||||
buf = f.read(4)
|
||||
|
|
@ -461,10 +454,17 @@ class ipc_msg_queue_handler:
|
|||
msg = pb.ipc_msg()
|
||||
msg.ParseFromString(f.read(size))
|
||||
rounded = round_up(msg.msize, sizeof_u64)
|
||||
f.seek(rounded, os.SEEK_CUR)
|
||||
pl_len += size + msg.msize
|
||||
|
||||
return pl_len
|
||||
if skip_data:
|
||||
f.seek(rounded, os.SEEK_CUR)
|
||||
else:
|
||||
data = f.read(msg.msize)
|
||||
f.seek(rounded - msg.msize, 1)
|
||||
messages.append(pb2dict.pb2dict(msg))
|
||||
messages.append(base64.encodebytes(data).decode('utf-8'))
|
||||
|
||||
return messages, pl_len
|
||||
|
||||
|
||||
class ipc_shm_handler:
|
||||
|
|
@ -560,7 +560,7 @@ handlers = {
|
|||
'MEMFD_INODE': entry_handler(pb.memfd_inode_entry),
|
||||
'BPFMAP_FILE': entry_handler(pb.bpfmap_file_entry),
|
||||
'BPFMAP_DATA': entry_handler(pb.bpfmap_data_entry,
|
||||
bpfmap_data_extra_handler()),
|
||||
bpfmap_data_extra_handler()),
|
||||
'APPARMOR': entry_handler(pb.apparmor_entry),
|
||||
}
|
||||
|
||||
|
|
@ -574,12 +574,12 @@ def __rhandler(f):
|
|||
|
||||
try:
|
||||
m = magic.by_val[img_magic]
|
||||
except:
|
||||
except Exception:
|
||||
raise MagicException(img_magic)
|
||||
|
||||
try:
|
||||
handler = handlers[m]
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("No handler found for image with magic " + m)
|
||||
|
||||
return m, handler
|
||||
|
|
@ -641,7 +641,7 @@ def dump(img, f):
|
|||
|
||||
try:
|
||||
handler = handlers[m]
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("No handler found for image with such magic")
|
||||
|
||||
handler.dump(img['entries'], f)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue