Skip to content

Conversation

@stas00
Copy link
Contributor

@stas00 stas00 commented Sep 29, 2021

This PR is @corbyrosset's suggestion at how to overcome the 104B numerical instability we have been experiencing. Quoting:

Re: 104B instability (https://huggingface.slack.com/archives/C01NHER1JLS/p1632801340055000) One thing I've encountered before is how the self-attention is computed. E.g. this line shows that the norm_factor may be multiplied after the Query * Key matrix multiplication. If the dim of Q and K are very large, the output may blow up and the norm_factor won't be able to save it.

Proposal: move the norm_factor inward, so Q and K are scaled down before matrix multiply:

        matmul_result = torch.baddbmm(
            matmul_result,
            1.0/math.sqrt(self.norm_factor) * query_layer.transpose(0, 1),   # [b * np, sq, hn]
            1.0/math.sqrt(self.norm_factor) * key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0 if alibi is None else 1.0, alpha=1.0)

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

To make the operation mathematically equivalent, moving the norm factor inward requires taking sqrt again
if n is a scalar, A and B matrices:

n * (A dot B) === (sqrt(n) * A) dot (sqrt(n) * B)

Also thanks to @RezaYazdaniAminabadi who helped to find where this function is defined in CuBlas https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmStridedBatchedEx and which includes the definition:

C+istrideC=αop(A+istrideA)op(B+istrideB)+β(C+istrideC), for i ∈[0,batchCount−1]

the issue is alpha is multiplied after the matrix-matrix mul is done so it can cause instability

@stas00 stas00 merged commit 48902f1 into tr8-104B Sep 29, 2021
@stas00 stas00 deleted the stas00-patch-1 branch September 29, 2021 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants