Skip to content

Commit a79576e

Browse files
authored
add Knowledge distillation TF example (#533)
1 parent 26b53b1 commit a79576e

File tree

6 files changed

+70
-28
lines changed

6 files changed

+70
-28
lines changed

examples/tensorflow/distillation/conf.yaml

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ model:
2121
distillation:
2222
train:
2323
start_epoch: 0
24-
end_epoch: 10
24+
end_epoch: 90
2525
iteration: 1000
2626
frequency: 1
2727
dataloader:
@@ -30,23 +30,21 @@ distillation:
3030
ImageFolder:
3131
root: /path/to/dataset
3232
transform:
33-
AlignImageChannel:
34-
dim: 3
35-
ResizeCropImagenet:
36-
height: 224
37-
width: 224
38-
Normalize:
33+
Resize:
34+
size: 224
35+
interpolation: nearest
36+
KerasRescale:
3937
rescale: [127.5, 1]
4038
optimizer:
4139
SGD:
42-
learning_rate: 0.1
40+
learning_rate: 0.001
4341
momentum: 0.1
4442
nesterov: True
4543
weight_decay: 0.001
4644
criterion:
4745
KnowledgeDistillationLoss:
4846
temperature: 1.0
49-
loss_types: ['CE', 'KL']
47+
loss_types: ['CE', 'CE']
5048
loss_weights: [0.5, 0.5]
5149

5250
evaluation: # optional. required if user doesn't provide eval_func in neural_compressor.Quantization.
@@ -59,14 +57,11 @@ evaluation: # optional. required if use
5957
ImageFolder:
6058
root: /path/to/dataset
6159
transform:
62-
AlignImageChannel:
63-
dim: 3
64-
ResizeCropImagenet:
65-
height: 224
66-
width: 224
67-
Normalize:
60+
Resize:
61+
size: 224
62+
interpolation: nearest
63+
KerasRescale:
6864
rescale: [127.5, 1]
69-
7065
tuning:
7166
accuracy_criterion:
7267
relative: 0.01 # the tuning target of accuracy loss percentage: 1%

examples/tensorflow/distillation/main.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2021 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
118
import argparse
219
import os
320
import logging
@@ -6,9 +23,9 @@
623
import warnings
724
import tensorflow as tf
825
from neural_compressor.utils import logger
9-
model_names = ['mobilenet','mobilenetv2']
26+
model_names = ['mobilenet','densenet201']
1027

11-
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
28+
parser = argparse.ArgumentParser(description='Tensorflow ImageNet Training')
1229
parser.add_argument('-t', '--topology', metavar='ARCH', default='resnet18',
1330
choices=model_names,
1431
help='model architecture: ' +
@@ -51,11 +68,10 @@ def main_worker(args):
5168
global best_acc1
5269

5370
print("=> using pre-trained model '{}'".format(args.topology))
54-
model = tf.keras.applications.mobilenet.MobileNet(weights='imagenet')
55-
71+
model = tf.keras.applications.MobileNet(weights='imagenet')
72+
5673
print("=> using pre-trained teacher model '{}'".format(args.teacher))
57-
teacher_model = tf.keras.applications.mobilenet_v2.MobileNetV2(weights='imagenet')
58-
# optionally resume from a checkpoint
74+
teacher_model = tf.keras.applications.DenseNet201(weights='imagenet')
5975

6076
if args.distillation:
6177
from neural_compressor.experimental import Distillation, common

neural_compressor/adaptor/tensorflow.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,9 @@ def train(self, model, dataloader, optimizer_tuple,
120120
iters = kwargs['kwargs'].get('iteration', None)
121121
callbacks = kwargs['kwargs'].get('callbacks', None)
122122
distributed = getattr(dataloader, 'distributed', False)
123-
124123
from neural_compressor.experimental.common.criterion import TensorflowKnowledgeDistillationLoss
125124
if isinstance(criterion, TensorflowKnowledgeDistillationLoss):
126-
input_model = model._model
125+
input_model = model._model
127126
else:
128127
input_model = tf.keras.models.load_model(model._model)
129128
hooks = callbacks['tf_pruning'](model, input_model, hooks)
@@ -148,7 +147,7 @@ def train(self, model, dataloader, optimizer_tuple,
148147
def training_step(first_batch):
149148
with tf.GradientTape() as tape:
150149
tape.watch(input_model.trainable_variables)
151-
y_ = input_model(x, training=True)
150+
y_ = input_model(x)
152151
loss_value = criterion(y, y_)
153152
tape = self.hvd.DistributedGradientTape(tape) if distributed else tape
154153
# Get gradient
@@ -178,17 +177,21 @@ def training_step(first_batch):
178177
hooks['on_batch_end']() # on_batch_end hook
179178
if iters is not None and cnt >= iters:
180179
break
180+
model._sess = None
181181
hooks['on_epoch_end']() # on_epoch_end hook
182182
# End epoch
183183
train_loss_results.append(epoch_loss_avg.result())
184184
if not distributed or self.hvd.local_rank() == 0:
185185
logger.info("Epoch {:03d}: Loss: {:.3f}".format(epoch+1, epoch_loss_avg.result()))
186+
186187
hooks['post_epoch_end']() # post_epoch_end hook
187188
model._sess = None
188189
if not isinstance(criterion, TensorflowKnowledgeDistillationLoss):
189190
if not distributed or self.hvd.rank() == 0:
190191
# Update the input model with pruned weights manually due to keras API limitation.
191192
input_model.save(model._model)
193+
else:
194+
input_model.save('distillation_model')
192195

193196
@dump_elapsed_time(customized_msg="Model inference")
194197
def evaluate(self, model, dataloader, postprocess=None,

neural_compressor/experimental/common/criterion.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from neural_compressor.utils.utility import LazyImport, singleton
2121
from neural_compressor.utils import logger
2222

23+
import numpy as np
24+
2325
torch = LazyImport('torch')
2426
tf = LazyImport('tensorflow')
2527

@@ -193,7 +195,9 @@ def loss_cal(self, student_outputs, targets):
193195

194196
def __call__(self, student_outputs, targets):
195197
if isinstance(self, TensorflowKnowledgeDistillationLoss):
196-
student_outputs, targets = targets, student_outputs
198+
tmp = student_outputs
199+
student_outputs = targets
200+
targets = tmp
197201
return self.loss_cal(student_outputs, targets)
198202

199203
class PyTorchKnowledgeDistillationLoss(KnowledgeDistillationLoss):
@@ -293,22 +297,26 @@ def __init__(self, temperature=1.0, loss_types=['CE', 'CE'],
293297
loss_weights=loss_weights)
294298
if self.student_targets_loss is None:
295299
if self.loss_types[0] == 'CE':
296-
self.student_targets_loss = tf.nn.sparse_softmax_cross_entropy_with_logits
300+
self.student_targets_loss = tf.keras.losses.SparseCategoricalCrossentropy()
297301
else:
298302
raise NotImplementedError('Now we only support CrossEntropyLoss '
299303
'for loss of student model output with respect to targets.')
300304
logger.info('student_targets_loss: {}, {}'.format(self.loss_types[0], \
301305
self.loss_weights[0]))
302306
if self.teacher_student_loss is None:
303307
if self.loss_types[1] == 'CE':
304-
self.teacher_student_loss = tf.keras.losses.CategoricalCrossentropy()
308+
self.teacher_student_loss = self.SoftCrossEntropy
305309
elif self.loss_types[1] == 'KL':
306310
self.teacher_student_loss = tf.keras.losses.KLDivergence()
307311
else:
308312
raise NotImplementedError('Now we only support CrossEntropyLoss'
309313
' for loss of student model output with respect to teacher model ouput.')
310314
logger.info('teacher_student_loss: {}, {}'.format(self.loss_types[1], \
311315
self.loss_weights[1]))
316+
def SoftCrossEntropy(self, targets, logits):
317+
log_prob = tf.math.log(logits)
318+
targets_prob = targets
319+
return tf.math.reduce_mean(tf.math.reduce_sum(- targets_prob * log_prob, axis=-1), axis=-1)
312320

313321
def teacher_model_forward(self, input, teacher_model=None):
314322
if self.loss_weights[1] > 0 and input is not None:

neural_compressor/experimental/data/datasets/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,8 @@ def __getitem__(self, index):
701701
sample = self.image_list[index]
702702
label = sample[1]
703703
with Image.open(sample[0]) as image:
704+
if image.mode != 'RGB':
705+
image = image.convert('RGB')
704706
image = np.array(image)
705707
if self.transform is not None:
706708
image, label = self.transform((image, label))

neural_compressor/experimental/data/transforms/transform.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,24 @@ def __call__(self, sample):
12271227
image -= self.rescale[1]
12281228
return (image, label)
12291229

1230+
@transform_registry(transform_type='KerasRescale', process="preprocess", \
1231+
framework='tensorflow')
1232+
class RescaleKerasPretrainTransform(BaseTransform):
1233+
"""Scale the values of image to [0,1].
1234+
1235+
Returns:
1236+
tuple of processed image and label
1237+
"""
1238+
def __init__(self, rescale=None):
1239+
self.rescale = rescale
1240+
1241+
def __call__(self, sample):
1242+
image, label = sample
1243+
if self.rescale:
1244+
image /= self.rescale[0]
1245+
image -= self.rescale[1]
1246+
return (image, label)
1247+
12301248
@transform_registry(transform_type='Rescale', process="preprocess", \
12311249
framework='tensorflow')
12321250
class RescaleTFTransform(BaseTransform):

0 commit comments

Comments
 (0)