11from abc import ABCMeta , abstractproperty , abstractmethod
22import numpy as np
3+ import math
34from . import crevolve as cr
45from .compression import init_compression as init
56from .schedulers import CRevolve , HRevolve , Action , Architecture
@@ -69,11 +70,12 @@ def __init__(
6970 fwd_operator ,
7071 rev_operator ,
7172 n_checkpoints ,
72- n_timesteps ,
73+ op_timesteps ,
7374 storage_list = None ,
7475 scheduler = None ,
7576 timings = None ,
7677 profiler = None ,
78+ block_size = 1 ,
7779 ):
7880 """
7981 Initialises checkpointer for a given forward- and reverse operator, a
@@ -82,10 +84,10 @@ def __init__(
8284 methods and a scheduler object must be provided as well. Otherwise
8385 NumpyStorage and CRevolve are used as default
8486 """
85- if n_timesteps is None :
87+ if op_timesteps is None :
8688 raise Exception (
8789 "Online checkpointing not yet supported. Specify \
88- number of time steps!"
90+ number of Operator time steps!"
8991 )
9092
9193 if profiler is None :
@@ -100,12 +102,15 @@ def __init__(
100102
101103 self .checkpoint = checkpoint
102104 self .n_checkpoints = n_checkpoints
103- self .n_timesteps = n_timesteps
105+ self .block_size = block_size
106+ self .op_timesteps = op_timesteps
104107 self .timings = timings
105108 self .fwd_operator = fwd_operator
106109 self .rev_operator = rev_operator
107110 self .scheduler = scheduler
108111
112+ self .cp_timesteps = int (math .ceil (self .op_timesteps / self .block_size ))
113+
109114 def addStorage (self , new_storage ):
110115 self .storage_list .append (new_storage )
111116
@@ -175,6 +180,20 @@ def makespan(self):
175180 def ratio (self ):
176181 return 0
177182
183+ @property
184+ def op_old_capo (self ):
185+ return self .scheduler .old_capo * self .block_size
186+
187+ @property
188+ def op_capo (self ):
189+ _op_capo = self .scheduler .capo * self .block_size
190+ return _op_capo if _op_capo < self .op_timesteps else self .op_timesteps
191+
192+ @property
193+ def next_op_capo (self ):
194+ _op_capo = (self .scheduler .capo + 1 ) * self .block_size
195+ return _op_capo if _op_capo < self .op_timesteps else self .op_timesteps
196+
178197 def apply_forward (self ):
179198 """Executes only the forward computation while storing checkpoints,
180199 then returns."""
@@ -185,7 +204,7 @@ def apply_forward(self):
185204 # advance forward computation
186205 with self .profiler .get_timer ("forward" , "advance" ):
187206 self .fwd_operator .apply (
188- t_start = self .scheduler . old_capo , t_end = self .scheduler . capo
207+ t_start = self .op_old_capo , t_end = self .op_capo
189208 )
190209 elif action .type == Action .TAKESHOT :
191210 # take a snapshot: copy from workspace into storage
@@ -199,7 +218,7 @@ def apply_forward(self):
199218 # final step in the forward computation
200219 with self .profiler .get_timer ("forward" , "lastfw" ):
201220 self .fwd_operator .apply (
202- t_start = self .scheduler . old_capo , t_end = self .n_timesteps
221+ t_start = self .op_old_capo , t_end = self .op_timesteps
203222 )
204223 break
205224 elif action .type == Action .REVERSE :
@@ -224,10 +243,10 @@ def apply_reverse(self):
224243 # advance adjoint computation by a single step
225244 with self .profiler .get_timer ("reverse" , "reverse" ):
226245 self .fwd_operator .apply (
227- t_start = self .scheduler . capo , t_end = self .scheduler . capo + 1
246+ t_start = self .op_capo , t_end = self .next_op_capo
228247 )
229248 self .rev_operator .apply (
230- t_start = self .scheduler . capo , t_end = self .scheduler . capo + 1
249+ t_start = self .op_capo , t_end = self .next_op_capo
231250 )
232251 elif action .type == Action .REVSTART :
233252 """Sets the rev_operator to 'nt' only if its not already there.
@@ -236,7 +255,7 @@ def apply_reverse(self):
236255 """
237256 with self .profiler .get_timer ("reverse" , "reverse" ):
238257 self .rev_operator .apply (
239- t_start = self .scheduler . capo , t_end = self .scheduler . capo + 1
258+ t_start = self .op_capo , t_end = self .next_op_capo
240259 )
241260 elif action .type == Action .TAKESHOT :
242261 # take a snapshot: copy from workspace into storage
@@ -246,7 +265,7 @@ def apply_reverse(self):
246265 # advance forward computation
247266 with self .profiler .get_timer ("reverse" , "advance" ):
248267 self .fwd_operator .apply (
249- t_start = self .scheduler . old_capo , t_end = self .scheduler . capo
268+ t_start = self .op_old_capo , t_end = self .op_capo
250269 )
251270 elif action .type == Action .RESTORE :
252271 # restore a snapshot: copy from storage into workspace
@@ -303,13 +322,14 @@ def __init__(
303322 fwd_operator ,
304323 rev_operator ,
305324 n_checkpoints ,
306- n_timesteps ,
325+ op_timesteps ,
307326 timings = None ,
308327 profiler = None ,
309328 compression_params = None ,
310329 diskstorage = False ,
311330 filedir = "./" ,
312331 singlefile = True ,
332+ block_size = 1 ,
313333 ):
314334 """
315335 Initializes a single-level Revolver
@@ -318,7 +338,7 @@ def __init__(
318338 fwd_operator: forward operator
319339 rev_operator: backward operator
320340 n_checkpoints: number of checkpoints
321- n_timesteps: number of timesteps
341+ op_timesteps: number of timesteps
322342 timings: timings
323343 profiler: Profiler
324344 compression_params: compression scheme
@@ -331,20 +351,21 @@ def __init__(
331351 fwd_operator ,
332352 rev_operator ,
333353 n_checkpoints ,
334- n_timesteps ,
354+ op_timesteps ,
335355 timings = timings ,
336356 profiler = profiler ,
357+ block_size = block_size ,
337358 )
338359
339360 self .filedir = filedir
340361 self .singlefile = singlefile
341362
342363 if n_checkpoints is None :
343- self .n_checkpoints = cr .adjust (n_timesteps )
364+ self .n_checkpoints = cr .adjust (self . cp_timesteps )
344365 else :
345366 self .n_checkpoints = n_checkpoints
346367
347- self .scheduler = CRevolve (self .n_checkpoints , self .n_timesteps )
368+ self .scheduler = CRevolve (self .n_checkpoints , self .cp_timesteps )
348369
349370 # remove storage list to avoid memory overflow
350371 self .resetStorageList ()
@@ -388,13 +409,14 @@ def __init__(
388409 checkpoint ,
389410 fwd_operator ,
390411 rev_operator ,
391- n_timesteps ,
412+ op_timesteps ,
392413 storage_list ,
393414 timings = None ,
394415 profiler = None ,
395416 uf = 1 ,
396417 ub = 1 ,
397418 up = 1 ,
419+ block_size = 1 ,
398420 ):
399421 """
400422 Initializes a multi-level Revolver using HRevolve
@@ -408,7 +430,8 @@ def __init__(
408430 checkpoint: checkpoint object
409431 fwd_operator: forward operator
410432 rev_operator: backward operator
411- n_timesteps: number of timesteps
433+ n_checkpoints: number of checkpoints
434+ op_timesteps: number of timesteps
412435 timings: timings
413436 profiler: profiler
414437 storage_list: list of storage objects
@@ -420,12 +443,14 @@ def __init__(
420443 checkpoint ,
421444 fwd_operator ,
422445 rev_operator ,
423- n_timesteps ,
424- n_timesteps ,
446+ op_timesteps ,
447+ op_timesteps ,
425448 storage_list = storage_list ,
426449 timings = timings ,
427450 profiler = profiler ,
451+ block_size = block_size ,
428452 )
453+
429454 self .uf = uf # forward cost (default=1)
430455 self .ub = ub # backward cost (default=1)
431456 self .up = up # turn cost (default=1)
@@ -449,7 +474,8 @@ def reload_scheduler(self, uf=1, ub=1, up=1):
449474 self .up = up
450475 self .arch = Architecture (self .storage_list )
451476 self .scheduler = HRevolve (
452- self .n_checkpoints , self .n_timesteps , self .arch , self .uf , self .ub , self .up
477+ self .n_checkpoints , self .cp_timesteps , self .arch , self .uf , self .ub ,
478+ self .up
453479 )
454480 else :
455481 raise ValueError (
@@ -498,21 +524,23 @@ def __init__(
498524 fwd_operator ,
499525 rev_operator ,
500526 n_checkpoints ,
501- n_timesteps ,
527+ op_timesteps ,
502528 timings = None ,
503529 profiler = None ,
504530 compression_params = None ,
531+ block_size = 1 ,
505532 ):
506533 super ().__init__ (
507534 checkpoint ,
508535 fwd_operator ,
509536 rev_operator ,
510537 n_checkpoints ,
511- n_timesteps ,
538+ op_timesteps ,
512539 timings = timings ,
513540 profiler = profiler ,
514541 compression_params = compression_params ,
515542 diskstorage = False ,
543+ block_size = block_size ,
516544 )
517545
518546
@@ -537,23 +565,25 @@ def __init__(
537565 fwd_operator ,
538566 rev_operator ,
539567 n_checkpoints ,
540- n_timesteps ,
568+ op_timesteps ,
541569 timings = None ,
542570 profiler = None ,
543571 filedir = "./" ,
544572 singlefile = True ,
573+ block_size = 1 ,
545574 ):
546575 super ().__init__ (
547576 checkpoint ,
548577 fwd_operator ,
549578 rev_operator ,
550579 n_checkpoints ,
551- n_timesteps ,
580+ op_timesteps ,
552581 timings = timings ,
553582 profiler = profiler ,
554583 diskstorage = True ,
555584 filedir = filedir ,
556585 singlefile = singlefile ,
586+ block_size = block_size ,
557587 )
558588
559589
0 commit comments