Chunk identify scans (#1169)

This commit is contained in:
2025-11-20 09:46:49 +01:00
committed by GitHub
5 changed files with 93 additions and 72 deletions
+3
View File
@@ -5,8 +5,10 @@ from logging import getLogger
from typing import Any, cast
from asyncpg import Connection, Pool, create_pool
from opentelemetry import trace
logger = getLogger(__name__)
tracer = trace.get_tracer("kyoo.scanner")
pool: Pool
@@ -55,6 +57,7 @@ async def get_db_fapi():
yield db
@tracer.start_as_current_span("migrate")
async def migrate(migrations_dir="./migrations"):
async with get_db() as db:
_ = await db.execute(
+24 -19
View File
@@ -1,3 +1,5 @@
import asyncio
import itertools
import os
import re
from contextlib import asynccontextmanager
@@ -111,30 +113,33 @@ class FsScanner:
logger.error("Unexpected error while monitoring files.", exc_info=e)
async def _register(self, videos: list[str] | set[str]):
# TODO: we should probably chunk those
vids: list[Video] = []
for path in list(videos):
async def process(path: str):
try:
vid = await identify(path)
vid = self._match(vid)
vids.append(vid)
return self._match(vid)
except Exception as e:
logger.error("Couldn't identify %s.", path, exc_info=e)
created = await self._client.create_videos(vids)
return None
await self._requests.enqueue(
[
Request(
kind=x.guess.kind,
title=x.guess.title,
year=next(iter(x.guess.years), None),
external_id=x.guess.external_id,
videos=[Request.Video(id=x.id, episodes=x.guess.episodes)],
)
for x in created
if not any(x.entries) and x.guess.kind != "extra"
]
)
for batch in itertools.batched(videos, 20):
vids = await asyncio.gather(*(process(path) for path in batch))
created = await self._client.create_videos(
[v for v in vids if v is not None]
)
await self._requests.enqueue(
[
Request(
kind=x.guess.kind,
title=x.guess.title,
year=next(iter(x.guess.years), None),
external_id=x.guess.external_id,
videos=[Request.Video(id=x.id, episodes=x.guess.episodes)],
)
for x in created
if not any(x.entries) and x.guess.kind != "extra"
]
)
def _match(self, video: Video) -> Video:
video.for_ = []
+57 -50
View File
@@ -1,15 +1,18 @@
import os
from collections.abc import Awaitable
from hashlib import sha256
from itertools import zip_longest
from logging import getLogger
from typing import Callable, Literal, cast
from opentelemetry import trace
from rebulk.match import Match
from ..models.videos import Guess, Video
from .guess.guess import guessit
logger = getLogger(__name__)
tracer = trace.get_tracer("kyoo.scanner")
pipeline: list[Callable[[str, Guess], Awaitable[Guess]]] = [
# TODO: add nfo scanner
@@ -19,62 +22,66 @@ pipeline: list[Callable[[str, Guess], Awaitable[Guess]]] = [
async def identify(path: str) -> Video:
raw = guessit(path, expected_titles=[])
with tracer.start_as_current_span(f"identify {os.path.basename(path)}") as span:
span.set_attribute("video.path", path)
# guessit should only return one (according to the doc)
title = raw.get("title", [])[0]
kind = raw.get("type", [])[0]
version = next(iter(raw.get("version", [])), None)
# apparently guessit can return multiples but tbh idk what to do with
# multiples part. we'll just ignore them for now
part = next(iter(raw.get("part", [])), None)
raw = guessit(path, expected_titles=[])
years = raw.get("year", [])
seasons = raw.get("season", [])
episodes = raw.get("episode", [])
# guessit should only return one (according to the doc)
title = raw.get("title", [])[0]
kind = raw.get("type", [])[0]
version = next(iter(raw.get("version", [])), None)
# apparently guessit can return multiples but tbh idk what to do with
# multiples part. we'll just ignore them for now
part = next(iter(raw.get("part", [])), None)
# just strip the version & part number from the path
rendering_path = "".join(
c
for i, c in enumerate(path)
if not (version and version.start <= i < version.end)
and not (part and part.start <= i < part.end)
)
years = raw.get("year", [])
seasons = raw.get("season", [])
episodes = raw.get("episode", [])
guess = Guess(
title=cast(str, title.value),
kind=cast(Literal["episode", "movie"], kind.value),
extra_kind=None,
years=[cast(int, y.value) for y in years],
episodes=[
Guess.Episode(season=cast(int, s.value), episode=cast(int, e.value))
for s, e in zip_longest(
seasons,
episodes,
fillvalue=seasons[-1] if any(seasons) else Match(0, 0, value=1),
)
],
external_id={},
from_="guessit",
raw={
k: [x.value if x.value is int else str(x.value) for x in v]
for k, v in raw.items()
},
)
# just strip the version & part number from the path
rendering_path = "".join(
c
for i, c in enumerate(path)
if not (version and version.start <= i < version.end)
and not (part and part.start <= i < part.end)
)
for step in pipeline:
try:
guess = await step(path, guess)
except Exception as e:
logger.error("Couldn't run %s.", step.__name__, exc_info=e)
guess = Guess(
title=cast(str, title.value),
kind=cast(Literal["episode", "movie"], kind.value),
extra_kind=None,
years=[cast(int, y.value) for y in years],
episodes=[
Guess.Episode(season=cast(int, s.value), episode=cast(int, e.value))
for s, e in zip_longest(
seasons,
episodes,
fillvalue=seasons[-1] if any(seasons) else Match(0, 0, value=1),
)
],
external_id={},
from_="guessit",
raw={
k: [x.value if x.value is int else str(x.value) for x in v]
for k, v in raw.items()
},
)
span.set_attribute("video.name", guess.title)
return Video(
path=path,
rendering=sha256(rendering_path.encode()).hexdigest(),
part=cast(int, part.value) if part else None,
version=cast(int, version.value) if version else 1,
guess=guess,
)
for step in pipeline:
try:
guess = await step(path, guess)
except Exception as e:
logger.error("Couldn't run %s.", step.__name__, exc_info=e)
return Video(
path=path,
rendering=sha256(rendering_path.encode()).hexdigest(),
part=cast(int, part.value) if part else None,
version=cast(int, version.value) if version else 1,
guess=guess,
)
if __name__ == "__main__":
+8 -2
View File
@@ -1,5 +1,6 @@
import logging
import os
import sys
from fastapi import FastAPI
from opentelemetry import metrics, trace
@@ -45,8 +46,13 @@ def instrument(app: FastAPI):
)
)
set_logger_provider(provider)
handler = LoggingHandler(level=logging.DEBUG, logger_provider=provider)
logging.basicConfig(handlers=[handler], level=logging.DEBUG)
logging.basicConfig(
handlers=[
LoggingHandler(level=logging.DEBUG, logger_provider=provider),
logging.StreamHandler(sys.stdout),
],
level=logging.DEBUG,
)
logging.getLogger("watchfiles").setLevel(logging.WARNING)
logging.getLogger("rebulk").setLevel(logging.WARNING)
+1 -1
View File
@@ -75,7 +75,7 @@ class RequestProcessor:
self._database.add_termination_listener(terminated)
await self._database.add_listener("scanner_requests", process)
logger.info("Listening for requestes")
logger.info("Listening for requests")
_ = await closed.wait()
logger.info("stopping...")
except CancelledError: