Skip to content

Commit 875d545

Browse files
authored
Merge pull request #382 from weaviate/fix/grpc-metadata
fix(grpc): pass request metadata via request context
2 parents 3100a8e + f58c875 commit 875d545

4 files changed

Lines changed: 104 additions & 39 deletions

File tree

internal/api/transport/transport.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ func (t *transport) Do(ctx context.Context, req any, dest any) error {
144144
}
145145
}
146146

147-
func newRPC[In RequestMessage, Out ReplyMessage](req Message[In, Out], dest any) rpcFunc {
147+
func newRPC[In RequestMessage, Out ReplyMessage](req Message[In, Out], dest any) transports.RPC[proto.WeaviateClient] {
148148
dev.AssertType[MessageUnmarshaler[Out]](dest, "dest")
149149
out := dest.(MessageUnmarshaler[Out])
150150

151151
body := req.Body()
152152
dev.AssertNotNil(body, "body")
153153

154-
return rpcFunc(func(ctx context.Context, wc proto.WeaviateClient) error {
154+
return func(ctx context.Context, wc proto.WeaviateClient) error {
155155
in, err := body.MarshalMessage()
156156
if err != nil {
157157
return fmt.Errorf("%s: marshal message: %w", req, err)
@@ -168,16 +168,7 @@ func newRPC[In RequestMessage, Out ReplyMessage](req Message[In, Out], dest any)
168168
return err
169169
}
170170
return nil
171-
})
172-
}
173-
174-
// rpcFunc implements [transports.RPC] as a function.
175-
type rpcFunc func(context.Context, proto.WeaviateClient) error
176-
177-
var _ transports.RPC[proto.WeaviateClient] = (*rpcFunc)(nil)
178-
179-
func (f rpcFunc) Do(ctx context.Context, wc proto.WeaviateClient) error {
180-
return f(ctx, wc)
171+
}
181172
}
182173

183174
// unmarshal unmarshals reply Out into dest. A nil dest means the reply can be ignored,

internal/api/transport/transport_test.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,11 @@ func TestTransport_Do(t *testing.T) {
156156
})
157157

