|
|
|
@ -7,7 +7,6 @@ from itertools import islice |
|
|
|
|
import asyncio |
|
|
|
|
from subprocess import run, PIPE |
|
|
|
|
from datetime import datetime, timezone, timedelta |
|
|
|
|
import logging |
|
|
|
|
|
|
|
|
|
from apscheduler.triggers.cron import CronTrigger |
|
|
|
|
from apscheduler.triggers.interval import IntervalTrigger |
|
|
|
@ -15,342 +14,125 @@ from apscheduler.triggers.interval import IntervalTrigger |
|
|
|
|
from zasd.apscheduler import * |
|
|
|
|
from zasd.asyncio import * |
|
|
|
|
from zasd.config import * |
|
|
|
|
from zasd.zfs import * |
|
|
|
|
from zasd.fs import * |
|
|
|
|
from zasd.log import * |
|
|
|
|
|
|
|
|
|
# |
|
|
|
|
# Constants |
|
|
|
|
class ZASD(): |
|
|
|
|
def __init__(self): |
|
|
|
|
self.event_loop = asyncio.get_event_loop() |
|
|
|
|
|
|
|
|
|
DATASET_COLS = ['type', 'name', 'creation', 'mountpoint'] |
|
|
|
|
self.event_loop.add_signal_handler(signal.SIGINT, |
|
|
|
|
partial(self.signal_handler, 'SIGINT')) |
|
|
|
|
self.event_loop.add_signal_handler(signal.SIGTERM, |
|
|
|
|
partial(self.signal_handler, 'SIGTERM')) |
|
|
|
|
|
|
|
|
|
# |
|
|
|
|
# Functions for running subprocesses with tabulated output |
|
|
|
|
global config |
|
|
|
|
config = load_config() |
|
|
|
|
configure_logging() |
|
|
|
|
|
|
|
|
|
# Run program and convert tabulated output to nested lists |
|
|
|
|
def run_for_table(args): |
|
|
|
|
result = run(args, check=True, stdout=PIPE, encoding='utf-8') |
|
|
|
|
return str_to_table(result.stdout) |
|
|
|
|
self.zfs = ZFS(config['zfs_path']) |
|
|
|
|
|
|
|
|
|
# Run program and convert tabulated output to list of dictionaries with given column names as keys |
|
|
|
|
def run_for_dicts(args, column_list): |
|
|
|
|
return table_to_dicts(run_for_table(args), column_list) |
|
|
|
|
log.info('Processing jobs') |
|
|
|
|
|
|
|
|
|
# |
|
|
|
|
# Functions for converting multi-line tabulated strings to data structures |
|
|
|
|
|
|
|
|
|
# Convert tabulated multi-line string to nested lists |
|
|
|
|
def str_to_table(string, sep='\t'): |
|
|
|
|
return list(line.split(sep) for line in string.splitlines()) |
|
|
|
|
|
|
|
|
|
# Convert table to list of dictionaries with given column names as keys |
|
|
|
|
def table_to_dicts(table, column_list): |
|
|
|
|
return list(row_to_dict(row, column_list) for row in table) |
|
|
|
|
|
|
|
|
|
# Convert table row to dictionary with given column names as keys |
|
|
|
|
def row_to_dict(row, column_list): |
|
|
|
|
return ({ column_list[i]: row[i] for i in range(len(row)) }) |
|
|
|
|
|
|
|
|
|
# |
|
|
|
|
# ZFS functions |
|
|
|
|
|
|
|
|
|
# Get list of snapshots (dataset dictionaries) |
|
|
|
|
def zfs_get_snapshots(): |
|
|
|
|
return zfs_get_datasets('snapshot') |
|
|
|
|
|
|
|
|
|
# Get list of filesystems (dataset dictionaries) |
|
|
|
|
def zfs_get_filesystems(): |
|
|
|
|
return zfs_get_datasets('filesystem') |
|
|
|
|
|
|
|
|
|
# Get list of datasets |
|
|
|
|
def zfs_get_datasets(dataset_type='all'): |
|
|
|
|
global config |
|
|
|
|
return zfs_dicts_to_datasets(run_for_dicts( |
|
|
|
|
[config['zfs_path'], |
|
|
|
|
'list', |
|
|
|
|
'-Hp', |
|
|
|
|
'-t', dataset_type, |
|
|
|
|
'-o', ','.join(DATASET_COLS)], DATASET_COLS)) |
|
|
|
|
|
|
|
|
|
# Transform list of ZFS dictionaries to list of datasets |
|
|
|
|
def zfs_dicts_to_datasets(dicts): |
|
|
|
|
return list(zfs_dict_to_dataset(d) for d in dicts) |
|
|
|
|
|
|
|
|
|
# Transform dictionary to dataset (pool, filesystem) |
|
|
|
|
def zfs_dict_to_dataset(zfs_dict): |
|
|
|
|
name = zfs_dict['name'] |
|
|
|
|
dataset = dict(zfs_dict) |
|
|
|
|
|
|
|
|
|
# Separate dataset and snapshot names out to extra fields |
|
|
|
|
if '@' in name: |
|
|
|
|
fields = name.split('@') |
|
|
|
|
dataset['dataset'] = fields[0] |
|
|
|
|
dataset['snapshot'] = fields[1] |
|
|
|
|
|
|
|
|
|
return dataset |
|
|
|
|
|
|
|
|
|
# Create one or more snapshots |
|
|
|
|
async def zfs_create_snapshot(*snapshots, recursive=False): |
|
|
|
|
global config |
|
|
|
|
|
|
|
|
|
args = [config['zfs_path'], 'snapshot'] |
|
|
|
|
if recursive: |
|
|
|
|
args.append('-r') |
|
|
|
|
|
|
|
|
|
for snapshot in snapshots: |
|
|
|
|
sargs = args + [get_snapshot_zfs_name(snapshot)] |
|
|
|
|
await asyncio.create_subprocess_exec(*sargs) |
|
|
|
|
|
|
|
|
|
# Destroy one or more snapshots |
|
|
|
|
async def zfs_destroy_snapshot(*snapshots, recursive=False): |
|
|
|
|
global config |
|
|
|
|
|
|
|
|
|
args = [config['zfs_path'], 'destroy'] |
|
|
|
|
if recursive: |
|
|
|
|
args.append('-r') |
|
|
|
|
|
|
|
|
|
for snapshot in snapshots: |
|
|
|
|
sargs = args + [get_snapshot_zfs_name(snapshot)] |
|
|
|
|
await asyncio.create_subprocess_exec(*sargs) |
|
|
|
|
|
|
|
|
|
# Generate ZFS identifier string for snapshot |
|
|
|
|
def get_snapshot_zfs_name(snapshot): |
|
|
|
|
if 'tag' in snapshot: |
|
|
|
|
return make_snapshot_zfs_name(snapshot['dataset'], snapshot['tag'], snapshot.get('serial', None)) |
|
|
|
|
elif 'snapshot' in snapshot: |
|
|
|
|
return make_snapshot_zfs_name(snapshot['dataset'], snapshot['snapshot']) |
|
|
|
|
else: |
|
|
|
|
raise KeyError('Snapshot has no name or tag') |
|
|
|
|
|
|
|
|
|
# Generate ZFS identifier string from arguments |
|
|
|
|
def make_snapshot_zfs_name(dataset, tag_or_snapshot, serial=None): |
|
|
|
|
if serial is None: |
|
|
|
|
return '{}@{}'.format(dataset, tag_or_snapshot) |
|
|
|
|
else: |
|
|
|
|
return '{}@{}:{}'.format(dataset, tag_or_snapshot, serial) |
|
|
|
|
# |
|
|
|
|
# Configuration functions |
|
|
|
|
|
|
|
|
|
# Retrieve all schedules and merge with default schedule |
|
|
|
|
def get_schedules(): |
|
|
|
|
global config |
|
|
|
|
schedules = ({**config['defaults'], **dict(s)} for s in config['schedules']) |
|
|
|
|
return schedules |
|
|
|
|
|
|
|
|
|
# Get dictionary of tag-modified flags on filesystem |
|
|
|
|
def get_fs_flags(name): |
|
|
|
|
global fs_modified |
|
|
|
|
|
|
|
|
|
if not name in fs_modified: |
|
|
|
|
fs_modified[name] = dict() |
|
|
|
|
|
|
|
|
|
return fs_modified[name] |
|
|
|
|
|
|
|
|
|
# Get tag-modified flag for specific tag on filesystem |
|
|
|
|
def get_fs_flag(name, tag): |
|
|
|
|
flags = get_fs_flags(name) |
|
|
|
|
|
|
|
|
|
if not tag in flags: |
|
|
|
|
flags[tag] = False |
|
|
|
|
|
|
|
|
|
return flags[tag] |
|
|
|
|
|
|
|
|
|
# Set specific tag-modified flag on filesystem |
|
|
|
|
def set_fs_flag(name, tag): |
|
|
|
|
flags = get_fs_flags(name) |
|
|
|
|
flags[tag] = True |
|
|
|
|
|
|
|
|
|
# Set all tag-modified flags on filesystem |
|
|
|
|
def set_all_fs_flags(name): |
|
|
|
|
flags = get_fs_flags(name) |
|
|
|
|
for tag in flags.keys(): |
|
|
|
|
set_fs_flag(name, tag) |
|
|
|
|
|
|
|
|
|
# Clear specific tag-modified flag on filesystem |
|
|
|
|
def clear_fs_flag(name, tag): |
|
|
|
|
flags = get_fs_flags(name) |
|
|
|
|
flags[tag] = False |
|
|
|
|
|
|
|
|
|
# |
|
|
|
|
# fswatch subprocess protocol for asyncio |
|
|
|
|
|
|
|
|
|
class FSWatchProtocol(LineBufferedProtocol): |
|
|
|
|
def __init__(self, fs): |
|
|
|
|
LineBufferedProtocol.__init__(self, 'utf-8') |
|
|
|
|
self.fs = fs |
|
|
|
|
|
|
|
|
|
def pipe_line_received(self, line): |
|
|
|
|
global logger |
|
|
|
|
|
|
|
|
|
# Ignore empty lines and NOOPs |
|
|
|
|
if len(line) == 0 or int(line) == 0: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
logger.info('Detected change on filesystem %s', self.fs['name']) |
|
|
|
|
|
|
|
|
|
# Set all tag-modified flags on filesystem |
|
|
|
|
set_all_fs_flags(self.fs['name']) |
|
|
|
|
|
|
|
|
|
# |
|
|
|
|
# Snapshot scheduling functions |
|
|
|
|
|
|
|
|
|
# Create snapshot from a snapshot schedule |
|
|
|
|
async def snapshot_creation_task(schedule, fs): |
|
|
|
|
global logger |
|
|
|
|
|
|
|
|
|
tag = schedule['tag'] |
|
|
|
|
serial = make_snapshot_serial() |
|
|
|
|
recursive = schedule['recursive'] |
|
|
|
|
|
|
|
|
|
if not schedule['if_modified'] or get_fs_flag(fs, tag): |
|
|
|
|
# Clear tag-modified flags for this tag on filesystem |
|
|
|
|
clear_fs_flag(fs, tag) |
|
|
|
|
|
|
|
|
|
logger.info('Taking snapshot of filesystem %s on schedule %s', fs, tag) |
|
|
|
|
|
|
|
|
|
# Create stub snapshot record and take the snapshot |
|
|
|
|
snapshot = dict(dataset=fs, tag=tag, serial=serial) |
|
|
|
|
await zfs_create_snapshot(snapshot, recursive=recursive) |
|
|
|
|
# Load and activate snapshot schedules |
|
|
|
|
self.scheduler = AsyncIOPriorityScheduler( |
|
|
|
|
event_loop = self.event_loop, |
|
|
|
|
executors = {'default': AsyncIOPriorityExecutor()}) |
|
|
|
|
self.load_schedules() |
|
|
|
|
self.scheduler.start() |
|
|
|
|
|
|
|
|
|
spinner = Spinner() |
|
|
|
|
self.event_loop.create_task(spinner.spin) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
self.event_loop.run_forever() |
|
|
|
|
finally: |
|
|
|
|
log.info('Terminating') |
|
|
|
|
print(file=stderr) |
|
|
|
|
self.event_loop.close() |
|
|
|
|
|
|
|
|
|
def signal_handler(self, signame): |
|
|
|
|
log.info('Received %s', signame) |
|
|
|
|
self.event_loop.stop() |
|
|
|
|
|
|
|
|
|
def load_schedules(self): |
|
|
|
|
for schedule in self.schedules(): |
|
|
|
|
if schedule['disabled']: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
tag = schedule['tag'] |
|
|
|
|
for fs in schedule['filesystems']: |
|
|
|
|
self.scheduler.add_job(lambda: self.snapshot_creation_task, |
|
|
|
|
trigger = schedule['trigger'], |
|
|
|
|
id = '{}:{}'.format(fs, tag), |
|
|
|
|
group = fs, |
|
|
|
|
priority = schedule['priority'], |
|
|
|
|
args = [schedule, fs]) |
|
|
|
|
|
|
|
|
|
# Generate time-based 8-character hexadecimal snapshot serial number |
|
|
|
|
def make_snapshot_serial(): |
|
|
|
|
return ('%x' % int(time.time()))[-8:] |
|
|
|
|
# Set tag-modified flags on filesystems (always take snapshots on startup) |
|
|
|
|
for name in schedule['filesystems']: |
|
|
|
|
filesystem = self.zfs.filesystems(name) |
|
|
|
|
filesystem.modified(tag, True) |
|
|
|
|
|
|
|
|
|
# Destroy all expired snapshots |
|
|
|
|
async def snapshot_destruction_task(): |
|
|
|
|
global config, logger |
|
|
|
|
self.scheduler.add_job(self.snapshot_destruction_task, |
|
|
|
|
trigger = config['destroy_trigger'], |
|
|
|
|
id = 'destroy', |
|
|
|
|
group = 'destroy') |
|
|
|
|
|
|
|
|
|
snapshots = zfs_get_snapshots() |
|
|
|
|
# Retrieve all schedules and merge with default schedule |
|
|
|
|
def schedules(self): |
|
|
|
|
schedules = ({**config['defaults'], **dict(s)} for |
|
|
|
|
s in config['schedules']) |
|
|
|
|
return schedules |
|
|
|
|
|
|
|
|
|
for schedule in get_schedules(): |
|
|
|
|
if schedule['disabled']: |
|
|
|
|
continue |
|
|
|
|
# |
|
|
|
|
# Snapshot scheduling functions |
|
|
|
|
|
|
|
|
|
# Find expired snapshots for schedule |
|
|
|
|
# Create snapshot from a snapshot schedule |
|
|
|
|
async def snapshot_creation_task(self, schedule, fs): |
|
|
|
|
tag = schedule['tag'] |
|
|
|
|
expired = slice_snapshots(snapshots, tag, index=schedule['keep'], stop=None, reverse=True) |
|
|
|
|
|
|
|
|
|
if len(expired) > 0: |
|
|
|
|
logger.info('Destroying snapsnots with tag %s:', tag) |
|
|
|
|
for snapshot in expired: |
|
|
|
|
logger.info('%s%s', config['tab_size'] * ' ', snapshot['name']) |
|
|
|
|
await zfs_destroy_snapshot(snapshot) |
|
|
|
|
|
|
|
|
|
# Check if snapshot matches tag |
|
|
|
|
def snapshot_matches_tag(snapshot, tag): |
|
|
|
|
return get_snapshot_tag(snapshot) == tag |
|
|
|
|
|
|
|
|
|
# Get tag of snapshot |
|
|
|
|
def get_snapshot_tag(snapshot): |
|
|
|
|
(tag, serial) = get_snapshot_fields(snapshot) |
|
|
|
|
return tag |
|
|
|
|
|
|
|
|
|
# Get serial number of snapshot |
|
|
|
|
def get_snapshot_serial(snapshot): |
|
|
|
|
(tag, serial) = get_snapshot_fields(snapshot) |
|
|
|
|
return serial |
|
|
|
|
|
|
|
|
|
# Get tuple of fields in snapshot name |
|
|
|
|
def get_snapshot_fields(snapshot): |
|
|
|
|
global config |
|
|
|
|
return tuple(snapshot['snapshot'].split(config['separator'])) |
|
|
|
|
|
|
|
|
|
# Group 'snapshots' list using 'key' function, enumerate groups (in reverse if 'reverse' is |
|
|
|
|
# True), slice by 'index' and 'stop', and return slice as flat list of snapshots |
|
|
|
|
# |
|
|
|
|
# If 'stop' is not specified, assume that 'index' is the index to shop at; otherwise, slice |
|
|
|
|
# beginning at 'index' and ending at 'stop' |
|
|
|
|
# |
|
|
|
|
def slice_snapshots(snapshots, tag, index, stop=0, reverse=False, key=get_snapshot_serial): |
|
|
|
|
# Find matching snapshots |
|
|
|
|
matches = list(s for s in snapshots if snapshot_matches_tag(s, tag)) |
|
|
|
|
|
|
|
|
|
# Make ordered set of serials present in matched snapshots |
|
|
|
|
ordered_set = list(sorted(set(key(s) for s in matches), reverse=reverse)) |
|
|
|
|
recursive = schedule['recursive'] |
|
|
|
|
|
|
|
|
|
# Slice n serials from ordered set of serials |
|
|
|
|
serials = ordered_set[:index] if stop == 0 else ordered_set[index:stop] |
|
|
|
|
if not schedule['if_modified'] or fs.was_modifed(): |
|
|
|
|
# Clear tag-modified flags for this tag on filesystem |
|
|
|
|
fs.clear_modified(tag) |
|
|
|
|
|
|
|
|
|
# Intersect matching snapshots with sliced set of serials |
|
|
|
|
result = list(s for s in matches if get_snapshot_serial(s) in set(serials)) |
|
|
|
|
log.info('Taking snapshot of filesystem %s on schedule %s', fs, tag) |
|
|
|
|
|
|
|
|
|
return result |
|
|
|
|
# Create stub snapshot record and take the snapshot |
|
|
|
|
snapshot = dict(dataset=fs, tag=tag) |
|
|
|
|
await self.zfs.create_snapshot(snapshot, recursive=recursive) |
|
|
|
|
|
|
|
|
|
# Load and activate snapshot schedules |
|
|
|
|
def load_snapshot_schedules(): |
|
|
|
|
global config, scheduler, fs_modified, logger |
|
|
|
|
|
|
|
|
|
fs_modified = dict() |
|
|
|
|
# Destroy all expired snapshots |
|
|
|
|
async def snapshot_destruction_task(self): |
|
|
|
|
snapshots = self.zfs.snapshots() |
|
|
|
|
|
|
|
|
|
for schedule in get_schedules(): |
|
|
|
|
if schedule['disabled']: |
|
|
|
|
continue |
|
|
|
|
for schedule in self.schedules(): |
|
|
|
|
if schedule['disabled']: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
tag = schedule['tag'] |
|
|
|
|
for fs in schedule['filesystems']: |
|
|
|
|
scheduler.add_job(snapshot_creation_task, |
|
|
|
|
trigger = schedule['trigger'], |
|
|
|
|
id = make_snapshot_zfs_name(fs, tag), |
|
|
|
|
group = fs, |
|
|
|
|
priority = schedule['priority'], |
|
|
|
|
args = [schedule, fs]) |
|
|
|
|
|
|
|
|
|
# Set tag-modified flags on filesystems (always take snapshots on startup) |
|
|
|
|
for name in schedule['filesystems']: |
|
|
|
|
set_fs_flag(name, tag) |
|
|
|
|
# Find expired snapshots for schedule |
|
|
|
|
tag = schedule['tag'] |
|
|
|
|
expired = slice_snapshots(snapshots, tag, index=schedule['keep'], stop=None, reverse=True) |
|
|
|
|
|
|
|
|
|
scheduler.add_job(snapshot_destruction_task, |
|
|
|
|
trigger = config['destroy_trigger'], |
|
|
|
|
group = 'destroy') |
|
|
|
|
if len(expired) > 0: |
|
|
|
|
log.info('Destroying snapsnots with tag %s:', tag) |
|
|
|
|
for snapshot in expired: |
|
|
|
|
log.info('%s%s', |
|
|
|
|
config['tab_size'] * ' ', snapshot['name']) |
|
|
|
|
await self.zfs.destroy_snapshot(snapshot) |
|
|
|
|
|
|
|
|
|
async def main_task(): |
|
|
|
|
global config, event_loop, scheduler, fs_listeners |
|
|
|
|
|
|
|
|
|
scheduler = AsyncIOPriorityScheduler( |
|
|
|
|
event_loop = event_loop, |
|
|
|
|
executors = {'default': AsyncIOPriorityExecutor()}) |
|
|
|
|
|
|
|
|
|
# Watch file system mountpoints |
|
|
|
|
fs_listeners = dict() |
|
|
|
|
for fs in zfs_get_filesystems(): |
|
|
|
|
await event_loop.subprocess_exec( |
|
|
|
|
lambda: FSWatchProtocol(fs), config['fswatch_path'], '-o', fs['mountpoint'], stdout=PIPE) |
|
|
|
|
|
|
|
|
|
load_snapshot_schedules() |
|
|
|
|
|
|
|
|
|
scheduler.start() |
|
|
|
|
# Class for printing idle spinner |
|
|
|
|
class Spinner(): |
|
|
|
|
CHARS = ['|', '/', '-', '\\'] |
|
|
|
|
counter = 0 |
|
|
|
|
|
|
|
|
|
if stdout.isatty(): |
|
|
|
|
def spin(self): |
|
|
|
|
while True: |
|
|
|
|
print_spinner() |
|
|
|
|
print(self.CHARS[self.counter] + '\x1B[G', end='', file=stderr, flush=True) |
|
|
|
|
self.counter = (self.counter + 1) % len(self.CHARS) |
|
|
|
|
await asyncio.sleep(1) |
|
|
|
|
|
|
|
|
|
# Print idle spinner |
|
|
|
|
def print_spinner(): |
|
|
|
|
print(print_spinner.chars[print_spinner.index] + '\x1B[G', end='', file=stderr, flush=True) |
|
|
|
|
print_spinner.index = (print_spinner.index + 1) % len(print_spinner.chars) |
|
|
|
|
print_spinner.index = 0 |
|
|
|
|
print_spinner.chars = ['|', '/', '-', '\\'] |
|
|
|
|
|
|
|
|
|
def signal_handler(signame): |
|
|
|
|
global logger, event_loop |
|
|
|
|
logger.info('Received %s', signame) |
|
|
|
|
event_loop.stop() |
|
|
|
|
|
|
|
|
|
# |
|
|
|
|
# Program |
|
|
|
|
|
|
|
|
|
config = load_config() |
|
|
|
|
configure_logging(config) |
|
|
|
|
|
|
|
|
|
logger.info('Processing jobs') |
|
|
|
|
|
|
|
|
|
event_loop = asyncio.get_event_loop() |
|
|
|
|
event_loop.add_signal_handler(signal.SIGINT, partial(signal_handler, 'SIGINT')) |
|
|
|
|
event_loop.add_signal_handler(signal.SIGTERM, partial(signal_handler, 'SIGTERM')) |
|
|
|
|
|
|
|
|
|
event_loop.create_task(main_task()) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
event_loop.run_forever() |
|
|
|
|
finally: |
|
|
|
|
logger.info('Terminating') |
|
|
|
|
print(file=stderr) |
|
|
|
|
event_loop.close() |
|
|
|
|