2026-04-01 01:59:10 -04:00

433 lines
14 KiB
Python

import cProfile
import difflib
import emoji
import gc
import io
import math
import os
import pstats
import string
import time
from django.core.paginator import Paginator
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from pathlib import Path
from urllib.parse import urlunsplit, urlencode, urlparse
from .errors import DatabaseConnectionError, QuerySetEmptyError
def get_usable_cpu_count() -> int:
"""
Find a reasonable value for available CPUs.
Checks: cgroups, scheduler, then physical count.
"""
# Check Cgroup v2 (Modern Docker/K8s)
cgroup_dir = Path('/sys/fs/cgroup')
try:
cpu_max_path = cgroup_dir / 'cpu.max'
if cpu_max_path.exists():
data = cpu_max_path.read_text().split()
# Format: "$MAX $PERIOD" (e.g., "50000 100000")
if 'max' != data[0]:
quota = int(data[0])
period = int(data[1])
return max(1, math.ceil(quota / period))
except (OSError, ValueError, IndexError):
pass
# Check Cgroup v1 (Legacy Docker/K8s)
cgroup_cpu_dir = cgroup_dir / 'cpu'
try:
quota_path = cgroup_cpu_dir / 'cpu.cfs_quota_us'
period_path = cgroup_cpu_dir / 'cpu.cfs_period_us'
if quota_path.exists() and period_path.exists():
quota = int(quota_path.read_text())
period = int(period_path.read_text())
if 0 < quota:
return max(1, math.ceil(quota / period))
except (OSError, ValueError):
pass
# Linux specific: Respects 'taskset' or CPU pinning
try:
affinity_count = len(os.sched_getaffinity(0))
if 0 < affinity_count:
return affinity_count
except (AttributeError, NotImplementedError):
pass
# Returns total logical CPUs; defaults to 1
return max(1, os.cpu_count() or 0)
def directory_and_stem(arg_path, /, all_suffixes=False):
filepath = Path(arg_path)
stem = Path(filepath.stem)
while all_suffixes and stem.suffixes and '' != stem.suffix:
stem = Path(stem.stem)
return (filepath.parent, str(stem),)
def getenv(key, default=None, /, *, integer=False, string=True):
'''
Guarantees a returned type from calling `os.getenv`
The caller can request the integer type,
or use the default string type.
'''
args = dict(key=key, default=default, integer=integer, string=string)
supported_types = dict(zip(args.keys(), (
(str,), # key
(
bool,
float,
int,
str,
None.__class__,
), # default
(bool,) * (len(args.keys()) - 2),
)))
unsupported_type_msg = 'Unsupported type for positional argument, "{}": {}'
for k, t in supported_types.items():
v = args[k]
assert isinstance(v, t), unsupported_type_msg.format(k, type(v))
d = str(default) if default is not None else None
r = os.getenv(key, d)
if r is None:
if string: r = str()
if integer: r = int()
elif integer:
r = int(float(r))
return r
def glob_quote(filestr, /):
_glob_specials = {
'?': '[?]',
'*': '[*]',
'[': '[[]',
']': '[]]', # probably not needed, but it won't hurt
}
if not isinstance(filestr, str):
raise TypeError(f'expected a str, got "{type(filestr)}"')
return filestr.translate(str.maketrans(_glob_specials))
def is_empty_iterator(iterator):
"""
Checks if an iterator is empty without fully consuming it.
Returns: (is_empty_boolean, iterator)
"""
returned_iterator = iterator
try:
first_item = next(iterator)
except StopIteration:
return True, iter([])
else:
# Put the first item back at the start of a new iterator
# Chaining the single item with the original remaining stream
returned_iterator = chain([first_item], iterator)
return False, returned_iterator
def list_of_dictionaries(arg_list, /, arg_function=lambda x: x):
assert callable(arg_function)
_map_list = arg_list
if hasattr(arg_list, 'exhaust') and callable(arg_list.exhaust):
_map_list = arg_list.exhaust()
if isinstance(_map_list, list):
_map_func = partial(lambda f, d: f(d) if isinstance(d, dict) else d, arg_function)
return (True, list(map(_map_func, _map_list)),)
return (False, arg_list,)
def mkdir_p(arg_path, /, *, mode=0o777):
'''
Reminder: mode only affects the last directory
'''
dirpath = Path(arg_path)
return dirpath.mkdir(mode=mode, parents=True, exist_ok=True)
def multi_key_sort(iterable, specs, /, use_reversed=False, *, item=False, attr=False, key_func=None):
result = list(iterable)
if key_func is None:
# itemgetter is the default
if item or not (item or attr):
key_func = itemgetter
elif attr:
key_func = attrgetter
for key, reverse in reversed(specs):
result.sort(key=key_func(key), reverse=reverse)
if use_reversed:
return list(reversed(result))
return result
def parse_database_connection_string(database_connection_string):
'''
Parses a connection string in a URL style format, such as:
postgresql://tubesync:password@localhost:5432/tubesync
mysql://someuser:somepassword@localhost:3306/tubesync
into a Django-compatible settings.DATABASES dict format.
'''
valid_drivers = ('postgresql', 'mysql')
default_ports = {
'postgresql': 5432,
'mysql': 3306,
}
django_backends = {
'postgresql': 'django.db.backends.postgresql',
'mysql': 'django.db.backends.mysql',
}
backend_options = {
'postgresql': dict(pool={
'max_size': 10, # default: None (static min_size pool)
'min_size': 3, # default: 4
'num_workers': 2, # default: 3
'timeout': 180, # default: 30
}),
'mysql': {
'charset': 'utf8mb4',
}
}
db_overrides = {
'mysql': {
'CONN_MAX_AGE': 300,
},
'postgresql': dict(),
}
try:
parts = urlparse(str(database_connection_string))
except Exception as e:
raise DatabaseConnectionError(f'Failed to parse "{database_connection_string}" '
f'as a database connection string: {e}') from e
driver = parts.scheme
user_pass_host_port = parts.netloc
database = parts.path
if driver not in valid_drivers:
raise DatabaseConnectionError(f'Database connection string '
f'"{database_connection_string}" specified an '
f'invalid driver, must be one of {valid_drivers}')
django_driver = django_backends.get(driver)
host_parts = user_pass_host_port.split('@')
user_pass_parts = host_parts[0].split(':')
if len(host_parts) != 2 or len(user_pass_parts) != 2:
raise DatabaseConnectionError('Database connection string netloc must be in '
'the format of user:pass@host')
user_pass, host_port = host_parts
username, password = user_pass_parts
host_port_parts = host_port.split(':')
if len(host_port_parts) == 1:
# No port number, assign a default port
hostname = host_port_parts[0]
port = default_ports.get(driver)
elif len(host_port_parts) == 2:
# Host name and port number
hostname, port = host_port_parts
try:
port = int(port)
except (ValueError, TypeError) as e:
raise DatabaseConnectionError(f'Database connection string contained an '
f'invalid port, ports must be integers: '
f'{e}') from e
if not 0 < port < 63336:
raise DatabaseConnectionError(f'Database connection string contained an '
f'invalid port, ports must be between 1 and '
f'65535, got {port}')
else:
# Malformed
raise DatabaseConnectionError('Database connection host must be a hostname or '
'a hostname:port combination')
if database.startswith('/'):
database = database[1:]
if not database:
raise DatabaseConnectionError('Database connection string path must be a '
'string in the format of /databasename')
if '/' in database:
raise DatabaseConnectionError(f'Database connection string path can only '
f'contain a single string name, got: {database}')
db_dict = {
'DRIVER': driver,
'ENGINE': django_driver,
'NAME': database,
'USER': username,
'PASSWORD': password,
'HOST': hostname,
'PORT': port,
'CONN_HEALTH_CHECKS': True,
'CONN_MAX_AGE': 0,
'OPTIONS': backend_options.get(driver),
}
db_dict.update(db_overrides.get(driver))
return db_dict
def get_client_ip(request):
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
def append_uri_params(uri, params):
uri = str(uri)
qs = urlencode(params)
return urlunsplit(('', '', uri, qs, ''))
def clean_filename(filename):
if not isinstance(filename, str):
raise ValueError(f'filename must be a str, got {type(filename)}')
to_scrub = r'<>\/:*?"|%'
for char in list(to_scrub):
filename = filename.replace(char, '')
clean_filename = ''
for c in filename:
if c in string.whitespace:
c = ' '
if ord(c) > 30:
clean_filename += c
return clean_filename.strip()
def clean_emoji(s):
if not isinstance(s, str):
raise ValueError(f'parameter must be a str, got {type(s)}')
return emoji.replace_emoji(s)
def seconds_to_timestr(seconds):
seconds = seconds % (24 * 3600)
hour = seconds // 3600
seconds %= 3600
minutes = seconds // 60
seconds %= 60
return '{:02d}:{:02d}:{:02d}'.format(hour, minutes, seconds)
def time_func(func):
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
end = time.perf_counter()
return (result, (end - start, start, end,),)
return wrapper
def profile_func(func):
def wrapper(*args, **kwargs):
s = io.StringIO()
with cProfile.Profile() as pr:
pr.enable()
result = func(*args, **kwargs)
pr.disable()
ps = pstats.Stats(pr, stream=s)
ps.sort_stats(
pstats.SortKey.CUMULATIVE
).print_stats()
return (result, (s.getvalue(), ps, s,),)
return wrapper
def remove_enclosed(haystack, /, open='[', close=']', sep=' ', *, valid=None, start=None, end=None):
if not haystack:
return haystack
assert open and close, 'open and close are required to be non-empty strings'
o = haystack.find(open, start, end)
sep = sep or ''
n = close + sep
c = haystack.find(n, len(open)+o, end)
if -1 in {o, c}:
return haystack
if valid is not None:
content = haystack[len(open)+o:c]
found = set(content)
valid = set(valid)
invalid = found - valid
# assert not invalid, f'Invalid characters {invalid} found in: {content}'
if invalid:
return haystack
return haystack[:o] + haystack[len(n)+c:]
def resolve_priority_order(user_input, master_list, cutoff=0.6):
"""
Normalizes and validates a user's preferred order against a master set.
"""
resolved = []
# Index for case-insensitive and underscore/hyphen normalization
norm_map = {m.lower().replace('_', '-'): m for m in master_list}
for item in user_input:
# 1. Exact match (fastest)
if item in master_list:
if item not in resolved:
resolved.append(item)
continue
# 2. Normalized match (handles 'en_US' -> 'en-US' or 'EN-US' -> 'en-US')
norm_item = item.lower().replace('_', '-')
if norm_item in norm_map:
match = norm_map[norm_item]
if match not in resolved:
resolved.append(match)
continue
# 3. Fuzzy match (handles 'English' -> 'en' or 'USA' -> 'en-US')
matches = difflib.get_close_matches(item, master_list, n=1, cutoff=cutoff)
if matches and matches[0] not in resolved:
resolved.append(matches[0])
return resolved
def django_queryset_generator(query_set, /, *,
page_size=100,
chunk_size=None,
use_chunked_fetch=False,
):
qs = query_set.values_list('pk', flat=True)
# Avoid the `UnorderedObjectListWarning`
if not query_set.ordered:
qs = qs.order_by('pk')
collecting = gc.isenabled()
gc.disable()
if use_chunked_fetch:
for key in qs._iterator(use_chunked_fetch, chunk_size):
try:
yield query_set.filter(pk=key)[0]
except IndexError as exc:
msg = f'missing primary key: {key}'
raise QuerySetEmptyError(msg, exc=exc, key=key) from exc
key = None
gc.collect(generation=1)
key = None
else:
for page in iter(Paginator(qs, page_size)):
for key in page.object_list:
try:
yield query_set.filter(pk=key)[0]
except IndexError as exc:
msg = f'missing primary key: {key}'
raise QuerySetEmptyError(msg, exc=exc, key=key) from exc
key = None
gc.collect(generation=1)
key = None
page = None
gc.collect()
page = None
qs = None
gc.collect()
if collecting:
gc.enable()