From d947c0ebd04e2ca636e7fe59558b80b5ebf079f7 Mon Sep 17 00:00:00 2001 From: Nobuhiro MIKI Date: Tue, 20 Aug 2024 13:24:29 +0000 Subject: [PATCH] Support context in VCPC thread. TODO: We need to more test cases. Signed-off-by: Nobuhiro MIKI --- machine/machine.go | 69 ++++++++++++++++++++++++----------------- machine/machine_test.go | 9 ++++-- vmm/vmm.go | 5 +-- 3 files changed, 50 insertions(+), 33 deletions(-) diff --git a/machine/machine.go b/machine/machine.go index 1075cc2..86d9063 100644 --- a/machine/machine.go +++ b/machine/machine.go @@ -2,6 +2,7 @@ package machine import ( "bytes" + "context" "debug/elf" "encoding/binary" "encoding/hex" @@ -803,7 +804,7 @@ func (m *Machine) SingleStep(onoff bool) error { // RunInfiniteLoop runs the guest cpu until there is an error. // If the error is ErrExitDebug, this function can be called again. -func (m *Machine) RunInfiniteLoop(cpu int) error { +func (m *Machine) RunInfiniteLoop(ctx context.Context, cpu int) error { // https://www.kernel.org/doc/Documentation/virtual/kvm/api.txt // - vcpu ioctls: These query and set attributes that control the operation // of a single virtual cpu. @@ -822,17 +823,22 @@ func (m *Machine) RunInfiniteLoop(cpu int) error { defer runtime.UnlockOSThread() for { - isContinue, err := m.RunOnce(cpu) - if isContinue { - if err != nil { - fmt.Printf("%v\r\n", err) - } + select { + case <-ctx.Done(): + return nil + default: + isContinue, err := m.RunOnce(cpu) + if isContinue { + if err != nil { + fmt.Printf("%v\r\n", err) + } - continue - } + continue + } - if err != nil { - return err + if err != nil { + return err + } } } } @@ -1248,7 +1254,7 @@ func initVMandVCPU( return kvmFd, vmFd, vcpuFds, runs, nil } -func (m *Machine) VCPU(stdout io.Writer, cpu, traceCount int) error { +func (m *Machine) VCPU(ctx context.Context, stdout io.Writer, cpu, traceCount int) error { trace := traceCount > 0 var err error @@ -1256,28 +1262,33 @@ func (m *Machine) VCPU(stdout io.Writer, cpu, traceCount int) error { // exit this loop after a certain number of instructions // were run. for tc := 0; ; tc++ { - err = m.RunInfiniteLoop(cpu) - if err == nil { - continue - } + select { + case <-ctx.Done(): + return nil + default: + err = m.RunInfiniteLoop(ctx, cpu) + if err == nil { + continue + } - if !errors.Is(err, kvm.ErrDebug) { - return fmt.Errorf("CPU %d: %w", cpu, err) - } + if !errors.Is(err, kvm.ErrDebug) { + return fmt.Errorf("CPU %d: %w", cpu, err) + } - if err := m.SingleStep(trace); err != nil { - fmt.Fprintf(stdout, "Setting trace to %v:%v", trace, err) - } + if err := m.SingleStep(trace); err != nil { + fmt.Fprintf(stdout, "Setting trace to %v:%v", trace, err) + } - if tc%traceCount != 0 { - continue - } + if tc%traceCount != 0 { + continue + } - _, r, s, err := m.Inst(cpu) - if err != nil { - fmt.Fprintf(stdout, "disassembling after debug exit:%v", err) - } else { - fmt.Fprintf(stdout, "%#x:%s\r\n", r.RIP, s) + _, r, s, err := m.Inst(cpu) + if err != nil { + fmt.Fprintf(stdout, "disassembling after debug exit:%v", err) + } else { + fmt.Fprintf(stdout, "%#x:%s\r\n", r.RIP, s) + } } } } diff --git a/machine/machine_test.go b/machine/machine_test.go index 43a7aba..ae17215 100644 --- a/machine/machine_test.go +++ b/machine/machine_test.go @@ -2,6 +2,7 @@ package machine_test import ( "bytes" + "context" "errors" "fmt" "os" @@ -71,8 +72,10 @@ func testNewAndLoadLinux(t *testing.T, kernel, tap, guestIPv4, hostIPv4, prefixL m.RunData() + ctx := context.Background() + go func() { - if err = m.RunInfiniteLoop(0); err != nil { + if err = m.RunInfiniteLoop(ctx, 0); err != nil { panic(err) } }() @@ -146,8 +149,10 @@ func TestNewAndLoadEDK2PVH(t *testing.T) { // nolint:paralleltest m.RunData() + ctx := context.Background() + go func() { - if err = m.RunInfiniteLoop(0); err != nil { + if err = m.RunInfiniteLoop(ctx, 0); err != nil { panic(err) } }() diff --git a/vmm/vmm.go b/vmm/vmm.go index 23bb1d8..6c4410c 100644 --- a/vmm/vmm.go +++ b/vmm/vmm.go @@ -2,6 +2,7 @@ package vmm import ( "bufio" + "context" "fmt" "log" "os" @@ -104,7 +105,7 @@ func (v *VMM) Boot() error { return fmt.Errorf("setting trace to %v:%w", trace, err) } - g := new(errgroup.Group) + g, ctx := errgroup.WithContext(context.Background()) for cpu := 0; cpu < v.NCPUs; cpu++ { fmt.Printf("Start CPU %d of %d\r\n", cpu, v.NCPUs) @@ -112,7 +113,7 @@ func (v *VMM) Boot() error { i := cpu f := func() error { - return v.VCPU(os.Stderr, i, v.TraceCount) + return v.VCPU(ctx, os.Stderr, i, v.TraceCount) } g.Go(f)