Skip to content

Commit d7d47ed

Browse files
committed
Add support for x0 with skopt search strategy
1 parent c05ed93 commit d7d47ed

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

kernel_tuner/strategies/skopt.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def tune(searchspace: Searchspace, runner, tuning_options):
3232
max_fevals = min(tuning_options.get("max_fevals", np.inf), searchspace.size)
3333

3434
# Const function
35-
cost_func = CostFunc(searchspace, tuning_options, runner)
3635
opt_config, opt_result = None, None
3736

3837
# The dimensions. Parameters with one value become categorical
@@ -70,18 +69,26 @@ def tune(searchspace: Searchspace, runner, tuning_options):
7069
# Ask initial batch of configs
7170
num_initial = optimizer._n_initial_points
7271
batch = optimizer.ask(num_initial, lie_strategy)
73-
Xs, Ys = [], []
72+
xs, ys = [], []
7473
eval_count = 0
7574

7675
if tuning_options.verbose:
7776
print(f"Asked optimizer for {num_initial} points: {batch}")
7877

78+
# Create cost function
79+
cost_func = CostFunc(searchspace, tuning_options, runner)
80+
x0 = cost_func.get_start_pos()
81+
82+
# Add x0 if the user has requested it
83+
if x0 is not None:
84+
batch.insert(0, searchspace.get_param_indices(x0))
85+
7986
try:
8087
while eval_count < max_fevals:
8188
if not batch:
82-
optimizer.tell(Xs, Ys)
89+
optimizer.tell(xs, ys)
8390
batch = optimizer.ask(batch_size, lie_strategy)
84-
Xs, Ys = [], []
91+
xs, ys = [], []
8592

8693
if tuning_options.verbose:
8794
print(f"Asked optimizer for {batch_size} points: {batch}")
@@ -90,8 +97,8 @@ def tune(searchspace: Searchspace, runner, tuning_options):
9097
y = cost_func(searchspace.get_param_config_from_param_indices(x))
9198
eval_count += 1
9299

93-
Xs.append(x)
94-
Ys.append(y)
100+
xs.append(x)
101+
ys.append(y)
95102

96103
if opt_result is None or y < opt_result:
97104
opt_config, opt_result = x, y
@@ -101,7 +108,7 @@ def tune(searchspace: Searchspace, runner, tuning_options):
101108
print(e)
102109

103110
if opt_result is not None and tuning_options.verbose:
104-
print(f"Best configuration: {opt_result}")
111+
print(f"Best configuration: {opt_config}")
105112

106113
return cost_func.results
107114

0 commit comments

Comments
 (0)