|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): |
|
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) |
|
return F.nll_loss( |
|
lprobs, |
|
target, |
|
ignore_index=ignore_index, |
|
reduction=reduction, |
|
) |
|
|
|
|
|
try: |
|
import xentropy_cuda |
|
from apex.contrib import xentropy |
|
|
|
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): |
|
if logits.device == torch.device("cpu"): |
|
return _cross_entropy_pytorch(logits, target, ignore_index, reduction) |
|
else: |
|
if not getattr(cross_entropy, "_has_logged_once", False): |
|
logger.info("using fused cross entropy") |
|
cross_entropy._has_logged_once = True |
|
|
|
half_to_float = logits.dtype == torch.half |
|
losses = xentropy.SoftmaxCrossEntropyLoss.apply( |
|
logits, |
|
target, |
|
0.0, |
|
ignore_index, |
|
half_to_float, |
|
) |
|
if reduction == "sum": |
|
return losses.sum() |
|
elif reduction == "mean": |
|
if ignore_index >= 0: |
|
return losses.sum() / target.ne(ignore_index).sum() |
|
else: |
|
return losses.mean() |
|
elif reduction == "none": |
|
return losses |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
except ImportError: |
|
|
|
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): |
|
return _cross_entropy_pytorch(logits, target, ignore_index, reduction) |
|
|