torch--drop out

星空28發表於2024-11-07
"""
drop out隨機丟棄神經元
"""

import torch


def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 在本情況中,所有元素都被丟棄
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情況中,所有元素都被保留
    if dropout == 0:
        return X

    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)


X = torch.arange(16, dtype=torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))

相關文章