-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Go API - [WIP] #212
base: branch-25.02
Are you sure you want to change the base?
Go API - [WIP] #212
Changes from 52 commits
7961798
764dc1e
b0fd0b1
b47cf5b
1ec0c8c
d5c0dc3
cc6319e
01af7c9
ed9bf47
4e79d8e
d278abc
b91721d
6b90861
773fd94
7cbc1f9
46ec2f7
9dcffbd
f82216c
486d40a
6f5c5a6
bfdf3be
de3cea0
57e8dc0
ab173dc
2d5fb95
505d8dc
c3360ee
e261c8a
a4890ed
f2bac2d
a684b01
7d45e24
a7084c2
3af1b73
44d9e58
9447f63
04b6532
a082f67
a84e764
853d538
4021229
6dd2044
3e33692
b1a0476
dddb165
e05782a
a315632
3c97864
bd2dd76
bd76cf8
1e5a756
927f2be
f7fac35
d814e37
2708bc0
e634821
8c9105a
4a80cf7
f5b8e72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,19 @@ jobs: | |
node_type: "gpu-v100-latest-1" | ||
run_script: "ci/build_rust.sh" | ||
sha: ${{ inputs.sha }} | ||
go-build: | ||
needs: cpp-build | ||
secrets: inherit | ||
uses: rapidsai/shared-workflows/.github/workflows/[email protected] | ||
with: | ||
build_type: ${{ inputs.build_type || 'branch' }} | ||
branch: ${{ inputs.branch }} | ||
arch: "amd64" | ||
date: ${{ inputs.date }} | ||
container_image: "rapidsai/ci-conda:latest" | ||
node_type: "gpu-v100-latest-1" | ||
run_script: "ci/build_go.sh" | ||
sha: ${{ inputs.sha }} | ||
python-build: | ||
needs: [cpp-build] | ||
secrets: inherit | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ jobs: | |
- conda-python-tests | ||
- docs-build | ||
- rust-build | ||
- go-build | ||
- wheel-build-cuvs | ||
- wheel-tests-cuvs | ||
- devcontainer | ||
|
@@ -44,13 +45,15 @@ jobs: | |
- '!notebooks/**' | ||
- '!python/**' | ||
- '!rust/**' | ||
- '!go/**' | ||
- '!thirdparty/LICENSES/**' | ||
test_notebooks: | ||
- '**' | ||
- '!.devcontainer/**' | ||
- '!.pre-commit-config.yaml' | ||
- '!README.md' | ||
- '!rust/**' | ||
- '!go/**' | ||
- '!thirdparty/LICENSES/**' | ||
test_python: | ||
- '**' | ||
|
@@ -61,6 +64,7 @@ jobs: | |
- '!img/**' | ||
- '!notebooks/**' | ||
- '!rust/**' | ||
- '!go/**' | ||
- '!thirdparty/LICENSES/**' | ||
checks: | ||
secrets: inherit | ||
|
@@ -122,6 +126,16 @@ jobs: | |
arch: "amd64" | ||
container_image: "rapidsai/ci-conda:latest" | ||
run_script: "ci/build_rust.sh" | ||
go-build: | ||
needs: conda-cpp-build | ||
secrets: inherit | ||
uses: rapidsai/shared-workflows/.github/workflows/[email protected] | ||
with: | ||
build_type: pull-request | ||
node_type: "gpu-v100-latest-1" | ||
arch: "amd64" | ||
container_image: "rapidsai/ci-conda:latest" | ||
run_script: "ci/build_go.sh" | ||
wheel-build-cuvs: | ||
needs: checks | ||
secrets: inherit | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,40 @@ | ||||
#!/bin/bash | ||||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||||
|
||||
set -euo pipefail | ||||
|
||||
rapids-logger "Create test conda environment" | ||||
. /opt/conda/etc/profile.d/conda.sh | ||||
|
||||
RAPIDS_VERSION="$(rapids-version)" | ||||
|
||||
rapids-dependency-file-generator \ | ||||
--output conda \ | ||||
--file-key go \ | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you know how I could add this file key to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a key under Line 70 in 710e9f5
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More explanation in the README here: https://github.com/rapidsai/dependency-file-generator |
||||
--matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml | ||||
|
||||
rapids-mamba-retry env create --yes -f env.yaml -n go | ||||
|
||||
# seeing failures on activating the environment here on unbound locals | ||||
# apply workaround from https://github.com/conda/conda/issues/8186#issuecomment-532874667 | ||||
set +eu | ||||
conda activate go | ||||
set -eu | ||||
|
||||
rapids-print-env | ||||
|
||||
export CGO_CFLAGS="-I/usr/local/cuda/include -I/home/ajit/miniforge3/envs/cuvs/include" | ||||
export CGO_LDFLAGS="-L/usr/local/cuda/lib64 -L/home/ajit/miniforge3/envs/cuvs/lib -lcudart -lcuvs -lcuvs_c" | ||||
export CC=clang | ||||
|
||||
rapids-logger "Downloading artifacts from previous jobs" | ||||
CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) | ||||
|
||||
# installing libcuvs/libraft will speed up the rust build substantially | ||||
rapids-mamba-retry install \ | ||||
--channel "${CPP_CHANNEL}" \ | ||||
libcuvs \ | ||||
libraft \ | ||||
cuvs | ||||
|
||||
bash ./build.sh go |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
package brute_force | ||
|
||
// #include <cuvs/neighbors/brute_force.h> | ||
import "C" | ||
|
||
import ( | ||
"errors" | ||
"unsafe" | ||
|
||
cuvs "github.com/rapidsai/cuvs/go" | ||
) | ||
|
||
type bruteForceIndex struct { | ||
index C.cuvsBruteForceIndex_t | ||
trained bool | ||
} | ||
|
||
func CreateIndex() (*bruteForceIndex, error) { | ||
var index C.cuvsBruteForceIndex_t | ||
|
||
err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsBruteForceIndexCreate(&index))) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return &bruteForceIndex{index: index, trained: false}, nil | ||
} | ||
|
||
func (index *bruteForceIndex) Close() error { | ||
err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsBruteForceIndexDestroy(index.index))) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
func BuildIndex[T any](Resources cuvs.Resource, Dataset *cuvs.Tensor[T], metric cuvs.Distance, metric_arg float32, index *bruteForceIndex) error { | ||
CMetric, exists := cuvs.CDistances[metric] | ||
|
||
if !exists { | ||
return errors.New("cuvs: invalid distance metric") | ||
} | ||
|
||
err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsBruteForceBuild(C.cuvsResources_t(Resources.Resource), (*C.DLManagedTensor)(unsafe.Pointer(Dataset.C_tensor)), C.cuvsDistanceType(CMetric), C.float(metric_arg), index.index))) | ||
if err != nil { | ||
return err | ||
} | ||
index.trained = true | ||
|
||
return nil | ||
} | ||
|
||
func SearchIndex[T any](resources cuvs.Resource, index bruteForceIndex, queries *cuvs.Tensor[T], neighbors *cuvs.Tensor[int64], distances *cuvs.Tensor[T]) error { | ||
if !index.trained { | ||
return errors.New("index needs to be built before calling search") | ||
} | ||
|
||
prefilter := C.cuvsFilter{ | ||
addr: 0, | ||
_type: C.NO_FILTER, | ||
} | ||
|
||
return cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsBruteForceSearch(C.ulong(resources.Resource), index.index, (*C.DLManagedTensor)(unsafe.Pointer(queries.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(neighbors.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(distances.C_tensor)), prefilter))) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
package brute_force | ||
|
||
import ( | ||
"math/rand/v2" | ||
"testing" | ||
|
||
cuvs "github.com/rapidsai/cuvs/go" | ||
) | ||
|
||
func TestCagra(t *testing.T) { | ||
const ( | ||
nDataPoints = 1024 | ||
nFeatures = 16 | ||
nQueries = 4 | ||
k = 4 | ||
epsilon = 0.001 | ||
) | ||
|
||
resource, _ := cuvs.NewResource(nil) | ||
defer resource.Close() | ||
|
||
testDataset := make([][]float32, nDataPoints) | ||
for i := range testDataset { | ||
testDataset[i] = make([]float32, nFeatures) | ||
for j := range testDataset[i] { | ||
testDataset[i][j] = rand.Float32() | ||
} | ||
} | ||
|
||
dataset, err := cuvs.NewTensor(testDataset) | ||
if err != nil { | ||
t.Fatalf("error creating dataset tensor: %v", err) | ||
} | ||
defer dataset.Close() | ||
|
||
index, _ := CreateIndex() | ||
defer index.Close() | ||
|
||
// use the first 4 points from the dataset as queries : will test that we get them back | ||
// as their own nearest neighbor | ||
queries, _ := cuvs.NewTensor(testDataset[:nQueries]) | ||
defer queries.Close() | ||
|
||
neighbors, err := cuvs.NewTensorOnDevice[int64](&resource, []int64{int64(nQueries), int64(k)}) | ||
if err != nil { | ||
t.Fatalf("error creating neighbors tensor: %v", err) | ||
} | ||
defer neighbors.Close() | ||
|
||
distances, err := cuvs.NewTensorOnDevice[float32](&resource, []int64{int64(nQueries), int64(k)}) | ||
if err != nil { | ||
t.Fatalf("error creating distances tensor: %v", err) | ||
} | ||
defer distances.Close() | ||
|
||
if _, err := dataset.ToDevice(&resource); err != nil { | ||
t.Fatalf("error moving dataset to device: %v", err) | ||
} | ||
|
||
if err := BuildIndex(resource, &dataset, cuvs.DistanceL2, 2.0, index); err != nil { | ||
t.Fatalf("error building index: %v", err) | ||
} | ||
|
||
if err := resource.Sync(); err != nil { | ||
t.Fatalf("error syncing resource: %v", err) | ||
} | ||
|
||
if _, err := queries.ToDevice(&resource); err != nil { | ||
t.Fatalf("error moving queries to device: %v", err) | ||
} | ||
|
||
err = SearchIndex(resource, *index, &queries, &neighbors, &distances) | ||
if err != nil { | ||
t.Fatalf("error searching index: %v", err) | ||
} | ||
|
||
if _, err := neighbors.ToHost(&resource); err != nil { | ||
t.Fatalf("error moving neighbors to host: %v", err) | ||
} | ||
|
||
if _, err := distances.ToHost(&resource); err != nil { | ||
t.Fatalf("error moving distances to host: %v", err) | ||
} | ||
|
||
if err := resource.Sync(); err != nil { | ||
t.Fatalf("error syncing resource: %v", err) | ||
} | ||
|
||
neighborsSlice, err := neighbors.Slice() | ||
if err != nil { | ||
t.Fatalf("error getting neighbors slice: %v", err) | ||
} | ||
|
||
for i := range neighborsSlice { | ||
if neighborsSlice[i][0] != int64(i) { | ||
t.Error("wrong neighbor, expected", i, "got", neighborsSlice[i][0]) | ||
} | ||
} | ||
|
||
distancesSlice, err := distances.Slice() | ||
if err != nil { | ||
t.Fatalf("error getting distances slice: %v", err) | ||
} | ||
|
||
for i := range distancesSlice { | ||
if distancesSlice[i][0] >= epsilon || distancesSlice[i][0] <= -epsilon { | ||
t.Error("distance should be close to 0, got", distancesSlice[i][0]) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
am I doing this correctly?