赞
踩
引入依赖
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.13.1</version>
</dependency>
定一个一个类,用来保存转换后的 ByteString 和维度信息,因为 ByteString 本身是不包含数组的维度信息的,因此信息需要单独保存。
import com.google.protobuf.ByteString; public class ByteStringWithDims { //byteString private ByteString byteString; //原始数组的维度信息 private long[] dims; public ByteStringWithDims(ByteString byteString, long[] dims) { this.byteString = byteString; this.dims = dims; } public ByteString getByteString() { return byteString; } public long[] getDims() { return dims; } }
如下的方法即可将输入的任意维度的基本类型的数组和 String 数组转成 ByteString:
/** * 将输入的数组转换成 ByteString 并保存 dims 信息 * @param array * @return */ public static ByteStringWithDims toByteString(Object array) { try { if(null == array) { return null; } long[] dims = shape(array); if(null == dims) { return null; } //多维数组降成一维 if(dims.length > 1) { array = toSingleArray(array, dims); } //获取数据类型 Class<?> type = Array.get(array, 0).getClass(); //将 string 转成 byte 数组 if(type.equals(String.class)) { array = toByteArray((String[])array); } //利用 tensor 进行转换,提高速度 Tensor<?> tensor = Tensor.create(array, type); ByteBuffer bb = ByteBuffer.allocate(tensor.numBytes()); tensor.writeTo(bb); return new ByteStringWithDims(ByteString.copyFrom(bb.array()), dims); } catch (Exception e) { throw e; } }
如下的方法可以将 ByteString 转成指定形状,指定类型的数组:
/** * 将输入的 byteString 转成 type 类型,维度为 dims 定义的数组 * @param byteString * @param type * @param dims * @return */ public static Object byteString2Array(ByteString byteString, Class<?> type, long[] dims) { try { ByteBuffer byteBuffer = byteString.asReadOnlyByteBuffer(); Tensor<?> tensor = Tensor.create(type, dims, byteBuffer); Method allocateMethod = BUFF_CLAZZ_MAP.get(type).getMethod("allocate", int.class); Buffer buffer = (Buffer) allocateMethod.invoke(null, (String.class.equals(type)) ? tensor.numBytes() :tensor.numElements()); Method m = Tensor.class.getMethod("writeTo", BUFF_CLAZZ_MAP.get(type)); m.invoke(tensor, buffer); return toMultiArray(buffer.array(), dims); } catch (Exception e) { throw new RuntimeException("read result form byte string failed", e); } }
完整的代码如下:
import java.lang.reflect.Array; import java.lang.reflect.Method; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import org.tensorflow.Tensor; import com.google.protobuf.ByteString; public class MultiArrayUtils { /** * 将输入的 byteString 转成 type 类型,维度为 dims 定义的数组 * @param byteString * @param type * @param dims * @return */ public static Object byteString2Array(ByteString byteString, Class<?> type, long[] dims) { try { ByteBuffer byteBuffer = byteString.asReadOnlyByteBuffer(); Tensor<?> tensor = Tensor.create(type, dims, byteBuffer); Method allocateMethod = BUFF_CLAZZ_MAP.get(type).getMethod("allocate", int.class); Buffer buffer = (Buffer) allocateMethod.invoke(null, (String.class.equals(type)) ? tensor.numBytes() :tensor.numElements()); Method m = Tensor.class.getMethod("writeTo", BUFF_CLAZZ_MAP.get(type)); m.invoke(tensor, buffer); return toMultiArray(buffer.array(), dims); } catch (Exception e) { throw new RuntimeException("read result form byte string failed", e); } } /** * 将输入的数组转换成 ByteString 并保存 dims 信息 * @param array * @return */ public static ByteStringWithDims toByteString(Object array) { try { if(null == array) { return null; } long[] dims = shape(array); if(null == dims) { return null; } //多维数组降成一维 if(dims.length > 1) { array = toSingleArray(array, dims); } //获取数据类型 Class<?> type = Array.get(array, 0).getClass(); //将 string 转成 byte 数组 if(type.equals(String.class)) { array = toByteArray((String[])array); } //利用 tensor 进行转换,提高速度 Tensor<?> tensor = Tensor.create(array, type); ByteBuffer bb = ByteBuffer.allocate(tensor.numBytes()); tensor.writeTo(bb); return new ByteStringWithDims(ByteString.copyFrom(bb.array()), dims); } catch (Exception e) { throw e; } } /** * String 数组转成二维的 byte 数组 * @param array * @return */ private static Object toByteArray(String[] array) { byte[][] ret = new byte[array.length][]; for(int i = 0; i < array.length; i++) { if(null != array[i]) { ret[i] = array[i].getBytes(); } } return ret; } /** * 多维数组的字符串 * @param array * @return */ public static String toString(Object array) { if(!array.getClass().isArray()) { return array.toString(); } long length = getArrayLength(array); StringBuffer buf = new StringBuffer(); buf.append('['); for(int i = 0; i < length; i++) { buf.append(toString(Array.get(array, i))); buf.append(','); } buf.deleteCharAt(buf.length() - 1).append(']'); return buf.toString(); } /** * 产生一个随机数组 * @param dims * @return */ public static Object randomFloatArray(long[] dims) { Random r = new Random(); Object ret = Array.newInstance(float.class, longArray2intArray(dims)); MultiPoint points = new MultiPoint(dims); while(points.hasNext()) { long[] index = points.next(); Object tmp = ret; for(int i = 0; i < index.length - 1; i++) { tmp = Array.get(tmp, (int) index[i]); } Array.set(tmp, (int) index[index.length - 1], (float)r.nextInt(10)); } return ret; } /** * 产生一个随机数组 * @param dims * @return */ public static Object randomIntegerArray(long[] dims) { Random r = new Random(); Object ret = Array.newInstance(int.class, longArray2intArray(dims)); MultiPoint points = new MultiPoint(dims); while(points.hasNext()) { long[] index = points.next(); Object tmp = ret; for(int i = 0; i < index.length - 1; i++) { tmp = Array.get(tmp, (int) index[i]); } Array.set(tmp, (int) index[index.length - 1], r.nextInt(10)); } return ret; } /** * 将一维数组转成多维数组 * @param result * @param dims * @return */ private static Object toMultiArray(Object result, long[] dims) { if(null == result || dims.length == 0 || dims.length == 1 || !result.getClass().isArray()) { return result; } Object ret = Array.newInstance(result.getClass().getComponentType(), longArray2intArray(dims)); MultiPoint point = new MultiPoint(dims); long length = getArrayLength(result); if(length != point.getMaxEleNum()) { throw new IndexOutOfBoundsException("result size: " + length + ", not match dims: " + Arrays.toString(dims)); } int resultIndex = 0; while(point.hasNext()) { long[] index = point.next(); Object tmp = ret; for(int i = 0; i < index.length - 1; i++) { tmp = Array.get(tmp, (int) index[i]); } Array.set(tmp, (int) index[index.length - 1], Array.get(result, resultIndex)); resultIndex++; } return ret; } /** * 将高维数组降成一维数组 * @param array * @param dims * @return */ private static Object toSingleArray(Object array, long[] dims) { MultiPoint tp = new MultiPoint(dims); Object sample = array; for(int i = 0; i < dims.length - 1; i++) { sample = Array.get(sample, 0); } Object newArray = Array.newInstance(sample.getClass().getComponentType(), (int) tp.getMaxEleNum()); int index = 0; while(tp.hasNext()) { long[] point = tp.next(); Object tmp = array; for(long p : point) { tmp = Array.get(tmp, (int)p); } Array.set(newArray, index, tmp); index++; } return newArray; } /** * 获取输入的多维数据的维度 * @param array * @return */ private static long[] shape(Object array) { long length = getArrayLength(array); if(length <= 0) { return null; } List<Long> tmp = new ArrayList<>(); while(length > 0) { tmp.add((long) length); array = Array.get(array, 0); length = getArrayLength(array); } long[] ret = new long[tmp.size()]; for(int i = 0; i < ret.length; i++) { ret[i] = tmp.get(i); } return ret; } /** * 获取数组的长度 * @param obj * @return */ private static long getArrayLength(Object obj) { if(obj == null || !obj.getClass().isArray()) { return -1; } return Array.getLength(obj); } /** * long[] 转 int[] * @param dims * @return */ private static int[] longArray2intArray(long[] dims) { int[] ret = new int[dims.length]; for(int i = 0; i < dims.length; i++) { ret[i] = (int) dims[i]; } return ret; } private final static Map<Class<?>, Class<? extends Buffer>> BUFF_CLAZZ_MAP = new HashMap<>(); static { BUFF_CLAZZ_MAP.put(Float.class, FloatBuffer.class); BUFF_CLAZZ_MAP.put(Double.class, DoubleBuffer.class); BUFF_CLAZZ_MAP.put(Integer.class, IntBuffer.class); BUFF_CLAZZ_MAP.put(Long.class, LongBuffer.class); BUFF_CLAZZ_MAP.put(String.class, ByteBuffer.class); } }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。