Skip to content

Commit 5022b33

Browse files
authored
Merge pull request #148 from dsanders11/feat/prompt-yaml-model-parameters
feat: support modelParameters in prompt.yaml files
2 parents 36ea137 + c9e1471 commit 5022b33

File tree

7 files changed

+45
-5
lines changed

7 files changed

+45
-5
lines changed

__tests__/helpers-inference.test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ describe('helpers.ts - inference request building', () => {
106106
undefined,
107107
undefined,
108108
'gpt-4',
109+
undefined,
110+
undefined,
109111
100,
110112
'https://api.test.com',
111113
'test-token',
@@ -117,6 +119,8 @@ describe('helpers.ts - inference request building', () => {
117119
{role: 'user', content: 'User message'},
118120
],
119121
modelName: 'gpt-4',
122+
temperature: undefined,
123+
topP: undefined,
120124
maxTokens: 100,
121125
endpoint: 'https://api.test.com',
122126
token: 'test-token',
@@ -136,6 +140,8 @@ describe('helpers.ts - inference request building', () => {
136140
'System prompt',
137141
'User prompt',
138142
'gpt-4',
143+
undefined,
144+
undefined,
139145
100,
140146
'https://api.test.com',
141147
'test-token',
@@ -147,6 +153,8 @@ describe('helpers.ts - inference request building', () => {
147153
{role: 'user', content: 'User prompt'},
148154
],
149155
modelName: 'gpt-4',
156+
temperature: undefined,
157+
topP: undefined,
150158
maxTokens: 100,
151159
endpoint: 'https://api.test.com',
152160
token: 'test-token',

dist/index.js

Lines changed: 12 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/index.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/helpers.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ export function buildInferenceRequest(
8282
systemPrompt: string | undefined,
8383
prompt: string | undefined,
8484
modelName: string,
85+
temperature: number | undefined,
86+
topP: number | undefined,
8587
maxTokens: number,
8688
endpoint: string,
8789
token: string,
@@ -92,6 +94,8 @@ export function buildInferenceRequest(
9294
return {
9395
messages,
9496
modelName,
97+
temperature,
98+
topP,
9599
maxTokens,
96100
endpoint,
97101
token,

src/inference.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ export interface InferenceRequest {
1515
maxTokens: number
1616
endpoint: string
1717
token: string
18+
temperature?: number
19+
topP?: number
1820
responseFormat?: {type: 'json_schema'; json_schema: unknown} // Processed response format for the API
1921
}
2022

@@ -45,6 +47,8 @@ export async function simpleInference(request: InferenceRequest): Promise<string
4547
messages: request.messages as OpenAI.Chat.Completions.ChatCompletionMessageParam[],
4648
max_tokens: request.maxTokens,
4749
model: request.modelName,
50+
temperature: request.temperature,
51+
top_p: request.topP,
4852
}
4953

5054
// Add response format if specified
@@ -90,6 +94,8 @@ export async function mcpInference(
9094
messages: messages as OpenAI.Chat.Completions.ChatCompletionMessageParam[],
9195
max_tokens: request.maxTokens,
9296
model: request.modelName,
97+
temperature: request.temperature,
98+
top_p: request.topP,
9399
}
94100

95101
// Add response format if specified (only on final iteration to avoid conflicts with tool calls)

src/main.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ export async function run(): Promise<void> {
5353

5454
// Get common parameters
5555
const modelName = promptConfig?.model || core.getInput('model')
56-
const maxTokens = parseInt(core.getInput('max-tokens'), 10)
56+
let maxTokens = promptConfig?.modelParameters?.maxTokens ?? core.getInput('max-tokens')
57+
58+
if (typeof maxTokens === 'string') {
59+
maxTokens = parseInt(maxTokens, 10)
60+
}
5761

5862
const token = process.env['GITHUB_TOKEN'] || core.getInput('token')
5963
if (token === undefined) {
@@ -71,6 +75,8 @@ export async function run(): Promise<void> {
7175
systemPrompt,
7276
prompt,
7377
modelName,
78+
promptConfig?.modelParameters?.temperature,
79+
promptConfig?.modelParameters?.topP,
7480
maxTokens,
7581
endpoint,
7682
token,

src/prompt.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@ export interface PromptMessage {
77
content: string
88
}
99

10+
export interface ModelParameters {
11+
maxTokens?: number
12+
temperature?: number
13+
topP?: number
14+
}
15+
1016
export interface PromptConfig {
1117
messages: PromptMessage[]
1218
model?: string
19+
modelParameters?: ModelParameters
1320
responseFormat?: 'text' | 'json_schema'
1421
jsonSchema?: string
1522
}

0 commit comments

Comments
 (0)