diff --git a/setup.py b/setup.py index 8a3e1db..c78e5e4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup setup(name='tap-postgres', - version='0.0.65', + version='0.0.66', description='Singer.io tap for extracting data from PostgreSQL', author='Stitch', url='https://singer.io', diff --git a/tap_postgres/sync_strategies/logical_replication.py b/tap_postgres/sync_strategies/logical_replication.py index 3f2b82c..14ad6b4 100644 --- a/tap_postgres/sync_strategies/logical_replication.py +++ b/tap_postgres/sync_strategies/logical_replication.py @@ -1,21 +1,23 @@ #!/usr/bin/env python3 # pylint: disable=missing-docstring,not-an-iterable,too-many-locals,too-many-arguments,invalid-name,too-many-return-statements,too-many-branches,len-as-condition,too-many-nested-blocks,wrong-import-order,duplicate-code, anomalous-backslash-in-string, too-many-statements, singleton-comparison, consider-using-in -import singer +from functools import reduce +from select import select +import copy +import csv import datetime import decimal +import json +import re + +from dateutil.parser import parse +import psycopg2 +import singer from singer import utils, get_bookmark import singer.metadata as metadata import tap_postgres.db as post_db import tap_postgres.sync_strategies.common as sync_common -from dateutil.parser import parse -import psycopg2 -from psycopg2 import sql -import copy -from select import select -from functools import reduce -import json -import re + LOGGER = singer.get_logger() @@ -65,81 +67,29 @@ def get_stream_version(tap_stream_id, state): return stream_version -def tuples_to_map(accum, t): - accum[t[0]] = t[1] - return accum - -def create_hstore_elem_query(elem): - return sql.SQL("SELECT hstore_to_array({})").format(sql.Literal(elem)) - -def create_hstore_elem(conn_info, elem): - with post_db.open_connection(conn_info) as conn: - with conn.cursor() as cur: - query = create_hstore_elem_query(elem) - cur.execute(query) - res = cur.fetchone()[0] - hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {}) - return hstore_elem - -def create_array_elem(elem, sql_datatype, conn_info): +def create_hstore_elem(elem): + array = [(item.replace('"', '').split('=>')) for item in elem] + hstore = {} + for item in array: + if len(item) == 2: + key, value = item + if key in hstore: + raise KeyError('Duplicate key {} found when creating hstore'.format(key)) + if value.lower() == 'null': + value = None + d[key] = value + + return hstore + +def create_array_elem(elem): if elem is None: return None - with post_db.open_connection(conn_info) as conn: - with conn.cursor() as cur: - if sql_datatype == 'bit[]': - cast_datatype = 'boolean[]' - elif sql_datatype == 'boolean[]': - cast_datatype = 'boolean[]' - elif sql_datatype == 'character varying[]': - cast_datatype = 'character varying[]' - elif sql_datatype == 'cidr[]': - cast_datatype = 'cidr[]' - elif sql_datatype == 'citext[]': - cast_datatype = 'text[]' - elif sql_datatype == 'date[]': - cast_datatype = 'text[]' - elif sql_datatype == 'double precision[]': - cast_datatype = 'double precision[]' - elif sql_datatype == 'hstore[]': - cast_datatype = 'text[]' - elif sql_datatype == 'integer[]': - cast_datatype = 'integer[]' - elif sql_datatype == 'bigint[]': - cast_datatype = 'bigint[]' - elif sql_datatype == 'inet[]': - cast_datatype = 'inet[]' - elif sql_datatype == 'json[]': - cast_datatype = 'text[]' - elif sql_datatype == 'jsonb[]': - cast_datatype = 'text[]' - elif sql_datatype == 'macaddr[]': - cast_datatype = 'macaddr[]' - elif sql_datatype == 'money[]': - cast_datatype = 'text[]' - elif sql_datatype == 'numeric[]': - cast_datatype = 'text[]' - elif sql_datatype == 'real[]': - cast_datatype = 'real[]' - elif sql_datatype == 'smallint[]': - cast_datatype = 'smallint[]' - elif sql_datatype == 'text[]': - cast_datatype = 'text[]' - elif sql_datatype in ('time without time zone[]', 'time with time zone[]'): - cast_datatype = 'text[]' - elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'): - cast_datatype = 'text[]' - elif sql_datatype == 'uuid[]': - cast_datatype = 'text[]' - - else: - #custom datatypes like enums - cast_datatype = 'text[]' - - sql_stmt = """SELECT $stitch_quote${}$stitch_quote$::{}""".format(elem, cast_datatype) - cur.execute(sql_stmt) - res = cur.fetchone()[0] - return res + elem = [elem[1:-1]] + reader = csv.reader(elem, delimiter=',', escapechar='\\' , quotechar='"') + array = next(reader) + array = [None if element.lower() == 'null' else element for element in array] + return array #pylint: disable=too-many-branches,too-many-nested-blocks def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): @@ -166,17 +116,21 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): #for ordinary bits, elem will == '1' return elem == '1' or elem == True if sql_datatype == 'boolean': - return elem + return bool(elem) if sql_datatype == 'hstore': - return create_hstore_elem(conn_info, elem) + return create_hstore_elem(elem) if 'numeric' in sql_datatype: - return decimal.Decimal(str(elem)) - if isinstance(elem, int): - return elem - if isinstance(elem, float): - return elem - if isinstance(elem, str): - return elem + return decimal.Decimal(elem) + if sql_datatype == 'money': + return decimal.Decimal(elem[1:]) + if sql_datatype in ('integer', 'smallint', 'bigint'): + return int(elem) + if sql_datatype in ('double precision', 'real', 'float'): + return float(elem) + if sql_datatype in ('text', 'character varying'): + return elem # return as string + if sql_datatype in ('cidr', 'citext', 'json', 'jsonb', 'inet', 'macaddr', 'uuid'): + return elem # return as string raise Exception("do not know how to marshall value of type {}".format(elem.__class__)) @@ -189,7 +143,7 @@ def selected_array_to_singer_value(elem, sql_datatype, conn_info): def selected_value_to_singer_value(elem, sql_datatype, conn_info): #are we dealing with an array? if sql_datatype.find('[]') > 0: - cleaned_elem = create_array_elem(elem, sql_datatype, conn_info) + cleaned_elem = create_array_elem(elem) return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), (cleaned_elem or []))) return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info) diff --git a/tests/test_logical_replication.py b/tests/test_logical_replication.py new file mode 100644 index 0000000..e4df816 --- /dev/null +++ b/tests/test_logical_replication.py @@ -0,0 +1,110 @@ +from decimal import Decimal +import unittest +from unittest.mock import patch + +from utils import get_test_connection_config +from tap_postgres.sync_strategies import logical_replication + + +class TestHandlingArrays(unittest.TestCase): + def setUp(self): + self.env = patch.dict( + 'os.environ', { + 'TAP_POSTGRES_HOST':'test', + 'TAP_POSTGRES_USER':'test', + 'TAP_POSTGRES_PASSWORD':'test', + 'TAP_POSTGRES_PORT':'5432' + }, + ) + + self.arrays = [ + '{10,01,NULL}', + '{t,f,NULL}', + '{127.0.0.1/32,10.0.0.0/32,NULL}', + '{CASE_INSENSITIVE,case_insensitive,NULL,"CASE,,INSENSITIVE"}', + '{2000-12-31,2001-01-01,NULL}', + '{3.14159265359,3.1415926,NULL}', + '{"\\"foo\\"=>\\"bar\\"","\\"baz\\"=>NULL",NULL}', + '{1,2,NULL}', + '{9223372036854775807,NULL}', + '{198.24.10.0/24,NULL}', + '{"{\\"foo\\":\\"bar\\"}",NULL}', + '{"{\\"foo\\": \\"bar\\"}",NULL}', + '{08:00:2b:01:02:03,NULL}', + '{$19.99,NULL}', + '{19.9999999,NULL}', + '{3.14159,NULL}', + '{0,1,NULL}', + '{foo,bar,NULL,"foo,bar","diederik\'s motel "}', + '{16:38:47,NULL}', + '{"2019-11-19 11:38:47-05",NULL}', + '{123e4567-e89b-12d3-a456-426655440000,NULL}' + ] + + self.sql_datatypes = { + 'bit[]': bool, + 'boolean[]': bool, + 'cidr[]': str, + 'citext[]': str, + 'date[]': str, + 'double precision[]': float, + 'hstore[]': dict, + 'integer[]': int, + 'bigint[]': int, + 'inet[]': str, + 'json[]': str, + 'jsonb[]': str, + 'macaddr[]': str, + 'money[]': Decimal, + 'numeric[]': Decimal, + 'real[]': float, + 'smallint[]': int, + 'text[]': str, + 'time with time zone[]': str, + 'timestamp with time zone[]': str, + 'uuid[]': str, + } + + def test_create_array_elem(self): + expected_arrays = [ + ['10', '01' ,None], + ['t', 'f', None], + ['127.0.0.1/32', '10.0.0.0/32', None], + ['CASE_INSENSITIVE', 'case_insensitive', None,"CASE,,INSENSITIVE"], + ['2000-12-31', '2001-01-01', None], + ['3.14159265359','3.1415926', None], + ['"foo"=>"bar"', '"baz"=>NULL', None], + ['1','2',None], + ['9223372036854775807', None], + ['198.24.10.0/24', None], + ["{\"foo\":\"bar\"}", None], + ["{\"foo\": \"bar\"}", None], + ['08:00:2b:01:02:03', None], + ['$19.99', None], + ['19.9999999', None], + ['3.14159', None], + ['0','1', None], + ['foo','bar',None,"foo,bar","diederik\'s motel "], + ['16:38:47',None], + ["2019-11-19 11:38:47-05",None], + ['123e4567-e89b-12d3-a456-426655440000', None], + ] + for elem, expected_array in zip(self.arrays, expected_arrays): + array = logical_replication.create_array_elem(elem) + self.assertEqual(array, expected_array) + + def test_selected_value_to_singer_value_impl(self): + with self.env: + conn_info = get_test_connection_config() + for elem, sql_datatype in zip(self.arrays, self.sql_datatypes.keys()): + array = logical_replication.selected_value_to_singer_value(elem, sql_datatype, conn_info) + + for element in array: + python_datatype = self.sql_datatypes[sql_datatype] + if element: + self.assertIsInstance(element, python_datatype) + +if __name__== "__main__": + test1 = TestHandlingArrays() + test1.setUp() + test1.test_selected_value_to_singer_value_impl()