diff --git a/config.py b/config.py index 600baf1..5a3fba7 100644 --- a/config.py +++ b/config.py @@ -9,8 +9,10 @@ class ServerConfig(object): videos_upload_path : str = os.path.join(videos_root_path, 'upload') videos_chunk_size : int = 1024 * 1024 * 1024 # 数据文件配置 - hdf5_root_path : str = os.path.join('data', 'h5s') - + h5s_root_path : str = os.path.join('data', 'h5s') + # 权重文件配置 + weights_root_path : str = os.path.join('data', 'weights') + weights_dsnet_af : str = os.path.join(weights_root_path, 'dsnet_af.pth') # TODO diff --git a/data/weights/dsnet_af.pth b/data/weights/dsnet_af.pth new file mode 100644 index 0000000..6b02664 Binary files /dev/null and b/data/weights/dsnet_af.pth differ diff --git a/dsnet.py b/dsnet.py index 3d7fbf6..ca8482b 100644 --- a/dsnet.py +++ b/dsnet.py @@ -379,37 +379,3 @@ def bbox2summary(seq_len: int, pred_summ = get_keyshot_summ(score, change_points, n_frames, nfps, picks) return pred_summ - - -video_preprocessor = VideoPreprocessor() -dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device) - - -def extract(video_path: str) -> Tensor: - n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_path) - with torch.no_grad(): - seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device) - pred_cls, pred_bboxes = dsnet_af.predict(seq_torch) - pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32) - pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh) - pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks) - cap = cv.VideoCapture(video_path) - width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) - height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) - fps = cap.get(cv.CAP_PROP_FPS) - - # create summary video writer - fourcc = cv.VideoWriter.fourcc(*'mp4v') - out = cv.VideoWriter(args.save_path, fourcc, fps, (width, height)) - - frame_idx = 0 - while True: - ret, frame = cap.read() - if not ret: - break - if pred_summ[frame_idx]: - out.write(frame) - frame_idx += 1 - out.release() - cap.release() - diff --git a/main.py b/main.py index b4ca6e3..985a8e3 100644 --- a/main.py +++ b/main.py @@ -25,6 +25,7 @@ app.add_middleware( video_preprocessor = VideoPreprocessor() dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device) +dsnet_af.load_state_dict(torch.load(ServerConfig.weights_dsnet_af, map_location=lambda storage, loc: storage)) @app.post('/upload') @@ -43,6 +44,7 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse: start = 0 end = video_size - 1 chunk_size = min(video_size, ServerConfig.videos_chunk_size) + status = 200 headers = { 'Content-Type': 'video/mp4', 'Content-Disposition': f'attachment; filename="{id}.mp4"', @@ -52,6 +54,7 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse: start = int(start) end = int(end) if end else video_size - 1 chunk_size = min(end - start + 1, ServerConfig.videos_chunk_size) + status = 206 headers['Content-Range'] = f'bytes {start}-{end}/{video_size}' headers['Accept-Ranges'] = 'bytes' headers['Content-Disposition'] = 'inline' @@ -63,13 +66,13 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse: if not data: break yield data - return StreamingResponse(file_reader(), status_code=206, headers=headers, media_type='video/mp4') + return StreamingResponse(file_reader(), status_code=status, headers=headers, media_type='video/mp4') @app.post('/extract') async def extract_file(id: str = Form(...)): video_source_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4') - data_target_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5') + data_target_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5') with h5py.File(data_target_path, 'a') as data: if id not in data: n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_source_path) @@ -83,7 +86,7 @@ async def extract_file(id: str = Form(...)): @app.post('/analyse') async def analyse_file(id: str = Form(...)): - data_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5') + data_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5') pred_summ = None with h5py.File(data_path, 'r') as data: if 'vsumm' not in data[id]: @@ -109,7 +112,7 @@ async def generate_file(id: str = Form(...)): origin_video_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4') target_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mpeg4.mp4') final_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mp4') - data_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5') + data_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5') with h5py.File(data_path, 'r') as data: vsumm = data[id]['vsumm'][...].astype(np.bool_) cap = cv.VideoCapture(origin_video_path) @@ -117,8 +120,7 @@ async def generate_file(id: str = Form(...)): height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv.CAP_PROP_FPS) - # create summary video writer - fourcc = cv.VideoWriter.fourcc(*'H264') + fourcc = cv.VideoWriter.fourcc(*'MP4V') out = cv.VideoWriter(target_video_path, fourcc, fps, (width, height)) frame_idx = 0 @@ -138,8 +140,5 @@ async def generate_file(id: str = Form(...)): if __name__ == "__main__": - fourcc = cv.VideoWriter.fourcc(*'H264') - out = cv.VideoWriter('123.mp4', fourcc, 30, (1024, 768)) - pass - #uvicorn.run(app=app) + uvicorn.run(app=app) diff --git a/static/index.html b/static/index.html index a7eea90..df00e2c 100644 --- a/static/index.html +++ b/static/index.html @@ -50,7 +50,6 @@ @@ -63,7 +62,7 @@
- +
diff --git a/static/js/index.js b/static/js/index.js index 71e2d41..31248d1 100644 --- a/static/js/index.js +++ b/static/js/index.js @@ -178,16 +178,12 @@ function onAnalyseFileFailed(index, file, err) { }); } - - - - - /** * 当文件生成成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的摘要. * * 1) 将表格 status-table 中进度栏的进度信息更新为100% * 2) 调用服务器接口, 生成摘要视频 + * 3) * * @param {number} index 文件的索引 * @param {File} file 文件对象 @@ -197,7 +193,13 @@ function onGenerateFileFinished(index, file, data) { $(`#stat-label-${index}`).html('
操作完成
'); $(`#stat-prog-${index}`).css({'width': '100%'}); $(`#stat-prog-${index}`).html('100%'); - //TODO: + $(`#stat-button-${index}`).removeClass('btn-dark disabled'); + $(`#stat-button-${index}`).addClass('btn-success'); + $(`#stat-button-${index}`).find('span').remove(); + $(`#stat-button-${index}`).text('查看'); + $(`#stat-button-${index}`).on('click', function() { + playFile(index, file, data.id); + }); } /** @@ -284,10 +286,10 @@ function playFile(index, file, id) { $('#video-modal').find('.modal-title').text(`预览${file ? file.name : '此视频'}的关键镜头`); $('#video-modal').find('video').attr('src', `http://127.0.0.1:8000/fetch/${id}`); var player = new Plyr('video'); - $('#video-modal').modal(); $('#video-modal').on('hidden.bs.modal', function(event) { player.pause(); }); + $('#video-modal').modal(); } /**