Commit 2c4cd5c0 authored by smile2019's avatar smile2019

Merge remote-tracking branch 'refs/remotes/origin/feat_1' into feat_1

parents d3fcd34f 7296e8d4
...@@ -13,4 +13,31 @@ res/ffmpeg-4.3.1/bin/output.mp4 ...@@ -13,4 +13,31 @@ res/ffmpeg-4.3.1/bin/output.mp4
res/ffmpeg-4.3.1/bin/qiji_local.mp4 res/ffmpeg-4.3.1/bin/qiji_local.mp4
venv/ venv/
venv37/ venv37/
shenming_test shenming_test
\ No newline at end of file
cap.png
requirements3.8.txt
venv3.8-new/
webrtcvad-2.0.10-cp38-abi3-win_amd64.whl
xlsx-resource/
deal_ocr.csv
deal_srt.csv
new.srt
shenhai1.xlsx
shenhai2.xlsx
test,py
"\346\267\261\346\265\267\347\237\255\347\211\2072.xlsx"
"\346\267\261\346\265\267\347\237\255\347\211\207origin.xlsx"
11.py
222.py
cap/
cap1597.png
cap831.png
deal.py
deal_movie.py
movie_1.txt
movie_pro.txt
res/.paddleocr/2.3.0.1/ocr/paddleocr/
script1.py
test/
\ No newline at end of file
...@@ -12,10 +12,11 @@ import os ...@@ -12,10 +12,11 @@ import os
class Content: class Content:
StartTimeColumn = 0 StartTimeColumn = 0
SubtitleColumnNumber = 2
AsideColumnNumber = 4 AsideColumnNumber = 4
SpeedColumnNumber = 5 SpeedColumnNumber = 5
# ActivateColumns = [2, 3] # ActivateColumns = [2, 3]
ActivateColumns = [4,5] ActivateColumns = [2,4,5]
# ColumnCount = 3 # ColumnCount = 3
ObjectName = "all_tableWidget" ObjectName = "all_tableWidget"
# TimeFormatColumns = [0] # TimeFormatColumns = [0]
......
This diff is collapsed.
This diff is collapsed.
...@@ -18,6 +18,9 @@ class MyWidget(QWidget): ...@@ -18,6 +18,9 @@ class MyWidget(QWidget):
# def __init__(self, parent=None): # def __init__(self, parent=None):
# super(QWidget, self).__init__(parent) # super(QWidget, self).__init__(parent)
# self.painter_flag = True # self.painter_flag = True
def __init__(self, parent=None, color = Qt.red):
super(QWidget, self).__init__(parent)
self.color = color
def paintEvent(self, event): def paintEvent(self, event):
# print(">>>>>>>>into paint") # print(">>>>>>>>into paint")
...@@ -26,7 +29,7 @@ class MyWidget(QWidget): ...@@ -26,7 +29,7 @@ class MyWidget(QWidget):
lock.acquire() lock.acquire()
painter = QPainter(self) painter = QPainter(self)
painter.setRenderHint(QPainter.Antialiasing) # Optional: Enable anti-aliasing painter.setRenderHint(QPainter.Antialiasing) # Optional: Enable anti-aliasing
painter.setPen(QPen(Qt.red, 2, Qt.SolidLine)) painter.setPen(QPen(self.color, 2, Qt.SolidLine))
painter.drawLine(0, 1, 800, 1) painter.drawLine(0, 1, 800, 1)
painter.end() painter.end()
lock.release() lock.release()
...@@ -55,8 +58,17 @@ class MyWidget(QWidget): ...@@ -55,8 +58,17 @@ class MyWidget(QWidget):
# painter.setPen(QPen(Qt.red, 2, Qt.SolidLine)) # painter.setPen(QPen(Qt.red, 2, Qt.SolidLine))
# painter.drawLine(0, 1, 800, 1) # painter.drawLine(0, 1, 800, 1)
# painter.end() # painter.end()
print(">>>>>cur_y : " + str(self.y()))
return self.y() return self.y()
def setY(self, h):
print(">>>>>cur_y2 : " + str(self.y()))
self.move(0, h)
def get_h(self):
return self.y()
def down(self, mov_len): def down(self, mov_len):
print(">>>>>>>>>>>down" + str(mov_len)) print(">>>>>>>>>>>down" + str(mov_len))
self.move(0,self.y() + mov_len) self.move(0,self.y() + mov_len)
...@@ -314,6 +326,8 @@ class Ui_MainWindow(object): ...@@ -314,6 +326,8 @@ class Ui_MainWindow(object):
self.horizontalLayout_7.setObjectName("horizontalLayout_7") self.horizontalLayout_7.setObjectName("horizontalLayout_7")
self.up_ocr_btn = QtWidgets.QPushButton(self.centralwidget) self.up_ocr_btn = QtWidgets.QPushButton(self.centralwidget)
self.up_ocr_btn.setObjectName("up_ocr_btn") self.up_ocr_btn.setObjectName("up_ocr_btn")
# self.up_ocr_btn.setAutoRepeatDelay(False)
# self.up_ocr_btn.setAutoRepeat
self.horizontalLayout_7.addWidget(self.up_ocr_btn) self.horizontalLayout_7.addWidget(self.up_ocr_btn)
self.down_ocr_btn = QtWidgets.QPushButton(self.centralwidget) self.down_ocr_btn = QtWidgets.QPushButton(self.centralwidget)
self.down_ocr_btn.setObjectName("down_ocr_btn") self.down_ocr_btn.setObjectName("down_ocr_btn")
...@@ -324,9 +338,13 @@ class Ui_MainWindow(object): ...@@ -324,9 +338,13 @@ class Ui_MainWindow(object):
self.down_ocr_bottom_btn = QtWidgets.QPushButton(self.centralwidget) self.down_ocr_bottom_btn = QtWidgets.QPushButton(self.centralwidget)
self.down_ocr_bottom_btn.setObjectName("down_ocr_bottom_btn") self.down_ocr_bottom_btn.setObjectName("down_ocr_bottom_btn")
self.horizontalLayout_7.addWidget(self.down_ocr_bottom_btn) self.horizontalLayout_7.addWidget(self.down_ocr_bottom_btn)
self.confirm_ocr_btn = QtWidgets.QPushButton(self.centralwidget)
self.confirm_ocr_btn.setObjectName("confirm_ocr_btn")
self.horizontalLayout_7.addWidget(self.confirm_ocr_btn)
self.confirm_head_aside_btn = QtWidgets.QPushButton(self.centralwidget) self.confirm_head_aside_btn = QtWidgets.QPushButton(self.centralwidget)
self.confirm_head_aside_btn.setObjectName("confirm_head_aside_btn") self.confirm_head_aside_btn.setObjectName("confirm_head_aside_btn")
self.horizontalLayout_7.addWidget(self.confirm_head_aside_btn) self.horizontalLayout_7.addWidget(self.confirm_head_aside_btn)
self.horizontalLayout_8 = QtWidgets.QHBoxLayout() self.horizontalLayout_8 = QtWidgets.QHBoxLayout()
self.horizontalLayout_8.setObjectName("horizontalLayout_8") self.horizontalLayout_8.setObjectName("horizontalLayout_8")
...@@ -523,7 +541,8 @@ class Ui_MainWindow(object): ...@@ -523,7 +541,8 @@ class Ui_MainWindow(object):
self.action_redo = QtWidgets.QAction(MainWindow) self.action_redo = QtWidgets.QAction(MainWindow)
# self.action_redo.setFont(font) # self.action_redo.setFont(font)
self.action_redo.setObjectName("action_redo") self.action_redo.setObjectName("action_redo")
self.action_3 = QtWidgets.QAction("旁白区间检测",self,triggered=self.show_detect_dialog) # self.action_3 = QtWidgets.QAction("旁白区间检测",self,triggered=self.show_detect_dialog)
self.action_3 = QtWidgets.QAction("旁白区间检测",self,triggered=self.show_confirmation_dialog)
self.action_3.setEnabled(False) self.action_3.setEnabled(False)
self.action_4 = QtWidgets.QAction("旁白音频合成",self,triggered=self.show_assemble_dialog) self.action_4 = QtWidgets.QAction("旁白音频合成",self,triggered=self.show_assemble_dialog)
self.action_4.setEnabled(False) self.action_4.setEnabled(False)
...@@ -539,7 +558,6 @@ class Ui_MainWindow(object): ...@@ -539,7 +558,6 @@ class Ui_MainWindow(object):
self.action_9.setEnabled(True) self.action_9.setEnabled(True)
self.action_10 = QtWidgets.QAction("片头旁白定位",self,triggered=self.confirm_head_aside) self.action_10 = QtWidgets.QAction("片头旁白定位",self,triggered=self.confirm_head_aside)
self.action_10.setEnabled(True) self.action_10.setEnabled(True)
# self.action_3.setObjectName("action_3") # self.action_3.setObjectName("action_3")
# self.action_4 = QtWidgets.QAction(MainWindow) # self.action_4 = QtWidgets.QAction(MainWindow)
# self.action_4.setObjectName("action_4") # self.action_4.setObjectName("action_4")
...@@ -604,6 +622,7 @@ class Ui_MainWindow(object): ...@@ -604,6 +622,7 @@ class Ui_MainWindow(object):
self.up_ocr_bottom_btn.setText(_translate("MainWindow", "字幕下边界上移")) self.up_ocr_bottom_btn.setText(_translate("MainWindow", "字幕下边界上移"))
self.down_ocr_bottom_btn.setText(_translate("MainWindow", "字幕下边界下移")) self.down_ocr_bottom_btn.setText(_translate("MainWindow", "字幕下边界下移"))
self.confirm_head_aside_btn.setText(_translate("MainWindow", "片头旁白定位")) self.confirm_head_aside_btn.setText(_translate("MainWindow", "片头旁白定位"))
self.confirm_ocr_btn.setText(_translate("MainWindow", "字幕边界确认"))
self.detect_btn.setText(_translate("MainWindow", "旁白区间检测")) self.detect_btn.setText(_translate("MainWindow", "旁白区间检测"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.all_tab), _translate("MainWindow", "字幕旁白")) self.tabWidget.setTabText(self.tabWidget.indexOf(self.all_tab), _translate("MainWindow", "字幕旁白"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.zm_tab), _translate("MainWindow", "字幕")) self.tabWidget.setTabText(self.tabWidget.indexOf(self.zm_tab), _translate("MainWindow", "字幕"))
......
...@@ -77,13 +77,14 @@ class OperateRecord: ...@@ -77,13 +77,14 @@ class OperateRecord:
# 每一行的具体信息,"起始时间", "终止时间", "字幕", '建议', '解说脚本' # 每一行的具体信息,"起始时间", "终止时间", "字幕", '建议', '解说脚本'
class Element: class Element:
def __init__(self, st_time_sec: str, ed_time_sec: str, subtitle, suggest, aside, speed = "1.00(4字/秒)"): def __init__(self, st_time_sec: str, ed_time_sec: str, subtitle, suggest, aside, speed = "1.00(4字/秒)",ocr_h = None):
self.st_time_sec = st_time_sec self.st_time_sec = st_time_sec
self.ed_time_sec = ed_time_sec self.ed_time_sec = ed_time_sec
self.subtitle = subtitle self.subtitle = subtitle
self.suggest = suggest self.suggest = suggest
self.aside = aside self.aside = aside
self.speed = speed self.speed = speed
self.ocr_h = ocr_h
# 判断当前元素是否是字幕 # 判断当前元素是否是字幕
def is_subtitle(self): def is_subtitle(self):
...@@ -263,6 +264,11 @@ class ProjectContext: ...@@ -263,6 +264,11 @@ class ProjectContext:
if not self.initial_ing: if not self.initial_ing:
save_excel_to_path(self.all_elements, self.excel_path, self.write_header, self.excel_sheet_name) save_excel_to_path(self.all_elements, self.excel_path, self.write_header, self.excel_sheet_name)
def refresh_subtitle(self, row, subtitle: str):
self.all_elements[int(row)].subtitle = subtitle
if not self.initial_ing:
save_excel_to_path(self.all_elements, self.excel_path, self.write_header, self.excel_sheet_name)
def refresh_speed(self, row, speed: str)->None: def refresh_speed(self, row, speed: str)->None:
self.all_elements[int(row)].speed = speed self.all_elements[int(row)].speed = speed
if not self.initial_ing: if not self.initial_ing:
...@@ -307,7 +313,7 @@ class ProjectContext: ...@@ -307,7 +313,7 @@ class ProjectContext:
if d["终止时间"][i] is None: if d["终止时间"][i] is None:
# 如果是最后一条 # 如果是最后一条
if i == len(d["字幕"]) - 1: if i == len(d["字幕"]) - 1:
print(1) print(">>>>>>>>>load_excel_from_path")
# ed_time_sec = "360000" if self.duration == 0 else self.duration # todo 默认最大时长是100h # ed_time_sec = "360000" if self.duration == 0 else self.duration # todo 默认最大时长是100h
else: else:
ed_time_sec = "%.2f"%(float(d["起始时间"][i + 1]) - 0.01) ed_time_sec = "%.2f"%(float(d["起始时间"][i + 1]) - 0.01)
...@@ -428,6 +434,8 @@ def save_excel_to_path(all_element, new_excel_path, header, excel_sheet_name): ...@@ -428,6 +434,8 @@ def save_excel_to_path(all_element, new_excel_path, header, excel_sheet_name):
backup_path = os.path.dirname(new_excel_path) + "/tmp_"+str(time.time())+".xlsx" backup_path = os.path.dirname(new_excel_path) + "/tmp_"+str(time.time())+".xlsx"
# os.remove(new_excel_path) # os.remove(new_excel_path)
os.rename(new_excel_path, backup_path) os.rename(new_excel_path, backup_path)
# print(">>>>>>new_excel_path:" + new_excel_path)
# print(">>>>>>>>>>backup_path:" + backup_path)
try: try:
create_sheet(new_excel_path, "旁白插入位置建议", [header]) create_sheet(new_excel_path, "旁白插入位置建议", [header])
# for element in all_element: # for element in all_element:
......
...@@ -63,6 +63,13 @@ def detect(video_path: str, start_time: float, end_time: float, book_path: str, ...@@ -63,6 +63,13 @@ def detect(video_path: str, start_time: float, end_time: float, book_path: str,
from detect_with_ocr import detect_with_ocr from detect_with_ocr import detect_with_ocr
detect_with_ocr(video_path, book_path, start_time, end_time, state, mainWindow) detect_with_ocr(video_path, book_path, start_time, end_time, state, mainWindow)
def process_err(mainWindow: MainWindow=None):
from detect_with_ocr import process_err_ocr
try:
process_err_ocr(mainWindow)
except Exception as e:
print("process_err err")
print(e)
if __name__ == '__main__': if __name__ == '__main__':
# 定义参数 # 定义参数
......
import re import re
import sys
import csv import csv
import jieba
import argparse import argparse
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from difflib import SequenceMatcher from difflib import SequenceMatcher
title = ['起始时间(转换后)', '终止时间(转换后)', '字幕'] from tqdm import tqdm
# title = ['起始时间(转换后)', '终止时间(转换后)', '字幕']
title = ['起始时间', '终止时间', '字幕']
def init(): def init():
# 获取中文停用词列表 # 获取中文停用词列表
...@@ -22,13 +27,32 @@ def change_to_second(time_str): ...@@ -22,13 +27,32 @@ def change_to_second(time_str):
time_obj.second + time_obj.microsecond / 1000000 time_obj.second + time_obj.microsecond / 1000000
return seconds return seconds
# 将中文句子划分,并且防止划分全部为停用词
def words_segment(str):
tmp = ','.join(jieba.cut(str))
# 将分割的句子差分成单词,也不进行划分
if is_all_stopwords(tmp) or len(list(jieba.cut(str))) == len(str) :
return str
return tmp
# 判断是否从中英文字幕中提取中文
def extract_info(str, has_english=False):
if not has_english:
return str
chinese_text = re.findall(r'[\u4e00-\u9fff]+', str)
return ' '.join(chinese_text)
# 计算字幕的相似度 # 计算字幕的相似度
def calculate_similarity(str1, str2, method='cosine'): def calculate_similarity(str1, str2, method='cosine'):
if method == 'cosine': if method == 'cosine':
tfidf_vectorizer = TfidfVectorizer() str1, str2 = words_segment(str1), words_segment(str2)
tfidf_matrix = tfidf_vectorizer.fit_transform([str1, str2]) tfidf_vectorizer = TfidfVectorizer(min_df=1)
tfidf_matrix = tfidf_vectorizer.fit_transform([str1, str2]) # shape=[2, N]
# print(np.array(tfidf_matrix.toarray()).shape, type(tfidf_matrix), tfidf_matrix.toarray())
similarity_matrix = cosine_similarity(tfidf_matrix) similarity_matrix = cosine_similarity(tfidf_matrix)
return similarity_matrix[0][1] return similarity_matrix[0][1]
elif method == 'distance':
return -String_edit_distance(str1, str2)
else : else :
return SequenceMatcher(None, str1, str2).ratio() return SequenceMatcher(None, str1, str2).ratio()
...@@ -37,23 +61,62 @@ def calculate_time_difference(time1, time2): ...@@ -37,23 +61,62 @@ def calculate_time_difference(time1, time2):
return abs(time2 - time1) return abs(time2 - time1)
def calculate_weight(x, y): def calculate_weight(x, y):
# weight = e^(-alpha * time_diff) # # weight = e^(-alpha * time_diff)
# 相差1s的系数为0.9 # # 相差1s的系数为0.9
alpha = 0.11 # alpha = 0.11
return 1 / (alpha * (x + y) + 1) # return 1 / (alpha * (x + y) + 1)
return 1.0 # 目前不考虑时间系数
# 检查句子中的每个单词是否都是停用词 # 检查句子中的每个单词是否都是停用词
def is_all_stopwords(sentence): def is_all_stopwords(sentence):
sentence = sentence.replace(' ', '')
return all(word in stop_words for word in sentence) return all(word in stop_words for word in sentence)
# 编辑距离算法 有问题!!!!!!
def String_edit_distance(str1, str2):
n, m = len(str1), len(str2)
dp = [[0 for _ in range(m+1)] for _ in range(n+1)]
for i in range(n+1):
dp[i][0] = i
for i in range(m+1):
dp[0][i] = i
dp[0][0] = 0
for i in range(1, n+1):
for j in range(1, m+1):
if str1[i-1] == str2[j-1]:
dp[i][j] = dp[i-1][j-1]
else :
dp[i][j] = min(dp[i-1][j-1], min(dp[i][j-1], dp[i-1][j])) + 1
# print(dp[n][m], n, m)
return 1.0 * dp[n][m] / max(n, m)
### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题 ### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题
def read_srt_to_csv(path_srt, path_output): def read_srt_to_csv(path_srt, path_output):
with open(path_srt, 'r', encoding='utf-8-sig') as f: try:
srt_content = f.read() # str with open(path_srt, 'r', encoding='utf-8-sig') as f:
srt_content = f.read() # str
except UnicodeDecodeError:
print(f"编码错误,已经切换到utf-16编码")
try:
with open(path_srt, 'r', encoding='utf-16') as f:
srt_content = f.read() # str
except:
print(f"请选择utf-8或utf-16编码形式的srt文件")
sys.exit(1)
# 使用正则表达式匹配时间码和字幕内容 # 使用正则表达式匹配时间码和字幕内容
pattern = re.compile(r'(\d+)\n([\d:,]+) --> ([\d:,]+)\n(.+?)(?=\n\d+\n|$)', re.DOTALL) pattern = re.compile(r'(\d+)\n([\d:,]+) --> ([\d:,]+)\n(.+?)(?=\n\d+\n|$)', re.DOTALL)
matches = pattern.findall(srt_content) matches = pattern.findall(srt_content)
has_english = []
for i in range(5):
idx = np.random.randint(len(matches))
pattern = re.compile(r'[a-zA-Z]')
has_english.append(bool(pattern.search(matches[idx][3])))
has_english = all(has_english)
print('!'*20, has_english)
# 写入 csv 文件 # 写入 csv 文件
with open(path_output, 'w', newline='', encoding='utf-8') as f: with open(path_output, 'w', newline='', encoding='utf-8') as f:
csv_writer = csv.writer(f) csv_writer = csv.writer(f)
...@@ -61,7 +124,7 @@ def read_srt_to_csv(path_srt, path_output): ...@@ -61,7 +124,7 @@ def read_srt_to_csv(path_srt, path_output):
for _, start, end, subtitle in matches: # 都是str格式 for _, start, end, subtitle in matches: # 都是str格式
subtitle = re.sub(r'\{[^}]*\}', '', subtitle) # 将srt文件前的加粗等格式去掉 subtitle = re.sub(r'\{[^}]*\}', '', subtitle) # 将srt文件前的加粗等格式去掉
csv_writer.writerow([start, end, subtitle.strip()]) csv_writer.writerow([start, end, extract_info(subtitle.strip(), has_english)])
def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'): def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'):
data = pd.read_excel(path_xlsx) data = pd.read_excel(path_xlsx)
...@@ -70,20 +133,19 @@ def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'): ...@@ -70,20 +133,19 @@ def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'):
csv_writer.writerow(title) csv_writer.writerow(title)
for _, data1 in data.iterrows(): for _, data1 in data.iterrows():
start, end, subtitle = data1[1], data1[3], data1[4] # print(data1[1])
start, end, subtitle = data1[0], data1[1], data1[2]
if isinstance(subtitle, float) and np.isnan(subtitle): if isinstance(subtitle, float) and np.isnan(subtitle):
continue continue
# 与srt文件格式同步 # 与srt文件格式同步
start = start.replace('.', ',') start = start.replace('.', ',')
end = end.replace('.', ',') end = end.replace('.', ',')
# print(start, end, subtitle,)
# print(type(start), type(end), type(subtitle))
csv_writer.writerow([start, end, subtitle.strip()]) csv_writer.writerow([start, end, subtitle.strip()])
### 对于srt中的字幕计算相似性度。从ocr中找到时间戳满足<=time_t的字幕, ### 对于srt中的字幕计算相似性度。从ocr中找到时间戳满足<=time_t的字幕,
### 然后计算字幕间的相似度,取一个最大的。字幕从start和end都匹配一遍 ### 然后计算字幕间的相似度,取一个最大的。字幕从start和end都匹配一遍
# time_threshold设置阈值,用于判断时间差是否可接受 # time_threshold设置阈值,用于判断时间差是否可接受
def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'): def measure_score(path_srt, path_ocr, time_threshold=5.0, time_threshold_re=False, method='cosine'):
data_srt, data_ocr = [], [] data_srt, data_ocr = [], []
with open(path_srt, 'r', encoding='utf-8') as file: with open(path_srt, 'r', encoding='utf-8') as file:
csv_reader = csv.reader(file) csv_reader = csv.reader(file)
...@@ -103,22 +165,36 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'): ...@@ -103,22 +165,36 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'):
# 计算相似度 # 计算相似度
total_similarity = 0.0 total_similarity = 0.0
total_weight = 0.0 total_weight = 0.0
txt1 = []
for sub in data_srt: for sub in tqdm(data_srt, desc="Processing", ncols=100):
max_similarity = 0.0 max_similarity = 0.0 if method != 'distance' else -1.0
# 去除srt中的停用词 # 去除srt中的停用词
if is_all_stopwords(sub[2]): if is_all_stopwords(sub[2]):
continue continue
subb = ""
for sub1 in data_ocr: for sub1 in data_ocr:
x, y = abs(sub[0] - sub1[0]), abs(sub[1] - sub1[1]) x, y = abs(sub[0] - sub1[0]), abs(sub[1] - sub1[1])
if min(x, y) <= time_threshold: if time_threshold_re:
# print(sub[2], sub1[2]) time_threshold_tmp = time_threshold
score = calculate_similarity(sub[2], sub1[2], 'cosine') else :
time_threshold_tmp = (sub[1] - sub[0]) * 0.3 # 10s允许3s的误差
if min(x, y) <= time_threshold_tmp:
score = calculate_similarity(sub[2], sub1[2], method)
if max_similarity <= score * calculate_weight(x, y):
subb = sub1[2]
max_similarity = max(max_similarity, score * calculate_weight(x, y)) max_similarity = max(max_similarity, score * calculate_weight(x, y))
if max_similarity <= -0.5:
# print(max_similarity, sub[2], subb, sub[0])
txt1.append(' !!! '.join([str(max_similarity), sub[2], subb, str(sub[0])]))
total_similarity += max_similarity total_similarity += max_similarity
total_weight += 1 total_weight += 1
if method == 'distance':
total_similarity = total_weight + total_similarity
# print(total_similarity, total_similarity / len(data_srt), total_similarity / total_weight) with open('movie_pro.txt', 'w', encoding='utf-8') as f:
for i in txt1:
f.write(i + '\n')
return total_similarity / len(data_srt), total_similarity / total_weight return total_similarity / len(data_srt), total_similarity / total_weight
if __name__ == '__main__': if __name__ == '__main__':
...@@ -128,13 +204,23 @@ if __name__ == '__main__': ...@@ -128,13 +204,23 @@ if __name__ == '__main__':
# 添加命令行参数 # 添加命令行参数
parser.add_argument("--path_srt", required=True, type=str, help="Path of srt file, format is srt") parser.add_argument("--path_srt", required=True, type=str, help="Path of srt file, format is srt")
parser.add_argument("--path_ocr", required=True, type=str, help="Path of ocr file, format is xlsx") parser.add_argument("--path_ocr", required=True, type=str, help="Path of ocr file, format is xlsx")
parser.add_argument("--method", type=str, default='cosine', help="Select evaluation method") parser.add_argument("--time_threshold", type=float, default=5.0, help="Allowable time frame")
parser.add_argument("--time_threshold", type=float,default=5.0, help="Allowable time frame")
parser.add_argument("--method", type=str, default='distance',choices=['cosine', 'distance', 'sequence']
, help="Select evaluation method")
parser.add_argument("--time_threshold_re", type=bool, default=True, help="Specify whether \
time threshold is required")
args = parser.parse_args() args = parser.parse_args()
output_file_srt = 'deal_srt.csv' output_file_srt = 'deal_srt.csv'
output_file_ocr = 'deal_ocr.csv' output_file_ocr = 'deal_ocr.csv'
read_srt_to_csv(args.path_srt, output_file_srt) read_srt_to_csv(args.path_srt, output_file_srt)
read_from_xlsx(args.path_ocr, output_file_ocr) read_from_xlsx(args.path_ocr, output_file_ocr)
score = measure_score(output_file_srt, output_file_ocr, args.time_threshold, args.method) score = measure_score(output_file_srt, output_file_ocr, args.time_threshold, \
print(f'该评估算法得分: {score[1]:.5f}') args.time_threshold_re, args.method)
\ No newline at end of file print(f'该评估算法得分: {100 * score[1]:.3f}')
# python ocr_metric.py --path_srt test/new/movie_1.srt --path_ocr ../测试/the-swan-v3/The.Swan-zimu.xlsx --time_threshold 10
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment