diff --git a/internal/api/tickers.go b/internal/api/tickers.go index 6bbb718..0af1e87 100644 --- a/internal/api/tickers.go +++ b/internal/api/tickers.go @@ -42,6 +42,12 @@ type TickerLocationParam struct { Lon float64 `json:"lon"` } +type TickerUsersParam struct { + Users []struct { + ID int `json:"id"` + } `json:"users" binding:"required"` +} + type TickerWebsitesParam struct { Websites []TickerWebsiteParam `json:"websites" binding:"required"` } @@ -145,18 +151,24 @@ func (h *handler) PutTickerUsers(c *gin.Context) { return } - var body struct { - Users []storage.User `json:"users" binding:"required"` - } - + var body TickerUsersParam err = c.Bind(&body) if err != nil { c.JSON(http.StatusBadRequest, response.ErrorResponse(response.CodeDefault, response.FormError)) return } - ticker.Users = body.Users + userIds := make([]int, 0) + for _, user := range body.Users { + userIds = append(userIds, user.ID) + } + users, err := h.storage.FindUsersByIDs(userIds) + if err != nil { + c.JSON(http.StatusInternalServerError, response.ErrorResponse(response.CodeDefault, response.StorageError)) + return + } + ticker.Users = users err = h.storage.SaveTicker(&ticker) if err != nil { c.JSON(http.StatusInternalServerError, response.ErrorResponse(response.CodeDefault, response.StorageError)) diff --git a/internal/api/tickers_test.go b/internal/api/tickers_test.go index af83c12..911eae3 100644 --- a/internal/api/tickers_test.go +++ b/internal/api/tickers_test.go @@ -270,11 +270,25 @@ func (s *TickerTestSuite) TestPutTickerUsers() { s.store.AssertExpectations(s.T()) }) + s.Run("when find users not working", func() { + s.ctx.Set("ticker", storage.Ticker{}) + body := `{"users":[{"id":1},{"id":2},{"id":3}]}` + s.ctx.Request = httptest.NewRequest(http.MethodPut, "/v1/admin/tickers/1/user", strings.NewReader(body)) + s.ctx.Request.Header.Add("Content-Type", "application/json") + s.store.On("FindUsersByIDs", mock.Anything).Return(nil, errors.New("storage error")).Once() + h := s.handler() + h.PutTickerUsers(s.ctx) + + s.Equal(http.StatusInternalServerError, s.w.Code) + s.store.AssertExpectations(s.T()) + }) + s.Run("when storage returns error", func() { s.ctx.Set("ticker", storage.Ticker{}) body := `{"users":[{"id":1},{"id":2},{"id":3}]}` s.ctx.Request = httptest.NewRequest(http.MethodPut, "/v1/admin/tickers/1/user", strings.NewReader(body)) s.ctx.Request.Header.Add("Content-Type", "application/json") + s.store.On("FindUsersByIDs", mock.Anything).Return([]storage.User{{ID: 1}, {ID: 2}, {ID: 3}}, nil).Once() s.store.On("SaveTicker", mock.Anything).Return(errors.New("storage error")).Once() h := s.handler() h.PutTickerUsers(s.ctx) @@ -288,6 +302,7 @@ func (s *TickerTestSuite) TestPutTickerUsers() { body := `{"users":[{"id":1},{"id":2},{"id":3}]}` s.ctx.Request = httptest.NewRequest(http.MethodPut, "/v1/admin/tickers/1/user", strings.NewReader(body)) s.ctx.Request.Header.Add("Content-Type", "application/json") + s.store.On("FindUsersByIDs", mock.Anything).Return([]storage.User{{ID: 1}, {ID: 2}, {ID: 3}}, nil).Once() s.store.On("SaveTicker", mock.Anything).Return(nil).Once() h := s.handler() h.PutTickerUsers(s.ctx) diff --git a/internal/storage/sql_storage.go b/internal/storage/sql_storage.go index 27c3ab5..f30031c 100644 --- a/internal/storage/sql_storage.go +++ b/internal/storage/sql_storage.go @@ -54,6 +54,11 @@ func (s *SqlStorage) FindUserByID(id int, opts ...func(*gorm.DB) *gorm.DB) (User func (s *SqlStorage) FindUsersByIDs(ids []int, opts ...func(*gorm.DB) *gorm.DB) ([]User, error) { users := make([]User, 0) + + if len(ids) == 0 { + return users, nil + } + db := s.prepareDb(opts...) err := db.Find(&users, ids).Error diff --git a/internal/storage/sql_storage_test.go b/internal/storage/sql_storage_test.go index 0b600ad..ea54e26 100644 --- a/internal/storage/sql_storage_test.go +++ b/internal/storage/sql_storage_test.go @@ -151,6 +151,12 @@ func (s *SqlStorageTestSuite) TestFindUsersByIDs() { s.Empty(users) }) + s.Run("when empty ids", func() { + users, err := s.store.FindUsersByIDs([]int{}) + s.NoError(err) + s.Empty(users) + }) + s.Run("when users exist", func() { user.Tickers = []Ticker{{ID: 1}} err := s.db.Create(&user).Error