From c02747d93fb27b2a9a768efeaaf50dc4ca6afd13 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 3 Jul 2026 15:19:00 +0800 Subject: [PATCH] fix: include workspace_id in model parameters for get and save methods --- .../serializers/model_serializer.py | 426 ++++++++++-------- apps/models_provider/views/model.py | 4 +- 2 files changed, 231 insertions(+), 199 deletions(-) diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 7e866eda128..8d0cef2a3e7 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -6,26 +6,25 @@ from typing import Dict import uuid_utils.compat as uuid -from django.core.cache import cache -from django.db import transaction -from django.db.models import QuerySet -from django.utils.translation import gettext_lazy as _ -from rest_framework import serializers - from common.config.embedding_config import ModelManage from common.constants.cache_version import Cache_Version -from common.constants.permission_constants import ResourcePermission, ResourceAuthType +from common.constants.permission_constants import ResourceAuthType, ResourcePermission from common.database_model_manage.database_model_manage import DatabaseModelManage from common.db.search import native_search from common.exception.app_exception import AppApiException from common.utils.common import get_file_content -from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt +from common.utils.rsa_util import rsa_long_decrypt, rsa_long_encrypt +from django.core.cache import cache +from django.db import transaction +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ from maxkb.conf import PROJECT_DIR -from models_provider.base_model_provider import ValidCode, DownModelChunkStatus +from models_provider.base_model_provider import DownModelChunkStatus, ValidCode from models_provider.constants.model_provider_constants import ModelProvideConstants from models_provider.models import Model, Status from models_provider.tools import get_model_credential -from system_manage.models import WorkspaceUserResourcePermission, AuthTargetType +from rest_framework import serializers +from system_manage.models import AuthTargetType, WorkspaceUserResourcePermission from system_manage.models.resource_mapping import ResourceMapping from system_manage.serializers.resource_mapping_serializers import ResourceMappingSerializer from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer @@ -44,9 +43,19 @@ class ModelModelSerializer(serializers.ModelSerializer): class Meta: model = Model fields = [ - 'id', 'name', 'status', 'model_type', 'model_name', - 'user', 'provider', 'credential', 'meta', - 'model_params_form', 'workspace_id', 'create_time', 'update_time' + "id", + "name", + "status", + "model_type", + "model_name", + "user", + "provider", + "credential", + "meta", + "model_params_form", + "workspace_id", + "create_time", + "update_time", ] @@ -83,19 +92,15 @@ def pull(model: Model, credential: Dict): status = Status.ERROR message = "" for chunk in down_model_chunk.values(): - if chunk.get('status') == DownModelChunkStatus.success.value: + if chunk.get("status") == DownModelChunkStatus.success.value: status = Status.SUCCESS - elif chunk.get('status') == DownModelChunkStatus.error.value: + elif chunk.get("status") == DownModelChunkStatus.error.value: message = chunk.get("digest") - QuerySet(Model).filter(id=model.id).update( - meta={"down_model_chunk": [], "message": message}, - status=status - ) + QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": [], "message": message}, status=status) except Exception as e: QuerySet(Model).filter(id=model.id).update( - meta={"down_model_chunk": [], "message": str(e)}, - status=Status.ERROR + meta={"down_model_chunk": [], "message": str(e)}, status=Status.ERROR ) @@ -104,19 +109,19 @@ class ModelSerializer(serializers.Serializer): def model_to_dict(model: Model): credential = json.loads(rsa_long_decrypt(model.credential)) return { - 'id': str(model.id), - 'provider': model.provider, - 'name': model.name, - 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'credential': ModelProvideConstants[model.provider].value.get_model_credential( - model.model_type, model.model_name - ).encryption_dict(credential), - 'workspace_id': model.workspace_id, - 'nick_name': model.user.nick_name if model.user else '', - 'username': model.user.username if model.user else '' + "id": str(model.id), + "provider": model.provider, + "name": model.name, + "model_type": model.model_type, + "model_name": model.model_name, + "status": model.status, + "meta": model.meta, + "credential": ModelProvideConstants[model.provider] + .value.get_model_credential(model.model_type, model.model_name) + .encryption_dict(credential), + "workspace_id": model.workspace_id, + "nick_name": model.user.nick_name if model.user else "", + "username": model.user.username if model.user else "", } class Operate(serializers.Serializer): @@ -132,44 +137,49 @@ def is_valid(self, *, raise_exception=False): model_query = model_query.filter(workspace_id=workspace_id) model = model_query.first() if model is None: - raise AppApiException(500, _('Model does not exist')) - if model.workspace_id == 'None': - raise AppApiException(500, _('Shared models cannot be deleted or modified')) + raise AppApiException(500, _("Model does not exist")) + if model.workspace_id == "None": + raise AppApiException(500, _("Shared models cannot be deleted or modified")) def one(self, with_valid=False): if with_valid: super().is_valid(raise_exception=True) - model = QuerySet(Model).get( - id=self.data.get('id'), workspace_id=self.data.get('workspace_id', 'None') - ) + model = QuerySet(Model).get(id=self.data.get("id"), workspace_id=self.data.get("workspace_id", "None")) return ModelSerializer.model_to_dict(model) def one_meta(self, with_valid=False): model = None if with_valid: super().is_valid(raise_exception=True) - model = QuerySet(Model).filter(id=self.data.get("id"), - workspace_id=self.data.get('workspace_id', 'None')).first() + model = ( + QuerySet(Model) + .filter(id=self.data.get("id"), workspace_id=self.data.get("workspace_id", "None")) + .first() + ) if model is None: - raise AppApiException(500, _('Model does not exist')) - return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'workspace_id': model.workspace_id, - } + raise AppApiException(500, _("Model does not exist")) + return { + "id": str(model.id), + "provider": model.provider, + "name": model.name, + "model_type": model.model_type, + "model_name": model.model_name, + "status": model.status, + "meta": model.meta, + "workspace_id": model.workspace_id, + } def pause_download(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD) + QuerySet(Model).filter(id=self.data.get("id")).update(status=Status.PAUSE_DOWNLOAD) return True @transaction.atomic def delete(self, with_valid=True): if with_valid: super().is_valid(raise_exception=True) - model_id = self.data.get('id') + model_id = self.data.get("id") model = Model.objects.filter(id=model_id).first() if model is None: return True @@ -198,30 +208,28 @@ def delete(self, with_valid=True): def edit(self, instance: Dict, user_id: str, with_valid=True): if with_valid: super().is_valid(raise_exception=True) - model = QuerySet(Model).filter(id=self.data.get('id')).first() + model = QuerySet(Model).filter(id=self.data.get("id")).first() - credential, model_credential, provider_handler = ModelSerializer.Edit( - data={**instance}).is_valid( - model=model) + credential, model_credential, provider_handler = ModelSerializer.Edit(data={**instance}).is_valid( + model=model + ) try: model.status = Status.SUCCESS - default_params = {item['field']: item['default_value'] for item in model.model_params_form} + default_params = {item["field"]: item["default_value"] for item in model.model_params_form} # 校验模型认证数据 - provider_handler.is_valid_credential(model.model_type, - instance.get("model_name"), - credential, - default_params, - raise_exception=True) + provider_handler.is_valid_credential( + model.model_type, instance.get("model_name"), credential, default_params, raise_exception=True + ) except AppApiException as e: if e.code == ValidCode.model_not_fount: model.status = Status.DOWNLOAD else: raise e - update_keys = ['credential', 'name', 'model_type', 'model_name'] + update_keys = ["credential", "name", "model_type", "model_name"] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: - if update_key == 'credential': + if update_key == "credential": model_credential_str = json.dumps(credential) model.__setattr__(update_key, rsa_long_encrypt(model_credential_str)) else: @@ -235,38 +243,35 @@ def edit(self, instance: Dict, user_id: str, with_valid=True): return self.one(with_valid=False) class Edit(serializers.Serializer): - user_id = serializers.CharField(required=False, label=(_('user id'))) + user_id = serializers.CharField(required=False, label=(_("user id"))) - name = serializers.CharField(required=False, max_length=64, - label=(_("model name"))) + name = serializers.CharField(required=False, max_length=64, label=(_("model name"))) model_type = serializers.CharField(required=False, label=(_("model type"))) model_name = serializers.CharField(required=False, label=(_("base model"))) - credential = serializers.DictField(required=False, - label=(_("certification information"))) + credential = serializers.DictField(required=False, label=(_("certification information"))) workspace_id = serializers.CharField(required=False, label=(_("workspace id"))) def is_valid(self, model=None, raise_exception=False): super().is_valid(raise_exception=True) - filter_params = {'workspace_id': model.workspace_id} - if 'name' in self.data and self.data.get('name') is not None: - filter_params['name'] = self.data.get('name') + filter_params = {"workspace_id": model.workspace_id} + if "name" in self.data and self.data.get("name") is not None: + filter_params["name"] = self.data.get("name") if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists(): - raise AppApiException(500, _('base model【{model_name}】already exists').format( - model_name=self.data.get("name"))) + raise AppApiException( + 500, _("base model【{model_name}】already exists").format(model_name=self.data.get("name")) + ) ModelSerializer.model_to_dict(model) provider = model.provider - model_type = self.data.get('model_type') - model_name = self.data.get( - 'model_name') - credential = self.data.get('credential') + model_type = self.data.get("model_type") + model_name = self.data.get("model_name") + credential = self.data.get("credential") provider_handler = ModelProvideConstants[provider].value - model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, - model_name) + model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, model_name) source_model_credential = json.loads(rsa_long_decrypt(model.credential)) source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) if credential is not None: @@ -276,7 +281,7 @@ def is_valid(self, model=None, raise_exception=False): return credential, model_credential, provider_handler class Create(serializers.Serializer): - user_id = serializers.UUIDField(required=True, label=_('user id')) + user_id = serializers.UUIDField(required=True, label=_("user id")) name = serializers.CharField(required=True, max_length=64, label=_("model name")) provider = serializers.CharField(required=True, label=_("provider")) model_type = serializers.CharField(required=True, label=_("model type")) @@ -287,21 +292,21 @@ class Create(serializers.Serializer): def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) - if QuerySet(Model).filter( - name=self.data.get('name'), - workspace_id=self.data.get('workspace_id', 'None') - ).exists(): + if ( + QuerySet(Model) + .filter(name=self.data.get("name"), workspace_id=self.data.get("workspace_id", "None")) + .exists() + ): raise AppApiException( - 500, - _('base model【{model_name}】already exists').format(model_name=self.data.get("name")) + 500, _("base model【{model_name}】already exists").format(model_name=self.data.get("name")) ) - default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')} - ModelProvideConstants[self.data.get('provider')].value.is_valid_credential( - self.data.get('model_type'), - self.data.get('model_name'), - self.data.get('credential'), + default_params = {item["field"]: item["default_value"] for item in self.data.get("model_params_form")} + ModelProvideConstants[self.data.get("provider")].value.is_valid_credential( + self.data.get("model_type"), + self.data.get("model_name"), + self.data.get("credential"), default_params, - raise_exception=True + raise_exception=True, ) def insert(self, workspace_id, with_valid=True): @@ -315,28 +320,30 @@ def insert(self, workspace_id, with_valid=True): else: raise e - credential = self.data.get('credential') + credential = self.data.get("credential") model_data = { - 'id': uuid.uuid7(), - 'status': status, - 'user_id': self.data.get('user_id'), - 'name': self.data.get('name'), - 'credential': rsa_long_encrypt(json.dumps(credential)), - 'provider': self.data.get('provider'), - 'model_type': self.data.get('model_type'), - 'model_name': self.data.get('model_name'), - 'model_params_form': self.data.get('model_params_form'), - 'workspace_id': workspace_id + "id": uuid.uuid7(), + "status": status, + "user_id": self.data.get("user_id"), + "name": self.data.get("name"), + "credential": rsa_long_encrypt(json.dumps(credential)), + "provider": self.data.get("provider"), + "model_type": self.data.get("model_type"), + "model_name": self.data.get("model_name"), + "model_params_form": self.data.get("model_params_form"), + "workspace_id": workspace_id, } model = Model(**model_data) try: model.save() - if workspace_id != 'None': - UserResourcePermissionSerializer(data={ - 'workspace_id': workspace_id, - 'user_id': self.data.get('user_id'), - 'auth_target_type': AuthTargetType.MODEL.value - }).auth_resource(str(model.id)) + if workspace_id != "None": + UserResourcePermissionSerializer( + data={ + "workspace_id": workspace_id, + "user_id": self.data.get("user_id"), + "auth_target_type": AuthTargetType.MODEL.value, + } + ).auth_resource(str(model.id)) except Exception as save_error: # 可添加日志记录 raise AppApiException(500, _("Model saving failed")) from save_error @@ -349,12 +356,12 @@ def insert(self, workspace_id, with_valid=True): class Query(serializers.Serializer): user_id = serializers.CharField(required=True, label=_("User ID")) - name = serializers.CharField(required=False, max_length=64, label=_('model name')) - model_type = serializers.CharField(required=False, label=_('model type')) - model_name = serializers.CharField(required=False, label=_('base model')) - provider = serializers.CharField(required=False, label=_('provider')) - create_user = serializers.CharField(required=False, label=_('create user')) - workspace_id = serializers.CharField(required=False, label=_('workspace id')) + name = serializers.CharField(required=False, max_length=64, label=_("model name")) + model_type = serializers.CharField(required=False, label=_("model type")) + model_name = serializers.CharField(required=False, label=_("base model")) + provider = serializers.CharField(required=False, label=_("provider")) + create_user = serializers.CharField(required=False, label=_("create user")) + workspace_id = serializers.CharField(required=False, label=_("workspace id")) @staticmethod def is_x_pack_ee(): @@ -366,15 +373,23 @@ def list(self, workspace_id, with_valid): if with_valid: self.is_valid(raise_exception=True) user_id = self.data.get("user_id") - workspace_manage = is_workspace_manage_permission_read(user_id, workspace_id, 'MODEL:READ') + workspace_manage = is_workspace_manage_permission_read(user_id, workspace_id, "MODEL:READ") query_params = self._build_query_params(workspace_id, workspace_manage, user_id) is_x_pack_ee = self.is_x_pack_ee() - result = native_search(query_params, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql', - 'list_model.sql' if workspace_manage else ( - 'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql') - ))) + result = native_search( + query_params, + select_string=get_file_content( + os.path.join( + PROJECT_DIR, + "apps", + "models_provider", + "sql", + "list_model.sql" + if workspace_manage + else ("list_model_user_ee.sql" if is_x_pack_ee else "list_model_user.sql"), + ) + ), + ) return ResourceMappingSerializer().get_resource_count(result) def share_list(self, workspace_id, with_valid=True): @@ -382,24 +397,20 @@ def share_list(self, workspace_id, with_valid=True): self.is_valid(raise_exception=True) user_id = self.data.get("user_id") query_params = self._build_query_params(workspace_id, False, user_id) - result = [ - self._build_model_data( - model - ) for model in query_params.get('model_query_set') - ] + result = [self._build_model_data(model) for model in query_params.get("model_query_set")] return ResourceMappingSerializer().get_resource_count(result) def model_list(self, workspace_id, with_valid=True): if with_valid: self.is_valid(raise_exception=True) user_id = self.data.get("user_id") - workspace_manage = is_workspace_manage_permission_read(user_id, workspace_id, 'MODEL:READ') + workspace_manage = is_workspace_manage_permission_read(user_id, workspace_id, "MODEL:READ") queryset = self._build_query_params(workspace_id, workspace_manage, user_id) get_authorized_model = DatabaseModelManage.get_model("get_authorized_model") shared_queryset = QuerySet(Model).none() if get_authorized_model is not None: - shared_queryset = self._build_query_params('None', False, user_id)['model_query_set'] + shared_queryset = self._build_query_params("None", False, user_id)["model_query_set"] shared_queryset = get_authorized_model(shared_queryset, workspace_id) # 构建共享模型和普通模型列表 @@ -410,95 +421,116 @@ def model_list(self, workspace_id, with_valid=True): queryset, select_string=get_file_content( os.path.join( - PROJECT_DIR, "apps", "models_provider", 'sql', - 'list_model.sql' if workspace_manage else ( - 'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql') + PROJECT_DIR, + "apps", + "models_provider", + "sql", + "list_model.sql" + if workspace_manage + else ("list_model_user_ee.sql" if is_x_pack_ee else "list_model_user.sql"), ) - ) + ), ) - return { - "shared_model": shared_model, - "model": normal_model - } + return {"shared_model": shared_model, "model": normal_model} def _build_query_params(self, workspace_id, workspace_manage: bool, user_id): queryset = QuerySet(Model) if workspace_id: queryset = queryset.filter(workspace_id=workspace_id) - for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: + for field in ["name", "model_type", "model_name", "provider", "create_user"]: value = self.data.get(field) if value is not None: - if field == 'name': - queryset = queryset.filter(**{f'{field}__icontains': value}) - elif field == 'create_user': + if field == "name": + queryset = queryset.filter(**{f"{field}__icontains": value}) + elif field == "create_user": queryset = queryset.filter(user_id=value) else: queryset = queryset.filter(**{field: value}) queryset = queryset.order_by("-create_time") - return { - 'model_query_set': queryset, - 'workspace_user_resource_permission_query_set': QuerySet(WorkspaceUserResourcePermission).filter( - auth_target_type="MODEL", - workspace_id=workspace_id, - user_id=user_id)} if ( - not workspace_manage) else { - 'model_query_set': queryset, - } + return ( + { + "model_query_set": queryset, + "workspace_user_resource_permission_query_set": QuerySet(WorkspaceUserResourcePermission).filter( + auth_target_type="MODEL", workspace_id=workspace_id, user_id=user_id + ), + } + if (not workspace_manage) + else { + "model_query_set": queryset, + } + ) def _build_model_data(self, model): return { - 'id': str(model.id), - 'provider': model.provider, - 'name': model.name, - 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'user_id': model.user_id, - 'username': model.user.username, - 'nick_name': model.user.nick_name, + "id": str(model.id), + "provider": model.provider, + "name": model.name, + "model_type": model.model_type, + "model_name": model.model_name, + "status": model.status, + "meta": model.meta, + "user_id": model.user_id, + "username": model.user.username, + "nick_name": model.user.nick_name, } def page(self, current_page, page_size): pass class ModelParams(serializers.Serializer): - id = serializers.UUIDField(required=True, label=_('model id')) + id = serializers.UUIDField(required=True, label=_("model id")) + workspace_id = serializers.UUIDField(required=False, label=_("workspace id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) - model = QuerySet(Model).filter(id=self.data.get("id")).first() + + validated_data = self.validated_data + model = ( + QuerySet(Model) + .filter( + id=validated_data.get("id"), + workspace_id=validated_data.get("workspace_id"), + ) + .first() + ) + if model is None: raise AppApiException(500, _("Model does not exist")) + return model + def get_model_params(self, with_valid=True): + model = None + if with_valid: - self.is_valid(raise_exception=True) - model_id = self.data.get('id') - model = QuerySet(Model).filter(id=model_id).first() - return model.model_params_form + model = self.is_valid(raise_exception=True) + + return model.model_params_form if model else None def save_model_params_form(self, model_params_form, with_valid=True): + model = None if with_valid: - self.is_valid(raise_exception=True) + model = self.is_valid(raise_exception=True) if model_params_form is None: model_params_form = [] - model_id = self.data.get('id') - model = QuerySet(Model).filter(id=model_id).first() if not isinstance(model_params_form, list): - raise AppApiException(500, _('model_params_form must be a list')) + raise AppApiException(500, _("model_params_form must be a list")) # 还需要校验几个字段:label required default_value # 校验每个配置项的必要字段 for index, param in enumerate(model_params_form): if not isinstance(param, dict): - raise AppApiException(500, _('The {index}th item in model_params_form must be a dictionary').format( - index=index)) + raise AppApiException( + 500, _("The {index}th item in model_params_form must be a dictionary").format(index=index) + ) # 校验 label 字段 - if 'label' not in param or param['label'] is None: - raise AppApiException(500, - _('The label field is required for the {index}th item in model_params_form').format( - index=index)) + if "label" not in param or param["label"] is None: + raise AppApiException( + 500, + _("The label field is required for the {index}th item in model_params_form").format( + index=index + ), + ) model.model_params_form = model_params_form model.save() @@ -506,31 +538,31 @@ def save_model_params_form(self, model_params_form, with_valid=True): class WorkspaceSharedModelSerializer(serializers.Serializer): - workspace_id = serializers.CharField(required=True, label=_('workspace id')) - name = serializers.CharField(required=False, max_length=64, label=_('model name')) - model_type = serializers.CharField(required=False, label=_('model type')) - model_name = serializers.CharField(required=False, label=_('base model')) - provider = serializers.CharField(required=False, label=_('provider')) - create_user = serializers.CharField(required=False, label=_('create user')) + workspace_id = serializers.CharField(required=True, label=_("workspace id")) + name = serializers.CharField(required=False, max_length=64, label=_("model name")) + model_type = serializers.CharField(required=False, label=_("model type")) + model_name = serializers.CharField(required=False, label=_("base model")) + provider = serializers.CharField(required=False, label=_("provider")) + create_user = serializers.CharField(required=False, label=_("create user")) def get_share_model_list(self): self.is_valid(raise_exception=True) - workspace_id = self.data.get('workspace_id') + workspace_id = self.data.get("workspace_id") queryset = self._build_queryset(workspace_id) return [ { - 'id': str(model.id), - 'provider': model.provider, - 'name': model.name, - 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'user_id': model.user_id, - 'nick_name': model.user.nick_name, - 'username': model.user.username + "id": str(model.id), + "provider": model.provider, + "name": model.name, + "model_type": model.model_type, + "model_name": model.model_name, + "status": model.status, + "meta": model.meta, + "user_id": model.user_id, + "nick_name": model.user.nick_name, + "username": model.user.username, } for model in queryset.order_by("-create_time") ] @@ -542,12 +574,12 @@ def _build_queryset(self, workspace_id): if get_authorized_model is not None: queryset = get_authorized_model(queryset, workspace_id) - for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: + for field in ["name", "model_type", "model_name", "provider", "create_user"]: value = self.data.get(field) if value is not None: - if field == 'name': - queryset = queryset.filter(**{f'{field}__icontains': value}) - elif field == 'create_user': + if field == "name": + queryset = queryset.filter(**{f"{field}__icontains": value}) + elif field == "create_user": queryset = queryset.filter(user_id=value) else: queryset = queryset.filter(**{field: value}) diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index fe498b20ac1..fa9fb69b80d 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -195,7 +195,7 @@ class ModelParamsForm(APIView): RoleConstants.USER.get_workspace_role(),) def get(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.ModelParams(data={'id': model_id}).get_model_params()) + ModelSerializer.ModelParams(data={'id': model_id, 'workspace_id': workspace_id}).get_model_params()) @extend_schema(methods=['PUT'], summary=_('Save model parameter form'), @@ -217,7 +217,7 @@ def get(self, request: Request, workspace_id: str, model_id: str): ) def put(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.ModelParams(data={'id': model_id}).save_model_params_form(request.data)) + ModelSerializer.ModelParams(data={'id': model_id, 'workspace_id': workspace_id}).save_model_params_form(request.data)) class ModelMeta(APIView): authentication_classes = [TokenAuth]