diff --git a/ext/go-plugin/capabilities.go b/ext/go-plugin/capabilities.go index 30517d515..9991e676f 100644 --- a/ext/go-plugin/capabilities.go +++ b/ext/go-plugin/capabilities.go @@ -2,6 +2,7 @@ package main import ( "C" + "context" "encoding/json" "errors" @@ -53,7 +54,8 @@ func GuestCapability(pluginName, pluginType, cname, cplatform, cargs, cmachine * cap := &vagrant.SystemCapability{ Name: to_gs(cname), Platform: to_gs(cplatform)} - r.Result, r.Error = p.GuestCapability(cap, args, machine) + ctx := context.Background() + r.Result, r.Error = p.GuestCapability(ctx, cap, args, machine) return r.Dump() } @@ -103,7 +105,8 @@ func HostCapability(pluginName, pluginType, cname, cplatform, cargs, cenv *C.cha cap := &vagrant.SystemCapability{ Name: to_gs(cname), Platform: to_gs(cplatform)} - r.Result, r.Error = p.HostCapability(cap, args, env) + ctx := context.Background() + r.Result, r.Error = p.HostCapability(ctx, cap, args, env) return r.Dump() } @@ -153,6 +156,7 @@ func ProviderCapability(pluginName, pluginType, cname, cprovider, cargs, cmach * cap := &vagrant.ProviderCapability{ Name: to_gs(cname), Provider: to_gs(cprovider)} - r.Result, r.Error = p.ProviderCapability(cap, args, m) + ctx := context.Background() + r.Result, r.Error = p.ProviderCapability(ctx, cap, args, m) return r.Dump() } diff --git a/ext/go-plugin/config.go b/ext/go-plugin/config.go index 0aeba47df..34f69fabf 100644 --- a/ext/go-plugin/config.go +++ b/ext/go-plugin/config.go @@ -2,6 +2,7 @@ package main import ( "C" + "context" "encoding/json" "errors" @@ -27,7 +28,8 @@ func ConfigLoad(pluginName, pluginType, data *C.char) *C.char { if r.Error != nil { return r.Dump() } - r.Result, r.Error = p.ConfigLoad(cdata) + ctx := context.Background() + r.Result, r.Error = p.ConfigLoad(ctx, cdata) return r.Dump() } @@ -71,7 +73,8 @@ func ConfigValidate(pluginName, pluginType, data, machData *C.char) *C.char { if r.Error != nil { return r.Dump() } - r.Result, r.Error = p.ConfigValidate(cdata, m) + ctx := context.Background() + r.Result, r.Error = p.ConfigValidate(ctx, cdata, m) return r.Dump() } @@ -91,7 +94,8 @@ func ConfigFinalize(pluginName, pluginType, data *C.char) *C.char { var cdata map[string]interface{} r.Error = json.Unmarshal([]byte(to_gs(data)), &cdata) if r.Error == nil { - r.Result, r.Error = p.ConfigFinalize(cdata) + ctx := context.Background() + r.Result, r.Error = p.ConfigFinalize(ctx, cdata) } return r.Dump() } diff --git a/ext/go-plugin/provider.go b/ext/go-plugin/provider.go index 881528b76..146d502e4 100644 --- a/ext/go-plugin/provider.go +++ b/ext/go-plugin/provider.go @@ -2,6 +2,7 @@ package main import ( "C" + "context" "encoding/json" "errors" @@ -42,7 +43,8 @@ func ProviderAction(providerName *C.char, actionName *C.char, machData *C.char) return r.Dump() } aName := to_gs(actionName) - r.Result, r.Error = p.Action(aName, m) + ctx := context.Background() + r.Result, r.Error = p.Action(ctx, aName, m) return r.Dump() } @@ -64,7 +66,8 @@ func ProviderIsInstalled(providerName *C.char, machData *C.char) *C.char { r.Error = err return r.Dump() } - r.Result, r.Error = p.IsInstalled(m) + ctx := context.Background() + r.Result, r.Error = p.IsInstalled(ctx, m) return r.Dump() } @@ -87,7 +90,8 @@ func ProviderIsUsable(providerName *C.char, machData *C.char) *C.char { r.Error = err return r.Dump() } - r.Result, r.Error = p.IsUsable(m) + ctx := context.Background() + r.Result, r.Error = p.IsUsable(ctx, m) return r.Dump() } @@ -109,7 +113,8 @@ func ProviderMachineIdChanged(providerName *C.char, machData *C.char) *C.char { r.Error = err return r.Dump() } - r.Error = p.MachineIdChanged(m) + ctx := context.Background() + r.Error = p.MachineIdChanged(ctx, m) return r.Dump() } @@ -138,7 +143,8 @@ func ProviderRunAction(providerName *C.char, actName *C.char, runData *C.char, m r.Error = err return r.Dump() } - r.Result, r.Error = p.RunAction(aName, rData, m) + ctx := context.Background() + r.Result, r.Error = p.RunAction(ctx, aName, rData, m) return r.Dump() } @@ -160,7 +166,8 @@ func ProviderSshInfo(providerName *C.char, machData *C.char) *C.char { r.Error = err return r.Dump() } - r.Result, r.Error = p.SshInfo(m) + ctx := context.Background() + r.Result, r.Error = p.SshInfo(ctx, m) return r.Dump() } @@ -182,6 +189,7 @@ func ProviderState(providerName *C.char, machData *C.char) *C.char { r.Error = err return r.Dump() } - r.Result, r.Error = p.State(m) + ctx := context.Background() + r.Result, r.Error = p.State(ctx, m) return r.Dump() } diff --git a/ext/go-plugin/synced_folder.go b/ext/go-plugin/synced_folder.go index 96d811fbf..5faef28a1 100644 --- a/ext/go-plugin/synced_folder.go +++ b/ext/go-plugin/synced_folder.go @@ -2,6 +2,7 @@ package main import ( "C" + "context" "encoding/json" "errors" @@ -47,7 +48,8 @@ func SyncedFolderCleanup(pluginName, machine, opts *C.char) *C.char { if r.Error != nil { return r.Dump() } - r.Error = p.Cleanup(m, o) + ctx := context.Background() + r.Error = p.Cleanup(ctx, m, o) return r.Dump() } @@ -79,7 +81,8 @@ func SyncedFolderDisable(pluginName, machine, folders, opts *C.char) *C.char { if r.Error != nil { return r.Dump() } - r.Error = p.Disable(m, f, o) + ctx := context.Background() + r.Error = p.Disable(ctx, m, f, o) return r.Dump() } @@ -111,7 +114,8 @@ func SyncedFolderEnable(pluginName, machine, folders, opts *C.char) *C.char { if r.Error != nil { return r.Dump() } - r.Error = p.Enable(m, f, o) + ctx := context.Background() + r.Error = p.Enable(ctx, m, f, o) return r.Dump() } @@ -133,7 +137,8 @@ func SyncedFolderIsUsable(pluginName, machine *C.char) *C.char { r.Error = err return r.Dump() } - r.Result, r.Error = p.IsUsable(m) + ctx := context.Background() + r.Result, r.Error = p.IsUsable(ctx, m) return r.Dump() } @@ -165,6 +170,7 @@ func SyncedFolderPrepare(pluginName, machine, folders, opts *C.char) *C.char { if r.Error != nil { return r.Dump() } - r.Error = p.Prepare(m, f, o) + ctx := context.Background() + r.Error = p.Prepare(ctx, m, f, o) return r.Dump() } diff --git a/ext/go-plugin/vagrant/capabilities.go b/ext/go-plugin/vagrant/capabilities.go index 996cdddba..44bb66c57 100644 --- a/ext/go-plugin/vagrant/capabilities.go +++ b/ext/go-plugin/vagrant/capabilities.go @@ -1,5 +1,9 @@ package vagrant +import ( + "context" +) + type SystemCapability struct { Name string `json:"name"` Platform string `json:"platform"` @@ -12,17 +16,17 @@ type ProviderCapability struct { type GuestCapabilities interface { GuestCapabilities() (caps []SystemCapability, err error) - GuestCapability(cap *SystemCapability, args interface{}, machine *Machine) (result interface{}, err error) + GuestCapability(ctx context.Context, cap *SystemCapability, args interface{}, machine *Machine) (result interface{}, err error) } type HostCapabilities interface { HostCapabilities() (caps []SystemCapability, err error) - HostCapability(cap *SystemCapability, args interface{}, env *Environment) (result interface{}, err error) + HostCapability(ctx context.Context, cap *SystemCapability, args interface{}, env *Environment) (result interface{}, err error) } type ProviderCapabilities interface { ProviderCapabilities() (caps []ProviderCapability, err error) - ProviderCapability(cap *ProviderCapability, args interface{}, machine *Machine) (result interface{}, err error) + ProviderCapability(ctx context.Context, cap *ProviderCapability, args interface{}, machine *Machine) (result interface{}, err error) } type NoGuestCapabilities struct{} @@ -34,7 +38,7 @@ func (g *NoGuestCapabilities) GuestCapabilities() (caps []SystemCapability, err return } -func (g *NoGuestCapabilities) GuestCapability(c *SystemCapability, a interface{}, m *Machine) (r interface{}, err error) { +func (g *NoGuestCapabilities) GuestCapability(x context.Context, c *SystemCapability, a interface{}, m *Machine) (r interface{}, err error) { return } @@ -43,7 +47,7 @@ func (h *NoHostCapabilities) HostCapabilities() (caps []SystemCapability, err er return } -func (h *NoHostCapabilities) HostCapability(c *SystemCapability, a interface{}, e *Environment) (r interface{}, err error) { +func (h *NoHostCapabilities) HostCapability(x context.Context, c *SystemCapability, a interface{}, e *Environment) (r interface{}, err error) { return } @@ -52,6 +56,6 @@ func (p *NoProviderCapabilities) ProviderCapabilities() (caps []ProviderCapabili return } -func (p *NoProviderCapabilities) ProviderCapability(cap *ProviderCapability, args interface{}, machine *Machine) (result interface{}, err error) { +func (p *NoProviderCapabilities) ProviderCapability(x context.Context, cap *ProviderCapability, args interface{}, machine *Machine) (result interface{}, err error) { return } diff --git a/ext/go-plugin/vagrant/capabilities_test.go b/ext/go-plugin/vagrant/capabilities_test.go index 6bc4790d4..b8a22316c 100644 --- a/ext/go-plugin/vagrant/capabilities_test.go +++ b/ext/go-plugin/vagrant/capabilities_test.go @@ -1,6 +1,7 @@ package vagrant import ( + "context" "testing" ) @@ -19,7 +20,7 @@ func TestNoGuestCapability(t *testing.T) { g := NoGuestCapabilities{} m := &Machine{} cap := &SystemCapability{"Test", "Test"} - r, err := g.GuestCapability(cap, "args", m) + r, err := g.GuestCapability(context.Background(), cap, "args", m) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -43,7 +44,7 @@ func TestNoHostCapability(t *testing.T) { h := NoHostCapabilities{} e := &Environment{} cap := &SystemCapability{"Test", "Test"} - r, err := h.HostCapability(cap, "args", e) + r, err := h.HostCapability(context.Background(), cap, "args", e) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -67,7 +68,7 @@ func TestNoProviderCapability(t *testing.T) { p := NoProviderCapabilities{} m := &Machine{} cap := &ProviderCapability{"Test", "Test"} - r, err := p.ProviderCapability(cap, "args", m) + r, err := p.ProviderCapability(context.Background(), cap, "args", m) if err != nil { t.Fatalf("unexpected error: %s", err) } diff --git a/ext/go-plugin/vagrant/config.go b/ext/go-plugin/vagrant/config.go index befc9ce19..9761121be 100644 --- a/ext/go-plugin/vagrant/config.go +++ b/ext/go-plugin/vagrant/config.go @@ -1,15 +1,25 @@ package vagrant +import ( + "context" +) + type Config interface { ConfigAttributes() (attrs []string, err error) - ConfigLoad(data map[string]interface{}) (loaddata map[string]interface{}, err error) - ConfigValidate(data map[string]interface{}, m *Machine) (errors []string, err error) - ConfigFinalize(data map[string]interface{}) (finaldata map[string]interface{}, err error) + ConfigLoad(ctx context.Context, data map[string]interface{}) (loaddata map[string]interface{}, err error) + ConfigValidate(ctx context.Context, data map[string]interface{}, m *Machine) (errors []string, err error) + ConfigFinalize(ctx context.Context, data map[string]interface{}) (finaldata map[string]interface{}, err error) } type NoConfig struct{} -func (c *NoConfig) ConfigAttributes() (a []string, e error) { return } -func (c *NoConfig) ConfigLoad(map[string]interface{}) (d map[string]interface{}, e error) { return } -func (c *NoConfig) ConfigValidate(map[string]interface{}, *Machine) (es []string, e error) { return } -func (c *NoConfig) ConfigFinalize(map[string]interface{}) (f map[string]interface{}, e error) { return } +func (c *NoConfig) ConfigAttributes() (a []string, e error) { return } +func (c *NoConfig) ConfigLoad(context.Context, map[string]interface{}) (d map[string]interface{}, e error) { + return +} +func (c *NoConfig) ConfigValidate(context.Context, map[string]interface{}, *Machine) (es []string, e error) { + return +} +func (c *NoConfig) ConfigFinalize(context.Context, map[string]interface{}) (f map[string]interface{}, e error) { + return +} diff --git a/ext/go-plugin/vagrant/plugin/base.go b/ext/go-plugin/vagrant/plugin/base.go index 6afa5e7d3..366580994 100644 --- a/ext/go-plugin/vagrant/plugin/base.go +++ b/ext/go-plugin/vagrant/plugin/base.go @@ -1,14 +1,18 @@ package plugin import ( + "context" "errors" "fmt" "os" "os/exec" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" hclog "github.com/hashicorp/go-hclog" go_plugin "github.com/hashicorp/go-plugin" - "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" ) @@ -17,6 +21,7 @@ var ( MagicCookieKey: "VAGRANT_PLUGIN_MAGIC_COOKIE", MagicCookieValue: "1561a662a76642f98df77ad025aa13a9b16225d93f90475e91090fbe577317ed", ProtocolVersion: 1} + ErrPluginShutdown = errors.New("plugin has shutdown") ) type RemotePlugin interface { @@ -239,3 +244,51 @@ func (v *VagrantPlugin) Kill() { v.Logger.Info("plugin killed", "name", n, "type", "synced_folder") } } + +// Helper used for inspect GRPC related errors and providing "correct" +// error message +func handleGrpcError(err error, pluginCtx context.Context, reqCtx context.Context) error { + // If there was no error then nothing to process + if err == nil { + return nil + } + + // If a request context is provided, check that it + // was not canceled or timed out. If no context + // provided, stub one for later. + if reqCtx != nil { + s := status.FromContextError(reqCtx.Err()) + switch s.Code() { + case codes.Canceled: + return context.Canceled + case codes.DeadlineExceeded: + return context.DeadlineExceeded + } + } else { + reqCtx = context.Background() + } + + s, ok := status.FromError(err) + if ok && (s.Code() == codes.Unavailable || s.Code() == codes.Canceled) { + select { + case <-pluginCtx.Done(): + err = ErrPluginShutdown + case <-reqCtx.Done(): + err = reqCtx.Err() + select { + case <-pluginCtx.Done(): + err = ErrPluginShutdown + default: + } + case <-time.After(5): + return errors.New("exceeded context timeout") + } + return err + } else if s != nil { + // Extract actual error message received + // and create new error + return errors.New(s.Message()) + } + + return err +} diff --git a/ext/go-plugin/vagrant/plugin/capabilities.go b/ext/go-plugin/vagrant/plugin/capabilities.go index 6efb92d55..9bf1f0164 100644 --- a/ext/go-plugin/vagrant/plugin/capabilities.go +++ b/ext/go-plugin/vagrant/plugin/capabilities.go @@ -3,7 +3,6 @@ package plugin import ( "context" "encoding/json" - "errors" "google.golang.org/grpc" @@ -11,6 +10,8 @@ import ( "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_caps" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_common" + + "github.com/LK4D4/joincontext" ) type GuestCapabilities interface { @@ -35,9 +36,11 @@ func (g *GuestCapabilitiesPlugin) GRPCServer(broker *go_plugin.GRPCBroker, s *gr func (g *GuestCapabilitiesPlugin) GRPCClient(ctx context.Context, broker *go_plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { client := vagrant_caps.NewGuestCapabilitiesClient(c) return &GRPCGuestCapabilitiesClient{ - client: client, + client: client, + doneCtx: ctx, GRPCIOClient: GRPCIOClient{ - client: client}}, nil + client: client, + doneCtx: ctx}}, nil } type GRPCGuestCapabilitiesServer struct { @@ -47,9 +50,18 @@ type GRPCGuestCapabilitiesServer struct { func (s *GRPCGuestCapabilitiesServer) GuestCapabilities(ctx context.Context, req *vagrant_common.NullRequest) (resp *vagrant_caps.CapabilitiesResponse, err error) { resp = &vagrant_caps.CapabilitiesResponse{} - r, e := s.Impl.GuestCapabilities() - if e != nil { - resp.Error = e.Error() + var r []vagrant.SystemCapability + n := make(chan struct{}, 1) + go func() { + r, err = s.Impl.GuestCapabilities() + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { return } for _, cap := range r { @@ -61,7 +73,7 @@ func (s *GRPCGuestCapabilitiesServer) GuestCapabilities(ctx context.Context, req func (s *GRPCGuestCapabilitiesServer) GuestCapability(ctx context.Context, req *vagrant_caps.GuestCapabilityRequest) (resp *vagrant_caps.GuestCapabilityResponse, err error) { resp = &vagrant_caps.GuestCapabilityResponse{} - var args interface{} + var args, r interface{} if err = json.Unmarshal([]byte(req.Arguments), &args); err != nil { return } @@ -72,7 +84,17 @@ func (s *GRPCGuestCapabilitiesServer) GuestCapability(ctx context.Context, req * cap := &vagrant.SystemCapability{ Name: req.Capability.Name, Platform: req.Capability.Platform} - r, err := s.Impl.GuestCapability(cap, args, machine) + n := make(chan struct{}, 1) + go func() { + r, err = s.Impl.GuestCapability(ctx, cap, args, machine) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { return } @@ -87,17 +109,16 @@ func (s *GRPCGuestCapabilitiesServer) GuestCapability(ctx context.Context, req * type GRPCGuestCapabilitiesClient struct { GRPCCoreClient GRPCIOClient - client vagrant_caps.GuestCapabilitiesClient + client vagrant_caps.GuestCapabilitiesClient + doneCtx context.Context } func (c *GRPCGuestCapabilitiesClient) GuestCapabilities() (caps []vagrant.SystemCapability, err error) { - resp, err := c.client.GuestCapabilities(context.Background(), &vagrant_common.NullRequest{}) + ctx := context.Background() + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.GuestCapabilities(jctx, &vagrant_common.NullRequest{}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } caps = make([]vagrant.SystemCapability, len(resp.Capabilities)) for i := 0; i < len(resp.Capabilities); i++ { @@ -109,7 +130,7 @@ func (c *GRPCGuestCapabilitiesClient) GuestCapabilities() (caps []vagrant.System return } -func (c *GRPCGuestCapabilitiesClient) GuestCapability(cap *vagrant.SystemCapability, args interface{}, machine *vagrant.Machine) (result interface{}, err error) { +func (c *GRPCGuestCapabilitiesClient) GuestCapability(ctx context.Context, cap *vagrant.SystemCapability, args interface{}, machine *vagrant.Machine) (result interface{}, err error) { a, err := json.Marshal(args) if err != nil { return @@ -118,16 +139,13 @@ func (c *GRPCGuestCapabilitiesClient) GuestCapability(cap *vagrant.SystemCapabil if err != nil { return } - resp, err := c.client.GuestCapability(context.Background(), &vagrant_caps.GuestCapabilityRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.GuestCapability(jctx, &vagrant_caps.GuestCapabilityRequest{ Capability: &vagrant_caps.Capability{Name: cap.Name, Platform: cap.Platform}, Machine: m, Arguments: string(a)}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } err = json.Unmarshal([]byte(resp.Result), &result) return @@ -155,9 +173,11 @@ func (h *HostCapabilitiesPlugin) GRPCServer(broker *go_plugin.GRPCBroker, s *grp func (h *HostCapabilitiesPlugin) GRPCClient(ctx context.Context, broker *go_plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { client := vagrant_caps.NewHostCapabilitiesClient(c) return &GRPCHostCapabilitiesClient{ - client: client, + client: client, + doneCtx: ctx, GRPCIOClient: GRPCIOClient{ - client: client}}, nil + client: client, + doneCtx: ctx}}, nil } type GRPCHostCapabilitiesServer struct { @@ -167,9 +187,18 @@ type GRPCHostCapabilitiesServer struct { func (s *GRPCHostCapabilitiesServer) HostCapabilities(ctx context.Context, req *vagrant_common.NullRequest) (resp *vagrant_caps.CapabilitiesResponse, err error) { resp = &vagrant_caps.CapabilitiesResponse{} - r, e := s.Impl.HostCapabilities() - if e != nil { - resp.Error = e.Error() + var r []vagrant.SystemCapability + n := make(chan struct{}, 1) + go func() { + r, err = s.Impl.HostCapabilities() + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { return } for _, cap := range r { @@ -181,7 +210,7 @@ func (s *GRPCHostCapabilitiesServer) HostCapabilities(ctx context.Context, req * func (s *GRPCHostCapabilitiesServer) HostCapability(ctx context.Context, req *vagrant_caps.HostCapabilityRequest) (resp *vagrant_caps.HostCapabilityResponse, err error) { resp = &vagrant_caps.HostCapabilityResponse{} - var args interface{} + var args, r interface{} if err = json.Unmarshal([]byte(req.Arguments), &args); err != nil { return } @@ -192,7 +221,16 @@ func (s *GRPCHostCapabilitiesServer) HostCapability(ctx context.Context, req *va cap := &vagrant.SystemCapability{ Name: req.Capability.Name, Platform: req.Capability.Platform} - r, err := s.Impl.HostCapability(cap, args, env) + n := make(chan struct{}, 1) + go func() { + r, err = s.Impl.HostCapability(ctx, cap, args, env) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } if err != nil { return } @@ -207,17 +245,16 @@ func (s *GRPCHostCapabilitiesServer) HostCapability(ctx context.Context, req *va type GRPCHostCapabilitiesClient struct { GRPCCoreClient GRPCIOClient - client vagrant_caps.HostCapabilitiesClient + client vagrant_caps.HostCapabilitiesClient + doneCtx context.Context } func (c *GRPCHostCapabilitiesClient) HostCapabilities() (caps []vagrant.SystemCapability, err error) { - resp, err := c.client.HostCapabilities(context.Background(), &vagrant_common.NullRequest{}) + ctx := context.Background() + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.HostCapabilities(jctx, &vagrant_common.NullRequest{}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } caps = make([]vagrant.SystemCapability, len(resp.Capabilities)) for i := 0; i < len(resp.Capabilities); i++ { @@ -229,7 +266,7 @@ func (c *GRPCHostCapabilitiesClient) HostCapabilities() (caps []vagrant.SystemCa return } -func (c *GRPCHostCapabilitiesClient) HostCapability(cap *vagrant.SystemCapability, args interface{}, env *vagrant.Environment) (result interface{}, err error) { +func (c *GRPCHostCapabilitiesClient) HostCapability(ctx context.Context, cap *vagrant.SystemCapability, args interface{}, env *vagrant.Environment) (result interface{}, err error) { a, err := json.Marshal(args) if err != nil { return @@ -238,14 +275,15 @@ func (c *GRPCHostCapabilitiesClient) HostCapability(cap *vagrant.SystemCapabilit if err != nil { return } - resp, err := c.client.HostCapability(context.Background(), &vagrant_caps.HostCapabilityRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.HostCapability(jctx, &vagrant_caps.HostCapabilityRequest{ Capability: &vagrant_caps.Capability{ Name: cap.Name, Platform: cap.Platform}, Environment: e, Arguments: string(a)}) if err != nil { - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } err = json.Unmarshal([]byte(resp.Result), &result) return @@ -273,9 +311,11 @@ func (p *ProviderCapabilitiesPlugin) GRPCServer(broker *go_plugin.GRPCBroker, s func (p *ProviderCapabilitiesPlugin) GRPCClient(ctx context.Context, broker *go_plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { client := vagrant_caps.NewProviderCapabilitiesClient(c) return &GRPCProviderCapabilitiesClient{ - client: client, + client: client, + doneCtx: ctx, GRPCIOClient: GRPCIOClient{ - client: client}}, nil + client: client, + doneCtx: ctx}}, nil } type GRPCProviderCapabilitiesServer struct { @@ -285,9 +325,18 @@ type GRPCProviderCapabilitiesServer struct { func (s *GRPCProviderCapabilitiesServer) ProviderCapabilities(ctx context.Context, req *vagrant_common.NullRequest) (resp *vagrant_caps.ProviderCapabilitiesResponse, err error) { resp = &vagrant_caps.ProviderCapabilitiesResponse{} - r, e := s.Impl.ProviderCapabilities() - if e != nil { - resp.Error = e.Error() + var r []vagrant.ProviderCapability + n := make(chan struct{}, 1) + go func() { + r, err = s.Impl.ProviderCapabilities() + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { return } for _, cap := range r { @@ -299,7 +348,7 @@ func (s *GRPCProviderCapabilitiesServer) ProviderCapabilities(ctx context.Contex func (s *GRPCProviderCapabilitiesServer) ProviderCapability(ctx context.Context, req *vagrant_caps.ProviderCapabilityRequest) (resp *vagrant_caps.ProviderCapabilityResponse, err error) { resp = &vagrant_caps.ProviderCapabilityResponse{} - var args interface{} + var args, r interface{} err = json.Unmarshal([]byte(req.Arguments), &args) if err != nil { return @@ -311,7 +360,16 @@ func (s *GRPCProviderCapabilitiesServer) ProviderCapability(ctx context.Context, cap := &vagrant.ProviderCapability{ Name: req.Capability.Name, Provider: req.Capability.Provider} - r, err := s.Impl.ProviderCapability(cap, args, m) + n := make(chan struct{}, 1) + go func() { + r, err = s.Impl.ProviderCapability(ctx, cap, args, m) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } if err != nil { return } @@ -326,17 +384,16 @@ func (s *GRPCProviderCapabilitiesServer) ProviderCapability(ctx context.Context, type GRPCProviderCapabilitiesClient struct { GRPCCoreClient GRPCIOClient - client vagrant_caps.ProviderCapabilitiesClient + client vagrant_caps.ProviderCapabilitiesClient + doneCtx context.Context } func (c *GRPCProviderCapabilitiesClient) ProviderCapabilities() (caps []vagrant.ProviderCapability, err error) { - resp, err := c.client.ProviderCapabilities(context.Background(), &vagrant_common.NullRequest{}) + ctx := context.Background() + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.ProviderCapabilities(jctx, &vagrant_common.NullRequest{}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } caps = make([]vagrant.ProviderCapability, len(resp.Capabilities)) for i := 0; i < len(resp.Capabilities); i++ { @@ -348,7 +405,7 @@ func (c *GRPCProviderCapabilitiesClient) ProviderCapabilities() (caps []vagrant. return } -func (c *GRPCProviderCapabilitiesClient) ProviderCapability(cap *vagrant.ProviderCapability, args interface{}, machine *vagrant.Machine) (result interface{}, err error) { +func (c *GRPCProviderCapabilitiesClient) ProviderCapability(ctx context.Context, cap *vagrant.ProviderCapability, args interface{}, machine *vagrant.Machine) (result interface{}, err error) { a, err := json.Marshal(args) if err != nil { return @@ -357,14 +414,15 @@ func (c *GRPCProviderCapabilitiesClient) ProviderCapability(cap *vagrant.Provide if err != nil { return } - resp, err := c.client.ProviderCapability(context.Background(), &vagrant_caps.ProviderCapabilityRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.ProviderCapability(jctx, &vagrant_caps.ProviderCapabilityRequest{ Capability: &vagrant_caps.ProviderCapability{ Name: cap.Name, Provider: cap.Provider}, Machine: m, Arguments: string(a)}) if err != nil { - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } err = json.Unmarshal([]byte(resp.Result), &result) return diff --git a/ext/go-plugin/vagrant/plugin/capabilities_test.go b/ext/go-plugin/vagrant/plugin/capabilities_test.go index 984924bb0..3993c705a 100644 --- a/ext/go-plugin/vagrant/plugin/capabilities_test.go +++ b/ext/go-plugin/vagrant/plugin/capabilities_test.go @@ -1,7 +1,9 @@ package plugin import ( + "context" "testing" + "time" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" @@ -57,7 +59,7 @@ func TestCapabilities_GuestCapability(t *testing.T) { m := &vagrant.Machine{} args := []string{"test_value", "next_test_value"} - resp, err := impl.GuestCapability(cap, args, m) + resp, err := impl.GuestCapability(context.Background(), cap, args, m) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -95,7 +97,7 @@ func TestCapabilities_GuestCapability_noargs(t *testing.T) { var args interface{} args = nil - resp, err := impl.GuestCapability(cap, args, m) + resp, err := impl.GuestCapability(context.Background(), cap, args, m) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -108,6 +110,80 @@ func TestCapabilities_GuestCapability_noargs(t *testing.T) { } } +func TestCapabilities_GuestCapability_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "caps": &GuestCapabilitiesPlugin{Impl: &MockGuestCapabilities{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("caps") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(GuestCapabilities) + if !ok { + t.Fatalf("bad %#v", raw) + } + + cap := &vagrant.SystemCapability{ + Name: "test_cap", + Platform: "TestOS"} + m := &vagrant.Machine{} + args := []string{"pause", "test_value", "next_test_value"} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.GuestCapability(ctx, cap, args, m) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestCapabilities_GuestCapability_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "caps": &GuestCapabilitiesPlugin{Impl: &MockGuestCapabilities{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("caps") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(GuestCapabilities) + if !ok { + t.Fatalf("bad %#v", raw) + } + + cap := &vagrant.SystemCapability{ + Name: "test_cap", + Platform: "TestOS"} + m := &vagrant.Machine{} + args := []string{"pause", "test_value", "next_test_value"} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.GuestCapability(ctx, cap, args, m) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestCapabilities_HostCapabilities(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "caps": &HostCapabilitiesPlugin{Impl: &MockHostCapabilities{}}}) @@ -158,7 +234,7 @@ func TestCapabilities_HostCapability(t *testing.T) { e := &vagrant.Environment{} args := []string{"test_value", "next_test_value"} - resp, err := impl.HostCapability(cap, args, e) + resp, err := impl.HostCapability(context.Background(), cap, args, e) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -196,7 +272,7 @@ func TestCapabilities_HostCapability_noargs(t *testing.T) { var args interface{} args = nil - resp, err := impl.HostCapability(cap, args, e) + resp, err := impl.HostCapability(context.Background(), cap, args, e) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -209,6 +285,80 @@ func TestCapabilities_HostCapability_noargs(t *testing.T) { } } +func TestCapabilities_HostCapability_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "caps": &HostCapabilitiesPlugin{Impl: &MockHostCapabilities{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("caps") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(HostCapabilities) + if !ok { + t.Fatalf("bad %#v", raw) + } + + cap := &vagrant.SystemCapability{ + Name: "test_cap", + Platform: "TestOS"} + e := &vagrant.Environment{} + args := []string{"pause", "test_value", "next_test_value"} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.HostCapability(ctx, cap, args, e) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestCapabilities_HostCapability_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "caps": &HostCapabilitiesPlugin{Impl: &MockHostCapabilities{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("caps") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(HostCapabilities) + if !ok { + t.Fatalf("bad %#v", raw) + } + + cap := &vagrant.SystemCapability{ + Name: "test_cap", + Platform: "TestOS"} + e := &vagrant.Environment{} + args := []string{"pause", "test_value", "next_test_value"} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.HostCapability(ctx, cap, args, e) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestCapabilities_ProviderCapabilities(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "caps": &ProviderCapabilitiesPlugin{Impl: &MockProviderCapabilities{}}}) @@ -259,7 +409,7 @@ func TestCapabilities_ProviderCapability(t *testing.T) { m := &vagrant.Machine{} args := []string{"test_value", "next_test_value"} - resp, err := impl.ProviderCapability(cap, args, m) + resp, err := impl.ProviderCapability(context.Background(), cap, args, m) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -297,7 +447,7 @@ func TestCapabilities_ProviderCapability_noargs(t *testing.T) { var args interface{} args = nil - resp, err := impl.ProviderCapability(cap, args, m) + resp, err := impl.ProviderCapability(context.Background(), cap, args, m) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -309,3 +459,77 @@ func TestCapabilities_ProviderCapability_noargs(t *testing.T) { t.Errorf("%s != test_cap", result[0]) } } + +func TestCapabilities_ProviderCapability_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "caps": &ProviderCapabilitiesPlugin{Impl: &MockProviderCapabilities{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("caps") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(ProviderCapabilities) + if !ok { + t.Fatalf("bad %#v", raw) + } + + cap := &vagrant.ProviderCapability{ + Name: "test_cap", + Provider: "test_provider"} + m := &vagrant.Machine{} + args := []string{"pause", "test_value", "next_test_value"} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.ProviderCapability(ctx, cap, args, m) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestCapabilities_ProviderCapability_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "caps": &ProviderCapabilitiesPlugin{Impl: &MockProviderCapabilities{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("caps") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(ProviderCapabilities) + if !ok { + t.Fatalf("bad %#v", raw) + } + + cap := &vagrant.ProviderCapability{ + Name: "test_cap", + Provider: "test_provider"} + m := &vagrant.Machine{} + args := []string{"pause", "test_value", "next_test_value"} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.ProviderCapability(ctx, cap, args, m) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} diff --git a/ext/go-plugin/vagrant/plugin/config.go b/ext/go-plugin/vagrant/plugin/config.go index 926baf915..7dccdadab 100644 --- a/ext/go-plugin/vagrant/plugin/config.go +++ b/ext/go-plugin/vagrant/plugin/config.go @@ -3,7 +3,6 @@ package plugin import ( "context" "encoding/json" - "errors" "google.golang.org/grpc" @@ -11,6 +10,8 @@ import ( "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_common" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_config" + + "github.com/LK4D4/joincontext" ) type Config interface { @@ -35,9 +36,11 @@ func (c *ConfigPlugin) GRPCServer(broker *go_plugin.GRPCBroker, s *grpc.Server) func (c *ConfigPlugin) GRPCClient(ctx context.Context, broker *go_plugin.GRPCBroker, con *grpc.ClientConn) (interface{}, error) { client := vagrant_config.NewConfigClient(con) return &GRPCConfigClient{ - client: client, + client: client, + doneCtx: ctx, GRPCIOClient: GRPCIOClient{ - client: client}}, nil + client: client, + doneCtx: ctx}}, nil } type GRPCConfigServer struct { @@ -47,33 +50,45 @@ type GRPCConfigServer struct { func (s *GRPCConfigServer) ConfigAttributes(ctx context.Context, req *vagrant_common.NullRequest) (resp *vagrant_config.AttributesResponse, err error) { resp = &vagrant_config.AttributesResponse{} - r, e := s.Impl.ConfigAttributes() - if e != nil { - resp.Error = e.Error() - return + n := make(chan struct{}, 1) + go func() { + resp.Attributes, err = s.Impl.ConfigAttributes() + n <- struct{}{} + }() + select { + case <-ctx.Done(): + case <-n: } - resp.Attributes = r return } func (s *GRPCConfigServer) ConfigLoad(ctx context.Context, req *vagrant_config.LoadRequest) (resp *vagrant_config.LoadResponse, err error) { resp = &vagrant_config.LoadResponse{} - var data map[string]interface{} + var data, r map[string]interface{} err = json.Unmarshal([]byte(req.Data), &data) if err != nil { - resp.Error = err.Error() return } - r, err := s.Impl.ConfigLoad(data) + n := make(chan struct{}, 1) + go func() { + r, err = s.Impl.ConfigLoad(ctx, data) + n <- struct{}{} + }() + + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { - resp.Error = err.Error() return } mdata, err := json.Marshal(r) if err != nil { - resp.Error = err.Error() return } + resp.Data = string(mdata) return } @@ -83,39 +98,50 @@ func (s *GRPCConfigServer) ConfigValidate(ctx context.Context, req *vagrant_conf var data map[string]interface{} err = json.Unmarshal([]byte(req.Data), &data) if err != nil { - resp.Error = err.Error() return } m, err := vagrant.LoadMachine(req.Machine, s.Impl) if err != nil { - resp.Error = err.Error() return } - r, err := s.Impl.ConfigValidate(data, m) - if err != nil { - resp.Error = err.Error() - return + n := make(chan struct{}, 1) + go func() { + resp.Errors, err = s.Impl.ConfigValidate(ctx, data, m) + n <- struct{}{} + }() + + select { + case <-ctx.Done(): + case <-n: } - resp.Errors = r + return } func (s *GRPCConfigServer) ConfigFinalize(ctx context.Context, req *vagrant_config.FinalizeRequest) (resp *vagrant_config.FinalizeResponse, err error) { resp = &vagrant_config.FinalizeResponse{} - var data map[string]interface{} + var data, r map[string]interface{} err = json.Unmarshal([]byte(req.Data), &data) if err != nil { - resp.Error = err.Error() return } - r, err := s.Impl.ConfigFinalize(data) + n := make(chan struct{}, 1) + go func() { + r, err = s.Impl.ConfigFinalize(ctx, data) + n <- struct{}{} + }() + + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { - resp.Error = err.Error() return } mdata, err := json.Marshal(r) if err != nil { - resp.Error = err.Error() return } resp.Data = string(mdata) @@ -125,40 +151,37 @@ func (s *GRPCConfigServer) ConfigFinalize(ctx context.Context, req *vagrant_conf type GRPCConfigClient struct { GRPCCoreClient GRPCIOClient - client vagrant_config.ConfigClient + client vagrant_config.ConfigClient + doneCtx context.Context } func (c *GRPCConfigClient) ConfigAttributes() (attrs []string, err error) { - resp, err := c.client.ConfigAttributes(context.Background(), &vagrant_common.NullRequest{}) + ctx := context.Background() + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.ConfigAttributes(jctx, &vagrant_common.NullRequest{}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) + return nil, handleGrpcError(err, c.doneCtx, nil) } attrs = resp.Attributes return } -func (c *GRPCConfigClient) ConfigLoad(data map[string]interface{}) (loaddata map[string]interface{}, err error) { +func (c *GRPCConfigClient) ConfigLoad(ctx context.Context, data map[string]interface{}) (loaddata map[string]interface{}, err error) { mdata, err := json.Marshal(data) if err != nil { return } - resp, err := c.client.ConfigLoad(context.Background(), &vagrant_config.LoadRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.ConfigLoad(jctx, &vagrant_config.LoadRequest{ Data: string(mdata)}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } err = json.Unmarshal([]byte(resp.Data), &loaddata) return } -func (c *GRPCConfigClient) ConfigValidate(data map[string]interface{}, m *vagrant.Machine) (errs []string, err error) { +func (c *GRPCConfigClient) ConfigValidate(ctx context.Context, data map[string]interface{}, m *vagrant.Machine) (errs []string, err error) { machData, err := vagrant.DumpMachine(m) if err != nil { return @@ -167,33 +190,27 @@ func (c *GRPCConfigClient) ConfigValidate(data map[string]interface{}, m *vagran if err != nil { return } - resp, err := c.client.ConfigValidate(context.Background(), &vagrant_config.ValidateRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.ConfigValidate(jctx, &vagrant_config.ValidateRequest{ Data: string(mdata), Machine: machData}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } errs = resp.Errors return } -func (c *GRPCConfigClient) ConfigFinalize(data map[string]interface{}) (finaldata map[string]interface{}, err error) { +func (c *GRPCConfigClient) ConfigFinalize(ctx context.Context, data map[string]interface{}) (finaldata map[string]interface{}, err error) { mdata, err := json.Marshal(data) if err != nil { return } - resp, err := c.client.ConfigFinalize(context.Background(), &vagrant_config.FinalizeRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.ConfigFinalize(jctx, &vagrant_config.FinalizeRequest{ Data: string(mdata)}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } err = json.Unmarshal([]byte(resp.Data), &finaldata) return diff --git a/ext/go-plugin/vagrant/plugin/config_test.go b/ext/go-plugin/vagrant/plugin/config_test.go index 64b460b21..0e264a512 100644 --- a/ext/go-plugin/vagrant/plugin/config_test.go +++ b/ext/go-plugin/vagrant/plugin/config_test.go @@ -1,7 +1,9 @@ package plugin import ( + "context" "testing" + "time" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" @@ -50,8 +52,8 @@ func TestConfigPlugin_Load(t *testing.T) { } data := map[string]interface{}{} - - resp, err := impl.ConfigLoad(data) + var resp map[string]interface{} + resp, err = impl.ConfigLoad(context.Background(), data) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -64,6 +66,70 @@ func TestConfigPlugin_Load(t *testing.T) { } } +func TestConfigPlugin_Load_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "configs": &ConfigPlugin{Impl: &MockConfig{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("configs") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Config) + if !ok { + t.Fatalf("bad %#v", raw) + } + + data := map[string]interface{}{"pause": true} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.ConfigLoad(ctx, data) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + +func TestConfigPlugin_Load_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "configs": &ConfigPlugin{Impl: &MockConfig{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("configs") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Config) + if !ok { + t.Fatalf("bad %#v", raw) + } + + data := map[string]interface{}{"pause": true} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.ConfigLoad(ctx, data) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + func TestConfigPlugin_Validate(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "configs": &ConfigPlugin{Impl: &MockConfig{}}}) @@ -82,7 +148,7 @@ func TestConfigPlugin_Validate(t *testing.T) { data := map[string]interface{}{} machine := &vagrant.Machine{} - resp, err := impl.ConfigValidate(data, machine) + resp, err := impl.ConfigValidate(context.Background(), data, machine) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -94,6 +160,43 @@ func TestConfigPlugin_Validate(t *testing.T) { } } +func TestConfigPlugin_Validate_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "configs": &ConfigPlugin{Impl: &MockConfig{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("configs") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Config) + if !ok { + t.Fatalf("bad %#v", raw) + } + + data := map[string]interface{}{"pause": true} + machine := &vagrant.Machine{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.ConfigValidate(ctx, data, machine) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + func TestConfigPlugin_Finalize(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "configs": &ConfigPlugin{Impl: &MockConfig{}}}) @@ -113,7 +216,7 @@ func TestConfigPlugin_Finalize(t *testing.T) { "test_key": "test_val", "other_key": "other_val"} - resp, err := impl.ConfigFinalize(data) + resp, err := impl.ConfigFinalize(context.Background(), data) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -129,3 +232,41 @@ func TestConfigPlugin_Finalize(t *testing.T) { t.Errorf("%s != other_val-updated", v) } } + +func TestConfigPlugin_Finalize_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "configs": &ConfigPlugin{Impl: &MockConfig{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("configs") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Config) + if !ok { + t.Fatalf("bad %#v", raw) + } + + data := map[string]interface{}{ + "pause": true, + "test_key": "test_val", + "other_key": "other_val"} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}, 1) + go func() { + _, err = impl.ConfigFinalize(ctx, data) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} diff --git a/ext/go-plugin/vagrant/plugin/io.go b/ext/go-plugin/vagrant/plugin/io.go index 25f049f7f..3ea8eda29 100644 --- a/ext/go-plugin/vagrant/plugin/io.go +++ b/ext/go-plugin/vagrant/plugin/io.go @@ -2,13 +2,14 @@ package plugin import ( "context" - "errors" "google.golang.org/grpc" go_plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_io" + + "github.com/LK4D4/joincontext" ) type IO interface { @@ -24,50 +25,62 @@ type GRPCIOServer struct { Impl vagrant.StreamIO } -func (s *GRPCIOServer) Read(ctx context.Context, req *vagrant_io.ReadRequest) (*vagrant_io.ReadResponse, error) { - r, e := s.Impl.Read(req.Target) - result := &vagrant_io.ReadResponse{Content: r} - if e != nil { - result.Error = e.Error() +func (s *GRPCIOServer) Read(ctx context.Context, req *vagrant_io.ReadRequest) (r *vagrant_io.ReadResponse, err error) { + r = &vagrant_io.ReadResponse{} + n := make(chan struct{}, 1) + go func() { + r.Content, err = s.Impl.Read(req.Target) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + case <-n: } - return result, nil + return } -func (s *GRPCIOServer) Write(ctx context.Context, req *vagrant_io.WriteRequest) (*vagrant_io.WriteResponse, error) { - n, e := s.Impl.Write(req.Content, req.Target) - result := &vagrant_io.WriteResponse{Length: int32(n)} - if e != nil { - result.Error = e.Error() +func (s *GRPCIOServer) Write(ctx context.Context, req *vagrant_io.WriteRequest) (r *vagrant_io.WriteResponse, err error) { + r = &vagrant_io.WriteResponse{} + n := make(chan struct{}, 1) + bytes := 0 + go func() { + bytes, err = s.Impl.Write(req.Content, req.Target) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + r.Length = int32(bytes) } - return result, nil + return } type GRPCIOClient struct { - client vagrant_io.IOClient + client vagrant_io.IOClient + doneCtx context.Context } func (c *GRPCIOClient) Read(target string) (content string, err error) { - resp, err := c.client.Read(context.Background(), &vagrant_io.ReadRequest{ + ctx := context.Background() + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.Read(jctx, &vagrant_io.ReadRequest{ Target: target}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) + return content, handleGrpcError(err, c.doneCtx, ctx) } content = resp.Content return } func (c *GRPCIOClient) Write(content, target string) (length int, err error) { - resp, err := c.client.Write(context.Background(), &vagrant_io.WriteRequest{ + ctx := context.Background() + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.Write(jctx, &vagrant_io.WriteRequest{ Content: content, Target: target}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) + return length, handleGrpcError(err, c.doneCtx, ctx) } length = int(resp.Length) return @@ -79,5 +92,7 @@ func (i *IOPlugin) GRPCServer(broker *go_plugin.GRPCBroker, s *grpc.Server) erro } func (i *IOPlugin) GRPCClient(ctx context.Context, broker *go_plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { - return &GRPCIOClient{client: vagrant_io.NewIOClient(c)}, nil + return &GRPCIOClient{ + client: vagrant_io.NewIOClient(c), + doneCtx: ctx}, nil } diff --git a/ext/go-plugin/vagrant/plugin/mocks.go b/ext/go-plugin/vagrant/plugin/mocks.go index df7b4e5a7..61359c6fe 100644 --- a/ext/go-plugin/vagrant/plugin/mocks.go +++ b/ext/go-plugin/vagrant/plugin/mocks.go @@ -1,7 +1,9 @@ package plugin import ( + "context" "errors" + "time" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" ) @@ -14,9 +16,12 @@ func (g *MockGuestCapabilities) GuestCapabilities() (caps []vagrant.SystemCapabi return } -func (g *MockGuestCapabilities) GuestCapability(cap *vagrant.SystemCapability, args interface{}, m *vagrant.Machine) (result interface{}, err error) { +func (g *MockGuestCapabilities) GuestCapability(ctx context.Context, cap *vagrant.SystemCapability, args interface{}, m *vagrant.Machine) (result interface{}, err error) { if args != nil { arguments := args.([]interface{}) + if arguments[0] == "pause" { + time.Sleep(1 * time.Second) + } if len(arguments) > 0 { result = []string{ cap.Name, @@ -36,9 +41,12 @@ func (h *MockHostCapabilities) HostCapabilities() (caps []vagrant.SystemCapabili return } -func (h *MockHostCapabilities) HostCapability(cap *vagrant.SystemCapability, args interface{}, e *vagrant.Environment) (result interface{}, err error) { +func (h *MockHostCapabilities) HostCapability(ctx context.Context, cap *vagrant.SystemCapability, args interface{}, e *vagrant.Environment) (result interface{}, err error) { if args != nil { arguments := args.([]interface{}) + if arguments[0] == "pause" { + time.Sleep(1 * time.Second) + } if len(arguments) > 0 { result = []string{ cap.Name, @@ -58,9 +66,12 @@ func (p *MockProviderCapabilities) ProviderCapabilities() (caps []vagrant.Provid return } -func (p *MockProviderCapabilities) ProviderCapability(cap *vagrant.ProviderCapability, args interface{}, m *vagrant.Machine) (result interface{}, err error) { +func (p *MockProviderCapabilities) ProviderCapability(ctx context.Context, cap *vagrant.ProviderCapability, args interface{}, m *vagrant.Machine) (result interface{}, err error) { if args != nil { arguments := args.([]interface{}) + if arguments[0] == "pause" { + time.Sleep(1 * time.Second) + } if len(arguments) > 0 { result = []string{ cap.Name, @@ -81,7 +92,10 @@ func (c *MockConfig) ConfigAttributes() (attrs []string, err error) { return } -func (c *MockConfig) ConfigLoad(data map[string]interface{}) (loaddata map[string]interface{}, err error) { +func (c *MockConfig) ConfigLoad(ctx context.Context, data map[string]interface{}) (loaddata map[string]interface{}, err error) { + if data["pause"] == true { + time.Sleep(1 * time.Second) + } loaddata = map[string]interface{}{ "test_key": "test_val"} if data["test_key"] != nil { @@ -90,17 +104,26 @@ func (c *MockConfig) ConfigLoad(data map[string]interface{}) (loaddata map[strin return } -func (c *MockConfig) ConfigValidate(data map[string]interface{}, m *vagrant.Machine) (errors []string, err error) { +func (c *MockConfig) ConfigValidate(ctx context.Context, data map[string]interface{}, m *vagrant.Machine) (errors []string, err error) { errors = []string{"test error"} + if data["pause"] == true { + time.Sleep(1 * time.Second) + } return } -func (c *MockConfig) ConfigFinalize(data map[string]interface{}) (finaldata map[string]interface{}, err error) { +func (c *MockConfig) ConfigFinalize(ctx context.Context, data map[string]interface{}) (finaldata map[string]interface{}, err error) { finaldata = make(map[string]interface{}) for key, tval := range data { - val := tval.(string) + val, ok := tval.(string) + if !ok { + continue + } finaldata[key] = val + "-updated" } + if data["pause"] == true { + time.Sleep(1 * time.Second) + } return } @@ -112,24 +135,39 @@ type MockProvider struct { vagrant.NoProviderCapabilities } -func (c *MockProvider) Action(actionName string, m *vagrant.Machine) (actions []string, err error) { - if actionName == "valid" { +func (c *MockProvider) Action(ctx context.Context, actionName string, m *vagrant.Machine) (actions []string, err error) { + switch actionName { + case "valid": actions = []string{"self::DoTask"} - } else { + case "pause": + time.Sleep(1 * time.Second) + default: err = errors.New("Unknown action requested") } return } -func (c *MockProvider) IsInstalled(m *vagrant.Machine) (bool, error) { +func (c *MockProvider) IsInstalled(ctx context.Context, m *vagrant.Machine) (bool, error) { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + return true, nil } -func (c *MockProvider) IsUsable(m *vagrant.Machine) (bool, error) { +func (c *MockProvider) IsUsable(ctx context.Context, m *vagrant.Machine) (bool, error) { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + return true, nil } -func (c *MockProvider) MachineIdChanged(m *vagrant.Machine) error { +func (c *MockProvider) MachineIdChanged(ctx context.Context, m *vagrant.Machine) error { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + return nil } @@ -137,13 +175,15 @@ func (c *MockProvider) Name() string { return "mock_provider" } -func (c *MockProvider) RunAction(actionName string, args interface{}, m *vagrant.Machine) (r interface{}, err error) { - if actionName != "valid" && actionName != "send_output" { - err = errors.New("invalid action name") - return - } - if actionName == "send_output" { +func (c *MockProvider) RunAction(ctx context.Context, actionName string, args interface{}, m *vagrant.Machine) (r interface{}, err error) { + switch actionName { + case "send_output": m.UI.Say("test_output_p") + case "pause": + time.Sleep(1 * time.Second) + case "valid": + default: + return nil, errors.New("invalid action name") } var arguments []interface{} if args != nil { @@ -157,13 +197,21 @@ func (c *MockProvider) RunAction(actionName string, args interface{}, m *vagrant return } -func (c *MockProvider) SshInfo(m *vagrant.Machine) (*vagrant.SshInfo, error) { +func (c *MockProvider) SshInfo(ctx context.Context, m *vagrant.Machine) (*vagrant.SshInfo, error) { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + return &vagrant.SshInfo{ Host: "localhost", Port: 2222}, nil } -func (c *MockProvider) State(m *vagrant.Machine) (*vagrant.MachineState, error) { +func (c *MockProvider) State(ctx context.Context, m *vagrant.Machine) (*vagrant.MachineState, error) { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + return &vagrant.MachineState{ Id: "default", ShortDesc: "running"}, nil @@ -181,7 +229,11 @@ type MockSyncedFolder struct { vagrant.NoHostCapabilities } -func (s *MockSyncedFolder) Cleanup(m *vagrant.Machine, opts vagrant.FolderOptions) error { +func (s *MockSyncedFolder) Cleanup(ctx context.Context, m *vagrant.Machine, opts vagrant.FolderOptions) error { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + if opts != nil { err, _ := opts["error"].(bool) ui, _ := opts["ui"].(bool) @@ -196,14 +248,22 @@ func (s *MockSyncedFolder) Cleanup(m *vagrant.Machine, opts vagrant.FolderOption return nil } -func (s *MockSyncedFolder) Disable(m *vagrant.Machine, f vagrant.FolderList, opts vagrant.FolderOptions) error { +func (s *MockSyncedFolder) Disable(ctx context.Context, m *vagrant.Machine, f vagrant.FolderList, opts vagrant.FolderOptions) error { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + if opts != nil && opts["error"].(bool) { return errors.New("disable error") } return nil } -func (s *MockSyncedFolder) Enable(m *vagrant.Machine, f vagrant.FolderList, opts vagrant.FolderOptions) error { +func (s *MockSyncedFolder) Enable(ctx context.Context, m *vagrant.Machine, f vagrant.FolderList, opts vagrant.FolderOptions) error { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + if opts != nil && opts["error"].(bool) { return errors.New("enable error") } @@ -216,7 +276,11 @@ func (s *MockSyncedFolder) Info() *vagrant.SyncedFolderInfo { Priority: 100} } -func (s *MockSyncedFolder) IsUsable(m *vagrant.Machine) (bool, error) { +func (s *MockSyncedFolder) IsUsable(ctx context.Context, m *vagrant.Machine) (bool, error) { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + return true, nil } @@ -224,7 +288,11 @@ func (s *MockSyncedFolder) Name() string { return "mock_folder" } -func (s *MockSyncedFolder) Prepare(m *vagrant.Machine, f vagrant.FolderList, opts vagrant.FolderOptions) error { +func (s *MockSyncedFolder) Prepare(ctx context.Context, m *vagrant.Machine, f vagrant.FolderList, opts vagrant.FolderOptions) error { + if m.Name == "pause" { + time.Sleep(1 * time.Second) + } + if opts != nil && opts["error"].(bool) { return errors.New("prepare error") } diff --git a/ext/go-plugin/vagrant/plugin/proto/genproto b/ext/go-plugin/vagrant/plugin/proto/genproto index c0503d8b1..4c2eb562a 100755 --- a/ext/go-plugin/vagrant/plugin/proto/genproto +++ b/ext/go-plugin/vagrant/plugin/proto/genproto @@ -5,7 +5,7 @@ echo -n "Parsing proto files and generating go output... " for i in * do if [ -d "${i}" ]; then - protoc --proto_path=/home/spox/.go/src --proto_path=. --go_out=plugins=grpc:. "${i}"/*.proto; + protoc --proto_path=`go env GOPATH`/src --proto_path=. --go_out=plugins=grpc:. "${i}"/*.proto; fi done diff --git a/ext/go-plugin/vagrant/plugin/provider.go b/ext/go-plugin/vagrant/plugin/provider.go index 7f483524f..d8c617a58 100644 --- a/ext/go-plugin/vagrant/plugin/provider.go +++ b/ext/go-plugin/vagrant/plugin/provider.go @@ -3,7 +3,6 @@ package plugin import ( "context" "encoding/json" - "errors" "google.golang.org/grpc" @@ -11,6 +10,8 @@ import ( "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_common" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_provider" + + "github.com/LK4D4/joincontext" ) type Provider interface { @@ -30,29 +31,31 @@ type GRPCProviderClient struct { GRPCHostCapabilitiesClient GRPCProviderCapabilitiesClient GRPCIOClient - client vagrant_provider.ProviderClient + client vagrant_provider.ProviderClient + doneCtx context.Context } -func (c *GRPCProviderClient) Action(actionName string, m *vagrant.Machine) (r []string, err error) { +func (c *GRPCProviderClient) Action(ctx context.Context, actionName string, m *vagrant.Machine) (r []string, err error) { machData, err := vagrant.DumpMachine(m) if err != nil { return } - resp, err := c.client.Action(context.Background(), &vagrant_provider.ActionRequest{ + + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.Action(jctx, &vagrant_provider.ActionRequest{ Name: actionName, Machine: machData}) if err != nil { - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } r = resp.Result - if resp.Error != "" { - err = errors.New(resp.Error) - } return } func (c *GRPCProviderClient) Info() *vagrant.ProviderInfo { - resp, err := c.client.Info(context.Background(), &vagrant_common.NullRequest{}) + ctx := context.Background() + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.Info(jctx, &vagrant_common.NullRequest{}) if err != nil { return &vagrant.ProviderInfo{} } @@ -61,57 +64,51 @@ func (c *GRPCProviderClient) Info() *vagrant.ProviderInfo { Priority: resp.Priority} } -func (c *GRPCProviderClient) IsInstalled(m *vagrant.Machine) (r bool, err error) { +func (c *GRPCProviderClient) IsInstalled(ctx context.Context, m *vagrant.Machine) (r bool, err error) { machData, err := vagrant.DumpMachine(m) if err != nil { return } - resp, err := c.client.IsInstalled(context.Background(), &vagrant_common.EmptyRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.IsInstalled(jctx, &vagrant_common.EmptyRequest{ Machine: machData}) if err != nil { - return + return false, handleGrpcError(err, c.doneCtx, ctx) } r = resp.Result - if resp.Error != "" { - err = errors.New(resp.Error) - } return } -func (c *GRPCProviderClient) IsUsable(m *vagrant.Machine) (r bool, err error) { +func (c *GRPCProviderClient) IsUsable(ctx context.Context, m *vagrant.Machine) (r bool, err error) { machData, err := vagrant.DumpMachine(m) if err != nil { return } - resp, err := c.client.IsUsable(context.Background(), &vagrant_common.EmptyRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.IsUsable(jctx, &vagrant_common.EmptyRequest{ Machine: machData}) if err != nil { - return + return false, handleGrpcError(err, c.doneCtx, ctx) } r = resp.Result - if resp.Error != "" { - err = errors.New(resp.Error) - } return } -func (c *GRPCProviderClient) MachineIdChanged(m *vagrant.Machine) (err error) { +func (c *GRPCProviderClient) MachineIdChanged(ctx context.Context, m *vagrant.Machine) (err error) { machData, err := vagrant.DumpMachine(m) if err != nil { return } - resp, err := c.client.MachineIdChanged(context.Background(), &vagrant_common.EmptyRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + _, err = c.client.MachineIdChanged(jctx, &vagrant_common.EmptyRequest{ Machine: machData}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) + return handleGrpcError(err, c.doneCtx, ctx) } return } -func (c *GRPCProviderClient) RunAction(actName string, args interface{}, m *vagrant.Machine) (r interface{}, err error) { +func (c *GRPCProviderClient) RunAction(ctx context.Context, actName string, args interface{}, m *vagrant.Machine) (r interface{}, err error) { machData, err := vagrant.DumpMachine(m) if err != nil { return @@ -120,35 +117,31 @@ func (c *GRPCProviderClient) RunAction(actName string, args interface{}, m *vagr if err != nil { return } - resp, err := c.client.RunAction(context.Background(), &vagrant_provider.RunActionRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.RunAction(jctx, &vagrant_provider.RunActionRequest{ Name: actName, Data: string(runData), Machine: machData}) if err != nil { - return + return nil, handleGrpcError(err, c.doneCtx, ctx) } err = json.Unmarshal([]byte(resp.Data), &r) if err != nil { return } - if resp.Error != "" { - err = errors.New(resp.Error) - } return } -func (c *GRPCProviderClient) SshInfo(m *vagrant.Machine) (r *vagrant.SshInfo, err error) { +func (c *GRPCProviderClient) SshInfo(ctx context.Context, m *vagrant.Machine) (r *vagrant.SshInfo, err error) { machData, err := vagrant.DumpMachine(m) if err != nil { return } - resp, err := c.client.SshInfo(context.Background(), &vagrant_common.EmptyRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.SshInfo(jctx, &vagrant_common.EmptyRequest{ Machine: machData}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) + return nil, handleGrpcError(err, c.doneCtx, ctx) } r = &vagrant.SshInfo{ Host: resp.Host, @@ -158,18 +151,16 @@ func (c *GRPCProviderClient) SshInfo(m *vagrant.Machine) (r *vagrant.SshInfo, er return } -func (c *GRPCProviderClient) State(m *vagrant.Machine) (r *vagrant.MachineState, err error) { +func (c *GRPCProviderClient) State(ctx context.Context, m *vagrant.Machine) (r *vagrant.MachineState, err error) { machData, err := vagrant.DumpMachine(m) if err != nil { return } - resp, err := c.client.State(context.Background(), &vagrant_common.EmptyRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.State(jctx, &vagrant_common.EmptyRequest{ Machine: machData}) if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) + return nil, handleGrpcError(err, c.doneCtx, ctx) } r = &vagrant.MachineState{ Id: resp.Id, @@ -179,7 +170,9 @@ func (c *GRPCProviderClient) State(m *vagrant.Machine) (r *vagrant.MachineState, } func (c *GRPCProviderClient) Name() string { - resp, err := c.client.Name(context.Background(), &vagrant_common.NullRequest{}) + ctx := context.Background() + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.Name(jctx, &vagrant_common.NullRequest{}) if err != nil { return "" } @@ -190,16 +183,22 @@ func (p *ProviderPlugin) GRPCClient(ctx context.Context, broker *go_plugin.GRPCB client := vagrant_provider.NewProviderClient(c) return &GRPCProviderClient{ GRPCConfigClient: GRPCConfigClient{ - client: client}, + client: client, + doneCtx: ctx}, GRPCGuestCapabilitiesClient: GRPCGuestCapabilitiesClient{ - client: client}, + client: client, + doneCtx: ctx}, GRPCHostCapabilitiesClient: GRPCHostCapabilitiesClient{ - client: client}, + client: client, + doneCtx: ctx}, GRPCProviderCapabilitiesClient: GRPCProviderCapabilitiesClient{ - client: client}, + client: client, + doneCtx: ctx}, GRPCIOClient: GRPCIOClient{ - client: client}, - client: client, + client: client, + doneCtx: ctx}, + client: client, + doneCtx: ctx, }, nil } @@ -231,14 +230,23 @@ type GRPCProviderServer struct { func (s *GRPCProviderServer) Action(ctx context.Context, req *vagrant_provider.ActionRequest) (resp *vagrant_provider.ActionResponse, err error) { resp = &vagrant_provider.ActionResponse{} - m, e := vagrant.LoadMachine(req.Machine, s.Impl) - if e != nil { - resp.Error = e.Error() + var r []string + n := make(chan struct{}, 1) + m, err := vagrant.LoadMachine(req.Machine, s.Impl) + if err != nil { return } - r, e := s.Impl.Action(req.Name, m) - if e != nil { - resp.Error = e.Error() + go func() { + r, err = s.Impl.Action(ctx, req.Name, m) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + + if err != nil { return } resp.Result = r @@ -247,19 +255,27 @@ func (s *GRPCProviderServer) Action(ctx context.Context, req *vagrant_provider.A func (s *GRPCProviderServer) RunAction(ctx context.Context, req *vagrant_provider.RunActionRequest) (resp *vagrant_provider.RunActionResponse, err error) { resp = &vagrant_provider.RunActionResponse{} - m, e := vagrant.LoadMachine(req.Machine, s.Impl) - if e != nil { - resp.Error = e.Error() + var args, r interface{} + n := make(chan struct{}, 1) + m, err := vagrant.LoadMachine(req.Machine, s.Impl) + if err != nil { return } - var args interface{} err = json.Unmarshal([]byte(req.Data), &args) if err != nil { return } - r, e := s.Impl.RunAction(req.Name, args, m) - if e != nil { - resp.Error = e.Error() + go func() { + r, err = s.Impl.RunAction(ctx, req.Name, args, m) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + + if err != nil { return } result, err := json.Marshal(r) @@ -271,7 +287,18 @@ func (s *GRPCProviderServer) RunAction(ctx context.Context, req *vagrant_provide } func (s *GRPCProviderServer) Info(ctx context.Context, req *vagrant_common.NullRequest) (*vagrant_provider.InfoResponse, error) { - r := s.Impl.Info() + var r *vagrant.ProviderInfo + n := make(chan struct{}, 1) + go func() { + r = s.Impl.Info() + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return nil, nil + case <-n: + } + return &vagrant_provider.InfoResponse{ Description: r.Description, Priority: r.Priority}, nil @@ -279,14 +306,22 @@ func (s *GRPCProviderServer) Info(ctx context.Context, req *vagrant_common.NullR func (s *GRPCProviderServer) IsInstalled(ctx context.Context, req *vagrant_common.EmptyRequest) (resp *vagrant_common.IsResponse, err error) { resp = &vagrant_common.IsResponse{} - m, e := vagrant.LoadMachine(req.Machine, s.Impl) - if e != nil { - resp.Error = e.Error() + var r bool + n := make(chan struct{}, 1) + m, err := vagrant.LoadMachine(req.Machine, s.Impl) + if err != nil { return } - r, e := s.Impl.IsInstalled(m) - if e != nil { - resp.Error = e.Error() + go func() { + r, err = s.Impl.IsInstalled(ctx, m) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { return } resp.Result = r @@ -295,14 +330,22 @@ func (s *GRPCProviderServer) IsInstalled(ctx context.Context, req *vagrant_commo func (s *GRPCProviderServer) IsUsable(ctx context.Context, req *vagrant_common.EmptyRequest) (resp *vagrant_common.IsResponse, err error) { resp = &vagrant_common.IsResponse{} - m, e := vagrant.LoadMachine(req.Machine, s.Impl) - if e != nil { - resp.Error = e.Error() + var r bool + n := make(chan struct{}, 1) + m, err := vagrant.LoadMachine(req.Machine, s.Impl) + if err != nil { return } - r, e := s.Impl.IsUsable(m) - if e != nil { - resp.Error = e.Error() + go func() { + r, err = s.Impl.IsUsable(ctx, m) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { return } resp.Result = r @@ -311,14 +354,23 @@ func (s *GRPCProviderServer) IsUsable(ctx context.Context, req *vagrant_common.E func (s *GRPCProviderServer) SshInfo(ctx context.Context, req *vagrant_common.EmptyRequest) (resp *vagrant_provider.SshInfoResponse, err error) { resp = &vagrant_provider.SshInfoResponse{} - m, e := vagrant.LoadMachine(req.Machine, s.Impl) - if e != nil { - resp.Error = e.Error() + var r *vagrant.SshInfo + n := make(chan struct{}, 1) + m, err := vagrant.LoadMachine(req.Machine, s.Impl) + if err != nil { return } - r, e := s.Impl.SshInfo(m) - if e != nil { - resp.Error = e.Error() + go func() { + r, err = s.Impl.SshInfo(ctx, m) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + + if err != nil { return } resp = &vagrant_provider.SshInfoResponse{ @@ -331,14 +383,23 @@ func (s *GRPCProviderServer) SshInfo(ctx context.Context, req *vagrant_common.Em func (s *GRPCProviderServer) State(ctx context.Context, req *vagrant_common.EmptyRequest) (resp *vagrant_provider.StateResponse, err error) { resp = &vagrant_provider.StateResponse{} - m, e := vagrant.LoadMachine(req.Machine, s.Impl) - if e != nil { - resp.Error = e.Error() + var r *vagrant.MachineState + n := make(chan struct{}, 1) + m, err := vagrant.LoadMachine(req.Machine, s.Impl) + if err != nil { return } - r, e := s.Impl.State(m) - if e != nil { - resp.Error = e.Error() + go func() { + r, err = s.Impl.State(ctx, m) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + + if err != nil { return } resp = &vagrant_provider.StateResponse{ @@ -350,14 +411,18 @@ func (s *GRPCProviderServer) State(ctx context.Context, req *vagrant_common.Empt func (s *GRPCProviderServer) MachineIdChanged(ctx context.Context, req *vagrant_common.EmptyRequest) (resp *vagrant_common.EmptyResponse, err error) { resp = &vagrant_common.EmptyResponse{} - m, e := vagrant.LoadMachine(req.Machine, s.Impl) - if e != nil { - resp.Error = e.Error() + n := make(chan struct{}, 1) + m, err := vagrant.LoadMachine(req.Machine, s.Impl) + if err != nil { return } - e = s.Impl.MachineIdChanged(m) - if e != nil { - resp.Error = e.Error() + go func() { + err = s.Impl.MachineIdChanged(ctx, m) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + case <-n: } return } diff --git a/ext/go-plugin/vagrant/plugin/provider_test.go b/ext/go-plugin/vagrant/plugin/provider_test.go index 771406d58..f7f46c060 100644 --- a/ext/go-plugin/vagrant/plugin/provider_test.go +++ b/ext/go-plugin/vagrant/plugin/provider_test.go @@ -1,8 +1,10 @@ package plugin import ( + "context" "strings" "testing" + "time" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" @@ -23,7 +25,7 @@ func TestProvider_Action(t *testing.T) { t.Fatalf("bad %#v", raw) } - resp, err := impl.Action("valid", &vagrant.Machine{}) + resp, err := impl.Action(context.Background(), "valid", &vagrant.Machine{}) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -47,12 +49,74 @@ func TestProvider_Action_invalid(t *testing.T) { t.Fatalf("bad %#v", raw) } - _, err = impl.Action("invalid", &vagrant.Machine{}) + _, err = impl.Action(context.Background(), "invalid", &vagrant.Machine{}) if err == nil { t.Errorf("illegal action") } } +func TestProvider_Action_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.Action(ctx, "pause", &vagrant.Machine{}) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestProvider_Action_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.Action(ctx, "pause", &vagrant.Machine{}) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestProvider_IsInstalled(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "provider": &ProviderPlugin{Impl: &MockProvider{}}}) @@ -68,7 +132,7 @@ func TestProvider_IsInstalled(t *testing.T) { t.Fatalf("bad %#v", raw) } - installed, err := impl.IsInstalled(&vagrant.Machine{}) + installed, err := impl.IsInstalled(context.Background(), &vagrant.Machine{}) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -77,7 +141,7 @@ func TestProvider_IsInstalled(t *testing.T) { } } -func TestProvider_IsUsable(t *testing.T) { +func TestProvider_IsInstalled_context_cancel(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "provider": &ProviderPlugin{Impl: &MockProvider{}}}) defer server.Stop() @@ -92,7 +156,67 @@ func TestProvider_IsUsable(t *testing.T) { t.Fatalf("bad %#v", raw) } - usable, err := impl.IsUsable(&vagrant.Machine{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.IsInstalled(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestProvider_IsInstalled_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.IsInstalled(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} +func TestProvider_IsUsable(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + usable, err := impl.IsUsable(context.Background(), &vagrant.Machine{}) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -101,6 +225,65 @@ func TestProvider_IsUsable(t *testing.T) { } } +func TestProvider_IsUsable_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.IsUsable(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestProvider_IsUsable_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.IsUsable(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} func TestProvider_MachineIdChanged(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "provider": &ProviderPlugin{Impl: &MockProvider{}}}) @@ -116,12 +299,74 @@ func TestProvider_MachineIdChanged(t *testing.T) { t.Fatalf("bad %#v", raw) } - err = impl.MachineIdChanged(&vagrant.Machine{}) + err = impl.MachineIdChanged(context.Background(), &vagrant.Machine{}) if err != nil { t.Errorf("err: %s", err) } } +func TestProvider_MachineIdChanged_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.MachineIdChanged(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestProvider_MachineIdChanged_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.MachineIdChanged(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestProvider_Name(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "provider": &ProviderPlugin{Impl: &MockProvider{}}}) @@ -161,7 +406,7 @@ func TestProvider_RunAction(t *testing.T) { args := []string{"test_arg", "other_arg"} m := &vagrant.Machine{} - resp, err := impl.RunAction("valid", args, m) + resp, err := impl.RunAction(context.Background(), "valid", args, m) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -175,6 +420,74 @@ func TestProvider_RunAction(t *testing.T) { } } +func TestProvider_RunAction_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + args := []string{"test_arg", "other_arg"} + m := &vagrant.Machine{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.RunAction(ctx, "pause", args, m) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestProvider_RunAction_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + args := []string{"test_arg", "other_arg"} + m := &vagrant.Machine{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.RunAction(ctx, "pause", args, m) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestProvider_RunAction_invalid(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "provider": &ProviderPlugin{Impl: &MockProvider{}}}) @@ -193,7 +506,7 @@ func TestProvider_RunAction_invalid(t *testing.T) { args := []string{"test_arg", "other_arg"} m := &vagrant.Machine{} - _, err = impl.RunAction("invalid", args, m) + _, err = impl.RunAction(context.Background(), "invalid", args, m) if err == nil { t.Fatalf("illegal action run") } @@ -214,7 +527,7 @@ func TestProvider_SshInfo(t *testing.T) { t.Fatalf("bad %#v", raw) } - resp, err := impl.SshInfo(&vagrant.Machine{}) + resp, err := impl.SshInfo(context.Background(), &vagrant.Machine{}) if err != nil { t.Fatalf("invalid resp: %s", err) } @@ -227,6 +540,68 @@ func TestProvider_SshInfo(t *testing.T) { } } +func TestProvider_SshInfo_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.SshInfo(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("invalid resp: %s", err) + } +} + +func TestProvider_SshInfo_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.SshInfo(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("invalid resp: %s", err) + } +} + func TestProvider_State(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "provider": &ProviderPlugin{Impl: &MockProvider{}}}) @@ -242,7 +617,7 @@ func TestProvider_State(t *testing.T) { t.Fatalf("bad %#v", raw) } - resp, err := impl.State(&vagrant.Machine{}) + resp, err := impl.State(context.Background(), &vagrant.Machine{}) if err != nil { t.Fatalf("invalid resp: %s", err) } @@ -255,6 +630,68 @@ func TestProvider_State(t *testing.T) { } } +func TestProvider_State_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.State(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("invalid resp: %s", err) + } +} + +func TestProvider_State_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "provider": &ProviderPlugin{Impl: &MockProvider{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("provider") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(Provider) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.State(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("invalid resp: %s", err) + } +} + func TestProvider_Info(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "provider": &ProviderPlugin{Impl: &MockProvider{}}}) @@ -295,8 +732,9 @@ func TestProvider_MachineUI_output(t *testing.T) { t.Fatalf("bad %#v", raw) } + ctx := context.Background() go func() { - _, err = impl.RunAction("send_output", nil, &vagrant.Machine{}) + _, err = impl.RunAction(ctx, "send_output", nil, &vagrant.Machine{}) if err != nil { t.Fatalf("bad resp: %s", err) } diff --git a/ext/go-plugin/vagrant/plugin/synced_folder.go b/ext/go-plugin/vagrant/plugin/synced_folder.go index 2240e2867..064e3aa1b 100644 --- a/ext/go-plugin/vagrant/plugin/synced_folder.go +++ b/ext/go-plugin/vagrant/plugin/synced_folder.go @@ -3,7 +3,6 @@ package plugin import ( "context" "encoding/json" - "errors" "google.golang.org/grpc" @@ -11,6 +10,8 @@ import ( "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_common" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant/plugin/proto/vagrant_folder" + + "github.com/LK4D4/joincontext" ) type SyncedFolder interface { @@ -28,10 +29,11 @@ type GRPCSyncedFolderClient struct { GRPCGuestCapabilitiesClient GRPCHostCapabilitiesClient GRPCIOClient - client vagrant_folder.SyncedFolderClient + client vagrant_folder.SyncedFolderClient + doneCtx context.Context } -func (c *GRPCSyncedFolderClient) Cleanup(m *vagrant.Machine, o vagrant.FolderOptions) (err error) { +func (c *GRPCSyncedFolderClient) Cleanup(ctx context.Context, m *vagrant.Machine, o vagrant.FolderOptions) (err error) { machine, err := vagrant.DumpMachine(m) if err != nil { return @@ -40,19 +42,14 @@ func (c *GRPCSyncedFolderClient) Cleanup(m *vagrant.Machine, o vagrant.FolderOpt if err != nil { return } - resp, err := c.client.Cleanup(context.Background(), &vagrant_folder.CleanupRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + _, err = c.client.Cleanup(jctx, &vagrant_folder.CleanupRequest{ Machine: machine, Options: string(opts)}) - if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - } - return + return handleGrpcError(err, c.doneCtx, ctx) } -func (c *GRPCSyncedFolderClient) Disable(m *vagrant.Machine, f vagrant.FolderList, o vagrant.FolderOptions) (err error) { +func (c *GRPCSyncedFolderClient) Disable(ctx context.Context, m *vagrant.Machine, f vagrant.FolderList, o vagrant.FolderOptions) (err error) { machine, err := vagrant.DumpMachine(m) if err != nil { return @@ -65,20 +62,15 @@ func (c *GRPCSyncedFolderClient) Disable(m *vagrant.Machine, f vagrant.FolderLis if err != nil { return } - resp, err := c.client.Disable(context.Background(), &vagrant_folder.Request{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + _, err = c.client.Disable(jctx, &vagrant_folder.Request{ Machine: machine, Folders: string(folders), Options: string(opts)}) - if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - } - return + return handleGrpcError(err, c.doneCtx, ctx) } -func (c *GRPCSyncedFolderClient) Enable(m *vagrant.Machine, f vagrant.FolderList, o vagrant.FolderOptions) (err error) { +func (c *GRPCSyncedFolderClient) Enable(ctx context.Context, m *vagrant.Machine, f vagrant.FolderList, o vagrant.FolderOptions) (err error) { machine, err := vagrant.DumpMachine(m) if err != nil { return @@ -91,17 +83,12 @@ func (c *GRPCSyncedFolderClient) Enable(m *vagrant.Machine, f vagrant.FolderList if err != nil { return } - resp, err := c.client.Enable(context.Background(), &vagrant_folder.Request{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + _, err = c.client.Enable(jctx, &vagrant_folder.Request{ Machine: machine, Folders: string(folders), Options: string(opts)}) - if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - } - return + return handleGrpcError(err, c.doneCtx, ctx) } func (c *GRPCSyncedFolderClient) Info() *vagrant.SyncedFolderInfo { @@ -114,17 +101,18 @@ func (c *GRPCSyncedFolderClient) Info() *vagrant.SyncedFolderInfo { Priority: resp.Priority} } -func (c *GRPCSyncedFolderClient) IsUsable(m *vagrant.Machine) (u bool, err error) { +func (c *GRPCSyncedFolderClient) IsUsable(ctx context.Context, m *vagrant.Machine) (u bool, err error) { machine, err := vagrant.DumpMachine(m) if err != nil { return } - resp, err := c.client.IsUsable(context.Background(), &vagrant_common.EmptyRequest{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + resp, err := c.client.IsUsable(jctx, &vagrant_common.EmptyRequest{ Machine: machine}) - u = resp.Result - if resp.Error != "" { - err = errors.New(resp.Error) + if err != nil { + return false, handleGrpcError(err, c.doneCtx, ctx) } + u = resp.Result return } @@ -136,7 +124,7 @@ func (c *GRPCSyncedFolderClient) Name() string { return resp.Name } -func (c *GRPCSyncedFolderClient) Prepare(m *vagrant.Machine, f vagrant.FolderList, o vagrant.FolderOptions) (err error) { +func (c *GRPCSyncedFolderClient) Prepare(ctx context.Context, m *vagrant.Machine, f vagrant.FolderList, o vagrant.FolderOptions) (err error) { machine, err := vagrant.DumpMachine(m) if err != nil { return @@ -149,17 +137,12 @@ func (c *GRPCSyncedFolderClient) Prepare(m *vagrant.Machine, f vagrant.FolderLis if err != nil { return } - resp, err := c.client.Prepare(context.Background(), &vagrant_folder.Request{ + jctx, _ := joincontext.Join(ctx, c.doneCtx) + _, err = c.client.Prepare(jctx, &vagrant_folder.Request{ Machine: machine, Folders: string(folders), Options: string(opts)}) - if err != nil { - return - } - if resp.Error != "" { - err = errors.New(resp.Error) - } - return + return handleGrpcError(err, c.doneCtx, ctx) } type GRPCSyncedFolderServer struct { @@ -180,9 +163,14 @@ func (s *GRPCSyncedFolderServer) Cleanup(ctx context.Context, req *vagrant_folde if err != nil { return } - e := s.Impl.Cleanup(machine, options) - if e != nil { - resp.Error = e.Error() + n := make(chan struct{}) + go func() { + err = s.Impl.Cleanup(ctx, machine, options) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + case <-n: } return } @@ -203,9 +191,14 @@ func (s *GRPCSyncedFolderServer) Disable(ctx context.Context, req *vagrant_folde if err != nil { return } - e := s.Impl.Disable(machine, folders, options) - if e != nil { - resp.Error = e.Error() + n := make(chan struct{}) + go func() { + err = s.Impl.Disable(ctx, machine, folders, options) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + case <-n: } return } @@ -226,15 +219,30 @@ func (s *GRPCSyncedFolderServer) Enable(ctx context.Context, req *vagrant_folder if err != nil { return } - e := s.Impl.Enable(machine, folders, options) - if e != nil { - resp.Error = e.Error() + n := make(chan struct{}) + go func() { + err = s.Impl.Enable(ctx, machine, folders, options) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + case <-n: } return } func (s *GRPCSyncedFolderServer) Info(ctx context.Context, req *vagrant_common.NullRequest) (*vagrant_folder.InfoResponse, error) { - r := s.Impl.Info() + n := make(chan struct{}) + var r *vagrant.SyncedFolderInfo + go func() { + r = s.Impl.Info() + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return nil, nil + case <-n: + } return &vagrant_folder.InfoResponse{ Description: r.Description, Priority: r.Priority}, nil @@ -242,19 +250,29 @@ func (s *GRPCSyncedFolderServer) Info(ctx context.Context, req *vagrant_common.N func (s *GRPCSyncedFolderServer) IsUsable(ctx context.Context, req *vagrant_common.EmptyRequest) (resp *vagrant_common.IsResponse, err error) { resp = &vagrant_common.IsResponse{} + var r bool machine, err := vagrant.LoadMachine(req.Machine, s.Impl) if err != nil { return } - r, e := s.Impl.IsUsable(machine) - if e != nil { - resp.Error = e.Error() + n := make(chan struct{}) + go func() { + r, err = s.Impl.IsUsable(ctx, machine) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + return + case <-n: + } + if err != nil { + return } resp.Result = r return } -func (s *GRPCSyncedFolderServer) Name(ctx context.Context, req *vagrant_common.NullRequest) (*vagrant_common.NameResponse, error) { +func (s *GRPCSyncedFolderServer) Name(_ context.Context, req *vagrant_common.NullRequest) (*vagrant_common.NameResponse, error) { return &vagrant_common.NameResponse{Name: s.Impl.Name()}, nil } @@ -274,9 +292,14 @@ func (s *GRPCSyncedFolderServer) Prepare(ctx context.Context, req *vagrant_folde if err != nil { return } - e := s.Impl.Prepare(machine, folders, options) - if e != nil { - resp.Error = e.Error() + n := make(chan struct{}) + go func() { + err = s.Impl.Prepare(ctx, machine, folders, options) + n <- struct{}{} + }() + select { + case <-ctx.Done(): + case <-n: } return } @@ -299,10 +322,14 @@ func (f *SyncedFolderPlugin) GRPCClient(ctx context.Context, broker *go_plugin.G client := vagrant_folder.NewSyncedFolderClient(c) return &GRPCSyncedFolderClient{ GRPCIOClient: GRPCIOClient{ - client: client}, + client: client, + doneCtx: ctx}, GRPCGuestCapabilitiesClient: GRPCGuestCapabilitiesClient{ - client: client}, + client: client, + doneCtx: ctx}, GRPCHostCapabilitiesClient: GRPCHostCapabilitiesClient{ - client: client}, - client: client}, nil + client: client, + doneCtx: ctx}, + client: client, + doneCtx: ctx}, nil } diff --git a/ext/go-plugin/vagrant/plugin/synced_folder_test.go b/ext/go-plugin/vagrant/plugin/synced_folder_test.go index 77e9c1483..67c1a394c 100644 --- a/ext/go-plugin/vagrant/plugin/synced_folder_test.go +++ b/ext/go-plugin/vagrant/plugin/synced_folder_test.go @@ -1,8 +1,10 @@ package plugin import ( + "context" "strings" "testing" + "time" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vagrant/ext/go-plugin/vagrant" @@ -23,7 +25,7 @@ func TestSyncedFolder_Cleanup(t *testing.T) { t.Fatalf("bad %#v", raw) } - err = impl.Cleanup(&vagrant.Machine{}, nil) + err = impl.Cleanup(context.Background(), &vagrant.Machine{}, nil) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -47,7 +49,7 @@ func TestSyncedFolder_Cleanup_error(t *testing.T) { args := map[string]interface{}{ "error": true} - err = impl.Cleanup(&vagrant.Machine{}, args) + err = impl.Cleanup(context.Background(), &vagrant.Machine{}, args) if err == nil { t.Fatalf("illegal cleanup") } @@ -56,6 +58,68 @@ func TestSyncedFolder_Cleanup_error(t *testing.T) { } } +func TestSyncedFolder_Cleanup_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.Cleanup(ctx, &vagrant.Machine{Name: "pause"}, nil) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestSyncedFolder_Cleanup_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.Cleanup(ctx, &vagrant.Machine{Name: "pause"}, nil) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestSyncedFolder_Disable(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) @@ -71,7 +135,7 @@ func TestSyncedFolder_Disable(t *testing.T) { t.Fatalf("bad %#v", raw) } - err = impl.Disable(&vagrant.Machine{}, nil, nil) + err = impl.Disable(context.Background(), &vagrant.Machine{}, nil, nil) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -97,7 +161,7 @@ func TestSyncedFolder_Disable_error(t *testing.T) { args := map[string]interface{}{ "error": true} - err = impl.Disable(&vagrant.Machine{}, folders, args) + err = impl.Disable(context.Background(), &vagrant.Machine{}, folders, args) if err == nil { t.Fatalf("illegal disable") } @@ -106,6 +170,68 @@ func TestSyncedFolder_Disable_error(t *testing.T) { } } +func TestSyncedFolder_Disable_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.Disable(ctx, &vagrant.Machine{Name: "pause"}, nil, nil) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestSyncedFolder_Disable_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.Disable(ctx, &vagrant.Machine{Name: "pause"}, nil, nil) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestSyncedFolder_Enable(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) @@ -121,7 +247,7 @@ func TestSyncedFolder_Enable(t *testing.T) { t.Fatalf("bad %#v", raw) } - err = impl.Enable(&vagrant.Machine{}, nil, nil) + err = impl.Enable(context.Background(), &vagrant.Machine{}, nil, nil) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -147,7 +273,7 @@ func TestSyncedFolder_Enable_error(t *testing.T) { args := map[string]interface{}{ "error": true} - err = impl.Enable(&vagrant.Machine{}, folders, args) + err = impl.Enable(context.Background(), &vagrant.Machine{}, folders, args) if err == nil { t.Fatalf("illegal enable") } @@ -156,6 +282,68 @@ func TestSyncedFolder_Enable_error(t *testing.T) { } } +func TestSyncedFolder_Enable_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.Enable(ctx, &vagrant.Machine{Name: "pause"}, nil, nil) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestSyncedFolder_Enable_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.Enable(ctx, &vagrant.Machine{Name: "pause"}, nil, nil) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestSyncedFolder_Prepare(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) @@ -171,7 +359,7 @@ func TestSyncedFolder_Prepare(t *testing.T) { t.Fatalf("bad %#v", raw) } - err = impl.Prepare(&vagrant.Machine{}, nil, nil) + err = impl.Prepare(context.Background(), &vagrant.Machine{}, nil, nil) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -197,7 +385,7 @@ func TestSyncedFolder_Prepare_error(t *testing.T) { args := map[string]interface{}{ "error": true} - err = impl.Prepare(&vagrant.Machine{}, folders, args) + err = impl.Prepare(context.Background(), &vagrant.Machine{}, folders, args) if err == nil { t.Fatalf("illegal prepare") } @@ -206,6 +394,67 @@ func TestSyncedFolder_Prepare_error(t *testing.T) { } } +func TestSyncedFolder_Prepare_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.Prepare(ctx, &vagrant.Machine{Name: "pause"}, nil, nil) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestSyncedFolder_Prepare_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + err = impl.Prepare(ctx, &vagrant.Machine{Name: "pause"}, nil, nil) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} func TestSyncedFolder_Info(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) @@ -249,7 +498,7 @@ func TestSyncedFolder_IsUsable(t *testing.T) { t.Fatalf("bad %#v", raw) } - resp, err := impl.IsUsable(&vagrant.Machine{}) + resp, err := impl.IsUsable(context.Background(), &vagrant.Machine{}) if err != nil { t.Fatalf("bad resp: %s", err) } @@ -258,6 +507,68 @@ func TestSyncedFolder_IsUsable(t *testing.T) { } } +func TestSyncedFolder_IsUsable_context_cancel(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.IsUsable(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + select { + case <-n: + t.Fatalf("unexpected completion") + case <-time.After(2 * time.Millisecond): + cancel() + } + <-n + if err != context.Canceled { + t.Fatalf("bad resp: %s", err) + } +} + +func TestSyncedFolder_IsUsable_context_timeout(t *testing.T) { + client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ + "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) + defer server.Stop() + defer client.Close() + + raw, err := client.Dispense("folder") + if err != nil { + t.Fatalf("err: %s", err) + } + impl, ok := raw.(SyncedFolder) + if !ok { + t.Fatalf("bad %#v", raw) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + n := make(chan struct{}) + go func() { + _, err = impl.IsUsable(ctx, &vagrant.Machine{Name: "pause"}) + n <- struct{}{} + }() + <-n + if err != context.DeadlineExceeded { + t.Fatalf("bad resp: %s", err) + } +} + func TestSyncedFolder_Name(t *testing.T) { client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{ "folder": &SyncedFolderPlugin{Impl: &MockSyncedFolder{}}}) @@ -295,7 +606,7 @@ func TestSyncedFolder_MachineUI_output(t *testing.T) { } go func() { - err := impl.Cleanup(&vagrant.Machine{}, map[string]interface{}{"ui": true}) + err := impl.Cleanup(context.Background(), &vagrant.Machine{}, map[string]interface{}{"ui": true}) if err != nil { t.Fatalf("bad resp: %s", err) } diff --git a/ext/go-plugin/vagrant/provider.go b/ext/go-plugin/vagrant/provider.go index abe42ffc1..7527df279 100644 --- a/ext/go-plugin/vagrant/provider.go +++ b/ext/go-plugin/vagrant/provider.go @@ -1,15 +1,19 @@ package vagrant +import ( + "context" +) + type Provider interface { Info() *ProviderInfo - Action(actionName string, machData *Machine) ([]string, error) - IsInstalled(machData *Machine) (bool, error) - IsUsable(machData *Machine) (bool, error) - MachineIdChanged(machData *Machine) error + Action(ctx context.Context, actionName string, machData *Machine) ([]string, error) + IsInstalled(ctx context.Context, machData *Machine) (bool, error) + IsUsable(ctx context.Context, machData *Machine) (bool, error) + MachineIdChanged(ctx context.Context, machData *Machine) error Name() string - RunAction(actionName string, args interface{}, machData *Machine) (interface{}, error) - SshInfo(machData *Machine) (*SshInfo, error) - State(machData *Machine) (*MachineState, error) + RunAction(ctx context.Context, actionName string, args interface{}, machData *Machine) (interface{}, error) + SshInfo(ctx context.Context, machData *Machine) (*SshInfo, error) + State(ctx context.Context, machData *Machine) (*MachineState, error) Config GuestCapabilities diff --git a/ext/go-plugin/vagrant/synced_folder.go b/ext/go-plugin/vagrant/synced_folder.go index b81e780ca..c3c260b52 100644 --- a/ext/go-plugin/vagrant/synced_folder.go +++ b/ext/go-plugin/vagrant/synced_folder.go @@ -1,5 +1,9 @@ package vagrant +import ( + "context" +) + type FolderList map[string]interface{} type FolderOptions map[string]interface{} @@ -9,13 +13,13 @@ type SyncedFolderInfo struct { } type SyncedFolder interface { - Cleanup(m *Machine, opts FolderOptions) error - Disable(m *Machine, f FolderList, opts FolderOptions) error - Enable(m *Machine, f FolderList, opts FolderOptions) error + Cleanup(ctx context.Context, m *Machine, opts FolderOptions) error + Disable(ctx context.Context, m *Machine, f FolderList, opts FolderOptions) error + Enable(ctx context.Context, m *Machine, f FolderList, opts FolderOptions) error Info() *SyncedFolderInfo - IsUsable(m *Machine) (bool, error) + IsUsable(ctx context.Context, m *Machine) (bool, error) Name() string - Prepare(m *Machine, f FolderList, opts FolderOptions) error + Prepare(ctx context.Context, m *Machine, f FolderList, opts FolderOptions) error GuestCapabilities HostCapabilities diff --git a/go.mod b/go.mod index 8f38301f7..2a6a6b99e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/hashicorp/vagrant require ( + github.com/LK4D4/joincontext v0.0.0-20171026170139-1724345da6d5 github.com/dylanmei/iso8601 v0.1.0 // indirect github.com/dylanmei/winrmtest v0.0.0-20190225150635-99b7fe2fddf1 github.com/golang/protobuf v1.3.0 diff --git a/go.sum b/go.sum index 06319175e..aee40f545 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/Azure/go-ntlmssp v0.0.0-20180810175552-4a21cbd618b4/go.mod h1:chxPXzS github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/ChrisTrenkamp/goxpath v0.0.0-20170922090931-c385f95c6022 h1:y8Gs8CzNfDF5AZvjr+5UyGQvQEBL7pwo+v+wX6q9JI8= github.com/ChrisTrenkamp/goxpath v0.0.0-20170922090931-c385f95c6022/go.mod h1:nuWgzSkT5PnyOd+272uUmV0dnAnAn42Mk7PiQC5VzN4= +github.com/LK4D4/joincontext v0.0.0-20171026170139-1724345da6d5 h1:U7q69tqXiCf6m097GRlNQB0/6SI1qWIOHYHhCEvDxF4= +github.com/LK4D4/joincontext v0.0.0-20171026170139-1724345da6d5/go.mod h1:nxQPcNPR/34g+HcK2hEsF99O+GJgIkW/OmPl8wtzhmk= github.com/antchfx/xpath v0.0.0-20190129040759-c8489ed3251e h1:ptBAamGVd6CfRsUtyHD+goy2JGhv1QC32v3gqM8mYAM= github.com/antchfx/xpath v0.0.0-20190129040759-c8489ed3251e/go.mod h1:Yee4kTMuNiPYJ7nSNorELQMr1J33uOpXDMByNYhvtNk= github.com/antchfx/xquery v0.0.0-20180515051857-ad5b8c7a47b0 h1:JaCC8jz0zdMLk2m+qCCVLLLM/PL93p84w4pK3aJWj60=