const tf = require('@tensorflow/tfjs'); class ModelLoader { constructor(logger) { this.logger = logger || console; this.model = null; } async loadModel(modelUrl, inputShape = [null, 24, 166]) { try { this.logger.debug(`Fetching model JSON from: ${modelUrl}`); const response = await fetch(modelUrl); const modelJSON = await response.json(); // Fix input shape this.configureInputLayer(modelJSON, inputShape); // Extract base path const baseUrl = this.getBaseUrl(modelUrl); this.fixWeightPaths(modelJSON, baseUrl); // Ensure weight specs are there if ( !modelJSON.weightsManifest || !modelJSON.weightsManifest[0].weights || modelJSON.weightsManifest[0].weights.length === 0 ) { throw new Error("Model JSON is missing weight specifications."); } // Load the binary weight data const weightUrl = modelJSON.weightsManifest[0].paths[0]; const weightResponse = await fetch(weightUrl); const weightBuffer = await weightResponse.arrayBuffer(); console.log('modelJSON.weightsManifest:', JSON.stringify(modelJSON.weightsManifest, null, 2)); if ( !modelJSON.weightsManifest || !modelJSON.weightsManifest[0].weights || modelJSON.weightsManifest[0].weights.length === 0 ) { console.error("❌ modelJSON.weightsManifest is missing weight specs!"); } else { console.log("✅ Weight specs found:", modelJSON.weightsManifest[0].weights.length); } // Create ModelArtifacts object const artifacts = { modelTopology: modelJSON.modelTopology, weightSpecs: modelJSON.weightsManifest[0].weights, // ✅ CORRECT FIELD NAME weightData: weightBuffer }; // Load from memory this.model = await tf.loadLayersModel(tf.io.fromMemory(artifacts)); this.logger.debug('Model loaded successfully'); return this.model; } catch (error) { this.logger.error(`Failed to load model: ${error.message}`); throw error; } } configureInputLayer(modelJSON, inputShape) { const layers = modelJSON.modelTopology.model_config.config.layers; if (layers && layers.length > 0) { const firstLayer = layers[0]; if (firstLayer.class_name === 'InputLayer') { if (firstLayer.config.batch_shape) { firstLayer.config.batchInputShape = firstLayer.config.batch_shape; delete firstLayer.config.batch_shape; this.logger.debug('Converted batch_shape to batchInputShape:', firstLayer); } else if (!firstLayer.config.batchInputShape && !firstLayer.config.inputShape) { firstLayer.config.batchInputShape = inputShape; this.logger.debug('Configured input layer:', firstLayer); } else { this.logger.debug('Input shape already set:', firstLayer.config); } } } } getBaseUrl(url) { return url.substring(0, url.lastIndexOf('/') + 1); } fixWeightPaths(modelJSON, baseUrl) { for (const group of modelJSON.weightsManifest) { group.paths = group.paths.map(path => { path = path.replace(/^\/+/, ''); return path.startsWith('http') ? path : `${baseUrl}${path}`; }); } } } const modelLoader = new ModelLoader(); (async () => { try { const localURL = "http://localhost:1880/generalFunctions/datasets/lstmData/tfjs_model/model.json"; const model = await modelLoader.loadModel(localURL); console.log('Model loaded successfully'); const denseLayer = model.getLayer('dense_8'); const weights = denseLayer.getWeights(); const weightArray = await weights[0].array(); console.log('Dense layer kernel (sample):', weightArray.slice(0, 5)); } catch (error) { console.error('Failed to load model:', error); } })(); module.exports = ModelLoader;