diff --git a/pufferlib/config/ocean/lightsout.ini b/pufferlib/config/ocean/lightsout.ini new file mode 100644 index 0000000000..c0b683b3f8 --- /dev/null +++ b/pufferlib/config/ocean/lightsout.ini @@ -0,0 +1,15 @@ +[base] +package = ocean +env_name = puffer_lightsout +policy_name = Policy + +[env] +num_envs = 1024 +grid_size = 5 +max_steps = 200 + +[policy] +hidden_size = 512 + +[train] +total_timesteps = 10_000_000 diff --git a/pufferlib/ocean/environment.py b/pufferlib/ocean/environment.py index 6c56a4ea20..77e9aa12fd 100644 --- a/pufferlib/ocean/environment.py +++ b/pufferlib/ocean/environment.py @@ -130,6 +130,7 @@ def make_multiagent(buf=None, **kwargs): 'moba': 'Moba', 'matsci': 'Matsci', 'memory': 'Memory', + 'lightsout': 'LightsOut', 'boids': 'Boids', 'drone': 'Drone', 'nmmo3': 'NMMO3', diff --git a/pufferlib/ocean/lightsout/binding.c b/pufferlib/ocean/lightsout/binding.c new file mode 100644 index 0000000000..958dcc2e80 --- /dev/null +++ b/pufferlib/ocean/lightsout/binding.c @@ -0,0 +1,20 @@ +#include "lightsout.h" + +#define Env LightsOut +#include "../env_binding.h" + +static int my_init(Env* env, PyObject* args, PyObject* kwargs) { + env->grid_size = unpack(kwargs, "grid_size"); + env->cell_size = unpack(kwargs, "cell_size"); + env->max_steps = unpack(kwargs, "max_steps"); + env->ema = 0.5f; + env->score_ema = 0.0f; + env->scramble_prob = 0.15f; + return 0; +} + +static int my_log(PyObject* dict, Log* log) { + assign_to_dict(dict, "score", log->score); + assign_to_dict(dict, "scramble_p", log->scramble_p); + return 0; +} diff --git a/pufferlib/ocean/lightsout/lightsout.c b/pufferlib/ocean/lightsout/lightsout.c new file mode 100644 index 0000000000..33d574c0c4 --- /dev/null +++ b/pufferlib/ocean/lightsout/lightsout.c @@ -0,0 +1,53 @@ +#include +#include +#include "lightsout.h" + +static LightsOut* g_env = NULL; + +static void demo_cleanup(void) { + if (g_env == NULL) { + return; + } + free(g_env->observations); + free(g_env->actions); + free(g_env->rewards); + free(g_env->terminals); + c_close(g_env); + g_env = NULL; +} + +int demo(){ + srand((unsigned)time(NULL)); + LightsOut env = {.grid_size = 5, .cell_size = 100, .client = NULL}; + g_env = &env; + atexit(demo_cleanup); + env.observations = (unsigned char*)calloc(env.grid_size * env.grid_size, sizeof(unsigned char)); + env.actions = (int*)calloc(1, sizeof(int)); + env.rewards = (float*)calloc(1, sizeof(float)); + env.terminals = (unsigned char*)calloc(1, sizeof(unsigned char)); + + c_reset(&env); + env.client = make_client(env.cell_size, env.grid_size); + + while (!WindowShouldClose()) { + if (IsKeyPressed(KEY_UP) || IsKeyPressed(KEY_W)) env.client->cursor_row = (env.client->cursor_row - 1 + env.grid_size) % env.grid_size; + if (IsKeyPressed(KEY_DOWN) || IsKeyPressed(KEY_S)) env.client->cursor_row = (env.client->cursor_row + 1) % env.grid_size; + if (IsKeyPressed(KEY_LEFT) || IsKeyPressed(KEY_A)) env.client->cursor_col = (env.client->cursor_col - 1 + env.grid_size) % env.grid_size; + if (IsKeyPressed(KEY_RIGHT) || IsKeyPressed(KEY_D)) env.client->cursor_col = (env.client->cursor_col + 1) % env.grid_size; + if (IsKeyPressed(KEY_SPACE)) { + int idx = env.client->cursor_row * env.grid_size + env.client->cursor_col; + env.actions[0] = idx; + c_step(&env); + } else if (IsKeyPressed(KEY_R)) { + c_reset(&env); + } + c_render(&env); + } + + demo_cleanup(); + return 0; +} +int main(void) { + demo(); + return 0; +} diff --git a/pufferlib/ocean/lightsout/lightsout.h b/pufferlib/ocean/lightsout/lightsout.h new file mode 100644 index 0000000000..0b3e1ce7f0 --- /dev/null +++ b/pufferlib/ocean/lightsout/lightsout.h @@ -0,0 +1,232 @@ +#include +#include +#include +#include "raylib.h" + +// Only use floats. +typedef struct { + float score; + float scramble_p; + float n; // Required as the last field. +} Log; + +typedef struct Client { + int cell_size; + int cursor_row; + int cursor_col; +} Client; + +typedef struct { + Log log; // Required field. + unsigned char* observations; // Required field. Ensure type matches in .py and .c. + int* actions; // Required field. Ensure type matches in .py and .c. + float* rewards; // Required field. + unsigned char* terminals; // Required field. + int grid_size; + int cell_size; + int max_steps; + int step_count; + int lights_on; + int prev_action; + int last_action; + float episode_return; + float ema; + float score_ema; + float scramble_prob; + unsigned char* grid; + Client* client; +} LightsOut; + +void step_grid(LightsOut* env, int idx) { + if (idx < 0 || idx >= env->grid_size * env->grid_size) return; + int row = idx/env->grid_size; + int col = idx%env->grid_size; + + static const int dirs[5][2] = {{0,0}, {1,0}, {0,1}, {-1,0}, {0,-1}}; + for (int i = 0; i < 5; i++) { + int dr = dirs[i][0]; + int dc = dirs[i][1]; + int r = row + dr; + int c = col + dc; + if (r >= 0 && r < env->grid_size && c >= 0 && c < env->grid_size) { + int offset = r*env->grid_size + c; + unsigned char old = env->grid[offset]; + env->grid[offset] = (unsigned char)!old; + env->lights_on += old ? -1 : 1; + } + } +} + +void init_lightsout(LightsOut* env) { + int n = env->grid_size * env->grid_size; + if (env->grid == NULL) { + env->grid = (unsigned char*)calloc(n, sizeof(unsigned char)); + } else { + memset(env->grid, 0, n * sizeof(unsigned char)); + } + + if (env->ema > 0.7f && env->score_ema > 0.0f) { + env->scramble_prob = fminf(0.5f, env->scramble_prob + 0.01f); // Increase scramble prob if EMA is high + } else if (env->ema < 0.3f) { + env->scramble_prob = fmaxf(0.15f, env->scramble_prob - 0.01f); // Decrease scramble prob if EMA is low + } + + env->step_count = 0; + env->lights_on = 0; + env->prev_action = -1; + env->last_action = -1; + env->episode_return = 0.0f; + + for (int i = 0; i < n; i++) { + float u = (float)rand() / (float)RAND_MAX; + if (u < env->scramble_prob) { + step_grid(env, i); + } + } +} + +void c_close(LightsOut* env) { + free(env->grid); + env->grid = NULL; + if (env->client != NULL) { + if (IsWindowReady()) { + CloseWindow(); + } + free(env->client); + env->client = NULL; + } +} + +void compute_observations(LightsOut* env) { + for (int i = 0; i < env->grid_size * env->grid_size; i++) { + env->observations[i] = env->grid[i]; + } +} + +void c_reset(LightsOut* env) { + env->rewards[0] = 0.0f; + env->terminals[0] = 0; + init_lightsout(env); + compute_observations(env); +} + +void c_step(LightsOut* env) { + // Defer reset by one step so terminal observation is preserved. + if (env->terminals[0]) { + init_lightsout(env); + env->rewards[0] = 0.0f; + env->terminals[0] = 0; + compute_observations(env); + return; + } + + int num_cells = env->grid_size * env->grid_size; + int atn = env->actions[0]; + env->terminals[0] = 0; + + float reward = -0.02 * (36.0 / (env->grid_size * env->grid_size)); // Base step penalty. + int prev_on = env->lights_on; + if (atn < 0 || atn >= num_cells) { + reward -= 0.5f; // Invalid action penalty. + } else { + if (atn == env->last_action) { + reward -= 0.03f; // Penalty for pressing the same cell twice in a row. + } else if (atn == env->prev_action) { + reward -= 0.02f; // Penalty for 2-step loop (A,B,A). + } + if (env->client != NULL) { + env->client->cursor_row = atn / env->grid_size; + env->client->cursor_col = atn % env->grid_size; + } + step_grid(env, atn); + env->prev_action = env->last_action; + env->last_action = atn; + int next_on = env->lights_on; + reward += 0.005f * (float)(prev_on - next_on); // Dense shaping: improve when lights decrease. + } + env->step_count += 1; + + if (env->lights_on == 0) { + reward = 2.0f; // Solved reward. + env->ema = 0.85f * env->ema + 0.15f; // Update EMA of steps to solve. + env->terminals[0] = 1; + } else if (env->client == NULL && env->step_count >= env->max_steps) { + reward -= 0.5f; // Timeout penalty during training. + env->ema = 0.85f * env->ema; // Decay EMA since we failed to solve. + env->terminals[0] = 1; + } + + env->rewards[0] = reward; + env->episode_return += reward; + if (env->terminals[0]) { + env->score_ema = 0.9f * env->score_ema + 0.1f * env->episode_return; + env->log.n += 1.0f; + env->log.score += env->episode_return; + env->log.scramble_p += env->scramble_prob; + } + + compute_observations(env); +} + +// Raylib client +static const Color COLORS[] = { + (Color){6, 24, 24, 255}, + (Color){0, 0, 255, 255}, + (Color){255, 255, 255, 255} +}; + +Client* make_client(int cell_size, int grid_size) { + Client* client= (Client*)malloc(sizeof(Client)); + client->cell_size = cell_size; + client->cursor_row = 0; + client->cursor_col = 0; + InitWindow(grid_size*cell_size, grid_size*cell_size, "PufferLib LightsOut"); + SetTargetFPS(15); + return client; +} + +void c_render(LightsOut* env) { + if (IsWindowReady() && (WindowShouldClose() || IsKeyPressed(KEY_ESCAPE))) { + c_close(env); + exit(0); + } + + if (env->client == NULL) { + env->client = make_client(env->cell_size, env->grid_size); + } + + Client* client = env->client; + + BeginDrawing(); + ClearBackground(COLORS[0]); + int sz = client->cell_size; + for (int y = 0; y < env->grid_size; y++) { + for (int x = 0; x < env->grid_size; x++){ + int tile = env->grid[y*env->grid_size + x]; + if (tile != 0) + DrawRectangle(x*sz, y*sz, sz, sz, COLORS[tile]); + } + } + DrawRectangleLinesEx( + (Rectangle){client->cursor_col * sz, client->cursor_row * sz, sz, sz}, + 3.0f, + COLORS[2] + ); + + if (env->terminals[0]) { + const char* msg = "Solved"; + int font_size = 48; + int text_w = MeasureText(msg, font_size); + int screen_w = env->grid_size * env->cell_size; + int screen_h = env->grid_size * env->cell_size; + + DrawRectangle(0, 0, screen_w, screen_h, (Color){0, 0, 0, 120}); // dim overlay + DrawText(msg, (screen_w - text_w) / 2, (screen_h - font_size) / 2, font_size, RAYWHITE); + } + + EndDrawing(); + + if (env->terminals[0]) { + WaitTime(0.5); // hold solved screen + } +} diff --git a/pufferlib/ocean/lightsout/lightsout.py b/pufferlib/ocean/lightsout/lightsout.py new file mode 100644 index 0000000000..6469779827 --- /dev/null +++ b/pufferlib/ocean/lightsout/lightsout.py @@ -0,0 +1,75 @@ +"""Scaffold for a future LightsOut ocean environment.""" + +import gymnasium +import numpy as np + +import pufferlib +from pufferlib.ocean.lightsout import binding + +import time + +class LightsOut(pufferlib.PufferEnv): + def __init__(self, num_envs=1, render_mode=None, log_interval=128, grid_size=5, max_steps=None, buf=None, seed=0): + self.single_observation_space = gymnasium.spaces.Box(low=0, high=1, shape=(grid_size * grid_size,), dtype=np.uint8) + self.single_action_space = gymnasium.spaces.Discrete(grid_size * grid_size) + self.render_mode = render_mode + self.num_agents = num_envs + self.log_interval = log_interval + self.tick = 0 + + if max_steps is None: + max_steps = grid_size * grid_size * 10 + + super().__init__(buf) + self.c_envs = binding.vec_init( + self.observations, + self.actions, + self.rewards, + self.terminals, + self.truncations, + num_envs, + seed, + grid_size=grid_size, + cell_size=int(np.ceil(1280 / grid_size)), + max_steps=max_steps, + ) + self.grid_size = grid_size + + def reset(self, seed=None): + self.tick = 0 + if seed is None: + seed = time.time_ns() & 0x7FFFFFFF + binding.vec_reset(self.c_envs, seed) + return self.observations, [] + + def step(self, actions): + self.actions[:] = actions + self.tick += 1 + binding.vec_step(self.c_envs) + info = [] + if self.tick % self.log_interval == 0: + info.append(binding.vec_log(self.c_envs)) + return self.observations, self.rewards, self.terminals, self.truncations, info + + def render(self): + binding.vec_render(self.c_envs, 0) + + def close(self): + binding.vec_close(self.c_envs) + + +if __name__ == "__main__": + n = 4096 + env = LightsOut(num_envs=n) + env.reset() + steps = 0 + + cache = 1024 + actions = np.zeros((cache, n), dtype=np.int32) + + start = time.time() + while time.time() - start < 10: + env.step(actions[steps % cache]) + steps += 1 + + print("LightsOut SPS:", int(env.num_agents * steps / (time.time() - start))) diff --git a/pufferlib/ocean/lightsout/train.py b/pufferlib/ocean/lightsout/train.py new file mode 100644 index 0000000000..35ecb7b2db --- /dev/null +++ b/pufferlib/ocean/lightsout/train.py @@ -0,0 +1,52 @@ +from pufferlib import pufferl + + +def train_until_target(env_name="puffer_lightsout", load_model_path=None): + args = pufferl.load_config(env_name) + + args["train"]["device"] = "cuda" + args["vec"]["backend"] = "PufferEnv" + args["vec"]["num_envs"] = 1 + args["env"]["num_envs"] = 4096 + args["env"]["grid_size"] = 8 + + # High cap; run stops early when target is stable. + args["train"]["total_timesteps"] = 2_000_000_000 + args["train"]["ent_coef"] = 0.005 + args["train"]["learning_rate"] = 0.015 + args["train"]["update_epochs"] = 2 + args["train"]["minibatch_size"] = 32768 + + if load_model_path is not None: + args["load_model_path"] = load_model_path + + target_score = 0.42 + target_scramble_p = 0.499 + target_min_n = 50.0 + target_streak = 3 + streak = 0 + + def stop_on_target(logs): + nonlocal streak + p = logs.get("environment/scramble_p") + score = logs.get("environment/score") + n = logs.get("environment/n", 0.0) + if p is None or score is None: + return False + + hit = p >= target_scramble_p and score >= target_score and n >= target_min_n + streak = streak + 1 if hit else 0 + if hit: + print( + f"target hit: scramble_p={p:.3f} score={score:.3f} n={n:.1f} " + f"streak={streak}/{target_streak}" + ) + + return streak >= target_streak + + pufferl.train(env_name, args=args, early_stop_fn=stop_on_target) + + +if __name__ == "__main__": + train_until_target("puffer_lightsout", load_model_path=None) + # train_until_target("puffer_lightsout", load_model_path="latest") \ No newline at end of file