Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dev/dev.env
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ APP_DB_USER=vertica
APP_DB_PASSWORD=Password1
VERTICA_DB_NAME=vertica

# ORACLE credentials
ORACLE_PASSWORD=YourSecurePassword123
ORACLE_DATABASE=app
APP_USER=oracle
APP_USER_PASSWORD=Password1

# To prevent generating sample demo VMart data (more about it here https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/GettingStartedGuide/IntroducingVMart/IntroducingVMart.htm),
# leave VMART_DIR and VMART_ETL_SCRIPT empty.
VMART_DIR=
Expand Down
20 changes: 17 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: "3.8"

services:
postgres:
container_name: dd-postgresql
Expand Down Expand Up @@ -50,9 +48,24 @@ services:
networks:
- local

oracle:
container_name: dd-oracle
image: gvenzl/oracle-free:slim-faststart
hostname: oracle-db
ports:
- "1521:1521"
env_file:
- dev/dev.env
volumes:
- oracle-volume:/opt/oracle/oradata
networks:
- local
restart: always


clickhouse:
container_name: dd-clickhouse
image: clickhouse/clickhouse-server:21.12.3.32
image: clickhouse/clickhouse-server:latest
restart: always
volumes:
- clickhouse-data:/var/lib/clickhouse:delegated
Expand Down Expand Up @@ -145,6 +158,7 @@ volumes:
# The vertica docker image is 404
#vertica-data:
dremio-data:
oracle-volume:

networks:
local:
Expand Down
12 changes: 8 additions & 4 deletions sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ class FractionalType(NumericType):
pass


class Float(FractionalType):
python_type = float


class IKey(ABC):
"Interface for ColType, for using a column as a key in table."

Expand All @@ -70,6 +66,14 @@ def make_value(self, value):
return self.python_type(value)


class Float(FractionalType, IKey):
@property
def python_type(self) -> type:
if self.precision == 0:
return int
return float


class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key
@property
def python_type(self) -> type:
Expand Down
23 changes: 18 additions & 5 deletions sqeleton/databases/oracle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, List, Optional

from ..queries.compiler import CompiledCode
from ..utils import match_regexps
from ..abcs.database_types import (
Decimal,
Expand All @@ -14,6 +15,7 @@
TimestampTZ,
FractionalType,
)
from ..queries.ast_classes import ForeignKey
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import this, table, SKIP
Expand Down Expand Up @@ -124,6 +126,8 @@ def is_distinct_from(self, a: str, b: str) -> str:
return f"DECODE({a}, {b}, 1, 0) = 0"

