Skip to content

Commit 11634cc

Browse files
committed
ckp: implement blocked time stepping
1 parent 7f1d4ab commit 11634cc

File tree

1 file changed

+54
-24
lines changed

1 file changed

+54
-24
lines changed

pyrevolve/pyrevolve.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABCMeta, abstractproperty, abstractmethod
22
import numpy as np
3+
import math
34
from . import crevolve as cr
45
from .compression import init_compression as init
56
from .schedulers import CRevolve, HRevolve, Action, Architecture
@@ -69,11 +70,12 @@ def __init__(
6970
fwd_operator,
7071
rev_operator,
7172
n_checkpoints,
72-
n_timesteps,
73+
op_timesteps,
7374
storage_list=None,
7475
scheduler=None,
7576
timings=None,
7677
profiler=None,
78+
block_size=1,
7779
):
7880
"""
7981
Initialises checkpointer for a given forward- and reverse operator, a
@@ -82,10 +84,10 @@ def __init__(
8284
methods and a scheduler object must be provided as well. Otherwise
8385
NumpyStorage and CRevolve are used as default
8486
"""
85-
if n_timesteps is None:
87+
if op_timesteps is None:
8688
raise Exception(
8789
"Online checkpointing not yet supported. Specify \
88-
number of time steps!"
90+
number of Operator time steps!"
8991
)
9092

