agola/internal/etcd/etcd.go
2019-02-21 16:00:12 +01:00

361 lines
8.5 KiB
Go

// Copyright 2019 Sorint.lab
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied
// See the License for the specific language governing permissions and
// limitations under the License.
package etcd
import (
"context"
"crypto/tls"
"net/url"
"strconv"
"strings"
"time"
"github.com/sorintlab/agola/internal/util"
"github.com/pkg/errors"
"go.etcd.io/etcd/clientv3"
etcdclientv3 "go.etcd.io/etcd/clientv3"
"go.etcd.io/etcd/clientv3/namespace"
"go.etcd.io/etcd/etcdserver/api/v3rpc/rpctypes"
"go.uber.org/zap"
)
var (
// ErrKeyNotFound is thrown when the key is not found in the store during a Get operation
ErrKeyNotFound = errors.New("key not found")
ErrKeyModified = errors.New("unable to complete atomic operation, key modified")
)
const (
defaultEndpoints = "http://127.0.0.1:2379"
compactKey = "compactkey"
defaultCompactionInterval = 10 * time.Minute
)
type WriteOptions struct {
TTL time.Duration
}
type Config struct {
Logger *zap.Logger
Endpoints string
Prefix string
CertFile string
KeyFile string
CAFile string
SkipTLSVerify bool
CompactionInterval time.Duration
}
func FromEtcdError(err error) error {
switch err {
case rpctypes.ErrKeyNotFound:
return ErrKeyNotFound
}
return err
}
type Store struct {
log *zap.SugaredLogger
c *etcdclientv3.Client
}
func New(cfg Config) (*Store, error) {
prefix := cfg.Prefix
if prefix != "" && !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
endpointsStr := cfg.Endpoints
if endpointsStr == "" {
endpointsStr = defaultEndpoints
}
endpoints := strings.Split(endpointsStr, ",")
// check that all the endpoints have the same scheme
var scheme string
for _, e := range endpoints {
u, err := url.Parse(e)
if err != nil {
return nil, errors.Wrapf(err, "cannot parse endpoint %q", e)
}
if scheme == "" {
scheme = u.Scheme
}
if scheme != u.Scheme {
return nil, errors.Errorf("all the endpoints must have the same scheme")
}
}
var tlsConfig *tls.Config
if scheme != "http" && scheme != "https" {
return nil, errors.Errorf("endpoints scheme must be http or https")
}
if scheme == "https" {
var err error
tlsConfig, err = util.NewTLSConfig(cfg.CertFile, cfg.KeyFile, cfg.CAFile, cfg.SkipTLSVerify)
if err != nil {
return nil, errors.Wrapf(err, "cannot create tls config")
}
}
config := etcdclientv3.Config{
Endpoints: endpoints,
TLS: tlsConfig,
}
c, err := etcdclientv3.New(config)
if err != nil {
return nil, err
}
c.KV = namespace.NewKV(c.KV, prefix)
c.Watcher = namespace.NewWatcher(c.Watcher, prefix)
c.Lease = namespace.NewLease(c.Lease, prefix)
s := &Store{
log: cfg.Logger.Sugar(),
c: c,
}
compactionInterval := defaultCompactionInterval
if cfg.CompactionInterval != 0 {
compactionInterval = cfg.CompactionInterval
}
go s.compactor(context.TODO(), compactionInterval)
return s, nil
}
func (s *Store) Client() *etcdclientv3.Client {
return s.c
}
func (s *Store) Put(ctx context.Context, key string, value []byte, options *WriteOptions) (*etcdclientv3.PutResponse, error) {
etcdv3Options := []etcdclientv3.OpOption{}
if options != nil {
if options.TTL > 0 {
lease, err := s.c.Grant(ctx, int64(options.TTL.Seconds()))
if err != nil {
return nil, err
}
etcdv3Options = append(etcdv3Options, etcdclientv3.WithLease(lease.ID))
}
}
resp, err := s.c.Put(ctx, key, string(value), etcdv3Options...)
return resp, FromEtcdError(err)
}
func (s *Store) Get(ctx context.Context, key string) (*etcdclientv3.GetResponse, error) {
resp, err := s.c.Get(ctx, key)
if err != nil {
return resp, FromEtcdError(err)
}
if len(resp.Kvs) == 0 {
return resp, ErrKeyNotFound
}
return resp, nil
}
func (s *Store) List(ctx context.Context, directory, start string, revision int64) (*etcdclientv3.GetResponse, error) {
if !strings.HasSuffix(directory, "/") {
directory += "/"
}
key := directory
rangeEnd := clientv3.GetPrefixRangeEnd(key)
if start != "" {
key = start
}
opts := []etcdclientv3.OpOption{etcdclientv3.WithRange(rangeEnd)}
if revision != 0 {
opts = append(opts, etcdclientv3.WithRev(revision))
}
resp, err := s.c.Get(ctx, key, opts...)
return resp, FromEtcdError(err)
}
type ListPagedResp struct {
Resp *clientv3.GetResponse
HasMore bool
Continuation *ListPagedContinuation
}
type ListPagedContinuation struct {
Revision int64
LastKey string
}
func (s *Store) ListPaged(ctx context.Context, directory string, revision, limit int64, continuation *ListPagedContinuation) (*ListPagedResp, error) {
if !strings.HasSuffix(directory, "/") {
directory += "/"
}
key := directory
rangeEnd := clientv3.GetPrefixRangeEnd(key)
if continuation != nil {
revision = continuation.Revision
key = continuation.LastKey
}
opts := []etcdclientv3.OpOption{etcdclientv3.WithRange(rangeEnd), etcdclientv3.WithLimit(limit)}
if revision != 0 {
opts = append(opts, etcdclientv3.WithRev(revision))
}
resp, err := s.c.Get(ctx, key, opts...)
if err != nil {
return nil, FromEtcdError(err)
}
lastKey := key
if len(resp.Kvs) > 0 {
lastKey = string(resp.Kvs[len(resp.Kvs)-1].Key) + "\x00"
}
return &ListPagedResp{
Resp: resp,
HasMore: resp.More,
Continuation: &ListPagedContinuation{
Revision: resp.Header.Revision,
LastKey: lastKey,
},
}, nil
}
func (s *Store) AtomicPut(ctx context.Context, key string, value []byte, prevRevision int64, options *WriteOptions) (*etcdclientv3.TxnResponse, error) {
etcdv3Options := []etcdclientv3.OpOption{}
if options != nil {
if options.TTL > 0 {
lease, err := s.c.Grant(ctx, int64(options.TTL))
if err != nil {
return nil, err
}
etcdv3Options = append(etcdv3Options, etcdclientv3.WithLease(lease.ID))
}
}
var cmp etcdclientv3.Cmp
if prevRevision != 0 {
cmp = etcdclientv3.Compare(etcdclientv3.ModRevision(key), "=", int64(prevRevision))
} else {
// key must not exist
cmp = etcdclientv3.Compare(etcdclientv3.CreateRevision(key), "=", 0)
}
txn := s.c.Txn(ctx).If(cmp)
txn = txn.Then(etcdclientv3.OpPut(key, string(value), etcdv3Options...))
tresp, err := txn.Commit()
if err != nil {
return tresp, FromEtcdError(err)
}
if !tresp.Succeeded {
return tresp, ErrKeyModified
}
return tresp, nil
}
func (s *Store) Delete(ctx context.Context, key string) error {
_, err := s.c.Delete(ctx, key)
return err
}
func (s *Store) AtomicDelete(ctx context.Context, key string, revision int64) (*etcdclientv3.TxnResponse, error) {
cmp := etcdclientv3.Compare(etcdclientv3.ModRevision(key), "=", revision)
req := etcdclientv3.OpDelete(key)
tresp, err := s.c.Txn(ctx).If(cmp).Then(req).Commit()
if err != nil {
return tresp, FromEtcdError(err)
}
if !tresp.Succeeded {
return tresp, ErrKeyModified
}
return tresp, nil
}
func (s *Store) Watch(ctx context.Context, prefix string, revision int64) etcdclientv3.WatchChan {
etcdv3Options := []etcdclientv3.OpOption{clientv3.WithPrefix()}
if revision != 0 {
etcdv3Options = append(etcdv3Options, etcdclientv3.WithRev(revision))
}
return s.c.Watch(ctx, prefix, etcdv3Options...)
}
func (s *Store) Close() error {
return s.c.Close()
}
func (s *Store) compactor(ctx context.Context, interval time.Duration) {
var version int64
var rev int64
var err error
for {
select {
case <-time.After(interval):
case <-ctx.Done():
return
}
version, rev, err = s.compact(ctx, version, rev)
if err != nil {
continue
}
}
}
func (s *Store) compact(ctx context.Context, version, rev int64) (int64, int64, error) {
resp, err := s.c.KV.Txn(ctx).If(
clientv3.Compare(clientv3.Version(compactKey), "=", version),
).Then(
clientv3.OpPut(compactKey, strconv.FormatInt(rev, 10)),
).Else(
clientv3.OpGet(compactKey),
).Commit()
if err != nil {
return version, rev, err
}
curRev := resp.Header.Revision
if !resp.Succeeded {
curVersion := resp.Responses[0].GetResponseRange().Kvs[0].Version
return curVersion, curRev, nil
}
curVersion := version + 1
if rev == 0 {
return curVersion, curRev, nil
}
if _, err = s.c.Compact(ctx, rev); err != nil {
s.log.Warnf("compact error: %v", err)
return curVersion, curRev, err
}
s.log.Infof("compacted revision: %d", rev)
return curVersion, curRev, nil
}