diff --git a/backend/endpoints/auth.py b/backend/endpoints/auth.py index d973ceb7a..2c7240e6d 100644 --- a/backend/endpoints/auth.py +++ b/backend/endpoints/auth.py @@ -105,15 +105,14 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp status_code=status.HTTP_400_BAD_REQUEST, detail="Missing refresh token" ) - potential_user = await oauth_handler.get_current_active_user_from_bearer_token( + user, claims = await oauth_handler.get_current_active_user_from_bearer_token( token ) - if not potential_user: + if not user or not claims: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) - user, claims = potential_user if claims.get("type") != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" diff --git a/backend/endpoints/collections.py b/backend/endpoints/collections.py index 8dff4fe4a..390649f79 100644 --- a/backend/endpoints/collections.py +++ b/backend/endpoints/collections.py @@ -57,7 +57,7 @@ async def add_collection( _added_collection = db_collection_handler.add_collection(Collection(**cleaned_data)) - if artwork is not None: + if artwork is not None and artwork.filename is not None: file_ext = artwork.filename.split(".")[-1] ( path_cover_l, @@ -82,8 +82,9 @@ async def add_collection( _added_collection.path_cover_s = path_cover_s _added_collection.path_cover_l = path_cover_l + # Update the collection with the cover path and update database - return db_collection_handler.update_collection( + created_collection = db_collection_handler.update_collection( _added_collection.id, { c: getattr(_added_collection, c) @@ -91,6 +92,8 @@ async def add_collection( }, ) + return CollectionSchema.model_validate(created_collection) + @protected_route(router.get, "/collections", [Scope.COLLECTIONS_READ]) def get_collections(request: Request) -> list[CollectionSchema]: @@ -105,7 +108,7 @@ def get_collections(request: Request) -> list[CollectionSchema]: """ collections = db_collection_handler.get_collections() - return CollectionSchema.for_user(request.user.id, collections) + return CollectionSchema.for_user(request.user.id, [c for c in collections]) @protected_route(router.get, "/collections/{id}", [Scope.COLLECTIONS_READ]) @@ -125,7 +128,7 @@ def get_collection(request: Request, id: int) -> CollectionSchema: if not collection: raise CollectionNotFoundInDatabaseException(id) - return collection + return CollectionSchema.model_validate(collection) @protected_route(router.put, "/collections/{id}", [Scope.COLLECTIONS_WRITE]) @@ -148,6 +151,8 @@ async def update_collection( data = await request.form() collection = db_collection_handler.get_collection(id) + if not collection: + raise CollectionNotFoundInDatabaseException(id) if collection.user_id != request.user.id: raise CollectionPermissionError(id) @@ -156,7 +161,7 @@ async def update_collection( raise CollectionNotFoundInDatabaseException(id) try: - roms = json.loads(data["roms"]) + roms = json.loads(data["roms"]) # type: ignore except json.JSONDecodeError as e: raise ValueError("Invalid list for roms field in update collection") from e except KeyError: @@ -174,7 +179,7 @@ async def update_collection( cleaned_data.update(fs_resource_handler.remove_cover(collection)) cleaned_data.update({"url_cover": ""}) else: - if artwork is not None: + if artwork is not None and artwork.filename is not None: file_ext = artwork.filename.split(".")[-1] ( path_cover_l, @@ -205,13 +210,14 @@ async def update_collection( path_cover_s, path_cover_l = await fs_resource_handler.get_cover( overwrite=True, entity=collection, - url_cover=data.get("url_cover", ""), + url_cover=data.get("url_cover", ""), # type: ignore ) cleaned_data.update( {"path_cover_s": path_cover_s, "path_cover_l": path_cover_l} ) - return db_collection_handler.update_collection(id, cleaned_data) + updated_collection = db_collection_handler.update_collection(id, cleaned_data) + return CollectionSchema.model_validate(updated_collection) @protected_route(router.delete, "/collections/{id}", [Scope.COLLECTIONS_WRITE]) diff --git a/backend/endpoints/feeds.py b/backend/endpoints/feeds.py index 40f305999..0f61bec9f 100644 --- a/backend/endpoints/feeds.py +++ b/backend/endpoints/feeds.py @@ -57,7 +57,7 @@ def platforms_webrcade_feed(request: Request) -> WebrcadeFeedSchema: request.url_for( "get_rom_content", id=rom.id, - file_name=rom.file_name, + file_name=rom.fs_name, ) ), ), @@ -135,10 +135,10 @@ async def tinfoil_index_feed( async def extract_titledb(roms: list[Rom]) -> dict[str, TinfoilFeedTitleDBSchema]: titledb = {} for rom in roms: - match = SWITCH_TITLEDB_REGEX.search(rom.file_name) + match = SWITCH_TITLEDB_REGEX.search(rom.fs_name) if match: _search_term, index_entry = ( - await meta_igdb_handler._switch_titledb_format(match, rom.file_name) + await meta_igdb_handler._switch_titledb_format(match, rom.fs_name) ) if index_entry: titledb[str(index_entry["nsuId"])] = TinfoilFeedTitleDBSchema( @@ -160,9 +160,7 @@ async def tinfoil_index_feed( files=[ TinfoilFeedFileSchema( url=str( - request.url_for( - "get_rom_content", id=rom.id, file_name=rom.file_name - ) + request.url_for("get_rom_content", id=rom.id, file_name=rom.fs_name) ), size=rom.file_size_bytes, ) diff --git a/backend/endpoints/firmware.py b/backend/endpoints/firmware.py index 39a1dc5b6..6f6a4c16d 100644 --- a/backend/endpoints/firmware.py +++ b/backend/endpoints/firmware.py @@ -28,25 +28,28 @@ def add_firmware( files (list[UploadFile], optional): List of files to upload Raises: - HTTPException: No files were uploaded + HTTPException Returns: AddFirmwareResponse: Standard message response """ db_platform = db_platform_handler.get_platform(platform_id) + if not db_platform: + error = f"Platform with ID {platform_id} not found" + log.error(error) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error) + log.info(f"Uploading firmware to {db_platform.fs_slug}") - if files is None: - log.error("No files were uploaded") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="No files were uploaded", - ) uploaded_firmware = [] firmware_path = fs_firmware_handler.build_upload_file_path(db_platform.fs_slug) for file in files: + if not file.filename: + log.warning("Empty filename, skipping") + continue + fs_firmware_handler.write_file(file=file, path=firmware_path) db_firmware = db_firmware_handler.get_firmware_by_filename( @@ -70,10 +73,14 @@ def add_firmware( uploaded_firmware.append(scanned_firmware) db_platform = db_platform_handler.get_platform(platform_id) + if not db_platform: + error = f"Platform with ID {platform_id} not found" + log.error(error) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error) return { "uploaded": len(files), - "firmware": db_platform.firmware, + "firmware": [FirmwareSchema.model_validate(f) for f in db_platform.firmware], } @@ -90,7 +97,10 @@ def get_platform_firmware( Returns: list[FirmwareSchema]: Firmware stored in the database """ - return db_firmware_handler.list_firmware(platform_id=platform_id) + return [ + FirmwareSchema.model_validate(f) + for f in db_firmware_handler.list_firmware(platform_id=platform_id) + ] @protected_route( @@ -108,7 +118,13 @@ def get_firmware(request: Request, id: int) -> FirmwareSchema: Returns: FirmwareSchema: Firmware stored in the database """ - return FirmwareSchema(**db_firmware_handler.get_firmware(id)) + firmware = db_firmware_handler.get_firmware(id) + if not firmware: + error = f"Firmware with ID {id} not found" + log.error(error) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error) + + return FirmwareSchema.model_validate(firmware) @protected_route( @@ -129,6 +145,11 @@ def head_firmware_content(request: Request, id: int, file_name: str): """ firmware = db_firmware_handler.get_firmware(id) + if not firmware: + error = f"Firmware with ID {id} not found" + log.error(error) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error) + firmware_path = f"{LIBRARY_BASE_PATH}/{firmware.full_path}" return FileResponse( @@ -162,6 +183,11 @@ def get_firmware_content( """ firmware = db_firmware_handler.get_firmware(id) + if not firmware: + error = f"Firmware with ID {id} not found" + log.error(error) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error) + firmware_path = f"{LIBRARY_BASE_PATH}/{firmware.full_path}" return FileResponse(path=firmware_path, filename=firmware.file_name) diff --git a/backend/endpoints/platform.py b/backend/endpoints/platform.py index b7b7daf3a..7d9f2e9e1 100644 --- a/backend/endpoints/platform.py +++ b/backend/endpoints/platform.py @@ -12,7 +12,6 @@ from handler.filesystem import fs_platform_handler from handler.metadata.igdb_handler import IGDB_PLATFORM_LIST from handler.scan_handler import scan_platform from logger.logger import log -from models.platform import Platform from utils.router import APIRouter router = APIRouter() @@ -36,7 +35,9 @@ async def add_platforms(request: Request) -> PlatformSchema: except PlatformAlreadyExistsException: log.info(f"Detected platform: {fs_slug}") scanned_platform = await scan_platform(fs_slug, [fs_slug]) - return db_platform_handler.add_platform(scanned_platform) + return PlatformSchema.model_validate( + db_platform_handler.add_platform(scanned_platform) + ) @protected_route(router.get, "/platforms", [Scope.PLATFORMS_READ]) @@ -51,7 +52,9 @@ def get_platforms(request: Request) -> list[PlatformSchema]: list[PlatformSchema]: List of platforms """ - return db_platform_handler.get_platforms() + return [ + PlatformSchema.model_validate(p) for p in db_platform_handler.get_platforms() + ] @protected_route(router.get, "/platforms/supported", [Scope.PLATFORMS_READ]) @@ -66,7 +69,7 @@ def get_supported_platforms(request: Request) -> list[PlatformSchema]: """ supported_platforms = [] - db_platforms: list[Platform] = db_platform_handler.get_platforms() + db_platforms = db_platform_handler.get_platforms() db_platforms_map = {p.name: p.id for p in db_platforms} for platform in IGDB_PLATFORM_LIST: @@ -108,7 +111,7 @@ def get_platform(request: Request, id: int) -> PlatformSchema: if not platform: raise PlatformNotFoundInDatabaseException(id) - return platform + return PlatformSchema.model_validate(platform) @protected_route(router.put, "/platforms/{id}", [Scope.PLATFORMS_WRITE]) diff --git a/backend/endpoints/responses/rom.py b/backend/endpoints/responses/rom.py index f101a7b6c..fac0bbe12 100644 --- a/backend/endpoints/responses/rom.py +++ b/backend/endpoints/responses/rom.py @@ -16,12 +16,12 @@ SORT_COMPARE_REGEX = re.compile(r"^([Tt]he|[Aa]|[Aa]nd)\s") RomIGDBMetadata = TypedDict( # type: ignore[misc] "RomIGDBMetadata", - {k: NotRequired[v] for k, v in get_type_hints(IGDBMetadata).items()}, + dict((k, NotRequired[v]) for k, v in get_type_hints(IGDBMetadata).items()), total=False, ) RomMobyMetadata = TypedDict( # type: ignore[misc] "RomMobyMetadata", - {k: NotRequired[v] for k, v in get_type_hints(MobyMetadata).items()}, + dict((k, NotRequired[v]) for k, v in get_type_hints(MobyMetadata).items()), total=False, ) @@ -179,16 +179,15 @@ class SimpleRomSchema(RomSchema): @classmethod def from_orm_with_request(cls, db_rom: Rom, request: Request) -> SimpleRomSchema: user_id = request.user.id - - db_rom.rom_user = RomUserSchema.for_user(user_id, db_rom) - - return cls.model_validate(db_rom) + rom = cls.model_validate(db_rom) + rom.rom_user = RomUserSchema.for_user(user_id, db_rom) + return rom @classmethod def from_orm_with_factory(cls, db_rom: Rom) -> SimpleRomSchema: - db_rom.rom_user = rom_user_schema_factory() - - return cls.model_validate(db_rom) + rom = cls.model_validate(db_rom) + rom.rom_user = rom_user_schema_factory() + return rom class DetailedRomSchema(RomSchema): @@ -205,24 +204,26 @@ class DetailedRomSchema(RomSchema): def from_orm_with_request(cls, db_rom: Rom, request: Request) -> DetailedRomSchema: user_id = request.user.id - db_rom.rom_user = RomUserSchema.for_user(user_id, db_rom) - db_rom.user_notes = RomUserSchema.notes_for_user(user_id, db_rom) - db_rom.user_saves = [ + rom = cls.model_validate(db_rom) + + rom.rom_user = RomUserSchema.for_user(user_id, db_rom) + rom.user_notes = RomUserSchema.notes_for_user(user_id, db_rom) + rom.user_saves = [ SaveSchema.model_validate(s) for s in db_rom.saves if s.user_id == user_id ] - db_rom.user_states = [ + rom.user_states = [ StateSchema.model_validate(s) for s in db_rom.states if s.user_id == user_id ] - db_rom.user_screenshots = [ + rom.user_screenshots = [ ScreenshotSchema.model_validate(s) for s in db_rom.screenshots if s.user_id == user_id ] - db_rom.user_collections = CollectionSchema.for_user( + rom.user_collections = CollectionSchema.for_user( user_id, db_rom.get_collections() ) - return cls.model_validate(db_rom) + return rom class UserNotesSchema(TypedDict): diff --git a/backend/endpoints/rom.py b/backend/endpoints/rom.py index e97ff30e3..ab05d7ce3 100644 --- a/backend/endpoints/rom.py +++ b/backend/endpoints/rom.py @@ -187,16 +187,16 @@ async def head_rom_content( raise RomNotFoundInDatabaseException(id) rom_path = f"{LIBRARY_BASE_PATH}/{rom.full_path}" - files_to_check = files or [r["filename"] for r in rom.files] + files_to_check = files or [r.file_name for r in rom.files] if not rom.multi: # Serve the file directly in development mode for emulatorjs if DEV_MODE: return FileResponse( path=rom_path, - filename=rom.file_name, + filename=rom.fs_name, headers={ - "Content-Disposition": f'attachment; filename="{quote(rom.file_name)}"', + "Content-Disposition": f'attachment; filename="{quote(rom.fs_name)}"', "Content-Type": "application/octet-stream", "Content-Length": str(rom.file_size_bytes), }, @@ -204,7 +204,7 @@ async def head_rom_content( return FileRedirectResponse( download_path=Path(f"/library/{rom.full_path}"), - filename=rom.file_name, + filename=rom.fs_name, ) if len(files_to_check) == 1: @@ -252,14 +252,14 @@ async def get_rom_content( raise RomNotFoundInDatabaseException(id) rom_path = f"{LIBRARY_BASE_PATH}/{rom.full_path}" - files_to_download = sorted(files or [r["filename"] for r in rom.files]) + files_to_download = sorted(files or [r.file_name for r in rom.files]) - log.info(f"User {current_username} is downloading {rom.file_name}") + log.info(f"User {current_username} is downloading {rom.fs_name}") if not rom.multi: return FileRedirectResponse( download_path=Path(f"/library/{rom.full_path}"), - filename=rom.file_name, + filename=rom.fs_name, ) if len(files_to_download) == 1: @@ -332,7 +332,7 @@ async def update_rom( "igdb_id": None, "sgdb_id": None, "moby_id": None, - "name": rom.file_name, + "name": rom.fs_name, "summary": "", "url_screenshots": [], "path_screenshots": [], @@ -346,20 +346,20 @@ async def update_rom( }, ) - return DetailedRomSchema.from_orm_with_request( - db_rom_handler.get_rom(id), request - ) + rom = db_rom_handler.get_rom(id) + if not rom: + raise RomNotFoundInDatabaseException(id) - cleaned_data = { - "igdb_id": data.get("igdb_id", None), - "moby_id": data.get("moby_id", None), + return DetailedRomSchema.from_orm_with_request(rom, request) + + cleaned_data: dict = { + "igdb_id": str(data.get("igdb_id", "")), + "moby_id": str(data.get("moby_id", "")), } - if ( - cleaned_data.get("moby_id", "") - and int(cleaned_data.get("moby_id", "")) != rom.moby_id - ): - moby_rom = await meta_moby_handler.get_rom_by_id(cleaned_data["moby_id"]) + moby_id: str = cleaned_data["moby_id"] + if moby_id and int(moby_id) != rom.moby_id: + moby_rom = await meta_moby_handler.get_rom_by_id(int(moby_id)) cleaned_data.update(moby_rom) path_screenshots = await fs_resource_handler.get_rom_screenshots( rom=rom, @@ -367,11 +367,9 @@ async def update_rom( ) cleaned_data.update({"path_screenshots": path_screenshots}) - if ( - cleaned_data.get("igdb_id", "") - and int(cleaned_data.get("igdb_id", "")) != rom.igdb_id - ): - igdb_rom = await meta_igdb_handler.get_rom_by_id(cleaned_data["igdb_id"]) + igdb_id: str = cleaned_data["igdb_id"] + if igdb_id and int(igdb_id) != rom.igdb_id: + igdb_rom = await meta_igdb_handler.get_rom_by_id(int(igdb_id)) cleaned_data.update(igdb_rom) path_screenshots = await fs_resource_handler.get_rom_screenshots( rom=rom, @@ -386,26 +384,26 @@ async def update_rom( } ) - new_file_name = data.get("file_name", rom.file_name) + new_file_name = str(data.get("file_name", rom.fs_name)) try: if rename_as_source: - new_file_name = rom.file_name.replace( - rom.file_name_no_tags or rom.file_name_no_ext, - data.get("name", rom.name), + new_file_name = rom.fs_name.replace( + rom.fs_name_no_tags or rom.fs_name_no_ext, + str(data.get("name", rom.name)), ) new_file_name = sanitize_filename(new_file_name) fs_rom_handler.rename_file( - old_name=rom.file_name, + old_name=rom.fs_name, new_name=new_file_name, - file_path=rom.file_path, + file_path=rom.fs_path, ) - elif rom.file_name != new_file_name: + elif rom.fs_name != new_file_name: new_file_name = sanitize_filename(new_file_name) fs_rom_handler.rename_file( - old_name=rom.file_name, + old_name=rom.fs_name, new_name=new_file_name, - file_path=rom.file_path, + file_path=rom.fs_path, ) except RomAlreadyExistsException as exc: log.error(exc) @@ -429,7 +427,7 @@ async def update_rom( cleaned_data.update(fs_resource_handler.remove_cover(rom)) cleaned_data.update({"url_cover": ""}) else: - if artwork: + if artwork is not None and artwork.filename is not None: file_ext = artwork.filename.split(".")[-1] ( path_cover_l, @@ -459,15 +457,18 @@ async def update_rom( path_cover_s, path_cover_l = await fs_resource_handler.get_cover( overwrite=True, entity=rom, - url_cover=data.get("url_cover", ""), + url_cover=str(data.get("url_cover", "")), ) cleaned_data.update( {"path_cover_s": path_cover_s, "path_cover_l": path_cover_l} ) db_rom_handler.update_rom(id, cleaned_data) + rom = db_rom_handler.get_rom(id) + if not rom: + raise RomNotFoundInDatabaseException(id) - return DetailedRomSchema.from_orm_with_request(db_rom_handler.get_rom(id), request) + return DetailedRomSchema.from_orm_with_request(rom, request) @protected_route(router.post, "/roms/delete", [Scope.ROMS_WRITE]) @@ -497,13 +498,15 @@ async def delete_roms( if not rom: raise RomNotFoundInDatabaseException(id) - log.info(f"Deleting {rom.file_name} from database") + log.info(f"Deleting {rom.fs_name} from database") db_rom_handler.delete_rom(id) # Update collections to remove the deleted rom collections = db_collection_handler.get_collections_by_rom_id(id) for collection in collections: - collection.roms = [rom_id for rom_id in collection.roms if rom_id != id] + collection.roms = set( + [rom_id for rom_id in collection.roms if rom_id != id] + ) db_collection_handler.update_collection( collection.id, {"roms": collection.roms} ) @@ -514,13 +517,13 @@ async def delete_roms( log.error(f"Couldn't find resources to delete for {rom.name}") if id in delete_from_fs: - log.info(f"Deleting {rom.file_name} from filesystem") + log.info(f"Deleting {rom.fs_name} from filesystem") try: - fs_rom_handler.remove_file( - file_name=rom.file_name, file_path=rom.file_path - ) + fs_rom_handler.remove_file(file_name=rom.fs_name, file_path=rom.fs_name) except FileNotFoundError as exc: - error = f"Rom file {rom.file_name} not found for platform {rom.platform_slug}" + error = ( + f"Rom file {rom.fs_name} not found for platform {rom.platform_slug}" + ) log.error(error) raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=error @@ -556,5 +559,6 @@ async def update_rom_user(request: Request, id: int) -> RomUserSchema: ] cleaned_data = {field: data[field] for field in fields_to_update if field in data} + rom_user = db_rom_handler.update_rom_user(db_rom_user.id, cleaned_data) - return db_rom_handler.update_rom_user(db_rom_user.id, cleaned_data) + return RomUserSchema.model_validate(rom_user) diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index 5cee897ec..448cb8aa2 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -29,18 +29,15 @@ def add_saves( current_user = request.user log.info(f"Uploading saves to {rom.name}") - if saves is None: - log.error("No saves were uploaded") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="No saves were uploaded", - ) - saves_path = fs_asset_handler.build_saves_file_path( user=request.user, platform_fs_slug=rom.platform.fs_slug, emulator=emulator ) for save in saves: + if not save.filename: + log.error("Save file has no filename") + continue + fs_asset_handler.write_file(file=save, path=saves_path) # Scan or update save @@ -79,7 +76,11 @@ def add_saves( return { "uploaded": len(saves), - "saves": [s for s in rom.saves if s.user_id == current_user.id], + "saves": [ + SaveSchema.model_validate(s) + for s in rom.saves + if s.user_id == current_user.id + ], } @@ -109,7 +110,7 @@ async def update_save(request: Request, id: int) -> SaveSchema: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=error) if "file" in data: - file: UploadFile = data["file"] + file: UploadFile = data["file"] # type: ignore fs_asset_handler.write_file(file=file, path=db_save.file_path) db_save_handler.update_save(db_save.id, {"file_size_bytes": file.size}) @@ -124,7 +125,7 @@ async def update_save(request: Request, id: int) -> SaveSchema: # Refetch the save to get updated fields db_save = db_save_handler.get_save(id) - return db_save + return SaveSchema.model_validate(db_save) @protected_route(router.post, "/saves/delete", [Scope.ASSETS_WRITE]) diff --git a/backend/endpoints/screenshots.py b/backend/endpoints/screenshots.py index 480a52407..10fc9b8d5 100644 --- a/backend/endpoints/screenshots.py +++ b/backend/endpoints/screenshots.py @@ -1,6 +1,7 @@ from decorators.auth import protected_route -from endpoints.responses.assets import UploadedScreenshotsResponse -from fastapi import File, HTTPException, Request, UploadFile, status +from endpoints.responses.assets import ScreenshotSchema, UploadedScreenshotsResponse +from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException +from fastapi import File, Request, UploadFile from handler.auth.base_handler import Scope from handler.database import db_rom_handler, db_screenshot_handler from handler.filesystem import fs_asset_handler @@ -18,21 +19,21 @@ def add_screenshots( screenshots: list[UploadFile] = File(...), # noqa: B008 ) -> UploadedScreenshotsResponse: rom = db_rom_handler.get_rom(rom_id) + if not rom: + raise RomNotFoundInDatabaseException(rom_id) + current_user = request.user log.info(f"Uploading screenshots to {rom.name}") - if screenshots is None: - log.error("No screenshots were uploaded") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="No screenshots were uploaded", - ) - screenshots_path = fs_asset_handler.build_screenshots_file_path( user=request.user, platform_fs_slug=rom.platform_slug ) for screenshot in screenshots: + if not screenshot.filename: + log.warning("Skipping empty screenshot") + continue + fs_asset_handler.write_file(file=screenshot, path=screenshots_path) # Scan or update screenshot @@ -56,8 +57,15 @@ def add_screenshots( db_screenshot_handler.add_screenshot(scanned_screenshot) rom = db_rom_handler.get_rom(rom_id) + if not rom: + raise RomNotFoundInDatabaseException(rom_id) + return { "uploaded": len(screenshots), - "screenshots": [s for s in rom.screenshots if s.user_id == current_user.id], + "screenshots": [ + ScreenshotSchema.model_validate(s) + for s in rom.screenshots + if s.user_id == current_user.id + ], "merged_screenshots": rom.merged_screenshots, } diff --git a/backend/endpoints/search.py b/backend/endpoints/search.py index e07d757f1..ce8a103c1 100644 --- a/backend/endpoints/search.py +++ b/backend/endpoints/search.py @@ -5,8 +5,8 @@ from fastapi import HTTPException, Request, status from handler.auth.base_handler import Scope from handler.database import db_rom_handler from handler.metadata import meta_igdb_handler, meta_moby_handler, meta_sgdb_handler -from handler.metadata.igdb_handler import IGDB_API_ENABLED -from handler.metadata.moby_handler import MOBY_API_ENABLED +from handler.metadata.igdb_handler import IGDB_API_ENABLED, IGDBRom +from handler.metadata.moby_handler import MOBY_API_ENABLED, MobyGamesRom from handler.metadata.sgdb_handler import STEAMGRIDDB_API_ENABLED from handler.scan_handler import _get_main_platform_igdb_id from logger.logger import log @@ -43,11 +43,11 @@ async def search_rom( detail="No metadata providers enabled", ) - rom = db_rom_handler.get_rom(rom_id) + rom = db_rom_handler.get_rom(int(rom_id)) if not rom: return [] - search_term = search_term or rom.file_name_no_tags + search_term = search_term or rom.fs_name_no_tags if not search_term: return [] @@ -57,7 +57,11 @@ async def search_rom( matched_roms: list = [] log.info(f"Searching by {search_by.lower()}: {search_term}") - log.info(emoji.emojize(f":video_game: {rom.platform_slug}: {rom.file_name}")) + log.info(emoji.emojize(f":video_game: {rom.platform_slug}: {rom.fs_name}")) + + igdb_matched_roms: list[IGDBRom] = [] + moby_matched_roms: list[MobyGamesRom] = [] + if search_by.lower() == "id": try: igdb_matched_roms = await meta_igdb_handler.get_matched_roms_by_id( @@ -73,19 +77,20 @@ async def search_rom( detail=f"Tried searching by ID, but '{search_term}' is not a valid ID", ) from exc elif search_by.lower() == "name": + main_platform_igdb_id = await _get_main_platform_igdb_id(rom.platform) igdb_matched_roms = await meta_igdb_handler.get_matched_roms_by_name( - search_term, (await _get_main_platform_igdb_id(rom.platform)) + search_term, main_platform_igdb_id or rom.platform.igdb_id ) moby_matched_roms = await meta_moby_handler.get_matched_roms_by_name( search_term, rom.platform.moby_id ) merged_dict = { - item["name"]: {**item, "igdb_url_cover": item.pop("url_cover", "")} + item["name"]: {**item, "igdb_url_cover": item.pop("url_cover", "")} # type: ignore for item in igdb_matched_roms } for item in moby_matched_roms: - merged_dict[item["name"]] = { + merged_dict[item["name"]] = { # type: ignore **item, "moby_url_cover": item.pop("url_cover", ""), **merged_dict.get(item.get("name", ""), {}), diff --git a/backend/endpoints/sockets/scan.py b/backend/endpoints/sockets/scan.py index 4766a86f7..188f46f67 100644 --- a/backend/endpoints/sockets/scan.py +++ b/backend/endpoints/sockets/scan.py @@ -143,7 +143,9 @@ async def scan_platforms( try: platform_list = [ - db_platform_handler.get_platform(s).fs_slug for s in platform_ids + platform.fs_slug + for s in platform_ids + if (platform := db_platform_handler.get_platform(s)) is not None ] or fs_platforms if len(platform_list) == 0: @@ -271,14 +273,14 @@ async def _identify_platform( for fs_roms_batch in batched(fs_roms, 200): rom_by_filename_map = db_rom_handler.get_roms_by_filename( platform_id=platform.id, - file_names={fs_rom["file_name"] for fs_rom in fs_roms_batch}, + file_names={fs_rom["fs_name"] for fs_rom in fs_roms_batch}, ) for fs_rom in fs_roms_batch: scan_stats += await _identify_rom( platform=platform, fs_rom=fs_rom, - rom=rom_by_filename_map.get(fs_rom["file_name"]), + rom=rom_by_filename_map.get(fs_rom["fs_name"]), scan_type=scan_type, roms_ids=roms_ids, metadata_sources=metadata_sources, @@ -290,12 +292,12 @@ async def _identify_platform( # the folder structure is not correct or the drive is not mounted if len(fs_roms) > 0: purged_roms = db_rom_handler.purge_roms( - platform.id, [rom["file_name"] for rom in fs_roms] + platform.id, [rom["fs_name"] for rom in fs_roms] ) if len(purged_roms) > 0: log.info("Purging roms not found in the filesystem:") for r in purged_roms: - log.info(f" - {r.file_name}") + log.info(f" - {r.fs_name}") # Same protection for firmware if len(fs_firmware) > 0: @@ -346,7 +348,7 @@ def _set_rom_hashes(rom_id: int): return try: - rom_hashes = fs_rom_handler.get_rom_hashes(rom.file_name, rom.file_path) + rom_hashes = fs_rom_handler.get_rom_hashes(rom.fs_name, rom.fs_path) db_rom_handler.update_rom( rom_id, { @@ -358,7 +360,7 @@ def _set_rom_hashes(rom_id: int): except zlib.error as e: # Set empty hashes if calculating them fails for corrupted files log.error( - f"Hashes of {rom.file_name} couldn't be calculated: {hl(str(e), color=RED)}" + f"Hashes of {rom.fs_name} couldn't be calculated: {hl(str(e), color=RED)}" ) db_rom_handler.update_rom( rom_id, @@ -387,12 +389,12 @@ async def _identify_rom( if not _should_scan_rom(scan_type=scan_type, rom=rom, roms_ids=roms_ids): if rom and ( - rom.file_name != fs_rom["file_name"] + rom.fs_name != fs_rom["fs_name"] or rom.multi != fs_rom["multi"] or rom.files != fs_rom["files"] ): # Just to update the filesystem data - rom.file_name = fs_rom["file_name"] + rom.fs_name = fs_rom["fs_name"] rom.multi = fs_rom["multi"] rom.files = fs_rom["files"] db_rom_handler.add_rom(rom) diff --git a/backend/endpoints/states.py b/backend/endpoints/states.py index dc95311c6..70e4200b7 100644 --- a/backend/endpoints/states.py +++ b/backend/endpoints/states.py @@ -29,18 +29,15 @@ def add_states( current_user = request.user log.info(f"Uploading states to {rom.name}") - if states is None: - log.error("No states were uploaded") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="No states were uploaded", - ) - states_path = fs_asset_handler.build_states_file_path( user=request.user, platform_fs_slug=rom.platform.fs_slug, emulator=emulator ) for state in states: + if not state.filename: + log.warning("Skipping file with no filename") + continue + fs_asset_handler.write_file(file=state, path=states_path) # Scan or update state @@ -73,9 +70,16 @@ def add_states( ) rom = db_rom_handler.get_rom(rom_id) + if not rom: + raise RomNotFoundInDatabaseException(rom_id) + return { "uploaded": len(states), - "states": [s for s in rom.states if s.user_id == current_user.id], + "states": [ + StateSchema.model_validate(s) + for s in rom.states + if s.user_id == current_user.id + ], } @@ -105,7 +109,7 @@ async def update_state(request: Request, id: int) -> StateSchema: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=error) if "file" in data: - file: UploadFile = data["file"] + file: UploadFile = data["file"] # type: ignore fs_asset_handler.write_file(file=file, path=db_state.file_path) db_state_handler.update_state(db_state.id, {"file_size_bytes": file.size}) @@ -119,7 +123,7 @@ async def update_state(request: Request, id: int) -> StateSchema: ) db_state = db_state_handler.get_state(id) - return db_state + return StateSchema.model_validate(db_state) @protected_route(router.post, "/states/delete", [Scope.ASSETS_WRITE]) diff --git a/backend/endpoints/user.py b/backend/endpoints/user.py index 98940147b..1488159e7 100644 --- a/backend/endpoints/user.py +++ b/backend/endpoints/user.py @@ -76,7 +76,7 @@ def add_user( role=Role[role.upper()], ) - return db_user_handler.add_user(user) + return UserSchema.model_validate(db_user_handler.add_user(user)) @protected_route(router.get, "/users", [Scope.USERS_READ]) @@ -90,7 +90,7 @@ def get_users(request: Request) -> list[UserSchema]: list[UserSchema]: All users stored in the RomM's database """ - return [u for u in db_user_handler.get_users()] + return [UserSchema.model_validate(u) for u in db_user_handler.get_users()] @protected_route(router.get, "/users/me", [Scope.ME_READ]) @@ -122,7 +122,7 @@ def get_user(request: Request, id: int) -> UserSchema: if not user: raise HTTPException(status_code=404, detail="User not found") - return user + return UserSchema.model_validate(user) @protected_route(router.put, "/users/{id}", [Scope.ME_WRITE]) @@ -215,7 +215,13 @@ async def update_user( if request.user.id == id and creds_updated: request.session.clear() - return db_user_handler.get_user(id) + db_user = db_user_handler.get_user(id) + if not db_user: + msg = f"Username with id {id} not found" + log.error(msg) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=msg) + + return UserSchema.model_validate(db_user) @protected_route(router.delete, "/users/{id}", [Scope.USERS_WRITE]) diff --git a/backend/handler/database/collections_handler.py b/backend/handler/database/collections_handler.py index 4faadc963..9e60b7549 100644 --- a/backend/handler/database/collections_handler.py +++ b/backend/handler/database/collections_handler.py @@ -1,6 +1,8 @@ +from typing import Sequence + from decorators.database import begin_session from models.collection import Collection -from sqlalchemy import Select, delete, select, update +from sqlalchemy import delete, select, update from sqlalchemy.orm import Session from .base_handler import DBBaseHandler @@ -10,10 +12,17 @@ class DBCollectionsHandler(DBBaseHandler): @begin_session def add_collection( self, collection: Collection, session: Session = None - ) -> Collection | None: + ) -> Collection: collection = session.merge(collection) session.flush() - return session.scalar(select(Collection).filter_by(id=collection.id).limit(1)) + + new_collection = session.scalar( + select(Collection).filter_by(id=collection.id).limit(1) + ) + if not new_collection: + raise ValueError("Could not find newly created collection") + + return new_collection @begin_session def get_collection(self, id: int, session: Session = None) -> Collection | None: @@ -28,9 +37,9 @@ class DBCollectionsHandler(DBBaseHandler): ) @begin_session - def get_collections(self, session: Session = None) -> Select[tuple[Collection]]: + def get_collections(self, session: Session = None) -> Sequence[Collection]: return ( - session.scalars(select(Collection).order_by(Collection.name.asc())) # type: ignore[attr-defined] + session.scalars(select(Collection).order_by(Collection.name.asc())) .unique() .all() ) @@ -38,7 +47,7 @@ class DBCollectionsHandler(DBBaseHandler): @begin_session def get_collections_by_rom_id( self, rom_id: int, session: Session = None - ) -> list[Collection]: + ) -> Sequence[Collection]: return session.scalars( select(Collection).filter(Collection.roms.contains(rom_id)) ).all() @@ -56,8 +65,8 @@ class DBCollectionsHandler(DBBaseHandler): return session.query(Collection).filter_by(id=id).one() @begin_session - def delete_collection(self, id: int, session: Session = None) -> int: - return session.execute( + def delete_collection(self, id: int, session: Session = None) -> None: + session.execute( delete(Collection) .where(Collection.id == id) .execution_options(synchronize_session="evaluate") diff --git a/backend/handler/database/firmware_handler.py b/backend/handler/database/firmware_handler.py index 88ccebadd..f99ebc36f 100644 --- a/backend/handler/database/firmware_handler.py +++ b/backend/handler/database/firmware_handler.py @@ -1,3 +1,5 @@ +from typing import Sequence + from decorators.database import begin_session from models.firmware import Firmware from sqlalchemy import and_, delete, select, update @@ -26,7 +28,7 @@ class DBFirmwareHandler(DBBaseHandler): *, platform_id: int | None = None, session: Session = None, - ) -> list[Firmware]: + ) -> Sequence[Firmware]: return session.scalars( select(Firmware) .filter_by(platform_id=platform_id) @@ -45,7 +47,7 @@ class DBFirmwareHandler(DBBaseHandler): @begin_session def update_firmware(self, id: int, data: dict, session: Session = None) -> Firmware: - return session.execute( + return session.scalar( update(Firmware) .where(Firmware.id == id) .values(**data) @@ -54,7 +56,7 @@ class DBFirmwareHandler(DBBaseHandler): @begin_session def delete_firmware(self, id: int, session: Session = None) -> None: - return session.execute( + session.execute( delete(Firmware) .where(Firmware.id == id) .execution_options(synchronize_session="evaluate") @@ -63,7 +65,7 @@ class DBFirmwareHandler(DBBaseHandler): @begin_session def purge_firmware( self, platform_id: int, fs_firmwares: list[str], session: Session = None - ) -> None: + ) -> Sequence[Firmware]: purged_firmware = ( session.scalars( select(Firmware) diff --git a/backend/handler/database/platforms_handler.py b/backend/handler/database/platforms_handler.py index 56ee6261b..e17c5de1f 100644 --- a/backend/handler/database/platforms_handler.py +++ b/backend/handler/database/platforms_handler.py @@ -1,7 +1,9 @@ +from typing import Sequence + from decorators.database import begin_session from models.platform import Platform from models.rom import Rom -from sqlalchemy import Select, delete, or_, select +from sqlalchemy import delete, or_, select from sqlalchemy.orm import Session from .base_handler import DBBaseHandler @@ -21,7 +23,7 @@ class DBPlatformsHandler(DBBaseHandler): select(Platform).filter_by(id=platform.id).limit(1) ) if not new_platform: - raise ValueError("Could not find newlyewly created platform") + raise ValueError("Could not find newly created platform") return new_platform @@ -30,9 +32,9 @@ class DBPlatformsHandler(DBBaseHandler): return session.scalar(select(Platform).filter_by(id=id).limit(1)) @begin_session - def get_platforms(self, *, session: Session = None) -> Select[tuple[Platform]]: + def get_platforms(self, *, session: Session = None) -> Sequence[Platform]: return ( - session.scalars(select(Platform).order_by(Platform.name.asc())) # type: ignore[attr-defined] + session.scalars(select(Platform).order_by(Platform.name.asc())) .unique() .all() ) @@ -61,7 +63,7 @@ class DBPlatformsHandler(DBBaseHandler): @begin_session def purge_platforms( self, fs_platforms: list[str], session: Session = None - ) -> Select[tuple[Platform]]: + ) -> Sequence[Platform]: purged_platforms = ( session.scalars( select(Platform) diff --git a/backend/handler/database/saves_handler.py b/backend/handler/database/saves_handler.py index ac65a177d..698c0279d 100644 --- a/backend/handler/database/saves_handler.py +++ b/backend/handler/database/saves_handler.py @@ -1,3 +1,5 @@ +from typing import Sequence + from decorators.database import begin_session from models.assets import Save from sqlalchemy import and_, delete, select, update @@ -12,7 +14,7 @@ class DBSavesHandler(DBBaseHandler): return session.merge(save) @begin_session - def get_save(self, id: int, session: Session = None) -> Save: + def get_save(self, id: int, session: Session = None) -> Save | None: return session.get(Save, id) @begin_session @@ -27,7 +29,7 @@ class DBSavesHandler(DBBaseHandler): @begin_session def update_save(self, id: int, data: dict, session: Session = None) -> Save: - return session.execute( + return session.scalar( update(Save) .where(Save.id == id) .values(**data) @@ -36,7 +38,7 @@ class DBSavesHandler(DBBaseHandler): @begin_session def delete_save(self, id: int, session: Session = None) -> None: - return session.execute( + session.execute( delete(Save) .where(Save.id == id) .execution_options(synchronize_session="evaluate") @@ -45,8 +47,18 @@ class DBSavesHandler(DBBaseHandler): @begin_session def purge_saves( self, rom_id: int, user_id: int, saves: list[str], session: Session = None - ) -> None: - return session.execute( + ) -> Sequence[Save]: + purged_saves = session.scalars( + select(Save).filter( + and_( + Save.rom_id == rom_id, + Save.user_id == user_id, + Save.file_name.not_in(saves), + ) + ) + ).all() + + session.execute( delete(Save) .where( and_( @@ -57,3 +69,5 @@ class DBSavesHandler(DBBaseHandler): ) .execution_options(synchronize_session="evaluate") ) + + return purged_saves diff --git a/backend/handler/database/screenshots_handler.py b/backend/handler/database/screenshots_handler.py index d7b50c637..94d30f43a 100644 --- a/backend/handler/database/screenshots_handler.py +++ b/backend/handler/database/screenshots_handler.py @@ -1,6 +1,8 @@ +from typing import Sequence + from decorators.database import begin_session from models.assets import Screenshot -from sqlalchemy import delete, select, update +from sqlalchemy import and_, delete, select, update from sqlalchemy.orm import Session from .base_handler import DBBaseHandler @@ -14,7 +16,7 @@ class DBScreenshotsHandler(DBBaseHandler): return session.merge(screenshot) @begin_session - def get_screenshot(self, id, session: Session = None) -> Screenshot: + def get_screenshot(self, id, session: Session = None) -> Screenshot | None: return session.get(Screenshot, id) @begin_session @@ -31,7 +33,7 @@ class DBScreenshotsHandler(DBBaseHandler): def update_screenshot( self, id: int, data: dict, session: Session = None ) -> Screenshot: - return session.execute( + return session.scalar( update(Screenshot) .where(Screenshot.id == id) .values(**data) @@ -40,7 +42,7 @@ class DBScreenshotsHandler(DBBaseHandler): @begin_session def delete_screenshot(self, id: int, session: Session = None) -> None: - return session.execute( + session.execute( delete(Screenshot) .where(Screenshot.id == id) .execution_options(synchronize_session="evaluate") @@ -49,13 +51,27 @@ class DBScreenshotsHandler(DBBaseHandler): @begin_session def purge_screenshots( self, rom_id: int, user_id: int, screenshots: list[str], session: Session = None - ) -> None: - return session.execute( + ) -> Sequence[Screenshot]: + purged_screenshots = session.scalars( + select(Screenshot).filter( + and_( + Screenshot.rom_id == rom_id, + Screenshot.user_id == user_id, + Screenshot.file_name.not_in(screenshots), + ) + ) + ).all() + + session.execute( delete(Screenshot) .where( - Screenshot.rom_id == rom_id, - Screenshot.user_id == user_id, - Screenshot.file_name.not_in(screenshots), + and_( + Screenshot.rom_id == rom_id, + Screenshot.user_id == user_id, + Screenshot.file_name.not_in(screenshots), + ) ) .execution_options(synchronize_session="evaluate") ) + + return purged_screenshots diff --git a/backend/handler/database/states_handler.py b/backend/handler/database/states_handler.py index c70daca14..8b3315e8c 100644 --- a/backend/handler/database/states_handler.py +++ b/backend/handler/database/states_handler.py @@ -1,3 +1,5 @@ +from typing import Sequence + from decorators.database import begin_session from models.assets import State from sqlalchemy import and_, delete, select, update @@ -12,7 +14,7 @@ class DBStatesHandler(DBBaseHandler): return session.merge(state) @begin_session - def get_state(self, id: int, session: Session = None) -> State: + def get_state(self, id: int, session: Session = None) -> State | None: return session.get(State, id) @begin_session @@ -27,7 +29,7 @@ class DBStatesHandler(DBBaseHandler): @begin_session def update_state(self, id: int, data: dict, session: Session = None) -> State: - return session.execute( + return session.scalar( update(State) .where(State.id == id) .values(**data) @@ -36,7 +38,7 @@ class DBStatesHandler(DBBaseHandler): @begin_session def delete_state(self, id: int, session: Session = None) -> None: - return session.execute( + session.execute( delete(State) .where(State.id == id) .execution_options(synchronize_session="evaluate") @@ -45,8 +47,18 @@ class DBStatesHandler(DBBaseHandler): @begin_session def purge_states( self, rom_id: int, user_id: int, states: list[str], session: Session = None - ) -> None: - return session.execute( + ) -> Sequence[State]: + purged_states = session.scalars( + select(State).filter( + and_( + State.rom_id == rom_id, + State.user_id == user_id, + State.file_name.not_in(states), + ) + ) + ).all() + + session.execute( delete(State) .where( and_( @@ -57,3 +69,5 @@ class DBStatesHandler(DBBaseHandler): ) .execution_options(synchronize_session="evaluate") ) + + return purged_states diff --git a/backend/handler/database/stats_handler.py b/backend/handler/database/stats_handler.py index c526456f1..8a8af03db 100644 --- a/backend/handler/database/stats_handler.py +++ b/backend/handler/database/stats_handler.py @@ -11,25 +11,28 @@ class DBStatsHandler(DBBaseHandler): @begin_session def get_platforms_count(self, session: Session = None) -> int: """Get the number of platforms with any roms.""" - return session.scalar( - select(func.count(distinct(Rom.platform_id))).select_from(Rom) + return ( + session.scalar( + select(func.count(distinct(Rom.platform_id))).select_from(Rom) + ) + or 0 ) @begin_session def get_roms_count(self, session: Session = None) -> int: - return session.scalar(select(func.count()).select_from(Rom)) + return session.scalar(select(func.count()).select_from(Rom)) or 0 @begin_session def get_saves_count(self, session: Session = None) -> int: - return session.scalar(select(func.count()).select_from(Save)) + return session.scalar(select(func.count()).select_from(Save)) or 0 @begin_session def get_states_count(self, session: Session = None) -> int: - return session.scalar(select(func.count()).select_from(State)) + return session.scalar(select(func.count()).select_from(State)) or 0 @begin_session def get_screenshots_count(self, session: Session = None) -> int: - return session.scalar(select(func.count()).select_from(Screenshot)) + return session.scalar(select(func.count()).select_from(Screenshot)) or 0 @begin_session def get_total_filesize(self, session: Session = None) -> int: diff --git a/backend/handler/database/users_handler.py b/backend/handler/database/users_handler.py index 1a5da23c3..c17641f06 100644 --- a/backend/handler/database/users_handler.py +++ b/backend/handler/database/users_handler.py @@ -1,3 +1,5 @@ +from typing import Sequence + from decorators.database import begin_session from models.user import Role, User from sqlalchemy import delete, select, update @@ -27,7 +29,7 @@ class DBUsersHandler(DBBaseHandler): @begin_session def update_user(self, id: int, data: dict, session: Session = None) -> User: - return session.execute( + return session.scalar( update(User) .where(User.id == id) .values(**data) @@ -35,7 +37,7 @@ class DBUsersHandler(DBBaseHandler): ) @begin_session - def get_users(self, session: Session = None) -> list[User]: + def get_users(self, session: Session = None) -> Sequence[User]: return session.scalars(select(User)).all() @begin_session @@ -47,5 +49,5 @@ class DBUsersHandler(DBBaseHandler): ) @begin_session - def get_admin_users(self, session: Session = None) -> list[User]: + def get_admin_users(self, session: Session = None) -> Sequence[User]: return session.scalars(select(User).filter_by(role=Role.ADMIN)).all() diff --git a/backend/handler/filesystem/resources_handler.py b/backend/handler/filesystem/resources_handler.py index 033327ef9..df00ecab7 100644 --- a/backend/handler/filesystem/resources_handler.py +++ b/backend/handler/filesystem/resources_handler.py @@ -95,7 +95,7 @@ class FSResourcesHandler(FSHandler): return "" async def get_cover( - self, entity: Rom | Collection | None, overwrite: bool, url_cover: str = "" + self, entity: Rom | Collection | None, overwrite: bool, url_cover: str | None ) -> tuple[str, str]: if not entity: return "", "" @@ -192,9 +192,9 @@ class FSResourcesHandler(FSHandler): return f"{rom.fs_resources_path}/screenshots/{idx}.jpg" async def get_rom_screenshots( - self, rom: Rom | None, url_screenshots: list + self, rom: Rom | None, url_screenshots: list | None ) -> list[str]: - if not rom: + if not rom or not url_screenshots: return [] path_screenshots: list[str] = [] diff --git a/backend/handler/filesystem/roms_handler.py b/backend/handler/filesystem/roms_handler.py index 0e425049f..d9f2dcf71 100644 --- a/backend/handler/filesystem/roms_handler.py +++ b/backend/handler/filesystem/roms_handler.py @@ -8,7 +8,7 @@ import tarfile import zipfile from collections.abc import Callable, Iterator from pathlib import Path -from typing import Any, Final, TypedDict +from typing import Any, Final, Literal, TypedDict import magic import py7zr @@ -59,7 +59,7 @@ FILE_READ_CHUNK_SIZE = 1024 * 8 class FSRom(TypedDict): multi: bool - file_name: str + fs_name: str files: list[RomFile] @@ -90,7 +90,9 @@ def read_zip_file(file_path: Path) -> Iterator[bytes]: yield chunk -def read_tar_file(file_path: Path, mode: str = "r") -> Iterator[bytes]: +def read_tar_file( + file_path: Path, mode: Literal["r", "r:*", "r:", "r:gz", "r:bz2", "r:xz"] = "r" +) -> Iterator[bytes]: try: with tarfile.open(file_path, mode) as f: for member in f.getmembers(): @@ -339,10 +341,10 @@ class FSRomsHandler(FSHandler): raise RomsNotFoundException(platform_fs_slug) from exc fs_roms: list[dict] = [ - {"multi": False, "file_name": rom} + {"multi": False, "fs_name": rom} for rom in self._exclude_files(fs_single_roms, "single") ] + [ - {"multi": True, "file_name": rom} + {"multi": True, "fs_name": rom} for rom in self._exclude_multi_roms(fs_multi_roms) ] @@ -350,12 +352,12 @@ class FSRomsHandler(FSHandler): [ FSRom( multi=rom["multi"], - file_name=rom["file_name"], - files=self.get_rom_files(rom["file_name"], roms_file_path), + fs_name=rom["fs_name"], + files=self.get_rom_files(rom["fs_name"], roms_file_path), ) for rom in fs_roms ], - key=lambda rom: rom["file_name"], + key=lambda rom: rom["fs_name"], ) def file_exists(self, path: str, file_name: str) -> bool: diff --git a/backend/handler/metadata/igdb_handler.py b/backend/handler/metadata/igdb_handler.py index d387540d9..b5a1a17ce 100644 --- a/backend/handler/metadata/igdb_handler.py +++ b/backend/handler/metadata/igdb_handler.py @@ -567,7 +567,7 @@ class IGDBBaseHandler(MetadataHandler): @check_twitch_token async def get_matched_roms_by_name( - self, search_term: str, platform_igdb_id: int + self, search_term: str, platform_igdb_id: int | None ) -> list[IGDBRom]: if not IGDB_API_ENABLED: return [] diff --git a/backend/handler/metadata/moby_handler.py b/backend/handler/metadata/moby_handler.py index 53cba5d3f..c2c57025e 100644 --- a/backend/handler/metadata/moby_handler.py +++ b/backend/handler/metadata/moby_handler.py @@ -310,7 +310,7 @@ class MobyGamesHandler(MetadataHandler): return [rom] if rom["moby_id"] else [] async def get_matched_roms_by_name( - self, search_term: str, platform_moby_id: int + self, search_term: str, platform_moby_id: int | None ) -> list[MobyGamesRom]: if not MOBY_API_ENABLED: return [] diff --git a/backend/models/rom.py b/backend/models/rom.py index f885b5c2c..027aad78f 100644 --- a/backend/models/rom.py +++ b/backend/models/rom.py @@ -145,6 +145,10 @@ class Rom(BaseModel): def multi(self) -> bool: return len(self.files) > 1 + @cached_property + def file_size_bytes(self) -> int: + return sum(f.file_size_bytes for f in self.files) + def get_collections(self) -> list[Collection]: from handler.database import db_rom_handler