158158
t.Run("grpc message", func(t *testing.T) {
159+
fakeGRPC := gRPCFunc(func(ctx context.Context, rpc transports.RPC[proto.WeaviateClient]) error {
160+
return rpc(ctx, nil)
161+
})
159162
t.Run("ok", func(t *testing.T) {
160-
tport := transport{gRPC: new(fakeGRPC)}
163+
tport := transport{gRPC: fakeGRPC}
161164

162165
// Actual request is captured by message itself,
163166
// because unlike transports.Endpoint, each transports.RPC
@@ -180,7 +183,7 @@ func TestTransport_Do(t *testing.T) {
180183
})
181184

182185
t.Run("error", func(t *testing.T) {
183-
tport := transport{gRPC: new(fakeGRPC)}
186+
tport := transport{gRPC: fakeGRPC}
184187

185188
var resp reply[proto.SearchReply]
186189
req := &message[proto.SearchRequest, proto.SearchReply]{
@@ -273,14 +276,6 @@ func (f gRPCFunc) Do(ctx context.Context, rpc transports.RPC[proto.WeaviateClien
273276
return f(ctx, rpc)
274277
}
275278

276-
// fakeGRPC calls rpc.Do with nil [proto.WeaviateClient].
277-
// It's a dummy that should be used together with [message].
278-
type fakeGRPC struct{}
279-
280-
func (*fakeGRPC) Do(ctx context.Context, rpc transports.RPC[proto.WeaviateClient]) error {
281-
return rpc.Do(ctx, nil)
282-
}
283-
284279
// message implements [Message] for testing.
285280
type message[In RequestMessage, Out ReplyMessage] struct {
286281
req *In // Expected request.

internal/transports/grpc.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,12 @@ type GRPCConfig[Client any] struct {
3030
type NewGRPCClientFunc[Client any] func(grpc.ClientConnInterface) Client
3131

3232
// RPC describes a gRPC request in the given Client.
33-
type RPC[Client any] interface {
34-
Do(context.Context, Client) error
35-
}
33+
type RPC[Client any] func(context.Context, Client) error
3634

3735
func (c *GRPC[Client]) Do(ctx context.Context, rpc RPC[Client]) error {
3836
dev.AssertNotNil(rpc, "rpc")
3937

40-
if err := rpc.Do(ctx, c.client); err != nil {
38+
if err := rpc(ctx, c.client); err != nil {
4139
return fmt.Errorf("grpc: %w", err)
4240
}
4341
return nil
@@ -58,9 +56,7 @@ type GRPC[Client any] struct {
5856
func NewGRPC[Client any](cfg GRPCConfig[Client]) (*GRPC[Client], error) {
5957
dev.AssertNotNil(cfg.NewGRPCClient, "cfg.NewGRPCClient")
6058

61-
callOpts := []grpc.CallOption{
62-
grpc.Header(cfg.Header),
63-
}
59+
var callOpts []grpc.CallOption
6460
if cfg.MaxMessageSize > 0 {
6561
callOpts = append(callOpts,
6662
grpc.MaxCallSendMsgSize(cfg.MaxMessageSize),
@@ -72,6 +68,10 @@ func NewGRPC[Client any](cfg GRPCConfig[Client]) (*GRPC[Client], error) {
7268
grpc.WithDefaultCallOptions(callOpts...),
7369
}
7470

71+
if cfg.Header != nil {
72+
dialOpts = append(dialOpts, withDefaultHeader(*cfg.Header))
73+
}
74+
7575
transportCreds := insecure.NewCredentials()
7676
if cfg.TLS {
7777
transportCreds = credentials.NewTLS(nil)
@@ -105,3 +105,17 @@ var _ io.Closer = (*GRPC[any])(nil)
105105
func (c *GRPC[Client]) Close() error {
106106
return c.channel.Close()
107107
}
108+
109+
// withDefaultHeader creates an interceptor that adds md headers to the request context.
110+
func withDefaultHeader(md metadata.MD) grpc.DialOption {
111+
return grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
112+
var pairs []string
113+
for k, v := range md {
114+
if len(v) == 0 {
115+
continue
116+
}
117+
pairs = append(pairs, k, v[0])
118+
}
119+
return invoker(metadata.AppendToOutgoingContext(ctx, pairs...), method, req, reply, cc, opts...)
120+
})
121+
}

internal/transports/grpc_test.go

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1+
//nolint:errcheck
12
package transports_test
23

34
import (
45
"context"
6+
"net"
7+
"strconv"
8+
"strings"
59
"testing"
610

711
"github.com/stretchr/testify/assert"
812
"github.com/stretchr/testify/require"
913
"github.com/weaviate/weaviate-go-client/v6/internal/testkit"
1014
"github.com/weaviate/weaviate-go-client/v6/internal/transports"
1115
"google.golang.org/grpc"
16+
"google.golang.org/grpc/metadata"
17+
"google.golang.org/protobuf/types/known/emptypb"
1218
)
1319

1420
func TestNewGRPC(t *testing.T) {
@@ -32,18 +38,18 @@ func TestNewGRPC(t *testing.T) {
3238

3339
func TestGRPC_Do(t *testing.T) {
3440
t.Run("ok", func(t *testing.T) {
35-
grpc, err := transports.NewGRPC(transports.GRPCConfig[any]{
41+
gRPC, err := transports.NewGRPC(transports.GRPCConfig[any]{
3642
NewGRPCClient: func(channel grpc.ClientConnInterface) any {
3743
return 92
3844
},
3945
})
4046
require.NoError(t, err, "create grpc transport")
41-
require.NotNil(t, grpc, "grpc transport")
47+
require.NotNil(t, gRPC, "grpc transport")
4248

43-
require.NoError(t, grpc.Do(t.Context(), rpcFunc(func(_ context.Context, client any) error {
49+
require.NoError(t, gRPC.Do(t.Context(), func(_ context.Context, client any) error {
4450
assert.Equal(t, 92, client, "injected client")
4551
return nil
46-
})), "request error")
52+
}), "request error")
4753
})
4854

4955
t.Run("with error", func(t *testing.T) {
@@ -55,17 +61,76 @@ func TestGRPC_Do(t *testing.T) {
5561
require.NoError(t, err, "create grpc transport")
5662
require.NotNil(t, grpc, "grpc transport")
5763

58-
require.ErrorIs(t, grpc.Do(t.Context(), rpcFunc(func(_ context.Context, client any) error {
64+
require.ErrorIs(t, grpc.Do(t.Context(), func(_ context.Context, client any) error {
5965
assert.Equal(t, 92, client, "injected client")
6066
return testkit.ErrWhaam
61-
})), testkit.ErrWhaam, "request error")
67+
}), testkit.ErrWhaam, "request error")
68+
})
69+
70+
t.Run("default headers", func(t *testing.T) {
71+
// Arrange: start a local gRPC server and register a handler with assertions.
72+
ts := startTestService(t, func(_ any, ctx context.Context, _ func(any) error, _ grpc.UnaryServerInterceptor) (any, error) {
73+
md, ok := metadata.FromIncomingContext(ctx)
74+
assert.True(t, ok, "incoming context should contain metadata")
75+
assert.Subset(t, md, metadata.MD{"x-findme": {"foo"}}, "default headers not present in request metadata")
76+
return nil, nil
77+
})
78+
79+
gRPC, err := transports.NewGRPC(transports.GRPCConfig[grpc.ClientConnInterface]{
80+
Host: ts.Host(),
81+
Port: ts.Port(),
82+
Header: &metadata.MD{"X-FindMe": {"foo"}},
83+
NewGRPCClient: func(channel grpc.ClientConnInterface) grpc.ClientConnInterface { return channel },
84+
})
85+
require.NoError(t, err, "new grpc transport")
86+
87+
// Act: our handled above will verify that the request included expected headers.
88+
gRPC.Do(t.Context(), func(ctx context.Context, client grpc.ClientConnInterface) error {
89+
var empty emptypb.Empty
90+
return client.Invoke(ctx, ts.MethodName(), nil, &empty)
91+
})
6292
})
6393
}
6494

65-
type rpcFunc func(ctx context.Context, client any) error
95+
type testService struct {
96+
lis net.Listener
97+
srv *grpc.Server
98+
host string
99+
port int
100+
}
101+
102+
// startTestService starts a local TCP [net.Listener] and creates a [grpc.Server]
103+
// using that listener. The mh handler can be used to make assertions about the
104+
// request or control how the requeset is processed.
105+
//
106+
// All resources are freed via [testing.T.Cleanup] hook.
107+
func startTestService(t *testing.T, mh grpc.MethodHandler) *testService {
108+
lis, err := net.Listen("tcp", "localhost:0")
109+
require.NoError(t, err)
110+
t.Cleanup(func() { lis.Close() })
111+
112+
addr := strings.Split(lis.Addr().String(), ":")
113+
port, _ := strconv.Atoi(addr[1])
114+
115+
srv := grpc.NewServer()
116+
srv.RegisterService(&grpc.ServiceDesc{
117+
ServiceName: "testService",
118+
Methods: []grpc.MethodDesc{
119+
{MethodName: "Test", Handler: mh},
120+
},
121+
}, nil)
66122

67-
var _ transports.RPC[any] = (*rpcFunc)(nil)
123+
go srv.Serve(lis)
124+
t.Cleanup(srv.Stop)
68125

69-
func (f rpcFunc) Do(ctx context.Context, client any) error {
70-
return f(ctx, client)
126+
return &testService{
127+
lis: lis,
128+
srv: srv,
129+
host: addr[0],
130+
port: port,
131+
}
71132
}
133+
134+
func (ts *testService) Host() string { return ts.host }
135+
func (ts *testService) Port() int { return ts.port }
136+
func (ts *testService) MethodName() string { return "/testService/Test" }

0 commit comments

Comments
 (0)