def type_repr(self, t) -> str:
if isinstance(t, ForeignKey):
return self.type_repr(t.type)
try:
return {
str: "VARCHAR(1024)",
Expand Down Expand Up @@ -167,16 +171,17 @@ def current_timestamp(self) -> str:

class Oracle(ThreadedDatabase):
dialect = Dialect()
CONNECT_URI_HELP = "oracle://<user>:<password>@<host>/<database>"
CONNECT_URI_HELP = "oracle://<user>:<password>@<host>:port/<database>"
CONNECT_URI_PARAMS = ["database?"]

def __init__(self, *, host, database, thread_count, **kw):

def __init__(self, *, host, database, thread_count, port=None, **kw):
self.kwargs = kw

# Build dsn if not present
if "dsn" not in kw:
self.kwargs["dsn"] = f"{host}/{database}" if database else host
# Support for different ports
port = port or 1521
self.kwargs["dsn"] = f"{host}:{port}/{database}" if database else f"{host}:{port}"

self.default_schema = kw.get("user").upper()

Expand All @@ -192,7 +197,15 @@ def create_connection(self):
except Exception as e:
raise ConnectError(*e.args) from e

def _query_cursor(self, c, sql_code: str):
def _query_cursor(self, c, sql_code: CompiledCode):
# Convert %s style parameters to :1, :2, :3 style for Oracle to support queries built by sqeleton (tbl.create())
if sql_code.args:
# Replace %s with :1, :2, :3, etc.
oracle_sql = sql_code.code
for i in range(len(sql_code.args)):
oracle_sql = oracle_sql.replace('%s', f':{i+1}', 1)
sql_code = CompiledCode(oracle_sql, sql_code.args, sql_code.type)

try:
return super()._query_cursor(c, sql_code)
except self._oracle.DatabaseError as e:
Expand Down
6 changes: 4 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from parameterized import parameterized_class

from sqeleton import databases as db
from sqeleton import connect
from sqeleton.abcs.mixins import AbstractMixin_NormalizeValue
from sqeleton.queries import table
from sqeleton.databases import Database
Expand All @@ -22,11 +21,14 @@
TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql"
TEST_POSTGRESQL_CONN_STRING: str = "postgresql://postgres:Password1@localhost/postgres"
TEST_SNOWFLAKE_CONN_STRING: str = os.environ.get("SNOWFLAKE_URI") or None
# presto uri for provided docker - "presto://presto:presto@localhost:8080/memory/default"
TEST_PRESTO_CONN_STRING: str = os.environ.get("PRESTO_URI") or None
TEST_BIGQUERY_CONN_STRING: str = os.environ.get("BIGQUERY_URI") or None
TEST_REDSHIFT_CONN_STRING: str = os.environ.get("REDSHIFT_URI") or None
TEST_ORACLE_CONN_STRING: str = None
# oracle uri for provided docker - "oracle://oracle:Password1@localhost/app"
TEST_ORACLE_CONN_STRING: str = os.environ.get("ORACLE_URI") or None
TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATABRICKS_URI")
# trino uri for provided docker - "trino://trino@localhost:8081/memory/default"
TEST_TRINO_CONN_STRING: str = os.environ.get("TRINO_URI") or None
# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse"
TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("CLICKHOUSE_URI")
Expand Down
64 changes: 35 additions & 29 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,54 @@
from sqeleton import connect
from sqeleton import databases as dbs
from sqeleton.queries import table, current_timestamp, NormalizeAsString, ForeignKey, Compiler
from .common import TEST_MYSQL_CONN_STRING
from .common import str_to_checksum, make_test_each_database_in_list, get_conn, random_table_suffix
from common import str_to_checksum, make_test_each_database_in_list, get_conn, random_table_suffix
from sqeleton.abcs.database_types import TimestampTZ
from sqeleton.abcs.mixins import AbstractMixin_MD5

TEST_DATABASES = {
dbs.MySQL,
dbs.PostgreSQL,
# dbs.MySQL,
# dbs.PostgreSQL,
dbs.Oracle,
dbs.Redshift,
dbs.Snowflake,
dbs.DuckDB,
dbs.BigQuery,
dbs.Presto,
dbs.Trino,
dbs.Vertica,
dbs.Dremio,
# dbs.DuckDB,
# dbs.Presto,
# dbs.Trino,
# dbs.Dremio,
# dbs.BigQuery,
# dbs.Snowflake,
# dbs.Redshift,
# dbs.Vertica,
}

test_each_database: Callable = make_test_each_database_in_list(TEST_DATABASES)


@test_each_database
class TestDatabase(unittest.TestCase):
def setUp(self):
self.mysql = connect(TEST_MYSQL_CONN_STRING)

def test_connect_to_db(self):
self.assertEqual(1, self.mysql.query("SELECT 1", int))
db = get_conn(self.db_cls)
self.assertEqual(1, db.query("SELECT 1", int))


@test_each_database
class TestMD5(unittest.TestCase):
def test_md5_as_int(self):
class MD5Dialect(dbs.mysql.Dialect, dbs.mysql.Mixin_MD5):
pass

self.mysql = connect(TEST_MYSQL_CONN_STRING)
self.mysql.dialect = MD5Dialect()

str = "hello world"
query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str))
db = get_conn(self.db_cls)

# Check if the database dialect has Mixin_MD5 in its MIXINS

has_md5_mixin = any(issubclass(mixin, AbstractMixin_MD5) for mixin in db.dialect.MIXINS)

if not has_md5_mixin:
self.skipTest(f"{self.db_cls.__name__} does not support MD5")

# Load the MD5 mixin into the dialect
dialect_with_md5 = db.dialect.load_mixins(AbstractMixin_MD5)

str_value = "hello world"
query_fragment = dialect_with_md5.md5_as_int("'{0}'".format(str_value))
query = f"SELECT {query_fragment}"

self.assertEqual(str_to_checksum(str), self.mysql.query(query, int))
self.assertEqual(str_to_checksum(str_value), db.query(query, int))


class TestConnect(unittest.TestCase):
Expand Down Expand Up @@ -111,15 +117,15 @@ def test_correct_timezone(self):

db.query(tbl.create())

tz = pytz.timezone("Europe/Berlin")
tz = pytz.timezone("UTC")

now = datetime.now(tz)
if isinstance(db, dbs.Presto) or isinstance(db, dbs.Dremio):
ms = now.microsecond // 1000 * 1000 # Presto max precision is 3
if isinstance(db, (dbs.Presto, dbs.Trino, dbs.Dremio)):
ms = now.microsecond // 1000 * 1000 # Presto/Trino max precision is 3
now = now.replace(microsecond=ms)

db.query(tbl.insert_row(1, now, now))
if self.db_cls not in [dbs.Dremio]:
if self.db_cls not in [dbs.Dremio, dbs.Presto]:
db.query(db.dialect.set_timezone_to_utc())

t = db.table(tbl).query_schema()
Expand Down
Loading