from sys import stdout, stderr import signal import time from functools import partial, reduce from itertools import islice import asyncio from subprocess import run, PIPE from datetime import datetime, timezone, timedelta import logging from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.interval import IntervalTrigger from zasd.apscheduler import * from zasd.asyncio import * from zasd.config import * # # Constants DATASET_COLS = ['type', 'name', 'creation', 'mountpoint'] # # Functions for running subprocesses with tabulated output # Run program and convert tabulated output to nested lists def run_for_table(args): result = run(args, check=True, stdout=PIPE, encoding='utf-8') return str_to_table(result.stdout) # Run program and convert tabulated output to list of dictionaries with given column names as keys def run_for_dicts(args, column_list): return table_to_dicts(run_for_table(args), column_list) # # Functions for converting multi-line tabulated strings to data structures # Convert tabulated multi-line string to nested lists def str_to_table(string, sep='\t'): return list(line.split(sep) for line in string.splitlines()) # Convert table to list of dictionaries with given column names as keys def table_to_dicts(table, column_list): return list(row_to_dict(row, column_list) for row in table) # Convert table row to dictionary with given column names as keys def row_to_dict(row, column_list): return ({ column_list[i]: row[i] for i in range(len(row)) }) # # ZFS functions # Get list of snapshots (dataset dictionaries) def zfs_get_snapshots(): return zfs_get_datasets('snapshot') # Get list of filesystems (dataset dictionaries) def zfs_get_filesystems(): return zfs_get_datasets('filesystem') # Get list of datasets def zfs_get_datasets(dataset_type='all'): global config return zfs_dicts_to_datasets(run_for_dicts( [config['zfs_path'], 'list', '-Hp', '-t', dataset_type, '-o', ','.join(DATASET_COLS)], DATASET_COLS)) # Transform list of ZFS dictionaries to list of datasets def zfs_dicts_to_datasets(dicts): return list(zfs_dict_to_dataset(d) for d in dicts) # Transform dictionary to dataset (pool, filesystem) def zfs_dict_to_dataset(zfs_dict): name = zfs_dict['name'] dataset = dict(zfs_dict) # Separate dataset and snapshot names out to extra fields if '@' in name: fields = name.split('@') dataset['dataset'] = fields[0] dataset['snapshot'] = fields[1] return dataset # Create one or more snapshots async def zfs_create_snapshot(*snapshots, recursive=False): global config args = [config['zfs_path'], 'snapshot'] if recursive: args.append('-r') for snapshot in snapshots: sargs = args + [get_snapshot_zfs_name(snapshot)] await asyncio.create_subprocess_exec(*sargs) # Destroy one or more snapshots async def zfs_destroy_snapshot(*snapshots, recursive=False): global config args = [config['zfs_path'], 'destroy'] if recursive: args.append('-r') for snapshot in snapshots: sargs = args + [get_snapshot_zfs_name(snapshot)] await asyncio.create_subprocess_exec(*sargs) # Generate ZFS identifier string for snapshot def get_snapshot_zfs_name(snapshot): if 'tag' in snapshot: return make_snapshot_zfs_name(snapshot['dataset'], snapshot['tag'], snapshot.get('serial', None)) elif 'snapshot' in snapshot: return make_snapshot_zfs_name(snapshot['dataset'], snapshot['snapshot']) else: raise KeyError('Snapshot has no name or tag') # Generate ZFS identifier string from arguments def make_snapshot_zfs_name(dataset, tag_or_snapshot, serial=None): if serial is None: return '{}@{}'.format(dataset, tag_or_snapshot) else: return '{}@{}:{}'.format(dataset, tag_or_snapshot, serial) # # Configuration functions # Retrieve all schedules and merge with default schedule def get_schedules(): global config schedules = ({**config['defaults'], **dict(s)} for s in config['schedules']) return schedules # Get dictionary of tag-modified flags on filesystem def get_fs_flags(name): global fs_modified if not name in fs_modified: fs_modified[name] = dict() return fs_modified[name] # Get tag-modified flag for specific tag on filesystem def get_fs_flag(name, tag): flags = get_fs_flags(name) if not tag in flags: flags[tag] = False return flags[tag] # Set specific tag-modified flag on filesystem def set_fs_flag(name, tag): flags = get_fs_flags(name) flags[tag] = True # Set all tag-modified flags on filesystem def set_all_fs_flags(name): flags = get_fs_flags(name) for tag in flags.keys(): set_fs_flag(name, tag) # Clear specific tag-modified flag on filesystem def clear_fs_flag(name, tag): flags = get_fs_flags(name) flags[tag] = False # # fswatch subprocess protocol for asyncio class FSWatchProtocol(LineBufferedProtocol): def __init__(self, fs): LineBufferedProtocol.__init__(self, 'utf-8') self.fs = fs def pipe_line_received(self, line): global logger # Ignore empty lines and NOOPs if len(line) == 0 or int(line) == 0: return logger.info('Detected change on filesystem %s', self.fs['name']) # Set all tag-modified flags on filesystem set_all_fs_flags(self.fs['name']) # # Snapshot scheduling functions # Create snapshot from a snapshot schedule async def snapshot_creation_task(schedule, fs): global logger tag = schedule['tag'] serial = make_snapshot_serial() recursive = schedule['recursive'] if get_fs_flag(fs, tag): # Clear tag-modified flags for this tag on filesystem clear_fs_flag(fs, tag) logger.info('Taking snapshot of filesystem %s on schedule %s', fs, tag) # Create stub snapshot record and take the snapshot snapshot = dict(dataset=fs, tag=tag, serial=serial) await zfs_create_snapshot(snapshot, recursive=recursive) # Generate time-based 8-character hexadecimal snapshot serial number def make_snapshot_serial(): return ('%x' % int(time.time()))[-8:] # Destroy all expired snapshots async def snapshot_destruction_task(): global config, logger snapshots = zfs_get_snapshots() for schedule in get_schedules(): if schedule['disabled']: continue # Find expired snapshots for schedule tag = schedule['tag'] expired = slice_snapshots(snapshots, tag, index=schedule['keep'], stop=None, reverse=True) if len(expired) > 0: logger.info('Destroying snapsnots with tag %s:', tag) for snapshot in expired: logger.info('%s%s', config['tab_size'] * ' ', snapshot['name']) await zfs_destroy_snapshot(snapshot) # Check if snapshot matches tag def snapshot_matches_tag(snapshot, tag): return get_snapshot_tag(snapshot) == tag # Get tag of snapshot def get_snapshot_tag(snapshot): (tag, serial) = get_snapshot_fields(snapshot) return tag # Get serial number of snapshot def get_snapshot_serial(snapshot): (tag, serial) = get_snapshot_fields(snapshot) return serial # Get tuple of fields in snapshot name def get_snapshot_fields(snapshot): global config return tuple(snapshot['snapshot'].split(config['separator'])) # Group 'snapshots' list using 'key' function, enumerate groups (in reverse if 'reverse' is # True), slice by 'index' and 'stop', and return slice as flat list of snapshots # # If 'stop' is not specified, assume that 'index' is the index to shop at; otherwise, slice # beginning at 'index' and ending at 'stop' # def slice_snapshots(snapshots, tag, index, stop=0, reverse=False, key=get_snapshot_serial): # Find matching snapshots matches = list(s for s in snapshots if snapshot_matches_tag(s, tag)) # Make ordered set of serials present in matched snapshots ordered_set = list(sorted(set(key(s) for s in matches), reverse=reverse)) # Slice n serials from ordered set of serials serials = ordered_set[:index] if stop == 0 else ordered_set[index:stop] # Intersect matching snapshots with sliced set of serials result = list(s for s in matches if get_snapshot_serial(s) in set(serials)) return result # Load and activate snapshot schedules def load_snapshot_schedules(): global config, scheduler, fs_modified, logger fs_modified = dict() for schedule in get_schedules(): if schedule['disabled']: continue tag = schedule['tag'] for fs in schedule['filesystems']: scheduler.add_job(snapshot_creation_task, trigger = schedule['trigger'], id = make_snapshot_zfs_name(fs, tag), group = fs, priority = schedule['priority'], args = [schedule, fs]) # Set tag-modified flags on filesystems (always take snapshots on startup) for name in schedule['filesystems']: set_fs_flag(name, tag) scheduler.add_job(snapshot_destruction_task, trigger = config['destroy_trigger'], group = 'destroy') async def main_task(): global config, event_loop, scheduler, fs_listeners scheduler = AsyncIOPriorityScheduler( event_loop = event_loop, executors = {'default': AsyncIOPriorityExecutor()}) # Watch file system mountpoints fs_listeners = dict() for fs in zfs_get_filesystems(): await event_loop.subprocess_exec( lambda: FSWatchProtocol(fs), config['fswatch_path'], '-o', fs['mountpoint'], stdout=PIPE) load_snapshot_schedules() scheduler.start() if stdout.isatty(): while True: print_spinner() await asyncio.sleep(1) # Print idle spinner def print_spinner(): print(print_spinner.chars[print_spinner.index] + '\x1B[G', end='', file=stderr, flush=True) print_spinner.index = (print_spinner.index + 1) % len(print_spinner.chars) print_spinner.index = 0 print_spinner.chars = ['|', '/', '-', '\\'] def signal_handler(signame): global logger, event_loop logger.info('Received %s', signame) event_loop.stop() # # Program config = load_config() configure_logging(config) logger.info('Processing jobs') event_loop = asyncio.get_event_loop() event_loop.add_signal_handler(signal.SIGINT, partial(signal_handler, 'SIGINT')) event_loop.add_signal_handler(signal.SIGTERM, partial(signal_handler, 'SIGTERM')) event_loop.create_task(main_task()) try: event_loop.run_forever() finally: logger.info('Terminating') print(file=stderr) event_loop.close()