面向对象的尝试

发布于 2020-06-10  681 次阅读


import os, re
from decimal import Decimal, InvalidOperation
from itertools import chain

TABLE_COLUMN = 13

sep = re.compile(r'\s+')#表明分隔符是至少一个空白符
def _sep(s):
    s = sep.sub(' ', s).strip()#分隔符转为一个空格后去掉首位的空格
    return s.split(' ')

def _path(name, root=None):
    return os.path.join(root or os.getcwd(),name)

def _getStrIncreasedWidth(s):
    '''
    大致计算以空格 为单位时字串增加的宽度 
    '''
    return sum(ord(c) > 127 for c in s)#历遍字串,统计非ASCII字符的数目

def _sort(s):#由于文本形式的数字大小比较和实际的数字大小比较不一致所以需要将数字转到Deciaml
    try:
        return Decimal(s)
    except InvalidOperation:
        return s
def _key(key):
    def wrapKey(*args, **kw):#对key进行修饰,在外面套一层_sort
        return _sort(key(*args, **kw))
    return wrapKey

class Table(list):#表格对象继承自列表
    def addLine(self, line):
        self.append(list(line))
        for index in range(self.column):
            column_width = len(line[index]) + _getStrIncreasedWidth(line[index])#获取添加项的宽度
            self.column_width[index] = max(self.column_width[index], column_width)#该列始终保持为最大的宽度。
    def addLines(self, data):
        self.extend((list(x) for x in data))
        for index in range(self.column):
            max_column_width = max(map(lambda line:len(line[index]) + _getStrIncreasedWidth(line[index]), data))
            self.column_width[index] = max(self.column_width[index], max_column_width)#更新表格该列宽度为最大宽度
    def getPrint(self):
        def formatLine(line):
            items = (f'{item:^{self.column_width[i]-_getStrIncreasedWidth(item)}}' for i,item in enumerate(line))
            return '\t'.join(items)
        return '\n'.join(map(formatLine, chain([self.structure], self)))
    def __init__(self, structure, equivalent_field=None):
        self.structure = list(structure)
        self.column = len(structure)
        self.column_width = list(map(lambda x:len(x) + _getStrIncreasedWidth(x), structure))
        self.equivalent_field = equivalent_field if equivalent_field else {}
    __repr__ = __str__ = getPrint
    def getColumn(self, index):
        return [line[index] for line in self]
    def addEquivalentField(self, equivalent_field):
        self.equivalent_field.update(equivalent_field)
    def getFiledName(self, field):
        return self.equivalent_field.get(field, field)
    def indexField(self, key):
        if callable(key):
            return [i for i,_field in enumerate(self.structure) if key(_field)]
        else:
            field = self.getFiledName(key)
            return [i for i,_field in enumerate(self.structure) if field == _field]
    def filter(self, index, key):
        table = Table(self.structure, self.equivalent_field)
        for line in self:
            if key(line[index]):
                table.addLine(line)
        return table
    def filterField(self, indexList):
        table = Table([self.structure[index] for index in indexList], self.equivalent_field)
        for line in self:
            table.addLine([line[index] for index in indexList])
        return table 
    def addColumn(self, field_name, key, index=None):
        if not index:
            index = self.column
        self.structure.insert(index, field_name)#更新结构
        self.column_width.insert(index, len(field_name) + _getStrIncreasedWidth(field_name))#例行维护
        self.column += 1
        for x,line in enumerate(self):
            item = key(self, x, index)#调用生成函数key来生成要添加的值
            line.insert(index, item)
            self.column_width[index] = max(self.column_width[index], len(item) + _getStrIncreasedWidth(item))
    def sort(self, key, reverse=False):
        try:
            list.sort(self, key=_key(key), reverse=reverse)
        except TypeError:#如果同时有数字和汉字会出错,这时回到最初的排序方式
            list.sort(self, key=key, reverse=reverse)
    def __getitem__(self, n):
        if isinstance(n, int):# n是索引
            return list.__getitem__(self, n)
        else:# n是切片
            table = Table(self.structure, self.equivalent_field)
            table.addLines(list.__getitem__(self, n))
            return table

