Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions config/sample_xconfwebconfig.conf
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ xconfwebconfig {
concurrent_queries = 5
connections = 5
local_dc = ""
port = 9042

//Config to create database client to AWS Keyspace using IAM temporary credentials
aws_keyspace_enabled = false
role_based_access_enabled = false
aws_region = ""
aws_keyspace_ca_path = "path_to_file/sf-class2-root.crt"

//If role_based_access_enabled is true, access_key_id and secret_access_key will be fetched using IAM temporary credentials
access_key_id = ""
secret_access_key = ""
}

misc {
Expand Down
127 changes: 127 additions & 0 deletions db/aws_keyspace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package db

import (
"fmt"
"os"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sigv4-auth-cassandra-gocql-driver-plugin/sigv4"
"github.com/go-akka/configuration"
"github.com/gocql/gocql"
log "github.com/sirupsen/logrus"
)

func awsKeyspaceClient(conf *configuration.Config, testOnly bool) (*CassandraClient, error) {
// init
hosts := conf.GetStringList("xconfwebconfig.database.hosts")
cluster := gocql.NewCluster(hosts...)

cluster.Consistency = gocql.LocalQuorum
cluster.ProtoVersion = int(conf.GetInt32("xconfwebconfig.database.protocolversion", ProtocolVersion))
cluster.DisableInitialHostLookup = DisableInitialHostLookup
cluster.Timeout = time.Duration(conf.GetInt32("xconfwebconfig.database.timeout_in_sec", 1)) * time.Second
cluster.ConnectTimeout = time.Duration(conf.GetInt32("xconfwebconfig.database.connect_timeout_in_sec", 1)) * time.Second
cluster.NumConns = int(conf.GetInt32("xconfwebconfig.database.connections", DefaultConnections))
cluster.Port = int(conf.GetInt64("xconfwebconfig.database.port", DefaultPort))

cluster.RetryPolicy = &gocql.DowngradingConsistencyRetryPolicy{
[]gocql.Consistency{
gocql.LocalQuorum,
gocql.LocalOne,
gocql.One,
},
}

localDc := conf.GetString("xconfwebconfig.database.local_dc")
if len(localDc) > 0 {
cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(localDc)
}

awsRegion, err := getAwsRegionForCassandra(conf)
if err != nil {
log.Error(err.Error())
return nil, err
}

var auth sigv4.AwsAuthenticator = sigv4.NewAwsAuthenticator()
auth.Region = awsRegion

isRoleBasedAccessEnabled := conf.GetBoolean("xconfwebconfig.database.role_based_access_enabled")
if isRoleBasedAccessEnabled {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(awsRegion)},
)
if err != nil {
log.Error(err.Error())
return nil, err
}

// Set up the callback to refresh credentials
auth.CredentialsCallback = func() (sigv4.SigV4Credentials, error) {
creds, err := sess.Config.Credentials.Get()
if err != nil {
return sigv4.SigV4Credentials{}, err
}

return sigv4.SigV4Credentials{
AccessKeyId: creds.AccessKeyID,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: creds.SessionToken,
}, nil
}
} else {
auth.AccessKeyId = conf.GetString("xconfwebconfig.database.access_key_id")
auth.SecretAccessKey = conf.GetString("xconfwebconfig.database.secret_access_key")
}
cluster.Authenticator = auth

awsKeySpaceCaPath := conf.GetString("xconfwebconfig.database.aws_keyspace_ca_path")
cluster.SslOpts = &gocql.SslOptions{
CaPath: awsKeySpaceCaPath,
EnableHostVerification: false,
}

// Use the appropriate keyspace
var deviceKeyspace string
if testOnly {
cluster.Keyspace = conf.GetString("xconfwebconfig.database.test_keyspace", DefaultTestKeyspace)
deviceKeyspace = conf.GetString("webconfig.database.device_test_keyspace", DefaultDeviceTestKeyspace)
} else {
cluster.Keyspace = conf.GetString("xconfwebconfig.database.keyspace", DefaultKeyspace)
deviceKeyspace = conf.GetString("webconfig.database.device_keyspace", DefaultDeviceKeyspace)
}
log.Debug(fmt.Sprintf("Init CassandraClient with keyspace: %v", cluster.Keyspace))

session, err := cluster.CreateSession()
if err != nil {
return nil, err
}

devicePodTableName := conf.GetString("webconfig.database.device_pod_table_name", DefaultDevicePodTableName)

