writeups/2023/misc/migration.md

11 KiB

DB migrations aren't hard

i like north. it's incredibly simple and it does the job of managing database migrations in a straightforward way

recently, because making unhinged technical decisions is my hobby, i decided to make a database migration tool in pure zero-dependency (technically) python in order to optimize for ease of "production" deployment. the migration tool should Just Work in any environment without needing a complicated setup

so let's steal north syntax. it's pretty simple: each migration is a plain SQL file with metadata included in the comments. this is a really basic form of DSL that allows you to use the existing SQL syntax in your favorite editor

here's an example

-- @revision: e864e1e88bbcf9f0a78554faa1563314
-- @parent: b71f6d5426b85d56aaed8cdee8a1ab64
-- @description: creates a table
-- @up {
CREATE TABLE my_table();
-- }

-- @down {
DROP TABLE my_table;
-- }

the revision line specifies a unique ID for this revision, and the parent line identifies its parent revision. this is one of north's design choices that i like, because rather than relying on filename lexicographic order to apply revisions correctly, it creates an ordered migration list using only the content of the files, allowing them to be named arbitrarily. also, both the up and down scripts are contained in the same file, which makes it more convenient to work with imo

[leverage music plays]

@dataclasses.dataclass
class Migration:
    filename: str
    revision: str
    description: str
    parent: Optional[str]
    up: str
    down: str

pretty shrimple

def parse_migration(filename: str, script: str) -> Migration:
    migration = Migration(filename, "", "", None, "", "")

and now it's time for a small parsing state machine, with states being "default" (reading metadata), "up" (reading the up part of the migration), and "down" (reading the down part of the migration)

    state = "default"
    for line in script.splitlines():
        if state == "default":
            ls = line.strip()
            if ls.startswith("#") or len(ls) == 0:
                continue

for good measure, this ignores the #lang line in north files, allowing you to use your existing north migrations

            elif ls.startswith("-- @revision: "):
                migration.revision = ls[14:]
            elif ls.startswith("-- @parent: "):
                migration.parent = ls[12:]
            elif ls.startswith("-- @description: "):
                migration.description = ls[17:]
            elif ls.startswith("-- @up {"):
                state = "up"
            elif ls.startswith("-- @down {"):
                state = "down"
            else:
                raise Exception("bad script line:", ls)
        elif state == "up":
            if line.strip().startswith("-- }"):
                state = "default"
            else:
                migration.up += line + "\n"
        elif state == "down":
            if line.strip().startswith("-- }"):
                state = "default"
            else:
                migration.down += line + "\n"

    if len(migration.revision) == 0:
        raise Exception("revision line is invalid")

    return migration

cool now we can read migration files

let's make a function that reads all the migrations from a folder and returns them in order

def read_migrations(source: pathlib.Path) -> List[Migration]:
    migrations = []
    for file in source.iterdir():
        if file.name.endswith(".sql"):
            with file.open("r") as f:
                script = f.read()
            migrations.append(parse_migration(file.name, script))

ok now it's time for some turbo spaghetti code. ignore the poor quality here ;____;

    sorted_migrations = []

    for m in migrations:
        if m.parent is None:
            sorted_migrations.append(m)
            break

    if len(sorted_migrations) == 0:
        return []

    while True:
        for m in migrations:
            if m.parent == sorted_migrations[-1].revision:
                sorted_migrations.append(m)
                break
        else:
            break

    return sorted_migrations

OK but what is the job of this program, fundamentally

the key thing this tool needs to do is

a. keep track of the current migration b. run available migrations starting from the current migration

in order to accomplish the first thing we can create a special table in the database itself to store the revision. we do this on program startup

def south_init() -> None:
    run_commands("""
                 CREATE TABLE IF NOT EXISTS south_schema_version(current_revision TEXT NOT NULL)
                 WITH (fillfactor=10);
                 """)

oh yeah we also only care about postgres. having support for multiple DBMS is neat but personally i don't care, and we're intentionally limiting scope to keep this as simple as possible

(the fillfactor=10 part instructs postgres to optimize this table for frequent UPDATEs. in fact we only plan to have one row in here ever, which just stores the current revision)

now we just need to define run_commands as a function to run postgres commands

...

so you know how my hobby is making poor technical decisions?

turns out there actually isn't a particular need to include a postgres library here. we can shrimply use the psql command to run the commands. this will work as long as the environment contains the right variables PGHOST and optionally PGUSER, PGDATABASE etc. this makes the script technically zero-dependency, or at least we get the psql command for free with any postgres install

