complete updating the endpoints and models

This commit is contained in:
Georges-Antoine Assi 2024-12-20 22:41:56 -05:00
parent 0850c0cbcf
commit 3fcce6606c
No known key found for this signature in database
GPG Key ID: 30F6E9865ABBA06E
26 changed files with 332 additions and 201 deletions

View File

@ -105,15 +105,14 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
status_code=status.HTTP_400_BAD_REQUEST, detail="Missing refresh token"
)
potential_user = await oauth_handler.get_current_active_user_from_bearer_token(
user, claims = await oauth_handler.get_current_active_user_from_bearer_token(
token
)
if not potential_user:
if not user or not claims:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
user, claims = potential_user
if claims.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"

View File

@ -57,7 +57,7 @@ async def add_collection(
_added_collection = db_collection_handler.add_collection(Collection(**cleaned_data))
if artwork is not None:
if artwork is not None and artwork.filename is not None:
file_ext = artwork.filename.split(".")[-1]
(
path_cover_l,
@ -82,8 +82,9 @@ async def add_collection(
_added_collection.path_cover_s = path_cover_s
_added_collection.path_cover_l = path_cover_l
# Update the collection with the cover path and update database
return db_collection_handler.update_collection(
created_collection = db_collection_handler.update_collection(
_added_collection.id,
{
c: getattr(_added_collection, c)
@ -91,6 +92,8 @@ async def add_collection(
},
)
return CollectionSchema.model_validate(created_collection)
@protected_route(router.get, "/collections", [Scope.COLLECTIONS_READ])
def get_collections(request: Request) -> list[CollectionSchema]:
@ -105,7 +108,7 @@ def get_collections(request: Request) -> list[CollectionSchema]:
"""
collections = db_collection_handler.get_collections()
return CollectionSchema.for_user(request.user.id, collections)
return CollectionSchema.for_user(request.user.id, [c for c in collections])
@protected_route(router.get, "/collections/{id}", [Scope.COLLECTIONS_READ])
@ -125,7 +128,7 @@ def get_collection(request: Request, id: int) -> CollectionSchema:
if not collection:
raise CollectionNotFoundInDatabaseException(id)
return collection
return CollectionSchema.model_validate(collection)
@protected_route(router.put, "/collections/{id}", [Scope.COLLECTIONS_WRITE])
@ -148,6 +151,8 @@ async def update_collection(
data = await request.form()
collection = db_collection_handler.get_collection(id)
if not collection:
raise CollectionNotFoundInDatabaseException(id)
if collection.user_id != request.user.id:
raise CollectionPermissionError(id)
@ -156,7 +161,7 @@ async def update_collection(
raise CollectionNotFoundInDatabaseException(id)
try:
roms = json.loads(data["roms"])
roms = json.loads(data["roms"]) # type: ignore
except json.JSONDecodeError as e:
raise ValueError("Invalid list for roms field in update collection") from e
except KeyError:
@ -174,7 +179,7 @@ async def update_collection(
cleaned_data.update(fs_resource_handler.remove_cover(collection))
cleaned_data.update({"url_cover": ""})
else:
if artwork is not None:
if artwork is not None and artwork.filename is not None:
file_ext = artwork.filename.split(".")[-1]
(
path_cover_l,
@ -205,13 +210,14 @@ async def update_collection(
path_cover_s, path_cover_l = await fs_resource_handler.get_cover(
overwrite=True,
entity=collection,
url_cover=data.get("url_cover", ""),
url_cover=data.get("url_cover", ""), # type: ignore
)
cleaned_data.update(
{"path_cover_s": path_cover_s, "path_cover_l": path_cover_l}
)
return db_collection_handler.update_collection(id, cleaned_data)
updated_collection = db_collection_handler.update_collection(id, cleaned_data)
return CollectionSchema.model_validate(updated_collection)
@protected_route(router.delete, "/collections/{id}", [Scope.COLLECTIONS_WRITE])

View File

@ -57,7 +57,7 @@ def platforms_webrcade_feed(request: Request) -> WebrcadeFeedSchema:
request.url_for(
"get_rom_content",
id=rom.id,
file_name=rom.file_name,
file_name=rom.fs_name,
)
),
),
@ -135,10 +135,10 @@ async def tinfoil_index_feed(
async def extract_titledb(roms: list[Rom]) -> dict[str, TinfoilFeedTitleDBSchema]:
titledb = {}
for rom in roms:
match = SWITCH_TITLEDB_REGEX.search(rom.file_name)
match = SWITCH_TITLEDB_REGEX.search(rom.fs_name)
if match:
_search_term, index_entry = (
await meta_igdb_handler._switch_titledb_format(match, rom.file_name)
await meta_igdb_handler._switch_titledb_format(match, rom.fs_name)
)
if index_entry:
titledb[str(index_entry["nsuId"])] = TinfoilFeedTitleDBSchema(
@ -160,9 +160,7 @@ async def tinfoil_index_feed(
files=[
TinfoilFeedFileSchema(
url=str(
request.url_for(
"get_rom_content", id=rom.id, file_name=rom.file_name
)
request.url_for("get_rom_content", id=rom.id, file_name=rom.fs_name)
),
size=rom.file_size_bytes,
)

View File

@ -28,25 +28,28 @@ def add_firmware(
files (list[UploadFile], optional): List of files to upload
Raises:
HTTPException: No files were uploaded
HTTPException
Returns:
AddFirmwareResponse: Standard message response
"""
db_platform = db_platform_handler.get_platform(platform_id)
if not db_platform:
error = f"Platform with ID {platform_id} not found"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)
log.info(f"Uploading firmware to {db_platform.fs_slug}")
if files is None:
log.error("No files were uploaded")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No files were uploaded",
)
uploaded_firmware = []
firmware_path = fs_firmware_handler.build_upload_file_path(db_platform.fs_slug)
for file in files:
if not file.filename:
log.warning("Empty filename, skipping")
continue
fs_firmware_handler.write_file(file=file, path=firmware_path)
db_firmware = db_firmware_handler.get_firmware_by_filename(
@ -70,10 +73,14 @@ def add_firmware(
uploaded_firmware.append(scanned_firmware)
db_platform = db_platform_handler.get_platform(platform_id)
if not db_platform:
error = f"Platform with ID {platform_id} not found"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)
return {
"uploaded": len(files),
"firmware": db_platform.firmware,
"firmware": [FirmwareSchema.model_validate(f) for f in db_platform.firmware],
}
@ -90,7 +97,10 @@ def get_platform_firmware(
Returns:
list[FirmwareSchema]: Firmware stored in the database
"""
return db_firmware_handler.list_firmware(platform_id=platform_id)
return [
FirmwareSchema.model_validate(f)
for f in db_firmware_handler.list_firmware(platform_id=platform_id)
]
@protected_route(
@ -108,7 +118,13 @@ def get_firmware(request: Request, id: int) -> FirmwareSchema:
Returns:
FirmwareSchema: Firmware stored in the database
"""
return FirmwareSchema(**db_firmware_handler.get_firmware(id))
firmware = db_firmware_handler.get_firmware(id)
if not firmware:
error = f"Firmware with ID {id} not found"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)
return FirmwareSchema.model_validate(firmware)
@protected_route(
@ -129,6 +145,11 @@ def head_firmware_content(request: Request, id: int, file_name: str):
"""
firmware = db_firmware_handler.get_firmware(id)
if not firmware:
error = f"Firmware with ID {id} not found"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)
firmware_path = f"{LIBRARY_BASE_PATH}/{firmware.full_path}"
return FileResponse(
@ -162,6 +183,11 @@ def get_firmware_content(
"""
firmware = db_firmware_handler.get_firmware(id)
if not firmware:
error = f"Firmware with ID {id} not found"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)
firmware_path = f"{LIBRARY_BASE_PATH}/{firmware.full_path}"
return FileResponse(path=firmware_path, filename=firmware.file_name)

View File

@ -12,7 +12,6 @@ from handler.filesystem import fs_platform_handler
from handler.metadata.igdb_handler import IGDB_PLATFORM_LIST
from handler.scan_handler import scan_platform
from logger.logger import log
from models.platform import Platform
from utils.router import APIRouter
router = APIRouter()
@ -36,7 +35,9 @@ async def add_platforms(request: Request) -> PlatformSchema:
except PlatformAlreadyExistsException:
log.info(f"Detected platform: {fs_slug}")
scanned_platform = await scan_platform(fs_slug, [fs_slug])
return db_platform_handler.add_platform(scanned_platform)
return PlatformSchema.model_validate(
db_platform_handler.add_platform(scanned_platform)
)
@protected_route(router.get, "/platforms", [Scope.PLATFORMS_READ])
@ -51,7 +52,9 @@ def get_platforms(request: Request) -> list[PlatformSchema]:
list[PlatformSchema]: List of platforms
"""
return db_platform_handler.get_platforms()
return [
PlatformSchema.model_validate(p) for p in db_platform_handler.get_platforms()
]
@protected_route(router.get, "/platforms/supported", [Scope.PLATFORMS_READ])
@ -66,7 +69,7 @@ def get_supported_platforms(request: Request) -> list[PlatformSchema]:
"""
supported_platforms = []
db_platforms: list[Platform] = db_platform_handler.get_platforms()
db_platforms = db_platform_handler.get_platforms()
db_platforms_map = {p.name: p.id for p in db_platforms}
for platform in IGDB_PLATFORM_LIST:
@ -108,7 +111,7 @@ def get_platform(request: Request, id: int) -> PlatformSchema:
if not platform:
raise PlatformNotFoundInDatabaseException(id)
return platform
return PlatformSchema.model_validate(platform)
@protected_route(router.put, "/platforms/{id}", [Scope.PLATFORMS_WRITE])

View File

@ -16,12 +16,12 @@ SORT_COMPARE_REGEX = re.compile(r"^([Tt]he|[Aa]|[Aa]nd)\s")
RomIGDBMetadata = TypedDict( # type: ignore[misc]
"RomIGDBMetadata",
{k: NotRequired[v] for k, v in get_type_hints(IGDBMetadata).items()},
dict((k, NotRequired[v]) for k, v in get_type_hints(IGDBMetadata).items()),
total=False,
)
RomMobyMetadata = TypedDict( # type: ignore[misc]
"RomMobyMetadata",
{k: NotRequired[v] for k, v in get_type_hints(MobyMetadata).items()},
dict((k, NotRequired[v]) for k, v in get_type_hints(MobyMetadata).items()),
total=False,
)
@ -179,16 +179,15 @@ class SimpleRomSchema(RomSchema):
@classmethod
def from_orm_with_request(cls, db_rom: Rom, request: Request) -> SimpleRomSchema:
user_id = request.user.id
db_rom.rom_user = RomUserSchema.for_user(user_id, db_rom)
return cls.model_validate(db_rom)
rom = cls.model_validate(db_rom)
rom.rom_user = RomUserSchema.for_user(user_id, db_rom)
return rom
@classmethod
def from_orm_with_factory(cls, db_rom: Rom) -> SimpleRomSchema:
db_rom.rom_user = rom_user_schema_factory()
return cls.model_validate(db_rom)
rom = cls.model_validate(db_rom)
rom.rom_user = rom_user_schema_factory()
return rom
class DetailedRomSchema(RomSchema):
@ -205,24 +204,26 @@ class DetailedRomSchema(RomSchema):
def from_orm_with_request(cls, db_rom: Rom, request: Request) -> DetailedRomSchema:
user_id = request.user.id
db_rom.rom_user = RomUserSchema.for_user(user_id, db_rom)
db_rom.user_notes = RomUserSchema.notes_for_user(user_id, db_rom)
db_rom.user_saves = [
rom = cls.model_validate(db_rom)
rom.rom_user = RomUserSchema.for_user(user_id, db_rom)
rom.user_notes = RomUserSchema.notes_for_user(user_id, db_rom)
rom.user_saves = [
SaveSchema.model_validate(s) for s in db_rom.saves if s.user_id == user_id
]
db_rom.user_states = [
rom.user_states = [
StateSchema.model_validate(s) for s in db_rom.states if s.user_id == user_id
]
db_rom.user_screenshots = [
rom.user_screenshots = [
ScreenshotSchema.model_validate(s)
for s in db_rom.screenshots
if s.user_id == user_id
]
db_rom.user_collections = CollectionSchema.for_user(
rom.user_collections = CollectionSchema.for_user(
user_id, db_rom.get_collections()
)
return cls.model_validate(db_rom)
return rom
class UserNotesSchema(TypedDict):

View File

@ -187,16 +187,16 @@ async def head_rom_content(
raise RomNotFoundInDatabaseException(id)
rom_path = f"{LIBRARY_BASE_PATH}/{rom.full_path}"
files_to_check = files or [r["filename"] for r in rom.files]
files_to_check = files or [r.file_name for r in rom.files]
if not rom.multi:
# Serve the file directly in development mode for emulatorjs
if DEV_MODE:
return FileResponse(
path=rom_path,
filename=rom.file_name,
filename=rom.fs_name,
headers={
"Content-Disposition": f'attachment; filename="{quote(rom.file_name)}"',
"Content-Disposition": f'attachment; filename="{quote(rom.fs_name)}"',
"Content-Type": "application/octet-stream",
"Content-Length": str(rom.file_size_bytes),
},
@ -204,7 +204,7 @@ async def head_rom_content(
return FileRedirectResponse(
download_path=Path(f"/library/{rom.full_path}"),
filename=rom.file_name,
filename=rom.fs_name,
)
if len(files_to_check) == 1:
@ -252,14 +252,14 @@ async def get_rom_content(
raise RomNotFoundInDatabaseException(id)
rom_path = f"{LIBRARY_BASE_PATH}/{rom.full_path}"
files_to_download = sorted(files or [r["filename"] for r in rom.files])
files_to_download = sorted(files or [r.file_name for r in rom.files])
log.info(f"User {current_username} is downloading {rom.file_name}")
log.info(f"User {current_username} is downloading {rom.fs_name}")
if not rom.multi:
return FileRedirectResponse(
download_path=Path(f"/library/{rom.full_path}"),
filename=rom.file_name,
filename=rom.fs_name,
)
if len(files_to_download) == 1:
@ -332,7 +332,7 @@ async def update_rom(
"igdb_id": None,
"sgdb_id": None,
"moby_id": None,
"name": rom.file_name,
"name": rom.fs_name,
"summary": "",
"url_screenshots": [],
"path_screenshots": [],
@ -346,20 +346,20 @@ async def update_rom(
},
)
return DetailedRomSchema.from_orm_with_request(
db_rom_handler.get_rom(id), request
)
rom = db_rom_handler.get_rom(id)
if not rom:
raise RomNotFoundInDatabaseException(id)
cleaned_data = {
"igdb_id": data.get("igdb_id", None),
"moby_id": data.get("moby_id", None),
return DetailedRomSchema.from_orm_with_request(rom, request)
cleaned_data: dict = {
"igdb_id": str(data.get("igdb_id", "")),
"moby_id": str(data.get("moby_id", "")),
}
if (
cleaned_data.get("moby_id", "")
and int(cleaned_data.get("moby_id", "")) != rom.moby_id
):
moby_rom = await meta_moby_handler.get_rom_by_id(cleaned_data["moby_id"])
moby_id: str = cleaned_data["moby_id"]
if moby_id and int(moby_id) != rom.moby_id:
moby_rom = await meta_moby_handler.get_rom_by_id(int(moby_id))
cleaned_data.update(moby_rom)
path_screenshots = await fs_resource_handler.get_rom_screenshots(
rom=rom,
@ -367,11 +367,9 @@ async def update_rom(
)
cleaned_data.update({"path_screenshots": path_screenshots})
if (
cleaned_data.get("igdb_id", "")
and int(cleaned_data.get("igdb_id", "")) != rom.igdb_id
):
igdb_rom = await meta_igdb_handler.get_rom_by_id(cleaned_data["igdb_id"])
igdb_id: str = cleaned_data["igdb_id"]
if igdb_id and int(igdb_id) != rom.igdb_id:
igdb_rom = await meta_igdb_handler.get_rom_by_id(int(igdb_id))
cleaned_data.update(igdb_rom)
path_screenshots = await fs_resource_handler.get_rom_screenshots(
rom=rom,
@ -386,26 +384,26 @@ async def update_rom(
}
)
new_file_name = data.get("file_name", rom.file_name)
new_file_name = str(data.get("file_name", rom.fs_name))
try:
if rename_as_source:
new_file_name = rom.file_name.replace(
rom.file_name_no_tags or rom.file_name_no_ext,
data.get("name", rom.name),
new_file_name = rom.fs_name.replace(
rom.fs_name_no_tags or rom.fs_name_no_ext,
str(data.get("name", rom.name)),
)
new_file_name = sanitize_filename(new_file_name)
fs_rom_handler.rename_file(
old_name=rom.file_name,
old_name=rom.fs_name,
new_name=new_file_name,
file_path=rom.file_path,
file_path=rom.fs_path,
)
elif rom.file_name != new_file_name:
elif rom.fs_name != new_file_name:
new_file_name = sanitize_filename(new_file_name)
fs_rom_handler.rename_file(
old_name=rom.file_name,
old_name=rom.fs_name,
new_name=new_file_name,
file_path=rom.file_path,
file_path=rom.fs_path,
)
except RomAlreadyExistsException as exc:
log.error(exc)
@ -429,7 +427,7 @@ async def update_rom(
cleaned_data.update(fs_resource_handler.remove_cover(rom))
cleaned_data.update({"url_cover": ""})
else:
if artwork:
if artwork is not None and artwork.filename is not None:
file_ext = artwork.filename.split(".")[-1]
(
path_cover_l,
@ -459,15 +457,18 @@ async def update_rom(
path_cover_s, path_cover_l = await fs_resource_handler.get_cover(
overwrite=True,
entity=rom,
url_cover=data.get("url_cover", ""),
url_cover=str(data.get("url_cover", "")),
)
cleaned_data.update(
{"path_cover_s": path_cover_s, "path_cover_l": path_cover_l}
)
db_rom_handler.update_rom(id, cleaned_data)
rom = db_rom_handler.get_rom(id)
if not rom:
raise RomNotFoundInDatabaseException(id)
return DetailedRomSchema.from_orm_with_request(db_rom_handler.get_rom(id), request)
return DetailedRomSchema.from_orm_with_request(rom, request)
@protected_route(router.post, "/roms/delete", [Scope.ROMS_WRITE])
@ -497,13 +498,15 @@ async def delete_roms(
if not rom:
raise RomNotFoundInDatabaseException(id)
log.info(f"Deleting {rom.file_name} from database")
log.info(f"Deleting {rom.fs_name} from database")
db_rom_handler.delete_rom(id)
# Update collections to remove the deleted rom
collections = db_collection_handler.get_collections_by_rom_id(id)
for collection in collections:
collection.roms = [rom_id for rom_id in collection.roms if rom_id != id]
collection.roms = set(
[rom_id for rom_id in collection.roms if rom_id != id]
)
db_collection_handler.update_collection(
collection.id, {"roms": collection.roms}
)
@ -514,13 +517,13 @@ async def delete_roms(
log.error(f"Couldn't find resources to delete for {rom.name}")
if id in delete_from_fs:
log.info(f"Deleting {rom.file_name} from filesystem")
log.info(f"Deleting {rom.fs_name} from filesystem")
try:
fs_rom_handler.remove_file(
file_name=rom.file_name, file_path=rom.file_path
)
fs_rom_handler.remove_file(file_name=rom.fs_name, file_path=rom.fs_name)
except FileNotFoundError as exc:
error = f"Rom file {rom.file_name} not found for platform {rom.platform_slug}"
error = (
f"Rom file {rom.fs_name} not found for platform {rom.platform_slug}"
)
log.error(error)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=error
@ -556,5 +559,6 @@ async def update_rom_user(request: Request, id: int) -> RomUserSchema:
]
cleaned_data = {field: data[field] for field in fields_to_update if field in data}
rom_user = db_rom_handler.update_rom_user(db_rom_user.id, cleaned_data)
return db_rom_handler.update_rom_user(db_rom_user.id, cleaned_data)
return RomUserSchema.model_validate(rom_user)

View File

@ -29,18 +29,15 @@ def add_saves(
current_user = request.user
log.info(f"Uploading saves to {rom.name}")
if saves is None:
log.error("No saves were uploaded")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No saves were uploaded",
)
saves_path = fs_asset_handler.build_saves_file_path(
user=request.user, platform_fs_slug=rom.platform.fs_slug, emulator=emulator
)
for save in saves:
if not save.filename:
log.error("Save file has no filename")
continue
fs_asset_handler.write_file(file=save, path=saves_path)
# Scan or update save
@ -79,7 +76,11 @@ def add_saves(
return {
"uploaded": len(saves),
"saves": [s for s in rom.saves if s.user_id == current_user.id],
"saves": [
SaveSchema.model_validate(s)
for s in rom.saves
if s.user_id == current_user.id
],
}
@ -109,7 +110,7 @@ async def update_save(request: Request, id: int) -> SaveSchema:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=error)
if "file" in data:
file: UploadFile = data["file"]
file: UploadFile = data["file"] # type: ignore
fs_asset_handler.write_file(file=file, path=db_save.file_path)
db_save_handler.update_save(db_save.id, {"file_size_bytes": file.size})
@ -124,7 +125,7 @@ async def update_save(request: Request, id: int) -> SaveSchema:
# Refetch the save to get updated fields
db_save = db_save_handler.get_save(id)
return db_save
return SaveSchema.model_validate(db_save)
@protected_route(router.post, "/saves/delete", [Scope.ASSETS_WRITE])

View File

@ -1,6 +1,7 @@
from decorators.auth import protected_route
from endpoints.responses.assets import UploadedScreenshotsResponse
from fastapi import File, HTTPException, Request, UploadFile, status
from endpoints.responses.assets import ScreenshotSchema, UploadedScreenshotsResponse
from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException
from fastapi import File, Request, UploadFile
from handler.auth.base_handler import Scope
from handler.database import db_rom_handler, db_screenshot_handler
from handler.filesystem import fs_asset_handler
@ -18,21 +19,21 @@ def add_screenshots(
screenshots: list[UploadFile] = File(...), # noqa: B008
) -> UploadedScreenshotsResponse:
rom = db_rom_handler.get_rom(rom_id)
if not rom:
raise RomNotFoundInDatabaseException(rom_id)
current_user = request.user
log.info(f"Uploading screenshots to {rom.name}")
if screenshots is None:
log.error("No screenshots were uploaded")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No screenshots were uploaded",
)
screenshots_path = fs_asset_handler.build_screenshots_file_path(
user=request.user, platform_fs_slug=rom.platform_slug
)
for screenshot in screenshots:
if not screenshot.filename:
log.warning("Skipping empty screenshot")
continue
fs_asset_handler.write_file(file=screenshot, path=screenshots_path)
# Scan or update screenshot
@ -56,8 +57,15 @@ def add_screenshots(
db_screenshot_handler.add_screenshot(scanned_screenshot)
rom = db_rom_handler.get_rom(rom_id)
if not rom:
raise RomNotFoundInDatabaseException(rom_id)
return {
"uploaded": len(screenshots),
"screenshots": [s for s in rom.screenshots if s.user_id == current_user.id],
"screenshots": [
ScreenshotSchema.model_validate(s)
for s in rom.screenshots
if s.user_id == current_user.id
],
"merged_screenshots": rom.merged_screenshots,
}

View File

@ -5,8 +5,8 @@ from fastapi import HTTPException, Request, status
from handler.auth.base_handler import Scope
from handler.database import db_rom_handler
from handler.metadata import meta_igdb_handler, meta_moby_handler, meta_sgdb_handler
from handler.metadata.igdb_handler import IGDB_API_ENABLED
from handler.metadata.moby_handler import MOBY_API_ENABLED
from handler.metadata.igdb_handler import IGDB_API_ENABLED, IGDBRom
from handler.metadata.moby_handler import MOBY_API_ENABLED, MobyGamesRom
from handler.metadata.sgdb_handler import STEAMGRIDDB_API_ENABLED
from handler.scan_handler import _get_main_platform_igdb_id
from logger.logger import log
@ -43,11 +43,11 @@ async def search_rom(
detail="No metadata providers enabled",
)
rom = db_rom_handler.get_rom(rom_id)
rom = db_rom_handler.get_rom(int(rom_id))
if not rom:
return []
search_term = search_term or rom.file_name_no_tags
search_term = search_term or rom.fs_name_no_tags
if not search_term:
return []
@ -57,7 +57,11 @@ async def search_rom(
matched_roms: list = []
log.info(f"Searching by {search_by.lower()}: {search_term}")
log.info(emoji.emojize(f":video_game: {rom.platform_slug}: {rom.file_name}"))
log.info(emoji.emojize(f":video_game: {rom.platform_slug}: {rom.fs_name}"))
igdb_matched_roms: list[IGDBRom] = []
moby_matched_roms: list[MobyGamesRom] = []
if search_by.lower() == "id":
try:
igdb_matched_roms = await meta_igdb_handler.get_matched_roms_by_id(
@ -73,19 +77,20 @@ async def search_rom(
detail=f"Tried searching by ID, but '{search_term}' is not a valid ID",
) from exc
elif search_by.lower() == "name":
main_platform_igdb_id = await _get_main_platform_igdb_id(rom.platform)
igdb_matched_roms = await meta_igdb_handler.get_matched_roms_by_name(
search_term, (await _get_main_platform_igdb_id(rom.platform))
search_term, main_platform_igdb_id or rom.platform.igdb_id
)
moby_matched_roms = await meta_moby_handler.get_matched_roms_by_name(
search_term, rom.platform.moby_id
)
merged_dict = {
item["name"]: {**item, "igdb_url_cover": item.pop("url_cover", "")}
item["name"]: {**item, "igdb_url_cover": item.pop("url_cover", "")} # type: ignore
for item in igdb_matched_roms
}
for item in moby_matched_roms:
merged_dict[item["name"]] = {
merged_dict[item["name"]] = { # type: ignore
**item,
"moby_url_cover": item.pop("url_cover", ""),
**merged_dict.get(item.get("name", ""), {}),

View File

@ -143,7 +143,9 @@ async def scan_platforms(
try:
platform_list = [
db_platform_handler.get_platform(s).fs_slug for s in platform_ids
platform.fs_slug
for s in platform_ids
if (platform := db_platform_handler.get_platform(s)) is not None
] or fs_platforms
if len(platform_list) == 0:
@ -271,14 +273,14 @@ async def _identify_platform(
for fs_roms_batch in batched(fs_roms, 200):
rom_by_filename_map = db_rom_handler.get_roms_by_filename(
platform_id=platform.id,
file_names={fs_rom["file_name"] for fs_rom in fs_roms_batch},
file_names={fs_rom["fs_name"] for fs_rom in fs_roms_batch},
)
for fs_rom in fs_roms_batch:
scan_stats += await _identify_rom(
platform=platform,
fs_rom=fs_rom,
rom=rom_by_filename_map.get(fs_rom["file_name"]),
rom=rom_by_filename_map.get(fs_rom["fs_name"]),
scan_type=scan_type,
roms_ids=roms_ids,
metadata_sources=metadata_sources,
@ -290,12 +292,12 @@ async def _identify_platform(
# the folder structure is not correct or the drive is not mounted
if len(fs_roms) > 0:
purged_roms = db_rom_handler.purge_roms(
platform.id, [rom["file_name"] for rom in fs_roms]
platform.id, [rom["fs_name"] for rom in fs_roms]
)
if len(purged_roms) > 0:
log.info("Purging roms not found in the filesystem:")
for r in purged_roms:
log.info(f" - {r.file_name}")
log.info(f" - {r.fs_name}")
# Same protection for firmware
if len(fs_firmware) > 0:
@ -346,7 +348,7 @@ def _set_rom_hashes(rom_id: int):
return
try:
rom_hashes = fs_rom_handler.get_rom_hashes(rom.file_name, rom.file_path)
rom_hashes = fs_rom_handler.get_rom_hashes(rom.fs_name, rom.fs_path)
db_rom_handler.update_rom(
rom_id,
{
@ -358,7 +360,7 @@ def _set_rom_hashes(rom_id: int):
except zlib.error as e:
# Set empty hashes if calculating them fails for corrupted files
log.error(
f"Hashes of {rom.file_name} couldn't be calculated: {hl(str(e), color=RED)}"
f"Hashes of {rom.fs_name} couldn't be calculated: {hl(str(e), color=RED)}"
)
db_rom_handler.update_rom(
rom_id,
@ -387,12 +389,12 @@ async def _identify_rom(
if not _should_scan_rom(scan_type=scan_type, rom=rom, roms_ids=roms_ids):
if rom and (
rom.file_name != fs_rom["file_name"]
rom.fs_name != fs_rom["fs_name"]
or rom.multi != fs_rom["multi"]
or rom.files != fs_rom["files"]
):
# Just to update the filesystem data
rom.file_name = fs_rom["file_name"]
rom.fs_name = fs_rom["fs_name"]
rom.multi = fs_rom["multi"]
rom.files = fs_rom["files"]
db_rom_handler.add_rom(rom)

View File

@ -29,18 +29,15 @@ def add_states(
current_user = request.user
log.info(f"Uploading states to {rom.name}")
if states is None:
log.error("No states were uploaded")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No states were uploaded",
)
states_path = fs_asset_handler.build_states_file_path(
user=request.user, platform_fs_slug=rom.platform.fs_slug, emulator=emulator
)
for state in states:
if not state.filename:
log.warning("Skipping file with no filename")
continue
fs_asset_handler.write_file(file=state, path=states_path)
# Scan or update state
@ -73,9 +70,16 @@ def add_states(
)
rom = db_rom_handler.get_rom(rom_id)
if not rom:
raise RomNotFoundInDatabaseException(rom_id)
return {
"uploaded": len(states),
"states": [s for s in rom.states if s.user_id == current_user.id],
"states": [
StateSchema.model_validate(s)
for s in rom.states
if s.user_id == current_user.id
],
}
@ -105,7 +109,7 @@ async def update_state(request: Request, id: int) -> StateSchema:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=error)
if "file" in data:
file: UploadFile = data["file"]
file: UploadFile = data["file"] # type: ignore
fs_asset_handler.write_file(file=file, path=db_state.file_path)
db_state_handler.update_state(db_state.id, {"file_size_bytes": file.size})
@ -119,7 +123,7 @@ async def update_state(request: Request, id: int) -> StateSchema:
)
db_state = db_state_handler.get_state(id)
return db_state
return StateSchema.model_validate(db_state)
@protected_route(router.post, "/states/delete", [Scope.ASSETS_WRITE])

View File

@ -76,7 +76,7 @@ def add_user(
role=Role[role.upper()],
)
return db_user_handler.add_user(user)
return UserSchema.model_validate(db_user_handler.add_user(user))
@protected_route(router.get, "/users", [Scope.USERS_READ])
@ -90,7 +90,7 @@ def get_users(request: Request) -> list[UserSchema]:
list[UserSchema]: All users stored in the RomM's database
"""
return [u for u in db_user_handler.get_users()]
return [UserSchema.model_validate(u) for u in db_user_handler.get_users()]
@protected_route(router.get, "/users/me", [Scope.ME_READ])
@ -122,7 +122,7 @@ def get_user(request: Request, id: int) -> UserSchema:
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
return UserSchema.model_validate(user)
@protected_route(router.put, "/users/{id}", [Scope.ME_WRITE])
@ -215,7 +215,13 @@ async def update_user(
if request.user.id == id and creds_updated:
request.session.clear()
return db_user_handler.get_user(id)
db_user = db_user_handler.get_user(id)
if not db_user:
msg = f"Username with id {id} not found"
log.error(msg)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=msg)
return UserSchema.model_validate(db_user)
@protected_route(router.delete, "/users/{id}", [Scope.USERS_WRITE])

View File

@ -1,6 +1,8 @@
from typing import Sequence
from decorators.database import begin_session
from models.collection import Collection
from sqlalchemy import Select, delete, select, update
from sqlalchemy import delete, select, update
from sqlalchemy.orm import Session
from .base_handler import DBBaseHandler
@ -10,10 +12,17 @@ class DBCollectionsHandler(DBBaseHandler):
@begin_session
def add_collection(
self, collection: Collection, session: Session = None
) -> Collection | None:
) -> Collection:
collection = session.merge(collection)
session.flush()
return session.scalar(select(Collection).filter_by(id=collection.id).limit(1))
new_collection = session.scalar(
select(Collection).filter_by(id=collection.id).limit(1)
)
if not new_collection:
raise ValueError("Could not find newly created collection")
return new_collection
@begin_session
def get_collection(self, id: int, session: Session = None) -> Collection | None:
@ -28,9 +37,9 @@ class DBCollectionsHandler(DBBaseHandler):
)
@begin_session
def get_collections(self, session: Session = None) -> Select[tuple[Collection]]:
def get_collections(self, session: Session = None) -> Sequence[Collection]:
return (
session.scalars(select(Collection).order_by(Collection.name.asc())) # type: ignore[attr-defined]
session.scalars(select(Collection).order_by(Collection.name.asc()))
.unique()
.all()
)
@ -38,7 +47,7 @@ class DBCollectionsHandler(DBBaseHandler):
@begin_session
def get_collections_by_rom_id(
self, rom_id: int, session: Session = None
) -> list[Collection]:
) -> Sequence[Collection]:
return session.scalars(
select(Collection).filter(Collection.roms.contains(rom_id))
).all()
@ -56,8 +65,8 @@ class DBCollectionsHandler(DBBaseHandler):
return session.query(Collection).filter_by(id=id).one()
@begin_session
def delete_collection(self, id: int, session: Session = None) -> int:
return session.execute(
def delete_collection(self, id: int, session: Session = None) -> None:
session.execute(
delete(Collection)
.where(Collection.id == id)
.execution_options(synchronize_session="evaluate")

View File

@ -1,3 +1,5 @@
from typing import Sequence
from decorators.database import begin_session
from models.firmware import Firmware
from sqlalchemy import and_, delete, select, update
@ -26,7 +28,7 @@ class DBFirmwareHandler(DBBaseHandler):
*,
platform_id: int | None = None,
session: Session = None,
) -> list[Firmware]:
) -> Sequence[Firmware]:
return session.scalars(
select(Firmware)
.filter_by(platform_id=platform_id)
@ -45,7 +47,7 @@ class DBFirmwareHandler(DBBaseHandler):
@begin_session
def update_firmware(self, id: int, data: dict, session: Session = None) -> Firmware:
return session.execute(
return session.scalar(
update(Firmware)
.where(Firmware.id == id)
.values(**data)
@ -54,7 +56,7 @@ class DBFirmwareHandler(DBBaseHandler):
@begin_session
def delete_firmware(self, id: int, session: Session = None) -> None:
return session.execute(
session.execute(
delete(Firmware)
.where(Firmware.id == id)
.execution_options(synchronize_session="evaluate")
@ -63,7 +65,7 @@ class DBFirmwareHandler(DBBaseHandler):
@begin_session
def purge_firmware(
self, platform_id: int, fs_firmwares: list[str], session: Session = None
) -> None:
) -> Sequence[Firmware]:
purged_firmware = (
session.scalars(
select(Firmware)

View File

@ -1,7 +1,9 @@
from typing import Sequence
from decorators.database import begin_session
from models.platform import Platform
from models.rom import Rom
from sqlalchemy import Select, delete, or_, select
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from .base_handler import DBBaseHandler
@ -21,7 +23,7 @@ class DBPlatformsHandler(DBBaseHandler):
select(Platform).filter_by(id=platform.id).limit(1)
)
if not new_platform:
raise ValueError("Could not find newlyewly created platform")
raise ValueError("Could not find newly created platform")
return new_platform
@ -30,9 +32,9 @@ class DBPlatformsHandler(DBBaseHandler):
return session.scalar(select(Platform).filter_by(id=id).limit(1))
@begin_session
def get_platforms(self, *, session: Session = None) -> Select[tuple[Platform]]:
def get_platforms(self, *, session: Session = None) -> Sequence[Platform]:
return (
session.scalars(select(Platform).order_by(Platform.name.asc())) # type: ignore[attr-defined]
session.scalars(select(Platform).order_by(Platform.name.asc()))
.unique()
.all()
)
@ -61,7 +63,7 @@ class DBPlatformsHandler(DBBaseHandler):
@begin_session
def purge_platforms(
self, fs_platforms: list[str], session: Session = None
) -> Select[tuple[Platform]]:
) -> Sequence[Platform]:
purged_platforms = (
session.scalars(
select(Platform)

View File

@ -1,3 +1,5 @@
from typing import Sequence
from decorators.database import begin_session
from models.assets import Save
from sqlalchemy import and_, delete, select, update
@ -12,7 +14,7 @@ class DBSavesHandler(DBBaseHandler):
return session.merge(save)
@begin_session
def get_save(self, id: int, session: Session = None) -> Save:
def get_save(self, id: int, session: Session = None) -> Save | None:
return session.get(Save, id)
@begin_session
@ -27,7 +29,7 @@ class DBSavesHandler(DBBaseHandler):
@begin_session
def update_save(self, id: int, data: dict, session: Session = None) -> Save:
return session.execute(
return session.scalar(
update(Save)
.where(Save.id == id)
.values(**data)
@ -36,7 +38,7 @@ class DBSavesHandler(DBBaseHandler):
@begin_session
def delete_save(self, id: int, session: Session = None) -> None:
return session.execute(
session.execute(
delete(Save)
.where(Save.id == id)
.execution_options(synchronize_session="evaluate")
@ -45,8 +47,18 @@ class DBSavesHandler(DBBaseHandler):
@begin_session
def purge_saves(
self, rom_id: int, user_id: int, saves: list[str], session: Session = None
) -> None:
return session.execute(
) -> Sequence[Save]:
purged_saves = session.scalars(
select(Save).filter(
and_(
Save.rom_id == rom_id,
Save.user_id == user_id,
Save.file_name.not_in(saves),
)
)
).all()
session.execute(
delete(Save)
.where(
and_(
@ -57,3 +69,5 @@ class DBSavesHandler(DBBaseHandler):
)
.execution_options(synchronize_session="evaluate")
)
return purged_saves

View File

@ -1,6 +1,8 @@
from typing import Sequence
from decorators.database import begin_session
from models.assets import Screenshot
from sqlalchemy import delete, select, update
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import Session
from .base_handler import DBBaseHandler
@ -14,7 +16,7 @@ class DBScreenshotsHandler(DBBaseHandler):
return session.merge(screenshot)
@begin_session
def get_screenshot(self, id, session: Session = None) -> Screenshot:
def get_screenshot(self, id, session: Session = None) -> Screenshot | None:
return session.get(Screenshot, id)
@begin_session
@ -31,7 +33,7 @@ class DBScreenshotsHandler(DBBaseHandler):
def update_screenshot(
self, id: int, data: dict, session: Session = None
) -> Screenshot:
return session.execute(
return session.scalar(
update(Screenshot)
.where(Screenshot.id == id)
.values(**data)
@ -40,7 +42,7 @@ class DBScreenshotsHandler(DBBaseHandler):
@begin_session
def delete_screenshot(self, id: int, session: Session = None) -> None:
return session.execute(
session.execute(
delete(Screenshot)
.where(Screenshot.id == id)
.execution_options(synchronize_session="evaluate")
@ -49,13 +51,27 @@ class DBScreenshotsHandler(DBBaseHandler):
@begin_session
def purge_screenshots(
self, rom_id: int, user_id: int, screenshots: list[str], session: Session = None
) -> None:
return session.execute(
) -> Sequence[Screenshot]:
purged_screenshots = session.scalars(
select(Screenshot).filter(
and_(
Screenshot.rom_id == rom_id,
Screenshot.user_id == user_id,
Screenshot.file_name.not_in(screenshots),
)
)
).all()
session.execute(
delete(Screenshot)
.where(
Screenshot.rom_id == rom_id,
Screenshot.user_id == user_id,
Screenshot.file_name.not_in(screenshots),
and_(
Screenshot.rom_id == rom_id,
Screenshot.user_id == user_id,
Screenshot.file_name.not_in(screenshots),
)
)
.execution_options(synchronize_session="evaluate")
)
return purged_screenshots

View File

@ -1,3 +1,5 @@
from typing import Sequence
from decorators.database import begin_session
from models.assets import State
from sqlalchemy import and_, delete, select, update
@ -12,7 +14,7 @@ class DBStatesHandler(DBBaseHandler):
return session.merge(state)
@begin_session
def get_state(self, id: int, session: Session = None) -> State:
def get_state(self, id: int, session: Session = None) -> State | None:
return session.get(State, id)
@begin_session
@ -27,7 +29,7 @@ class DBStatesHandler(DBBaseHandler):
@begin_session
def update_state(self, id: int, data: dict, session: Session = None) -> State:
return session.execute(
return session.scalar(
update(State)
.where(State.id == id)
.values(**data)
@ -36,7 +38,7 @@ class DBStatesHandler(DBBaseHandler):
@begin_session
def delete_state(self, id: int, session: Session = None) -> None:
return session.execute(
session.execute(
delete(State)
.where(State.id == id)
.execution_options(synchronize_session="evaluate")
@ -45,8 +47,18 @@ class DBStatesHandler(DBBaseHandler):
@begin_session
def purge_states(
self, rom_id: int, user_id: int, states: list[str], session: Session = None
) -> None:
return session.execute(
) -> Sequence[State]:
purged_states = session.scalars(
select(State).filter(
and_(
State.rom_id == rom_id,
State.user_id == user_id,
State.file_name.not_in(states),
)
)
).all()
session.execute(
delete(State)
.where(
and_(
@ -57,3 +69,5 @@ class DBStatesHandler(DBBaseHandler):
)
.execution_options(synchronize_session="evaluate")
)
return purged_states

View File

@ -11,25 +11,28 @@ class DBStatsHandler(DBBaseHandler):
@begin_session
def get_platforms_count(self, session: Session = None) -> int:
"""Get the number of platforms with any roms."""
return session.scalar(
select(func.count(distinct(Rom.platform_id))).select_from(Rom)
return (
session.scalar(
select(func.count(distinct(Rom.platform_id))).select_from(Rom)
)
or 0
)
@begin_session
def get_roms_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Rom))
return session.scalar(select(func.count()).select_from(Rom)) or 0
@begin_session
def get_saves_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Save))
return session.scalar(select(func.count()).select_from(Save)) or 0
@begin_session
def get_states_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(State))
return session.scalar(select(func.count()).select_from(State)) or 0
@begin_session
def get_screenshots_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Screenshot))
return session.scalar(select(func.count()).select_from(Screenshot)) or 0
@begin_session
def get_total_filesize(self, session: Session = None) -> int:

View File

@ -1,3 +1,5 @@
from typing import Sequence
from decorators.database import begin_session
from models.user import Role, User
from sqlalchemy import delete, select, update
@ -27,7 +29,7 @@ class DBUsersHandler(DBBaseHandler):
@begin_session
def update_user(self, id: int, data: dict, session: Session = None) -> User:
return session.execute(
return session.scalar(
update(User)
.where(User.id == id)
.values(**data)
@ -35,7 +37,7 @@ class DBUsersHandler(DBBaseHandler):
)
@begin_session
def get_users(self, session: Session = None) -> list[User]:
def get_users(self, session: Session = None) -> Sequence[User]:
return session.scalars(select(User)).all()
@begin_session
@ -47,5 +49,5 @@ class DBUsersHandler(DBBaseHandler):
)
@begin_session
def get_admin_users(self, session: Session = None) -> list[User]:
def get_admin_users(self, session: Session = None) -> Sequence[User]:
return session.scalars(select(User).filter_by(role=Role.ADMIN)).all()

View File

@ -95,7 +95,7 @@ class FSResourcesHandler(FSHandler):
return ""
async def get_cover(
self, entity: Rom | Collection | None, overwrite: bool, url_cover: str = ""
self, entity: Rom | Collection | None, overwrite: bool, url_cover: str | None
) -> tuple[str, str]:
if not entity:
return "", ""
@ -192,9 +192,9 @@ class FSResourcesHandler(FSHandler):
return f"{rom.fs_resources_path}/screenshots/{idx}.jpg"
async def get_rom_screenshots(
self, rom: Rom | None, url_screenshots: list
self, rom: Rom | None, url_screenshots: list | None
) -> list[str]:
if not rom:
if not rom or not url_screenshots:
return []
path_screenshots: list[str] = []

View File

@ -8,7 +8,7 @@ import tarfile
import zipfile
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import Any, Final, TypedDict
from typing import Any, Final, Literal, TypedDict
import magic
import py7zr
@ -59,7 +59,7 @@ FILE_READ_CHUNK_SIZE = 1024 * 8
class FSRom(TypedDict):
multi: bool
file_name: str
fs_name: str
files: list[RomFile]
@ -90,7 +90,9 @@ def read_zip_file(file_path: Path) -> Iterator[bytes]:
yield chunk
def read_tar_file(file_path: Path, mode: str = "r") -> Iterator[bytes]:
def read_tar_file(
file_path: Path, mode: Literal["r", "r:*", "r:", "r:gz", "r:bz2", "r:xz"] = "r"
) -> Iterator[bytes]:
try:
with tarfile.open(file_path, mode) as f:
for member in f.getmembers():
@ -339,10 +341,10 @@ class FSRomsHandler(FSHandler):
raise RomsNotFoundException(platform_fs_slug) from exc
fs_roms: list[dict] = [
{"multi": False, "file_name": rom}
{"multi": False, "fs_name": rom}
for rom in self._exclude_files(fs_single_roms, "single")
] + [
{"multi": True, "file_name": rom}
{"multi": True, "fs_name": rom}
for rom in self._exclude_multi_roms(fs_multi_roms)
]
@ -350,12 +352,12 @@ class FSRomsHandler(FSHandler):
[
FSRom(
multi=rom["multi"],
file_name=rom["file_name"],
files=self.get_rom_files(rom["file_name"], roms_file_path),
fs_name=rom["fs_name"],
files=self.get_rom_files(rom["fs_name"], roms_file_path),
)
for rom in fs_roms
],
key=lambda rom: rom["file_name"],
key=lambda rom: rom["fs_name"],
)
def file_exists(self, path: str, file_name: str) -> bool:

View File

@ -567,7 +567,7 @@ class IGDBBaseHandler(MetadataHandler):
@check_twitch_token
async def get_matched_roms_by_name(
self, search_term: str, platform_igdb_id: int
self, search_term: str, platform_igdb_id: int | None
) -> list[IGDBRom]:
if not IGDB_API_ENABLED:
return []

View File

@ -310,7 +310,7 @@ class MobyGamesHandler(MetadataHandler):
return [rom] if rom["moby_id"] else []
async def get_matched_roms_by_name(
self, search_term: str, platform_moby_id: int
self, search_term: str, platform_moby_id: int | None
) -> list[MobyGamesRom]:
if not MOBY_API_ENABLED:
return []

View File

@ -145,6 +145,10 @@ class Rom(BaseModel):
def multi(self) -> bool:
return len(self.files) > 1
@cached_property
def file_size_bytes(self) -> int:
return sum(f.file_size_bytes for f in self.files)
def get_collections(self) -> list[Collection]:
from handler.database import db_rom_handler