ZFS Automatic Snapshot Daemon
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

357 lines
11 KiB

from sys import stdout, stderr
import signal
import time
from functools import partial, reduce
from itertools import islice
import asyncio
from subprocess import run, PIPE
from datetime import datetime, timezone, timedelta
import logging
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from zasd.apscheduler import *
from zasd.asyncio import *
from zasd.config import *
#
# Constants
DATASET_COLS = ['type', 'name', 'creation', 'mountpoint']
#
# Functions for running subprocesses with tabulated output
# Run program and convert tabulated output to nested lists
def run_for_table(args):
result = run(args, check=True, stdout=PIPE, encoding='utf-8')
return str_to_table(result.stdout)
# Run program and convert tabulated output to list of dictionaries with given column names as keys
def run_for_dicts(args, column_list):
return table_to_dicts(run_for_table(args), column_list)
#
# Functions for converting multi-line tabulated strings to data structures
# Convert tabulated multi-line string to nested lists
def str_to_table(string, sep='\t'):
return list(line.split(sep) for line in string.splitlines())
# Convert table to list of dictionaries with given column names as keys
def table_to_dicts(table, column_list):
return list(row_to_dict(row, column_list) for row in table)
# Convert table row to dictionary with given column names as keys
def row_to_dict(row, column_list):
return ({ column_list[i]: row[i] for i in range(len(row)) })
#
# ZFS functions
# Get list of snapshots (dataset dictionaries)
def zfs_get_snapshots():
return zfs_get_datasets('snapshot')
# Get list of filesystems (dataset dictionaries)
def zfs_get_filesystems():
return zfs_get_datasets('filesystem')
# Get list of datasets
def zfs_get_datasets(dataset_type='all'):
global config
return zfs_dicts_to_datasets(run_for_dicts(
[config['zfs_path'],
'list',
'-Hp',
'-t', dataset_type,
'-o', ','.join(DATASET_COLS)], DATASET_COLS))
# Transform list of ZFS dictionaries to list of datasets
def zfs_dicts_to_datasets(dicts):
return list(zfs_dict_to_dataset(d) for d in dicts)
# Transform dictionary to dataset (pool, filesystem)
def zfs_dict_to_dataset(zfs_dict):
name = zfs_dict['name']
dataset = dict(zfs_dict)
# Separate dataset and snapshot names out to extra fields
if '@' in name:
fields = name.split('@')
dataset['dataset'] = fields[0]
dataset['snapshot'] = fields[1]
return dataset
# Create one or more snapshots
async def zfs_create_snapshot(*snapshots, recursive=False):
global config
args = [config['zfs_path'], 'snapshot']
if recursive:
args.append('-r')
for snapshot in snapshots:
sargs = args + [get_snapshot_zfs_name(snapshot)]
await asyncio.create_subprocess_exec(*sargs)
# Destroy one or more snapshots
async def zfs_destroy_snapshot(*snapshots, recursive=False):
global config
args = [config['zfs_path'], 'destroy']
if recursive:
args.append('-r')
for snapshot in snapshots:
sargs = args + [get_snapshot_zfs_name(snapshot)]
await asyncio.create_subprocess_exec(*sargs)
# Generate ZFS identifier string for snapshot
def get_snapshot_zfs_name(snapshot):
if 'tag' in snapshot:
return make_snapshot_zfs_name(snapshot['dataset'], snapshot['tag'], snapshot.get('serial', None))
elif 'snapshot' in snapshot:
return make_snapshot_zfs_name(snapshot['dataset'], snapshot['snapshot'])
else:
raise KeyError('Snapshot has no name or tag')
# Generate ZFS identifier string from arguments
def make_snapshot_zfs_name(dataset, tag_or_snapshot, serial=None):
if serial is None:
return '{}@{}'.format(dataset, tag_or_snapshot)
else:
return '{}@{}:{}'.format(dataset, tag_or_snapshot, serial)
#
# Configuration functions
# Retrieve all schedules and merge with default schedule
def get_schedules():
global config
schedules = ({**config['defaults'], **dict(s)} for s in config['schedules'])
return schedules
# Get dictionary of tag-modified flags on filesystem
def get_fs_flags(name):
global fs_modified
if not name in fs_modified:
fs_modified[name] = dict()
return fs_modified[name]
# Get tag-modified flag for specific tag on filesystem
def get_fs_flag(name, tag):
flags = get_fs_flags(name)
if not tag in flags:
flags[tag] = False
return flags[tag]
# Set specific tag-modified flag on filesystem
def set_fs_flag(name, tag):
flags = get_fs_flags(name)
flags[tag] = True
# Set all tag-modified flags on filesystem
def set_all_fs_flags(name):
flags = get_fs_flags(name)
for tag in flags.keys():
set_fs_flag(name, tag)
# Clear specific tag-modified flag on filesystem
def clear_fs_flag(name, tag):
flags = get_fs_flags(name)
flags[tag] = False
#
# fswatch subprocess protocol for asyncio
class FSWatchProtocol(LineBufferedProtocol):
def __init__(self, fs):
LineBufferedProtocol.__init__(self, 'utf-8')
self.fs = fs
def pipe_line_received(self, line):
global logger
# Ignore empty lines and NOOPs
if len(line) == 0 or int(line) == 0:
return
logger.info('Detected change on filesystem %s', self.fs['name'])
# Set all tag-modified flags on filesystem
set_all_fs_flags(self.fs['name'])
#
# Snapshot scheduling functions
# Create snapshot from a snapshot schedule
async def snapshot_creation_task(schedule, fs):
global logger
tag = schedule['tag']
serial = make_snapshot_serial()
recursive = schedule['recursive']
if get_fs_flag(fs, tag):
# Clear tag-modified flags for this tag on filesystem
clear_fs_flag(fs, tag)
logger.info('Taking snapshot of filesystem %s on schedule %s', fs, tag)
# Create stub snapshot record and take the snapshot
snapshot = dict(dataset=fs, tag=tag, serial=serial)
await zfs_create_snapshot(snapshot, recursive=recursive)
# Generate time-based 8-character hexadecimal snapshot serial number
def make_snapshot_serial():
return ('%x' % int(time.time()))[-8:]
# Destroy all expired snapshots
async def snapshot_destruction_task():
global config, logger
snapshots = zfs_get_snapshots()
for schedule in get_schedules():
if schedule['disabled']:
continue
# Find expired snapshots for schedule
tag = schedule['tag']
expired = slice_snapshots(snapshots, tag, index=schedule['keep'], stop=None, reverse=True)
if len(expired) > 0:
logger.info('Destroying snapsnots with tag %s:', tag)
for snapshot in expired:
logger.info('%s%s', config['tab_size'] * ' ', snapshot['name'])
await zfs_destroy_snapshot(snapshot)
# Check if snapshot matches tag
def snapshot_matches_tag(snapshot, tag):
return get_snapshot_tag(snapshot) == tag
# Get tag of snapshot
def get_snapshot_tag(snapshot):
(tag, serial) = get_snapshot_fields(snapshot)
return tag
# Get serial number of snapshot
def get_snapshot_serial(snapshot):
(tag, serial) = get_snapshot_fields(snapshot)
return serial
# Get tuple of fields in snapshot name
def get_snapshot_fields(snapshot):
global config
return tuple(snapshot['snapshot'].split(config['separator']))
# Group 'snapshots' list using 'key' function, enumerate groups (in reverse if 'reverse' is
# True), slice by 'index' and 'stop', and return slice as flat list of snapshots
#
# If 'stop' is not specified, assume that 'index' is the index to shop at; otherwise, slice
# beginning at 'index' and ending at 'stop'
#
def slice_snapshots(snapshots, tag, index, stop=0, reverse=False, key=get_snapshot_serial):
# Find matching snapshots
matches = list(s for s in snapshots if snapshot_matches_tag(s, tag))
# Make ordered set of serials present in matched snapshots
ordered_set = list(sorted(set(key(s) for s in matches), reverse=reverse))
# Slice n serials from ordered set of serials
serials = ordered_set[:index] if stop == 0 else ordered_set[index:stop]
# Intersect matching snapshots with sliced set of serials
result = list(s for s in matches if get_snapshot_serial(s) in set(serials))
return result
# Load and activate snapshot schedules
def load_snapshot_schedules():
global config, scheduler, fs_modified, logger
fs_modified = dict()
for schedule in get_schedules():
if schedule['disabled']:
continue
tag = schedule['tag']
for fs in schedule['filesystems']:
scheduler.add_job(snapshot_creation_task,
trigger = schedule['trigger'],
id = make_snapshot_zfs_name(fs, tag),
group = fs,
priority = schedule['priority'],
args = [schedule, fs])
# Set tag-modified flags on filesystems (always take snapshots on startup)
for name in schedule['filesystems']:
set_fs_flag(name, tag)
scheduler.add_job(snapshot_destruction_task,
trigger = config['destroy_trigger'],
group = 'destroy')
async def main_task():
global config, event_loop, scheduler, fs_listeners
scheduler = AsyncIOPriorityScheduler(
event_loop = event_loop,
executors = {'default': AsyncIOPriorityExecutor()})
# Watch file system mountpoints
fs_listeners = dict()
for fs in zfs_get_filesystems():
await event_loop.subprocess_exec(
lambda: FSWatchProtocol(fs), config['fswatch_path'], '-o', fs['mountpoint'], stdout=PIPE)
load_snapshot_schedules()
scheduler.start()
if stdout.isatty():
while True:
print_spinner()
await asyncio.sleep(1)
# Print idle spinner
def print_spinner():
print(print_spinner.chars[print_spinner.index] + '\x1B[G', end='', file=stderr, flush=True)
print_spinner.index = (print_spinner.index + 1) % len(print_spinner.chars)
print_spinner.index = 0
print_spinner.chars = ['|', '/', '-', '\\']
def signal_handler(signame):
global logger, event_loop
logger.info('Received %s', signame)
event_loop.stop()
#
# Program
config = load_config()
configure_logging(config)
logger.info('Processing jobs')
event_loop = asyncio.get_event_loop()
event_loop.add_signal_handler(signal.SIGINT, partial(signal_handler, 'SIGINT'))
event_loop.add_signal_handler(signal.SIGTERM, partial(signal_handler, 'SIGTERM'))
event_loop.create_task(main_task())
try:
event_loop.run_forever()
finally:
logger.info('Terminating')
print(file=stderr)
event_loop.close()