icdd-vsumm/main.py

155 lines
6.2 KiB
Python
Raw Normal View History

2024-04-20 13:52:25 +08:00
import os
2024-04-21 00:02:47 +08:00
import h5py
2024-04-20 13:52:25 +08:00
import uvicorn
2024-04-21 00:02:47 +08:00
import ffmpeg
2024-04-20 13:52:25 +08:00
from pathlib import Path
from fastapi import FastAPI, UploadFile, Form, File, Header, HTTPException
2024-04-20 13:52:25 +08:00
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from config import ServerConfig
2024-04-21 00:02:47 +08:00
from dsnet import *
2024-04-20 13:52:25 +08:00
app = FastAPI()
2024-04-20 15:11:08 +08:00
app.mount('/home', StaticFiles(directory='static', html=True), name='static')
2024-04-20 13:52:25 +08:00
app.add_middleware(
CORSMiddleware,
2024-04-20 15:11:08 +08:00
allow_origins=['*'],
2024-04-20 13:52:25 +08:00
allow_credentials=True,
2024-04-20 15:11:08 +08:00
allow_methods=['*'],
allow_headers=['*'],
2024-04-20 13:52:25 +08:00
)
2024-04-21 00:02:47 +08:00
video_preprocessor = VideoPreprocessor()
dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device)
2024-04-21 15:41:27 +08:00
dsnet_af.load_state_dict(torch.load(ServerConfig.weights_dsnet_af, map_location=lambda storage, loc: storage))
2024-04-21 00:02:47 +08:00
2024-04-20 15:11:08 +08:00
@app.post('/upload')
async def upload_file(file: UploadFile = File(...)):
2024-04-21 00:02:47 +08:00
video_name = os.path.splitext(file.filename)[0]
video_path = os.path.join(ServerConfig.videos_upload_path, video_name + '.mp4')
2024-04-20 15:11:08 +08:00
with open(video_path, 'wb') as video:
video.write(await file.read())
2024-04-21 00:02:47 +08:00
return {'status': 200, 'id': video_name}
2024-04-20 15:11:08 +08:00
@app.get("/fetch/{id}&location={location}")
async def fetch_file(id: str, location: str, range: str = Header(None)) -> StreamingResponse:
if location == "play":
video_path = Path(os.path.join(ServerConfig.videos_play_path, id + '.mp4'))
elif location == "upload":
video_path = Path(os.path.join(ServerConfig.videos_upload_path, id + '.mp4'))
else:
raise HTTPException(status_code=442, detail=f"location {location} is not defined!")
2024-04-20 15:11:08 +08:00
video_size = video_path.stat().st_size
start = 0
end = video_size - 1
2024-04-21 00:02:47 +08:00
chunk_size = min(video_size, ServerConfig.videos_chunk_size)
2024-04-21 15:41:27 +08:00
status = 200
2024-04-20 15:11:08 +08:00
headers = {
'Content-Type': 'video/mp4',
'Content-Disposition': f'attachment; filename="{id}.mp4"',
}
2024-04-20 13:52:25 +08:00
if range:
2024-04-20 15:11:08 +08:00
start, end = range.replace('bytes=', '').split('-')
start = int(start)
end = int(end) if end else video_size - 1
2024-04-21 00:02:47 +08:00
chunk_size = min(end - start + 1, ServerConfig.videos_chunk_size)
2024-04-21 15:41:27 +08:00
status = 206
2024-04-20 15:11:08 +08:00
headers['Content-Range'] = f'bytes {start}-{end}/{video_size}'
headers['Accept-Ranges'] = 'bytes'
headers['Content-Disposition'] = 'inline'
2024-04-20 13:52:25 +08:00
def file_reader():
with open(video_path, 'rb') as video:
video.seek(start)
while True:
data = video.read(chunk_size)
if not data:
break
yield data
2024-04-21 15:41:27 +08:00
return StreamingResponse(file_reader(), status_code=status, headers=headers, media_type='video/mp4')
2024-04-20 13:52:25 +08:00
2024-04-21 00:02:47 +08:00
@app.post('/extract')
async def extract_file(id: str = Form(...)):
video_source_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4')
2024-04-21 15:41:27 +08:00
data_target_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5')
2024-04-21 00:02:47 +08:00
with h5py.File(data_target_path, 'a') as data:
if id not in data:
n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_source_path)
data.create_dataset(f'{id}/features', data=seq)
data.create_dataset(f'{id}/change_points', data=cps)
data.create_dataset(f'{id}/n_frame_per_seg', data=nfps)
data.create_dataset(f'{id}/n_frames', data=n_frames)
data.create_dataset(f'{id}/picks', data=picks)
return {'status': 200, 'id': id}
@app.post('/analyse')
async def analyse_file(id: str = Form(...)):
2024-04-21 15:41:27 +08:00
data_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5')
2024-04-21 00:02:47 +08:00
pred_summ = None
with h5py.File(data_path, 'r') as data:
if 'vsumm' not in data[id]:
seq = data[id]['features'][...].astype(np.float32)
cps = data[id]['change_points'][...].astype(np.int32)
nfps = data[id]['n_frame_per_seg'][...].astype(np.int32)
n_frames = data[id]['n_frames'][...].astype(np.int32)
picks = data[id]['picks'][...].astype(np.int32)
with torch.no_grad():
seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device)
pred_cls, pred_bboxes = dsnet_af.predict(seq_torch)
pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32)
pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh)
pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks)
if pred_summ is not None:
with h5py.File(data_path, 'a') as data:
data.create_dataset(f'{id}/vsumm', data=pred_summ)
return {'status': 200, 'id': id}
@app.post('/generate')
async def generate_file(id: str = Form(...)):
origin_video_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4')
target_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mpeg4.mp4')
final_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mp4')
2024-04-21 15:41:27 +08:00
data_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5')
2024-04-21 00:02:47 +08:00
with h5py.File(data_path, 'r') as data:
vsumm = data[id]['vsumm'][...].astype(np.bool_)
cap = cv.VideoCapture(origin_video_path)
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv.CAP_PROP_FPS)
2024-04-21 15:41:27 +08:00
fourcc = cv.VideoWriter.fourcc(*'MP4V')
2024-04-21 00:02:47 +08:00
out = cv.VideoWriter(target_video_path, fourcc, fps, (width, height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if vsumm[frame_idx]:
out.write(frame)
frame_idx += 1
out.release()
cap.release()
2024-04-20 13:52:25 +08:00
ffmpeg.input(target_video_path, format='mp4', vcodec='mpeg4') \
.output(final_video_path, format='mp4', vcodec='h264') \
.run(overwrite_output=True)
2024-04-21 00:02:47 +08:00
return {'status': 200, 'id': id}
2024-04-20 13:52:25 +08:00
if __name__ == "__main__":
2024-04-22 10:23:08 +08:00
os.makedirs(ServerConfig.videos_upload_path, exist_ok=True)
os.makedirs(ServerConfig.videos_play_path, exist_ok=True)
os.makedirs(ServerConfig.h5s_root_path, exist_ok=True)
2024-04-21 15:41:27 +08:00
uvicorn.run(app=app)
2024-04-20 13:52:25 +08:00