diff --git a/src/ntr_fetcher/db.py b/src/ntr_fetcher/db.py index 0a49a7b..87df518 100644 --- a/src/ntr_fetcher/db.py +++ b/src/ntr_fetcher/db.py @@ -50,3 +50,259 @@ class Database: conn = self._connect() conn.executescript(SCHEMA) conn.close() + + def upsert_track(self, track: Track) -> None: + conn = self._connect() + conn.execute( + """ + INSERT INTO tracks (id, title, artist, permalink_url, artwork_url, + duration_ms, license, liked_at, raw_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + title=excluded.title, + artist=excluded.artist, + permalink_url=excluded.permalink_url, + artwork_url=excluded.artwork_url, + duration_ms=excluded.duration_ms, + license=excluded.license, + liked_at=excluded.liked_at, + raw_json=excluded.raw_json + """, + ( + track.id, + track.title, + track.artist, + track.permalink_url, + track.artwork_url, + track.duration_ms, + track.license, + track.liked_at.isoformat(), + track.raw_json, + ), + ) + conn.commit() + conn.close() + + def get_track(self, track_id: int) -> Track | None: + conn = self._connect() + row = conn.execute("SELECT * FROM tracks WHERE id = ?", (track_id,)).fetchone() + conn.close() + if row is None: + return None + return Track( + id=row["id"], + title=row["title"], + artist=row["artist"], + permalink_url=row["permalink_url"], + artwork_url=row["artwork_url"], + duration_ms=row["duration_ms"], + license=row["license"], + liked_at=datetime.fromisoformat(row["liked_at"]), + raw_json=row["raw_json"], + ) + + def get_or_create_show( + self, week_start: datetime, week_end: datetime + ) -> Show: + conn = self._connect() + row = conn.execute( + "SELECT id, week_start, week_end, created_at FROM shows " + "WHERE week_start = ? AND week_end = ?", + (week_start.isoformat(), week_end.isoformat()), + ).fetchone() + if row is not None: + conn.close() + return Show( + id=row["id"], + week_start=datetime.fromisoformat(row["week_start"]), + week_end=datetime.fromisoformat(row["week_end"]), + created_at=datetime.fromisoformat(row["created_at"]), + ) + now = datetime.now(timezone.utc).isoformat() + cursor = conn.execute( + "INSERT INTO shows (week_start, week_end, created_at) VALUES (?, ?, ?)", + (week_start.isoformat(), week_end.isoformat(), now), + ) + conn.commit() + show_id = cursor.lastrowid + conn.close() + return Show( + id=show_id, + week_start=week_start, + week_end=week_end, + created_at=datetime.fromisoformat(now), + ) + + def get_show_tracks(self, show_id: int) -> list[dict]: + conn = self._connect() + rows = conn.execute( + """ + SELECT st.show_id, st.track_id, st.position, t.title, t.artist, + t.permalink_url, t.artwork_url, t.duration_ms, t.license, + t.liked_at, t.raw_json + FROM show_tracks st + JOIN tracks t ON st.track_id = t.id + WHERE st.show_id = ? + ORDER BY st.position + """, + (show_id,), + ).fetchall() + conn.close() + return [dict(row) for row in rows] + + def get_show_track_by_position( + self, show_id: int, position: int + ) -> dict | None: + conn = self._connect() + row = conn.execute( + """ + SELECT st.show_id, st.track_id, st.position, t.title, t.artist, + t.permalink_url, t.artwork_url, t.duration_ms, t.license, + t.liked_at, t.raw_json + FROM show_tracks st + JOIN tracks t ON st.track_id = t.id + WHERE st.show_id = ? AND st.position = ? + """, + (show_id, position), + ).fetchone() + conn.close() + return dict(row) if row else None + + def set_show_tracks(self, show_id: int, track_ids: list[int]) -> None: + conn = self._connect() + if track_ids: + placeholders = ",".join("?" * len(track_ids)) + conn.execute( + f"DELETE FROM show_tracks WHERE show_id = ? AND track_id NOT IN ({placeholders})", + (show_id, *track_ids), + ) + else: + conn.execute("DELETE FROM show_tracks WHERE show_id = ?", (show_id,)) + for position, track_id in enumerate(track_ids, start=1): + conn.execute( + """ + INSERT INTO show_tracks (show_id, track_id, position) + VALUES (?, ?, ?) + ON CONFLICT(show_id, track_id) DO UPDATE SET position = excluded.position + """, + (show_id, track_id, position), + ) + conn.commit() + conn.close() + + def get_max_position(self, show_id: int) -> int: + conn = self._connect() + row = conn.execute( + "SELECT COALESCE(MAX(position), 0) as max_pos FROM show_tracks WHERE show_id = ?", + (show_id,), + ).fetchone() + conn.close() + return row["max_pos"] + + def list_shows(self, limit: int, offset: int) -> list[Show]: + conn = self._connect() + rows = conn.execute( + """ + SELECT id, week_start, week_end, created_at + FROM shows + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """, + (limit, offset), + ).fetchall() + conn.close() + return [ + Show( + id=row["id"], + week_start=datetime.fromisoformat(row["week_start"]), + week_end=datetime.fromisoformat(row["week_end"]), + created_at=datetime.fromisoformat(row["created_at"]), + ) + for row in rows + ] + + def has_track_in_show(self, show_id: int, track_id: int) -> bool: + conn = self._connect() + row = conn.execute( + "SELECT 1 FROM show_tracks WHERE show_id = ? AND track_id = ?", + (show_id, track_id), + ).fetchone() + conn.close() + return row is not None + + def remove_show_track(self, show_id: int, track_id: int) -> None: + conn = self._connect() + conn.execute( + "DELETE FROM show_tracks WHERE show_id = ? AND track_id = ?", + (show_id, track_id), + ) + rows = conn.execute( + "SELECT track_id, position FROM show_tracks WHERE show_id = ? ORDER BY position", + (show_id,), + ).fetchall() + for new_position, row in enumerate(rows, start=1): + if row["position"] != new_position: + conn.execute( + "UPDATE show_tracks SET position = ? WHERE show_id = ? AND track_id = ?", + (new_position, show_id, row["track_id"]), + ) + conn.commit() + conn.close() + + def move_show_track( + self, show_id: int, track_id: int, new_position: int + ) -> None: + conn = self._connect() + row = conn.execute( + "SELECT position FROM show_tracks WHERE show_id = ? AND track_id = ?", + (show_id, track_id), + ).fetchone() + if row is None: + conn.close() + return + old_position = row["position"] + if old_position == new_position: + conn.close() + return + if old_position < new_position: + conn.execute( + "UPDATE show_tracks SET position = position - 1 " + "WHERE show_id = ? AND position > ? AND position <= ?", + (show_id, old_position, new_position), + ) + else: + conn.execute( + "UPDATE show_tracks SET position = position + 1 " + "WHERE show_id = ? AND position >= ? AND position < ?", + (show_id, new_position, old_position), + ) + conn.execute( + "UPDATE show_tracks SET position = ? WHERE show_id = ? AND track_id = ?", + (new_position, show_id, track_id), + ) + conn.commit() + conn.close() + + def add_track_to_show( + self, show_id: int, track_id: int, position: int | None = None + ) -> None: + conn = self._connect() + if position is None: + max_pos = conn.execute( + "SELECT COALESCE(MAX(position), 0) FROM show_tracks WHERE show_id = ?", + (show_id,), + ).fetchone()[0] + new_position = max_pos + 1 + else: + conn.execute( + "UPDATE show_tracks SET position = position + 1 " + "WHERE show_id = ? AND position >= ?", + (show_id, position), + ) + new_position = position + conn.execute( + "INSERT INTO show_tracks (show_id, track_id, position) VALUES (?, ?, ?)", + (show_id, track_id, new_position), + ) + conn.commit() + conn.close() diff --git a/tests/test_db.py b/tests/test_db.py index b942940..868a63d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,8 +1,10 @@ import sqlite3 +from datetime import datetime, timezone import pytest from ntr_fetcher.db import Database +from ntr_fetcher.models import Track @pytest.fixture @@ -28,3 +30,208 @@ def test_tables_created(db): def test_initialize_idempotent(db): """Calling initialize twice doesn't raise.""" db.initialize() + + +def _make_track(id: int, liked_at: str, title: str = "Test", artist: str = "Artist") -> Track: + return Track( + id=id, + title=title, + artist=artist, + permalink_url=f"https://soundcloud.com/test/track-{id}", + artwork_url=None, + duration_ms=180000, + license="cc-by", + liked_at=datetime.fromisoformat(liked_at), + raw_json="{}", + ) + + +def test_upsert_track(db): + track = _make_track(100, "2026-03-10T12:00:00+00:00") + db.upsert_track(track) + result = db.get_track(100) + assert result is not None + assert result.title == "Test" + + +def test_upsert_track_updates_existing(db): + track1 = _make_track(100, "2026-03-10T12:00:00+00:00", title="Original") + db.upsert_track(track1) + track2 = _make_track(100, "2026-03-10T12:00:00+00:00", title="Updated") + db.upsert_track(track2) + result = db.get_track(100) + assert result.title == "Updated" + + +def test_get_or_create_show(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + assert show.id is not None + assert show.week_start == week_start + show2 = db.get_or_create_show(week_start, week_end) + assert show2.id == show.id + + +def test_set_show_tracks(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + t1 = _make_track(1, "2026-03-14T01:00:00+00:00", title="First") + t2 = _make_track(2, "2026-03-14T02:00:00+00:00", title="Second") + db.upsert_track(t1) + db.upsert_track(t2) + db.set_show_tracks(show.id, [t1.id, t2.id]) + tracks = db.get_show_tracks(show.id) + assert len(tracks) == 2 + assert tracks[0]["position"] == 1 + assert tracks[0]["title"] == "First" + assert tracks[1]["position"] == 2 + + +def test_set_show_tracks_preserves_existing_positions(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + t1 = _make_track(1, "2026-03-14T01:00:00+00:00") + db.upsert_track(t1) + db.set_show_tracks(show.id, [t1.id]) + t2 = _make_track(2, "2026-03-14T02:00:00+00:00") + db.upsert_track(t2) + db.set_show_tracks(show.id, [t1.id, t2.id]) + tracks = db.get_show_tracks(show.id) + assert tracks[0]["track_id"] == 1 + assert tracks[0]["position"] == 1 + assert tracks[1]["track_id"] == 2 + assert tracks[1]["position"] == 2 + + +def test_set_show_tracks_removes_unliked(db): + """Tracks no longer in the likes list are removed and positions re-compact.""" + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + t1 = _make_track(1, "2026-03-14T01:00:00+00:00", title="First") + t2 = _make_track(2, "2026-03-14T02:00:00+00:00", title="Second") + t3 = _make_track(3, "2026-03-14T03:00:00+00:00", title="Third") + db.upsert_track(t1) + db.upsert_track(t2) + db.upsert_track(t3) + db.set_show_tracks(show.id, [t1.id, t2.id, t3.id]) + db.set_show_tracks(show.id, [t1.id, t3.id]) + tracks = db.get_show_tracks(show.id) + assert len(tracks) == 2 + assert tracks[0]["track_id"] == 1 + assert tracks[0]["position"] == 1 + assert tracks[1]["track_id"] == 3 + assert tracks[1]["position"] == 2 + + +def test_get_show_track_by_position(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + t1 = _make_track(1, "2026-03-14T01:00:00+00:00", title="First") + db.upsert_track(t1) + db.set_show_tracks(show.id, [t1.id]) + result = db.get_show_track_by_position(show.id, 1) + assert result is not None + assert result["title"] == "First" + result_missing = db.get_show_track_by_position(show.id, 99) + assert result_missing is None + + +def test_list_shows(db): + s1 = db.get_or_create_show( + datetime(2026, 3, 6, 3, 0, 0, tzinfo=timezone.utc), + datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc), + ) + s2 = db.get_or_create_show( + datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc), + datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc), + ) + shows = db.list_shows(limit=10, offset=0) + assert len(shows) == 2 + assert shows[0].id == s2.id + + +def test_max_position_for_show(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + assert db.get_max_position(show.id) == 0 + t1 = _make_track(1, "2026-03-14T01:00:00+00:00") + db.upsert_track(t1) + db.set_show_tracks(show.id, [t1.id]) + assert db.get_max_position(show.id) == 1 + + +def test_remove_show_track(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + t1 = _make_track(1, "2026-03-14T01:00:00+00:00") + t2 = _make_track(2, "2026-03-14T02:00:00+00:00") + t3 = _make_track(3, "2026-03-14T03:00:00+00:00") + db.upsert_track(t1) + db.upsert_track(t2) + db.upsert_track(t3) + db.set_show_tracks(show.id, [t1.id, t2.id, t3.id]) + db.remove_show_track(show.id, 2) + tracks = db.get_show_tracks(show.id) + assert len(tracks) == 2 + assert tracks[0]["position"] == 1 + assert tracks[0]["track_id"] == 1 + assert tracks[1]["position"] == 2 + assert tracks[1]["track_id"] == 3 + + +def test_move_show_track(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + t1 = _make_track(1, "2026-03-14T01:00:00+00:00") + t2 = _make_track(2, "2026-03-14T02:00:00+00:00") + t3 = _make_track(3, "2026-03-14T03:00:00+00:00") + db.upsert_track(t1) + db.upsert_track(t2) + db.upsert_track(t3) + db.set_show_tracks(show.id, [t1.id, t2.id, t3.id]) + db.move_show_track(show.id, track_id=3, new_position=1) + tracks = db.get_show_tracks(show.id) + assert tracks[0]["track_id"] == 3 + assert tracks[0]["position"] == 1 + assert tracks[1]["track_id"] == 1 + assert tracks[1]["position"] == 2 + assert tracks[2]["track_id"] == 2 + assert tracks[2]["position"] == 3 + + +def test_add_track_to_show_at_position(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + t1 = _make_track(1, "2026-03-14T01:00:00+00:00") + t2 = _make_track(2, "2026-03-14T02:00:00+00:00") + t3 = _make_track(3, "2026-03-14T03:00:00+00:00") + db.upsert_track(t1) + db.upsert_track(t2) + db.upsert_track(t3) + db.set_show_tracks(show.id, [t1.id, t2.id]) + db.add_track_to_show(show.id, track_id=3, position=2) + tracks = db.get_show_tracks(show.id) + assert len(tracks) == 3 + assert tracks[0]["track_id"] == 1 + assert tracks[1]["track_id"] == 3 + assert tracks[2]["track_id"] == 2 + + +def test_has_track_in_show(db): + week_start = datetime(2026, 3, 13, 2, 0, 0, tzinfo=timezone.utc) + week_end = datetime(2026, 3, 20, 2, 0, 0, tzinfo=timezone.utc) + show = db.get_or_create_show(week_start, week_end) + t1 = _make_track(1, "2026-03-14T01:00:00+00:00") + db.upsert_track(t1) + db.set_show_tracks(show.id, [t1.id]) + assert db.has_track_in_show(show.id, 1) is True + assert db.has_track_in_show(show.id, 999) is False