Skip to content

Commit

Permalink
修复odps查询不可用问题, 实例测试无效问题, 添加odps测试用例 (#1454)
Browse files Browse the repository at this point in the history
* 修复odps查询不可用问题, 实例测试无效问题

* odps 添加测试
  • Loading branch information
cpzt authored Apr 10, 2022
1 parent d40f48b commit 9709c56
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 4 deletions.
41 changes: 37 additions & 4 deletions sql/engines/odps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import re
import logging
import sqlparse

from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from .models import ResultSet

from odps import ODPS

Expand Down Expand Up @@ -37,16 +38,24 @@ def info(self):

def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet
ODPS只有project概念, 直接返回project名称
ODPS只有project概念, 直接返回project名称
TODO: 目前ODPS获取所有项目接口比较慢, 暂时支持返回一个project,后续再优化
"""
result = ResultSet()

try:
conn = self.get_connection(self.get_connection())
conn = self.get_connection()

# 判断project是否存在
db_exist = conn.exist_project(self.instance.db_name)

if db_exist is False:
raise ValueError(f"[{self.instance.db_name}]项目不存在")

result.rows = [conn.project]
except Exception as e:
logger.warning(f"ODPS执行异常, {e}")
result.rows = [self.instance.db_name]
result.error = str(e)
return result

def get_all_tables(self, db_name, **kwargs):
Expand Down Expand Up @@ -126,3 +135,27 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
result_set.error = str(e)
return result_set

def query_check(self, db_name=None, sql=''):
# 查询语句的检查、注释去除、切分
result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
keyword_warning = ''
sql_whitelist = ['select']
# 根据白名单list拼接pattern语句
whitelist_pattern = re.compile("^" + "|^".join(sql_whitelist), re.IGNORECASE)
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
result['filtered_sql'] = sql.strip()
# sql_lower = sql.lower()
except IndexError:
result['bad_query'] = True
result['msg'] = '没有有效的SQL语句'
return result
if whitelist_pattern.match(sql) is None:
result['bad_query'] = True
result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist))
return result
if result.get('bad_query'):
result['msg'] = keyword_warning
return result
96 changes: 96 additions & 0 deletions sql/engines/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sql.engines.oracle import OracleEngine
from sql.engines.mongo import MongoEngine
from sql.engines.clickhouse import ClickHouseEngine
from sql.engines.odps import ODPSEngine
from sql.models import Instance, SqlWorkflow, SqlWorkflowContent

User = get_user_model()
Expand Down Expand Up @@ -1882,3 +1883,98 @@ def test_execute_workflow_success(self, _conn, _cursor, _execute):
execute_result = new_engine.execute_workflow(workflow=wf)
self.assertIsInstance(execute_result, ReviewSet)
self.assertEqual(execute_result.rows[0].__dict__.keys(), row.__dict__.keys())


class ODPSTest(TestCase):
def setUp(self) -> None:
self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='odps',
host='some_host', port=9200, user='ins_user', db_name='some_db')
self.engine = ODPSEngine(instance=self.ins)

def tearDown(self) -> None:
self.ins.delete()

@patch('sql.engines.odps.ODPSEngine.get_connection')
def test_get_connection(self, mock_odps):
_ = self.engine.get_connection()
mock_odps.assert_called_once()

@patch('sql.engines.odps.ODPSEngine.get_connection')
def test_query(self, mock_get_connection):
test_sql = """select 123"""
self.assertIsInstance(self.engine.query('some_db', test_sql), ResultSet)

def test_query_check(self):
test_sql = """select 123; -- this is comment
select 456;"""

result_sql = "select 123;"

check_result = self.engine.query_check(sql=test_sql)

self.assertIsInstance(check_result, dict)
self.assertEqual(False, check_result.get("bad_query"))
self.assertEqual(result_sql, check_result.get("filtered_sql"))

def test_query_check_error(self):
test_sql = """drop table table_a"""

check_result = self.engine.query_check(sql=test_sql)

self.assertIsInstance(check_result, dict)
self.assertEqual(True, check_result.get("bad_query"))

@patch('sql.engines.odps.ODPSEngine.get_connection')
def test_get_all_databases(self, mock_get_connection):

mock_conn = Mock()
mock_conn.exist_project.return_value = True
mock_conn.project = 'some_db'

mock_get_connection.return_value = mock_conn

result = self.engine.get_all_databases()

self.assertIsInstance(result, ResultSet)
self.assertEqual(result.rows, ['some_db'])

@patch('sql.engines.odps.ODPSEngine.get_connection')
def test_get_all_tables(self, mock_get_connection):

# 下面是查表示例返回结果
class T:
def __init__(self, name):
self.name = name

mock_conn = Mock()
mock_conn.list_tables.return_value = [T('u'), T('v'), T('w')]
mock_get_connection.return_value = mock_conn

table_list = self.engine.get_all_tables('some_db')

self.assertEqual(table_list.rows, ['u', 'v', 'w'])

@patch('sql.engines.odps.ODPSEngine.get_all_columns_by_tb')
def test_describe_table(self, mock_get_all_columns_by_tb):
self.engine.describe_table('some_db', 'some_table')
mock_get_all_columns_by_tb.assert_called_once()

@patch('sql.engines.odps.ODPSEngine.get_connection')
def test_get_all_columns_by_tb(self, mock_get_connection):

mock_conn = Mock()

mock_cols = Mock()

mock_col = Mock()
mock_col.name, mock_col.type, mock_col.comment = 'XiaoMing', 'string', 'name'

mock_cols.schema.columns = [mock_col]
mock_conn.get_table.return_value = mock_cols
mock_get_connection.return_value = mock_conn

result = self.engine.get_all_columns_by_tb('some_db', 'some_table')
mock_get_connection.assert_called_once()
mock_conn.get_table.assert_called_once()
self.assertEqual(result.rows, [['XiaoMing', 'string', 'name']])
self.assertEqual(result.column_list, ['COLUMN_NAME', 'COLUMN_TYPE', 'COLUMN_COMMENT'])

0 comments on commit 9709c56

Please sign in to comment.