Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 65 additions & 66 deletions webview-ui/src/components/settings/ApiOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ import {

import { MODELS_BY_PROVIDER, PROVIDERS } from "./constants"
import { inputEventTransform, noTransform } from "./transforms"
import { ModelInfoView } from "./ModelInfoView"
import { ModelPicker } from "./ModelPicker"
import { ApiErrorMessage } from "./ApiErrorMessage"
import { ThinkingBudget } from "./ThinkingBudget"
import { Verbosity } from "./Verbosity"
Expand Down Expand Up @@ -173,7 +173,6 @@ const ApiOptions = ({
[customHeaders, apiConfiguration?.openAiHeaders, setApiConfigurationField],
)

const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false)
const [isAdvancedSettingsOpen, setIsAdvancedSettingsOpen] = useState(false)

const handleInputChange = useCallback(
Expand Down Expand Up @@ -272,31 +271,47 @@ const ApiOptions = ({
setErrorMessage(apiValidationResult)
}, [apiConfiguration, routerModels, organizationAllowList, setErrorMessage])

const selectedProviderModels = useMemo(() => {
// Get models for static providers (those with models defined in MODELS_BY_PROVIDER)
const staticProviderModels = useMemo(() => {
const models = MODELS_BY_PROVIDER[selectedProvider]
if (!models) return null
return filterModels(models, selectedProvider, organizationAllowList)
}, [selectedProvider, organizationAllowList])

// Get the default model ID for the current static provider
const staticProviderDefaultModelId = useMemo(() => {
const defaults: Partial<Record<ProviderName, string>> = {
anthropic: anthropicDefaultModelId,
"openai-native": openAiNativeDefaultModelId,
gemini: geminiDefaultModelId,
deepseek: deepSeekDefaultModelId,
doubao: doubaoDefaultModelId,
moonshot: moonshotDefaultModelId,
mistral: mistralDefaultModelId,
xai: xaiDefaultModelId,
groq: groqDefaultModelId,
cerebras: cerebrasDefaultModelId,
baseten: basetenDefaultModelId,
bedrock: bedrockDefaultModelId,
vertex: vertexDefaultModelId,
sambanova: sambaNovaDefaultModelId,
zai:
apiConfiguration.zaiApiLine === "china_coding"
? mainlandZAiDefaultModelId
: internationalZAiDefaultModelId,
fireworks: fireworksDefaultModelId,
featherless: featherlessDefaultModelId,
minimax: minimaxDefaultModelId,
"qwen-code": qwenCodeDefaultModelId,
}
return defaults[selectedProvider] || ""
}, [selectedProvider, apiConfiguration.zaiApiLine])

if (!models) return []

const filteredModels = filterModels(models, selectedProvider, organizationAllowList)

// Include the currently selected model even if deprecated (so users can see what they have selected)
// But filter out other deprecated models from being newly selectable
const availableModels = filteredModels
? Object.entries(filteredModels)
.filter(([modelId, modelInfo]) => {
// Always include the currently selected model
if (modelId === selectedModelId) return true
// Filter out deprecated models that aren't currently selected
return !modelInfo.deprecated
})
.map(([modelId]) => ({
value: modelId,
label: modelId,
}))
: []

return availableModels
}, [selectedProvider, organizationAllowList, selectedModelId])
// Get the provider label for display
const staticProviderLabel = useMemo(() => {
const provider = PROVIDERS.find(({ value }) => value === selectedProvider)
return provider?.label || selectedProvider
}, [selectedProvider])

const onProviderChange = useCallback(
(value: ProviderName) => {
Expand Down Expand Up @@ -781,16 +796,16 @@ const ApiOptions = ({
)}

{/* Skip generic model picker for claude-code since it has its own in ClaudeCode.tsx */}
{selectedProviderModels.length > 0 && selectedProvider !== "claude-code" && (
{staticProviderModels !== null && selectedProvider !== "claude-code" && (
<>
<div>
<label className="block font-medium mb-1">{t("settings:providers.model")}</label>
<Select
value={selectedModelId === "custom-arn" ? "custom-arn" : selectedModelId}
onValueChange={(value) => {
setApiConfigurationField("apiModelId", value)

// Clear custom ARN if not using custom ARN option.
<ModelPicker
apiConfiguration={apiConfiguration}
setApiConfigurationField={(field, value, isUserAction) => {
setApiConfigurationField(field, value, isUserAction)

// Handle special cases when model changes
if (field === "apiModelId") {
// Clear custom ARN if not using custom ARN option for Bedrock
if (value !== "custom-arn" && selectedProvider === "bedrock") {
setApiConfigurationField("awsCustomArn", "")
}
Expand All @@ -800,45 +815,29 @@ const ApiOptions = ({
if (selectedProvider === "openai-native") {
setApiConfigurationField("reasoningEffort", undefined)
}
}}>
<SelectTrigger className="w-full">
<SelectValue placeholder={t("settings:common.select")} />
</SelectTrigger>
<SelectContent>
{selectedProviderModels.map((option) => (
<SelectItem key={option.value} value={option.value}>
{option.label}
</SelectItem>
))}
{selectedProvider === "bedrock" && (
<SelectItem value="custom-arn">{t("settings:labels.useCustomArn")}</SelectItem>
)}
</SelectContent>
</Select>
</div>

{/* Show error if a deprecated model is selected */}
{selectedModelInfo?.deprecated && (
<ApiErrorMessage errorMessage={t("settings:validation.modelDeprecated")} />
)}
}
}}
defaultModelId={staticProviderDefaultModelId}
models={staticProviderModels}
modelIdKey="apiModelId"
serviceName={staticProviderLabel}
serviceUrl={docs?.url || ""}
organizationAllowList={organizationAllowList}
simplifySettings={fromWelcomeView}
hidePricing
extraOptions={
selectedProvider === "bedrock"
? [{ value: "custom-arn", label: t("settings:labels.useCustomArn") }]
: undefined
}
/>

{selectedProvider === "bedrock" && selectedModelId === "custom-arn" && (
<BedrockCustomArn
apiConfiguration={apiConfiguration}
setApiConfigurationField={setApiConfigurationField}
/>
)}

{/* Only show model info if not deprecated */}
{!selectedModelInfo?.deprecated && (
<ModelInfoView
apiProvider={selectedProvider}
selectedModelId={selectedModelId}
modelInfo={selectedModelInfo}
isDescriptionExpanded={isDescriptionExpanded}
setIsDescriptionExpanded={setIsDescriptionExpanded}
/>
)}
</>
)}

Expand Down
24 changes: 24 additions & 0 deletions webview-ui/src/components/settings/ModelPicker.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ type ModelIdKey = keyof Pick<
| "apiModelId"
>

interface ExtraOption {
value: string
label: string
}

interface ModelPickerProps {
defaultModelId: string
models: Record<string, ModelInfo> | null
Expand All @@ -55,6 +60,7 @@ interface ModelPickerProps {
errorMessage?: string
simplifySettings?: boolean
hidePricing?: boolean
extraOptions?: ExtraOption[]
}

export const ModelPicker = ({
Expand All @@ -69,6 +75,7 @@ export const ModelPicker = ({
errorMessage,
simplifySettings,
hidePricing,
extraOptions,
}: ModelPickerProps) => {
const { t } = useAppTranslation()

Expand Down Expand Up @@ -232,6 +239,23 @@ export const ModelPicker = ({
/>
</CommandItem>
))}
{extraOptions?.map((option) => (
<CommandItem
key={option.value}
value={option.value}
onSelect={onSelect}
data-testid={`model-option-${option.value}`}>
<span className="truncate" title={option.label}>
{option.label}
</span>
<Check
className={cn(
"size-4 p-0.5 ml-auto",
option.value === selectedModelId ? "opacity-100" : "opacity-0",
)}
/>
</CommandItem>
))}
</CommandGroup>
</CommandList>
{searchValue && !modelIds.includes(searchValue) && (
Expand Down
Loading