# Copyright 2017 Okera Inc.
from __future__ import absolute_import
import okera
import datetime
import os
import random
import pytz
# TODO: we need to add this to the install dependencies
import certifi
import urllib3
import xml.dom.minidom
from decimal import Context, Decimal
from collections import OrderedDict
from okera._util import get_logger_and_init_null
from okera._thrift_api import (
TGetDatabasesParams, TGetRegisteredObjectsParams, TGetTablesParams, TNetworkAddress,
TPlanRequestParams, TRequestType,
TExecDDLParams, TExecTaskParams, TFetchParams, TListFilesOp, TListFilesParams,
TRecordServiceException, TRecordFormat, TTypeId,
OkeraRecordServicePlanner, RecordServiceWorker)
from okera._thrift_util import (
create_socket, get_transport, TTransportException, TBinaryProtocol,
PlannerClient, WorkerClient, KERBEROS_NOT_ENABLED_MSG, SOCKET_READ_ZERO)
from .concurrency import (BaseBackgroundTask,
ConcurrencyController,
default_max_client_process_count)
_log = get_logger_and_init_null(__name__)
""" Context for this user session."""
[docs]class OkeraContext():
def __init__(self, application_name, tz=pytz.utc):
_log.debug('Creating okera context')
self.__auth = None
self.__service_name = None
self.__token = None
self.__host_override = None
self.__user = None
self.__name = application_name
self.__configure()
self.__tz = tz
[docs] def enable_kerberos(self, service_name, host_override=None):
"""Enable kerberos based authentication.
Parameters
----------
service_name : str
Authenticate to a particular `okera` service principal. This is typically
the first part of the 3-part service principal (SERVICE_NAME/HOST@REALM).
host_override : str, optional
If set, the HOST portion of the server's service principal. If not set,
then this is the resolved DNS name of the service being connected to.
Returns
-------
OkeraContext
Returns this object.
"""
if not service_name:
raise ValueError("Service name must be specified.")
self.__auth = 'GSSAPI'
self.__service_name = service_name
self.__host_override = host_override
self.__user = None
_log.debug('Enabled kerberos')
return self
[docs] def enable_token_auth(self, token_str=None, token_file=None):
"""Enables token based authentication.
Parameters
----------
token_str : str, optional
Authentication token to use.
token_file : str, optional
File containing token to use.
Returns
-------
OkeraContext
Returns this object.
"""
if not token_str and not token_file:
raise ValueError("Must specify token_str or token_file")
if token_str and token_file:
raise ValueError("Cannot specify both token_str token_file")
if token_file:
with open(os.path.expanduser(token_file), 'r') as t:
token_str = t.read()
self.__configure_token(token_str.strip())
_log.debug('Enabled token auth')
return self
[docs] def disable_auth(self):
""" Disables authentication.
Returns
-------
OkeraContext
Returns this object.
"""
self.__auth = None
self.__token = None
self.__service_name = None
self.__host_override = None
self.__user = None
_log.debug('Disabled auth')
return self
[docs] def get_auth(self):
""" Returns the configured auth mechanism. None if no auth is enabled."""
return self.__auth
[docs] def get_token(self):
""" Returns the token string. Note that logging this should be done with care."""
return self.__token
[docs] def get_name(self):
""" Returns name of this application. This is recorded for diagnostics on
the server.
"""
return self.__name
def _get_user(self):
""" Returns the user name. This is ignored if authentication is enabled. """
return self.__user
[docs] def get_timezone(self):
return self.__tz
[docs] def connect(self, host='localhost', port=12050, timeout=None):
"""Get a connection to an ODAS cluster. This connects to the planner service.
Parameters
----------
host : str or list of hostnames
The hostname for the planner. If a list is specified, picks a planner at
random.
port : int, optional
The port number for the planner. The default is 12050.
timeout : int, optional
Connection timeout in seconds. Default is no timeout.
Returns
-------
PlannerConnection
Handle to a connection. Users should call `close()` when done.
"""
host, port = self.__pick_host(host, port)
# Convert from user names to underlying transport names
auth_mechanism = self.__get_auth()
_log.debug('Connecting to planner %s:%s with %s authentication '
'mechanism', host, port, auth_mechanism)
sock = create_socket(host, port, timeout, False, None)
transport = None
try:
transport = get_transport(sock, host, auth_mechanism, self.__service_name,
None, None, self.__token, self.__host_override)
transport.open()
protocol = TBinaryProtocol(transport)
service = _ThriftService(PlannerClient(OkeraRecordServicePlanner, protocol))
planner = PlannerConnection(service, self)
planner.set_application(self.__name)
return planner
except (TTransportException, IOError) as e:
sock.close()
if transport:
transport.close()
self.__handle_transport_exception(e)
raise e
except:
sock.close()
if transport:
transport.close()
raise
[docs] def connect_worker(self, host='localhost', port=13050, timeout=None):
"""Get a connection to ODAS worker.
Most users should not need to call this API directly.
Parameters
----------
host : str or list of hostnames
The hostname for the worker. If a list is specified, picks a worker at
random.
port : int, optional
The port number for the worker. The default is 13050.
timeout : int, optional
Connection timeout in seconds. Default is no timeout.
Returns
-------
WorkerConnection
Handle to a worker connection. Users must call `close()` when done.
"""
return self._connect_worker(host, port, timeout=timeout)
def _connect_worker(self, host, port, timeout=None, options=None):
host, port = self.__pick_host(host, port, options)
auth_mechanism = self.__get_auth()
_log.debug('Connecting to worker %s:%s with %s authentication '
'mechanism', host, port, auth_mechanism)
sock = create_socket(host, port, timeout, False, None)
transport = None
try:
transport = get_transport(sock, host, auth_mechanism, self.__service_name,
None, None, self.__token, self.__host_override)
transport.open()
protocol = TBinaryProtocol(transport)
service = _ThriftService(WorkerClient(RecordServiceWorker, protocol))
worker = WorkerConnection(service, self)
worker.set_application(self.__name)
return worker
except (TTransportException, IOError) as e:
sock.close()
if transport:
transport.close()
self.__handle_transport_exception(e)
raise e
except:
sock.close()
if transport:
transport.close()
raise
@staticmethod
def __pick_host(host, port, options=None):
"""
Returns a host, port from the input. host can be a string or a list of strings.
If it is a list, a host is picked from the list. If the host string contains the
port that port is used, otherwise, the port argument is used.
"""
if not host:
raise ValueError("host must be specified")
if isinstance(host, list):
chosen_host = host[0]
if isinstance(chosen_host, TNetworkAddress):
# With this option, we want to pin a host instead of picking a random one.
if options and 'PIN_HOST' in options:
host.sort(key = lambda v: v.hostname)
chosen_host = host[0]
else:
chosen_host = random.choice(host)
host = chosen_host.hostname
port = chosen_host.port
elif isinstance(chosen_host, str):
if options and 'PIN_HOST' in options:
host.sort()
host = host[0]
else:
host = random.choice(host)
host = chosen_host
else:
raise ValueError("host list must be TNetworkAddress objects or strings.")
if isinstance(host, str):
parts = host.split(':')
if len(parts) == 2:
host = parts[0]
port = int(parts[1])
elif len(parts) == 1:
host = parts[0]
if port is None:
raise ValueError("port must be specified")
else:
raise ValueError("Invalid host: %s " % host)
else:
raise ValueError("Invalid host: %s" % host)
return host, port
def __configure(self):
""" Configures the context based on system wide settings"""
home = os.path.expanduser("~")
token_file = os.path.join(home, '.cerebro', 'token')
if os.path.exists(token_file):
# TODO: we could catch this exception and go on but having this file be
# messed up here is likely something to fix ASAP.
with open(token_file, 'r') as t:
self.__configure_token(t.read().strip())
_log.info("Configured token auth with token in home directory.")
def __configure_token(self, token):
# Valid authentication tokens contain '.' in them, either an Okera token or a JWT
# token. For API convenience, we use the token value to mean user (plain text)
# when run against unauthenticated servers.
if '.' in token:
self.__token = token
self.__auth = 'TOKEN'
self.__service_name = 'cerebro'
self.__user = None
else:
self.__token = None
self.__auth = None
self.__user = token
self.__host_override = None
def __handle_transport_exception(self, e):
""" Maps transport layer exceptions to better user facing ones. """
if self.__auth and e.message == SOCKET_READ_ZERO:
e.message = "Server did not respond to authentication handshake. " + \
"Ensure server has authentication enabled."
elif not self.__auth and e.message == SOCKET_READ_ZERO:
e.message = "Client does not have authentication enabled but it appears " + \
"the server does. Enable client authentication."
elif self.__auth == 'GSSAPI' and KERBEROS_NOT_ENABLED_MSG in e.message:
e.message = "Client is authenticating with kerberos but kerberos is not " + \
"enabled on the server."
raise e
def __get_auth(self):
""" Canonicalizes user facing auth names to transport layer ones """
auth_mechanism = self.__auth
if not auth_mechanism:
auth_mechanism = 'NOSASL'
if auth_mechanism == 'TOKEN':
auth_mechanism = 'DIGEST-MD5'
return auth_mechanism
class _ThriftService():
""" Wrapper around a thrift service client object """
def __init__(self, thrift_client, retries=3):
self.client = thrift_client
self.retries = retries
def close(self):
# pylint: disable=protected-access
_log.debug('close_service: client=%s', self.client)
self.client._iprot.trans.close()
def reconnect(self):
# pylint: disable=protected-access
_log.debug('reconnect: client=%s', self.client)
self.client._iprot.trans.close()
self.client._iprot.trans.open()
[docs]class PlannerConnection():
"""A connection to an ODAS planner. """
def __init__(self, thrift_service, ctx):
self.service = thrift_service
self.ctx = ctx
self.http_pool = urllib3.PoolManager(cert_reqs='CERT_REQUIRED',
ca_certs=certifi.where())
_log.debug('PlannerConnection(service=%s)', self.service)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
[docs] def close(self):
"""Close the session and server connection."""
_log.debug('Closing Planner connection')
self.service.close()
def _reconnect(self):
self.service.reconnect()
def _underlying_client(self):
""" Returns the underlying thrift client. Exposed for internal use. """
return self.service.client
[docs] def get_protocol_version(self):
"""Returns the RPC API version of the server."""
return self.service.client.GetProtocolVersion()
[docs] def set_application(self, name):
"""Sets the name of this session. Used for logging purposes on the server."""
self.service.client.SetApplication(name)
[docs] def ls(self, path):
""" Lists the files in this directory
Parameters
----------
path : str
The path to list.
Returns
-------
list(str)
List of files located at this path.
"""
if not path:
raise ValueError("path must be specified.")
params = TListFilesParams()
params.op = TListFilesOp.LIST
params.object = path
if self.ctx._get_user():
params.requesting_user = self.ctx._get_user()
return self.service.client.ListFiles(params).files
[docs] def open(self, path, preload_content=True, version=None):
""" Returns the object at this path as a byte stream
Parameters
----------
path : str
The path to the file to open.
Returns
-------
object
Returns an object that behaves like an opened urllib3 stream.
"""
if not path:
raise ValueError("path must be specified.")
params = TListFilesParams()
params.op = TListFilesOp.READ
params.object = path
params.version_id = version
if self.ctx._get_user():
params.requesting_user = self.ctx._get_user()
try:
urls = self.service.client.ListFiles(params).files
if urls and len(urls) != 1:
raise ValueError(
"Unexpected result from server. Expecting at most one url.")
return self.http_pool.request('GET', urls[0],
preload_content=preload_content)
except TRecordServiceException as ex:
if not ex.detail.startswith('AuthorizationException'):
raise ex
# This request (for the path) failed with an authorization exception,
# meaning this user does not have full access to the path. Try to see
# if this user has access to a table over this path.
objs = self.get_catalog_objects_at(path, True)
if not objs or path not in objs:
raise ex
for obj in objs[path]:
if '.' in obj:
return OkeraFsStream(self, obj)
# No object found, raise original exception
raise ex
[docs] def cat(self, path, as_utf8=True):
""" Returns the object at this path as a string
Parameters
----------
path : str
The path to the file to read.
as_utf8 : bool
If true, convert the returned data as a utf-8 string (instead of binary)
Returns
-------
str
Returns the contents at the path as a string.
"""
result = self.open(path)
if result.status != 200:
if 'Content-Type' in result.headers:
if result.headers['Content-Type'] == 'application/xml':
msg = result.data.decode('utf-8').replace('\n', '')
tree = xml.dom.minidom.parseString(msg).toprettyxml(indent=' ')
raise ValueError("Could not read from path: %s\n\n%s" % (path, tree))
raise ValueError("Could not read from path: %d" % result.status)
if not result.data:
if as_utf8:
return b""
return ""
# Check the types to avoid some double serialization
if isinstance(result.data, str):
if as_utf8:
# Both UTF-8
if result.data.endswith('\n'):
return result.data[:-1]
return result.data
else:
if result.data.endswith('\n'):
return result.data[:-1].encode('utf-8')
return result.data.encode('utf-8')
# Result is binary
if result.data.endswith(b'\n'):
if as_utf8:
return result.data[:-1].decode('utf-8')
return result.data[:-1]
if as_utf8:
return result.data.decode('utf-8')
else:
return result.data
[docs] def get_catalog_objects_at(self, path_prefix, include_views=False):
""" Returns the objects (databases or datasets) thats registered with this
prefix path.
Parameters
----------
path_prefix : str
The path prefix to look up objects defined with this prefix.
include_views : bool
If true, also return views at this path.
Returns
-------
map(str, list(str))
For each path with a catalog objects, the list of objects located at that
path. Empty map if there are none.
"""
if not path_prefix:
raise ValueError("path_prefix must be specified.")
params = TGetRegisteredObjectsParams()
params.prefix_path = path_prefix
params.include_views = include_views
if self.ctx._get_user():
params.requesting_user = self.ctx._get_user()
return self.service.client.GetRegisteredObjects(params).object_names
[docs] def list_databases(self):
"""Lists all the databases in the catalog
Returns
-------
list(str)
List of database names.
Examples
--------
>>> import okera
>>> ctx = okera.context()
>>> with ctx.connect(host = 'localhost', port = 12050) as conn:
... dbs = conn.list_databases()
... 'okera_sample' in dbs
True
"""
request = TGetDatabasesParams()
if self.ctx._get_user():
request.requesting_user = self.ctx._get_user()
result = self.service.client.GetDatabases(request)
dbs = []
for db in result.databases:
dbs.append(db.name[0])
return dbs
[docs] def list_dataset_names(self, db, filter=None):
""" Returns the names of the datasets in this db
Parameters
----------
db : str
Name of database to return datasets in.
filter : str, optional
Substring filter on names to of datasets to return.
Returns
-------
list(str)
List of dataset names.
Examples
--------
>>> import okera
>>> ctx = okera.context()
>>> with ctx.connect(host = 'localhost', port = 12050) as conn:
... datasets = conn.list_dataset_names('okera_sample')
... datasets
['okera_sample.sample', 'okera_sample.users', 'okera_sample.users_ccn_masked', 'okera_sample.whoami']
"""
request = TGetTablesParams()
request.database = [db]
request.filter = filter
if self.ctx._get_user():
request.requesting_user = self.ctx._get_user()
tables = self.service.client.GetTables(request).tables
result = []
for t in tables:
result.append(db + '.' + t.name)
return result
[docs] def list_datasets(self, db, filter=None):
""" Returns the datasets in this db
Parameters
----------
db : str
Name of database to return datasets in.
filter : str, optional
Substring filter on names to of datasets to return.
Returns
-------
obj
Thrift dataset objects.
Note
-------
This API is subject to change and the returned object may not be backwards
compatible.
"""
request = TGetTablesParams()
request.database = [db]
request.filter = filter
if self.ctx._get_user():
request.requesting_user = self.ctx._get_user()
tables = self.service.client.GetTables(request)
return tables
[docs] def plan(self, request, max_task_count=None, requesting_user=None):
""" Plans the request to read from CDAS
Parameters
----------
request : str, required
Name of dataset or SQL statement to plan scan for.
requesting_user : str, optional
Name of user to request plan for, if different from
the current user.
Returns
-------
object
Thrift serialized plan object.
Note
-------
This API is subject to change and the returned object may not be backwards
compatible.
"""
if not request:
raise ValueError("request must be specified.")
params = TPlanRequestParams()
params.request_type = TRequestType.Sql
if max_task_count:
params.max_tasks = max_task_count
if requesting_user:
params.requesting_user = requesting_user
elif self.ctx._get_user():
params.requesting_user = self.ctx._get_user()
request = request.strip()
if request.lower().startswith('select '):
_log.debug('Planning request for query: %s', request)
params.sql_stmt = request
else:
_log.debug('Planning request to read dataset: %s', request)
params.sql_stmt = "SELECT * FROM " + request
plan = self.service.client.PlanRequest(params)
_log.debug('Plan complete. Number of tasks: %d', len(plan.tasks))
return plan
[docs] def execute_ddl(self, sql):
# pylint: disable=line-too-long
""" Execute a DDL statement against the server.
Parameters
----------
sql : str
DDL statement to run
Returns
-------
list(list(str))
Returns the result as a table.
Examples
--------
>>> import okera
>>> ctx = okera.context()
>>> with ctx.connect(host = 'localhost', port = 12050) as conn:
... result = conn.execute_ddl('describe okera_sample.users')
... result
[['uid', 'string', 'Unique user id'], ['dob', 'string', 'Formatted as DD-month-YY'], ['gender', 'string', ''], ['ccn', 'string', 'Sensitive data, should not be accessible without masking.']]
"""
# pylint: enable=line-too-long
if not sql:
raise ValueError("Must specify sql string to execute_ddl")
request = TExecDDLParams()
request.ddl = sql
if self.ctx._get_user():
request.requesting_user = self.ctx._get_user()
response = self.service.client.ExecuteDDL2(request)
return response.tabular_result
[docs] def execute_ddl_table_output(self, sql):
""" Execute a DDL statement against the server.
Parameters
----------
sql : str
DDL statement to run
Returns
-------
PrettyTable
Returns the result as a table object.
Examples
--------
>>> import okera
>>> ctx = okera.context()
>>> with ctx.connect(host = 'localhost', port = 12050) as conn:
... result = conn.execute_ddl_table_output('describe okera_sample.users')
... print(result)
+--------+--------+-----------------------------------------------------------+
| name | type | comment |
+--------+--------+-----------------------------------------------------------+
| uid | string | Unique user id |
| dob | string | Formatted as DD-month-YY |
| gender | string | |
| ccn | string | Sensitive data, should not be accessible without masking. |
+--------+--------+-----------------------------------------------------------+
"""
from prettytable import PrettyTable
if not sql:
raise ValueError("Must specify sql string to execute_ddl")
request = TExecDDLParams()
request.ddl = sql
if self.ctx._get_user():
request.requesting_user = self.ctx._get_user()
response = self.service.client.ExecuteDDL2(request)
if not response.col_names:
return None
t = PrettyTable(response.col_names)
for row in response.tabular_result:
t.add_row(row)
return t
[docs] def scan_as_pandas(self,
request,
max_records=None,
max_client_process_count=default_max_client_process_count(),
max_task_count=None,
requesting_user=None,
options=None,
ignore_errors=False,
warnings=None,
strings_as_utf8=False):
"""Scans data, returning the result for pandas.
Parameters
----------
request : string, required
Name of dataset or SQL statement to scan.
max_records : int, optional
Maximum number of records to return. Default is unlimited.
options : dictionary, optional
Optional key/value configs to specify to the request. Note that these
options are not guaranteed to be backwards compatible.
warnings : list(string), optional
If not None, will be populated with any warnings generated for request.
Returns
-------
pandas DataFrame
Data returned as a pandas DataFrame object
Examples
--------
>>> import okera
>>> ctx = okera.context()
>>> with ctx.connect(host = 'localhost', port = 12050) as conn:
... pd = conn.scan_as_pandas('select * from okera_sample.sample')
... print(pd)
record
0 b'This is a sample test file.'
1 b'It should consist of two lines.'
"""
import pandas
plan = self.plan(request,
max_task_count=max_task_count,
requesting_user=requesting_user)
self._ensure_serialization_support(plan)
# Return any warnings if the user is interested
if warnings is not None and plan.warnings:
for warning in plan.warnings:
warnings.append(warning.message)
concurrency_ctl = self._get_concurrency_controller_for_plan(
plan, max_client_process_count)
for task in plan.tasks:
_log.debug('Executing task %s', str(task.task_id))
concurrency_ctl.enqueueTask(PandasScanTask(self.ctx,
plan.hosts,
task,
max_records,
options,
strings_as_utf8))
result_list = self._start_and_wait_for_results(concurrency_ctl,
len(plan.tasks),
limit=max_records,
is_pandas=True,
ignore_errors=ignore_errors)
if not result_list:
col_names = []
for col in plan.schema.cols:
col_names.append(col.name)
return pandas.DataFrame(columns=col_names)
else:
return pandas.concat(result_list).head(max_records)
[docs] def scan_as_json(self,
request,
max_records=None,
warnings=None,
max_client_process_count=default_max_client_process_count(),
max_task_count=None,
requesting_user=None,
ignore_errors=False):
# pylint: disable=line-too-long
"""Scans data, returning the result in json format.
Parameters
----------
request : string, required
Name of dataset or SQL statement to scan.
max_records : int, optional
Maximum number of records to return. Default is unlimited.
warnings : list(string), optional
If not None, will be populated with any warnings generated for request.
Returns
-------
list(obj)
Data returned as a list of JSON objects
Examples
--------
>>> import okera
>>> ctx = okera.context()
>>> with ctx.connect(host = 'localhost', port = 12050) as conn:
... data = conn.scan_as_json('okera_sample.sample')
... data
[{'record': 'This is a sample test file.'}, {'record': 'It should consist of two lines.'}]
"""
# pylint: enable=line-too-long
plan = self.plan(request,
max_task_count=max_task_count,
requesting_user=requesting_user)
self._ensure_serialization_support(plan)
concurrency_ctl = self._get_concurrency_controller_for_plan(
plan, max_client_process_count)
# Return any warnings if the user is interested
if warnings is not None and plan.warnings:
for warning in plan.warnings:
warnings.append(warning.message)
if len(plan.tasks) <= 0:
return []
for task in plan.tasks:
_log.debug('Executing task %s', str(task.task_id))
concurrency_ctl.enqueueTask(JsonScanTask(self.ctx,
plan.hosts,
task,
max_records))
res = self._start_and_wait_for_results(concurrency_ctl,
len(plan.tasks),
limit=max_records,
ignore_errors=ignore_errors)
if max_records is not None:
return res[:max_records]
return res
@staticmethod
def _get_concurrency_controller_for_plan(plan, max_client_process_count):
worker_count = min(max_client_process_count, len(plan.tasks))
return ConcurrencyController(worker_count=worker_count)
@staticmethod
def _calculate_limit(current_limit, results, is_pandas):
if current_limit is None:
return None
if is_pandas:
if len(results) == 0:
return current_limit
return max(current_limit-len(results[0].index), 0)
return max(current_limit-len(results), 0)
@staticmethod
def _start_and_wait_for_results(concurrency_ctl,
task_count,
limit=None,
is_pandas=False,
ignore_errors=False):
results = []
if not task_count:
return results
try:
task_result_count = 0
is_completed = False
# We're setting the limit value into a shared dict that is handed
# to all async tasks. The main thread (this function) is responsible
# for updating the value as records are received. The async tasks
# will read this `limit` value from the dict and pass it as the
# `max_records` param to the worker. As records are received in the
# main thread, the `limit` value will decrease until it is zero, at
# which point the remaining async tasks will immediately return
# with empty results.
concurrency_ctl.metrics_dict['limit'] = limit
# All default values need to be set prior to calling start(). This is
# because there will be a delay when setting any shared value between
# processes because the values are cached in each process.
concurrency_ctl.start()
while not is_completed:
res = concurrency_ctl.output_queue.get()
if res is not None:
task_result_count += 1
limit = PlannerConnection._calculate_limit(limit, res, is_pandas)
results.extend(res)
if task_result_count == task_count:
is_completed = True
concurrency_ctl.metrics_dict['limit'] = limit
if not ignore_errors and concurrency_ctl.errors_queue.qsize():
print('One or more errors occurred while processing this query:')
while concurrency_ctl.errors_queue.qsize():
err = concurrency_ctl.errors_queue.get()
print('{0}'.format(err))
raise err
finally:
concurrency_ctl.stop()
return results
@staticmethod
def _ensure_serialization_support(plan):
if not plan.supported_result_formats or \
TRecordFormat.ColumnarNumPy not in plan.supported_result_formats:
raise IOError("PyOkera requires the server to support the " +
"`ColumnarNumPy` serialization format. Please upgrade the " +
"server to at least 0.8.1.")
[docs]class WorkerConnection():
"""A connection to a CDAS worker. """
def __init__(self, thrift_service, ctx):
self.service = thrift_service
self.ctx = ctx
_log.debug('WorkerConnection(service=%s)', self.service)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
[docs] def close(self):
"""Close the session and server connection."""
_log.debug('Closing Worker connection')
self.service.close()
def _reconnect(self):
self.service.reconnect()
[docs] def get_protocol_version(self):
"""Returns the RPC API version of the server."""
return self.service.client.GetProtocolVersion()
[docs] def set_application(self, name):
"""Sets the name of this session. Used for logging purposes on the server."""
self.service.client.SetApplication(name)
[docs] def exec_task(self, task, max_records=None):
""" Executes a task to begin scanning records.
Parameters
----------
task : obj
Description of task. This is the result from the planner's plan() call.
max_records: int, optional
Maximum number of records to return for this task. Default is unlimited.
Returns
-------
object
Handle for this task. Used in subsequent API calls.
object
Schema for records returned from this task.
"""
request = TExecTaskParams()
request.task = task.task
request.limit = max_records
request.fetch_size = 20000
request.record_format = TRecordFormat.ColumnarNumPy
result = self.service.client.ExecTask(request)
return result.handle, result.schema
[docs] def close_task(self, handle):
""" Closes the task. """
self.service.client.CloseTask(handle)
[docs] def fetch(self, handle):
""" Fetch the next batch of records for this task. """
request = TFetchParams()
request.handle = handle
return self.service.client.Fetch(request)
def _columnar_batch_to_python(schema, columnar_records, num_records,
ctx_tz=pytz.utc, strings_as_utf8=False):
# Issues with numpy, thrift and this function being perf optimized
# pylint: disable=no-member
# pylint: disable=protected-access
# pylint: disable=too-many-locals
import numpy
cols = columnar_records.cols
# Things we will return.
col_names = []
# Checks if any of the values in this batch are null. Handling NULL can be
# noticeably slower, so skip it in bulk if possible.
any_nulls = []
is_nulls = [None] * len(cols)
data = [None] * len(cols)
# For each column seen, the index to append to it. Empty means nothing to append.
# The planner does not need to generate unique column names in all cases. e.g.
# 'select c1, c1 from t' will generate two columns called 'c1'. We need to dedup
# here as we put the columns in a dictionary.
# In this case we will name the second "c1_2"
col_names_dedup = {}
# Go over each column and convert the binary data to python objects. This is very
# perf sensitive.
for col in range(0, len(cols)):
buf = cols[col].data
if isinstance(buf, str):
buf = buf.encode()
name = schema.cols[col].name
if name not in col_names_dedup:
col_names_dedup[name] = 2
else:
# Keep resolving to dedup
while name in col_names_dedup:
idx = col_names_dedup[name]
col_names_dedup[name] = idx + 1
name += '_' + str(idx)
col_names.append(name)
is_null = numpy.frombuffer(cols[col].is_null.encode(), dtype=numpy.bool)
any_nulls.append(numpy.any(is_null))
is_nulls[col] = is_null
t = schema.cols[col].type.type_id
if t == TTypeId.STRING or t == TTypeId.VARCHAR:
off = 4 * num_records
column = [numpy.nan] * num_records
lens = numpy.frombuffer(buf[0: off], dtype=numpy.int32)
if any_nulls[col]:
for i in range(0, num_records):
if not is_null[i]:
length = lens[i]
column[i] = buf[off:off + length]
if strings_as_utf8:
column[i] = column[i].decode('utf-8')
off += length
else:
for i in range(0, num_records):
length = lens[i]
column[i] = buf[off:off + length]
if strings_as_utf8:
column[i] = column[i].decode('utf-8')
off += length
if strings_as_utf8:
data[col] = column
else:
data[col] = numpy.array(column, dtype=object)
elif t == TTypeId.CHAR:
off = 0
column = [numpy.nan] * num_records
length = schema.cols[col].type.len
if any_nulls[col]:
for i in range(0, num_records):
if not is_null[i]:
column[i] = buf[off:off + length]
off += length
else:
for i in range(0, num_records):
column[i] = buf[off:off + length]
off += length
if strings_as_utf8:
data[col] = ''.join(column)
else:
data[col] = numpy.array(column, dtype=object)
elif t == TTypeId.BOOLEAN:
data[col] = numpy.frombuffer(buf, dtype=numpy.bool)
elif t == TTypeId.TINYINT:
data[col] = numpy.frombuffer(buf, dtype=numpy.int8)
elif t == TTypeId.SMALLINT:
data[col] = numpy.frombuffer(buf, dtype=numpy.int16)
elif t == TTypeId.INT:
data[col] = numpy.frombuffer(buf, dtype=numpy.int32)
elif t == TTypeId.BIGINT:
data[col] = numpy.frombuffer(buf, dtype=numpy.int64)
elif t == TTypeId.FLOAT:
data[col] = numpy.frombuffer(buf, dtype=numpy.float32)
elif t == TTypeId.DOUBLE:
data[col] = numpy.frombuffer(buf, dtype=numpy.float64)
elif t == TTypeId.TIMESTAMP_NANOS:
dt = numpy.dtype([('millis', numpy.int64), ('nanos', numpy.int32)])
values = numpy.frombuffer(buf, dtype=dt)
millis = values['millis']
column = [numpy.nan] * num_records
for i in range(0, num_records):
if not is_null[i]:
# TODO: use nanos?
column[i] = datetime.datetime.fromtimestamp(millis[i] / 1000.0,
ctx_tz)
data[col] = column
elif t == TTypeId.DECIMAL:
column = [numpy.nan] * num_records
scale = -schema.cols[col].type.scale
if schema.cols[col].type.precision <= 18:
if schema.cols[col].type.precision <= 9:
values = numpy.frombuffer(buf, dtype=numpy.int32)
elif schema.cols[col].type.precision <= 18:
values = numpy.frombuffer(buf, dtype=numpy.int64)
for i in range(0, num_records):
if not is_null[i]:
column[i] = Decimal(int(values[i])).scaleb(scale)
else:
# These decimals are stored as up to 128 bits with two longs back
# to back. This needs to be reconstructed and we want to compute:
# v = longs[i*2+1] << 64 + longs[i*2]
# This is done carefully to avoid overflow.
ctx = Context(schema.cols[col].type.precision)
multiple = ctx.power(2, 64)
longs = numpy.frombuffer(buf, dtype=numpy.int64)
for i in range(0, num_records):
if is_null[i]:
continue
v = Decimal(int(longs[i * 2 + 1])) * multiple + longs[i * 2]
column[i] = v.scaleb(scale)
data[col] = column
else:
raise RuntimeError("Unsupported type: " + TTypeId._VALUES_TO_NAMES[t])
return col_names, data, any_nulls, is_nulls
[docs]def context(application_name=None):
""" Gets the top level context object to use pyokera.
Parameters
----------
application_name : str, optional
Name of this application. Used for logging and diagnostics.
Returns
-------
OkeraContext
Context object.
Examples
--------
>>> import okera
>>> ctx = okera.context()
>>> ctx # doctest: +ELLIPSIS
<okera.odas.OkeraContext object at 0x...>
"""
if not application_name:
application_name = 'pyokera (%s)' % version()
return OkeraContext(application_name)
[docs]def version():
""" Returns version string of this library. """
from . import __version__
return __version__
[docs]class ScanTask(BaseBackgroundTask):
def __init__(self, name, ctx, plan_hosts, task, max_records, options):
BaseBackgroundTask.__init__(self, "ScanTask.{0}".format(name))
self.ctx = ctx
self.plan_hosts = plan_hosts
self.task = task
self.max_records = max_records
self.options = options
self.errors = []
def __call__(self):
results = []
total = 0
if self.max_records is not None and self.max_records <= 0:
return results
with self.ctx._connect_worker(self.plan_hosts, None, options=self.options) as worker:
try:
handle, schema = worker.exec_task(self.task, self.max_records)
while True:
fetch_result = worker.fetch(handle)
assert fetch_result.record_format == TRecordFormat.ColumnarNumPy
if fetch_result.num_records:
t_results = self.deserialize(schema,
fetch_result.columnar_records,
fetch_result.num_records)
if t_results:
results.extend(t_results)
total += fetch_result.num_records
if fetch_result.done or (self.max_records and total >= self.max_records):
break
except Exception as ex:
self.errors.append(ex)
finally:
worker.close_task(handle)
return results
[docs] def deserialize(self, schema, columnar_records, num_records):
'''Abstract definition to deserialize the returned dataset'''
raise Exception('Invalid invocation of an abstract function: ' +
'BaseBackgroundTask::deserialize')
[docs]class JsonScanTask(ScanTask):
def __init__(self, ctx, plan_hosts, task, max_records):
ScanTask.__init__(self, "JsonScanTask", ctx, plan_hosts, task, max_records, None)
[docs] def deserialize(self, schema, columnar_records, num_records):
col_names, data, _, is_nulls = _columnar_batch_to_python(
schema, columnar_records, num_records, self.ctx.get_timezone())
num_cols = len(col_names)
result = []
# Go over each row and construct a python array as a row
for r in range(0, num_records):
row = [None] * num_cols
for c in range(0, num_cols):
if not is_nulls[c][r]:
datum = data[c][r]
row[c] = datum.decode('utf-8') if isinstance(datum, bytes) else datum
result.append(dict(zip(col_names, row)))
return result
[docs]class PandasScanTask(ScanTask):
def __init__(self, ctx, plan_hosts, task, max_records, options, strings_as_utf8):
ScanTask.__init__(self, "PandasScanTask", ctx, plan_hosts, task, max_records, options)
self.__strings_as_utf8 = strings_as_utf8
[docs] def deserialize(self, schema, columnar_records, num_records):
import numpy
import pandas
result = []
col_names, data, any_nulls, is_nulls = _columnar_batch_to_python(
schema, columnar_records, num_records, self.ctx.get_timezone(),
self.__strings_as_utf8)
df = pandas.DataFrame(OrderedDict(zip(col_names, data)))
if len(df):
for c in range(0, len(col_names)):
if not any_nulls[c] or df[col_names[c]].dtype == 'object':
# Either no nulls, or objects are already handled.
continue
if isinstance(df[col_names[c]][0], str):
continue
# Fix up nulls, replace with nan
# TODO: this is not the cheapest
df[col_names[c]] = df[col_names[c]].where(~is_nulls[c], other=numpy.nan)
result.append(df)
return result
[docs]class OkeraFsStream():
""" Wrapper object which behaves like a stream to send serialized results back
in a byte stream based API. The API is intended to be compatible with a
urllib stream object. """
def __init__(self, planner, tbl, delimiter=',', quote_strings=True):
# TODO: this needs to stream the result instead of all at once
self.planner = planner
self.tbl = tbl
self.status = 200
self.headers = {}
self.data = planner.scan_as_pandas(
tbl, max_task_count=1,strings_as_utf8=True).to_csv(
None, header=False, index=False)
[docs] def read(self, amt):
return self.data
Binary = memoryview