[docs]defget_cluster_model(ckpt_path:Path|str):withPath(ckpt_path).open("rb")asf:checkpoint=torch.load(f,map_location="cpu")# Danger of arbitrary code executionkmeans_dict={}forspk,ckptincheckpoint.items():km=KMeans(ckpt["n_features_in_"])km.__dict__["n_features_in_"]=ckpt["n_features_in_"]km.__dict__["_n_threads"]=ckpt["_n_threads"]km.__dict__["cluster_centers_"]=ckpt["cluster_centers_"]kmeans_dict[spk]=kmreturnkmeans_dict
[docs]defcheck_speaker(model:Any,speaker:Any):ifspeakernotinmodel:raiseValueError(f"Speaker {speaker} not in {list(model.keys())}")
[docs]defget_cluster_result(model:Any,x:Any,speaker:Any):""" x: np.array [t, 256] return cluster class result """check_speaker(model,speaker)returnmodel[speaker].predict(x)