mirror of
https://github.com/open-webui/open-webui.git
synced 2026-02-20 07:06:24 +08:00
446 lines
15 KiB
Python
446 lines
15 KiB
Python
import time
|
||
import logging
|
||
import asyncio
|
||
import sys
|
||
|
||
from aiocache import cached
|
||
from fastapi import Request
|
||
|
||
from open_webui.socket.utils import RedisDict
|
||
from open_webui.routers import openai, ollama
|
||
from open_webui.functions import get_function_models
|
||
|
||
|
||
from open_webui.models.functions import Functions
|
||
from open_webui.models.models import Models
|
||
from open_webui.models.access_grants import AccessGrants
|
||
from open_webui.models.groups import Groups
|
||
|
||
|
||
from open_webui.utils.plugin import (
|
||
load_function_module_by_id,
|
||
get_function_module_from_cache,
|
||
)
|
||
from open_webui.utils.access_control import has_access
|
||
|
||
|
||
from open_webui.config import (
|
||
BYPASS_ADMIN_ACCESS_CONTROL,
|
||
DEFAULT_ARENA_MODEL,
|
||
)
|
||
|
||
from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL
|
||
from open_webui.models.users import UserModel
|
||
|
||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||
log = logging.getLogger(__name__)
|
||
|
||
|
||
async def fetch_ollama_models(request: Request, user: UserModel = None):
|
||
raw_ollama_models = await ollama.get_all_models(request, user=user)
|
||
return [
|
||
{
|
||
"id": model["model"],
|
||
"name": model["name"],
|
||
"object": "model",
|
||
"created": int(time.time()),
|
||
"owned_by": "ollama",
|
||
"ollama": model,
|
||
"connection_type": model.get("connection_type", "local"),
|
||
"tags": model.get("tags", []),
|
||
}
|
||
for model in raw_ollama_models["models"]
|
||
]
|
||
|
||
|
||
async def fetch_openai_models(request: Request, user: UserModel = None):
|
||
openai_response = await openai.get_all_models(request, user=user)
|
||
return openai_response["data"]
|
||
|
||
|
||
async def get_all_base_models(request: Request, user: UserModel = None):
|
||
openai_task = (
|
||
fetch_openai_models(request, user)
|
||
if request.app.state.config.ENABLE_OPENAI_API
|
||
else asyncio.sleep(0, result=[])
|
||
)
|
||
ollama_task = (
|
||
fetch_ollama_models(request, user)
|
||
if request.app.state.config.ENABLE_OLLAMA_API
|
||
else asyncio.sleep(0, result=[])
|
||
)
|
||
function_task = get_function_models(request)
|
||
|
||
openai_models, ollama_models, function_models = await asyncio.gather(
|
||
openai_task, ollama_task, function_task
|
||
)
|
||
|
||
return function_models + openai_models + ollama_models
|
||
|
||
|
||
async def get_all_models(request, refresh: bool = False, user: UserModel = None):
|
||
if (
|
||
request.app.state.MODELS
|
||
and request.app.state.BASE_MODELS
|
||
and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
|
||
):
|
||
base_models = request.app.state.BASE_MODELS
|
||
else:
|
||
base_models = await get_all_base_models(request, user=user)
|
||
request.app.state.BASE_MODELS = base_models
|
||
|
||
# deep copy the base models to avoid modifying the original list
|
||
models = [model.copy() for model in base_models]
|
||
|
||
# If there are no models, return an empty list
|
||
if len(models) == 0:
|
||
return []
|
||
|
||
# Add arena models
|
||
if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
|
||
arena_models = []
|
||
if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
|
||
arena_models = [
|
||
{
|
||
"id": model["id"],
|
||
"name": model["name"],
|
||
"info": {
|
||
"meta": model["meta"],
|
||
},
|
||
"object": "model",
|
||
"created": int(time.time()),
|
||
"owned_by": "arena",
|
||
"arena": True,
|
||
}
|
||
for model in request.app.state.config.EVALUATION_ARENA_MODELS
|
||
]
|
||
else:
|
||
# Add default arena model
|
||
arena_models = [
|
||
{
|
||
"id": DEFAULT_ARENA_MODEL["id"],
|
||
"name": DEFAULT_ARENA_MODEL["name"],
|
||
"info": {
|
||
"meta": DEFAULT_ARENA_MODEL["meta"],
|
||
},
|
||
"object": "model",
|
||
"created": int(time.time()),
|
||
"owned_by": "arena",
|
||
"arena": True,
|
||
}
|
||
]
|
||
models = models + arena_models
|
||
|
||
global_action_ids = [
|
||
function.id for function in Functions.get_global_action_functions()
|
||
]
|
||
enabled_action_ids = [
|
||
function.id
|
||
for function in Functions.get_functions_by_type("action", active_only=True)
|
||
]
|
||
|
||
global_filter_ids = [
|
||
function.id for function in Functions.get_global_filter_functions()
|
||
]
|
||
enabled_filter_ids = [
|
||
function.id
|
||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||
]
|
||
|
||
custom_models = Models.get_all_models()
|
||
for custom_model in custom_models:
|
||
if custom_model.base_model_id is None:
|
||
# Applied directly to a base model
|
||
for model in models:
|
||
if custom_model.id == model["id"] or (
|
||
model.get("owned_by") == "ollama"
|
||
and custom_model.id
|
||
== model["id"].split(":")[
|
||
0
|
||
] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
|
||
):
|
||
if custom_model.is_active:
|
||
model["name"] = custom_model.name
|
||
model["info"] = custom_model.model_dump()
|
||
|
||
# Set action_ids and filter_ids
|
||
action_ids = []
|
||
filter_ids = []
|
||
|
||
if "info" in model:
|
||
if "meta" in model["info"]:
|
||
action_ids.extend(
|
||
model["info"]["meta"].get("actionIds", [])
|
||
)
|
||
filter_ids.extend(
|
||
model["info"]["meta"].get("filterIds", [])
|
||
)
|
||
|
||
if "params" in model["info"]:
|
||
# Remove params to avoid exposing sensitive info
|
||
del model["info"]["params"]
|
||
|
||
model["action_ids"] = action_ids
|
||
model["filter_ids"] = filter_ids
|
||
else:
|
||
models.remove(model)
|
||
|
||
elif custom_model.is_active and (
|
||
custom_model.id not in [model["id"] for model in models]
|
||
):
|
||
# Custom model based on a base model
|
||
owned_by = "openai"
|
||
connection_type = None
|
||
|
||
pipe = None
|
||
|
||
for m in models:
|
||
if (
|
||
custom_model.base_model_id == m["id"]
|
||
or custom_model.base_model_id == m["id"].split(":")[0]
|
||
):
|
||
owned_by = m.get("owned_by", "unknown")
|
||
if "pipe" in m:
|
||
pipe = m["pipe"]
|
||
|
||
connection_type = m.get("connection_type", None)
|
||
break
|
||
|
||
model = {
|
||
"id": f"{custom_model.id}",
|
||
"name": custom_model.name,
|
||
"object": "model",
|
||
"created": custom_model.created_at,
|
||
"owned_by": owned_by,
|
||
"connection_type": connection_type,
|
||
"preset": True,
|
||
**({"pipe": pipe} if pipe is not None else {}),
|
||
}
|
||
|
||
info = custom_model.model_dump()
|
||
if "params" in info:
|
||
# Remove params to avoid exposing sensitive info
|
||
del info["params"]
|
||
|
||
model["info"] = info
|
||
|
||
action_ids = []
|
||
filter_ids = []
|
||
|
||
if custom_model.meta:
|
||
meta = custom_model.meta.model_dump()
|
||
|
||
if "actionIds" in meta:
|
||
action_ids.extend(meta["actionIds"])
|
||
|
||
if "filterIds" in meta:
|
||
filter_ids.extend(meta["filterIds"])
|
||
|
||
model["action_ids"] = action_ids
|
||
model["filter_ids"] = filter_ids
|
||
|
||
models.append(model)
|
||
|
||
# Process action_ids to get the actions
|
||
def get_action_items_from_module(function, module):
|
||
actions = []
|
||
if hasattr(module, "actions"):
|
||
actions = module.actions
|
||
return [
|
||
{
|
||
"id": f"{function.id}.{action['id']}",
|
||
"name": action.get("name", f"{function.name} ({action['id']})"),
|
||
"description": function.meta.description,
|
||
"icon": action.get(
|
||
"icon_url",
|
||
function.meta.manifest.get("icon_url", None)
|
||
or getattr(module, "icon_url", None)
|
||
or getattr(module, "icon", None),
|
||
),
|
||
}
|
||
for action in actions
|
||
]
|
||
else:
|
||
return [
|
||
{
|
||
"id": function.id,
|
||
"name": function.name,
|
||
"description": function.meta.description,
|
||
"icon": function.meta.manifest.get("icon_url", None)
|
||
or getattr(module, "icon_url", None)
|
||
or getattr(module, "icon", None),
|
||
}
|
||
]
|
||
|
||
# Process filter_ids to get the filters
|
||
def get_filter_items_from_module(function, module):
|
||
return [
|
||
{
|
||
"id": function.id,
|
||
"name": function.name,
|
||
"description": function.meta.description,
|
||
"icon": function.meta.manifest.get("icon_url", None)
|
||
or getattr(module, "icon_url", None)
|
||
or getattr(module, "icon", None),
|
||
"has_user_valves": hasattr(module, "UserValves"),
|
||
}
|
||
]
|
||
|
||
# Batch-prefetch all needed function records to avoid N+1 queries
|
||
all_function_ids = set()
|
||
for model in models:
|
||
all_function_ids.update(model.get("action_ids", []))
|
||
all_function_ids.update(model.get("filter_ids", []))
|
||
all_function_ids.update(global_action_ids)
|
||
all_function_ids.update(global_filter_ids)
|
||
|
||
functions_by_id = {
|
||
f.id: f for f in Functions.get_functions_by_ids(list(all_function_ids))
|
||
}
|
||
|
||
# Pre-warm the function module cache once per unique function ID.
|
||
# This ensures each function's DB freshness check runs exactly once,
|
||
# not once per (model × function) pair.
|
||
for function_id in all_function_ids:
|
||
try:
|
||
get_function_module_from_cache(request, function_id)
|
||
except Exception as e:
|
||
log.info(f"Failed to load function module for {function_id}: {e}")
|
||
|
||
for model in models:
|
||
action_ids = [
|
||
action_id
|
||
for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
|
||
if action_id in enabled_action_ids
|
||
]
|
||
filter_ids = [
|
||
filter_id
|
||
for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
|
||
if filter_id in enabled_filter_ids
|
||
]
|
||
|
||
model["actions"] = []
|
||
for action_id in action_ids:
|
||
action_function = functions_by_id.get(action_id)
|
||
if action_function is None:
|
||
log.info(f"Action not found: {action_id}")
|
||
continue
|
||
|
||
function_module = request.app.state.FUNCTIONS.get(action_id)
|
||
if function_module is None:
|
||
log.info(f"Failed to load action module: {action_id}")
|
||
continue
|
||
model["actions"].extend(
|
||
get_action_items_from_module(action_function, function_module)
|
||
)
|
||
|
||
model["filters"] = []
|
||
for filter_id in filter_ids:
|
||
filter_function = functions_by_id.get(filter_id)
|
||
if filter_function is None:
|
||
log.info(f"Filter not found: {filter_id}")
|
||
continue
|
||
|
||
function_module = request.app.state.FUNCTIONS.get(filter_id)
|
||
if function_module is None:
|
||
log.info(f"Failed to load filter module: {filter_id}")
|
||
continue
|
||
if getattr(function_module, "toggle", None):
|
||
model["filters"].extend(
|
||
get_filter_items_from_module(filter_function, function_module)
|
||
)
|
||
|
||
log.debug(f"get_all_models() returned {len(models)} models")
|
||
|
||
models_dict = {model["id"]: model for model in models}
|
||
if isinstance(request.app.state.MODELS, RedisDict):
|
||
request.app.state.MODELS.set(models_dict)
|
||
else:
|
||
request.app.state.MODELS = models_dict
|
||
|
||
return models
|
||
|
||
|
||
def check_model_access(user, model, db=None):
|
||
if model.get("arena"):
|
||
meta = model.get("info", {}).get("meta", {})
|
||
access_grants = meta.get("access_grants", [])
|
||
if not has_access(
|
||
user.id,
|
||
permission="read",
|
||
access_grants=access_grants,
|
||
db=db,
|
||
):
|
||
raise Exception("Model not found")
|
||
else:
|
||
model_info = Models.get_model_by_id(model.get("id"), db=db)
|
||
if not model_info:
|
||
raise Exception("Model not found")
|
||
elif not (
|
||
user.id == model_info.user_id
|
||
or AccessGrants.has_access(
|
||
user_id=user.id,
|
||
resource_type="model",
|
||
resource_id=model_info.id,
|
||
permission="read",
|
||
db=db,
|
||
)
|
||
):
|
||
raise Exception("Model not found")
|
||
|
||
|
||
def get_filtered_models(models, user, db=None):
|
||
# Filter out models that the user does not have access to
|
||
if (
|
||
user.role == "user"
|
||
or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
|
||
) and not BYPASS_MODEL_ACCESS_CONTROL:
|
||
model_infos = {}
|
||
for model in models:
|
||
if model.get("arena"):
|
||
continue
|
||
info = model.get("info")
|
||
if info:
|
||
model_infos[model["id"]] = info
|
||
|
||
user_group_ids = {
|
||
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||
}
|
||
|
||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||
user_id=user.id,
|
||
resource_type="model",
|
||
resource_ids=list(model_infos.keys()),
|
||
permission="read",
|
||
user_group_ids=user_group_ids,
|
||
db=db,
|
||
)
|
||
|
||
filtered_models = []
|
||
for model in models:
|
||
if model.get("arena"):
|
||
meta = model.get("info", {}).get("meta", {})
|
||
access_grants = meta.get("access_grants", [])
|
||
if has_access(
|
||
user.id,
|
||
permission="read",
|
||
access_grants=access_grants,
|
||
user_group_ids=user_group_ids,
|
||
):
|
||
filtered_models.append(model)
|
||
continue
|
||
|
||
model_info = model_infos.get(model["id"])
|
||
if model_info:
|
||
if (
|
||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||
or user.id == model_info["user_id"]
|
||
or model["id"] in accessible_model_ids
|
||
):
|
||
filtered_models.append(model)
|
||
|
||
return filtered_models
|
||
else:
|
||
return models
|