return &CassandraClient{
Session: session,
ClusterConfig: cluster,
sleepTime: conf.GetInt32("xconfwebconfig.perftest.sleep_in_msecs", DefaultSleepTimeInMillisecond),
concurrentQueries: make(chan bool, conf.GetInt32("xconfwebconfig.database.concurrent_queries", 500)),
localDc: localDc,
deviceKeyspace: deviceKeyspace,
devicePodTableName: devicePodTableName,
testOnly: testOnly,
}, nil
}

func getAwsRegionForCassandra(conf *configuration.Config) (string, error) {
awsRegion := conf.GetString("xconfwebconfig.database.aws_region")
if len(awsRegion) == 0 {
awsRegion = os.Getenv("AWS_REGION")
}

if len(awsRegion) == 0 {
return "", fmt.Errorf("%s", "Aws region is not provided")
}

return awsRegion, nil
}
95 changes: 95 additions & 0 deletions db/cassandra.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package db

import (
"fmt"
"time"
"xconfwebconfig/security"

"github.com/go-akka/configuration"
"github.com/gocql/gocql"
log "github.com/sirupsen/logrus"
)

func cassandraClient(conf *configuration.Config, testOnly bool) (*CassandraClient, error) {
// init
hosts := conf.GetStringList("xconfwebconfig.database.hosts")
cluster := gocql.NewCluster(hosts...)

cluster.Consistency = gocql.LocalQuorum
cluster.ProtoVersion = int(conf.GetInt32("xconfwebconfig.database.protocolversion", ProtocolVersion))
cluster.DisableInitialHostLookup = DisableInitialHostLookup
cluster.Timeout = time.Duration(conf.GetInt32("xconfwebconfig.database.timeout_in_sec", 1)) * time.Second
cluster.ConnectTimeout = time.Duration(conf.GetInt32("xconfwebconfig.database.connect_timeout_in_sec", 1)) * time.Second
cluster.NumConns = int(conf.GetInt32("xconfwebconfig.database.connections", DefaultConnections))

cluster.RetryPolicy = &gocql.DowngradingConsistencyRetryPolicy{
[]gocql.Consistency{
gocql.LocalQuorum,
gocql.LocalOne,
gocql.One,
},
}

localDc := conf.GetString("xconfwebconfig.database.local_dc")
if len(localDc) > 0 {
cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(localDc)
}

isSslEnabled := conf.GetBoolean("xconfwebconfig.database.is_ssl_enabled")

var password string
var err error

encryptedPassword := conf.GetString("xconfwebconfig.database.encrypted_password")
if encryptedPassword != "" {
codec := security.NewAesCodec()
password, err = codec.Decrypt(encryptedPassword)
if err != nil {
log.Error(err.Error())
return nil, err
}
} else {
password = conf.GetString("xconfwebconfig.database.password")
}

user := conf.GetString("xconfwebconfig.database.user")
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: user,
Password: password,
}

if isSslEnabled {
cluster.SslOpts = &gocql.SslOptions{
EnableHostVerification: false,
}
}

// Use the appropriate keyspace
var deviceKeyspace string
if testOnly {
cluster.Keyspace = conf.GetString("xconfwebconfig.database.test_keyspace", DefaultTestKeyspace)
deviceKeyspace = conf.GetString("webconfig.database.device_test_keyspace", DefaultDeviceTestKeyspace)
} else {
cluster.Keyspace = conf.GetString("xconfwebconfig.database.keyspace", DefaultKeyspace)
deviceKeyspace = conf.GetString("webconfig.database.device_keyspace", DefaultDeviceKeyspace)
}
log.Debug(fmt.Sprintf("Init CassandraClient with keyspace: %v", cluster.Keyspace))

session, err := cluster.CreateSession()
if err != nil {
return nil, err
}

devicePodTableName := conf.GetString("webconfig.database.device_pod_table_name", DefaultDevicePodTableName)

return &CassandraClient{
Session: session,
ClusterConfig: cluster,
sleepTime: conf.GetInt32("xconfwebconfig.perftest.sleep_in_msecs", DefaultSleepTimeInMillisecond),
concurrentQueries: make(chan bool, conf.GetInt32("xconfwebconfig.database.concurrent_queries", 500)),
localDc: localDc,
deviceKeyspace: deviceKeyspace,
devicePodTableName: devicePodTableName,
testOnly: testOnly,
}, nil
}
88 changes: 5 additions & 83 deletions db/cassandra_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"github.com/go-akka/configuration"
"github.com/gocql/gocql"

