diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 72518c783f..0059f3879c 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -104,7 +104,7 @@ import { import { MODELS_BY_PROVIDER, PROVIDERS } from "./constants" import { inputEventTransform, noTransform } from "./transforms" -import { ModelInfoView } from "./ModelInfoView" +import { ModelPicker } from "./ModelPicker" import { ApiErrorMessage } from "./ApiErrorMessage" import { ThinkingBudget } from "./ThinkingBudget" import { Verbosity } from "./Verbosity" @@ -173,7 +173,6 @@ const ApiOptions = ({ [customHeaders, apiConfiguration?.openAiHeaders, setApiConfigurationField], ) - const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false) const [isAdvancedSettingsOpen, setIsAdvancedSettingsOpen] = useState(false) const handleInputChange = useCallback( @@ -272,31 +271,47 @@ const ApiOptions = ({ setErrorMessage(apiValidationResult) }, [apiConfiguration, routerModels, organizationAllowList, setErrorMessage]) - const selectedProviderModels = useMemo(() => { + // Get models for static providers (those with models defined in MODELS_BY_PROVIDER) + const staticProviderModels = useMemo(() => { const models = MODELS_BY_PROVIDER[selectedProvider] + if (!models) return null + return filterModels(models, selectedProvider, organizationAllowList) + }, [selectedProvider, organizationAllowList]) + + // Get the default model ID for the current static provider + const staticProviderDefaultModelId = useMemo(() => { + const defaults: Partial> = { + anthropic: anthropicDefaultModelId, + "openai-native": openAiNativeDefaultModelId, + gemini: geminiDefaultModelId, + deepseek: deepSeekDefaultModelId, + doubao: doubaoDefaultModelId, + moonshot: moonshotDefaultModelId, + mistral: mistralDefaultModelId, + xai: xaiDefaultModelId, + groq: groqDefaultModelId, + cerebras: cerebrasDefaultModelId, + baseten: basetenDefaultModelId, + bedrock: bedrockDefaultModelId, + vertex: vertexDefaultModelId, + sambanova: sambaNovaDefaultModelId, + zai: + apiConfiguration.zaiApiLine === "china_coding" + ? mainlandZAiDefaultModelId + : internationalZAiDefaultModelId, + fireworks: fireworksDefaultModelId, + featherless: featherlessDefaultModelId, + minimax: minimaxDefaultModelId, + "qwen-code": qwenCodeDefaultModelId, + } + return defaults[selectedProvider] || "" + }, [selectedProvider, apiConfiguration.zaiApiLine]) - if (!models) return [] - - const filteredModels = filterModels(models, selectedProvider, organizationAllowList) - - // Include the currently selected model even if deprecated (so users can see what they have selected) - // But filter out other deprecated models from being newly selectable - const availableModels = filteredModels - ? Object.entries(filteredModels) - .filter(([modelId, modelInfo]) => { - // Always include the currently selected model - if (modelId === selectedModelId) return true - // Filter out deprecated models that aren't currently selected - return !modelInfo.deprecated - }) - .map(([modelId]) => ({ - value: modelId, - label: modelId, - })) - : [] - - return availableModels - }, [selectedProvider, organizationAllowList, selectedModelId]) + // Get the provider label for display + const staticProviderLabel = useMemo(() => { + const provider = PROVIDERS.find(({ value }) => value === selectedProvider) + return provider?.label || selectedProvider + }, [selectedProvider]) const onProviderChange = useCallback( (value: ProviderName) => { @@ -781,16 +796,16 @@ const ApiOptions = ({ )} {/* Skip generic model picker for claude-code since it has its own in ClaudeCode.tsx */} - {selectedProviderModels.length > 0 && selectedProvider !== "claude-code" && ( + {staticProviderModels !== null && selectedProvider !== "claude-code" && ( <> -
- - -
- - {/* Show error if a deprecated model is selected */} - {selectedModelInfo?.deprecated && ( - - )} + } + }} + defaultModelId={staticProviderDefaultModelId} + models={staticProviderModels} + modelIdKey="apiModelId" + serviceName={staticProviderLabel} + serviceUrl={docs?.url || ""} + organizationAllowList={organizationAllowList} + simplifySettings={fromWelcomeView} + hidePricing + extraOptions={ + selectedProvider === "bedrock" + ? [{ value: "custom-arn", label: t("settings:labels.useCustomArn") }] + : undefined + } + /> {selectedProvider === "bedrock" && selectedModelId === "custom-arn" && ( )} - - {/* Only show model info if not deprecated */} - {!selectedModelInfo?.deprecated && ( - - )} )} diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index 4fe4c02dda..accfa9efb6 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -39,6 +39,11 @@ type ModelIdKey = keyof Pick< | "apiModelId" > +interface ExtraOption { + value: string + label: string +} + interface ModelPickerProps { defaultModelId: string models: Record | null @@ -55,6 +60,7 @@ interface ModelPickerProps { errorMessage?: string simplifySettings?: boolean hidePricing?: boolean + extraOptions?: ExtraOption[] } export const ModelPicker = ({ @@ -69,6 +75,7 @@ export const ModelPicker = ({ errorMessage, simplifySettings, hidePricing, + extraOptions, }: ModelPickerProps) => { const { t } = useAppTranslation() @@ -232,6 +239,23 @@ export const ModelPicker = ({ /> ))} + {extraOptions?.map((option) => ( + + + {option.label} + + + + ))} {searchValue && !modelIds.includes(searchValue) && (