|
20 | 20 | from neural_compressor.utils.utility import LazyImport, singleton |
21 | 21 | from neural_compressor.utils import logger |
22 | 22 |
|
| 23 | +import numpy as np |
| 24 | + |
23 | 25 | torch = LazyImport('torch') |
24 | 26 | tf = LazyImport('tensorflow') |
25 | 27 |
|
@@ -193,7 +195,9 @@ def loss_cal(self, student_outputs, targets): |
193 | 195 |
|
194 | 196 | def __call__(self, student_outputs, targets): |
195 | 197 | if isinstance(self, TensorflowKnowledgeDistillationLoss): |
196 | | - student_outputs, targets = targets, student_outputs |
| 198 | + tmp = student_outputs |
| 199 | + student_outputs = targets |
| 200 | + targets = tmp |
197 | 201 | return self.loss_cal(student_outputs, targets) |
198 | 202 |
|
199 | 203 | class PyTorchKnowledgeDistillationLoss(KnowledgeDistillationLoss): |
@@ -293,22 +297,26 @@ def __init__(self, temperature=1.0, loss_types=['CE', 'CE'], |
293 | 297 | loss_weights=loss_weights) |
294 | 298 | if self.student_targets_loss is None: |
295 | 299 | if self.loss_types[0] == 'CE': |
296 | | - self.student_targets_loss = tf.nn.sparse_softmax_cross_entropy_with_logits |
| 300 | + self.student_targets_loss = tf.keras.losses.SparseCategoricalCrossentropy() |
297 | 301 | else: |
298 | 302 | raise NotImplementedError('Now we only support CrossEntropyLoss ' |
299 | 303 | 'for loss of student model output with respect to targets.') |
300 | 304 | logger.info('student_targets_loss: {}, {}'.format(self.loss_types[0], \ |
301 | 305 | self.loss_weights[0])) |
302 | 306 | if self.teacher_student_loss is None: |
303 | 307 | if self.loss_types[1] == 'CE': |
304 | | - self.teacher_student_loss = tf.keras.losses.CategoricalCrossentropy() |
| 308 | + self.teacher_student_loss = self.SoftCrossEntropy |
305 | 309 | elif self.loss_types[1] == 'KL': |
306 | 310 | self.teacher_student_loss = tf.keras.losses.KLDivergence() |
307 | 311 | else: |
308 | 312 | raise NotImplementedError('Now we only support CrossEntropyLoss' |
309 | 313 | ' for loss of student model output with respect to teacher model ouput.') |
310 | 314 | logger.info('teacher_student_loss: {}, {}'.format(self.loss_types[1], \ |
311 | 315 | self.loss_weights[1])) |
| 316 | + def SoftCrossEntropy(self, targets, logits): |
| 317 | + log_prob = tf.math.log(logits) |
| 318 | + targets_prob = targets |
| 319 | + return tf.math.reduce_mean(tf.math.reduce_sum(- targets_prob * log_prob, axis=-1), axis=-1) |
312 | 320 |
|
313 | 321 | def teacher_model_forward(self, input, teacher_model=None): |
314 | 322 | if self.loss_weights[1] > 0 and input is not None: |
|
0 commit comments