58 lines
1.4 KiB
JavaScript
58 lines
1.4 KiB
JavaScript
import { Registry } from 'mixly';
|
|
import * as tf from '@tensorflow/tfjs';
|
|
|
|
|
|
const modelsValueRegistry = new Registry();
|
|
const customFetch = function (path) {
|
|
let result = {
|
|
ok: false,
|
|
buffer: null,
|
|
json: function () {
|
|
const decoder = new TextDecoder('utf-8');
|
|
const jsonText = decoder.decode(this.buffer);
|
|
return JSON.parse(jsonText);
|
|
},
|
|
arrayBuffer: function () {
|
|
return this.buffer;
|
|
}
|
|
}
|
|
if (!modelsValueRegistry.hasKey(path)) {
|
|
return result;
|
|
}
|
|
result.ok = true;
|
|
result.buffer = modelsValueRegistry.getItem(path);
|
|
return result;
|
|
};
|
|
|
|
const tensorflow = {};
|
|
tensorflow.modelsValue = {};
|
|
|
|
tensorflow.loadGraphModel = async function (path) {
|
|
const model = await tf.loadGraphModel(path, {
|
|
fromTFHub: false,
|
|
fetchFunc: (...args) => {
|
|
return customFetch(args[0]);
|
|
}
|
|
});
|
|
return model;
|
|
};
|
|
|
|
tensorflow.loadLayersModel = async function (path) {
|
|
const model = await tf.loadLayersModel(path, {
|
|
fromTFHub: false,
|
|
fetchFunc: (...args) => {
|
|
return customFetch(args[0]);
|
|
}
|
|
});
|
|
return model;
|
|
};
|
|
|
|
tensorflow.setModelsValue = function (path, value) {
|
|
if (modelsValueRegistry.hasKey(path)) {
|
|
modelsValueRegistry.unregister(path);
|
|
}
|
|
modelsValueRegistry.register(path, value);
|
|
tensorflow.modelsValue[path] = value;
|
|
};
|
|
|
|
window.tensorflow = tensorflow; |