Conversation
There was a problem hiding this comment.
We do operations on int16s because there is (surprisingly) no instruction to multiply-add signed int8 vectors.
There was a problem hiding this comment.
This section sums the 8 32-bit ints in Y0 by repeatedly folding it in half and adding vertically. We are left with the sum in the rightmost position of X0
There was a problem hiding this comment.
In case our input is not a multiple of 16 (which it will be for OpenAI embeddings), this handles the remainder
goos: linux
goarch: amd64
pkg: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings
cpu: Intel(R) Xeon(R) CPU E5-2643 v3 @ 3.40GHz
│ /tmp/before.txt │ /tmp/after.txt │
│ sec/op │ sec/op vs base │
SimilaritySearch/numWorkers=1-24 1817.3m ± 48% 286.2m ± 67% -84.25% (p=0.000 n=10)
SimilaritySearch/numWorkers=2-24 845.2m ± 31% 150.9m ± 21% -82.15% (p=0.000 n=10)
SimilaritySearch/numWorkers=4-24 593.8m ± 21% 107.0m ± 15% -81.99% (p=0.000 n=10)
SimilaritySearch/numWorkers=8-24 302.67m ± 19% 77.49m ± 14% -74.40% (p=0.000 n=10)
SimilaritySearch/numWorkers=16-24 173.93m ± 12% 83.05m ± 6% -52.25% (p=0.000 n=10)
geomean 544.9m 124.3m -77.18%
805c336 to
8bca4c1
Compare
0b9ac83 to
d667d02
Compare
| func CosineSimilarity(row []int8, query []int8) int32 { | ||
| similarity := int32(0) | ||
|
|
||
| count := len(row) | ||
| if count > len(query) { | ||
| // Do this ahead of time so the compiler doesn't need to bounds check | ||
| // every time we index into query. | ||
| panic("mismatched vector lengths") | ||
| } | ||
|
|
||
| i := 0 | ||
| for ; i+3 < count; i += 4 { | ||
| m0 := int32(row[i]) * int32(query[i]) | ||
| m1 := int32(row[i+1]) * int32(query[i+1]) | ||
| m2 := int32(row[i+2]) * int32(query[i+2]) | ||
| m3 := int32(row[i+3]) * int32(query[i+3]) | ||
| similarity += (m0 + m1 + m2 + m3) | ||
| } | ||
|
|
||
| for ; i < count; i++ { | ||
| similarity += int32(row[i]) * int32(query[i]) | ||
| } | ||
|
|
||
| return similarity | ||
| } | ||
|
|
||
| func CosineSimilarityFloat32(row []float32, query []float32) float32 { | ||
| similarity := float32(0) | ||
|
|
||
| count := len(row) | ||
| if count > len(query) { | ||
| // Do this ahead of time so the compiler doesn't need to bounds check | ||
| // every time we index into query. | ||
| panic("mismatched vector lengths") | ||
| } | ||
|
|
||
| i := 0 | ||
| for ; i+3 < count; i += 4 { | ||
| m0 := row[i] * query[i] | ||
| m1 := row[i+1] * query[i+1] | ||
| m2 := row[i+2] * query[i+2] | ||
| m3 := row[i+3] * query[i+3] | ||
| similarity += (m0 + m1 + m2 + m3) | ||
| } | ||
|
|
||
| for ; i < count; i++ { | ||
| similarity += row[i] * query[i] | ||
| } | ||
|
|
||
| return similarity | ||
| } |
There was a problem hiding this comment.
I moved these into dot.go and renamed them to Dot*. The dot product is only equivalent to cosine similarity if the vectors are normalized, so I think the rename is justified in the case we ever use non-normalized vectors.
jtibshirani
left a comment
There was a problem hiding this comment.
This makes sense to me, but I'm not up-to-speed on Go assembly! Maybe there's someone more knowledgeable who could give a timely review?
Also, it's interesting that when testing on the GCE instance, we get a max 2x speeduup. This is different from what we observed when testing locally, where the speedup scales with the number of workers: https://github.com/sourcegraph/sourcegraph/pull/51372. It means we should limit the request parallelism to something more conservative like 2 threads, rather than the number of processors as we do now. This would be for a follow-up, as it's separate from this PR.
vdavid
left a comment
There was a problem hiding this comment.
I went through the code (on mobile 😄 ) and it LGTM. It's been two decades since I last did assembly and, of course, I have no idea of these new instructions, but it looks reasonable, and the test coverage is convincing. Overall, wow, impressive work!
jtibshirani
left a comment
There was a problem hiding this comment.
Looks good to me from the config and search side. I will dig through the assembly at another time :)
|
The backport to To backport manually, run these commands in your terminal: # Fetch latest updates from GitHub
git fetch
# Create a new working tree
git worktree add .worktrees/backport-5.0 5.0
# Navigate to the new working tree
cd .worktrees/backport-5.0
# Create a new branch
git switch --create backport-51372-to-5.0
# Cherry-pick the merged commit of this pull request and resolve the conflicts
git cherry-pick -x --mainline 1 17a8ec942c1eaca26ae62191460e7ff9bd6285aa
# Push it to GitHub
git push --set-upstream origin backport-51372-to-5.0
# Go back to the original working tree
cd ../..
# Delete the working tree
git worktree remove .worktrees/backport-5.0Then, create a pull request where the |
This implements a hand-written assembly version of the int8 dot product that takes advantage of AVX2 SIMD instructions. This speeds up our embeddings searches by roughly 10x on modern x86_64 machines. (cherry picked from commit 17a8ec9)
| @@ -0,0 +1,70 @@ | |||
| #include "textflag.h" | |||
There was a problem hiding this comment.
This file is missing from this patch; maybe remove this #include?
There was a problem hiding this comment.
What do you mean "missing from this patch"? The #include "textflag.h" is needed to define the NOSPLIT symbol
There was a problem hiding this comment.
If you mean "you didn't commit a textflag.h file", it's a go compiler builtin
| SUBQ $16, DX | ||
| JMP blockloop | ||
|
|
||
| reduce: |
There was a problem hiding this comment.
Given that reduce: and tailloop: are running only once (or are very small), it would make things a bit simpler to remove the assembly for them and write them in Go. (Assuming it's possible to access Y0 from Go, otherwise it would make sense to leave the reduce code in assembly)
There was a problem hiding this comment.
AFAIK, it is not possible to access Y0 from Go without assembly
| MOVQ a_base+0(FP), AX | ||
| MOVQ b_base+24(FP), BX | ||
| MOVQ a_len+8(FP), DX |
There was a problem hiding this comment.
Is this hard-coding the stack offsets based on the ABI of a slice? If so, it'll break if the compiler starts using SROA for slices.
Instead, you could use https://pkg.go.dev/unsafe#SliceData to get the underlying pointer in a stable way. Then this function would take in two pointers and the one length as the arguments, instead of hard-coding stack offsets here.
Otherwise, at least add a comment describing where these hard-coded constants come from.
There was a problem hiding this comment.
My understanding was that, unless I opt into ABIInternal (or any future stable ABI), I can depend on the current (stack-based) ABI to be stable.
Of note, the a_base notation is a mnemonic that is checked by go vet. So if the field offset does not line up with the FP offset I specified there, go vet will complain.
I'll add some comments describing the offsets though 👍
| VPMOVSXBW (AX), Y1 | ||
| VPMOVSXBW (BX), Y2 |
There was a problem hiding this comment.
Could you add a link to the calling convention where it's described whether these registers are preserved or clobbered across a call? It seems like this code is assuming that all the registers it is using are caller-preserve (i.e. OK to clobber).
There was a problem hiding this comment.
See "Clobber sets" here. Based on my read of it, I do not need to worry about callee-saved registers
| JMP tailloop | ||
|
|
||
| end: | ||
| MOVQ R8, ret+48(FP) |
There was a problem hiding this comment.
Is there a more "stable"/"reliable" way to get this rather than hard-coding the stack offset? IIUC this is just writing the return value, It'll also break if Go starts returning small return values in registers.
I'm surprised this is actually working, I thought Go started using a register based calling convention recently... Maybe only for parameters or for functions implemented in Go?
There was a problem hiding this comment.
All the examples I've seen hard-code the stack offset. I agree it's awkward and error-prone.
The register-based calling convention is only used for compiled go source code unless you opt into it with the ABIInternal flag on the function definition. So, since I did not opt into it, I am using the old stack-based calling convention
There was a problem hiding this comment.
This will also be caught by go vet though. ret is the implicit return variable name, so go vet will check that 48(FP) is, in fact, the target address for the return value
There was a problem hiding this comment.
Actually though, I thought go vet runs in CI. It does not. Apparently, I miscalculated the frame size. PR incoming.
There was a problem hiding this comment.
go vet should be running 🏃♀️ It runs as part of the nogo linters 🤔 I'll double check the BAZEL config to make sure
keegancsmith
left a comment
There was a problem hiding this comment.
neat.
Also, it's interesting that when testing on the GCE instance, we get a max 2x speeduup
@jtibshirani is it possible that the node you are running on has a lot more CPUs than kubernetes is configured to let you use? So we end up doing too much parallelism / something else is confusing in the measurements. I would make sure we are using automaxprocs: https://sourcegraph.com/search?q=context:global+r:%5Egithub%5C.com/sourcegraph/+maxprocs.Set&patternType=standard&sm=0&groupBy=repo
| got := Dot(a, b) | ||
|
|
||
| if want != got { | ||
| t.Fatalf("a: %#v\nb: %#v\ngot: %d\nwant: %d", a, b, got, want) |
There was a problem hiding this comment.
t.Log otherwise you never return false.
|
@keegancsmith, the benchmarks were not running in kubernetes, so it's unlikely to be related to reserved CPUs. For clarity, we're seeing different behaviors on the different machines I've tested on. M1 scales linearly up to 8 cores, which is what we based our initial assumption of scaling on. My home server (2014 intel 12 core) scales linearly up to 16 cores without SIMD, but only up to 8 cores with SIMD. I expect it's starting to hit memory bandwidth and/or cache limits with SIMD implementation (M1s have stupidly good memory bandwidth). The GCE n2-standard-4 scales up to 2 cores without SIMD, and 4 cores with SIMD, but the 4 cores is only ~2x faster than 1 core. This is where the "2x" number is coming from. This problem should be very parallel-friendly, but we could be hitting cache effects. I put together a spreadsheet with the numbers I'm working from. Note that these numbers aren't super rigorous and I haven't looked into this very closely. It's more of an observation that I thought was interesting and probably deserves a little bit of looking to make sure we're not throwing more CPU at the problem than we can use. |
|
Something we could consider here, since we are using It might be worth it to consider building another set of binaries that unlock these gains at the runtime level by setting |
|
AFAICT, the runtime uses very few v3-specific features so far. Am I looking at the right thing? |
|
I'm comparing this hand-written assembly to what clang generates from C++ code, and shouldn't there be a |
This implements a hand-written assembly version of the int8 dot product that takes advantage of AVX2 SIMD instructions. This speeds up our embeddings searches by roughly 10x on modern x86_64 machines.
Test plan
Added quickchecks and fuzz tests to compare output with the go version.
The following benchmark is for a
n2-standard-4GCE instance, which is a very standard machine type. For the single core benchmark, we can search about 6 million embeddings per second, which is the equivalent of a 6GB monorepo.This is low-risk to merge because it is disabled by default.