You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
371 lines
11 KiB
371 lines
11 KiB
from sys import stdout, stderr
|
|
|
|
import signal
|
|
import time
|
|
from functools import partial, reduce
|
|
from itertools import islice
|
|
import asyncio
|
|
from subprocess import run, PIPE
|
|
from datetime import datetime, timezone, timedelta
|
|
import logging
|
|
|
|
from apscheduler.triggers.cron import CronTrigger
|
|
from apscheduler.triggers.interval import IntervalTrigger
|
|
|
|
from zasd.apscheduler import *
|
|
from zasd.asyncio import *
|
|
from zasd.config import *
|
|
|
|
#
|
|
# Constants
|
|
|
|
DATASET_COLS = ['type', 'name', 'creation', 'mountpoint']
|
|
|
|
#
|
|
# Functions for running subprocesses with tabulated output
|
|
|
|
# Run program and convert tabulated output to nested lists
|
|
def run_for_table(args):
|
|
result = run(args, check=True, stdout=PIPE, encoding='utf-8')
|
|
return str_to_table(result.stdout)
|
|
|
|
# Run program and convert tabulated output to list of dictionaries with given column names as keys
|
|
def run_for_dicts(args, column_list):
|
|
return table_to_dicts(run_for_table(args), column_list)
|
|
|
|
#
|
|
# Functions for converting multi-line tabulated strings to data structures
|
|
|
|
# Convert tabulated multi-line string to nested lists
|
|
def str_to_table(string, sep='\t'):
|
|
return list(line.split(sep) for line in string.splitlines())
|
|
|
|
# Convert table to list of dictionaries with given column names as keys
|
|
def table_to_dicts(table, column_list):
|
|
return list(row_to_dict(row, column_list) for row in table)
|
|
|
|
# Convert table row to dictionary with given column names as keys
|
|
def row_to_dict(row, column_list):
|
|
return ({ column_list[i]: row[i] for i in range(len(row)) })
|
|
|
|
#
|
|
# ZFS functions
|
|
|
|
# Get list of snapshots (dataset dictionaries)
|
|
def zfs_get_snapshots():
|
|
return zfs_get_datasets('snapshot')
|
|
|
|
# Get list of filesystems (dataset dictionaries)
|
|
def zfs_get_filesystems():
|
|
return zfs_get_datasets('filesystem')
|
|
|
|
# Get list of datasets
|
|
def zfs_get_datasets(dataset_type='all'):
|
|
global config
|
|
return zfs_dicts_to_datasets(run_for_dicts(
|
|
[config['zfs_path'],
|
|
'list',
|
|
'-Hp',
|
|
'-t', dataset_type,
|
|
'-o', ','.join(DATASET_COLS)], DATASET_COLS))
|
|
|
|
# Transform list of ZFS dictionaries to list of datasets
|
|
def zfs_dicts_to_datasets(dicts):
|
|
return list(zfs_dict_to_dataset(d) for d in dicts)
|
|
|
|
# Transform dictionary to dataset (pool, filesystem)
|
|
def zfs_dict_to_dataset(zfs_dict):
|
|
name = zfs_dict['name']
|
|
dataset = dict(zfs_dict)
|
|
|
|
# Separate dataset and snapshot names out to extra fields
|
|
if '@' in name:
|
|
fields = name.split('@')
|
|
dataset['dataset'] = fields[0]
|
|
dataset['snapshot'] = fields[1]
|
|
|
|
return dataset
|
|
|
|
# Create one or more snapshots
|
|
async def zfs_create_snapshot(*snapshots, recursive=False):
|
|
global config
|
|
|
|
args = [config['zfs_path'], 'snapshot']
|
|
if recursive:
|
|
args.append('-r')
|
|
|
|
for snapshot in snapshots:
|
|
sargs = args + [get_snapshot_zfs_name(snapshot)]
|
|
await asyncio.create_subprocess_exec(*sargs)
|
|
|
|
# Destroy one or more snapshots
|
|
async def zfs_destroy_snapshot(*snapshots, recursive=False):
|
|
global config
|
|
|
|
args = [config['zfs_path'], 'destroy']
|
|
if recursive:
|
|
args.append('-r')
|
|
|
|
for snapshot in snapshots:
|
|
sargs = args + [get_snapshot_zfs_name(snapshot)]
|
|
await asyncio.create_subprocess_exec(*sargs)
|
|
|
|
# Generate ZFS identifier string for snapshot
|
|
def get_snapshot_zfs_name(snapshot):
|
|
if 'tag' in snapshot:
|
|
return make_snapshot_zfs_name(snapshot['dataset'], snapshot['tag'], snapshot.get('serial', None))
|
|
elif 'snapshot' in snapshot:
|
|
return make_snapshot_zfs_name(snapshot['dataset'], snapshot['snapshot'])
|
|
else:
|
|
raise KeyError('Snapshot has no name or tag')
|
|
|
|
# Generate ZFS identifier string from arguments
|
|
def make_snapshot_zfs_name(dataset, tag_or_snapshot, serial=None):
|
|
if serial is None:
|
|
return '{}@{}'.format(dataset, tag_or_snapshot)
|
|
else:
|
|
return '{}@{}:{}'.format(dataset, tag_or_snapshot, serial)
|
|
#
|
|
# Configuration functions
|
|
|
|
# Retrieve all schedules and merge with default schedule
|
|
def get_schedules():
|
|
global config
|
|
schedules = ({**config['defaults'], **dict(s)} for s in config['schedules'])
|
|
return schedules
|
|
|
|
# Get dictionary of tag-modified flags on filesystem
|
|
def get_fs_flags(name):
|
|
global fs_modified
|
|
|
|
if not name in fs_modified:
|
|
fs_modified[name] = dict()
|
|
|
|
return fs_modified[name]
|
|
|
|
# Get tag-modified flag for specific tag on filesystem
|
|
def get_fs_flag(name, tag):
|
|
flags = get_fs_flags(name)
|
|
|
|
if not tag in flags:
|
|
flags[tag] = False
|
|
|
|
return flags[tag]
|
|
|
|
# Set specific tag-modified flag on filesystem
|
|
def set_fs_flag(name, tag):
|
|
flags = get_fs_flags(name)
|
|
flags[tag] = True
|
|
|
|
# Set all tag-modified flags on filesystem
|
|
def set_all_fs_flags(name):
|
|
flags = get_fs_flags(name)
|
|
for tag in flags.keys():
|
|
set_fs_flag(name, tag)
|
|
|
|
# Clear specific tag-modified flag on filesystem
|
|
def clear_fs_flag(name, tag):
|
|
flags = get_fs_flags(name)
|
|
flags[tag] = False
|
|
|
|
#
|
|
# fswatch subprocess protocol for asyncio
|
|
|
|
class FSWatchProtocol(LineBufferedProtocol):
|
|
def __init__(self, fs):
|
|
LineBufferedProtocol.__init__(self, 'utf-8')
|
|
self.fs = fs
|
|
|
|
def pipe_line_received(self, line):
|
|
global logger
|
|
|
|
# Ignore empty lines and NOOPs
|
|
if len(line) == 0 or int(line) == 0:
|
|
return
|
|
|
|
logger.info('Detected change on filesystem %s', self.fs['name'])
|
|
|
|
# Set all tag-modified flags on filesystem
|
|
set_all_fs_flags(self.fs['name'])
|
|
|
|
#
|
|
# Snapshot scheduling functions
|
|
|
|
# Create snapshot from a snapshot schedule
|
|
async def snapshot_creation_task(schedule, fs):
|
|
global logger
|
|
|
|
tag = schedule['tag']
|
|
serial = make_snapshot_serial()
|
|
recursive = schedule['recursive']
|
|
|
|
if not schedule['if_modified'] or get_fs_flag(fs, tag):
|
|
# Clear tag-modified flags for this tag on filesystem
|
|
clear_fs_flag(fs, tag)
|
|
|
|
logger.info('Taking snapshot of filesystem %s on schedule %s', fs, tag)
|
|
|
|
# Create stub snapshot record and take the snapshot
|
|
snapshot = dict(dataset=fs, tag=tag, serial=serial)
|
|
await zfs_create_snapshot(snapshot, recursive=recursive)
|
|
|
|
# Generate time-based 8-character hexadecimal snapshot serial number
|
|
def make_snapshot_serial():
|
|
return ('%x' % int(time.time()))[-8:]
|
|
|
|
# Destroy all expired snapshots
|
|
async def snapshot_destruction_task():
|
|
global config, logger
|
|
|
|
snapshots = zfs_get_snapshots()
|
|
|
|
for schedule in get_schedules():
|
|
if schedule['disabled']:
|
|
continue
|
|
|
|
# Find expired snapshots for schedule
|
|
tag = schedule['tag']
|
|
expired = slice_snapshots(snapshots, tag, index=schedule['keep'], stop=None, reverse=True)
|
|
|
|
if len(expired) > 0:
|
|
logger.info('Destroying snapsnots with tag %s:', tag)
|
|
for snapshot in expired:
|
|
logger.info('%s%s', config['tab_size'] * ' ', snapshot['name'])
|
|
await zfs_destroy_snapshot(snapshot)
|
|
|
|
# Check if snapshot matches tag
|
|
def snapshot_matches_tag(snapshot, tag):
|
|
return get_snapshot_tag(snapshot) == tag
|
|
|
|
# Get tag of snapshot
|
|
def get_snapshot_tag(snapshot):
|
|
fields = get_snapshot_fields(snapshot)
|
|
if len(fields) == 2:
|
|
(tag, serial) = fields
|
|
return tag
|
|
else:
|
|
return ''
|
|
|
|
# Get serial number of snapshot
|
|
def get_snapshot_serial(snapshot):
|
|
fields = get_snapshot_fields(snapshot)
|
|
if len(fields) == 2:
|
|
(tag, serial) = fields
|
|
return serial
|
|
else:
|
|
return ''
|
|
|
|
# Get tuple of fields in snapshot name
|
|
def get_snapshot_fields(snapshot):
|
|
global config
|
|
return tuple(snapshot['snapshot'].split(config['separator']))
|
|
|
|
# Group 'snapshots' list using 'key' function, enumerate groups (in reverse if 'reverse' is
|
|
# True), slice by 'index' and 'stop', and return slice as flat list of snapshots
|
|
#
|
|
# If 'stop' is not specified, assume that 'index' is the index to shop at; otherwise, slice
|
|
# beginning at 'index' and ending at 'stop'
|
|
#
|
|
def slice_snapshots(snapshots, tag, index, stop=0, reverse=False, key=get_snapshot_serial):
|
|
# Find matching snapshots
|
|
matches = list(s for s in snapshots if snapshot_matches_tag(s, tag))
|
|
|
|
# Make ordered set of serials present in matched snapshots
|
|
ordered_set = list(sorted(set(key(s) for s in matches), reverse=reverse))
|
|
|
|
# Slice n serials from ordered set of serials
|
|
serials = ordered_set[:index] if stop == 0 else ordered_set[index:stop]
|
|
|
|
# Intersect matching snapshots with sliced set of serials
|
|
result = list(s for s in matches if get_snapshot_serial(s) in set(serials))
|
|
|
|
return result
|
|
|
|
# Load and activate snapshot schedules
|
|
def load_snapshot_schedules():
|
|
global config, scheduler, fs_modified, logger
|
|
|
|
fs_modified = dict()
|
|
|
|
for schedule in get_schedules():
|
|
if schedule['disabled']:
|
|
continue
|
|
|
|
tag = schedule['tag']
|
|
for fs in schedule['filesystems']:
|
|
scheduler.add_job(snapshot_creation_task,
|
|
trigger = schedule['trigger'],
|
|
id = make_snapshot_zfs_name(fs, tag),
|
|
group = fs,
|
|
priority = schedule['priority'],
|
|
args = [schedule, fs])
|
|
|
|
# Set tag-modified flags on filesystems (always take snapshots on startup)
|
|
for name in schedule['filesystems']:
|
|
set_fs_flag(name, tag)
|
|
|
|
scheduler.add_job(snapshot_destruction_task,
|
|
trigger = config['destroy_trigger'],
|
|
group = 'destroy')
|
|
|
|
async def main_task():
|
|
global config, event_loop, scheduler, fs_listeners
|
|
|
|
scheduler = AsyncIOPriorityScheduler(
|
|
event_loop = event_loop,
|
|
executors = {'default': AsyncIOPriorityExecutor()})
|
|
|
|
monitor = False
|
|
for schedule in get_schedules():
|
|
if schedule['if_modified']:
|
|
monitor = True
|
|
|
|
# Watch file system mountpoints
|
|
if monitor:
|
|
logger.info('Starting file system watcher')
|
|
fs_listeners = dict()
|
|
for fs in zfs_get_filesystems():
|
|
await event_loop.subprocess_exec(
|
|
lambda: FSWatchProtocol(fs), config['fswatch_path'], '-o', fs['mountpoint'], stdout=PIPE)
|
|
|
|
load_snapshot_schedules()
|
|
|
|
scheduler.start()
|
|
|
|
if stdout.isatty():
|
|
while True:
|
|
print_spinner()
|
|
await asyncio.sleep(1)
|
|
|
|
# Print idle spinner
|
|
def print_spinner():
|
|
print(print_spinner.chars[print_spinner.index] + '\x1B[G', end='', file=stderr, flush=True)
|
|
print_spinner.index = (print_spinner.index + 1) % len(print_spinner.chars)
|
|
print_spinner.index = 0
|
|
print_spinner.chars = ['|', '/', '-', '\\']
|
|
|
|
def signal_handler(signame):
|
|
global logger, event_loop
|
|
logger.info('Received %s', signame)
|
|
event_loop.stop()
|
|
|
|
#
|
|
# Program
|
|
|
|
config = load_config()
|
|
configure_logging(config)
|
|
|
|
logger.info('Processing jobs')
|
|
|
|
event_loop = asyncio.get_event_loop()
|
|
event_loop.add_signal_handler(signal.SIGINT, partial(signal_handler, 'SIGINT'))
|
|
event_loop.add_signal_handler(signal.SIGTERM, partial(signal_handler, 'SIGTERM'))
|
|
|
|
event_loop.create_task(main_task())
|
|
|
|
try:
|
|
event_loop.run_forever()
|
|
finally:
|
|
logger.info('Terminating')
|
|
print(file=stderr)
|
|
event_loop.close()
|
|
|