9193
if profiler is None:
@@ -100,12 +102,15 @@ def __init__(
100102

101103
self.checkpoint = checkpoint
102104
self.n_checkpoints = n_checkpoints
103-
self.n_timesteps = n_timesteps
105+
self.block_size = block_size
106+
self.op_timesteps = op_timesteps
104107
self.timings = timings
105108
self.fwd_operator = fwd_operator
106109
self.rev_operator = rev_operator
107110
self.scheduler = scheduler
108111

112+
self.cp_timesteps = int(math.ceil(self.op_timesteps / self.block_size))
113+
109114
def addStorage(self, new_storage):
110115
self.storage_list.append(new_storage)
111116

@@ -175,6 +180,20 @@ def makespan(self):
175180
def ratio(self):
176181
return 0
177182

183+
@property
184+
def op_old_capo(self):
185+
return self.scheduler.old_capo * self.block_size
186+
187+
@property
188+
def op_capo(self):
189+
_op_capo = self.scheduler.capo * self.block_size
190+
return _op_capo if _op_capo < self.op_timesteps else self.op_timesteps
191+
192+
@property
193+
def next_op_capo(self):
194+
_op_capo = (self.scheduler.capo + 1) * self.block_size
195+
return _op_capo if _op_capo < self.op_timesteps else self.op_timesteps
196+
178197
def apply_forward(self):
179198
"""Executes only the forward computation while storing checkpoints,
180199
then returns."""
@@ -185,7 +204,7 @@ def apply_forward(self):
185204
# advance forward computation
186205
with self.profiler.get_timer("forward", "advance"):
187206
self.fwd_operator.apply(
188-
t_start=self.scheduler.old_capo, t_end=self.scheduler.capo
207+
t_start=self.op_old_capo, t_end=self.op_capo
189208
)
190209
elif action.type == Action.TAKESHOT:
191210
# take a snapshot: copy from workspace into storage
@@ -199,7 +218,7 @@ def apply_forward(self):
199218
# final step in the forward computation
200219
with self.profiler.get_timer("forward", "lastfw"):
201220
self.fwd_operator.apply(
202-
t_start=self.scheduler.old_capo, t_end=self.n_timesteps
221+
t_start=self.op_old_capo, t_end=self.op_timesteps
203222
)
204223
break
205224
elif action.type == Action.REVERSE:
@@ -224,10 +243,10 @@ def apply_reverse(self):
224243
# advance adjoint computation by a single step
225244
with self.profiler.get_timer("reverse", "reverse"):
226245
self.fwd_operator.apply(
227-
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
246+
t_start=self.op_capo, t_end=self.next_op_capo
228247
)
229248
self.rev_operator.apply(
230-
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
249+
t_start=self.op_capo, t_end=self.next_op_capo
231250
)
232251
elif action.type == Action.REVSTART:
233252
"""Sets the rev_operator to 'nt' only if its not already there.
@@ -236,7 +255,7 @@ def apply_reverse(self):
236255
"""
237256
with self.profiler.get_timer("reverse", "reverse"):
238257
self.rev_operator.apply(
239-
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
258+
t_start=self.op_capo, t_end=self.next_op_capo
240259
)
241260
elif action.type == Action.TAKESHOT:
242261
# take a snapshot: copy from workspace into storage
@@ -246,7 +265,7 @@ def apply_reverse(self):
246265
# advance forward computation
247266
with self.profiler.get_timer("reverse", "advance"):
248267
self.fwd_operator.apply(
249-
t_start=self.scheduler.old_capo, t_end=self.scheduler.capo
268+
t_start=self.op_old_capo, t_end=self.op_capo
250269
)
251270
elif action.type == Action.RESTORE:
252271
# restore a snapshot: copy from storage into workspace
@@ -303,13 +322,14 @@ def __init__(
303322
fwd_operator,
304323
rev_operator,
305324
n_checkpoints,
306-
n_timesteps,
325+
op_timesteps,
307326
timings=None,
308327
profiler=None,
309328
compression_params=None,
310329
diskstorage=False,
311330
filedir="./",
312331
singlefile=True,
332+
block_size=1,
313333
):
314334
"""
315335
Initializes a single-level Revolver
@@ -318,7 +338,7 @@ def __init__(
318338
fwd_operator: forward operator
319339
rev_operator: backward operator
320340
n_checkpoints: number of checkpoints
321-
n_timesteps: number of timesteps
341+
op_timesteps: number of timesteps
322342
timings: timings
323343
profiler: Profiler
324344
compression_params: compression scheme
@@ -331,20 +351,21 @@ def __init__(
331351
fwd_operator,
332352
rev_operator,
333353
n_checkpoints,
334-
n_timesteps,
354+
op_timesteps,
335355
timings=timings,
336356
profiler=profiler,
357+
block_size=block_size,
337358
)
338359

339360
self.filedir = filedir
340361
self.singlefile = singlefile
341362

342363
if n_checkpoints is None:
343-
self.n_checkpoints = cr.adjust(n_timesteps)
364+
self.n_checkpoints = cr.adjust(self.cp_timesteps)
344365
else:
345366
self.n_checkpoints = n_checkpoints
346367

347-
self.scheduler = CRevolve(self.n_checkpoints, self.n_timesteps)
368+
self.scheduler = CRevolve(self.n_checkpoints, self.cp_timesteps)
348369

349370
# remove storage list to avoid memory overflow
350371
self.resetStorageList()
@@ -388,13 +409,14 @@ def __init__(
388409
checkpoint,
389410
fwd_operator,
390411
rev_operator,
391-
n_timesteps,
412+
op_timesteps,
392413
storage_list,
393414
timings=None,
394415
profiler=None,
395416
uf=1,
396417
ub=1,
397418
up=1,
419+
block_size=1,
398420
):
399421
"""
400422
Initializes a multi-level Revolver using HRevolve
@@ -408,7 +430,8 @@ def __init__(
408430
checkpoint: checkpoint object
409431
fwd_operator: forward operator
410432
rev_operator: backward operator
411-
n_timesteps: number of timesteps
433+
n_checkpoints: number of checkpoints
434+
op_timesteps: number of timesteps
412435
timings: timings
413436
profiler: profiler
414437
storage_list: list of storage objects
@@ -420,12 +443,14 @@ def __init__(
420443
checkpoint,
421444
fwd_operator,
422445
rev_operator,
423-
n_timesteps,
424-
n_timesteps,
446+
op_timesteps,
447+
op_timesteps,
425448
storage_list=storage_list,
426449
timings=timings,
427450
profiler=profiler,
451+
block_size=block_size,
428452
)
453+
429454
self.uf = uf # forward cost (default=1)
430455
self.ub = ub # backward cost (default=1)
431456
self.up = up # turn cost (default=1)
@@ -449,7 +474,8 @@ def reload_scheduler(self, uf=1, ub=1, up=1):
449474
self.up = up
450475
self.arch = Architecture(self.storage_list)
451476
self.scheduler = HRevolve(
452-
self.n_checkpoints, self.n_timesteps, self.arch, self.uf, self.ub, self.up
477+
self.n_checkpoints, self.cp_timesteps, self.arch, self.uf, self.ub,
478+
self.up
453479
)
454480
else:
455481
raise ValueError(
@@ -498,21 +524,23 @@ def __init__(
498524
fwd_operator,
499525
rev_operator,
500526
n_checkpoints,
501-
n_timesteps,
527+
op_timesteps,
502528
timings=None,
503529
profiler=None,
504530
compression_params=None,
531+
block_size=1,
505532
):
506533
super().__init__(
507534
checkpoint,
508535
fwd_operator,
509536
rev_operator,
510537
n_checkpoints,
511-
n_timesteps,
538+
op_timesteps,
512539
timings=timings,
513540
profiler=profiler,
514541
compression_params=compression_params,
515542
diskstorage=False,
543+
block_size=block_size,
516544
)
517545

518546

@@ -537,23 +565,25 @@ def __init__(
537565
fwd_operator,
538566
rev_operator,
539567
n_checkpoints,
540-
n_timesteps,
568+
op_timesteps,
541569
timings=None,
542570
profiler=None,
543571
filedir="./",
544572
singlefile=True,
573+
block_size=1,
545574
):
546575
super().__init__(
547576
checkpoint,
548577
fwd_operator,
549578
rev_operator,
550579
n_checkpoints,
551-
n_timesteps,
580+
op_timesteps,
552581
timings=timings,
553582
profiler=profiler,
554583
diskstorage=True,
555584
filedir=filedir,
556585
singlefile=singlefile,
586+
block_size=block_size,
557587
)
558588

559589

0 commit comments

Comments
 (0)