Project

General

Profile

Bug #21568 » arv-mount-stress-test.py

Version 2 - Brett Smith, 03/01/2024 11:57 PM

 
1
#!/usr/bin/env python3
2

    
3
import argparse
4
import dataclasses
5
import itertools
6
import logging
7
import logging.handlers
8
# Using multiprocessing not for the usual GIL avoidance reasons,
9
# but because of historical risks of mixing subprocess+threading.
10
import multiprocessing
11
import os
12
import subprocess
13
import sys
14
import tempfile
15
import time
16

    
17
from pathlib import Path
18
from typing import Optional
19

    
20
logger = logging.getLogger('arv-mount-stress')
21
logger.addHandler(logging.handlers.SysLogHandler('/dev/log'))
22
logger.setLevel(logging.DEBUG)
23

    
24
@dataclasses.dataclass
25
class MountState:
26
    mount_path: Path
27
    unit_name: str
28
    ready_flag: multiprocessing.Event
29
    crash_flag: multiprocessing.Event
30
    returncode: Optional[int]
31

    
32
    def __init__(self, mount_parent):
33
        self.mount_path = Path(tempfile.mkdtemp(
34
            prefix='arv-mount-stress-',
35
            dir=mount_parent,
36
        ))
37
        self.unit_name = f'{self.mount_path.stem}.service'
38
        self.ready_flag = multiprocessing.Event()
39
        self.crash_flag = multiprocessing.Event()
40
        self.returncode = None
41

    
42

    
43
def follow_journal(mount_state, journal_fd):
44
    with open(journal_fd) as journal_out:
45
        for line in journal_out:
46
            if line.endswith(' ERROR: Unhandled exception during FUSE operation\n'):
47
                mount_state.crash_flag.set()
48
    
49
def schedule_tries(mount_state, start_sleep=90, sleep_mult=2, stop_sleep=900):
50
    tries = itertools.count(1)
51
    sleep_time = start_sleep
52
    while sleep_time < stop_sleep:
53
        start_time = time.time()
54
        yield next(tries)
55
        if mount_state.crash_flag.is_set():
56
            break
57
        else:
58
            elapsed_time = time.time() - start_time
59
            time.sleep(max(0, sleep_time - elapsed_time))
60
            sleep_time *= sleep_mult
61

    
62
def stress_mount_unlimited_subprocesses(mount_state):
63
    procs = [
64
        subprocess.Popen(
65
            ['ls', '-lR', str(path)],
66
            stdin=subprocess.DEVNULL,
67
            stdout=subprocess.DEVNULL,
68
        ) for path in mount_state.mount_path.iterdir()
69
    ]
70
    logger.debug("Running %s ls processes", len(procs))
71
    result = max(proc.wait() for proc in procs)
72
    logger.debug("Stress returncode = %d", result)
73
    return result
74

    
75
def walk_dir(start_path):
76
    path_queue = [start_path]
77
    while path_queue:
78
        path_queue.extend(
79
            path
80
            for path in path_queue.pop().iterdir()
81
            if path.is_dir()
82
        )
83

    
84
def stress_mount_process_pool(mount_state, pool_size):
85
    with multiprocessing.Pool(pool_size) as pool:
86
        for result in pool.imap_unordered(walk_dir, mount_state.mount_path.iterdir()):
87
            pass
88
    return os.EX_OK
89

    
90
def clean_mount(mount_state):
91
    with subprocess.Popen(
92
        ['fusermount', '-qu', str(mount_state.mount_path)],
93
        stdin=subprocess.DEVNULL,
94
    ) as umount_proc:
95
        try:
96
            umount_proc.wait(10)
97
        except subprocess.TimeoutExpired:
98
            subprocess.run([
99
                'systemctl', '--user',
100
                'kill', mount_state.unit_name,
101
            ])
102
            umount_proc.wait(20)
103
    return umount_proc.returncode
104

    
105
def parse_arguments(arglist=None):
106
    parser = argparse.ArgumentParser()
107
    parser.add_argument(
108
        '--jobs', '-j',
109
        type=int,
110
        default=8,
111
        help="Number of parallel accesses during test",
112
    )
113
    return parser.parse_known_args(arglist)
114

    
115
def main(arglist=None):
116
    args, arv_mount_opts = parse_arguments(arglist)
117
    mount_parent = Path(
118
        os.environ.get('XDG_RUNTIME_DIR')
119
        or os.environ.get('TMPDIR')
120
        or '/tmp'
121
    )
122
    mount_state = MountState(mount_parent)
123
    logger.debug("starting mount service %s", mount_state.unit_name)
124
    unit_arg = f'--unit={mount_state.unit_name}'
125
    subprocess.run([
126
        'systemd-run', '--user', unit_arg, '--quiet',
127
        'arv-mount', '--foreground', '--read-only', '--shared',
128
        f'--directory-cache={2 << 20}',
129
        *arv_mount_opts,
130
        '--', str(mount_state.mount_path),
131
    ], stdin=subprocess.DEVNULL, check=True)
132
    journal_proc = subprocess.Popen(
133
        ['journalctl', '--user', unit_arg, '--follow', '--output=cat'],
134
        stdin=subprocess.DEVNULL,
135
        stdout=subprocess.PIPE,
136
    )
137
    follow_proc = multiprocessing.Process(
138
        target=follow_journal,
139
        args=(mount_state, journal_proc.stdout.fileno()),
140
    )
141
    follow_proc.start()
142

    
143
    try:
144
        logger.debug("waiting for mount")
145
        for _ in schedule_tries(mount_state, 1, 2, 60):
146
            if have_contents := any(mount_state.mount_path.iterdir()):
147
                break
148
        assert have_contents, "mount never had contents"
149
        for count in schedule_tries(mount_state):
150
            logger.debug("starting stress test #%d", count)
151
            stress_returncode = stress_mount_process_pool(mount_state, args.jobs)
152
            if stress_returncode != os.EX_OK:
153
                break
154
    finally:
155
        journal_proc.terminate()
156
        umount_returncode = clean_mount(mount_state)
157
        follow_proc.join()
158
        mount_state.mount_path.rmdir()
159

    
160
    if stress_returncode != os.EX_OK:
161
        if sys.stdout.isatty():
162
            subprocess.run(['journalctl', '--user', '--no-pager', unit_arg])
163
        return stress_returncode
164
    else:
165
        return umount_returncode
166

    
167
if __name__ == '__main__':
168
    log_fmt = logging.Formatter('%(asctime)s %(name)s[%(process)d]: %(levelname)s: %(message)s')
169
    log_handler = logging.StreamHandler()
170
    log_handler.setFormatter(log_fmt)
171
    logger.addHandler(log_handler)
172
    sys.exit(main())
(2-2/2)