def run_commands(script: str) -> None:
    proc = subprocess.Popen(["psql"], stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    (stdout, stderr) = proc.communicate(script.encode())
    if proc.returncode != 0 or b"ERROR" in stderr:
        raise Exception("failed to execute psql:", proc.returncode, stdout, stderr)
    return stdout.decode()

the or b"ERROR" in stderr part is kind of nasty because sometimes psql will exit with a returncode of zero even though an error occurred, and that's the only way to tell. just make sure you don't name anything in your migrations "ERROR" i guess

ok now we need a way to fetch the current revision. this requires receiving output from psql, which we can do with the --csv flag that makes the output CSV and thus more machine-readable. here's the update to run_commands

def run_commands(script: str) -> None:
    proc = subprocess.Popen(["psql", "--csv"], stdin=subprocess.PIPE,
...

and now we can make the function

def south_get_migration() -> Optional[str]:
    migration = run_commands("SELECT current_revision FROM south_schema_version").splitlines()
    if len(migration) < 2:
        return None
    return migration[1].strip()

row 0 is the column header, so we get the contents of row 1, or if there is no such row (indicating the database is currently unmigrated), we return None

rest of the fucking owl

part b of the tool's job is slightly harder, but only slightly. the key insight is you can bundle each migration together with the update of the current_revision value using a transaction. that way, either the full migration and metadata update works or it all fails, and the DB is never left in an inconsistent state

first, a helper function

def south_cur_migration_idx(migrations: List[Migration], cur: str) -> int:
    idx = 0
    if cur is not None:
        for i, m in enumerate(migrations):
            if cur == m.revision:
                idx = i + 1
                break
        else:
            raise Exception("current migration not in revisions")
    return idx

and then

def south_migrate(migrations: List[Migration]) -> None:
    idx = south_cur_migration_idx(migrations, south_get_migration())
    to_apply = migrations[idx:]
    for m in to_apply:
        print("applying migration", m.revision, m.description)

and now we set up a transaction bundling the revision update with the migration contents

        transaction = "BEGIN;\n" + m.up + "DELETE FROM south_schema_version;\n"
        transaction += "INSERT INTO south_schema_version VALUES ('" + m.revision + "');\n"
        transaction += "COMMIT;\n"
        run_commands(transaction)

ok hold up

        transaction += "INSERT INTO south_schema_version VALUES ('" + m.revision + "');\n"

this line from above looks stinky

technically a revision Should not contain a SQL injection, and because the entire purpose of this tool is basically SQL injection it doesn't really matter

but we can fix this

so here's the really interesting piece of insight: psql supports a form of prepared statements

you can bind variables on the command line using -v name=value, and then use them with :'name' in the SQL

with that knowledge let's update run_commands one more time

def run_commands(script: str, *args: List[str]) -> str:
    extra_args = []
    for (i, arg) in enumerate(args):
        extra_args.extend(["-v", f"v{i}={arg}"])
    proc = subprocess.Popen(["psql", "--csv", *extra_args], stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    (stdout, stderr) = proc.communicate(script.encode())
    if proc.returncode != 0 or b"ERROR" in stderr:
        raise Exception("failed to execute psql:", proc.returncode, stdout, stderr)
    return stdout.decode()

and the migrate function

def south_migrate(migrations: List[Migration]) -> None:
    idx = south_cur_migration_idx(migrations, south_get_migration())
    to_apply = migrations[idx:]
    for m in to_apply:
        print("applying migration", m.revision, m.description)
        transaction = "BEGIN;\n" + m.up + "DELETE FROM south_schema_version;\n"
        transaction += "INSERT INTO south_schema_version VALUES (:'v0');\n"
        transaction += "COMMIT;\n"
        run_commands(transaction, m.revision)

that's it (mostly) !

then i added a similar rollback function and made a basic argparse CLI and utility functions to generate template migrations parented by the latest migration file revision ID and that's basically it

this is a functional database migrations tool that technically has zero dependencies and runs wherever python 3 (and postgres, obviously) runs, with the full script being about 200 lines / 160 SLOC

yo where's the full source file

it's part of a currently-private project. 70% of the code is on this page and if you really want to be making poor technical decisions like i do, you should be able to fill in the rest and get a usable tool for yourself

conclusion

i have NIH syndrome

COMMIT;