diff --git a/src/scry/__init__.py b/src/scry/__init__.py index 804c63b..2489ee1 100644 --- a/src/scry/__init__.py +++ b/src/scry/__init__.py @@ -3,7 +3,6 @@ "get_random_card", "get_card_list", "insert_cards", - "db_connect", "db_stats", "get_total_cards", "clear_database", @@ -11,7 +10,7 @@ "transform_card", ] -from .db_setup import create_table, db_connect, clear_database +from .db_setup import create_table, clear_database from .db_insert import insert_cards, transform_card from .request import get_random_card, get_card_list, set_codes from .db_queries import db_stats, get_total_cards diff --git a/src/scry/cli.py b/src/scry/cli.py index 2524c90..2f09442 100644 --- a/src/scry/cli.py +++ b/src/scry/cli.py @@ -209,7 +209,7 @@ def handle_stats(args, db_connection): def handle_clear(args, db_connection): # HACK: why does this only work with db_connection and args, even though neither # are required? Same with handle_setlist... - clear_database() + clear_database(db_connection) # Helper functions: diff --git a/src/scry/db_setup.py b/src/scry/db_setup.py index 65bc182..4479128 100644 --- a/src/scry/db_setup.py +++ b/src/scry/db_setup.py @@ -1,15 +1,4 @@ # create SQLITE database and tables, insert into table -import sqlite3 -from pathlib import Path -from os import makedirs - -# set path (if required) for db: scry/data/cards.db -path_to_root = Path(__file__).parent.parent.parent # project root scry/ -makedirs(path_to_root / "data/", exist_ok=True) - - -def db_connect(): - return sqlite3.connect(Path(path_to_root / "data" / "cards.db")) def create_table(connection): @@ -35,17 +24,19 @@ def create_table(connection): connection.commit() -def clear_database() -> None: +def drop_table(connection) -> None: + try: + cursor = connection.cursor() + cursor.execute("DROP TABLE IF EXISTS cards;") + + except Exception as err: + print(f"Error clearing database: {err}") + + +def clear_database(connection) -> None: check = input("This will delete your database, are you sure? (y/N): ") if check.lower() == "y": - try: - connection = db_connect() - cursor = connection.cursor() - cursor.execute("DROP TABLE IF EXISTS cards") - cursor.close() - print("Database exiled to graveyard and removed from game...") - - except Exception as err: - print(f"Error clearing database: {err}") + drop_table(connection) + print("Your database has been exiled to the graveyard and removed from game...") else: print("Damnation avoided.") diff --git a/src/scry/main.py b/src/scry/main.py index c05ff46..fda2aa6 100644 --- a/src/scry/main.py +++ b/src/scry/main.py @@ -2,30 +2,36 @@ # by thrly -from . import ( - create_table, - db_connect, -) -from .cli import ( - build_arg_parser, -) +from .db_setup import create_table +from .cli import build_arg_parser + +from contextlib import closing +import sqlite3 +from pathlib import Path +from os import makedirs + +# set path (if required) for db: scry/data/cards.db +path_to_root = Path(__file__).parent.parent.parent # project root scry/ +makedirs(path_to_root / "data/", exist_ok=True) def main(argv=None): - connection = db_connect() # setup argument parsing (argv for testing) parser = build_arg_parser() args = parser.parse_args(argv) - try: - # setup / connect to local database - create_table(connection) + # use auto-closing context manager to create db connection, pass it to handlers + with closing( + sqlite3.connect(Path(path_to_root / "data" / "cards.db")) + ) as connection: + + try: + # setup / connect to local database + create_table(connection) - # execute commands from cli arguments (see cli.py for handling) - args.func(args, connection) + # execute commands from cli arguments (see cli.py for handling) + args.func(args, connection) - except Exception as err: - print("Error in main(): ", err) - finally: - connection.close() # finally close db connection + except Exception as err: + print("Error in main(): ", err) diff --git a/tests/test_db.py b/tests/test_db.py index 7bd19f9..a092b2b 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -4,12 +4,13 @@ from src.scry.db_insert import insert_cards from tests.sample_card import sample_card from pathlib import Path +from contextlib import closing def test_card_insert_and_get(): path_test_db = Path(__file__).parent - connection = sqlite3.connect(Path(path_test_db) / "tests.db") - try: + with closing(sqlite3.connect(Path(path_test_db) / "tests.db")) as connection: + create_table(connection) assert insert_cards([sample_card()], datetime.datetime.now(), connection) == 1 @@ -30,5 +31,5 @@ def test_card_insert_and_get(): assert card[7] == "{G}" # mana_cost assert card[8] == 1.0 # cmc - finally: - connection.close() + cursor.execute("DROP TABLE cards;") + # TODO: look into pytest's fixtures for automatic teardown/cleanup