"xconfwebconfig/security"
"xconfwebconfig/util"

log "github.com/sirupsen/logrus"
Expand All @@ -48,6 +47,7 @@ const (
DefaultColumnValue = "data"
NamedListPartColumnValue = "NamedListData_part_"
NamedListCountColumnValue = "NamedListData_parts_count"
DefaultPort = 9042
)

type CassandraClient struct {
Expand Down Expand Up @@ -76,90 +76,12 @@ type PenetrationMetrics struct {
}

func NewCassandraClient(conf *configuration.Config, testOnly bool) (*CassandraClient, error) {
// init
hosts := conf.GetStringList("xconfwebconfig.database.hosts")
cluster := gocql.NewCluster(hosts...)

cluster.Consistency = gocql.LocalQuorum
cluster.ProtoVersion = int(conf.GetInt32("xconfwebconfig.database.protocolversion", ProtocolVersion))
cluster.DisableInitialHostLookup = DisableInitialHostLookup
cluster.Timeout = time.Duration(conf.GetInt32("xconfwebconfig.database.timeout_in_sec", 1)) * time.Second
cluster.ConnectTimeout = time.Duration(conf.GetInt32("xconfwebconfig.database.connect_timeout_in_sec", 1)) * time.Second
cluster.NumConns = int(conf.GetInt32("xconfwebconfig.database.connections", DefaultConnections))

cluster.RetryPolicy = &gocql.DowngradingConsistencyRetryPolicy{
[]gocql.Consistency{
gocql.LocalQuorum,
gocql.LocalOne,
gocql.One,
},
}

localDc := conf.GetString("xconfwebconfig.database.local_dc")
if len(localDc) > 0 {
cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(localDc)
}

user := conf.GetString("xconfwebconfig.database.user")
encryptedPassword := conf.GetString("xconfwebconfig.database.encrypted_password")
isSslEnabled := conf.GetBoolean("xconfwebconfig.database.is_ssl_enabled")

//build codec
codec := security.NewAesCodec()

var password string
var err error

if encryptedPassword != "" {
password, err = codec.Decrypt(encryptedPassword)
if err != nil {
log.Error(err.Error())
return nil, err
}
} else {
password = conf.GetString("xconfwebconfig.database.password")
}

cluster.Authenticator = gocql.PasswordAuthenticator{
Username: user,
Password: password,
}

if isSslEnabled {
sslOpts := &gocql.SslOptions{
EnableHostVerification: false,
}
cluster.SslOpts = sslOpts
}

// Use the appropriate keyspace
var deviceKeyspace string
if testOnly {
cluster.Keyspace = conf.GetString("xconfwebconfig.database.test_keyspace", DefaultTestKeyspace)
deviceKeyspace = conf.GetString("webconfig.database.device_test_keyspace", DefaultDeviceTestKeyspace)
isAwsKeyspaceEnabled := conf.GetBoolean("xconfwebconfig.database.aws_keyspace_enabled")
if isAwsKeyspaceEnabled {
return awsKeyspaceClient(conf, testOnly)
} else {
cluster.Keyspace = conf.GetString("xconfwebconfig.database.keyspace", DefaultKeyspace)
deviceKeyspace = conf.GetString("webconfig.database.device_keyspace", DefaultDeviceKeyspace)
return cassandraClient(conf, testOnly)
}
log.Debug(fmt.Sprintf("Init CassandraClient with keyspace: %v", cluster.Keyspace))

session, err := cluster.CreateSession()
if err != nil {
return nil, err
}

devicePodTableName := conf.GetString("webconfig.database.device_pod_table_name", DefaultDevicePodTableName)

return &CassandraClient{
Session: session,
ClusterConfig: cluster,
sleepTime: conf.GetInt32("xconfwebconfig.perftest.sleep_in_msecs", DefaultSleepTimeInMillisecond),
concurrentQueries: make(chan bool, conf.GetInt32("xconfwebconfig.database.concurrent_queries", 500)),
localDc: localDc,
deviceKeyspace: deviceKeyspace,
devicePodTableName: devicePodTableName,
testOnly: testOnly,
}, nil
}

// Cassandra Impl of DatabaseClient
Expand Down
Loading