Commit 0824724a authored by smile2019's avatar smile2019

Merge remote-tracking branch 'refs/remotes/origin/feat_1' into feat_1

parents a5f36591 9f2adad2
...@@ -12,12 +12,16 @@ import os ...@@ -12,12 +12,16 @@ import os
class Content: class Content:
StartTimeColumn = 0 StartTimeColumn = 0
AsideColumnNumber = 2 AsideColumnNumber = 4
SpeedColumnNumber = 3 SpeedColumnNumber = 5
ActivateColumns = [2, 3] # ActivateColumns = [2, 3]
ActivateColumns = [4,5]
# ColumnCount = 3 # ColumnCount = 3
ObjectName = "all_tableWidget" ObjectName = "all_tableWidget"
TimeFormatColumns = [0] # TimeFormatColumns = [0]
TimeFormatColumns = [0, 1]
SpeedList = ["1.00(4字/秒)", "1.10(4.5字/秒)", "1.25(5字/秒)",
"1.50(6字/秒)", "1.75(7字/秒)", "2.00(8字/秒)", "2.50(10字/秒)"]
class Aside: class Aside:
......
...@@ -24,17 +24,20 @@ class Ui_Dialog(object): ...@@ -24,17 +24,20 @@ class Ui_Dialog(object):
self.name_input.setObjectName("name_input") self.name_input.setObjectName("name_input")
self.root_input = QtWidgets.QLineEdit(Dialog) self.root_input = QtWidgets.QLineEdit(Dialog)
self.root_input.setObjectName("root_input") self.root_input.setObjectName("root_input")
self.gridLayout.addWidget(self.root_input, 1, 1, 1, 1) self.gridLayout.addWidget(self.root_input, 0, 1, 1, 1)
self.get_dir = QtWidgets.QPushButton(Dialog) self.get_dir = QtWidgets.QPushButton(Dialog)
self.get_dir.setObjectName("get_dir") self.get_dir.setObjectName("get_dir")
self.gridLayout.addWidget(self.get_dir, 1, 2, 1, 1) self.gridLayout.addWidget(self.get_dir, 0, 2, 1, 1)
self.rootLabel = QtWidgets.QLabel(Dialog) self.rootLabel = QtWidgets.QLabel(Dialog)
self.rootLabel.setObjectName("rootLabel") self.rootLabel.setObjectName("rootLabel")
self.gridLayout.addWidget(self.rootLabel, 1, 0, 1, 1) self.gridLayout.addWidget(self.rootLabel, 0, 0, 1, 1)
self.nameLabel = QtWidgets.QLabel(Dialog) self.nameLabel = QtWidgets.QLabel(Dialog)
self.nameLabel.setObjectName("nameLabel") self.nameLabel.setObjectName("nameLabel")
self.gridLayout.addWidget(self.nameLabel, 0, 0, 1, 1) self.gridLayout.addWidget(self.nameLabel, 1, 0, 1, 1)
self.gridLayout.addWidget(self.name_input, 0, 1, 1, 1)
self.gridLayout.addWidget(self.name_input, 1, 1, 1, 1)
self.gridLayout_2.addLayout(self.gridLayout, 0, 0, 1, 1) self.gridLayout_2.addLayout(self.gridLayout, 0, 0, 1, 1)
self.horizontalLayout = QtWidgets.QHBoxLayout() self.horizontalLayout = QtWidgets.QHBoxLayout()
self.horizontalLayout.setObjectName("horizontalLayout") self.horizontalLayout.setObjectName("horizontalLayout")
...@@ -64,7 +67,7 @@ class Ui_Dialog(object): ...@@ -64,7 +67,7 @@ class Ui_Dialog(object):
_translate = QtCore.QCoreApplication.translate _translate = QtCore.QCoreApplication.translate
Dialog.setWindowTitle(_translate("Dialog", "Dialog")) Dialog.setWindowTitle(_translate("Dialog", "Dialog"))
self.nameLabel.setText(_translate("Dialog", "工程名称")) self.nameLabel.setText(_translate("Dialog", "工程名称"))
self.rootLabel.setText(_translate("Dialog", "工程文件夹")) self.rootLabel.setText(_translate("Dialog", "目标路径"))
self.get_dir.setText(_translate("Dialog", "打开文件夹")) self.get_dir.setText(_translate("Dialog", "打开文件夹"))
self.confirm.setText(_translate("Dialog", "确认")) self.confirm.setText(_translate("Dialog", "确认"))
self.cancel.setText(_translate("Dialog", "取消")) self.cancel.setText(_translate("Dialog", "取消"))
This diff is collapsed.
This diff is collapsed.
<?xml version="1.0" encoding="UTF-8"?> <!-- <?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0"> <ui version="4.0">
<class>MainWindow</class> <class>MainWindow</class>
<widget class="QMainWindow" name="MainWindow"> <widget class="QMainWindow" name="MainWindow">
...@@ -709,4 +709,4 @@ ...@@ -709,4 +709,4 @@
</customwidgets> </customwidgets>
<resources/> <resources/>
<connections/> <connections/>
</ui> </ui> -->
...@@ -7,7 +7,33 @@ ...@@ -7,7 +7,33 @@
# WARNING! All changes made in this file will be lost! # WARNING! All changes made in this file will be lost!
from PyQt5 import QtCore, QtGui, QtWidgets from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QMainWindow, QFileDialog, QTableWidget, QTableWidgetItem, QAbstractItemView, QProgressBar, QLabel, QApplication, QPushButton, QMenu, QWidget
from PyQt5.QtCore import QUrl, Qt, QTimer, QRect, pyqtSignal, QPersistentModelIndex
from PyQt5.QtMultimedia import *
from PyQt5.QtGui import QIcon, QPainter, QColor, QPen
class MyWidget(QWidget):
def paintEvent(self, event):
print(">>>>>>>>>>>>>>>into paint")
painter = QPainter(self)
painter.setRenderHint(QPainter.Antialiasing) # Optional: Enable anti-aliasing
# painter.setCompositionMode(QPainter.CompositionMode_SourceOver) # Set composition mode
# # Draw existing content
# painter.fillRect(event.rect(), QColor(255, 255, 255)) # Fill with white color (you can adjust as needed)
# Draw a transparent horizontal line
painter.setPen(QPen(Qt.red, 2, Qt.SolidLine))
painter.drawLine(0, 1, 800, 1)
def up(self, mov_len):
print(">>>>>>>>>>>up" + str(mov_len))
self.move(0, self.y() - mov_len)
return self.y()
def down(self, mov_len):
print(">>>>>>>>>>>down" + str(mov_len))
self.move(0,self.y() + mov_len)
return self.y()
class Ui_MainWindow(object): class Ui_MainWindow(object):
def setupUi(self, MainWindow): def setupUi(self, MainWindow):
...@@ -32,8 +58,12 @@ class Ui_MainWindow(object): ...@@ -32,8 +58,12 @@ class Ui_MainWindow(object):
self.verticalLayout_2 = QtWidgets.QVBoxLayout() self.verticalLayout_2 = QtWidgets.QVBoxLayout()
self.verticalLayout_2.setObjectName("verticalLayout_2") self.verticalLayout_2.setObjectName("verticalLayout_2")
self.wgt_video = myVideoWidget(self.centralwidget) self.wgt_video = myVideoWidget(self.centralwidget)
self.wgt_video.setMinimumSize(QtCore.QSize(410, 200)) # self.wgt_video.setMinimumSize(QtCore.QSize(410, 200))
self.wgt_video.setMaximumSize(QtCore.QSize(16777215, 16777215)) # self.wgt_video.setMaximumSize(QtCore.QSize(16777215, 16777215))
self.widget = MyWidget(self.centralwidget)
self.widget.setGeometry(0,400,800,3)
self.widget_bottom = MyWidget(self.centralwidget)
self.widget_bottom.setGeometry(0,430,800,3)
palette = QtGui.QPalette() palette = QtGui.QPalette()
brush = QtGui.QBrush(QtGui.QColor(0, 0, 0)) brush = QtGui.QBrush(QtGui.QColor(0, 0, 0))
brush.setStyle(QtCore.Qt.SolidPattern) brush.setStyle(QtCore.Qt.SolidPattern)
...@@ -225,7 +255,7 @@ class Ui_MainWindow(object): ...@@ -225,7 +255,7 @@ class Ui_MainWindow(object):
self.zm_tableWidget.setRowCount(0) self.zm_tableWidget.setRowCount(0)
self.zm_tableWidget.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) self.zm_tableWidget.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)
self.horizontalLayout_2.addWidget(self.zm_tableWidget) self.horizontalLayout_2.addWidget(self.zm_tableWidget)
self.tabWidget.addTab(self.zm_tab, "") # self.tabWidget.addTab(self.zm_tab, "")
self.pb_tab = QtWidgets.QWidget() self.pb_tab = QtWidgets.QWidget()
self.pb_tab.setObjectName("pb_tab") self.pb_tab.setObjectName("pb_tab")
self.horizontalLayout_3 = QtWidgets.QHBoxLayout(self.pb_tab) self.horizontalLayout_3 = QtWidgets.QHBoxLayout(self.pb_tab)
...@@ -236,7 +266,7 @@ class Ui_MainWindow(object): ...@@ -236,7 +266,7 @@ class Ui_MainWindow(object):
self.pb_tableWidget.setRowCount(0) self.pb_tableWidget.setRowCount(0)
self.pb_tableWidget.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) self.pb_tableWidget.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)
self.horizontalLayout_3.addWidget(self.pb_tableWidget) self.horizontalLayout_3.addWidget(self.pb_tableWidget)
self.tabWidget.addTab(self.pb_tab, "") # self.tabWidget.addTab(self.pb_tab, "")
self.shuiping.addWidget(self.tabWidget) self.shuiping.addWidget(self.tabWidget)
self.shuiping.setStretch(0, 3) self.shuiping.setStretch(0, 3)
self.shuiping.setStretch(1, 5) self.shuiping.setStretch(1, 5)
...@@ -327,8 +357,14 @@ class Ui_MainWindow(object): ...@@ -327,8 +357,14 @@ class Ui_MainWindow(object):
self.menu.setObjectName("menu") self.menu.setObjectName("menu")
self.menu_2 = QtWidgets.QMenu(self.menubar) self.menu_2 = QtWidgets.QMenu(self.menubar)
self.menu_2.setObjectName("menu_2") self.menu_2.setObjectName("menu_2")
self.menu_3 = QtWidgets.QMenu(self.menubar) # self.menu_3 = QtWidgets.QMenu(self.menubar)
self.menu_3.setObjectName("menu_3") # self.menu_3.setObjectName("menu_3")
self.menu_4 = QtWidgets.QMenu(self.menubar)
self.menu_4.setObjectName("menu_4")
self.menu_5 = QtWidgets.QMenu(self.menubar)
self.menu_5.setObjectName("menu_5")
self.menu_6 = QtWidgets.QMenu(self.menubar)
self.menu_6.setObjectName("menu_6")
MainWindow.setMenuBar(self.menubar) MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow) self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar") self.statusbar.setObjectName("statusbar")
...@@ -355,12 +391,26 @@ class Ui_MainWindow(object): ...@@ -355,12 +391,26 @@ class Ui_MainWindow(object):
self.action_redo = QtWidgets.QAction(MainWindow) self.action_redo = QtWidgets.QAction(MainWindow)
# self.action_redo.setFont(font) # self.action_redo.setFont(font)
self.action_redo.setObjectName("action_redo") self.action_redo.setObjectName("action_redo")
self.action_3 = QtWidgets.QAction(MainWindow) self.action_3 = QtWidgets.QAction("旁白区间检测",self,triggered=self.show_detect_dialog)
self.action_3.setObjectName("action_3") self.action_3.setEnabled(False)
self.action_4 = QtWidgets.QAction(MainWindow) self.action_4 = QtWidgets.QAction("旁白音频合成",self,triggered=self.show_assemble_dialog)
self.action_4.setObjectName("action_4") self.action_4.setEnabled(False)
self.action_5 = QtWidgets.QAction(MainWindow) self.action_5 = QtWidgets.QAction("旁白导入",self,triggered=self.import_excel)
self.action_5.setObjectName("action_5") self.action_5.setEnabled(False)
self.action_6 = QtWidgets.QAction("字幕上边界++",self,triggered=self.up_ocr)
self.action_6.setEnabled(True)
self.action_7 = QtWidgets.QAction("字幕上边界--",self,triggered=self.down_ocr)
self.action_7.setEnabled(True)
self.action_8 = QtWidgets.QAction("字幕下边界++",self,triggered=self.up_ocr_bottom)
self.action_8.setEnabled(True)
self.action_9 = QtWidgets.QAction("字幕下边界--",self,triggered=self.down_ocr_bottom)
self.action_9.setEnabled(True)
# self.action_3.setObjectName("action_3")
# self.action_4 = QtWidgets.QAction(MainWindow)
# self.action_4.setObjectName("action_4")
# self.action_5 = QtWidgets.QAction(MainWindow)
# self.action_5.setObjectName("action_5")
self.action_operate = QtWidgets.QAction(MainWindow) self.action_operate = QtWidgets.QAction(MainWindow)
self.action_operate.setObjectName("action_operate") self.action_operate.setObjectName("action_operate")
self.action_export = QtWidgets.QAction(MainWindow) self.action_export = QtWidgets.QAction(MainWindow)
...@@ -384,13 +434,22 @@ class Ui_MainWindow(object): ...@@ -384,13 +434,22 @@ class Ui_MainWindow(object):
self.menu_2.addSeparator() self.menu_2.addSeparator()
self.menu_2.addAction(self.action_insert_aside_from_now) self.menu_2.addAction(self.action_insert_aside_from_now)
self.menu_2.addAction(self.action_operate) self.menu_2.addAction(self.action_operate)
self.menu_3.addAction(self.action_3) # self.menu_3.addAction(self.action_3)
self.menu_3.addAction(self.action_4) # self.menu_3.addAction(self.action_4)
self.menu_3.addAction(self.action_5) # self.menu_3.addAction(self.action_5)
self.menu_3.addSeparator() # self.menu_3.addSeparator()
self.menubar.addAction(self.menu.menuAction()) self.menubar.addAction(self.menu.menuAction())
self.menubar.addAction(self.menu_2.menuAction()) self.menubar.addAction(self.menu_2.menuAction())
self.menubar.addAction(self.menu_3.menuAction()) self.menubar.addAction(self.action_3)
self.menubar.addAction(self.action_4)
self.menubar.addAction(self.action_5)
self.menubar.addAction(self.action_6)
self.menubar.addAction(self.action_7)
self.menubar.addAction(self.action_8)
self.menubar.addAction(self.action_9)
# self.menubar.addAction(self.menu_5.menuAction())
# self.menubar.addAction(self.menu_6.menuAction())
# self.menubar.addAction(self.menu_3.menuAction())
self.retranslateUi(MainWindow) self.retranslateUi(MainWindow)
self.tabWidget.setCurrentIndex(0) self.tabWidget.setCurrentIndex(0)
...@@ -410,7 +469,10 @@ class Ui_MainWindow(object): ...@@ -410,7 +469,10 @@ class Ui_MainWindow(object):
self.pb_label.setText(_translate("MainWindow", "刻度")) self.pb_label.setText(_translate("MainWindow", "刻度"))
self.menu.setTitle(_translate("MainWindow", "文件")) self.menu.setTitle(_translate("MainWindow", "文件"))
self.menu_2.setTitle(_translate("MainWindow", "编辑")) self.menu_2.setTitle(_translate("MainWindow", "编辑"))
self.menu_3.setTitle(_translate("MainWindow", "功能按键")) # self.menu_3.setTitle(_translate("MainWindow", "功能按键"))
self.menu_4.setTitle(_translate("MainWindow", "旁白区间检测"))
self.menu_5.setTitle(_translate("MainWindow", "旁白音频合成"))
self.menu_6.setTitle(_translate("MainWindow", "旁白导入"))
self.setting.setText(_translate("MainWindow", "设置")) self.setting.setText(_translate("MainWindow", "设置"))
self.action_open_project.setText(_translate("MainWindow", "打开")) self.action_open_project.setText(_translate("MainWindow", "打开"))
self.import_movie.setText(_translate("MainWindow", "视频导入")) self.import_movie.setText(_translate("MainWindow", "视频导入"))
...@@ -418,15 +480,14 @@ class Ui_MainWindow(object): ...@@ -418,15 +480,14 @@ class Ui_MainWindow(object):
self.action_save.setText(_translate("MainWindow", "保存并备份")) self.action_save.setText(_translate("MainWindow", "保存并备份"))
self.action_undo.setText(_translate("MainWindow", "撤销")) self.action_undo.setText(_translate("MainWindow", "撤销"))
self.action_redo.setText(_translate("MainWindow", "重做")) self.action_redo.setText(_translate("MainWindow", "重做"))
self.action_3.setText(_translate("MainWindow", "旁白区间检测")) # self.action_3.setText(_translate("MainWindow", "旁白区间检测"))
self.action_4.setText(_translate("MainWindow", "旁白音频合成")) # self.action_4.setText(_translate("MainWindow", "旁白音频合成"))
self.action_5.setText(_translate("MainWindow", "旁白导入")) # self.action_5.setText(_translate("MainWindow", "旁白导入"))
self.action_operate.setText(_translate("MainWindow", "操作表格")) self.action_operate.setText(_translate("MainWindow", "操作表格"))
self.action_export.setText(_translate("MainWindow", "导出")) self.action_export.setText(_translate("MainWindow", "导出"))
self.action_insert_aside_from_now.setText(_translate("MainWindow", "当前位置插入旁白")) self.action_insert_aside_from_now.setText(_translate("MainWindow", "当前位置插入旁白"))
self.action_create.setText(_translate("MainWindow", "新建")) self.action_create.setText(_translate("MainWindow", "新建"))
from myVideoWidget import myVideoWidget from myVideoWidget import myVideoWidget
from myvideoslider import myVideoSlider from myvideoslider import myVideoSlider
from mywidgetcontents import myWidgetContents from mywidgetcontents import myWidgetContents
...@@ -96,7 +96,8 @@ class Element: ...@@ -96,7 +96,8 @@ class Element:
def to_list(self): def to_list(self):
return [self.st_time_sec, self.ed_time_sec, self.subtitle, self.suggest, self.aside, self.speed] return [self.st_time_sec, self.ed_time_sec, self.subtitle, self.suggest, self.aside, self.speed]
def to_short_list(self): def to_short_list(self):
return [self.st_time_sec, self.subtitle, self.aside, self.speed] # return [self.st_time_sec, self.subtitle, self.aside, self.speed]
return [self.st_time_sec, self.ed_time_sec, self.subtitle, self.suggest, self.aside, self.speed]
def to_aside_list(self): def to_aside_list(self):
# return [self.st_time_sec, self.ed_time_sec, self.suggest, self.aside, self.speed] # return [self.st_time_sec, self.ed_time_sec, self.suggest, self.aside, self.speed]
return [self.st_time_sec, self.suggest, self.aside, self.speed] return [self.st_time_sec, self.suggest, self.aside, self.speed]
...@@ -119,6 +120,7 @@ class ProjectContext: ...@@ -119,6 +120,7 @@ class ProjectContext:
self.subtitle_list = [] self.subtitle_list = []
self.aside_list = [] self.aside_list = []
self.all_elements = [] self.all_elements = []
self.speaker_type = None
self.speaker_info = None self.speaker_info = None
self.speaker_speed = None self.speaker_speed = None
self.duration = 0 self.duration = 0
...@@ -128,7 +130,8 @@ class ProjectContext: ...@@ -128,7 +130,8 @@ class ProjectContext:
self.aside_header = ['起始时间', '推荐字数', '解说脚本',"语速", "预览音频"] self.aside_header = ['起始时间', '推荐字数', '解说脚本',"语速", "预览音频"]
self.subtitle_header = ["起始时间", "终止时间", "字幕"] self.subtitle_header = ["起始时间", "终止时间", "字幕"]
self.contentHeader = ["起始时间", "字幕", "解说脚本", "语速"] # self.contentHeader = ["起始时间", "字幕", "解说脚本", "语速"]
self.contentHeader = ["起始时间", "结束时间", "字幕", "推荐字数", "解说脚本", "语速", "预览音频"]
self.excel_sheet_name = "旁白插入位置建议" self.excel_sheet_name = "旁白插入位置建议"
self.history_records = [] self.history_records = []
self.records_pos = 0 self.records_pos = 0
...@@ -174,6 +177,7 @@ class ProjectContext: ...@@ -174,6 +177,7 @@ class ProjectContext:
if not os.path.exists(self.conf_path): if not os.path.exists(self.conf_path):
print("conf file does not exist, 找管理员要") print("conf file does not exist, 找管理员要")
return return
print(self.conf_path)
with open(self.conf_path, 'r', encoding='utf8') as f: with open(self.conf_path, 'r', encoding='utf8') as f:
info = json.load(f) info = json.load(f)
# print(json.dumps(info, ensure_ascii=False, indent=4)) # print(json.dumps(info, ensure_ascii=False, indent=4))
...@@ -181,17 +185,20 @@ class ProjectContext: ...@@ -181,17 +185,20 @@ class ProjectContext:
self.excel_path = info["excel_path"] self.excel_path = info["excel_path"]
self.speaker_info = info["speaker_info"]["speaker_id"] self.speaker_info = info["speaker_info"]["speaker_id"]
self.speaker_speed = info["speaker_info"]["speaker_speed"] self.speaker_speed = info["speaker_info"]["speaker_speed"]
self.speaker_type = info["speaker_info"]["speaker_type"] if "speaker_type" in info["speaker_info"] else "科大讯飞"
self.detected = info["detection_info"]["detected"] self.detected = info["detection_info"]["detected"]
self.nd_process = info["detection_info"]["nd_process"] self.nd_process = info["detection_info"]["nd_process"]
self.last_time = info["detection_info"]["last_time"] self.last_time = info["detection_info"]["last_time"]
self.caption_boundings = info["detection_info"]["caption_boundings"] self.caption_boundings = info["detection_info"]["caption_boundings"]
self.has_subtitle = info["detection_info"]["has_subtitle"] self.has_subtitle = info["detection_info"]["has_subtitle"]
# 当前工程下没有配置文件,就初始化一份 # 当前工程下没有配置文件,就初始化一份``
if self.conf_path != this_conf_path: if self.conf_path != this_conf_path:
self.conf_path = this_conf_path self.conf_path = this_conf_path
print("11111sava")
self.save_conf() self.save_conf()
def save_conf(self): def save_conf(self):
print(self.speaker_speed)
with open(self.conf_path, 'w', encoding='utf-8') as f: with open(self.conf_path, 'w', encoding='utf-8') as f:
# if len(self.caption_boundings) > 0: # if len(self.caption_boundings) > 0:
# print(type(self.caption_boundings[0])) # print(type(self.caption_boundings[0]))
...@@ -207,6 +214,7 @@ class ProjectContext: ...@@ -207,6 +214,7 @@ class ProjectContext:
"has_subtitle": self.has_subtitle "has_subtitle": self.has_subtitle
}, },
"speaker_info": { "speaker_info": {
"speaker_type": self.speaker_type,
"speaker_id": self.speaker_info, "speaker_id": self.speaker_info,
"speaker_speed": self.speaker_speed "speaker_speed": self.speaker_speed
} }
...@@ -223,6 +231,7 @@ class ProjectContext: ...@@ -223,6 +231,7 @@ class ProjectContext:
# 先备份文件,再覆盖主文件,可选是否需要备份,默认需要备份 # 先备份文件,再覆盖主文件,可选是否需要备份,默认需要备份
# 20221030:添加旁白检测的进度 # 20221030:添加旁白检测的进度
def save_project(self, need_save_new: bool=False) -> str: def save_project(self, need_save_new: bool=False) -> str:
print("22222sava")
self.save_conf() self.save_conf()
# all_element = sorted(all_element, key=lambda x: float(x.st_time_sec)) # all_element = sorted(all_element, key=lambda x: float(x.st_time_sec))
print("current excel_path:", self.excel_path) print("current excel_path:", self.excel_path)
...@@ -254,6 +263,11 @@ class ProjectContext: ...@@ -254,6 +263,11 @@ class ProjectContext:
if not self.initial_ing: if not self.initial_ing:
save_excel_to_path(self.all_elements, self.excel_path, self.write_header, self.excel_sheet_name) save_excel_to_path(self.all_elements, self.excel_path, self.write_header, self.excel_sheet_name)
def refresh_speed(self, row, speed: str)->None:
self.all_elements[int(row)].speed = speed
if not self.initial_ing:
save_excel_to_path(self.all_elements, self.excel_path, self.write_header, self.excel_sheet_name)
# 加载整个工程,填充到ProjectContext上下文中 # 加载整个工程,填充到ProjectContext上下文中
def load_project(self): def load_project(self):
pass pass
...@@ -344,6 +358,22 @@ class ProjectContext: ...@@ -344,6 +358,22 @@ class ProjectContext:
self.speaker_info = speaker_name[0] self.speaker_info = speaker_name[0]
return tuple(speaker_name) return tuple(speaker_name)
def get_all_speaker_zju_info(self):
"""获取所有说话人的名字、性别及年龄段等信息
用于显示在人机交互界面上,方便用户了解说话人并进行选择
"""
f = open(constant.Pathes.speaker_conf_path, encoding="utf-8")
content = json.load(f)
speaker_name = []
for speaker in content["speaker_zju_details"]:
speaker_name.append(
",".join([speaker["name"], speaker["gender"], speaker["age_group"]]))
if self.speaker_info is None:
self.speaker_info = speaker_name[0]
return tuple(speaker_name)
def init_speakers(self): def init_speakers(self):
"""初始化说话人信息 """初始化说话人信息
...@@ -354,6 +384,8 @@ class ProjectContext: ...@@ -354,6 +384,8 @@ class ProjectContext:
content = json.load(f) content = json.load(f)
for speaker_info in content["speaker_details"]: for speaker_info in content["speaker_details"]:
self.speakers.append(Speaker(speaker_info)) self.speakers.append(Speaker(speaker_info))
for speaker_info in content["speaker_zju_details"]:
self.speakers.append(Speaker(speaker_info))
def choose_speaker(self, speaker_name: str) -> Speaker: def choose_speaker(self, speaker_name: str) -> Speaker:
"""选择说话人 """选择说话人
......
from PyQt5.QtMultimediaWidgets import QVideoWidget from PyQt5.QtMultimediaWidgets import QVideoWidget
from PyQt5.QtCore import * from PyQt5.QtCore import *
from PyQt5.QtMultimedia import QMediaPlayer
class myVideoWidget(QVideoWidget): class myVideoWidget(QVideoWidget):
...@@ -7,6 +8,8 @@ class myVideoWidget(QVideoWidget): ...@@ -7,6 +8,8 @@ class myVideoWidget(QVideoWidget):
def __init__(self, parent=None): def __init__(self, parent=None):
super(QVideoWidget, self).__init__(parent) super(QVideoWidget, self).__init__(parent)
self.setAspectRatioMode(Qt.IgnoreAspectRatio)
def mouseDoubleClickEvent(self, QMouseEvent): #双击事件 def mouseDoubleClickEvent(self, QMouseEvent): #双击事件
......
{"video_path": null, "excel_path": null, "detection_info": {"detected": false, "nd_process": 0.0, "last_time": 0.0, "caption_boundings": [], "has_subtitle": true}, "speaker_info": {"speaker_id": "\u6653\u6653\uff0c\u5973\uff0c\u5e74\u8f7b\u4eba", "speaker_speed": "1.10(4.5\u5b57/\u79d2)"}} {"video_path": null, "excel_path": null, "detection_info": {"detected": false, "nd_process": 0.0, "last_time": 0.0, "caption_boundings": [], "has_subtitle": true}, "speaker_info": {"speaker_type": "\u6d59\u5927\u5185\u90e8tts", "speaker_id": "test\uff0c\u5973\uff0c\u5e74\u8f7b\u4eba", "speaker_speed": "1.00(4\u5b57/\u79d2)"}}
\ No newline at end of file \ No newline at end of file
...@@ -139,5 +139,16 @@ ...@@ -139,5 +139,16 @@
"audio_path": "./res/speaker_audio/Yunye.wav", "audio_path": "./res/speaker_audio/Yunye.wav",
"speaker_code": "zh-CN-YunyeNeural" "speaker_code": "zh-CN-YunyeNeural"
} }
] ],
"speaker_zju_details": [{
"id": 0,
"name": "test",
"language": "中文(普通话,简体)",
"age_group": "年轻人",
"gender": "女",
"description": "休闲、放松的语音,用于自发性对话和会议听录。",
"audio_path": "./res/speaker_zju_audio/local_tts_example.wav",
"speaker_code": "",
"speaker_type":"1"
}]
} }
\ No newline at end of file
...@@ -8,6 +8,7 @@ from setting_dialog_ui import Ui_Dialog ...@@ -8,6 +8,7 @@ from setting_dialog_ui import Ui_Dialog
from utils import validate_and_get_filepath, replace_path_suffix from utils import validate_and_get_filepath, replace_path_suffix
import winsound import winsound
import constant
audioPlayed = winsound.PlaySound(None, winsound.SND_NODEFAULT) audioPlayed = winsound.PlaySound(None, winsound.SND_NODEFAULT)
...@@ -19,41 +20,98 @@ class Setting_Dialog(QDialog, Ui_Dialog): ...@@ -19,41 +20,98 @@ class Setting_Dialog(QDialog, Ui_Dialog):
self.setupUi(self) self.setupUi(self)
self.setWindowTitle("设置") self.setWindowTitle("设置")
self.projectContext = projectContext self.projectContext = projectContext
# todo 把所有说话人都加上来 self.refresh(self.projectContext)
self.speaker_li = self.projectContext.get_all_speaker_info() self.refresh_flag = False
for i in self.speaker_li: self.clear_flag = False
self.comboBox.addItem(i) self.comboBox_0.currentIndexChanged.connect(self.choose)
self.speed_li_2 = ["1.00(4字/秒)", "1.10(4.5字/秒)", "1.25(5字/秒)", "1.50(6字/秒)", "1.75(7字/秒)", "2.00(8字/秒)", "2.50(10字/秒)"]
self.comboBox_2.addItems(self.speed_li_2)
if self.projectContext.speaker_info is None:
self.comboBox.setCurrentIndex(0)
else:
self.comboBox.setCurrentIndex(self.speaker_li.index(self.projectContext.speaker_info))
if self.projectContext.speaker_speed is None:
self.comboBox_2.setCurrentIndex(0)
else:
self.comboBox_2.setCurrentIndex(self.speed_li_2.index(self.projectContext.speaker_speed))
self.comboBox.currentIndexChanged.connect(self.speaker_change_slot) self.comboBox.currentIndexChanged.connect(self.speaker_change_slot)
self.comboBox_2.currentIndexChanged.connect(self.speed_change_slot) self.comboBox_2.currentIndexChanged.connect(self.speed_change_slot)
self.pushButton.clicked.connect(self.play_audio_slot) self.pushButton.clicked.connect(self.play_audio_slot)
def refresh(self,projectContext):
try:
self.refresh_flag = True
self.clear_flag = True
self.comboBox_0.clear()
self.comboBox.clear()
self.comboBox_2.clear()
# todo 把所有说话人都加上来
self.speaker_li = projectContext.get_all_speaker_info()
self.speaker_zju_li = projectContext.get_all_speaker_zju_info() #本地tts
self.speed_list_zju = ["1.00(4字/秒)", "1.10(4.5字/秒)", "1.25(5字/秒)", "1.50(6字/秒)", "1.75(7字/秒)", "2.00(8字/秒)", "2.50(10字/秒)"] #本地tts
# for i in self.speaker_li:
# self.comboBox.addItem(i)
self.speed_li_2 = ["1.00(4字/秒)", "1.10(4.5字/秒)", "1.25(5字/秒)", "1.50(6字/秒)", "1.75(7字/秒)", "2.00(8字/秒)", "2.50(10字/秒)"]
# self.comboBox_2.addItems(self.speed_li_2)
self.speaker_types = ["科大讯飞", "浙大内部tts"]
self.comboBox_0.addItems(self.speaker_types)
print(projectContext.speaker_type)
if projectContext.speaker_type is None or projectContext.speaker_type == "":
self.comboBox_0.setCurrentIndex(0)
else:
self.comboBox_0.setCurrentIndex(self.speaker_types.index(projectContext.speaker_type))
if self.comboBox_0.currentIndex() ==0: #讯飞
self.comboBox.addItems(self.speaker_li)
self.comboBox_2.addItems(self.speed_li_2)
else:
# local
self.comboBox.addItems(self.speaker_zju_li)
self.comboBox_2.addItems(self.speed_list_zju)
self.clear_flag = False
if projectContext.speaker_info is None or projectContext.speaker_info == "":
self.comboBox.setCurrentIndex(0)
else:
print(projectContext.speaker_info)
self.comboBox.setCurrentIndex(self.speaker_li.index(projectContext.speaker_info) if self.comboBox_0.currentIndex() ==0 else self.speaker_zju_li.index(projectContext.speaker_info))
print(projectContext.speaker_speed)
if projectContext.speaker_speed is None or projectContext.speaker_speed == "":
self.comboBox_2.setCurrentIndex(0)
else:
self.comboBox_2.setCurrentIndex(self.speed_li_2.index(projectContext.speaker_speed) if self.comboBox_0.currentIndex() ==0 else self.speed_list_zju.index(projectContext.speaker_speed))
finally:
self.refresh_flag = False
def choose(self):
if self.refresh_flag:
return
print(self.comboBox_0.currentIndex())
self.comboBox.clear()
self.comboBox_2.clear()
self.projectContext.speaker_type = self.comboBox_0.currentText()
if self.comboBox_0.currentIndex() ==0:
print("讯飞")
self.comboBox.addItems(self.speaker_li)
self.comboBox_2.addItems(self.speed_li_2)
# constant.Content.SpeedList.clear()
# constant.Content.SpeedList = self.speed_li_2
else:
print("local")
self.comboBox.addItems(self.speaker_zju_li)
self.comboBox_2.addItems(self.speed_list_zju)
# constant.Content.SpeedList.clear()
# constant.Content.SpeedList = self.speed_list_zju
def content_fresh(self): def content_fresh(self):
"""刷新界面中的内容 """刷新界面中的内容
将工程信息中的说话人信息、说话人语速更新到界面中,如果未选择则初始化为第一个选项 将工程信息中的说话人信息、说话人语速更新到界面中,如果未选择则初始化为第一个选项
""" """
if self.projectContext.speaker_info is None: print(self.projectContext.speaker_info)
if self.projectContext.speaker_info is None or self.projectContext.speaker_info == "" :
self.comboBox.setCurrentIndex(0) self.comboBox.setCurrentIndex(0)
else: else:
self.comboBox.setCurrentIndex(self.speaker_li.index(self.projectContext.speaker_info)) self.comboBox.setCurrentIndex(self.speaker_li.index(self.projectContext.speaker_info) if self.comboBox_0.currentIndex() ==0 else self.speaker_zju_li.index(self.projectContext.speaker_info))
if self.projectContext.speaker_speed is None:
if self.projectContext.speaker_speed is None or self.projectContext.speaker_speed == "":
self.comboBox_2.setCurrentIndex(0) self.comboBox_2.setCurrentIndex(0)
else: else:
self.comboBox_2.setCurrentIndex(self.speed_li_2.index(self.projectContext.speaker_speed)) self.comboBox_2.setCurrentIndex(self.speed_li_2.index(self.projectContext.speaker_speed) if self.comboBox_0.currentIndex() ==0 else self.speed_list_zju.index(self.projectContext.speaker_speed))
def speaker_change_slot(self): def speaker_change_slot(self):
"""切换说话人 """切换说话人
...@@ -61,6 +119,8 @@ class Setting_Dialog(QDialog, Ui_Dialog): ...@@ -61,6 +119,8 @@ class Setting_Dialog(QDialog, Ui_Dialog):
将当前的说话人设置为工程的说话人,并保存到配置文件中 将当前的说话人设置为工程的说话人,并保存到配置文件中
""" """
if self.clear_flag:
return
self.projectContext.speaker_info = self.comboBox.currentText() self.projectContext.speaker_info = self.comboBox.currentText()
self.projectContext.save_conf() self.projectContext.save_conf()
# print("self.projectContext.speaker_info:", self.projectContext.speaker_info) # print("self.projectContext.speaker_info:", self.projectContext.speaker_info)
...@@ -71,6 +131,8 @@ class Setting_Dialog(QDialog, Ui_Dialog): ...@@ -71,6 +131,8 @@ class Setting_Dialog(QDialog, Ui_Dialog):
将当前的语速设置为工程的语速,并保存到配置文件中 将当前的语速设置为工程的语速,并保存到配置文件中
""" """
if self.clear_flag:
return
self.projectContext.speaker_speed = self.comboBox_2.currentText() self.projectContext.speaker_speed = self.comboBox_2.currentText()
self.projectContext.save_conf() self.projectContext.save_conf()
......
...@@ -19,20 +19,32 @@ class Ui_Dialog(object): ...@@ -19,20 +19,32 @@ class Ui_Dialog(object):
self.gridLayout_2.setObjectName("gridLayout_2") self.gridLayout_2.setObjectName("gridLayout_2")
self.gridLayout = QtWidgets.QGridLayout() self.gridLayout = QtWidgets.QGridLayout()
self.gridLayout.setObjectName("gridLayout") self.gridLayout.setObjectName("gridLayout")
self.label_2 = QtWidgets.QLabel(Dialog)
self.label_2.setObjectName("label_2")
self.gridLayout.addWidget(self.label_2, 0, 0, 1, 1)
self.comboBox_0 = QtWidgets.QComboBox(Dialog)
self.comboBox_0.setCurrentText("")
self.comboBox_0.setObjectName("comboBox_0")
self.gridLayout.addWidget(self.comboBox_0, 0, 1, 1, 1)
self.label_3 = QtWidgets.QLabel(Dialog) self.label_3 = QtWidgets.QLabel(Dialog)
self.label_3.setObjectName("label_3") self.label_3.setObjectName("label_3")
self.gridLayout.addWidget(self.label_3, 0, 0, 1, 1) self.gridLayout.addWidget(self.label_3, 1, 0, 1, 1)
self.comboBox = QtWidgets.QComboBox(Dialog) self.comboBox = QtWidgets.QComboBox(Dialog)
self.comboBox.setCurrentText("") self.comboBox.setCurrentText("")
self.comboBox.setObjectName("comboBox") self.comboBox.setObjectName("comboBox")
self.gridLayout.addWidget(self.comboBox, 0, 1, 1, 1) self.gridLayout.addWidget(self.comboBox, 1, 1, 1, 1)
self.label_4 = QtWidgets.QLabel(Dialog) self.label_4 = QtWidgets.QLabel(Dialog)
self.label_4.setObjectName("label_4") self.label_4.setObjectName("label_4")
self.gridLayout.addWidget(self.label_4, 1, 0, 1, 1) self.gridLayout.addWidget(self.label_4, 2, 0, 1, 1)
self.comboBox_2 = QtWidgets.QComboBox(Dialog) self.comboBox_2 = QtWidgets.QComboBox(Dialog)
self.comboBox_2.setCurrentText("") self.comboBox_2.setCurrentText("")
self.comboBox_2.setObjectName("comboBox_2") self.comboBox_2.setObjectName("comboBox_2")
self.gridLayout.addWidget(self.comboBox_2, 1, 1, 1, 1) self.gridLayout.addWidget(self.comboBox_2, 2, 1, 1, 1)
self.gridLayout.setRowMinimumHeight(0, 60) self.gridLayout.setRowMinimumHeight(0, 60)
self.gridLayout.setRowMinimumHeight(1, 60) self.gridLayout.setRowMinimumHeight(1, 60)
self.gridLayout.setColumnStretch(1, 1) self.gridLayout.setColumnStretch(1, 1)
...@@ -50,6 +62,8 @@ class Ui_Dialog(object): ...@@ -50,6 +62,8 @@ class Ui_Dialog(object):
def retranslateUi(self, Dialog): def retranslateUi(self, Dialog):
_translate = QtCore.QCoreApplication.translate _translate = QtCore.QCoreApplication.translate
Dialog.setWindowTitle(_translate("Dialog", "Dialog")) Dialog.setWindowTitle(_translate("Dialog", "Dialog"))
self.label_2.setText(_translate("Dialog", "TTS引擎"))
self.label_3.setText(_translate("Dialog", "旁白说话人:"))
self.label_3.setText(_translate("Dialog", "旁白说话人:")) self.label_3.setText(_translate("Dialog", "旁白说话人:"))
self.label_4.setText(_translate("Dialog", "旁白语速:")) self.label_4.setText(_translate("Dialog", "旁白语速:"))
self.pushButton.setText(_translate("Dialog", "播放样例音频")) self.pushButton.setText(_translate("Dialog", "播放样例音频"))
...@@ -27,6 +27,7 @@ from azure.cognitiveservices.speech import SpeechConfig, SpeechSynthesizer, Resu ...@@ -27,6 +27,7 @@ from azure.cognitiveservices.speech import SpeechConfig, SpeechSynthesizer, Resu
from azure.cognitiveservices.speech.audio import AudioOutputConfig from azure.cognitiveservices.speech.audio import AudioOutputConfig
import openpyxl import openpyxl
import shutil import shutil
from vits_chinese import tts
tmp_file = 'tmp.wav' tmp_file = 'tmp.wav'
adjusted_wav_path = "adjusted.wav" adjusted_wav_path = "adjusted.wav"
...@@ -53,6 +54,8 @@ class Speaker: ...@@ -53,6 +54,8 @@ class Speaker:
self.speaker_code = speaker_info["speaker_code"] self.speaker_code = speaker_info["speaker_code"]
self.description = speaker_info["description"] self.description = speaker_info["description"]
self.voice_example = speaker_info["audio_path"] self.voice_example = speaker_info["audio_path"]
self.speaker_type = speaker_info["speaker_type"] if "speaker_type" in speaker_info else None #speakers.json里面新加字段speaker_type =1 表示用local tts
def init_speakers(): def init_speakers():
...@@ -65,6 +68,8 @@ def init_speakers(): ...@@ -65,6 +68,8 @@ def init_speakers():
global speakers global speakers
for speaker_info in content["speaker_details"]: for speaker_info in content["speaker_details"]:
speakers.append(Speaker(speaker_info)) speakers.append(Speaker(speaker_info))
for speaker_info in content["speaker_zju_details"]:
speakers.append(Speaker(speaker_info))
def choose_speaker(speaker_name: str) -> Speaker: def choose_speaker(speaker_name: str) -> Speaker:
...@@ -94,44 +99,48 @@ def speech_synthesis(text: str, output_file: str, speaker: Speaker, speed: float ...@@ -94,44 +99,48 @@ def speech_synthesis(text: str, output_file: str, speaker: Speaker, speed: float
speed (float, optional): 指定的音频语速. Defaults to 1.0. speed (float, optional): 指定的音频语速. Defaults to 1.0.
""" """
speech_config = SpeechConfig(
subscription="db34d38d2d3447d482e0f977c66bd624",
region="eastus"
)
speech_config.speech_synthesis_language = "zh-CN"
speech_config.speech_synthesis_voice_name = speaker.speaker_code
# 先把合成的语音文件输出得到tmp.wav中,便于可能的调速需求
if not os.path.exists(os.path.dirname(output_file)): # 如果路径不存在 if not os.path.exists(os.path.dirname(output_file)): # 如果路径不存在
print("output_file路径不存在,创建:", os.path.dirname(output_file)) print("output_file路径不存在,创建:", os.path.dirname(output_file))
os.makedirs(os.path.dirname(output_file)) os.makedirs(os.path.dirname(output_file))
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None) if speaker.speaker_type != None and speaker.speaker_type == "1":
ssml_string = f""" tts(text, speed, output_file)
<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{speech_config.speech_synthesis_language}"> else:
<voice name="{speaker.speaker_code}"> speech_config = SpeechConfig(
<prosody rate="{round((speed - 1.0) * 100, 2)}%"> subscription="db34d38d2d3447d482e0f977c66bd624",
{text} region="eastus"
</prosody> )
</voice>
</speak>""" speech_config.speech_synthesis_language = "zh-CN"
result = synthesizer.speak_ssml_async(ssml_string).get() speech_config.speech_synthesis_voice_name = speaker.speaker_code
stream = AudioDataStream(result)
stream.save_to_wav_file(output_file) # 先把合成的语音文件输出得到tmp.wav中,便于可能的调速需求
print(result.reason)
while result.reason == ResultReason.Canceled:
cancellation_details = result.cancellation_details
print("取消的原因", cancellation_details.reason, cancellation_details.error_details)
time.sleep(1)
synthesizer.stop_speaking()
del synthesizer
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None) synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
ssml_string = f"""
<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{speech_config.speech_synthesis_language}">
<voice name="{speaker.speaker_code}">
<prosody rate="{round((speed - 1.0) * 100, 2)}%">
{text}
</prosody>
</voice>
</speak>"""
result = synthesizer.speak_ssml_async(ssml_string).get() result = synthesizer.speak_ssml_async(ssml_string).get()
stream = AudioDataStream(result) stream = AudioDataStream(result)
stream.save_to_wav_file(output_file) stream.save_to_wav_file(output_file)
print(result.reason) print(result.reason)
while result.reason == ResultReason.Canceled:
cancellation_details = result.cancellation_details
print("取消的原因", cancellation_details.reason, cancellation_details.error_details)
time.sleep(1)
synthesizer.stop_speaking()
del synthesizer
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
result = synthesizer.speak_ssml_async(ssml_string).get()
stream = AudioDataStream(result)
stream.save_to_wav_file(output_file)
print(result.reason)
# detached # detached
def change_speed_and_volume(wav_path: str, speed: float = 1.0): def change_speed_and_volume(wav_path: str, speed: float = 1.0):
"""调整语速,顺便把音量调大,语音合成的声音太小了 """调整语速,顺便把音量调大,语音合成的声音太小了
......
...@@ -69,20 +69,21 @@ if __name__ == '__main__': ...@@ -69,20 +69,21 @@ if __name__ == '__main__':
QCoreApplication.setAttribute(Qt.AA_UseHighDpiPixmaps) QCoreApplication.setAttribute(Qt.AA_UseHighDpiPixmaps)
currentExitCode = MainWindow.EXIT_CODE_REBOOT currentExitCode = MainWindow.EXIT_CODE_REBOOT
while currentExitCode == MainWindow.EXIT_CODE_REBOOT: while currentExitCode == MainWindow.EXIT_CODE_REBOOT:
app = QApplication(sys.argv) app = QApplication(sys.argv)
app.setWindowIcon(QIcon("./res/images/eagle_2.ico")) app.setWindowIcon(QIcon("./res/images/eagle_2.ico"))
mainWindow = MainWindow(project_path) mainWindow = MainWindow(project_path)
if project_path == None or project_path == "": if project_path == None or project_path == "":
mainWindow.setWindowTitle(f"无障碍电影制作软件(请打开或新建工程)") mainWindow.setWindowTitle(f"无障碍电影制作软件(请打开或新建工程)")
else: else:
project_name = os.path.basename(project_path) project_name = os.path.basename(project_path)
mainWindow.setWindowTitle(f"无障碍电影制作软件(当前工程为:{project_name})") mainWindow.setWindowTitle(f"无障碍电影制作软件(当前工程为:{project_name})")
mainWindow.renew_signal.connect(change_project_path) mainWindow.renew_signal.connect(change_project_path)
apply_stylesheet(app, theme='dark_amber.xml') apply_stylesheet(app, theme='dark_amber.xml')
mainWindow.show() # mainWindow.show()
currentExitCode = app.exec_() mainWindow.showMaximized()
app = None currentExitCode = app.exec_()
app = None
except Exception as e: except Exception as e:
exc_traceback = ''.join( exc_traceback = ''.join(
traceback.format_exception(*sys.exc_info())) traceback.format_exception(*sys.exc_info()))
......
### 安装环境
```
pip install -r requirements.txt
```
### 接口
infer.py
\ No newline at end of file
import sys
import os
sys.path.append(os.path.dirname(__file__))
from .infer import tts
from .utils import get_hparams_from_file
\ No newline at end of file
This diff is collapsed.
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig, BertTokenizer
class CharEmbedding(nn.Module):
def __init__(self, model_dir):
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(model_dir)
self.bert_config = BertConfig.from_pretrained(model_dir)
self.hidden_size = self.bert_config.hidden_size
self.bert = BertModel(self.bert_config)
self.proj = nn.Linear(self.hidden_size, 256)
self.linear = nn.Linear(256, 3)
def text2Token(self, text):
token = self.tokenizer.tokenize(text)
txtid = self.tokenizer.convert_tokens_to_ids(token)
return txtid
def forward(self, inputs_ids, inputs_masks, tokens_type_ids):
out_seq = self.bert(input_ids=inputs_ids,
attention_mask=inputs_masks,
token_type_ids=tokens_type_ids)[0]
out_seq = self.proj(out_seq)
return out_seq
class TTSProsody(object):
def __init__(self, path, device):
self.device = device
self.char_model = CharEmbedding(path)
self.char_model.load_state_dict(
torch.load(
os.path.join(path, 'prosody_model.pt'),
map_location="cpu"
),
strict=False
)
self.char_model.eval()
self.char_model.to(self.device)
def get_char_embeds(self, text):
input_ids = self.char_model.text2Token(text)
input_masks = [1] * len(input_ids)
type_ids = [0] * len(input_ids)
input_ids = torch.LongTensor([input_ids]).to(self.device)
input_masks = torch.LongTensor([input_masks]).to(self.device)
type_ids = torch.LongTensor([type_ids]).to(self.device)
with torch.no_grad():
char_embeds = self.char_model(
input_ids, input_masks, type_ids).squeeze(0).cpu()
return char_embeds
def expand_for_phone(self, char_embeds, length): # length of phones for char
assert char_embeds.size(0) == len(length)
expand_vecs = list()
for vec, leng in zip(char_embeds, length):
vec = vec.expand(leng, -1)
expand_vecs.append(vec)
expand_embeds = torch.cat(expand_vecs, 0)
assert expand_embeds.size(0) == sum(length)
return expand_embeds.numpy()
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prosody = TTSProsody('./bert/', device)
while True:
text = input("请输入文本:")
prosody.get_char_embeds(text)
from .ProsodyModel import TTSProsody
\ No newline at end of file
{
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"type_vocab_size": 2,
"vocab_size": 21128
}
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment