diff --git a/config.py b/config.py index 8dd12e4..600baf1 100644 --- a/config.py +++ b/config.py @@ -1,14 +1,25 @@ import os +import torch + class ServerConfig(object): # 视频文件配置 - video_root_path : str = os.path.join('data', 'video') - video_play_path : str = os.path.join(video_root_path, 'play') - video_upload_path : str = os.path.join(video_root_path, 'upload') - video_chunk_size : int = 1024 * 1024 * 1024 + videos_root_path : str = os.path.join('data', 'videos') + videos_play_path : str = os.path.join(videos_root_path, 'play') + videos_upload_path : str = os.path.join(videos_root_path, 'upload') + videos_chunk_size : int = 1024 * 1024 * 1024 + # 数据文件配置 + hdf5_root_path : str = os.path.join('data', 'h5s') + # TODO -class ModelConfig(object): + +class DSNetConfig(object): + device : str = 'cuda' if torch.cuda.is_available() else 'cpu' + random_seed : int = 4906 + sample_rate : int = 15 + nms_thresh : float = 0.5 + lambda_center : float = 1. # TODO pass diff --git a/dsnet.py b/dsnet.py new file mode 100644 index 0000000..3d7fbf6 --- /dev/null +++ b/dsnet.py @@ -0,0 +1,415 @@ +import math +import torch +import torch.nn as nn +import cv2 as cv +import numpy as np +import torchvision + +from pathlib import Path +from PIL import Image +from torch import Tensor +from torch.nn import Module +from numpy import ndarray, linalg +from torchvision import transforms +from ortools.algorithms.pywrapknapsack_solver import KnapsackSolver +from config import DSNetConfig +from typing import Iterable + + +class FeatureExtractor(object): + + def __init__(self): + super(FeatureExtractor, self).__init__() + self.preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + self.model = torchvision.models.googlenet(pretrained=True) + self.model = torch.nn.Sequential(*list(self.model.children())[:-2]) + self.model = self.model.to(DSNetConfig.device).eval() + + def run(self, img: ndarray) -> ndarray: + img : Image = Image.fromarray(img) + tensor : Tensor = self.preprocess(img) + batch : Tensor = tensor.unsqueeze(0) + with torch.no_grad(): + feat : Tensor|ndarray|None = None + feat = self.model(batch.to(DSNetConfig.device)) + feat = feat.squeeze().cpu().numpy() + + assert feat.shape == (1024,), f'Invalid feature shape {feat.shape}: expected 1024' + feat /= linalg.norm(feat) + 1e-10 + return feat + + +class VideoPreprocessor(object): + + def __init__(self) -> None: + super(VideoPreprocessor, self).__init__() + self.model = FeatureExtractor() + self.sample_rate = DSNetConfig.sample_rate + + def get_features(self, video_path: str) -> tuple[int, ndarray]: + video_path = Path(video_path) + video_capture = cv.VideoCapture(str(video_path)) + assert video_capture is not None, f'Cannot open video: {video_path}' + + features: list[ndarray] = [] + n_frames: int = 0 + while True: + ret, frame = video_capture.read() + if not ret: + break + if n_frames % self.sample_rate == 0: + frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB) + feat = self.model.run(frame) + features.append(feat) + n_frames += 1 + video_capture.release() + features = np.array(features) + return n_frames, features + + def calculate_scatters(K: ndarray) -> ndarray: + n = K.shape[0] + K1 = np.cumsum([0] + list(np.diag(K))) + K2 = np.zeros((n + 1, n + 1)) + K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) + + diagK2 = np.diag(K2) + + i = np.arange(n).reshape((-1, 1)) + j = np.arange(n).reshape((1, -1)) + scatters = ( + K1[1:].reshape((1, -1)) - K1[:-1].reshape((-1, 1)) - + (diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape((-1, 1)) - + K2[1:, :-1].T - K2[:-1, 1:]) / + ((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32)) + ) + scatters[j < i] = 0 + return scatters + + def change_point_detect_nonlin(K: ndarray, ncp: int, lmin: int = 1, lmax: int = 100000, backtrack=True) -> tuple[ndarray, ndarray]: + m = int(ncp) + n, n1 = K.shape + assert n == n1, 'Kernel matrix awaited.' + assert (m + 1) * lmin <= n <= (m + 1) * lmax + assert 1 <= lmin <= lmax + + J = VideoPreprocessor.calculate_scatters(K) + I = 1e101 * np.ones((m + 1, n + 1)) + I[0, lmin:lmax] = J[0, lmin - 1:lmax - 1] + p = np.zeros((m + 1, n + 1), dtype=int) if backtrack else np.zeros((1, 1), dtype=int) + + for k in range(1, m + 1): # k: 当前向视频中插入了k个变化点, 即将视频分为了(k + 1)段 + for l in range((k + 1) * lmin, n + 1): # l: 当序列中出现了k个变化点后, 下一个段的最小起始位置, 也即是当前段的结束位置 + tmin = max(k * lmin, l - lmax) # tmin: 现有的k个变化点,至少使用了k * lmin个帧, 即当前段的最小起始位置 + tmax = l - lmin + 1 # + c = J[tmin:tmax, l - 1].reshape(-1) + \ + I[k - 1, tmin:tmax].reshape(-1) + I[k, l] = np.min(c) + if backtrack: + p[k, l] = np.argmin(c) + tmin + + cps = np.zeros(m, dtype=int) + if backtrack: + cur = n + for k in range(m, 0, -1): + cps[k - 1] = p[k, cur] + cur = cps[k - 1] + scores = I[:, n].copy() + scores[scores > 1e99] = np.inf + return cps, scores + + def change_point_detect_auto(K: ndarray) -> tuple[ndarray, ndarray]: + m, N = len(K) - 1, len(K) + _, scores = VideoPreprocessor.change_point_detect_nonlin(K, m, backtrack=False) + + penalties = np.zeros(m + 1) + ncp = np.arange(1, m + 1) + penalties[1:] = (ncp / (2.0 * N)) * (np.log(float(N) / ncp) + 1) + + costs = scores / float(N) + penalties + m_best = np.argmin(costs) + return VideoPreprocessor.change_point_detect_nonlin(K, m_best) + + def kernel_temporal_segment(self, n_frames: int, features: ndarray) -> tuple[ndarray, ndarray, ndarray]: + seq_len = len(features) + picks = np.arange(0, seq_len) * self.sample_rate + + kernel = np.matmul(features, features.T) + change_points, _ = VideoPreprocessor.change_point_detect_auto(kernel) + change_points *= self.sample_rate + change_points = np.hstack((0, change_points, n_frames)) + begin_frames = change_points[:-1] + end_frames = change_points[1:] + change_points = np.vstack((begin_frames, end_frames - 1)).T + + n_frame_per_seg = end_frames - begin_frames + return change_points, n_frame_per_seg, picks + + def run(self, video_path: str) -> tuple[int, ndarray, ndarray, ndarray, ndarray]: + n_frames, features = self.get_features(video_path) + cps, nfps, picks = self.kernel_temporal_segment(n_frames, features) + return n_frames, features, cps, nfps, picks + + +class ScaledDotProductAttention(Module): + def __init__(self, d_k: float): + super(ScaledDotProductAttention, self).__init__() + self.dropout = nn.Dropout(0.5) + self.sqrt_d_k = math.sqrt(d_k) + + def forward(self, Q: Tensor, K: Tensor, V: Tensor) -> tuple[Tensor, Tensor]: + attn = torch.bmm(Q, K.transpose(2, 1)) + attn = attn / self.sqrt_d_k + attn = torch.softmax(attn, dim=-1) + attn = self.dropout(attn) + y = torch.bmm(attn, V) + return y, attn + + +class MultiHeadAttention(nn.Module): + def __init__(self, num_head: int, num_feature: int) -> None: + super(MultiHeadAttention, self).__init__() + self.num_head = num_head + self.Q = nn.Linear(num_feature, num_feature, bias=False) + self.K = nn.Linear(num_feature, num_feature, bias=False) + self.V = nn.Linear(num_feature, num_feature, bias=False) + self.d_k = num_feature // num_head + self.attention = ScaledDotProductAttention(self.d_k) + self.fc = nn.Sequential( + nn.Linear(num_feature, num_feature, bias=False), + nn.Dropout(0.5) + ) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + _, seq_len, num_feature = x.shape # [1, seq_len, 1024] + K: Tensor = self.K(x) # [1, seq_len, 1024] + Q: Tensor = self.Q(x) # [1, seq_len, 1024] + V: Tensor = self.V(x) # [1, seq_len, 1024] + + K = K.view(1, seq_len, self.num_head, self.d_k).permute( + 2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k) + Q = Q.view(1, seq_len, self.num_head, self.d_k).permute( + 2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k) + V = V.view(1, seq_len, self.num_head, self.d_k).permute( + 2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k) + + y, attn = self.attention(Q, K, V) # [num_head, seq_len, d_k] + y = y.view(1, self.num_head, seq_len, self.d_k).permute( + 0, 2, 1, 3).contiguous().view(1, seq_len, num_feature) + + y = self.fc(y) + return y, attn + + +class AttentionExtractor(MultiHeadAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *inputs): + out, _ = super().forward(*inputs) + return out + + +class DSNetAF(Module): + + def __init__(self, num_feature: int, num_hidden: int, num_head: int) -> None: + super(DSNetAF, self).__init__() + self.base_model = AttentionExtractor(num_head, num_feature) + self.layer_norm = nn.LayerNorm(num_feature) + self.fc1 = nn.Sequential( + nn.Linear(num_feature, num_hidden), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + nn.LayerNorm(num_hidden), + ) + self.fc_cls = nn.Linear(num_hidden, 1) + self.fc_loc = nn.Linear(num_hidden, 2) + self.fc_ctr = nn.Linear(num_hidden, 1) + + def offset2bbox(offsets: np.ndarray) -> np.ndarray: + offset_left, offset_right = offsets[:, 0], offsets[:, 1] + seq_len, _ = offsets.shape + indices = np.arange(seq_len) + bbox_left = indices - offset_left + bbox_right = indices + offset_right + 1 + bboxes = np.vstack((bbox_left, bbox_right)).T + return bboxes + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: + _, seq_len, _ = x.shape + out = self.base_model(x) + out = out + x + out = self.layer_norm(out) + out = self.fc1(out) + pred_cls = self.fc_cls(out).sigmoid().view(seq_len) + pred_loc = self.fc_loc(out).exp().view(seq_len, 2) + pred_ctr = self.fc_ctr(out).sigmoid().view(seq_len) + return pred_cls, pred_loc, pred_ctr + + def predict(self, seq: Tensor) -> tuple[ndarray, ndarray]: + pred_cls, pred_loc, pred_ctr = self(seq) + pred_cls *= pred_ctr + pred_cls /= pred_cls.max() + 1e-8 + pred_cls = pred_cls.cpu().numpy() + pred_loc = pred_loc.cpu().numpy() + pred_bboxes = DSNetAF.offset2bbox(pred_loc) + return pred_cls, pred_bboxes + + +def iou_lr(anchor_bbox: np.ndarray, target_bbox: np.ndarray) -> np.ndarray: + anchor_left, anchor_right = anchor_bbox[:, 0], anchor_bbox[:, 1] + target_left, target_right = target_bbox[:, 0], target_bbox[:, 1] + + inter_left = np.maximum(anchor_left, target_left) + inter_right = np.minimum(anchor_right, target_right) + union_left = np.minimum(anchor_left, target_left) + union_right = np.maximum(anchor_right, target_right) + + intersect = inter_right - inter_left + intersect[intersect < 0] = 0 + union = union_right - union_left + union[union <= 0] = 1e-6 + + iou = intersect / union + return iou + + +def nms(scores: np.ndarray, bboxes: np.ndarray, thresh: float) -> tuple[np.ndarray, np.ndarray]: + valid_idx = bboxes[:, 0] < bboxes[:, 1] + scores = scores[valid_idx] + bboxes = bboxes[valid_idx] + + arg_desc = scores.argsort()[::-1] + + scores_remain = scores[arg_desc] + bboxes_remain = bboxes[arg_desc] + + keep_bboxes = [] + keep_scores = [] + + while bboxes_remain.size > 0: + bbox = bboxes_remain[0] + score = scores_remain[0] + keep_bboxes.append(bbox) + keep_scores.append(score) + + iou = iou_lr(bboxes_remain, np.expand_dims(bbox, axis=0)) + + keep_indices = (iou < thresh) + bboxes_remain = bboxes_remain[keep_indices] + scores_remain = scores_remain[keep_indices] + + keep_bboxes = np.asarray(keep_bboxes, dtype=bboxes.dtype) + keep_scores = np.asarray(keep_scores, dtype=scores.dtype) + return keep_scores, keep_bboxes + + +def knapsack(values: Iterable[int], + weights: Iterable[int], + capacity: int + ) -> list[int]: + + knapsack_solver = KnapsackSolver( + KnapsackSolver.KNAPSACK_DYNAMIC_PROGRAMMING_SOLVER, 'test' + ) + + values = list(values) + weights = list(weights) + capacity = int(capacity) + + knapsack_solver.Init(values, [weights], [capacity]) + knapsack_solver.Solve() + packed_items = [x for x in range(0, len(weights)) + if knapsack_solver.BestSolutionContains(x)] + + return packed_items + + +def get_keyshot_summ(pred: np.ndarray, + cps: np.ndarray, + n_frames: int, + nfps: np.ndarray, + picks: np.ndarray, + proportion: float = 0.15 + ) -> np.ndarray: + assert pred.shape == picks.shape + picks = np.asarray(picks, dtype=np.int32) + + # Get original frame scores from downsampled sequence + frame_scores = np.zeros(n_frames, dtype=np.float32) + for i in range(len(picks)): + pos_lo = picks[i] + pos_hi = picks[i + 1] if i + 1 < len(picks) else n_frames + frame_scores[pos_lo:pos_hi] = pred[i] + + # Assign scores to video shots as the average of the frames. + seg_scores = np.zeros(len(cps), dtype=np.int32) + for seg_idx, (first, last) in enumerate(cps): + scores = frame_scores[first:last + 1] + seg_scores[seg_idx] = int(1000 * scores.mean()) + + # Apply knapsack algorithm to find the best shots + limits = int(n_frames * proportion) + packed = knapsack(seg_scores, nfps, limits) + + # Get key-shot based summary + summary = np.zeros(n_frames, dtype=np.bool_) + for seg_idx in packed: + first, last = cps[seg_idx] + summary[first:last + 1] = True + return summary + + +def bbox2summary(seq_len: int, + pred_cls: np.ndarray, + pred_bboxes: np.ndarray, + change_points: np.ndarray, + n_frames: int, + nfps: np.ndarray, + picks: np.ndarray + ) -> np.ndarray: + score = np.zeros(seq_len, dtype=np.float32) + for bbox_idx in range(len(pred_bboxes)): + lo, hi = pred_bboxes[bbox_idx, 0], pred_bboxes[bbox_idx, 1] + score[lo:hi] = np.maximum(score[lo:hi], [pred_cls[bbox_idx]]) + + pred_summ = get_keyshot_summ(score, change_points, n_frames, nfps, picks) + return pred_summ + + +video_preprocessor = VideoPreprocessor() +dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device) + + +def extract(video_path: str) -> Tensor: + n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_path) + with torch.no_grad(): + seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device) + pred_cls, pred_bboxes = dsnet_af.predict(seq_torch) + pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32) + pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh) + pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks) + cap = cv.VideoCapture(video_path) + width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv.CAP_PROP_FPS) + + # create summary video writer + fourcc = cv.VideoWriter.fourcc(*'mp4v') + out = cv.VideoWriter(args.save_path, fourcc, fps, (width, height)) + + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + if pred_summ[frame_idx]: + out.write(frame) + frame_idx += 1 + out.release() + cap.release() + diff --git a/main.py b/main.py index f458f76..b4ca6e3 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,15 @@ import os +import h5py import uvicorn +import ffmpeg from pathlib import Path -from fastapi import FastAPI, UploadFile, File, Header +from fastapi import FastAPI, UploadFile, Form, File, Header from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from config import ServerConfig +from dsnet import * app = FastAPI() @@ -20,21 +23,26 @@ app.add_middleware( ) +video_preprocessor = VideoPreprocessor() +dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device) + + @app.post('/upload') async def upload_file(file: UploadFile = File(...)): - video_path = os.path.join(ServerConfig.video_upload_path, file.filename + '.mp4') + video_name = os.path.splitext(file.filename)[0] + video_path = os.path.join(ServerConfig.videos_upload_path, video_name + '.mp4') with open(video_path, 'wb') as video: video.write(await file.read()) - return {'status': 200, 'id': os.path.splitext(file.filename)[0]} + return {'status': 200, 'id': video_name} @app.get("/fetch/{id}") async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse: - video_path = Path(os.path.join(ServerConfig.video_play_path, id + '.mp4')) + video_path = Path(os.path.join(ServerConfig.videos_play_path, id + '.mp4')) video_size = video_path.stat().st_size start = 0 end = video_size - 1 - chunk_size = min(video_size, ServerConfig.video_chunk_size) + chunk_size = min(video_size, ServerConfig.videos_chunk_size) headers = { 'Content-Type': 'video/mp4', 'Content-Disposition': f'attachment; filename="{id}.mp4"', @@ -43,7 +51,7 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse: start, end = range.replace('bytes=', '').split('-') start = int(start) end = int(end) if end else video_size - 1 - chunk_size = min(end - start + 1, ServerConfig.video_chunk_size) + chunk_size = min(end - start + 1, ServerConfig.videos_chunk_size) headers['Content-Range'] = f'bytes {start}-{end}/{video_size}' headers['Accept-Ranges'] = 'bytes' headers['Content-Disposition'] = 'inline' @@ -58,9 +66,80 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse: return StreamingResponse(file_reader(), status_code=206, headers=headers, media_type='video/mp4') +@app.post('/extract') +async def extract_file(id: str = Form(...)): + video_source_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4') + data_target_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5') + with h5py.File(data_target_path, 'a') as data: + if id not in data: + n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_source_path) + data.create_dataset(f'{id}/features', data=seq) + data.create_dataset(f'{id}/change_points', data=cps) + data.create_dataset(f'{id}/n_frame_per_seg', data=nfps) + data.create_dataset(f'{id}/n_frames', data=n_frames) + data.create_dataset(f'{id}/picks', data=picks) + return {'status': 200, 'id': id} +@app.post('/analyse') +async def analyse_file(id: str = Form(...)): + data_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5') + pred_summ = None + with h5py.File(data_path, 'r') as data: + if 'vsumm' not in data[id]: + seq = data[id]['features'][...].astype(np.float32) + cps = data[id]['change_points'][...].astype(np.int32) + nfps = data[id]['n_frame_per_seg'][...].astype(np.int32) + n_frames = data[id]['n_frames'][...].astype(np.int32) + picks = data[id]['picks'][...].astype(np.int32) + with torch.no_grad(): + seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device) + pred_cls, pred_bboxes = dsnet_af.predict(seq_torch) + pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32) + pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh) + pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks) + if pred_summ is not None: + with h5py.File(data_path, 'a') as data: + data.create_dataset(f'{id}/vsumm', data=pred_summ) + return {'status': 200, 'id': id} + + +@app.post('/generate') +async def generate_file(id: str = Form(...)): + origin_video_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4') + target_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mpeg4.mp4') + final_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mp4') + data_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5') + with h5py.File(data_path, 'r') as data: + vsumm = data[id]['vsumm'][...].astype(np.bool_) + cap = cv.VideoCapture(origin_video_path) + width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv.CAP_PROP_FPS) + + # create summary video writer + fourcc = cv.VideoWriter.fourcc(*'H264') + out = cv.VideoWriter(target_video_path, fourcc, fps, (width, height)) + + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + if vsumm[frame_idx]: + out.write(frame) + + frame_idx += 1 + out.release() + cap.release() + + ffmpeg.input(target_video_path, format='mp4', vcodec='mpeg4').output(final_video_path, format='mp4', vcodec='h264').run() + return {'status': 200, 'id': id} + if __name__ == "__main__": - uvicorn.run(app=app) + fourcc = cv.VideoWriter.fourcc(*'H264') + out = cv.VideoWriter('123.mp4', fourcc, 30, (1024, 768)) + pass + #uvicorn.run(app=app) diff --git a/requirements.txt b/requirements.txt index 0b78062..863c1f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ mypy==1.9.0 -numpy==1.22.0 +h5py==3.11.0 +numpy==1.26.4 torch==2.1.0 torchvision==0.16.0 opencv-python==4.9.0.80 @@ -7,3 +8,5 @@ opencv-contrib-python==4.9.0.80 uvicorn==0.29.0 fastapi==0.110.2 python-multipart==0.0.9 +ortools==9.9.396 +ffmpeg-python==0.2.0 diff --git a/static/js/index.js b/static/js/index.js index 3cc9f14..71e2d41 100644 --- a/static/js/index.js +++ b/static/js/index.js @@ -59,7 +59,7 @@ function onFileUploadStepped(index, percent) { * * 1) 将表格 status-table 中状态栏的信息改为 '正在提取' * 2) 将表格 status-table 中进度栏的进度信息重置为 '0%' - * 3) 调用服务器接口, 开始执行推理 + * 3) 调用服务器接口, 开始执行提取 * * @param {number} index 文件的索引 * @param {File} file 文件对象 @@ -69,7 +69,7 @@ function onFileUploadFinished(index, file, data) { $(`#stat-label-${index}`).html('正在提取'); $(`#stat-prog-${index}`).css({'width': '0%'}); $(`#stat-prog-${index}`).html('0%'); - + extractFile(index, file, data.id); } /** @@ -97,6 +97,134 @@ function onFileUploadFailed(index, file, err) { } /** + * 当文件提取成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的特征信息到h5文件中. + * + * 1) 将表格 status-table 中进度栏的进度信息更新为50% + * 2) 调用服务器接口, 开始执行提取关键帧 + * + * @param {number} index 文件的索引 + * @param {File} file 文件对象 + * @param {object} data 服务器响应数据 + */ +function onExtractFileFinished(index, file, data) { + $(`#stat-label-${index}`).html('正在分析'); + $(`#stat-prog-${index}`).css({'width': '50%'}); + $(`#stat-prog-${index}`).html('50%'); + analyseFile(index, file, data.id); +} + +/** + * 当文件提取失败后, 此函数将被回调, 用于更新部分界面组件. + * + * 1) 将表格 status-table 中的状态栏信息改为 '提取失败'. + * 2) 将表格 status-table 中进度栏的进度信息重置为 '0%' + * 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因. + * + * @param {number} index 文件的索引 + * @param {File} file 文件对象 + * @param {error} err 失败原因 + */ +function onExtractFileFailed(index, file, err) { + $(`#stat-label-${index}`).html('
提取失败
'); + $(`#stat-prog-${index}`).css({'width': '0%'}); + $(`#stat-prog-${index}`).html('0%'); + $(`#stat-button-${index}`).removeClass('btn-dark disabled'); + $(`#stat-button-${index}`).addClass('btn-danger'); + $(`#stat-button-${index}`).find('span').remove(); + $(`#stat-button-${index}`).text('查看'); + $(`#stat-button-${index}`).on('click', function() { + errorModal('错误', `文件${file.name}提取失败: ${err}`); + }); +} + +/** + * 当文件分析成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的特征信息和分析结果到h5文件中. + * + * 1) 将表格 status-table 中进度栏的进度信息更新为75% + * 2) 调用服务器接口, 生成摘要视频 + * + * @param {number} index 文件的索引 + * @param {File} file 文件对象 + * @param {object} data 服务器响应数据 + */ +function onAnalyseFileFinished(index, file, data) { + $(`#stat-label-${index}`).html('正在导出'); + $(`#stat-prog-${index}`).css({'width': '50%'}); + $(`#stat-prog-${index}`).html('50%'); + generateFile(index, file, data.id); +} + +/** + * 当文件分析失败后, 此函数将被回调, 用于更新部分界面组件. + * + * 1) 将表格 status-table 中的状态栏信息改为 '分析失败'. + * 2) 将表格 status-table 中进度栏的进度信息重置为 '0%' + * 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因. + * + * @param {number} index 文件的索引 + * @param {File} file 文件对象 + * @param {error} err 失败原因 + */ +function onAnalyseFileFailed(index, file, err) { + $(`#stat-label-${index}`).html('
分析失败
'); + $(`#stat-prog-${index}`).css({'width': '0%'}); + $(`#stat-prog-${index}`).html('0%'); + $(`#stat-button-${index}`).removeClass('btn-dark disabled'); + $(`#stat-button-${index}`).addClass('btn-danger'); + $(`#stat-button-${index}`).find('span').remove(); + $(`#stat-button-${index}`).text('查看'); + $(`#stat-button-${index}`).on('click', function() { + errorModal('错误', `文件${file.name}分析失败: ${err}`); + }); +} + + + + + + +/** + * 当文件生成成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的摘要. + * + * 1) 将表格 status-table 中进度栏的进度信息更新为100% + * 2) 调用服务器接口, 生成摘要视频 + * + * @param {number} index 文件的索引 + * @param {File} file 文件对象 + * @param {object} data 服务器响应数据 + */ +function onGenerateFileFinished(index, file, data) { + $(`#stat-label-${index}`).html('
操作完成
'); + $(`#stat-prog-${index}`).css({'width': '100%'}); + $(`#stat-prog-${index}`).html('100%'); + //TODO: +} + +/** + * 当文件生成失败后, 此函数将被回调, 用于更新部分界面组件. + * + * 1) 将表格 status-table 中的状态栏信息改为 '生成失败'. + * 2) 将表格 status-table 中进度栏的进度信息重置为 '0%' + * 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因. + * + * @param {number} index 文件的索引 + * @param {File} file 文件对象 + * @param {error} err 失败原因 + */ +function onGenerateFileFailed(index, file, err) { + $(`#stat-label-${index}`).html('
生成失败
'); + $(`#stat-prog-${index}`).css({'width': '0%'}); + $(`#stat-prog-${index}`).html('0%'); + $(`#stat-button-${index}`).removeClass('btn-dark disabled'); + $(`#stat-button-${index}`).addClass('btn-danger'); + $(`#stat-button-${index}`).find('span').remove(); + $(`#stat-button-${index}`).text('查看'); + $(`#stat-button-${index}`).on('click', function() { + errorModal('错误', `文件${file.name}生成失败: ${err}`); + }); +} + +/** * 上传指定的文件到服务器的接口. * * 此函数会回调 onFileUploadStepped 函数, 告知文件的上传进度. @@ -148,6 +276,8 @@ function uploadFile(index, file) { /** * 在结果窗口中播放服务器中编号为id的视频. * + * @param {index} index 文件的索引 + * @param {file} file 文件对象 * @param {string} id 服务器视频id */ function playFile(index, file, id) { @@ -161,6 +291,90 @@ function playFile(index, file, id) { } /** + * 让服务器提取目标文件的特征. + * + * @param {index} index 文件的索引 + * @param {file} file 文件对象 + * @param {string} id 服务器视频id + */ +function extractFile(index, file, id) { + var formData = new FormData(); + formData.append('id', id); + $.ajax({ + url: 'http://127.0.0.1:8000/extract', + type: 'POST', + data: formData, + processData: false, // 不处理发送的数据 + contentType: false, // 不设置内容类型 + success: function(data) { + if (data.status != 200) { + //TODO: 添加错误处理 + } + onExtractFileFinished(index, file, data); + }, + error: function(xhr, status, error) { + onExtractFileFailed(index, file, error); + }, + }); +} + +/** + * 让服务器分析目标文件的关键帧. + * + * @param {index} index 文件的索引 + * @param {file} file 文件对象 + * @param {string} id 服务器视频id + */ +function analyseFile(index, file, id) { + var formData = new FormData(); + formData.append('id', id); + $.ajax({ + url: 'http://127.0.0.1:8000/analyse', + type: 'POST', + data: formData, + processData: false, // 不处理发送的数据 + contentType: false, // 不设置内容类型 + success: function(data) { + if (data.status != 200) { + //TODO: 添加错误处理 + } + onAnalyseFileFinished(index, file, data); + }, + error: function(xhr, status, error) { + onAnalyseFileFailed(index, file, error); + }, + }); +} + +/** + * 让服务器生成目标文件的摘要视频. + * + * @param {index} index 文件的索引 + * @param {file} file 文件对象 + * @param {string} id 服务器视频id + */ +function generateFile(index, file, id) { + var formData = new FormData(); + formData.append('id', id); + $.ajax({ + url: 'http://127.0.0.1:8000/generate', + type: 'POST', + data: formData, + processData: false, // 不处理发送的数据 + contentType: false, // 不设置内容类型 + success: function(data) { + if (data.status != 200) { + //TODO: 添加错误处理 + } + onGenerateFileFinished(index, file, data); + }, + error: function(xhr, status, error) { + onGenerateFileFailed(index, file, error); + }, + }); +} + +/** * 向表格'status-table'中增加一行, 向用户展示上传文件的信息与当前状态, 并提供必要的操作. * * @param {number} index 正在处理的文件编号.