Skip to content

Commit

Permalink
chore: add Add Elements method and update for workspace provider (#78)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Oct 23, 2024
1 parent 3901872 commit cd749df
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 35 deletions.
80 changes: 57 additions & 23 deletions datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ type Dataset struct {
}

type datasetRequest struct {
Input string `json:"input"`
Workspace string `json:"workspace"`
DatasetToolRepo string `json:"datasetToolRepo"`
Input string `json:"input"`
WorkspaceID string `json:"workspaceID"`
DatasetToolRepo string `json:"datasetToolRepo"`
Env []string `json:"env"`
}

type createDatasetArgs struct {
Expand All @@ -47,6 +48,11 @@ type addDatasetElementArgs struct {
ElementContent string `json:"elementContent"`
}

type addDatasetElementsArgs struct {
DatasetID string `json:"datasetID"`
Elements []DatasetElement `json:"elements"`
}

type listDatasetElementArgs struct {
DatasetID string `json:"datasetID"`
}
Expand All @@ -56,15 +62,16 @@ type getDatasetElementArgs struct {
Element string `json:"element"`
}

func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]DatasetMeta, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]DatasetMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{
Input: "{}",
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return nil, err
Expand All @@ -77,9 +84,9 @@ func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]Datas
return datasets, nil
}

func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, description string) (Dataset, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) CreateDataset(ctx context.Context, workspaceID, name, description string) (Dataset, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := createDatasetArgs{
Expand All @@ -93,8 +100,9 @@ func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, descript

out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{
Input: string(argsJSON),
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return Dataset{}, err
Expand All @@ -107,9 +115,9 @@ func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, descript
return dataset, nil
}

func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) AddDatasetElement(ctx context.Context, workspaceID, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := addDatasetElementArgs{
Expand All @@ -125,8 +133,9 @@ func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID,

out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{
Input: string(argsJSON),
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return DatasetElementMeta{}, err
Expand All @@ -139,9 +148,32 @@ func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID,
return element, nil
}

func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetID string) ([]DatasetElementMeta, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) AddDatasetElements(ctx context.Context, workspaceID, datasetID string, elements []DatasetElement) error {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := addDatasetElementsArgs{
DatasetID: datasetID,
Elements: elements,
}
argsJSON, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal element args: %w", err)
}

_, err = g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{
Input: string(argsJSON),
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
return err
}

func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datasetID string) ([]DatasetElementMeta, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := listDatasetElementArgs{
Expand All @@ -154,8 +186,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetI

out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{
Input: string(argsJSON),
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return nil, err
Expand All @@ -168,9 +201,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetI
return elements, nil
}

func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID, elementName string) (DatasetElement, error) {
if workspace == "" {
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetID, elementName string) (DatasetElement, error) {
if workspaceID == "" {
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}

args := getDatasetElementArgs{
Expand All @@ -184,8 +217,9 @@ func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID,

out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{
Input: string(argsJSON),
Workspace: workspace,
WorkspaceID: workspaceID,
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
Env: g.globalOpts.Env,
})
if err != nil {
return DatasetElement{}, err
Expand Down
48 changes: 36 additions & 12 deletions datasets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,72 @@ package gptscript

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
)

func TestDatasets(t *testing.T) {
workspace, err := os.MkdirTemp("", "go-gptscript-test")
workspaceID, err := g.CreateWorkspace(context.Background(), "directory")
require.NoError(t, err)

defer func() {
_ = os.RemoveAll(workspace)
_ = g.DeleteWorkspace(context.Background(), DeleteWorkspaceOptions{WorkspaceID: workspaceID})
}()

// Create a dataset
dataset, err := g.CreateDataset(context.Background(), workspace, "test-dataset", "This is a test dataset")
dataset, err := g.CreateDataset(context.Background(), workspaceID, "test-dataset", "This is a test dataset")
require.NoError(t, err)
require.Equal(t, "test-dataset", dataset.Name)
require.Equal(t, "This is a test dataset", dataset.Description)
require.Equal(t, 0, len(dataset.Elements))

// Add an element
elementMeta, err := g.AddDatasetElement(context.Background(), workspace, dataset.ID, "test-element", "This is a test element", "This is the content")
elementMeta, err := g.AddDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element", "This is a test element", "This is the content")
require.NoError(t, err)
require.Equal(t, "test-element", elementMeta.Name)
require.Equal(t, "This is a test element", elementMeta.Description)

// Get the element
element, err := g.GetDatasetElement(context.Background(), workspace, dataset.ID, "test-element")
// Add two more
err = g.AddDatasetElements(context.Background(), workspaceID, dataset.ID, []DatasetElement{
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-2",
Description: "This is a test element 2",
},
Contents: "This is the content 2",
},
{
DatasetElementMeta: DatasetElementMeta{
Name: "test-element-3",
Description: "This is a test element 3",
},
Contents: "This is the content 3",
},
})
require.NoError(t, err)

// Get the first element
element, err := g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element")
require.NoError(t, err)
require.Equal(t, "test-element", element.Name)
require.Equal(t, "This is a test element", element.Description)
require.Equal(t, "This is the content", element.Contents)

// Get the third element
element, err = g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element-3")
require.NoError(t, err)
require.Equal(t, "test-element-3", element.Name)
require.Equal(t, "This is a test element 3", element.Description)
require.Equal(t, "This is the content 3", element.Contents)

// List elements in the dataset
elements, err := g.ListDatasetElements(context.Background(), workspace, dataset.ID)
elements, err := g.ListDatasetElements(context.Background(), workspaceID, dataset.ID)
require.NoError(t, err)
require.Equal(t, 1, len(elements))
require.Equal(t, "test-element", elements[0].Name)
require.Equal(t, "This is a test element", elements[0].Description)
require.Equal(t, 3, len(elements))

// List datasets
datasets, err := g.ListDatasets(context.Background(), workspace)
datasets, err := g.ListDatasets(context.Background(), workspaceID)
require.NoError(t, err)
require.Equal(t, 1, len(datasets))
require.Equal(t, "test-dataset", datasets[0].Name)
Expand Down

0 comments on commit cd749df

Please sign in to comment.