def pretreat():
    '''
    文件内容预处理:
    把数据文件stocks.txt中的箭头和逗号去除,
    把x%转换为x/100,
    去除掉最后一列(“股吧”),
    去除掉数字前面的+。
    在同一个目录下生成新的文件stockdata.txt
    '''
    TABLE_STRUCTURE = re.compile(r'([\w/]{2,})[\t ]+')
    TABLE_DATA = re.compile(r'^((?:(?:SH)?[69]|(?:SZ)?[02]|(?:HK)?)0\d{4})\s+(\S+)\s+[↑↓]?(\S+)\s+\+?(\S+)\s+\+?(\S+)%\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+股吧',flags=re.IGNORECASE)
    with open(_path(r'stocks.txt'), encoding='utf-8') as stocks:
        structure = TABLE_STRUCTURE.findall(stocks.readline())
        if len(structure) != TABLE_COLUMN:
            return 'stocks.txt格式有误, 请确认表头是否复制完全!'
        table = Table(structure)
        hundred = Decimal('100')
        for line in stocks:
            line = TABLE_DATA.findall(line.replace(',',''))
            if not line or len(line[0]) != TABLE_COLUMN:
                return f'stocks.txt格式有误, 请确认第{len(table)+2}行的数据是否复制完全!'
            line = list(line[0])#转成list是因为tuple不支持元素的修改
            line[4] = str(Decimal(line[4]) / hundred)
            table.addLine(line)
    with open(_path(r'stockdata.txt'), 'w', encoding='utf-8') as stockdata:
        stockdata.write(table.getPrint())
    return table

def readTable():
    '''
    读入处理后的stockdata.txt文件
    '''
    with open(_path(r'stockdata.txt'), encoding='utf-8') as stockdata:
        table = Table(_sep(stockdata.readline()))#先读入表头
        for line in stockdata:
            table.addLine(_sep(line))#再分行读入数据
    return table

def search(table):
    '''
    输入以空格隔开的若干个股票名称S,S如果没有输入,则默认为所有股票。
    输入以空格隔开的若干个字段名T,T如果没有输入,则默认为”代码”。
    输入一个要排序的字段名name(该字段必须在T中),
    输入排序方式(升序、降序)。其中排序方式默认为升序。
    根据这些条件,检索文件中符合条件的股票,并显示T字段名称及相应的内容。
    前两个字段名不管有没有出现在T中,都要输出显示。
    '''
    S = input('请输入以空格隔开的若干个股票名称S:').strip()
    if S:#如果S是有意义的输入就进行筛选
        S = _sep(S)
        temp_table = table.filter(table.indexField('名称')[0], lambda x:x in S)
        if len(temp_table) > 0:
            table = temp_table
    T = input('请输入以空格隔开的若干个字段名T:').strip()
    if not T:#默认为”代码”
        T = '代码'
    T = set(table.getFiledName(field) for field in _sep(T))
    T.update({'代码','名称'})#前两个字段名不管有没有出现在T中,都要输出显示
    table = table.filterField(table.indexField(lambda x: x in T))
    name = input('请输入一个要排序的字段名name(该字段必须在T中):').strip()
    if not name:#默认为”代码”
        name = '代码'
    name = table.getFiledName(name)
    if not name in T:#该字段必须在T中
        return f'排序的字段 {name} 必须在T中!'
    sort_index = table.indexField(name)
    if len(sort_index) != 1:
        return 'name 和 T 输入错误!'
    sort_index = sort_index[0]#如果输入没问题,就将以name所在的第一个字段进行排序
    reverse = input('请输入排序方式(升序、降序):').strip()
    if reverse == '降序':#默认为升序
        reverse = True
    else:
        reverse = False
    table.sort(key=lambda x:x[sort_index], reverse=reverse)
    return table

def volatilityAnalysis(table):
    '''
    根据“最高”和“最低”两个字段,
    计算波动率=(最高-最低)/最低,
    列出波动率最大的前三个股票代码、名称、波动率。
    '''
    table = table.filter(table.indexField('最低')[0], eval)#除去最低为0的股票
    index_sup = table.indexField('最高')[0]#获得“最高”字段所在位置
    index_inf = table.indexField('最低')[0]#获得“最低”字段所在位置
    def _analysis(table, x, y):
        sup = Decimal(table[x][index_sup])#获得当前行的“最高”字段的值
        inf = Decimal(table[x][index_inf])#获得当前行的“最低”字段的值
        return f'{(sup-inf)/inf:.5}'#按公式计算,并且结果保留5位有效数字
    table.addColumn('波动率', _analysis)#使用公式生成波动率字段添加到表格最后一列
    T = {'代码', '名称', '波动率'}#只保留代码、名称、和波动率三个字段
    table = table.filterField(table.indexField(lambda x: x in T))
    sort_index = table.indexField('波动率')[0]#对波动率进行降序排列
    table.sort(key=lambda x:x[sort_index], reverse=True)
    return table

if os.path.exists(_path(r'stockdata.txt')):#如果数据已经处理就直接读入,否则进行预处理
    table = readTable()
else:
    table = pretreat()
    assert isinstance(table, Table), table #如果stocks.txt格式有误,就提示错误信息,并终止运行
table.addEquivalentField({'成交量':'成交量/手', '成交额':'成交额/万'})#为表格字段添加别名

print('\n搜索结果如下:', search(table), sep='\n')#检索前面读入的表格,并打印检索结果所形成的新表格

print('\n波动性分析结果如下:')
print(volatilityAnalysis(table)[:3])#打印波动性分析后形成的新表格的前三行

一枚爱好探索的医学生