赞
踩
自用,刚开始接触可能顺序会比较乱。
import jax.numpy as jnp from jax import jit @jit def _extractValues(matrix, positions): values = matrix[positions[:, 0], positions[:, 1]] return values matrix = jnp.array([[5,2,4,2,4,1,3,9,4], [3,4,0,2,8,8,0,9,5], [6,4,0,7,3,0,0,2,7], [2,7,1,6,9,1,6,2,4]]) positions = jnp.array([[0, 0],[1, 0],[2, 0],[2, 1],[3, 0],[3, 1],[3, 2],[0, 3],[0, 4],[0, 5],[1, 3],[1, 4],[1, 5],[2, 4],[2, 5],[2, 6],[3, 5],[3, 6]]) extracted_values = extract_values(matrix, positions) print("Extracted Values: ",extracted_values )
输出:
Extracted Values: [ 4 2 4 3 1 4 0 0 2 -1 0 4 2 1 -2 -2 -1 2]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。