diff --git a/cmd/agola/cmd/serve.go b/cmd/agola/cmd/serve.go index 3dd0ea8..6f88d5f 100644 --- a/cmd/agola/cmd/serve.go +++ b/cmd/agola/cmd/serve.go @@ -20,6 +20,7 @@ import ( "github.com/sorintlab/agola/cmd" "github.com/sorintlab/agola/internal/services/config" + "github.com/sorintlab/agola/internal/services/configstore" "github.com/sorintlab/agola/internal/services/runservice/executor" rsscheduler "github.com/sorintlab/agola/internal/services/runservice/scheduler" "github.com/sorintlab/agola/internal/services/scheduler" @@ -116,6 +117,11 @@ func serve(cmd *cobra.Command, args []string) error { return errors.Wrapf(err, "failed to start run service executor") } + cs, err := configstore.NewConfigStore(ctx, &c.ConfigStore) + if err != nil { + return errors.Wrapf(err, "failed to start config store") + } + sched1, err := scheduler.NewScheduler(&c.Scheduler) if err != nil { return errors.Wrapf(err, "failed to start scheduler") @@ -125,6 +131,7 @@ func serve(cmd *cobra.Command, args []string) error { go func() { errCh <- rsex1.Run(ctx) }() go func() { errCh <- rssched1.Run(ctx) }() + go func() { errCh <- cs.Run(ctx) }() go func() { errCh <- sched1.Run(ctx) }() return <-errCh diff --git a/go.mod b/go.mod index 91f1139..c8b2af3 100644 --- a/go.mod +++ b/go.mod @@ -11,8 +11,7 @@ require ( github.com/docker/docker v1.13.1 github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.3.3 // indirect - github.com/dsnet/compress v0.0.0-20171208185109-cc9eb1d7ad76 // indirect - github.com/go-bindata/go-bindata v1.0.0 // indirect + github.com/go-bindata/go-bindata v1.0.0 github.com/go-ini/ini v1.42.0 // indirect github.com/go-sql-driver/mysql v1.4.1 // indirect github.com/google/go-cmp v0.3.0 @@ -41,7 +40,7 @@ require ( github.com/spf13/cobra v0.0.3 go.etcd.io/etcd v0.0.0-20181128220305-dedae6eb7c25 go.uber.org/zap v1.9.1 - golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 + golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 // indirect golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e // indirect golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect google.golang.org/appengine v1.4.0 // indirect diff --git a/go.sum b/go.sum index c07c5b7..9ebe1f6 100644 --- a/go.sum +++ b/go.sum @@ -32,7 +32,6 @@ github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKoh github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.3.3 h1:Xk8S3Xj5sLGlG5g67hJmYMmUgXv5N4PhkjJHHqrwnTk= github.com/docker/go-units v0.3.3/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/dsnet/compress v0.0.0-20171208185109-cc9eb1d7ad76/go.mod h1:KjxHHirfLaw19iGT70HvVjHQsL1vq1SRQB4yOsAfy2s= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4 h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= @@ -42,7 +41,6 @@ github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-bindata/go-bindata v1.0.0 h1:DZ34txDXWn1DyWa+vQf7V9ANc2ILTtrEjtlsdJRF26M= github.com/go-bindata/go-bindata v1.0.0/go.mod h1:xK8Dsgwmeed+BBsSy2XTopBn/8uK2HWuGSnA11C3Joo= -github.com/go-ini/ini v1.41.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-ini/ini v1.42.0 h1:TWr1wGj35+UiWHlBA8er89seFXxzwFn11spilrrj+38= github.com/go-ini/ini v1.42.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= diff --git a/internal/services/configstore/api/client.go b/internal/services/configstore/api/client.go new file mode 100644 index 0000000..b7781de --- /dev/null +++ b/internal/services/configstore/api/client.go @@ -0,0 +1,318 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/sorintlab/agola/internal/services/types" + + "github.com/pkg/errors" +) + +var jsonContent = http.Header{"content-type": []string{"application/json"}} + +// Client represents a Gogs API client. +type Client struct { + url string + client *http.Client +} + +// NewClient initializes and returns a API client. +func NewClient(url string) *Client { + return &Client{ + url: strings.TrimSuffix(url, "/"), + client: &http.Client{}, + } +} + +// SetHTTPClient replaces default http.Client with user given one. +func (c *Client) SetHTTPClient(client *http.Client) { + c.client = client +} + +func (c *Client) doRequest(ctx context.Context, method, path string, query url.Values, header http.Header, ibody io.Reader) (*http.Response, error) { + u, err := url.Parse(c.url + "/api/v1alpha" + path) + if err != nil { + return nil, err + } + u.RawQuery = query.Encode() + + req, err := http.NewRequest(method, u.String(), ibody) + req = req.WithContext(ctx) + if err != nil { + return nil, err + } + for k, v := range header { + req.Header[k] = v + } + + return c.client.Do(req) +} + +func (c *Client) getResponse(ctx context.Context, method, path string, query url.Values, header http.Header, ibody io.Reader) (*http.Response, error) { + resp, err := c.doRequest(ctx, method, path, query, header, ibody) + if err != nil { + return nil, err + } + + if resp.StatusCode/100 != 2 { + defer resp.Body.Close() + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if len(data) <= 1 { + return resp, errors.New(resp.Status) + } + + // TODO(sgotti) use a json error response + + return resp, errors.New(string(data)) + } + + return resp, nil +} + +func (c *Client) getParsedResponse(ctx context.Context, method, path string, query url.Values, header http.Header, ibody io.Reader, obj interface{}) (*http.Response, error) { + resp, err := c.getResponse(ctx, method, path, query, header, ibody) + if err != nil { + return resp, err + } + defer resp.Body.Close() + + d := json.NewDecoder(resp.Body) + + return resp, d.Decode(obj) +} + +func (c *Client) GetProject(ctx context.Context, projectID string) (*types.Project, *http.Response, error) { + project := new(types.Project) + resp, err := c.getParsedResponse(ctx, "GET", fmt.Sprintf("/project/%s", projectID), nil, jsonContent, nil, project) + return project, resp, err +} + +func (c *Client) GetProjectByName(ctx context.Context, projectName string) (*types.Project, *http.Response, error) { + project := new(types.Project) + resp, err := c.getParsedResponse(ctx, "GET", fmt.Sprintf("/projects/%s", projectName), nil, jsonContent, nil, project) + return project, resp, err +} + +func (c *Client) CreateProject(ctx context.Context, project *types.Project) (*types.Project, *http.Response, error) { + pj, err := json.Marshal(project) + if err != nil { + return nil, nil, err + } + + project = new(types.Project) + resp, err := c.getParsedResponse(ctx, "PUT", "/projects", nil, jsonContent, bytes.NewReader(pj), project) + return project, resp, err +} + +func (c *Client) DeleteProject(ctx context.Context, projectName string) (*http.Response, error) { + return c.getResponse(ctx, "DELETE", fmt.Sprintf("/projects/%s", projectName), nil, jsonContent, nil) +} + +func (c *Client) GetProjects(ctx context.Context, start string, limit int, asc bool) ([]*types.Project, *http.Response, error) { + q := url.Values{} + if start != "" { + q.Add("start", start) + } + if limit > 0 { + q.Add("limit", strconv.Itoa(limit)) + } + if asc { + q.Add("asc", "") + } + + projects := []*types.Project{} + resp, err := c.getParsedResponse(ctx, "GET", "/projects", q, jsonContent, nil, &projects) + return projects, resp, err +} + +func (c *Client) GetUser(ctx context.Context, userID string) (*types.User, *http.Response, error) { + user := new(types.User) + resp, err := c.getParsedResponse(ctx, "GET", fmt.Sprintf("/user/%s", userID), nil, jsonContent, nil, user) + return user, resp, err +} + +func (c *Client) GetUserByName(ctx context.Context, userName string) (*types.User, *http.Response, error) { + user := new(types.User) + resp, err := c.getParsedResponse(ctx, "GET", fmt.Sprintf("/users/%s", userName), nil, jsonContent, nil, user) + return user, resp, err +} + +func (c *Client) GetUserByToken(ctx context.Context, token string) (*types.User, *http.Response, error) { + q := url.Values{} + q.Add("query_type", "bytoken") + q.Add("token", token) + + users := []*types.User{} + resp, err := c.getParsedResponse(ctx, "GET", "/users", q, jsonContent, nil, &users) + if err != nil { + return nil, resp, err + } + return users[0], resp, err +} + +func (c *Client) GetUserByLinkedAccountRemoteUserAndSource(ctx context.Context, remoteUserID, remoteSourceID string) (*types.User, *http.Response, error) { + q := url.Values{} + q.Add("query_type", "byremoteuser") + q.Add("remoteuserid", remoteUserID) + q.Add("remotesourceid", remoteSourceID) + + users := []*types.User{} + resp, err := c.getParsedResponse(ctx, "GET", "/users", q, jsonContent, nil, &users) + if err != nil { + return nil, resp, err + } + return users[0], resp, err +} + +func (c *Client) GetUserByLinkedAccount(ctx context.Context, linkedAccountID string) (*types.User, *http.Response, error) { + q := url.Values{} + q.Add("query_type", "bylinkedaccount") + q.Add("linkedaccountid", linkedAccountID) + + users := []*types.User{} + resp, err := c.getParsedResponse(ctx, "GET", "/users", q, jsonContent, nil, &users) + if err != nil { + return nil, resp, err + } + return users[0], resp, err +} + +func (c *Client) CreateUser(ctx context.Context, user *types.User) (*types.User, *http.Response, error) { + uj, err := json.Marshal(user) + if err != nil { + return nil, nil, err + } + + user = new(types.User) + resp, err := c.getParsedResponse(ctx, "PUT", "/users", nil, jsonContent, bytes.NewReader(uj), user) + return user, resp, err +} + +func (c *Client) DeleteUser(ctx context.Context, userName string) (*http.Response, error) { + return c.getResponse(ctx, "DELETE", fmt.Sprintf("/users/%s", userName), nil, jsonContent, nil) +} + +func (c *Client) GetUsers(ctx context.Context, start string, limit int, asc bool) ([]*types.User, *http.Response, error) { + q := url.Values{} + if start != "" { + q.Add("start", start) + } + if limit > 0 { + q.Add("limit", strconv.Itoa(limit)) + } + if asc { + q.Add("asc", "") + } + + users := []*types.User{} + resp, err := c.getParsedResponse(ctx, "GET", "/users", q, jsonContent, nil, &users) + return users, resp, err +} + +func (c *Client) CreateUserLA(ctx context.Context, userName string, req *CreateUserLARequest) (*types.LinkedAccount, *http.Response, error) { + reqj, err := json.Marshal(req) + if err != nil { + return nil, nil, err + } + + la := new(types.LinkedAccount) + resp, err := c.getParsedResponse(ctx, "PUT", fmt.Sprintf("/users/%s/linkedaccounts", userName), nil, jsonContent, bytes.NewReader(reqj), la) + return la, resp, err +} + +func (c *Client) DeleteUserLA(ctx context.Context, userName, laID string) (*http.Response, error) { + return c.getResponse(ctx, "DELETE", fmt.Sprintf("/users/%s/linkedaccounts/%s", userName, laID), nil, jsonContent, nil) +} + +func (c *Client) UpdateUserLA(ctx context.Context, userName, laID string, req *UpdateUserLARequest) (*types.LinkedAccount, *http.Response, error) { + reqj, err := json.Marshal(req) + if err != nil { + return nil, nil, err + } + + la := new(types.LinkedAccount) + resp, err := c.getParsedResponse(ctx, "PUT", fmt.Sprintf("/users/%s/linkedaccounts/%s", userName, laID), nil, jsonContent, bytes.NewReader(reqj), la) + return la, resp, err +} + +func (c *Client) CreateUserToken(ctx context.Context, userName string, req *CreateUserTokenRequest) (*CreateUserTokenResponse, *http.Response, error) { + reqj, err := json.Marshal(req) + if err != nil { + return nil, nil, err + } + + tresp := new(CreateUserTokenResponse) + resp, err := c.getParsedResponse(ctx, "PUT", fmt.Sprintf("/users/%s/tokens", userName), nil, jsonContent, bytes.NewReader(reqj), tresp) + return tresp, resp, err +} + +func (c *Client) GetRemoteSource(ctx context.Context, rsID string) (*types.RemoteSource, *http.Response, error) { + rs := new(types.RemoteSource) + resp, err := c.getParsedResponse(ctx, "GET", fmt.Sprintf("/remotesource/%s", rsID), nil, jsonContent, nil, rs) + return rs, resp, err +} + +func (c *Client) GetRemoteSourceByName(ctx context.Context, rsName string) (*types.RemoteSource, *http.Response, error) { + rs := new(types.RemoteSource) + resp, err := c.getParsedResponse(ctx, "GET", fmt.Sprintf("/remotesources/%s", rsName), nil, jsonContent, nil, rs) + return rs, resp, err +} + +func (c *Client) GetRemoteSources(ctx context.Context, start string, limit int, asc bool) ([]*types.RemoteSource, *http.Response, error) { + q := url.Values{} + if start != "" { + q.Add("start", start) + } + if limit > 0 { + q.Add("limit", strconv.Itoa(limit)) + } + if asc { + q.Add("asc", "") + } + + rss := []*types.RemoteSource{} + resp, err := c.getParsedResponse(ctx, "GET", "/remotesources", q, jsonContent, nil, &rss) + return rss, resp, err +} + +func (c *Client) CreateRemoteSource(ctx context.Context, rs *types.RemoteSource) (*types.RemoteSource, *http.Response, error) { + uj, err := json.Marshal(rs) + if err != nil { + return nil, nil, err + } + + rs = new(types.RemoteSource) + resp, err := c.getParsedResponse(ctx, "PUT", "/remotesources", nil, jsonContent, bytes.NewReader(uj), rs) + return rs, resp, err +} + +func (c *Client) DeleteRemoteSource(ctx context.Context, name string) (*http.Response, error) { + return c.getResponse(ctx, "DELETE", fmt.Sprintf("/remotesources/%s", name), nil, jsonContent, nil) +} diff --git a/internal/services/configstore/api/project.go b/internal/services/configstore/api/project.go new file mode 100644 index 0000000..e81a4c7 --- /dev/null +++ b/internal/services/configstore/api/project.go @@ -0,0 +1,207 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "encoding/json" + "net/http" + "strconv" + + "github.com/sorintlab/agola/internal/db" + "github.com/sorintlab/agola/internal/services/configstore/command" + "github.com/sorintlab/agola/internal/services/configstore/readdb" + "github.com/sorintlab/agola/internal/services/types" + + "github.com/gorilla/mux" + "go.uber.org/zap" +) + +type GetProjectHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewGetProjectHandler(logger *zap.Logger, readDB *readdb.ReadDB) *GetProjectHandler { + return &GetProjectHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *GetProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + projectID := vars["projectid"] + + var project *types.Project + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + project, err = h.readDB.GetProject(tx, projectID) + return err + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if project == nil { + http.Error(w, "", http.StatusNotFound) + return + } + + if err := json.NewEncoder(w).Encode(project); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type GetProjectByNameHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewGetProjectByNameHandler(logger *zap.Logger, readDB *readdb.ReadDB) *GetProjectByNameHandler { + return &GetProjectByNameHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *GetProjectByNameHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + projectName := vars["projectname"] + + var project *types.Project + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + project, err = h.readDB.GetProjectByName(tx, projectName) + return err + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if project == nil { + http.Error(w, "", http.StatusNotFound) + return + } + + if err := json.NewEncoder(w).Encode(project); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type CreateProjectHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler + readDB *readdb.ReadDB +} + +func NewCreateProjectHandler(logger *zap.Logger, ch *command.CommandHandler) *CreateProjectHandler { + return &CreateProjectHandler{log: logger.Sugar(), ch: ch} +} + +func (h *CreateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req types.Project + d := json.NewDecoder(r.Body) + if err := d.Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + project, err := h.ch.CreateProject(ctx, &req) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := json.NewEncoder(w).Encode(project); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type DeleteProjectHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler +} + +func NewDeleteProjectHandler(logger *zap.Logger, ch *command.CommandHandler) *DeleteProjectHandler { + return &DeleteProjectHandler{log: logger.Sugar(), ch: ch} +} + +func (h *DeleteProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + vars := mux.Vars(r) + projectName := vars["projectname"] + + if err := h.ch.DeleteProject(ctx, projectName); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } +} + +const ( + DefaultProjectsLimit = 10 + MaxProjectsLimit = 20 +) + +type ProjectsHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewProjectsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectsHandler { + return &ProjectsHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *ProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + + limitS := query.Get("limit") + limit := DefaultProjectsLimit + if limitS != "" { + var err error + limit, err = strconv.Atoi(limitS) + if err != nil { + http.Error(w, "", http.StatusBadRequest) + return + } + } + if limit < 0 { + http.Error(w, "limit must be greater or equal than 0", http.StatusBadRequest) + return + } + if limit > MaxProjectsLimit { + limit = MaxProjectsLimit + } + asc := false + if _, ok := query["asc"]; ok { + asc = true + } + + start := query.Get("start") + + projects, err := h.readDB.GetProjects(start, limit, asc) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if err := json.NewEncoder(w).Encode(projects); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} diff --git a/internal/services/configstore/api/remotesource.go b/internal/services/configstore/api/remotesource.go new file mode 100644 index 0000000..9ccfcd2 --- /dev/null +++ b/internal/services/configstore/api/remotesource.go @@ -0,0 +1,207 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "encoding/json" + "net/http" + "strconv" + + "github.com/sorintlab/agola/internal/db" + "github.com/sorintlab/agola/internal/services/configstore/command" + "github.com/sorintlab/agola/internal/services/configstore/readdb" + "github.com/sorintlab/agola/internal/services/types" + + "github.com/gorilla/mux" + "go.uber.org/zap" +) + +type GetRemoteSourceHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewGetRemoteSourceHandler(logger *zap.Logger, readDB *readdb.ReadDB) *GetRemoteSourceHandler { + return &GetRemoteSourceHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *GetRemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + remoteSourceID := vars["id"] + + var remoteSource *types.RemoteSource + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + remoteSource, err = h.readDB.GetRemoteSource(tx, remoteSourceID) + return err + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if remoteSource == nil { + http.Error(w, "", http.StatusNotFound) + return + } + + if err := json.NewEncoder(w).Encode(remoteSource); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type GetRemoteSourceByNameHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewGetRemoteSourceByNameHandler(logger *zap.Logger, readDB *readdb.ReadDB) *GetRemoteSourceByNameHandler { + return &GetRemoteSourceByNameHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *GetRemoteSourceByNameHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + remoteSourceName := vars["name"] + + var remoteSource *types.RemoteSource + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + remoteSource, err = h.readDB.GetRemoteSourceByName(tx, remoteSourceName) + return err + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if remoteSource == nil { + http.Error(w, "", http.StatusNotFound) + return + } + + if err := json.NewEncoder(w).Encode(remoteSource); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type CreateRemoteSourceHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler + readDB *readdb.ReadDB +} + +func NewCreateRemoteSourceHandler(logger *zap.Logger, ch *command.CommandHandler) *CreateRemoteSourceHandler { + return &CreateRemoteSourceHandler{log: logger.Sugar(), ch: ch} +} + +func (h *CreateRemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req types.RemoteSource + d := json.NewDecoder(r.Body) + if err := d.Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + remoteSource, err := h.ch.CreateRemoteSource(ctx, &req) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := json.NewEncoder(w).Encode(remoteSource); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type DeleteRemoteSourceHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler +} + +func NewDeleteRemoteSourceHandler(logger *zap.Logger, ch *command.CommandHandler) *DeleteRemoteSourceHandler { + return &DeleteRemoteSourceHandler{log: logger.Sugar(), ch: ch} +} + +func (h *DeleteRemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + vars := mux.Vars(r) + remoteSourceName := vars["name"] + + if err := h.ch.DeleteRemoteSource(ctx, remoteSourceName); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } +} + +const ( + DefaultRemoteSourcesLimit = 10 + MaxRemoteSourcesLimit = 20 +) + +type RemoteSourcesHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewRemoteSourcesHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteSourcesHandler { + return &RemoteSourcesHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + + limitS := query.Get("limit") + limit := DefaultRemoteSourcesLimit + if limitS != "" { + var err error + limit, err = strconv.Atoi(limitS) + if err != nil { + http.Error(w, "", http.StatusBadRequest) + return + } + } + if limit < 0 { + http.Error(w, "limit must be greater or equal than 0", http.StatusBadRequest) + return + } + if limit > MaxRemoteSourcesLimit { + limit = MaxRemoteSourcesLimit + } + asc := false + if _, ok := query["asc"]; ok { + asc = true + } + + start := query.Get("start") + + remoteSources, err := h.readDB.GetRemoteSources(start, limit, asc) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if err := json.NewEncoder(w).Encode(remoteSources); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} diff --git a/internal/services/configstore/api/user.go b/internal/services/configstore/api/user.go new file mode 100644 index 0000000..6901572 --- /dev/null +++ b/internal/services/configstore/api/user.go @@ -0,0 +1,460 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "encoding/json" + "net/http" + "strconv" + + "github.com/sorintlab/agola/internal/db" + "github.com/sorintlab/agola/internal/services/configstore/command" + "github.com/sorintlab/agola/internal/services/configstore/readdb" + "github.com/sorintlab/agola/internal/services/types" + "github.com/sorintlab/agola/internal/util" + + "github.com/gorilla/mux" + "go.uber.org/zap" +) + +type GetUserHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewGetUserHandler(logger *zap.Logger, readDB *readdb.ReadDB) *GetUserHandler { + return &GetUserHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *GetUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + userID := vars["userid"] + + var user *types.User + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = h.readDB.GetUser(tx, userID) + return err + }) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if user == nil { + http.Error(w, "", http.StatusNotFound) + return + } + + if err := json.NewEncoder(w).Encode(user); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type GetUserByNameHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewGetUserByNameHandler(logger *zap.Logger, readDB *readdb.ReadDB) *GetUserByNameHandler { + return &GetUserByNameHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *GetUserByNameHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + userName := vars["username"] + + var user *types.User + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = h.readDB.GetUserByName(tx, userName) + return err + }) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if user == nil { + http.Error(w, "", http.StatusNotFound) + return + } + + if err := json.NewEncoder(w).Encode(user); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type CreateUserHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler +} + +func NewCreateUserHandler(logger *zap.Logger, ch *command.CommandHandler) *CreateUserHandler { + return &CreateUserHandler{log: logger.Sugar(), ch: ch} +} + +func (h *CreateUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req types.User + d := json.NewDecoder(r.Body) + if err := d.Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + user, err := h.ch.CreateUser(ctx, &req) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := json.NewEncoder(w).Encode(user); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type DeleteUserHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler +} + +func NewDeleteUserHandler(logger *zap.Logger, ch *command.CommandHandler) *DeleteUserHandler { + return &DeleteUserHandler{log: logger.Sugar(), ch: ch} +} + +func (h *DeleteUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.log.Infof("deleteuserhandler") + ctx := r.Context() + + vars := mux.Vars(r) + userName := vars["username"] + + if err := h.ch.DeleteUser(ctx, userName); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } +} + +const ( + DefaultUsersLimit = 10 + MaxUsersLimit = 20 +) + +type UsersHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB +} + +func NewUsersHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UsersHandler { + return &UsersHandler{log: logger.Sugar(), readDB: readDB} +} + +func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + + limitS := query.Get("limit") + limit := DefaultUsersLimit + if limitS != "" { + var err error + limit, err = strconv.Atoi(limitS) + if err != nil { + http.Error(w, "", http.StatusBadRequest) + return + } + } + if limit < 0 { + http.Error(w, "limit must be greater or equal than 0", http.StatusBadRequest) + return + } + if limit > MaxUsersLimit { + limit = MaxUsersLimit + } + asc := false + if _, ok := query["asc"]; ok { + asc = true + } + + start := query.Get("start") + + // handle special queries, like get user by token + queryType := query.Get("query_type") + h.log.Infof("query_type: %s", queryType) + + var users []*types.User + switch queryType { + case "bytoken": + token := query.Get("token") + var user *types.User + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = h.readDB.GetUserByTokenValue(tx, token) + return err + }) + h.log.Infof("user: %s", util.Dump(user)) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if user == nil { + http.Error(w, "", http.StatusNotFound) + return + } + users = []*types.User{user} + case "bylinkedaccount": + linkedAccountID := query.Get("linkedaccountid") + var user *types.User + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = h.readDB.GetUserByLinkedAccount(tx, linkedAccountID) + return err + }) + h.log.Infof("user: %s", util.Dump(user)) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if user == nil { + http.Error(w, "", http.StatusNotFound) + return + } + users = []*types.User{user} + case "byremoteuser": + remoteUserID := query.Get("remoteuserid") + remoteSourceID := query.Get("remotesourceid") + var user *types.User + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = h.readDB.GetUserByLinkedAccountRemoteUserIDandSource(tx, remoteUserID, remoteSourceID) + return err + }) + h.log.Infof("user: %s", util.Dump(user)) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if user == nil { + http.Error(w, "", http.StatusNotFound) + return + } + users = []*types.User{user} + default: + // default query + err := h.readDB.Do(func(tx *db.Tx) error { + var err error + users, err = h.readDB.GetUsers(tx, start, limit, asc) + return err + }) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + + if err := json.NewEncoder(w).Encode(users); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type CreateUserLARequest struct { + RemoteSourceName string `json:"remote_source_name"` + RemoteUserID string `json:"remote_user_id"` + RemoteUserName string `json:"remote_user_name"` + UserAccessToken string `json:"user_access_token"` + Oauth2AccessToken string `json:"oauth2_access_token"` + Oauth2RefreshToken string `json:"oauth2_refresh_token"` +} + +type CreateUserLAHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler +} + +func NewCreateUserLAHandler(logger *zap.Logger, ch *command.CommandHandler) *CreateUserLAHandler { + return &CreateUserLAHandler{log: logger.Sugar(), ch: ch} +} + +func (h *CreateUserLAHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + vars := mux.Vars(r) + userName := vars["username"] + + var req CreateUserLARequest + d := json.NewDecoder(r.Body) + if err := d.Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + creq := &command.CreateUserLARequest{ + UserName: userName, + RemoteSourceName: req.RemoteSourceName, + RemoteUserID: req.RemoteUserID, + RemoteUserName: req.RemoteUserName, + Oauth2AccessToken: req.Oauth2AccessToken, + Oauth2RefreshToken: req.Oauth2RefreshToken, + UserAccessToken: req.UserAccessToken, + } + user, err := h.ch.CreateUserLA(ctx, creq) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := json.NewEncoder(w).Encode(user); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type DeleteUserLAHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler +} + +func NewDeleteUserLAHandler(logger *zap.Logger, ch *command.CommandHandler) *DeleteUserLAHandler { + return &DeleteUserLAHandler{log: logger.Sugar(), ch: ch} +} + +func (h *DeleteUserLAHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + vars := mux.Vars(r) + userName := vars["username"] + laID := vars["laid"] + + if err := h.ch.DeleteUserLA(ctx, userName, laID); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } +} + +type UpdateUserLARequest struct { + RemoteUserID string `json:"remote_user_id"` + RemoteUserName string `json:"remote_user_name"` + UserAccessToken string `json:"user_access_token"` + Oauth2AccessToken string `json:"oauth2_access_token"` + Oauth2RefreshToken string `json:"oauth2_refresh_token"` +} + +type UpdateUserLAHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler +} + +func NewUpdateUserLAHandler(logger *zap.Logger, ch *command.CommandHandler) *UpdateUserLAHandler { + return &UpdateUserLAHandler{log: logger.Sugar(), ch: ch} +} + +func (h *UpdateUserLAHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + vars := mux.Vars(r) + userName := vars["username"] + linkedAccountID := vars["laid"] + + var req UpdateUserLARequest + d := json.NewDecoder(r.Body) + if err := d.Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + creq := &command.UpdateUserLARequest{ + UserName: userName, + LinkedAccountID: linkedAccountID, + RemoteUserID: req.RemoteUserID, + RemoteUserName: req.RemoteUserName, + Oauth2AccessToken: req.Oauth2AccessToken, + Oauth2RefreshToken: req.Oauth2RefreshToken, + UserAccessToken: req.UserAccessToken, + } + user, err := h.ch.UpdateUserLA(ctx, creq) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := json.NewEncoder(w).Encode(user); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +type CreateUserTokenRequest struct { + TokenName string `json:"token_name"` +} + +type CreateUserTokenResponse struct { + Token string `json:"token"` +} + +type CreateUserTokenHandler struct { + log *zap.SugaredLogger + ch *command.CommandHandler +} + +func NewCreateUserTokenHandler(logger *zap.Logger, ch *command.CommandHandler) *CreateUserTokenHandler { + return &CreateUserTokenHandler{log: logger.Sugar(), ch: ch} +} + +func (h *CreateUserTokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + vars := mux.Vars(r) + userName := vars["username"] + + var req CreateUserTokenRequest + d := json.NewDecoder(r.Body) + if err := d.Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + token, err := h.ch.CreateUserToken(ctx, userName, req.TokenName) + if err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + resp := &CreateUserTokenResponse{ + Token: token, + } + if err := json.NewEncoder(w).Encode(resp); err != nil { + h.log.Errorf("err: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} diff --git a/internal/services/configstore/command/command.go b/internal/services/configstore/command/command.go new file mode 100644 index 0000000..9172069 --- /dev/null +++ b/internal/services/configstore/command/command.go @@ -0,0 +1,595 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package command + +import ( + "context" + "encoding/json" + + "github.com/sorintlab/agola/internal/db" + "github.com/sorintlab/agola/internal/services/configstore/common" + "github.com/sorintlab/agola/internal/services/configstore/readdb" + "github.com/sorintlab/agola/internal/services/types" + "github.com/sorintlab/agola/internal/util" + "github.com/sorintlab/agola/internal/wal" + + "github.com/pkg/errors" + uuid "github.com/satori/go.uuid" + "go.uber.org/zap" +) + +type CommandHandler struct { + log *zap.SugaredLogger + readDB *readdb.ReadDB + wal *wal.WalManager +} + +func NewCommandHandler(logger *zap.Logger, readDB *readdb.ReadDB, wal *wal.WalManager) *CommandHandler { + return &CommandHandler{ + log: logger.Sugar(), + readDB: readDB, + wal: wal, + } +} + +func (s *CommandHandler) CreateProject(ctx context.Context, project *types.Project) (*types.Project, error) { + if project.Name == "" { + return nil, errors.Errorf("project name required") + } + + var cgt *wal.ChangeGroupsUpdateToken + cgNames := []string{project.Name} + + // must do all the check in a single transaction to avoid concurrent changes + // since the use token is related to the transaction time + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + // check duplicate project name + p, err := s.readDB.GetProjectByName(tx, project.Name) + if err != nil { + return err + } + if p != nil { + return errors.Errorf("project %q already exists", p.Name) + } + return nil + }) + if err != nil { + return nil, err + } + + project.ID = uuid.NewV4().String() + + pcj, err := json.Marshal(project) + if err != nil { + return nil, errors.Wrapf(err, "failed to marshal project") + } + actions := []*wal.Action{ + { + ActionType: wal.ActionTypePut, + Path: common.StorageProjectFile(project.ID), + Data: pcj, + }, + } + + _, err = s.wal.WriteWal(ctx, actions, cgt) + return project, err +} + +func (s *CommandHandler) DeleteProject(ctx context.Context, projectName string) error { + var project *types.Project + + var cgt *wal.ChangeGroupsUpdateToken + cgNames := []string{project.Name} + + // must do all the check in a single transaction to avoid concurrent changes + // since the use token is related to the transaction time + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + // check project existance + project, err = s.readDB.GetProjectByName(tx, projectName) + if err != nil { + return err + } + if project == nil { + return errors.Errorf("project %q doesn't exist", projectName) + } + return nil + }) + if err != nil { + return err + } + + actions := []*wal.Action{ + { + ActionType: wal.ActionTypeDelete, + Path: common.StorageProjectFile(project.ID), + }, + } + + _, err = s.wal.WriteWal(ctx, actions, cgt) + return err +} + +func (s *CommandHandler) CreateUser(ctx context.Context, user *types.User) (*types.User, error) { + if user.UserName == "" { + return nil, errors.Errorf("user name required") + } + + var cgt *wal.ChangeGroupsUpdateToken + cgNames := []string{user.UserName} + + // must do all the check in a single transaction to avoid concurrent changes + // since the use token is related to the transaction time + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + // check duplicate user name + u, err := s.readDB.GetUserByName(tx, user.UserName) + if err != nil { + return err + } + if u != nil { + return errors.Errorf("user %q already exists", u.UserName) + } + return nil + }) + if err != nil { + return nil, err + } + + user.ID = uuid.NewV4().String() + + pcj, err := json.Marshal(user) + if err != nil { + return nil, errors.Wrapf(err, "failed to marshal user") + } + actions := []*wal.Action{ + { + ActionType: wal.ActionTypePut, + Path: common.StorageUserFile(user.ID), + Data: pcj, + }, + } + + _, err = s.wal.WriteWal(ctx, actions, cgt) + return user, err +} + +func (s *CommandHandler) DeleteUser(ctx context.Context, userName string) error { + var user *types.User + + var cgt *wal.ChangeGroupsUpdateToken + cgNames := []string{user.UserName} + + // must do all the check in a single transaction to avoid concurrent changes + // since the use token is related to the transaction time + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + // check user existance + user, err = s.readDB.GetUserByName(tx, userName) + if err != nil { + return err + } + if user == nil { + return errors.Errorf("user %q doesn't exist", userName) + } + return nil + }) + if err != nil { + return err + } + + actions := []*wal.Action{ + { + ActionType: wal.ActionTypeDelete, + Path: common.StorageUserFile(user.ID), + }, + } + + // changegroup is the username (and in future the email) to ensure no + // concurrent user creation/modification using the same name + _, err = s.wal.WriteWal(ctx, actions, cgt) + return err +} + +type CreateUserLARequest struct { + UserName string + RemoteSourceName string + RemoteUserID string + RemoteUserName string + UserAccessToken string + Oauth2AccessToken string + Oauth2RefreshToken string +} + +func (s *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequest) (*types.LinkedAccount, error) { + if req.UserName == "" { + return nil, errors.Errorf("user name required") + } + if req.RemoteSourceName == "" { + return nil, errors.Errorf("remote source name required") + } + + var user *types.User + var rs *types.RemoteSource + + var cgt *wal.ChangeGroupsUpdateToken + + // must do all the check in a single transaction to avoid concurrent changes + // since the use token is related to the transaction time + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = s.readDB.GetUserByName(tx, req.UserName) + if err != nil { + return err + } + if user == nil { + return errors.Errorf("user %q doesn't exist", req.UserName) + } + + cgNames := []string{user.ID} + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + rs, err = s.readDB.GetRemoteSourceByName(tx, req.RemoteSourceName) + if err != nil { + return err + } + if rs == nil { + return errors.Errorf("remote source %q doesn't exist", req.RemoteSourceName) + } + return nil + }) + if err != nil { + return nil, err + } + + if user.LinkedAccounts == nil { + user.LinkedAccounts = make(map[string]*types.LinkedAccount) + } + + la := &types.LinkedAccount{ + ID: uuid.NewV4().String(), + RemoteSourceID: rs.ID, + RemoteUserID: req.RemoteUserID, + RemoteUserName: req.RemoteUserName, + UserAccessToken: req.UserAccessToken, + Oauth2AccessToken: req.Oauth2AccessToken, + Oauth2RefreshToken: req.Oauth2RefreshToken, + } + + user.LinkedAccounts[la.ID] = la + + pcj, err := json.Marshal(user) + if err != nil { + return nil, errors.Wrapf(err, "failed to marshal user") + } + actions := []*wal.Action{ + { + ActionType: wal.ActionTypePut, + Path: common.StorageUserFile(user.ID), + Data: pcj, + }, + } + + _, err = s.wal.WriteWal(ctx, actions, cgt) + return la, err +} + +func (s *CommandHandler) DeleteUserLA(ctx context.Context, userName, laID string) error { + if userName == "" { + return errors.Errorf("user name required") + } + if laID == "" { + return errors.Errorf("user linked account id required") + } + + var user *types.User + + var cgt *wal.ChangeGroupsUpdateToken + + // must do all the check in a single transaction to avoid concurrent changes + // since the use token is related to the transaction time + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = s.readDB.GetUserByName(tx, userName) + if err != nil { + return err + } + if user == nil { + return errors.Errorf("user %q doesn't exist", userName) + } + + cgNames := []string{user.ID} + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return err + } + + _, ok := user.LinkedAccounts[laID] + if !ok { + return errors.Errorf("linked account id %q for user %q doesn't exist", laID, userName) + } + + delete(user.LinkedAccounts, laID) + + pcj, err := json.Marshal(user) + if err != nil { + return errors.Wrapf(err, "failed to marshal user") + } + actions := []*wal.Action{ + { + ActionType: wal.ActionTypePut, + Path: common.StorageUserFile(user.ID), + Data: pcj, + }, + } + + _, err = s.wal.WriteWal(ctx, actions, cgt) + return err +} + +type UpdateUserLARequest struct { + UserName string + LinkedAccountID string + RemoteUserID string + RemoteUserName string + UserAccessToken string + Oauth2AccessToken string + Oauth2RefreshToken string +} + +func (s *CommandHandler) UpdateUserLA(ctx context.Context, req *UpdateUserLARequest) (*types.LinkedAccount, error) { + if req.UserName == "" { + return nil, errors.Errorf("user name required") + } + + var user *types.User + var rs *types.RemoteSource + + var cgt *wal.ChangeGroupsUpdateToken + + // must do all the check in a single transaction to avoid concurrent changes + // since the use token is related to the transaction time + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = s.readDB.GetUserByName(tx, req.UserName) + if err != nil { + return err + } + if user == nil { + return errors.Errorf("user %q doesn't exist", req.UserName) + } + + cgNames := []string{user.ID} + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + la, ok := user.LinkedAccounts[req.LinkedAccountID] + if !ok { + return errors.Errorf("linked account id %q for user %q doesn't exist", req.LinkedAccountID, user.UserName) + } + + rs, err = s.readDB.GetRemoteSource(tx, la.RemoteSourceID) + if err != nil { + return err + } + if rs == nil { + return errors.Errorf("remote source with id %q doesn't exist", la.RemoteSourceID) + } + return nil + }) + if err != nil { + return nil, err + } + + la := user.LinkedAccounts[req.LinkedAccountID] + + la.RemoteUserID = req.RemoteUserID + la.RemoteUserName = req.RemoteUserName + la.UserAccessToken = req.UserAccessToken + la.Oauth2AccessToken = req.Oauth2AccessToken + la.Oauth2RefreshToken = req.Oauth2RefreshToken + + userj, err := json.Marshal(user) + if err != nil { + return nil, errors.Wrapf(err, "failed to marshal user") + } + actions := []*wal.Action{ + { + ActionType: wal.ActionTypePut, + Path: common.StorageUserFile(user.ID), + Data: userj, + }, + } + + _, err = s.wal.WriteWal(ctx, actions, cgt) + return la, err +} + +func (s *CommandHandler) CreateUserToken(ctx context.Context, userName, tokenName string) (string, error) { + if userName == "" { + return "", errors.Errorf("user name required") + } + + var user *types.User + + var cgt *wal.ChangeGroupsUpdateToken + + // must do all the check in a single transaction to avoid concurrent changes + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + user, err = s.readDB.GetUserByName(tx, userName) + if err != nil { + return err + } + if user == nil { + return errors.Errorf("user %q doesn't exist", userName) + } + + cgNames := []string{user.ID} + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return "", err + } + + if user.Tokens == nil { + user.Tokens = make(map[string]string) + } + + token := util.EncodeSha1Hex(uuid.NewV4().String()) + user.Tokens[tokenName] = token + + userj, err := json.Marshal(user) + if err != nil { + return "", errors.Wrapf(err, "failed to marshal user") + } + actions := []*wal.Action{ + { + ActionType: wal.ActionTypePut, + Path: common.StorageUserFile(user.ID), + Data: userj, + }, + } + + // changegroup is the userid + _, err = s.wal.WriteWal(ctx, actions, cgt) + return token, err +} + +func (s *CommandHandler) CreateRemoteSource(ctx context.Context, remoteSource *types.RemoteSource) (*types.RemoteSource, error) { + if remoteSource.Name == "" { + return nil, errors.Errorf("remotesource name required") + } + + var cgt *wal.ChangeGroupsUpdateToken + cgNames := []string{remoteSource.Name} + + // must do all the check in a single transaction to avoid concurrent changes + // since the use token is related to the transaction time + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + // check duplicate remoteSource name + u, err := s.readDB.GetRemoteSourceByName(tx, remoteSource.Name) + if err != nil { + return err + } + if u != nil { + return errors.Errorf("remoteSource %q already exists", u.Name) + } + return nil + }) + if err != nil { + return nil, err + } + + remoteSource.ID = uuid.NewV4().String() + + rsj, err := json.Marshal(remoteSource) + if err != nil { + return nil, errors.Wrapf(err, "failed to marshal remotesource") + } + actions := []*wal.Action{ + { + ActionType: wal.ActionTypePut, + Path: common.StorageRemoteSourceFile(remoteSource.ID), + Data: rsj, + }, + } + + _, err = s.wal.WriteWal(ctx, actions, cgt) + return remoteSource, err +} + +func (s *CommandHandler) DeleteRemoteSource(ctx context.Context, remoteSourceName string) error { + var remoteSource *types.RemoteSource + + var cgt *wal.ChangeGroupsUpdateToken + cgNames := []string{remoteSource.ID} + + // must do all the check in a single transaction to avoid concurrent changes + err := s.readDB.Do(func(tx *db.Tx) error { + var err error + cgt, err = s.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) + if err != nil { + return err + } + + // check remoteSource existance + remoteSource, err = s.readDB.GetRemoteSourceByName(tx, remoteSourceName) + if err != nil { + return err + } + if remoteSource == nil { + return errors.Errorf("remotesource %q doesn't exist", remoteSourceName) + } + return nil + }) + if err != nil { + return err + } + + actions := []*wal.Action{ + { + ActionType: wal.ActionTypeDelete, + Path: common.StorageRemoteSourceFile(remoteSource.ID), + }, + } + + // changegroup is all the remote sources + _, err = s.wal.WriteWal(ctx, actions, cgt) + return err +} diff --git a/internal/services/configstore/common/common.go b/internal/services/configstore/common/common.go new file mode 100644 index 0000000..b39fa38 --- /dev/null +++ b/internal/services/configstore/common/common.go @@ -0,0 +1,68 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "fmt" + "path" +) + +var ( + // Storage paths. Always use path (not filepath) to use the "/" separator + StorageDataDir = "data" + StorageProjectsDir = path.Join(StorageDataDir, "projects") + StorageUsersDir = path.Join(StorageDataDir, "users") + StorageRemoteSourcesDir = path.Join(StorageDataDir, "remotesources") +) + +const ( + etcdWalsMinRevisionRange = 100 +) + +func StorageProjectFile(projectID string) string { + return path.Join(StorageProjectsDir, projectID) +} + +func StorageUserFile(userID string) string { + return path.Join(StorageUsersDir, userID) +} + +func StorageRemoteSourceFile(userID string) string { + return path.Join(StorageRemoteSourcesDir, userID) +} + +type ConfigType string + +const ( + ConfigTypeProject ConfigType = "project" + ConfigTypeUser ConfigType = "user" + ConfigTypeRemoteSource ConfigType = "remotesource" +) + +func PathToTypeID(p string) (ConfigType, string) { + var configType ConfigType + switch path.Dir(p) { + case StorageProjectsDir: + configType = ConfigTypeProject + case StorageUsersDir: + configType = ConfigTypeUser + case StorageRemoteSourcesDir: + configType = ConfigTypeRemoteSource + default: + panic(fmt.Errorf("cannot determine configtype for path: %q", p)) + } + + return configType, path.Base(p) +} diff --git a/internal/services/configstore/configstore.go b/internal/services/configstore/configstore.go new file mode 100644 index 0000000..6fff4fa --- /dev/null +++ b/internal/services/configstore/configstore.go @@ -0,0 +1,198 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package configstore + +import ( + "context" + "crypto/tls" + "net/http" + "path/filepath" + + scommon "github.com/sorintlab/agola/internal/common" + "github.com/sorintlab/agola/internal/etcd" + slog "github.com/sorintlab/agola/internal/log" + "github.com/sorintlab/agola/internal/objectstorage" + "github.com/sorintlab/agola/internal/services/config" + "github.com/sorintlab/agola/internal/services/configstore/api" + "github.com/sorintlab/agola/internal/services/configstore/command" + "github.com/sorintlab/agola/internal/services/configstore/readdb" + "github.com/sorintlab/agola/internal/util" + "github.com/sorintlab/agola/internal/wal" + + ghandlers "github.com/gorilla/handlers" + "github.com/gorilla/mux" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +var level = zap.NewAtomicLevelAt(zapcore.InfoLevel) +var logger = slog.New(level) +var log = logger.Sugar() + +type ConfigStore struct { + c *config.ConfigStore + e *etcd.Store + wal *wal.WalManager + readDB *readdb.ReadDB + lts *objectstorage.ObjStorage + ch *command.CommandHandler + listenAddress string +} + +func NewConfigStore(ctx context.Context, c *config.ConfigStore) (*ConfigStore, error) { + if c.Debug { + level.SetLevel(zapcore.DebugLevel) + } + + lts, err := scommon.NewLTS(&c.LTS) + if err != nil { + return nil, err + } + e, err := scommon.NewEtcd(&c.Etcd, logger, "configstore") + if err != nil { + return nil, err + } + + cs := &ConfigStore{ + c: c, + e: e, + lts: lts, + } + + walConf := &wal.WalManagerConfig{ + E: e, + Lts: lts, + } + wal, err := wal.NewWalManager(ctx, logger, walConf) + if err != nil { + return nil, err + } + readDB, err := readdb.NewReadDB(ctx, logger, filepath.Join(c.DataDir, "readdb"), e, lts, wal) + if err != nil { + return nil, err + } + + cs.wal = wal + cs.readDB = readDB + + ch := command.NewCommandHandler(logger, readDB, wal) + cs.ch = ch + + return cs, nil +} + +func (s *ConfigStore) Run(ctx context.Context) error { + errCh := make(chan error) + + go func() { errCh <- s.wal.Run(ctx) }() + go func() { errCh <- s.readDB.Run(ctx) }() + + // noop coors handler + corsHandler := func(h http.Handler) http.Handler { + return h + } + + corsAllowedMethodsOptions := ghandlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "DELETE"}) + corsAllowedHeadersOptions := ghandlers.AllowedHeaders([]string{"Accept", "Accept-Encoding", "Authorization", "Content-Length", "Content-Type", "X-CSRF-Token", "Authorization"}) + corsAllowedOriginsOptions := ghandlers.AllowedOrigins([]string{"*"}) + corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions) + + getProjectHandler := api.NewGetProjectHandler(logger, s.readDB) + projectsHandler := api.NewProjectsHandler(logger, s.readDB) + getProjectByNameHandler := api.NewGetProjectByNameHandler(logger, s.readDB) + createProjectHandler := api.NewCreateProjectHandler(logger, s.ch) + deleteProjectHandler := api.NewDeleteProjectHandler(logger, s.ch) + + getUserHandler := api.NewGetUserHandler(logger, s.readDB) + usersHandler := api.NewUsersHandler(logger, s.readDB) + getUserByNameHandler := api.NewGetUserByNameHandler(logger, s.readDB) + createUserHandler := api.NewCreateUserHandler(logger, s.ch) + deleteUserHandler := api.NewDeleteUserHandler(logger, s.ch) + + createUserLAHandler := api.NewCreateUserLAHandler(logger, s.ch) + deleteUserLAHandler := api.NewDeleteUserLAHandler(logger, s.ch) + updateUserLAHandler := api.NewUpdateUserLAHandler(logger, s.ch) + + createUserTokenHandler := api.NewCreateUserTokenHandler(logger, s.ch) + + getRemoteSourceHandler := api.NewGetRemoteSourceHandler(logger, s.readDB) + remoteSourcesHandler := api.NewRemoteSourcesHandler(logger, s.readDB) + getRemoteSourceByNameHandler := api.NewGetRemoteSourceByNameHandler(logger, s.readDB) + createRemoteSourceHandler := api.NewCreateRemoteSourceHandler(logger, s.ch) + deleteRemoteSourceHandler := api.NewDeleteRemoteSourceHandler(logger, s.ch) + + router := mux.NewRouter() + apirouter := router.PathPrefix("/api/v1alpha").Subrouter() + + apirouter.Handle("/project/{projectid}", getProjectHandler).Methods("GET") + apirouter.Handle("/projects", projectsHandler).Methods("GET") + apirouter.Handle("/projects", createProjectHandler).Methods("PUT") + apirouter.Handle("/projects/{projectname}", getProjectByNameHandler).Methods("GET") + apirouter.Handle("/projects/{projectname}", deleteProjectHandler).Methods("DELETE") + + apirouter.Handle("/user/{userid}", getUserHandler).Methods("GET") + apirouter.Handle("/users", usersHandler).Methods("GET") + apirouter.Handle("/users", createUserHandler).Methods("PUT") + apirouter.Handle("/users/{username}", getUserByNameHandler).Methods("GET") + apirouter.Handle("/users/{username}", deleteUserHandler).Methods("DELETE") + + apirouter.Handle("/users/{username}/linkedaccounts", createUserLAHandler).Methods("PUT") + apirouter.Handle("/users/{username}/linkedaccounts/{laid}", deleteUserLAHandler).Methods("DELETE") + apirouter.Handle("/users/{username}/linkedaccounts/{laid}", updateUserLAHandler).Methods("PUT") + apirouter.Handle("/users/{username}/tokens", createUserTokenHandler).Methods("PUT") + + apirouter.Handle("/remotesource/{id}", getRemoteSourceHandler).Methods("GET") + apirouter.Handle("/remotesources", remoteSourcesHandler).Methods("GET") + apirouter.Handle("/remotesources", createRemoteSourceHandler).Methods("PUT") + apirouter.Handle("/remotesources/{name}", getRemoteSourceByNameHandler).Methods("GET") + apirouter.Handle("/remotesources/{name}", deleteRemoteSourceHandler).Methods("DELETE") + + mainrouter := mux.NewRouter() + mainrouter.PathPrefix("/").Handler(corsHandler(router)) + + var tlsConfig *tls.Config + if s.c.Web.TLS { + var err error + tlsConfig, err = util.NewTLSConfig(s.c.Web.TLSCertFile, s.c.Web.TLSKeyFile, "", false) + if err != nil { + log.Errorf("err: %+v") + return err + } + } + + httpServer := http.Server{ + Addr: s.c.Web.ListenAddress, + Handler: mainrouter, + TLSConfig: tlsConfig, + } + + lerrCh := make(chan error) + go func() { + lerrCh <- httpServer.ListenAndServe() + }() + + select { + case <-ctx.Done(): + log.Infof("configstore exiting") + httpServer.Close() + return nil + case err := <-lerrCh: + log.Errorf("http server listen error: %+v", err) + return err + case err := <-errCh: + log.Errorf("error: %+v", err) + return err + } +} diff --git a/internal/services/configstore/configstore_test.go b/internal/services/configstore/configstore_test.go new file mode 100644 index 0000000..fc49ff5 --- /dev/null +++ b/internal/services/configstore/configstore_test.go @@ -0,0 +1,232 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package configstore + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "os" + "reflect" + "testing" + "time" + + "github.com/sorintlab/agola/internal/services/config" + "github.com/sorintlab/agola/internal/services/types" + "github.com/sorintlab/agola/internal/testutil" + "github.com/sorintlab/agola/internal/util" +) + +func setupEtcd(t *testing.T, dir string) *testutil.TestEmbeddedEtcd { + tetcd, err := testutil.NewTestEmbeddedEtcd(t, logger, dir) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if err := tetcd.Start(); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if err := tetcd.WaitUp(30 * time.Second); err != nil { + t.Fatalf("error waiting on etcd up: %v", err) + } + return tetcd +} + +func shutdownEtcd(tetcd *testutil.TestEmbeddedEtcd) { + if tetcd.Etcd != nil { + tetcd.Kill() + } +} + +func TestResync(t *testing.T) { + dir, err := ioutil.TempDir("", "agola") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + defer os.RemoveAll(dir) + + etcdDir, err := ioutil.TempDir(dir, "etcd") + tetcd := setupEtcd(t, etcdDir) + defer shutdownEtcd(tetcd) + + listenAddress1, port1, err := testutil.GetFreePort(true, false) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + listenAddress2, port2, err := testutil.GetFreePort(true, false) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + listenAddress3, port3, err := testutil.GetFreePort(true, false) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + ctx := context.Background() + + ltsDir, err := ioutil.TempDir(dir, "lts") + csDir1, err := ioutil.TempDir(dir, "cs1") + csDir2, err := ioutil.TempDir(dir, "cs2") + csDir3, err := ioutil.TempDir(dir, "cs3") + + baseConfig := config.ConfigStore{ + Etcd: config.Etcd{ + Endpoints: tetcd.Endpoint, + }, + LTS: config.LTS{ + Type: config.LTSTypePosix, + Path: ltsDir, + }, + Web: config.Web{}, + } + cs1Config := baseConfig + cs1Config.DataDir = csDir1 + cs1Config.Web.ListenAddress = net.JoinHostPort(listenAddress1, port1) + + cs2Config := baseConfig + cs2Config.DataDir = csDir2 + cs2Config.Web.ListenAddress = net.JoinHostPort(listenAddress2, port2) + + cs1, err := NewConfigStore(ctx, &cs1Config) + if err != nil { + t.Fatalf("err: %v", err) + } + cs2, err := NewConfigStore(ctx, &cs2Config) + if err != nil { + t.Fatalf("err: %v", err) + } + + ctx1 := context.Background() + ctx2, cancel2 := context.WithCancel(context.Background()) + + t.Logf("starting cs1") + go func() { + if err := cs1.Run(ctx1); err != nil { + t.Fatalf("err: %v", err) + } + }() + go func() { + if err := cs2.Run(ctx2); err != nil { + t.Fatalf("err: %v", err) + } + }() + + time.Sleep(1 * time.Second) + + for i := 0; i < 10; i++ { + if _, err := cs1.ch.CreateProject(ctx, &types.Project{Name: fmt.Sprintf("project%d", i)}); err != nil { + t.Fatalf("err: %v", err) + } + time.Sleep(200 * time.Millisecond) + } + + time.Sleep(5 * time.Second) + + // stop cs2 + log.Infof("stopping cs2") + cancel2() + + // Do some more changes + for i := 11; i < 20; i++ { + if _, err := cs1.ch.CreateProject(ctx, &types.Project{Name: fmt.Sprintf("project%d", i)}); err != nil { + t.Fatalf("err: %v", err) + } + time.Sleep(200 * time.Millisecond) + } + + time.Sleep(5 * time.Second) + + // compact etcd + if err := tetcd.Compact(); err != nil { + t.Fatalf("err: %v", err) + } + + // start cs2 + // it should resync from wals since the etcd revision as been compacted + cs2, err = NewConfigStore(ctx, &cs2Config) + if err != nil { + t.Fatalf("err: %v", err) + } + log.Infof("starting cs2") + ctx2 = context.Background() + go cs2.Run(ctx2) + + time.Sleep(5 * time.Second) + + projects1, err := cs1.readDB.GetProjects("", 0, true) + if err != nil { + t.Fatalf("err: %v", err) + } + + projects2, err := cs2.readDB.GetProjects("", 0, true) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !compareProjects(projects1, projects2) { + t.Logf("len(projects1): %d", len(projects1)) + t.Logf("len(projects2): %d", len(projects2)) + t.Logf("projects1: %s", util.Dump(projects1)) + t.Logf("projects2: %s", util.Dump(projects2)) + t.Fatalf("projects are different between the two readdbs") + } + + // start cs3, since it's a new instance it should do a full resync + cs3Config := baseConfig + cs3Config.DataDir = csDir3 + cs3Config.Web.ListenAddress = net.JoinHostPort(listenAddress3, port3) + + log.Infof("starting cs3") + cs3, err := NewConfigStore(ctx, &cs3Config) + if err != nil { + t.Fatalf("err: %v", err) + } + ctx3 := context.Background() + go cs3.Run(ctx3) + + time.Sleep(5 * time.Second) + + projects1, err = cs1.readDB.GetProjects("", 0, true) + if err != nil { + t.Fatalf("err: %v", err) + } + + projects3, err := cs3.readDB.GetProjects("", 0, true) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !compareProjects(projects1, projects3) { + t.Logf("len(projects1): %d", len(projects1)) + t.Logf("len(projects3): %d", len(projects3)) + t.Logf("projects1: %s", util.Dump(projects1)) + t.Logf("projects3: %s", util.Dump(projects3)) + t.Fatalf("projects are different between the two readdbs") + } +} + +func compareProjects(p1, p2 []*types.Project) bool { + p1ids := map[string]struct{}{} + p2ids := map[string]struct{}{} + + for _, p := range p1 { + p1ids[p.ID] = struct{}{} + } + for _, p := range p2 { + p2ids[p.ID] = struct{}{} + } + + return reflect.DeepEqual(p1ids, p2ids) +} diff --git a/internal/services/configstore/readdb/migration.go b/internal/services/configstore/readdb/migration.go new file mode 100644 index 0000000..4a478cb --- /dev/null +++ b/internal/services/configstore/readdb/migration.go @@ -0,0 +1,40 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package readdb + +var Stmts = []string{ + // last processed etcd event revision + "create table revision (revision bigint, PRIMARY KEY(revision))", + + // committedwalsequence stores the last committed wal sequence + "create table committedwalsequence (seq varchar, PRIMARY KEY (seq))", + + // changegrouprevision stores the current revision of the changegroup for optimistic locking + "create table changegrouprevision (id varchar, revision varchar, PRIMARY KEY (id, revision))", + + "create table project (id uuid, name varchar, ownerid varchar, data bytea, PRIMARY KEY (id))", + "create index project_name on project(name)", + + "create table user (id uuid, name varchar, data bytea, PRIMARY KEY (id))", + "create table user_token (tokenvalue varchar, userid uuid, PRIMARY KEY (tokenvalue, userid))", + + "create table remotesource (id uuid, name varchar, data bytea, PRIMARY KEY (id))", + + "create table projectsource (id uuid, name varchar, data bytea, PRIMARY KEY (id))", + + "create table linkedaccount_user (id uuid, remotesourceid uuid, userid uuid, remoteuserid uuid, PRIMARY KEY (id), FOREIGN KEY(userid) REFERENCES user(id))", + + "create table linkedaccount_project (id uuid, projectid uuid, PRIMARY KEY (id), FOREIGN KEY(projectid) REFERENCES user(id))", +} diff --git a/internal/services/configstore/readdb/project.go b/internal/services/configstore/readdb/project.go new file mode 100644 index 0000000..57dab35 --- /dev/null +++ b/internal/services/configstore/readdb/project.go @@ -0,0 +1,184 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package readdb + +import ( + "database/sql" + "encoding/json" + + "github.com/sorintlab/agola/internal/db" + "github.com/sorintlab/agola/internal/services/types" + "github.com/sorintlab/agola/internal/util" + + sq "github.com/Masterminds/squirrel" + "github.com/pkg/errors" +) + +var ( + projectSelect = sb.Select("id", "data").From("project") + projectInsert = sb.Insert("project").Columns("id", "name", "data") +) + +func (r *ReadDB) insertProject(tx *db.Tx, data []byte) error { + project := types.Project{} + if err := json.Unmarshal(data, &project); err != nil { + return errors.Wrap(err, "failed to unmarshal project") + } + // poor man insert or update... + if err := r.deleteProject(tx, project.ID); err != nil { + return err + } + q, args, err := projectInsert.Values(project.ID, project.Name, data).ToSql() + if err != nil { + return errors.Wrap(err, "failed to build query") + } + _, err = tx.Exec(q, args...) + return errors.Wrap(err, "failed to insert project") +} + +func (r *ReadDB) deleteProject(tx *db.Tx, id string) error { + // poor man insert or update... + if _, err := tx.Exec("delete from project where id = $1", id); err != nil { + return errors.Wrap(err, "failed to delete project") + } + return nil +} + +func (r *ReadDB) GetProject(tx *db.Tx, projectID string) (*types.Project, error) { + q, args, err := projectSelect.Where(sq.Eq{"id": projectID}).ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + projects, _, err := fetchProjects(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(projects) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(projects) == 0 { + return nil, nil + } + return projects[0], nil +} + +func (r *ReadDB) GetProjectByName(tx *db.Tx, name string) (*types.Project, error) { + q, args, err := projectSelect.Where(sq.Eq{"name": name}).ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + projects, _, err := fetchProjects(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(projects) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(projects) == 0 { + return nil, nil + } + return projects[0], nil +} + +func getProjectsFilteredQuery(startProjectName string, limit int, asc bool) sq.SelectBuilder { + fields := []string{"id", "data"} + + s := sb.Select(fields...).From("project as project") + if asc { + s = s.OrderBy("project.name asc") + } else { + s = s.OrderBy("project.name desc") + } + if startProjectName != "" { + if asc { + s = s.Where(sq.Gt{"project.name": startProjectName}) + } else { + s = s.Where(sq.Lt{"project.name": startProjectName}) + } + } + if limit > 0 { + s = s.Limit(uint64(limit)) + } + + return s +} +func (r *ReadDB) GetProjects(startProjectName string, limit int, asc bool) ([]*types.Project, error) { + var projects []*types.Project + + s := getProjectsFilteredQuery(startProjectName, limit, asc) + q, args, err := s.ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + err = r.rdb.Do(func(tx *db.Tx) error { + rows, err := tx.Query(q, args...) + if err != nil { + return err + } + + projects, _, err = scanProjects(rows) + return err + }) + return projects, errors.WithStack(err) +} + +func fetchProjects(tx *db.Tx, q string, args ...interface{}) ([]*types.Project, []string, error) { + rows, err := tx.Query(q, args...) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanProjects(rows) +} + +func scanProject(rows *sql.Rows, additionalFields ...interface{}) (*types.Project, string, error) { + var id string + var data []byte + if err := rows.Scan(&id, &data); err != nil { + return nil, "", errors.Wrap(err, "failed to scan rows") + } + project := types.Project{} + if len(data) > 0 { + if err := json.Unmarshal(data, &project); err != nil { + return nil, "", errors.Wrap(err, "failed to unmarshal project") + } + } + + return &project, id, nil +} + +func scanProjects(rows *sql.Rows) ([]*types.Project, []string, error) { + projects := []*types.Project{} + ids := []string{} + for rows.Next() { + p, id, err := scanProject(rows) + if err != nil { + rows.Close() + return nil, nil, err + } + projects = append(projects, p) + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, nil, err + } + return projects, ids, nil +} diff --git a/internal/services/configstore/readdb/readdb.go b/internal/services/configstore/readdb/readdb.go new file mode 100644 index 0000000..2609a07 --- /dev/null +++ b/internal/services/configstore/readdb/readdb.go @@ -0,0 +1,812 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package readdb + +import ( + "context" + "database/sql" + "encoding/json" + "io" + "io/ioutil" + "os" + "path/filepath" + "sync" + "time" + + "github.com/sorintlab/agola/internal/db" + "github.com/sorintlab/agola/internal/etcd" + "github.com/sorintlab/agola/internal/objectstorage" + "github.com/sorintlab/agola/internal/sequence" + "github.com/sorintlab/agola/internal/services/configstore/common" + "github.com/sorintlab/agola/internal/util" + "github.com/sorintlab/agola/internal/wal" + + sq "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + "go.uber.org/zap" +) + +var ( + // Use postgresql $ placeholder. It'll be converted to ? from the provided db functions + sb = sq.StatementBuilder.PlaceholderFormat(sq.Dollar) + + revisionSelect = sb.Select("revision").From("revision") + revisionInsert = sb.Insert("revision").Columns("revision") + + committedwalsequenceSelect = sb.Select("seq").From("committedwalsequence") + committedwalsequenceInsert = sb.Insert("committedwalsequence").Columns("seq") + + changegrouprevisionSelect = sb.Select("id, revision").From("changegrouprevision") + changegrouprevisionInsert = sb.Insert("changegrouprevision").Columns("id", "revision") +) + +type ReadDB struct { + log *zap.SugaredLogger + dataDir string + e *etcd.Store + rdb *db.DB + lts *objectstorage.ObjStorage + wal *wal.WalManager + + Initialized bool + initMutex sync.Mutex +} + +func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd.Store, lts *objectstorage.ObjStorage, wal *wal.WalManager) (*ReadDB, error) { + if err := os.MkdirAll(dataDir, 0770); err != nil { + return nil, err + } + rdb, err := db.NewDB(db.Sqlite3, filepath.Join(dataDir, "db")) + if err != nil { + return nil, err + } + + // populate readdb + if err := rdb.Create(Stmts); err != nil { + return nil, err + } + + readDB := &ReadDB{ + log: logger.Sugar(), + dataDir: dataDir, + rdb: rdb, + e: e, + lts: lts, + wal: wal, + } + + return readDB, nil +} + +// Initialize populates the readdb with the current etcd data and save the +// revision to then feed it with the etcd events +func (r *ReadDB) Initialize(ctx context.Context) error { + // sync the rdb + if err := r.SyncRDB(ctx); err != nil { + return errors.Wrapf(err, "error syncing db") + } + return nil +} + +func (r *ReadDB) ResetDB() error { + // TODO(sgotti) this needs to be protected by a mutex + r.rdb.Close() + + // drop rdb + if err := os.Remove(filepath.Join(r.dataDir, "db")); err != nil { + return err + } + + rdb, err := db.NewDB(db.Sqlite3, filepath.Join(r.dataDir, "db")) + if err != nil { + return err + } + + // populate readdb + if err := rdb.Create(Stmts); err != nil { + return err + } + + r.rdb = rdb + + return nil +} + +func (r *ReadDB) SyncFromFiles() (string, error) { + doneCh := make(chan struct{}) + defer close(doneCh) + + var lastCheckpointedWal string + // Get last checkpointed wal from lts + for wal := range r.wal.ListLtsWals("") { + if wal.Err != nil { + return "", wal.Err + } + if wal.Checkpointed { + lastCheckpointedWal = wal.WalSequence + } + } + + doneCh = make(chan struct{}) + haveConfigFiles := false + for object := range r.wal.List(common.StorageDataDir, "", true, doneCh) { + if object.Err != nil { + close(doneCh) + return "", object.Err + } + + haveConfigFiles = true + break + } + close(doneCh) + + if lastCheckpointedWal == "" && haveConfigFiles { + return "", errors.Errorf("no last checkpointed wal in lts but the storage has config files. This should never happen!") + } + + if !haveConfigFiles { + return lastCheckpointedWal, nil + } + + insertfunc := func(objs []string) error { + err := r.rdb.Do(func(tx *db.Tx) error { + for _, obj := range objs { + f, _, err := r.wal.ReadObject(obj, nil) + if err != nil { + if err == objectstorage.ErrNotExist { + r.log.Warnf("object %s disappeared, ignoring", obj) + } + return err + } + data, err := ioutil.ReadAll(f) + if err != nil { + f.Close() + return err + } + f.Close() + + action := &wal.Action{ + ActionType: wal.ActionTypePut, + Path: obj, + Data: data, + } + if err := r.applyAction(tx, action); err != nil { + return err + } + } + return nil + }) + return err + } + + objs := []string{} + count := 0 + doneCh = make(chan struct{}) + defer close(doneCh) + + // file may have changed in the meantime (due to checkpointing) but we don't + // need to have a consistent snapshot since we'll apply all the wals and handle + // them + for object := range r.wal.List(common.StorageDataDir, "", true, doneCh) { + if object.Err != nil { + return "", object.Err + } + + objs = append(objs, object.Path) + + if count > 100 { + if err := insertfunc(objs); err != nil { + return "", err + } + count = 0 + objs = []string{} + } else { + count++ + } + } + if err := insertfunc(objs); err != nil { + return "", err + } + + // save the wal sequence of the last checkpointed wal before syncing from files + err := r.rdb.Do(func(tx *db.Tx) error { + return r.insertCommittedWalSequence(tx, lastCheckpointedWal) + }) + if err != nil { + return "", err + } + + return lastCheckpointedWal, nil +} + +func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { + insertfunc := func(walFiles []*wal.WalFile) error { + err := r.rdb.Do(func(tx *db.Tx) error { + for _, walFile := range walFiles { + walFilef, err := r.wal.ReadWal(walFile.WalSequence) + if err != nil { + return err + } + dec := json.NewDecoder(walFilef) + var header *wal.WalHeader + if err = dec.Decode(&header); err != nil && err != io.EOF { + walFilef.Close() + return err + } + walFilef.Close() + if err := r.insertCommittedWalSequence(tx, walFile.WalSequence); err != nil { + return err + } + if err := r.applyWal(tx, header.WalDataFileID); err != nil { + return err + } + } + return nil + }) + return err + } + + lastWalSeq := startWalSeq + walFiles := []*wal.WalFile{} + count := 0 + + doneCh := make(chan struct{}) + defer close(doneCh) + + for walFile := range r.wal.ListLtsWals(startWalSeq) { + if walFile.Err != nil { + return "", walFile.Err + } + + walFiles = append(walFiles, walFile) + lastWalSeq = walFile.WalSequence + + if count > 100 { + if err := insertfunc(walFiles); err != nil { + return "", err + } + count = 0 + walFiles = []*wal.WalFile{} + } else { + count++ + } + } + if err := insertfunc(walFiles); err != nil { + return "", err + } + + return lastWalSeq, nil +} + +func (r *ReadDB) SyncRDB(ctx context.Context) error { + // get the last committed storage wal sequence saved in the rdb + curWalSeq := "" + err := r.rdb.Do(func(tx *db.Tx) error { + var err error + curWalSeq, err = r.GetCommittedWalSequence(tx) + if err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + lastCommittedStorageWal, _, err := r.wal.LastCommittedStorageWal(ctx) + if err != nil { + return err + } + + doFullSync := false + if curWalSeq == "" { + doFullSync = true + r.log.Warn("no startWalSeq in db, doing a full sync") + } else { + ok, err := r.wal.HasLtsWal(curWalSeq) + if err != nil { + return err + } + if !ok { + r.log.Warnf("no wal with seq %q in lts, doing a full sync", curWalSeq) + doFullSync = true + } + + // if the epoch of the wals has changed this means etcd has been reset. If so we should do a full resync since we are saving in the rdb also data that was not yet committed to lts so we should have the rdb ahead of the current lts data + // TODO(sgotti) improve this to avoid doing a full resync + curWalSequence, err := sequence.Parse(curWalSeq) + if err != nil { + return err + } + curWalEpoch := curWalSequence.Epoch + + lastCommittedStorageWalSequence, err := sequence.Parse(lastCommittedStorageWal) + if err != nil { + return err + } + if curWalEpoch != lastCommittedStorageWalSequence.Epoch { + r.log.Warnf("current rdb wal sequence epoch %d different than new wal sequence epoch %d, doing a full sync", curWalEpoch, lastCommittedStorageWalSequence.Epoch) + doFullSync = true + } + } + + if doFullSync { + r.log.Infof("doing a full sync from lts files") + if err := r.ResetDB(); err != nil { + return err + } + + var err error + curWalSeq, err = r.SyncFromFiles() + if err != nil { + return err + } + } + + r.log.Infof("startWalSeq: %s", curWalSeq) + + // Sync from wals + // sync from lts until the current known lastCommittedStorageWal in etcd + // since wals are first committed to lts and then in etcd we would like to + // avoid to store in rdb something that is not yet marked as committedstorage + // in etcd + curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) + if err != nil { + return errors.Wrap(err, "failed to sync from wals") + } + + // Get the first available wal from etcd and check that our current walseq + // from wals on lts is >= + // if not (this happens when syncFromWals takes some time and in the meantime + // many new wals are written, the next sync should be faster and able to continue + firstAvailableWalData, revision, err := r.wal.FirstAvailableWalData(ctx) + if err != nil { + return errors.Wrap(err, "failed to get first available wal data") + } + r.log.Infof("firstAvailableWalData: %s", util.Dump(firstAvailableWalData)) + r.log.Infof("revision: %d", revision) + if firstAvailableWalData == nil { + if curWalSeq != "" { + // this happens if etcd has been reset + return errors.Errorf("our curwalseq is %q but there's no wal data on etcd", curWalSeq) + } + } + if firstAvailableWalData != nil { + if curWalSeq < firstAvailableWalData.WalSequence { + return errors.Errorf("current applied wal seq %q is smaller than the first available wal on etcd %q", curWalSeq, firstAvailableWalData.WalSequence) + } + } + + err = r.rdb.Do(func(tx *db.Tx) error { + if err := r.insertRevision(tx, revision); err != nil { + return err + } + + // use the same revision as previous operation + for walElement := range r.wal.ListEtcdWals(ctx, revision) { + if walElement.Err != nil { + return err + } + if walElement.WalData.WalSequence <= curWalSeq { + continue + } + //if walElement.WalData.WalStatus == wal.WalStatusCommittedStorage { + + if err := r.insertCommittedWalSequence(tx, walElement.WalData.WalSequence); err != nil { + return err + } + //} + + //// update readdb only when the wal has been committed to lts + //if walElement.WalData.WalStatus != wal.WalStatusCommittedStorage { + // return nil + //} + + r.log.Debugf("applying wal to db") + if err := r.applyWal(tx, walElement.WalData.WalDataFileID); err != nil { + return err + } + } + + return nil + }) + + return err +} + +func (r *ReadDB) Run(ctx context.Context) error { + revision, err := r.GetRevision() + if err != nil { + return err + } + + if revision == 0 { + for { + err := r.Initialize(ctx) + if err == nil { + r.Initialized = true + break + } + r.log.Errorf("initialize err: %+v", err) + time.Sleep(1 * time.Second) + } + } + + for { + if !r.Initialized { + r.Initialize(ctx) + } + if err := r.HandleEvents(ctx); err != nil { + r.log.Errorf("handleevents err: %+v", err) + } + + select { + case <-ctx.Done(): + r.log.Infof("readdb exiting") + r.rdb.Close() + return nil + default: + } + + time.Sleep(1 * time.Second) + } +} + +// TODO(sgotti) improve to apply when the wal have been "committedstorage" and +// not only "committed", in this way we don't have to full resync when etcd is +// lost/reset +func (r *ReadDB) HandleEvents(ctx context.Context) error { + var revision int64 + err := r.rdb.Do(func(tx *db.Tx) error { + err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) + if err != nil { + if err == sql.ErrNoRows { + revision = 0 + } else { + return err + } + } + return nil + }) + if err != nil { + return err + } + + wctx, cancel := context.WithCancel(ctx) + defer cancel() + r.log.Infof("revision: %d", revision) + wch := r.wal.Watch(wctx, revision+1) + for we := range wch { + r.log.Debugf("we: %s", util.Dump(we)) + if we.Err != nil { + err := we.Err + if err == wal.ErrCompacted { + r.log.Warnf("required events already compacted, reinitializing readdb") + r.Initialized = false + return nil + } + return errors.Wrapf(err, "watch error") + } + + // a single transaction for every response (every response contains all the + // events happened in an etcd revision). + err = r.rdb.Do(func(tx *db.Tx) error { + + // if theres a wal seq epoch change something happened to etcd, usually (if + // the user hasn't messed up with etcd keys) this means etcd has been reset + // in such case we should resync from the lts state to ensure we apply all the + // wal marked as committedstorage (since they could have been lost from etcd) + curWalSeq, err := r.GetCommittedWalSequence(tx) + if err != nil { + return err + } + r.log.Debugf("curWalSeq: %q", curWalSeq) + if curWalSeq != "" && we.WalData != nil { + curWalSequence, err := sequence.Parse(curWalSeq) + if err != nil { + return err + } + curWalEpoch := curWalSequence.Epoch + + weWalSequence, err := sequence.Parse(we.WalData.WalSequence) + if err != nil { + return err + } + r.log.Infof("we.WalData.WalSequence: %q", we.WalData.WalSequence) + weWalEpoch := weWalSequence.Epoch + if curWalEpoch != weWalEpoch { + r.Initialized = false + return errors.Errorf("current rdb wal sequence epoch %d different than new wal sequence epoch %d, resyncing from lts", curWalEpoch, weWalEpoch) + } + } + + if err := r.handleEvent(tx, we); err != nil { + return err + } + + if err := r.insertRevision(tx, we.Revision); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + } + r.log.Infof("wch closed") + + return nil +} + +func (r *ReadDB) handleEvent(tx *db.Tx, we *wal.WatchElement) error { + //r.log.Debugf("event: %s %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value) + //key := string(ev.Kv.Key) + + if err := r.handleWalEvent(tx, we); err != nil { + return err + } + return nil +} + +func (r *ReadDB) handleWalEvent(tx *db.Tx, we *wal.WatchElement) error { + // update readdb only when the wal has been committed to lts + //if we.WalData.WalStatus != wal.WalStatusCommittedStorage { + // return nil + //} + + if we.WalData != nil { + if err := r.insertCommittedWalSequence(tx, we.WalData.WalSequence); err != nil { + return err + } + } + + for cgName, cgRev := range we.ChangeGroupsRevisions { + if err := r.insertChangeGroupRevision(tx, cgName, cgRev); err != nil { + return err + } + } + + if we.WalData != nil { + r.log.Debugf("applying wal to db") + return r.applyWal(tx, we.WalData.WalDataFileID) + } + return nil +} + +func (r *ReadDB) applyWal(tx *db.Tx, walDataFileID string) error { + walFile, err := r.wal.ReadWalData(walDataFileID) + if err != nil { + return errors.Wrapf(err, "cannot read wal data file %q", walDataFileID) + } + defer walFile.Close() + + dec := json.NewDecoder(walFile) + for { + var action *wal.Action + + err := dec.Decode(&action) + if err == io.EOF { + // all done + break + } + if err != nil { + return errors.Wrapf(err, "failed to decode wal file") + } + + if err := r.applyAction(tx, action); err != nil { + return err + } + } + + return nil +} + +func (r *ReadDB) applyAction(tx *db.Tx, action *wal.Action) error { + configType, ID := common.PathToTypeID(action.Path) + + switch action.ActionType { + case wal.ActionTypePut: + switch configType { + case common.ConfigTypeProject: + if err := r.insertProject(tx, action.Data); err != nil { + return err + } + case common.ConfigTypeUser: + if err := r.insertUser(tx, action.Data); err != nil { + return err + } + case common.ConfigTypeRemoteSource: + if err := r.insertRemoteSource(tx, action.Data); err != nil { + return err + } + } + + case wal.ActionTypeDelete: + switch configType { + case common.ConfigTypeProject: + r.log.Debugf("deleting project with id: %s", ID) + if err := r.deleteProject(tx, ID); err != nil { + return err + } + case common.ConfigTypeUser: + r.log.Debugf("deleting user with id: %s", ID) + if err := r.deleteUser(tx, ID); err != nil { + return err + } + case common.ConfigTypeRemoteSource: + r.log.Debugf("deleting remote source with id: %s", ID) + if err := r.deleteRemoteSource(tx, ID); err != nil { + return err + } + } + } + + return nil +} + +func (r *ReadDB) Do(f func(tx *db.Tx) error) error { + return r.rdb.Do(f) +} + +func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error { + //r.log.Infof("insert revision: %d", revision) + // poor man insert or update that works because transaction isolation level is serializable + if _, err := tx.Exec("delete from revision"); err != nil { + return errors.Wrap(err, "failed to delete revision") + } + // TODO(sgotti) go database/sql and mattn/sqlite3 don't support uint64 types... + //q, args, err = revisionInsert.Values(int64(wresp.Header.ClusterId), run.Revision).ToSql() + q, args, err := revisionInsert.Values(revision).ToSql() + if err != nil { + return errors.Wrap(err, "failed to build query") + } + if _, err = tx.Exec(q, args...); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (r *ReadDB) GetRevision() (int64, error) { + var revision int64 + + err := r.rdb.Do(func(tx *db.Tx) error { + var err error + revision, err = r.getRevision(tx) + return err + }) + return revision, err +} + +func (r *ReadDB) getRevision(tx *db.Tx) (int64, error) { + var revision int64 + + q, args, err := revisionSelect.ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return 0, errors.Wrap(err, "failed to build query") + } + + err = tx.QueryRow(q, args...).Scan(&revision) + if err == sql.ErrNoRows { + return 0, nil + } + return revision, err +} + +func (r *ReadDB) insertCommittedWalSequence(tx *db.Tx, seq string) error { + r.log.Infof("insert seq: %s", seq) + // poor man insert or update that works because transaction isolation level is serializable + if _, err := tx.Exec("delete from committedwalsequence"); err != nil { + return errors.Wrap(err, "failed to delete committedwalsequence") + } + q, args, err := committedwalsequenceInsert.Values(seq).ToSql() + if err != nil { + return errors.Wrap(err, "failed to build query") + } + if _, err = tx.Exec(q, args...); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (r *ReadDB) GetCommittedWalSequence(tx *db.Tx) (string, error) { + var seq string + + q, args, err := committedwalsequenceSelect.OrderBy("seq").Limit(1).ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return "", errors.Wrap(err, "failed to build query") + } + + err = tx.QueryRow(q, args...).Scan(&seq) + if err == sql.ErrNoRows { + return "", nil + } + return seq, err +} + +func (r *ReadDB) insertChangeGroupRevision(tx *db.Tx, changegroup string, revision int64) error { + r.log.Infof("insertChangeGroupRevision: %s %d", changegroup, revision) + + // poor man insert or update that works because transaction isolation level is serializable + if _, err := tx.Exec("delete from changegrouprevision where id = $1", changegroup); err != nil { + return errors.Wrap(err, "failed to delete run") + } + // insert only if revision > 0 + if revision > 0 { + q, args, err := changegrouprevisionInsert.Values(changegroup, revision).ToSql() + if err != nil { + return errors.Wrap(err, "failed to build query") + } + if _, err = tx.Exec(q, args...); err != nil { + return err + } + } + return nil +} + +func (r *ReadDB) GetChangeGroupsUpdateTokens(tx *db.Tx, groups []string) (*wal.ChangeGroupsUpdateToken, error) { + s := changegrouprevisionSelect.Where(sq.Eq{"id": groups}) + q, args, err := s.ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + cgr, err := fetchChangeGroupsRevision(tx, q, args...) + if err != nil { + return nil, err + } + + revision, err := r.getRevision(tx) + if err != nil { + return nil, err + } + + // for non existing changegroups use a changegroup with revision = 0 + for _, g := range groups { + if _, ok := cgr[g]; !ok { + cgr[g] = 0 + } + } + + return &wal.ChangeGroupsUpdateToken{CurRevision: revision, ChangeGroupsRevisions: cgr}, nil +} + +func fetchChangeGroupsRevision(tx *db.Tx, q string, args ...interface{}) (map[string]int64, error) { + rows, err := tx.Query(q, args...) + if err != nil { + return nil, err + } + defer rows.Close() + return scanChangeGroupsRevision(rows) +} + +func scanChangeGroupsRevision(rows *sql.Rows) (map[string]int64, error) { + changegroups := map[string]int64{} + for rows.Next() { + var ( + id string + revision int64 + ) + if err := rows.Scan(&id, &revision); err != nil { + return nil, errors.Wrap(err, "failed to scan rows") + } + changegroups[id] = revision + } + if err := rows.Err(); err != nil { + return nil, err + } + return changegroups, nil +} diff --git a/internal/services/configstore/readdb/remotesource.go b/internal/services/configstore/readdb/remotesource.go new file mode 100644 index 0000000..ab2f6b1 --- /dev/null +++ b/internal/services/configstore/readdb/remotesource.go @@ -0,0 +1,184 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package readdb + +import ( + "database/sql" + "encoding/json" + + "github.com/sorintlab/agola/internal/db" + "github.com/sorintlab/agola/internal/services/types" + "github.com/sorintlab/agola/internal/util" + + sq "github.com/Masterminds/squirrel" + "github.com/pkg/errors" +) + +var ( + remotesourceSelect = sb.Select("id", "data").From("remotesource") + remotesourceInsert = sb.Insert("remotesource").Columns("id", "name", "data") +) + +func (r *ReadDB) insertRemoteSource(tx *db.Tx, data []byte) error { + remoteSource := types.RemoteSource{} + if err := json.Unmarshal(data, &remoteSource); err != nil { + return errors.Wrap(err, "failed to unmarshal remotesource") + } + // poor man insert or update... + if err := r.deleteRemoteSource(tx, remoteSource.ID); err != nil { + return err + } + q, args, err := remotesourceInsert.Values(remoteSource.ID, remoteSource.Name, data).ToSql() + if err != nil { + return errors.Wrap(err, "failed to build query") + } + _, err = tx.Exec(q, args...) + return errors.Wrap(err, "failed to insert remotesource") +} + +func (r *ReadDB) deleteRemoteSource(tx *db.Tx, id string) error { + // poor man insert or update... + if _, err := tx.Exec("delete from remotesource where id = $1", id); err != nil { + return errors.Wrap(err, "failed to delete remotesource") + } + return nil +} + +func (r *ReadDB) GetRemoteSource(tx *db.Tx, remoteSourceID string) (*types.RemoteSource, error) { + q, args, err := remotesourceSelect.Where(sq.Eq{"id": remoteSourceID}).ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + remoteSources, _, err := fetchRemoteSources(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(remoteSources) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(remoteSources) == 0 { + return nil, nil + } + return remoteSources[0], nil +} + +func (r *ReadDB) GetRemoteSourceByName(tx *db.Tx, name string) (*types.RemoteSource, error) { + q, args, err := remotesourceSelect.Where(sq.Eq{"name": name}).ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + remoteSources, _, err := fetchRemoteSources(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(remoteSources) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(remoteSources) == 0 { + return nil, nil + } + return remoteSources[0], nil +} + +func getRemoteSourcesFilteredQuery(startRemoteSourceName string, limit int, asc bool) sq.SelectBuilder { + fields := []string{"id", "data"} + + s := sb.Select(fields...).From("remotesource as remotesource") + if asc { + s = s.OrderBy("remotesource.name asc") + } else { + s = s.OrderBy("remotesource.name desc") + } + if startRemoteSourceName != "" { + if asc { + s = s.Where(sq.Gt{"remotesource.name": startRemoteSourceName}) + } else { + s = s.Where(sq.Lt{"remotesource.name": startRemoteSourceName}) + } + } + if limit > 0 { + s = s.Limit(uint64(limit)) + } + + return s +} +func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) { + var remoteSources []*types.RemoteSource + + s := getRemoteSourcesFilteredQuery(startRemoteSourceName, limit, asc) + q, args, err := s.ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + err = r.rdb.Do(func(tx *db.Tx) error { + rows, err := tx.Query(q, args...) + if err != nil { + return err + } + + remoteSources, _, err = scanRemoteSources(rows) + return err + }) + return remoteSources, errors.WithStack(err) +} + +func fetchRemoteSources(tx *db.Tx, q string, args ...interface{}) ([]*types.RemoteSource, []string, error) { + rows, err := tx.Query(q, args...) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRemoteSources(rows) +} + +func scanRemoteSource(rows *sql.Rows, additionalFields ...interface{}) (*types.RemoteSource, string, error) { + var id string + var data []byte + if err := rows.Scan(&id, &data); err != nil { + return nil, "", errors.Wrap(err, "failed to scan rows") + } + remoteSource := types.RemoteSource{} + if len(data) > 0 { + if err := json.Unmarshal(data, &remoteSource); err != nil { + return nil, "", errors.Wrap(err, "failed to unmarshal remotesource") + } + } + + return &remoteSource, id, nil +} + +func scanRemoteSources(rows *sql.Rows) ([]*types.RemoteSource, []string, error) { + remoteSources := []*types.RemoteSource{} + ids := []string{} + for rows.Next() { + p, id, err := scanRemoteSource(rows) + if err != nil { + rows.Close() + return nil, nil, err + } + remoteSources = append(remoteSources, p) + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, nil, err + } + return remoteSources, ids, nil +} diff --git a/internal/services/configstore/readdb/user.go b/internal/services/configstore/readdb/user.go new file mode 100644 index 0000000..3e958d3 --- /dev/null +++ b/internal/services/configstore/readdb/user.go @@ -0,0 +1,381 @@ +// Copyright 2019 Sorint.lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied +// See the License for the specific language governing permissions and +// limitations under the License. + +package readdb + +import ( + "database/sql" + "encoding/json" + + "github.com/sorintlab/agola/internal/db" + "github.com/sorintlab/agola/internal/services/types" + "github.com/sorintlab/agola/internal/util" + + sq "github.com/Masterminds/squirrel" + "github.com/pkg/errors" +) + +var ( + userSelect = sb.Select("user.id", "user.data").From("user") + userInsert = sb.Insert("user").Columns("id", "name", "data") + + linkedaccountSelect = sb.Select("id", "data").From("linkedaccount") + linkedaccountInsert = sb.Insert("linkedaccount").Columns("id", "name", "data") + linkedaccountuserInsert = sb.Insert("linkedaccount_user").Columns("id", "remotesourceid", "userid", "remoteuserid") + linkedaccountuserSelect = sb.Select("id", "userid").From("linkedaccount_user") + linkedaccountprojectInsert = sb.Insert("linkedaccount_project").Columns("id", "userid") + + usertokenSelect = sb.Select("tokenvalue", "userid").From("user_token") + usertokenInsert = sb.Insert("user_token").Columns("tokenvalue", "userid") +) + +func (r *ReadDB) insertUser(tx *db.Tx, data []byte) error { + user := types.User{} + if err := json.Unmarshal(data, &user); err != nil { + return errors.Wrap(err, "failed to unmarshal user") + } + r.log.Infof("inserting user: %s", util.Dump(user)) + // poor man insert or update... + if err := r.deleteUser(tx, user.ID); err != nil { + return err + } + q, args, err := userInsert.Values(user.ID, user.UserName, data).ToSql() + if err != nil { + return errors.Wrap(err, "failed to build query") + } + if _, err := tx.Exec(q, args...); err != nil { + return errors.Wrap(err, "failed to insert user") + } + + // insert linkedaccounts_user + for _, la := range user.LinkedAccounts { + if err := r.deleteUserLinkedAccount(tx, la.ID); err != nil { + return err + } + q, args, err = linkedaccountuserInsert.Values(la.ID, la.RemoteSourceID, user.ID, la.RemoteUserID).ToSql() + if err != nil { + return errors.Wrap(err, "failed to build query") + } + if _, err := tx.Exec(q, args...); err != nil { + return errors.Wrap(err, "failed to insert user") + } + } + // insert user_token + for _, tokenValue := range user.Tokens { + r.log.Infof("inserting user token: %s", tokenValue) + if err := r.deleteUserToken(tx, tokenValue); err != nil { + return err + } + q, args, err = usertokenInsert.Values(tokenValue, user.ID).ToSql() + if err != nil { + return errors.Wrap(err, "failed to build query") + } + if _, err := tx.Exec(q, args...); err != nil { + return errors.Wrap(err, "failed to insert user") + } + } + + return nil +} + +func (r *ReadDB) deleteUser(tx *db.Tx, userID string) error { + // delete user linked accounts + if err := r.deleteUserLinkedAccounts(tx, userID); err != nil { + return errors.Wrap(err, "failed to delete user linked accounts") + } + + // delete user tokens + if _, err := tx.Exec("delete from user_token where userid = $1", userID); err != nil { + return errors.Wrap(err, "failed to delete usertokens") + } + + // poor man insert or update... + if _, err := tx.Exec("delete from user where id = $1", userID); err != nil { + return errors.Wrap(err, "failed to delete user") + } + + return nil +} + +func (r *ReadDB) deleteUserLinkedAccounts(tx *db.Tx, userID string) error { + // poor man insert or update... + if _, err := tx.Exec("delete from linkedaccount_user where userid = $1", userID); err != nil { + return errors.Wrap(err, "failed to delete linked account") + } + if _, err := tx.Exec("delete from linkedaccount_project where id = $1", userID); err != nil { + return errors.Wrap(err, "failed to delete linked account") + } + return nil +} + +func (r *ReadDB) deleteUserLinkedAccount(tx *db.Tx, id string) error { + // poor man insert or update... + if _, err := tx.Exec("delete from linkedaccount_user where id = $1", id); err != nil { + return errors.Wrap(err, "failed to delete linked account") + } + if _, err := tx.Exec("delete from linkedaccount_project where id = $1", id); err != nil { + return errors.Wrap(err, "failed to delete linked account") + } + return nil +} + +func (r *ReadDB) deleteAllUserTokens(tx *db.Tx, userID string) error { + // poor man insert or update... + if _, err := tx.Exec("delete from user_token where userid = $1", userID); err != nil { + return errors.Wrap(err, "failed to delete user_token") + } + return nil +} + +func (r *ReadDB) deleteUserToken(tx *db.Tx, tokenValue string) error { + // poor man insert or update... + if _, err := tx.Exec("delete from user_token where tokenvalue = $1", tokenValue); err != nil { + return errors.Wrap(err, "failed to delete user_token") + } + return nil +} + +func (r *ReadDB) GetUser(tx *db.Tx, userID string) (*types.User, error) { + q, args, err := userSelect.Where(sq.Eq{"id": userID}).ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + users, _, err := fetchUsers(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(users) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(users) == 0 { + return nil, nil + } + return users[0], nil +} + +func (r *ReadDB) GetUserByName(tx *db.Tx, name string) (*types.User, error) { + q, args, err := userSelect.Where(sq.Eq{"name": name}).ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + users, _, err := fetchUsers(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(users) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(users) == 0 { + return nil, nil + } + return users[0], nil +} + +func (r *ReadDB) GetUserByTokenValue(tx *db.Tx, tokenValue string) (*types.User, error) { + s := userSelect + s = s.Join("user_token on user_token.userid = user.id") + s = s.Where(sq.Eq{"user_token.tokenvalue": tokenValue}) + q, args, err := s.ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + users, _, err := fetchUsers(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(users) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(users) == 0 { + return nil, nil + } + return users[0], nil +} + +func (r *ReadDB) GetUserByLinkedAccount(tx *db.Tx, linkedAccountID string) (*types.User, error) { + s := userSelect + s = s.Join("linkedaccount_user as lau on lau.userid = user.id") + s = s.Where(sq.Eq{"lau.id": linkedAccountID}) + q, args, err := s.ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + users, _, err := fetchUsers(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(users) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(users) == 0 { + return nil, nil + } + return users[0], nil +} + +func (r *ReadDB) GetUserByLinkedAccountRemoteUserIDandSource(tx *db.Tx, remoteUserID, remoteSourceID string) (*types.User, error) { + s := userSelect + s = s.Join("linkedaccount_user as lau on lau.userid = user.id") + s = s.Where(sq.Eq{"lau.remoteuserid": remoteUserID, "lau.remotesourceid": remoteSourceID}) + q, args, err := s.ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + users, _, err := fetchUsers(tx, q, args...) + if err != nil { + return nil, errors.WithStack(err) + } + if len(users) > 1 { + return nil, errors.Errorf("too many rows returned") + } + if len(users) == 0 { + return nil, nil + } + return users[0], nil +} + +func getUsersFilteredQuery(startUserName string, limit int, asc bool) sq.SelectBuilder { + fields := []string{"id", "data"} + + s := sb.Select(fields...).From("user as user") + if asc { + s = s.OrderBy("user.name asc") + } else { + s = s.OrderBy("user.name desc") + } + if startUserName != "" { + if asc { + s = s.Where(sq.Gt{"user.name": startUserName}) + } else { + s = s.Where(sq.Lt{"user.name": startUserName}) + } + } + if limit > 0 { + s = s.Limit(uint64(limit)) + } + + return s +} +func (r *ReadDB) GetUsers(tx *db.Tx, startUserName string, limit int, asc bool) ([]*types.User, error) { + var users []*types.User + + s := getUsersFilteredQuery(startUserName, limit, asc) + q, args, err := s.ToSql() + r.log.Debugf("q: %s, args: %s", q, util.Dump(args)) + if err != nil { + return nil, errors.Wrap(err, "failed to build query") + } + + rows, err := tx.Query(q, args...) + if err != nil { + return nil, err + } + + users, _, err = scanUsers(rows) + return users, err +} + +func fetchUsers(tx *db.Tx, q string, args ...interface{}) ([]*types.User, []string, error) { + rows, err := tx.Query(q, args...) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanUsers(rows) +} + +func scanUser(rows *sql.Rows, additionalFields ...interface{}) (*types.User, string, error) { + var id string + var data []byte + if err := rows.Scan(&id, &data); err != nil { + return nil, "", errors.Wrap(err, "failed to scan rows") + } + user := types.User{} + if len(data) > 0 { + if err := json.Unmarshal(data, &user); err != nil { + return nil, "", errors.Wrap(err, "failed to unmarshal user") + } + } + + return &user, id, nil +} + +func scanUsers(rows *sql.Rows) ([]*types.User, []string, error) { + users := []*types.User{} + ids := []string{} + for rows.Next() { + p, id, err := scanUser(rows) + if err != nil { + rows.Close() + return nil, nil, err + } + users = append(users, p) + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, nil, err + } + return users, ids, nil +} + +type LinkedAccountUser struct { + ID string + UserID string +} + +func fetchLinkedAccounts(tx *db.Tx, q string, args ...interface{}) ([]*LinkedAccountUser, error) { + rows, err := tx.Query(q, args...) + if err != nil { + return nil, err + } + defer rows.Close() + linkedAccounts, err := scanLinkedAccounts(rows) + + return linkedAccounts, err +} + +func scanLinkedAccount(rows *sql.Rows, additionalFields ...interface{}) (*LinkedAccountUser, error) { + var id, userid string + if err := rows.Scan(&id, &userid); err != nil { + return nil, errors.Wrap(err, "failed to scan rows") + } + + return &LinkedAccountUser{ID: id, UserID: userid}, nil +} + +func scanLinkedAccounts(rows *sql.Rows) ([]*LinkedAccountUser, error) { + linkedAccounts := []*LinkedAccountUser{} + for rows.Next() { + linkedAccount, err := scanLinkedAccount(rows) + if err != nil { + rows.Close() + return nil, err + } + linkedAccounts = append(linkedAccounts, linkedAccount) + } + if err := rows.Err(); err != nil { + return nil, err + } + return linkedAccounts, nil +}