diff --git a/api/src/auth.ts b/api/src/auth.ts index d7f0d874..101a5dc0 100644 --- a/api/src/auth.ts +++ b/api/src/auth.ts @@ -33,6 +33,17 @@ const Jwt = t.Object({ type Jwt = typeof Jwt.static; const validator = TypeCompiler.Compile(Jwt); +export async function verifyJwt(bearer: string) { + // @ts-expect-error ts can't understand that there's two overload idk why + const { payload } = await jwtVerify(bearer, jwtSecret ?? jwks, { + issuer: process.env.JWT_ISSUER, + }); + const raw = validator.Decode(payload); + const jwt = Value.Default(Jwt, raw) as Prettify; + + return { jwt }; +} + export const auth = new Elysia({ name: "auth" }) .guard({ headers: t.Object( @@ -50,18 +61,8 @@ export const auth = new Elysia({ name: "auth" }) message: "No authorization header was found.", }); } - try { - // @ts-expect-error ts can't understand that there's two overload idk why - const { payload } = await jwtVerify(bearer, jwtSecret ?? jwks, { - issuer: process.env.JWT_ISSUER, - }); - const raw = validator.Decode(payload); - const jwt = Value.Default(Jwt, raw) as Prettify< - Jwt & { settings: Settings } - >; - - return { jwt }; + return await verifyJwt(bearer); } catch (err) { return status(403, { status: 403, diff --git a/api/src/base.ts b/api/src/base.ts index ad33b439..920d88ff 100644 --- a/api/src/base.ts +++ b/api/src/base.ts @@ -17,6 +17,7 @@ import { videosReadH, videosWriteH } from "./controllers/videos"; import { db } from "./db"; import type { KError } from "./models/error"; import { otel } from "./otel"; +import { appWs } from "./websockets"; export const base = new Elysia({ name: "base" }) .onError(({ code, error }) => { @@ -91,8 +92,9 @@ export const base = new Elysia({ name: "base" }) export const prefix = "/api"; export const handlers = new Elysia({ prefix }) .use(base) - .use(auth) .use(otel) + .use(appWs) + .use(auth) .guard( { // Those are not applied for now. See https://github.com/elysiajs/elysia/issues/1139 diff --git a/api/src/controllers/profiles/history.ts b/api/src/controllers/profiles/history.ts index 9736ee23..71af81ad 100644 --- a/api/src/controllers/profiles/history.ts +++ b/api/src/controllers/profiles/history.ts @@ -1,11 +1,22 @@ -import { and, count, eq, exists, gt, isNotNull, ne, sql } from "drizzle-orm"; +import { + and, + count, + eq, + exists, + gt, + isNotNull, + lte, + ne, + sql, + TransactionRollbackError, +} from "drizzle-orm"; import { alias } from "drizzle-orm/pg-core"; import Elysia, { t } from "elysia"; import { auth, getUserInfo } from "~/auth"; -import { db } from "~/db"; +import { db, type Transaction } from "~/db"; import { entries, history, profiles, shows, videos } from "~/db/schema"; import { watchlist } from "~/db/schema/watchlist"; -import { coalesce, values } from "~/db/utils"; +import { coalesce, sqlarr } from "~/db/utils"; import { Entry } from "~/models/entry"; import { KError } from "~/models/error"; import { SeedHistory } from "~/models/history"; @@ -19,6 +30,7 @@ import { } from "~/models/utils"; import { desc } from "~/models/utils/descriptions"; import type { WatchlistStatus } from "~/models/watchlist"; +import { traverse } from "~/utils"; import { entryFilters, entryProgressQ, @@ -27,6 +39,275 @@ import { } from "../entries"; import { getOrCreateProfile } from "./profile"; +export async function updateProgress(userPk: number, progress: SeedHistory[]) { + try { + return await db.transaction(async (tx) => { + const hist = await updateHistory(tx, userPk, progress); + if (hist.created.length + hist.updated.length !== progress.length) { + tx.rollback(); + } + // only return new and entries whose status has changed. + // we don't need to update the watchlist every 10s when watching a video. + await updateWatchlist(tx, userPk, [ + ...hist.created, + ...hist.updated.filter((x) => x.percent >= 95), + ]); + return { status: 201, inserted: hist.created.length } as const; + }); + } catch (e) { + if (!(e instanceof TransactionRollbackError)) throw e; + return { + status: 404, + message: "Invalid entry id/slug in progress array", + } as const; + } +} + +async function updateHistory( + dbTx: Transaction, + userPk: number, + progress: SeedHistory[], +) { + return dbTx.transaction(async (tx) => { + // `for("update", { of: history })` will put the `kyoo.history` instead + // of `history` in the sql and that triggers a sql error. + const existing = ( + await tx + .select({ videoId: videos.id }) + .from(history) + .for("update", { of: sql`history` as any }) + .leftJoin(videos, eq(videos.pk, history.videoPk)) + .where( + and( + eq(history.profilePk, userPk), + lte(sql`now() - ${history.playedDate}`, sql`interval '1 day'`), + ), + ) + ).map((x) => x.videoId); + + const toUpdate = traverse( + progress.filter((x) => existing.includes(x.videoId)), + ); + const newEntries = traverse( + progress + .filter((x) => !existing.includes(x.videoId)) + .map((x) => ({ ...x, entryUseid: isUuid(x.entry) })), + ); + + const updated = + toUpdate === null + ? [] + : await tx + .update(history) + .set({ + time: sql`hist.ts`, + percent: sql`hist.percent`, + playedDate: coalesce(sql`hist.played_date`, sql`now()`), + }) + .from(sql`unnest( + ${sqlarr(toUpdate.videoId)}::uuid[], + ${sqlarr(toUpdate.time)}::integer[], + ${sqlarr(toUpdate.percent)}::integer[], + ${sqlarr(toUpdate.playedDate)}::timestamp[] + ) as hist(video_id, ts, percent, played_date)`) + .innerJoin(videos, eq(videos.id, sql`hist.video_id`)) + .where( + and( + eq(history.profilePk, userPk), + eq(history.videoPk, videos.pk), + ), + ) + .returning({ + entryPk: history.entryPk, + videoPk: history.videoPk, + percent: history.percent, + playedDate: history.playedDate, + }); + + const created = + newEntries === null + ? [] + : await tx + .insert(history) + .select( + db + .select({ + profilePk: sql`${userPk}`.as("profilePk"), + videoPk: videos.pk, + entryPk: entries.pk, + percent: sql`hist.percent`.as("percent"), + time: sql`hist.ts`.as("time"), + playedDate: coalesce(sql`hist.played_date`, sql`now()`).as( + "playedDate", + ), + }) + .from(sql`unnest( + ${sqlarr(newEntries.entry)}::text[], + ${sqlarr(newEntries.entryUseid)}::boolean[], + ${sqlarr(newEntries.videoId)}::uuid[], + ${sqlarr(newEntries.time)}::integer[], + ${sqlarr(newEntries.percent)}::integer[], + ${sqlarr(newEntries.playedDate)}::timestamptz[] + ) as hist(entry, entry_use_id, video_id, ts, percent, played_date)`) + .innerJoin( + entries, + sql` + case + when hist.entry_use_id then ${entries.id} = hist.entry::uuid + else ${entries.slug} = hist.entry + end + `, + ) + .leftJoin(videos, eq(videos.id, sql`hist.video_id`)), + ) + .returning({ + entryPk: history.entryPk, + videoPk: history.videoPk, + percent: history.percent, + playedDate: history.playedDate, + }); + + return { created, updated }; + }); +} + +async function updateWatchlist( + tx: Transaction, + userPk: number, + histArr: { + entryPk: number; + percent: number; + playedDate: string; + }[], +) { + if (histArr.length === 0) return; + + const nextEntry = alias(entries, "next_entry"); + const nextEntryQ = tx + .select({ + pk: nextEntry.pk, + }) + .from(nextEntry) + .where( + and( + eq(nextEntry.showPk, entries.showPk), + ne(nextEntry.kind, "extra"), + gt(nextEntry.order, entries.order), + ), + ) + .orderBy(nextEntry.order) + .limit(1) + .as("nextEntryQ"); + + const seenCountQ = tx + .select({ c: count() }) + .from(entries) + .where( + and( + eq(entries.showPk, sql`excluded.show_pk`), + exists( + db + .select() + .from(history) + .where( + and( + eq(history.profilePk, userPk), + eq(history.entryPk, entries.pk), + ), + ), + ), + ), + ); + + const showKindQ = tx + .select({ k: shows.kind }) + .from(shows) + .where(eq(shows.pk, sql`excluded.show_pk`)); + + const hist = traverse(histArr)!; + await tx + .insert(watchlist) + .select( + db + .selectDistinctOn([entries.showPk], { + profilePk: sql`${userPk}`.as("profilePk"), + showPk: entries.showPk, + status: sql` + case + when + hist.percent >= 95 + and ${nextEntryQ.pk} is null + then 'completed'::watchlist_status + else 'watching'::watchlist_status + end + `.as("status"), + seenCount: sql` + case + when ${entries.kind} = 'movie' then hist.percent + when hist.percent >= 95 then 1 + else 0 + end + `.as("seen_count"), + nextEntry: sql` + case + when hist.percent >= 95 then ${nextEntryQ.pk} + else ${entries.pk} + end + `.as("next_entry"), + score: sql`null`.as("score"), + startedAt: sql`hist.played_date`.as("startedAt"), + lastPlayedAt: sql`hist.played_date`.as("lastPlayedAt"), + completedAt: sql` + case + when ${nextEntryQ.pk} is null then hist.played_date + else null + end + `.as("completedAt"), + // see https://github.com/drizzle-team/drizzle-orm/issues/3608 + updatedAt: sql`now()`.as("updatedAt"), + }) + .from(sql`unnest( + ${sqlarr(hist.entryPk)}::integer[], + ${sqlarr(hist.percent)}::integer[], + ${sqlarr(hist.playedDate)}::timestamptz[] + ) as hist(entry_pk, percent, played_date)`) + .innerJoin(entries, eq(entries.pk, sql`hist.entry_pk`)) + .leftJoinLateral(nextEntryQ, sql`true`), + ) + .onConflictDoUpdate({ + target: [watchlist.profilePk, watchlist.showPk], + set: { + status: sql` + case + when excluded.status = 'completed' then excluded.status + when + ${watchlist.status} != 'completed' + and ${watchlist.status} != 'rewatching' + then excluded.status + else ${watchlist.status} + end + `, + seenCount: sql` + case + when ${showKindQ} = 'movie' then excluded.seen_count + else ${seenCountQ} + end`, + nextEntry: sql` + case + when ${watchlist.status} = 'completed' then null + else excluded.next_entry + end + `, + lastPlayedAt: sql`excluded.last_played_at`, + completedAt: coalesce( + watchlist.completedAt, + sql`excluded.completed_at`, + ), + }, + }); +} + +// this one is different than the normal progressQ because we want duplicates const historyProgressQ: typeof entryProgressQ = db .select({ percent: history.percent, @@ -37,7 +318,7 @@ const historyProgressQ: typeof entryProgressQ = db }) .from(history) .leftJoin(videos, eq(history.videoPk, videos.pk)) - .leftJoin(profiles, eq(history.profilePk, profiles.pk)) + .innerJoin(profiles, eq(history.profilePk, profiles.pk)) .where(eq(profiles.id, sql.placeholder("userId"))) .as("progress"); @@ -170,162 +451,8 @@ export const historyH = new Elysia({ tags: ["profiles"] }) async ({ body, jwt: { sub }, status }) => { const profilePk = await getOrCreateProfile(sub); - const hist = values( - body.map((x) => ({ ...x, entryUseId: isUuid(x.entry) })), - { - percent: "integer", - time: "integer", - playedDate: "timestamptz", - videoId: "uuid", - }, - ).as("hist"); - const valEqEntries = sql` - case - when hist.entryUseId::boolean then ${entries.id} = hist.entry::uuid - else ${entries.slug} = hist.entry - end - `; - - const rows = await db - .insert(history) - .select( - db - .select({ - profilePk: sql`${profilePk}`.as("profilePk"), - entryPk: entries.pk, - videoPk: videos.pk, - percent: sql`hist.percent`.as("percent"), - time: sql`hist.time`.as("time"), - playedDate: sql`hist.playedDate`.as("playedDate"), - }) - .from(hist) - .innerJoin(entries, valEqEntries) - .leftJoin(videos, eq(videos.id, sql`hist.videoId`)), - ) - .returning({ pk: history.pk }); - - // automatically update watchlist with this new info - - const nextEntry = alias(entries, "next_entry"); - const nextEntryQ = db - .select({ - pk: nextEntry.pk, - }) - .from(nextEntry) - .where( - and( - eq(nextEntry.showPk, entries.showPk), - ne(nextEntry.kind, "extra"), - gt(nextEntry.order, entries.order), - ), - ) - .orderBy(nextEntry.order) - .limit(1) - .as("nextEntryQ"); - - const seenCountQ = db - .select({ c: count() }) - .from(entries) - .where( - and( - eq(entries.showPk, sql`excluded.show_pk`), - exists( - db - .select() - .from(history) - .where( - and( - eq(history.profilePk, profilePk), - eq(history.entryPk, entries.pk), - ), - ), - ), - ), - ); - - const showKindQ = db - .select({ k: shows.kind }) - .from(shows) - .where(eq(shows.pk, sql`excluded.show_pk`)); - - await db - .insert(watchlist) - .select( - db - .select({ - profilePk: sql`${profilePk}`.as("profilePk"), - showPk: entries.showPk, - status: sql` - case - when - hist.percent >= 95 - and ${nextEntryQ.pk} is null - then 'completed'::watchlist_status - else 'watching'::watchlist_status - end - `.as("status"), - seenCount: sql` - case - when ${entries.kind} = 'movie' then hist.percent - when hist.percent >= 95 then 1 - else 0 - end - `.as("seen_count"), - nextEntry: sql` - case - when hist.percent >= 95 then ${nextEntryQ.pk} - else ${entries.pk} - end - `.as("next_entry"), - score: sql`null`.as("score"), - startedAt: sql`hist.playedDate`.as("startedAt"), - lastPlayedAt: sql`hist.playedDate`.as("lastPlayedAt"), - completedAt: sql` - case - when ${nextEntryQ.pk} is null then hist.playedDate - else null - end - `.as("completedAt"), - // see https://github.com/drizzle-team/drizzle-orm/issues/3608 - updatedAt: sql`now()`.as("updatedAt"), - }) - .from(hist) - .leftJoin(entries, valEqEntries) - .leftJoinLateral(nextEntryQ, sql`true`), - ) - .onConflictDoUpdate({ - target: [watchlist.profilePk, watchlist.showPk], - set: { - status: sql` - case - when excluded.status = 'completed' then excluded.status - when - ${watchlist.status} != 'completed' - and ${watchlist.status} != 'rewatching' - then excluded.status - else ${watchlist.status} - end - `, - seenCount: sql` - case - when ${showKindQ} = 'movie' then excluded.seen_count - else ${seenCountQ} - end`, - nextEntry: sql` - case - when ${watchlist.status} = 'completed' then null - else excluded.next_entry - end - `, - lastPlayedAt: sql`excluded.last_played_at`, - completedAt: coalesce( - watchlist.completedAt, - sql`excluded.completed_at`, - ), - }, - }); - - return status(201, { status: 201, inserted: rows.length }); + const ret = await updateProgress(profilePk, body); + return status(ret.status, ret); }, { detail: { description: "Bulk add entries/movies to your watch history." }, @@ -338,6 +465,10 @@ export const historyH = new Elysia({ tags: ["profiles"] }) description: "The number of history entry inserted", }), }), + 404: { + ...KError, + description: "No entry found with the given id or slug.", + }, 422: KError, }, }, diff --git a/api/src/db/schema/history.ts b/api/src/db/schema/history.ts index a2c6f685..3df6258d 100644 --- a/api/src/db/schema/history.ts +++ b/api/src/db/schema/history.ts @@ -12,6 +12,8 @@ export const history = schema.table( profilePk: integer() .notNull() .references(() => profiles.pk, { onDelete: "cascade" }), + // we need to attach an history to an entry because we want to keep history + // when we delete a video file entryPk: integer() .notNull() .references(() => entries.pk, { onDelete: "cascade" }), diff --git a/api/src/db/utils.ts b/api/src/db/utils.ts index 12c2e43c..c6fd3fdd 100644 --- a/api/src/db/utils.ts +++ b/api/src/db/utils.ts @@ -92,36 +92,6 @@ export function sqlarr(array: unknown[]): string { .join(", ")}}`; } -// See https://github.com/drizzle-team/drizzle-orm/issues/4044 -export function values( - items: Record[], - typeInfo: Partial> = {}, -) { - if (items[0] === undefined) - throw new Error("Invalid values, expecting at least one items"); - const [firstProp, ...props] = Object.keys(items[0]) as K[]; - const values = items - .map((x, i) => { - let ret = sql`(${x[firstProp]}`; - if (i === 0 && typeInfo[firstProp]) - ret = sql`${ret}::${sql.raw(typeInfo[firstProp])}`; - for (const val of props) { - ret = sql`${ret}, ${x[val]}`; - if (i === 0 && typeInfo[val]) - ret = sql`${ret}::${sql.raw(typeInfo[val])}`; - } - return sql`${ret})`; - }) - .reduce((acc, x) => sql`${acc}, ${x}`); - const valueNames = [firstProp, ...props].join(", "); - - return { - as: (name: string) => { - return sql`(values ${values}) as ${sql.raw(name)}(${sql.raw(valueNames)})`; - }, - }; -} - /* goal: * unnestValues([{a: 1, b: 2}, {a: 3, b: 4}], tbl) * diff --git a/api/src/models/history.ts b/api/src/models/history.ts index 541153fd..34064ff0 100644 --- a/api/src/models/history.ts +++ b/api/src/models/history.ts @@ -28,11 +28,11 @@ export const Progress = t.Object({ export type Progress = typeof Progress.static; export const SeedHistory = t.Intersect([ + Progress, t.Object({ entry: t.String({ description: "Id or slug of the entry/movie you watched", }), }), - Progress, ]); export type SeedHistory = typeof SeedHistory.static; diff --git a/api/src/utils.ts b/api/src/utils.ts index c74bd4a9..74c11a88 100644 --- a/api/src/utils.ts +++ b/api/src/utils.ts @@ -38,3 +38,21 @@ export function uniqBy(a: T[], key: (val: T) => string): T[] { return true; }); } + +export function traverse>( + arr: T[], +): { [K in keyof T]: T[K][] } | null { + if (arr.length === 0) return null; + + const result = {} as { [K in keyof T]: T[K][] }; + arr.forEach((obj, i) => { + for (const key in obj) { + if (!result[key]) { + result[key] = new Array(i).fill(null); + } + result[key].push(obj[key]); + } + }); + + return result; +} diff --git a/api/src/websockets.ts b/api/src/websockets.ts new file mode 100644 index 00000000..10da619c --- /dev/null +++ b/api/src/websockets.ts @@ -0,0 +1,102 @@ +import type { TObject, TString } from "@sinclair/typebox"; +import Elysia, { type TSchema, t } from "elysia"; +import { verifyJwt } from "./auth"; +import { updateProgress } from "./controllers/profiles/history"; +import { getOrCreateProfile } from "./controllers/profiles/profile"; +import { SeedHistory } from "./models/history"; + +const actionMap = { + ping: handler({ + message(ws) { + ws.send({ response: "pong" }); + }, + }), + watch: handler({ + body: t.Omit(SeedHistory, ["playedDate"]), + permissions: ["core.read"], + async message(ws, body) { + const profilePk = await getOrCreateProfile(ws.data.jwt.sub); + + const ret = await updateProgress(profilePk, [ + { ...body, playedDate: null }, + ]); + ws.send(ret); + }, + }), +}; + +const baseWs = new Elysia() + .guard({ + headers: t.Object( + { + authorization: t.Optional(t.TemplateLiteral("Bearer ${string}")), + "Sec-WebSocket-Protocol": t.Optional( + t.Array( + t.Union([t.Literal("kyoo"), t.TemplateLiteral("Bearer ${string}")]), + ), + ), + }, + { additionalProperties: true }, + ), + }) + .resolve( + async ({ + headers: { authorization, "Sec-WebSocket-Protocol": wsProtocol }, + status, + }) => { + const auth = + authorization ?? + (wsProtocol?.length === 2 && + wsProtocol[0] === "kyoo" && + wsProtocol[1].startsWith("Bearer ") + ? wsProtocol[1] + : null); + const bearer = auth?.slice(7); + if (!bearer) { + return status(403, { + status: 403, + message: "No authorization header was found.", + }); + } + try { + return await verifyJwt(bearer); + } catch (err) { + return status(403, { + status: 403, + message: "Invalid jwt. Verification vailed", + details: err, + }); + } + }, + ); + +export const appWs = baseWs.ws("/ws", { + body: t.Union( + Object.entries(actionMap).map(([k, v]) => + t.Intersect([t.Object({ action: t.Literal(k) }), v.body ?? t.Object({})]), + ), + ) as unknown as TObject<{ action: TString }>, + async open(ws) { + if (!ws.data.jwt.sub) { + ws.close(3000, "Unauthorized"); + } + }, + async message(ws, { action, ...body }) { + const handler = actionMap[action as keyof typeof actionMap]; + for (const perm of handler.permissions ?? []) { + if (!ws.data.jwt.permissions.includes(perm)) { + return ws.close(3000, `Missing permission: '${perm}'.`); + } + } + await handler.message(ws as any, body as any); + }, +}); + +type Ws = Parameters[1]["open"]>>[0]; +function handler>(ret: { + body?: Schema; + permissions?: string[]; + message: (ws: Ws, body: Schema["static"]) => void | Promise; +}) { + return ret; +}